"""
Facial Expression Analysis source — the open replacement for iMotions' locked
Affectiva module (HARDWARE.md, tier=ml).

A webcam frame goes through py-feat (open-source FACS: action units + emotions,
OpenFace/RetinaFace under the hood). We push one sample per processed frame:
the action units plus a valence/arousal estimate, as an LSL stream the recorder
captures and the dashboard circumplex consumes — no Affectiva licence, runs offline.

Everything heavy is lazy-imported, and the analysis step is injectable so the
source can be unit-tested and so you can swap py-feat for OpenFace/MediaPipe.
"""

from __future__ import annotations

import time
from typing import Callable

from ..drivers._base import DriverSource

# AUs we publish (subset of FACS that drives the common affect mappings).
AU_CHANNELS = ["AU01", "AU02", "AU04", "AU06", "AU12"]
CHANNELS = AU_CHANNELS + ["valence", "arousal"]


class FEASource(DriverSource):
    """
    Parameters (all via kwargs through drivers.make):
      camera       : OpenCV camera index (default 0).
      fps          : target processing rate (default 10; FEA is compute-bound).
      detector     : a pre-built py-feat Detector (else one is created lazily).
      frame_source : callable -> frame ndarray | None. Overrides the webcam
                     (used for tests / video files / an external grabber).
      analyze      : callable(frame) -> (aus: dict, emotions: dict). Overrides the
                     whole py-feat path (used for tests / alternative models).
    """

    def __init__(self, device, *, camera=0, fps=10, detector=None,
                 frame_source: Callable | None = None,
                 analyze: Callable | None = None, **opts):
        # ensure the published channels match what we emit, regardless of catalog
        object.__setattr__(device, "channels", tuple(CHANNELS)) \
            if not device.channels else None
        super().__init__(device, camera=camera, fps=fps, **opts)
        self._detector = detector
        self._frame_source = frame_source
        self._analyze = analyze

    # --- the py-feat analysis (overridable) ------------------------------
    def _get_detector(self):
        if self._detector is None:
            from feat import Detector            # lazy: heavy (torch)
            self._detector = Detector()
        return self._detector

    def _default_analyze(self, frame):
        det = self._get_detector()
        fex = det.detect(frame, data_type="image") if hasattr(det, "detect") \
            else det.detect_image(frame)
        aus = _row_to_dict(fex.aus) if hasattr(fex, "aus") else {}
        emo = _row_to_dict(fex.emotions) if hasattr(fex, "emotions") else {}
        return aus, emo

    # --- capture loop ----------------------------------------------------
    def read(self):
        from pylsl import local_clock
        analyze = self._analyze or self._default_analyze
        grab = self._frame_source or _webcam_grabber(self.opts["camera"])
        period = 1.0 / float(self.opts["fps"])
        try:
            while not self.stopping:
                t = time.time()
                frame = grab()
                if frame is not None:
                    aus, emo = analyze(frame)
                    val, aro = valence_arousal(emo)
                    sample = [float(aus.get(a, 0.0)) for a in AU_CHANNELS] + [val, aro]
                    yield sample, local_clock()
                time.sleep(max(0.0, period - (time.time() - t)))
        finally:
            closer = getattr(grab, "close", None)
            if closer:
                closer()


# --------------------------------------------------------------------------
def valence_arousal(emotions: dict) -> tuple[float, float]:
    """Map categorical emotion probabilities to circumplex valence/arousal.

    Standard affective-computing mapping (Russell circumplex): valence is the
    pleasant-minus-unpleasant balance; arousal is activation.
    """
    g = lambda k: float(emotions.get(k, 0.0))
    valence = g("happiness") - (g("sadness") + g("anger") + g("disgust") + g("fear"))
    arousal = (g("anger") + g("fear") + g("surprise") + g("happiness")) - g("neutral")
    return _clip(valence), _clip(arousal)


def _clip(v, lo=-1.0, hi=1.0):
    return float(max(lo, min(hi, v)))


def _row_to_dict(df_like):
    """Take the first row of a py-feat AUs/emotions frame as a plain dict."""
    try:
        row = df_like.iloc[0]
        return {str(k): float(v) for k, v in row.items()}
    except Exception:
        return dict(df_like) if isinstance(df_like, dict) else {}


def _webcam_grabber(camera):
    import cv2  # lazy
    cap = cv2.VideoCapture(camera)

    def grab():
        ok, frame = cap.read()
        return frame if ok else None

    grab.close = cap.release
    return grab
