from PIL import Image
from pathlib import Path
import asyncio
import aiohttp
import os
import cv2
from lib.card_detector import CardDetector
from lib.config import load_config
from lib.camera import Camera

# VL6180X sensor configuration for CardDetector
INT_PIN = 4
DETECTION_THRESHOLD_MM = 25
INTERMEASUREMENT_MS = 5
DEBOUNCE_WINDOW_S = 0.2


class MotionHandler:
    def __init__(
        self,
        i2c_bus,
        buzzer=None,
        screen=None,
        i2c_lock=None,
        wait_time=1,
        output_dir=Path("./out"),
        output_filename="latest.jpg",
        use_sharpness_detection=True,
        max_wait_time=3.0,
        min_stable_frames=2,
        sharpness_threshold_increase=0.15,
        good_enough_sharpness=150.0,
        analysis_resize_factor=0.5,
        min_stabilization_time=0.8,
    ):
        self.scanned_count = 0
        self.wait_time = wait_time
        self.output_dir = output_dir
        self.output_file = self.output_dir / output_filename
        self.output_dir.mkdir(exist_ok=True)
        self.send_to = None
        self.current_bucket = None
        self.current_bucket_label = None
        self.collection_id = None
        self.test_mode = False
        self.config = load_config()

        # Sharpness detection parameters
        self.use_sharpness_detection = use_sharpness_detection
        self.max_wait_time = max_wait_time
        self.min_stable_frames = min_stable_frames
        self.sharpness_threshold_increase = sharpness_threshold_increase
        self.good_enough_sharpness = good_enough_sharpness
        self.analysis_resize_factor = analysis_resize_factor
        self.min_stabilization_time = min_stabilization_time

        self._task = None
        self._lock = asyncio.Lock()
        self._capture_in_progress = False
        self._paused = False

        if buzzer is None:
            self.buzzer = None
            raise Exception("No buzzer provided")
        buzzer.quick_beep()

        if screen is None:
            self.screen = None
            raise Exception("No screen provided")

        self.buzzer = buzzer
        self.screen = screen
        self.screen.show_text("Init...")

        print("Initializing camera...")
        self.camera = Camera(self.config)

        print("Initializing card detector...")
        self._detector = CardDetector(
            i2c_bus=i2c_bus,
            on_card_detected=self._on_card_detected,
            i2c_lock=i2c_lock,
            int_pin=INT_PIN,
            threshold_mm=DETECTION_THRESHOLD_MM,
            loop_period_ms=INTERMEASUREMENT_MS,
            debounce_s=DEBOUNCE_WINDOW_S,
        )
        print("Card detector initialized (will start on first async loop)")

    async def set_ready(self, send_to, bucket, label, collection_id, token):
        await self._detector.reset_sensor()

        self.send_to = send_to
        self.current_bucket = bucket
        self.current_bucket_label = label
        self.collection_id = collection_id
        self.token = token
        self.scanned_count = 0
        self._capture_in_progress = False

        self.screen.show_text(f"Scanning\nbucket {label}")
        self.buzzer.ready()

    def reset(self):
        self.current_bucket = None
        self.current_bucket_label = None
        self.collection_id = None
        self.scanned_count = 0

    def can_scan(self):
        # In test mode, always allow scanning
        if self.test_mode:
            return True

        if not all([self.send_to, self.current_bucket, self.collection_id]):
            return False
        if self.scanned_count >= self.current_bucket:
            return False
        return True

    async def _on_card_detected(self):
        """Callback invoked by CardDetector when a card is detected after debouncing."""
        print("Card detection callback triggered")

        if self._paused:
            print("Card detected while scanner is paused, ignoring")
            return

        if self._capture_in_progress:
            print("Card detected while capture in progress, ignoring")
            return

        if not self.can_scan():
            print("Card detected but scanner is not ready (not in test mode or scan mode), ignoring")
            return

        self._capture_in_progress = True
        await self.start_capture_process()

    def calculate_image_sharpness(self, image_array):
        """Calculate image sharpness using optimized Laplacian variance method"""
        try:
            # Resize image for faster analysis if resize factor < 1.0
            if self.analysis_resize_factor < 1.0:
                height, width = image_array.shape[:2]
                new_height = int(height * self.analysis_resize_factor)
                new_width = int(width * self.analysis_resize_factor)
                resized = cv2.resize(image_array, (new_width, new_height))
            else:
                resized = image_array

            # Convert to grayscale if needed
            if len(resized.shape) == 3:
                gray = cv2.cvtColor(resized, cv2.COLOR_RGB2GRAY)
            else:
                gray = resized

            # Calculate Laplacian variance (measures sharpness)
            laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
            return laplacian_var
        except Exception as e:
            print(f"Error calculating sharpness: {e}")
            return 0.0

    async def wait_for_stable_image(self):
        """Wait for image to stabilize using sharpness detection"""
        if not self.use_sharpness_detection:
            # Fall back to fixed wait time
            await asyncio.sleep(self.wait_time)
            return

        print("Starting sharpness-based stabilization detection...")
        start_time = asyncio.get_event_loop().time()
        stable_frame_count = 0
        best_sharpness = 0.0
        best_image = None
        last_sharpness = 0.0
        frame_count = 0
        total_capture_time = 0.0
        total_sharpness_time = 0.0

        # Dynamic sample rate: start faster, then slow down
        # This accounts for actual processing time (~350ms per frame)
        base_sample_interval = 0.4  # More realistic given processing overhead

        while True:
            current_time = asyncio.get_event_loop().time()
            elapsed = current_time - start_time

            # Safety timeout
            if elapsed > self.max_wait_time:
                print(f"Timeout reached ({self.max_wait_time}s), using best image so far")
                break

            try:
                # Time the image capture
                capture_start = asyncio.get_event_loop().time()
                array = self.camera.capture_array()
                capture_end = asyncio.get_event_loop().time()
                capture_duration = capture_end - capture_start
                total_capture_time += capture_duration

                # Time the sharpness calculation
                sharpness_start = asyncio.get_event_loop().time()
                sharpness = self.calculate_image_sharpness(array)
                sharpness_end = asyncio.get_event_loop().time()
                sharpness_duration = sharpness_end - sharpness_start
                total_sharpness_time += sharpness_duration

                frame_count += 1

                print(
                    f"Frame {frame_count}: Time: {elapsed:.2f}s, Sharpness: {sharpness:.1f}, "
                    f"Capture: {capture_duration * 1000:.1f}ms, Analysis: {sharpness_duration * 1000:.1f}ms"
                )

                # Track the sharpest image seen so far
                if sharpness > best_sharpness:
                    best_sharpness = sharpness
                    best_image = array.copy()
                    stable_frame_count = 0  # Reset stability counter
                    print(f"  → New best sharpness: {best_sharpness:.1f}")

                    # Early exit for really good sharpness, but only after minimum stabilization time
                    if (
                        best_sharpness >= self.good_enough_sharpness
                        and frame_count >= 3
                        and elapsed >= self.min_stabilization_time
                    ):
                        print(
                            f"  → Early exit: sharpness {best_sharpness:.1f} exceeds threshold {self.good_enough_sharpness} after {elapsed:.2f}s"
                        )
                        break
                else:
                    # Check if sharpness has stabilized (not improving significantly)
                    sharpness_diff = abs(sharpness - last_sharpness)
                    if sharpness_diff < (best_sharpness * self.sharpness_threshold_increase):
                        stable_frame_count += 1
                        print(f"  → Stable frame {stable_frame_count}/{self.min_stable_frames}")
                    else:
                        stable_frame_count = 0

                # If we have enough stable frames and reasonable sharpness, we're done
                if stable_frame_count >= self.min_stable_frames and best_sharpness > 50:
                    avg_capture = (total_capture_time / frame_count) * 1000
                    avg_sharpness = (total_sharpness_time / frame_count) * 1000
                    print(f"Image stabilized after {elapsed:.2f}s with sharpness {best_sharpness:.1f}")
                    print(
                        f"Performance: {frame_count} frames, avg capture: {avg_capture:.1f}ms, "
                        f"avg analysis: {avg_sharpness:.1f}ms"
                    )
                    break

                last_sharpness = sharpness

                # Account for processing time in sleep interval
                processing_time = asyncio.get_event_loop().time() - capture_start

                # Adaptive sample interval: shorter for first few frames, longer later
                if frame_count <= 3:
                    current_sample_interval = base_sample_interval * 0.75  # Faster initial sampling
                else:
                    current_sample_interval = base_sample_interval

                actual_sleep = max(0.05, current_sample_interval - processing_time)  # Minimum 50ms sleep
                await asyncio.sleep(actual_sleep)

            except Exception as e:
                print(f"Error during sharpness detection: {e}")
                # Fall back to fixed wait time
                await asyncio.sleep(self.wait_time)
                return None

        # Final performance summary
        if frame_count > 0:
            avg_capture = (total_capture_time / frame_count) * 1000
            avg_sharpness = (total_sharpness_time / frame_count) * 1000
            total_processing = total_capture_time + total_sharpness_time
            print("Final performance summary:")
            print(f"  Total frames processed: {frame_count}")
            print(f"  Average capture time: {avg_capture:.1f}ms")
            print(f"  Average sharpness analysis: {avg_sharpness:.1f}ms")
            print(f"  Total processing time: {total_processing * 1000:.1f}ms")
            print(f"  Processing overhead: {(total_processing / elapsed) * 100:.1f}% of total time")

        return best_image

    async def capture_image(self):
        print("Acquiring lock...")
        async with self._lock:
            self.screen.show_wait()
            capture_start_time = asyncio.get_event_loop().time()
            print("Capturing image...")
            try:
                # Wait for image to stabilize and get the best frame
                stabilization_start = asyncio.get_event_loop().time()
                best_array = await self.wait_for_stable_image()
                stabilization_end = asyncio.get_event_loop().time()
                stabilization_time = stabilization_end - stabilization_start

                if best_array is not None:
                    # Use the stabilized frame
                    array = best_array
                    print(f"Using sharpness-optimized frame (stabilization took {stabilization_time:.2f}s)")
                else:
                    # Fall back to immediate capture
                    array_start = asyncio.get_event_loop().time()
                    array = self.camera.capture_array()
                    array_end = asyncio.get_event_loop().time()
                    print(f"Using immediate capture (fallback, took {(array_end - array_start) * 1000:.1f}ms)")

                # Time the PIL processing
                pil_start = asyncio.get_event_loop().time()
                image = Image.fromarray(array)
                rotation = self.config.get("camera_rotation", 180)
                image = image.rotate(rotation, expand=True)
                pil_end = asyncio.get_event_loop().time()
                pil_time = pil_end - pil_start
                print(f"PIL processing took {pil_time * 1000:.1f}ms (rotation: {rotation}°)")

                async def save_and_notify():
                    asyncio.get_event_loop().time()

                    # Convert PIL image to bytes for upload
                    import io

                    buffer = io.BytesIO()
                    image.save(buffer, format="JPEG")
                    image_bytes = buffer.getvalue()

                    # Only save to file in test mode for preview
                    if self.test_mode:
                        try:
                            # Ensure image is fully written to disk with proper sync
                            temp_file = str(self.output_file) + ".tmp"
                            print(f"Saving image to temp file: {temp_file}")

                            # Time the image save operation
                            save_op_start = asyncio.get_event_loop().time()
                            await asyncio.to_thread(image.save, temp_file, "JPEG")
                            save_op_end = asyncio.get_event_loop().time()
                            save_duration = save_op_end - save_op_start

                            # Time the atomic move operation
                            move_start = asyncio.get_event_loop().time()
                            await asyncio.to_thread(os.rename, temp_file, str(self.output_file))
                            move_end = asyncio.get_event_loop().time()
                            move_duration = move_end - move_start

                            print(f"Image saved successfully to {self.output_file}")
                            print(
                                f"Save timing: JPEG write {save_duration * 1000:.1f}ms, atomic move {move_duration * 1000:.1f}ms"
                            )

                        except Exception as save_error:
                            print(f"Error saving image: {save_error}")
                            # Fallback to direct save if atomic operation fails
                            try:
                                fallback_start = asyncio.get_event_loop().time()
                                await asyncio.to_thread(image.save, str(self.output_file))
                                fallback_end = asyncio.get_event_loop().time()
                                print(
                                    f"Fallback save successful to {self.output_file} (took {(fallback_end - fallback_start) * 1000:.1f}ms)"
                                )
                            except Exception as fallback_error:
                                print(f"Fallback save also failed: {fallback_error}")
                                return

                        # Test mode: just show confirmation
                        self.screen.show_check("Captured!")
                        self.buzzer.quick_beep()
                    elif self.send_to and self.collection_id and self.current_bucket_label:
                        # Normal mode: upload to server using image bytes
                        asyncio.create_task(self.send_to_server(image_bytes))
                        self.scanned_count += 1

                        if self.scanned_count >= self.current_bucket:
                            self.screen.show_cross("Stop-Next bucket")
                            self.buzzer.wait()
                            asyncio.create_task(self.notify_bucket_finished())
                        else:
                            self.screen.show_text(f"go {self.scanned_count}/{self.current_bucket}")
                            self.buzzer.quick_beep()

                # Start the save operation and measure total capture time
                save_task = asyncio.create_task(save_and_notify())

                # Wait for save to complete to get final timing
                await save_task

                capture_end_time = asyncio.get_event_loop().time()
                total_capture_time = capture_end_time - capture_start_time
                print(f"📸 TOTAL CAPTURE TIME: {total_capture_time:.2f}s")

            except Exception as e:
                capture_end_time = asyncio.get_event_loop().time()
                total_capture_time = capture_end_time - capture_start_time
                print(f"Failed to capture image: {e} (after {total_capture_time:.2f}s)")
                self.screen.show_cross("Error")

    async def start_capture_process(self):
        print("Starting capture process")

        if not self.can_scan():
            self.buzzer.error()
            if self.test_mode:
                self.screen.show_cross("Test Error")
            elif self.scanned_count >= self.current_bucket:
                self.screen.show_cross("Stop-Next bucket")
            else:
                self.screen.show_cross("Stop-Not Ready")
            self._capture_in_progress = False
            return

        async def delayed_capture():
            try:
                await self.capture_image()
            except asyncio.CancelledError:
                print("Delayed capture cancelled")
            finally:
                self._capture_in_progress = False

        self._task = asyncio.create_task(delayed_capture())

    async def notify_bucket_finished(self):
        url = f"{self.send_to.rsplit('/', 1)[0]}/bucket_scan_finished"
        self._capture_in_progress = False

        async with aiohttp.ClientSession() as session:
            headers = {"Authorization": f"Bearer {self.token}"}

            data = {
                "collection_id": self.collection_id,
                "bucket_label": self.current_bucket_label,
            }

            async with session.post(url, json=data, headers=headers) as response:
                if response.status != 200:
                    print(f"Failed to notify bucket finished: {await response.text()}")
                else:
                    print("Successfully notified bucket finished")

    async def send_to_server(self, image_bytes):
        url = self.send_to  # Now using the full URL provided by the Go server

        async with aiohttp.ClientSession() as session:
            data = aiohttp.FormData()
            data.add_field(
                "image",
                image_bytes,
                filename="scan.jpg",
                content_type="image/jpeg",
            )

            data.add_field("collection_id", str(self.collection_id))
            data.add_field("bucket_label", self.current_bucket_label)

            headers = {"Authorization": f"Bearer {self.token}"}

            async with session.post(url, data=data, headers=headers) as response:
                if response.status != 200:
                    raise Exception(f"Failed to upload image: {await response.text()}")
                print(f"Successfully uploaded image to {url}")

    async def enable_test_mode(self):
        """Enable test mode - allows motion capture without server setup"""
        await self._detector.reset_sensor()

        self.test_mode = True
        print("Test mode enabled")

    def disable_test_mode(self):
        """Disable test mode - return to normal operation"""
        self.test_mode = False
        print("Test mode disabled")

    def configure_sharpness_detection(
        self,
        use_sharpness_detection=None,
        max_wait_time=None,
        min_stable_frames=None,
        sharpness_threshold_increase=None,
        good_enough_sharpness=None,
        analysis_resize_factor=None,
        min_stabilization_time=None,
    ):
        """Configure sharpness detection parameters"""
        if use_sharpness_detection is not None:
            self.use_sharpness_detection = use_sharpness_detection
        if max_wait_time is not None:
            self.max_wait_time = max_wait_time
        if min_stable_frames is not None:
            self.min_stable_frames = min_stable_frames
        if sharpness_threshold_increase is not None:
            self.sharpness_threshold_increase = sharpness_threshold_increase
        if good_enough_sharpness is not None:
            self.good_enough_sharpness = good_enough_sharpness
        if analysis_resize_factor is not None:
            self.analysis_resize_factor = analysis_resize_factor
        if min_stabilization_time is not None:
            self.min_stabilization_time = min_stabilization_time

        print(
            f"Sharpness detection configured: enabled={self.use_sharpness_detection}, "
            f"max_wait={self.max_wait_time}s, min_stable={self.min_stable_frames}, "
            f"threshold={self.sharpness_threshold_increase}, good_enough={self.good_enough_sharpness}, "
            f"resize_factor={self.analysis_resize_factor}, min_stabilization={self.min_stabilization_time}s"
        )

    def pause(self):
        """Pause the scanner - motion detection will be ignored"""
        self._paused = True
        print("Scanner paused")
        if self.screen:
            if self.current_bucket_label:
                self.screen.show_text(f"PAUSED\nbucket {self.current_bucket_label}")
            else:
                self.screen.show_text("PAUSED")

    def resume(self):
        """Resume the scanner - motion detection will be re-enabled"""
        self._paused = False
        print("Scanner resumed")
        if self.screen:
            if self.current_bucket_label:
                self.screen.show_text(f"Scanning\nbucket {self.current_bucket_label}")
            else:
                self.screen.show_text("Ready")

    def is_paused(self):
        """Check if the scanner is currently paused"""
        return self._paused

    def reload_config(self):
        """Reload configuration from file"""
        self.config = load_config()
        self.camera.reconfigure_resolution(self.config)
        print(f"Config reloaded - camera_rotation: {self.config.get('camera_rotation', 180)}°")

    async def init_gpio(self):
        """Initialize GPIO interrupt handling (sensor init deferred to first scan)."""
        await self._detector.init_gpio()
        print("GPIO interrupt handling initialized")

    async def cleanup(self):
        if self._task and not self._task.done():
            self._task.cancel()
            try:
                await self._task
            except asyncio.CancelledError:
                pass

        self._detector.shutdown()
        self.camera.cleanup()

    async def capture_instant_image(self):
        """Capture an instant image without sharpening and save to latest_instant.jpg"""
        print("Capturing instant image...")

        async with self._lock:
            try:
                # Get the current configuration
                rotation = self.config.get("camera_rotation", 180)

                # Capture the image immediately without waiting for stabilization
                array = self.camera.capture_array()

                # Convert to PIL image and rotate
                image = Image.fromarray(array)
                image = image.rotate(rotation, expand=True)

                # Save to latest_instant.jpg
                instant_file = self.output_dir / "latest_instant.jpg"

                # Save with atomic operation
                temp_file = str(instant_file) + ".tmp"
                await asyncio.to_thread(image.save, temp_file, "JPEG")
                await asyncio.to_thread(os.rename, temp_file, str(instant_file))

                print(f"Instant image saved to {instant_file}")

                # Show success on screen
                if self.screen:
                    self.screen.show_check("Instant Captured!")
                if self.buzzer:
                    self.buzzer.friendly_beep()

                return True

            except Exception as e:
                print(f"Error capturing instant image: {e}")
                if self.screen:
                    self.screen.show_cross("Error")
                if self.buzzer:
                    self.buzzer.error()
                return False