"""
Dynamic AOIs — areas of interest that MOVE over a video stimulus.

A static AOI (analysis/aoi.py) is a fixed rectangle/polygon. On video, the thing
you care about (a face, a product, a logo) moves frame to frame, so the AOI has to
move with it. A `DynamicAOI` is a set of keyframes — (time, region) — and the
region at any moment is linearly interpolated between the surrounding keyframes.

Given gaze with timestamps and a stimulus onset, this computes the standard
metrics against the moving AOI: dwell, time-to-first-fixation, hit, revisits.
Keyframes can be authored by hand, or generated automatically by tracking an
object across the video (`track_aoi`, OpenCV, lazy-imported).

Times are seconds relative to the video/stimulus onset. Coordinates are normalized
[0,1] like the rest of biosync.
"""

from __future__ import annotations

from dataclasses import dataclass, field

import numpy as np

from . import load as A


@dataclass
class DynamicAOI:
    name: str
    # keyframes sorted by time: each is (t, region) where region is a rect
    # (x0,y0,x1,y1) or a polygon [(x,y),...]. All keyframes must be the same kind.
    keyframes: list = field(default_factory=list)
    kind: str = "rect"                 # "rect" | "polygon"

    def add(self, t: float, region) -> "DynamicAOI":
        self.keyframes.append((float(t), region))
        self.keyframes.sort(key=lambda k: k[0])
        return self

    def region_at(self, t: float):
        """Interpolated region at time t (clamped to the keyframe range)."""
        ks = self.keyframes
        if not ks:
            return None
        if t <= ks[0][0]:
            return ks[0][1]
        if t >= ks[-1][0]:
            return ks[-1][1]
        for (t0, r0), (t1, r1) in zip(ks, ks[1:]):
            if t0 <= t <= t1:
                f = (t - t0) / ((t1 - t0) or 1e-9)
                return _lerp_region(r0, r1, f, self.kind)
        return ks[-1][1]

    def contains(self, x: float, y: float, t: float) -> bool:
        r = self.region_at(t)
        if r is None or x != x or y != y:        # None or NaN gaze
            return False
        return (_in_rect(x, y, r) if self.kind == "rect" else _in_poly(x, y, r))


def _lerp_region(r0, r1, f, kind):
    if kind == "rect":
        return tuple(a + (b - a) * f for a, b in zip(r0, r1))
    return [(ax + (bx - ax) * f, ay + (by - ay) * f)
            for (ax, ay), (bx, by) in zip(r0, r1)]


def _in_rect(x, y, r):
    x0, y0, x1, y1 = r
    return min(x0, x1) <= x <= max(x0, x1) and min(y0, y1) <= y <= max(y0, y1)


def _in_poly(x, y, poly):
    inside = False; n = len(poly); j = n - 1
    for i in range(n):
        xi, yi = poly[i]; xj, yj = poly[j]
        if ((yi > y) != (yj > y)) and \
           (x < (xj - xi) * (y - yi) / ((yj - yi) or 1e-12) + xi):
            inside = not inside
        j = i
    return inside


# --------------------------------------------------------------------------
def metrics(session: dict, dynamic_aois, *, stimulus: str | None = None,
            onset: float | None = None, window: float | None = None,
            gaze_stream="EyeTracker", gaze_xy=(0, 1)) -> dict:
    """
    Dwell / TTFF / hit / revisits for each DynamicAOI over the stimulus window.

    Provide either `stimulus` (uses its stim_on/stim_off markers) or an explicit
    `onset` (+ optional `window`). AOI times are relative to that onset.
    """
    s = session["streams"][gaze_stream]
    t = np.asarray(s["t"], float)
    X = np.asarray(s["x"], float)
    gx, gy = X[:, gaze_xy[0]], X[:, gaze_xy[1]]

    if stimulus is not None:
        lo, hi = _stimulus_window(session, stimulus)
    else:
        lo = onset if onset is not None else (t[0] if len(t) else 0.0)
        hi = lo + window if window is not None else (t[-1] if len(t) else lo)
    m = (t >= lo) & (t <= hi)
    t, gx, gy = t[m], gx[m], gy[m]
    rel = t - lo                                 # time since onset
    dt = np.gradient(t) if len(t) > 1 else np.array([0.0])

    out = {}
    for aoi in dynamic_aois:
        inside = np.array([aoi.contains(px, py, rt)
                           for px, py, rt in zip(gx, gy, rel)])
        dwell = float(np.sum(dt[inside])) if inside.any() else 0.0
        ttff = float(rel[np.argmax(inside)]) if inside.any() else None
        out[aoi.name] = {
            "dwell_time_s": dwell,
            "hit": bool(inside.any()),
            "ttff_s": ttff,
            "revisits": _revisits(inside),
            "coverage": float(np.mean(inside)) if len(inside) else 0.0,
        }
    return out


def _stimulus_window(session, stimulus):
    lo = None
    for mt, label in A.markers(session):
        label = str(label)
        if label == f"stim_on:{stimulus}" and lo is None:
            lo = float(mt)
        elif label == f"stim_off:{stimulus}" and lo is not None:
            return lo, float(mt)
    t = session["streams"]["EyeTracker"]["t"]
    return (lo if lo is not None else (t[0] if len(t) else 0.0)), (t[-1] if len(t) else 0.0)


def _revisits(inside: np.ndarray) -> int:
    if not inside.any():
        return 0
    entries = int(np.sum(np.diff(inside.astype(int)) == 1)) + (1 if inside[0] else 0)
    return max(0, entries - 1)


# --------------------------------------------------------------------------
def track_aoi(video_path: str, name: str, init_box, *, every: float = 0.25,
              tracker: str = "CSRT") -> DynamicAOI:
    """
    Auto-generate AOI keyframes by tracking an object across a video.

    init_box: (x, y, w, h) in normalized coords on the first frame.
    Samples a keyframe every `every` seconds. Uses OpenCV (lazy import); the open
    replacement for hand-drawing a moving AOI on every frame.
    """
    import cv2  # lazy
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    ok, frame = cap.read()
    if not ok:
        raise RuntimeError(f"cannot read video {video_path}")
    H, W = frame.shape[:2]
    x, y, w, h = init_box
    box = (int(x * W), int(y * H), int(w * W), int(h * H))
    trk = _make_tracker(cv2, tracker)
    trk.init(frame, box)
    aoi = DynamicAOI(name, kind="rect")
    aoi.add(0.0, _box_to_rect(box, W, H))
    fi, step = 0, max(1, int(fps * every))
    while True:
        ok, frame = cap.read()
        if not ok:
            break
        fi += 1
        if fi % step:
            continue
        ok, box = trk.update(frame)
        if ok:
            aoi.add(fi / fps, _box_to_rect(box, W, H))
    cap.release()
    return aoi


def _make_tracker(cv2, name):
    for attr in (f"Tracker{name}_create", f"legacy.Tracker{name}_create"):
        obj = cv2
        try:
            for part in attr.split("."):
                obj = getattr(obj, part)
            return obj()
        except AttributeError:
            continue
    raise RuntimeError(f"OpenCV tracker {name} unavailable")


def _box_to_rect(box, W, H):
    x, y, w, h = box
    return (x / W, y / H, (x + w) / W, (y + h) / H)


# --------------------------------------------------------------------------
# persistence — the visual editor saves/loads these
# --------------------------------------------------------------------------
def to_dict(aoi: DynamicAOI) -> dict:
    return {"name": aoi.name, "kind": aoi.kind,
            "keyframes": [[float(t), list(r)] for t, r in aoi.keyframes]}


def from_dict(d: dict) -> DynamicAOI:
    a = DynamicAOI(d["name"], kind=d.get("kind", "rect"))
    for t, r in d.get("keyframes", []):
        a.add(float(t), tuple(r) if d.get("kind", "rect") == "rect"
              else [tuple(p) for p in r])
    return a


def load_aois(path: str) -> dict:
    """aois.json -> {stimulus: [DynamicAOI, ...]}."""
    import json
    import os
    if not os.path.exists(path):
        return {}
    raw = json.load(open(path))
    return {stim: [from_dict(a) for a in aois] for stim, aois in raw.items()}


def save_aois(path: str, aois_by_stimulus: dict) -> str:
    """{stimulus: [DynamicAOI|dict, ...]} -> aois.json."""
    import json
    out = {}
    for stim, aois in aois_by_stimulus.items():
        out[stim] = [a if isinstance(a, dict) else to_dict(a) for a in aois]
    json.dump(out, open(path, "w"), indent=2)
    return path
