"""
Eye-tracking analysis suite (iMotions-class gaze analytics).

Turns a recorded gaze stream (normalized [0,1] screen coords) into the standard
eye-movement events and metrics, plus the heatmap / scanpath visualizations:

  * I-VT classification — velocity-threshold split into fixations and saccades,
    in real degrees of visual angle (needs screen geometry; sensible defaults),
  * fixation metrics — count, durations, fixation rate,
  * saccade metrics — count, amplitudes (deg), peak velocity,
  * scanpath length, gaze sample validity,
  * heatmap (gaze density) and scanpath (ordered fixations) Plotly figures.

AOI dwell/TTFF live in `analysis/aoi.py`; calibration in `analysis/calibration.py`.
All pure numpy (+ optional plotly for the figures). Works on any gaze session.
"""

from __future__ import annotations

from dataclasses import dataclass, field

import numpy as np

from . import load as A


# --------------------------------------------------------------------------
# screen geometry: convert normalized gaze -> degrees of visual angle
# --------------------------------------------------------------------------
@dataclass
class Screen:
    width_px: int = 1920
    height_px: int = 1080
    width_mm: float = 530.0
    height_mm: float = 300.0
    distance_mm: float = 600.0          # eye-to-screen distance

    def norm_to_deg(self, x: np.ndarray, y: np.ndarray):
        """Normalized [0,1] gaze -> degrees of visual angle from screen center."""
        mm_x = (np.asarray(x) - 0.5) * self.width_mm
        mm_y = (np.asarray(y) - 0.5) * self.height_mm
        deg_x = np.degrees(np.arctan2(mm_x, self.distance_mm))
        deg_y = np.degrees(np.arctan2(mm_y, self.distance_mm))
        return deg_x, deg_y


# --------------------------------------------------------------------------
# events
# --------------------------------------------------------------------------
@dataclass
class Fixation:
    t_start: float
    t_end: float
    x: float            # normalized centroid
    y: float

    @property
    def duration(self) -> float:
        return self.t_end - self.t_start


@dataclass
class Saccade:
    t_start: float
    t_end: float
    amplitude_deg: float
    peak_velocity: float


def _gaze(session: dict, stream="EyeTracker", xy=(0, 1)):
    s = session["streams"][stream]
    t = np.asarray(s["t"], float)
    X = np.asarray(s["x"], float)
    return t, X[:, xy[0]], X[:, xy[1]]


def classify_ivt(session: dict, *, stream="EyeTracker", screen: Screen | None = None,
                 velocity_threshold=30.0, min_fix_duration=0.06,
                 merge_time=0.075, merge_angle_deg=0.5, min_saccade_deg=0.5):
    """
    I-VT: label each sample as fixation (slow) or saccade (fast) by angular
    velocity (deg/s; 30 is the common default), group runs into fixations, then
    apply the standard post-processing real trackers use:

      * merge adjacent fixations separated by < `merge_time` and within
        `merge_angle_deg` (collapses noise-induced splits),
      * derive saccades as the transitions between surviving fixations, keeping
        only those with amplitude >= `min_saccade_deg` (drops micro-saccades).

    Returns (fixations, saccades, sample_velocity_degps).
    """
    screen = screen or Screen()
    t, xn, yn = _gaze(session, stream)
    valid = ~(np.isnan(xn) | np.isnan(yn))
    dx, dy = screen.norm_to_deg(xn, yn)

    dt = np.gradient(t)
    dt[dt <= 0] = np.nan
    vel = np.sqrt(np.gradient(dx) ** 2 + np.gradient(dy) ** 2) / dt
    vel = np.nan_to_num(vel, nan=np.inf)

    # 1) raw fixation runs (ignore short ones)
    is_fix = (vel < velocity_threshold) & valid
    raw = []
    i, n = 0, len(t)
    while i < n:
        j = i
        while j < n and is_fix[j] and valid[j]:
            j += 1
        if j > i and (t[j-1] - t[i]) >= min_fix_duration:
            seg = slice(i, j)
            raw.append(Fixation(float(t[i]), float(t[j-1]),
                                float(np.nanmean(xn[seg])), float(np.nanmean(yn[seg]))))
        i = max(j, i + 1)

    # 2) merge fixations split by brief noise
    fixations = _merge_fixations(raw, screen, merge_time, merge_angle_deg)

    # 3) saccades = transitions between surviving fixations
    saccades = []
    for a, b in zip(fixations, fixations[1:]):
        ax, ay = screen.norm_to_deg(np.array([a.x]), np.array([a.y]))
        bx, by = screen.norm_to_deg(np.array([b.x]), np.array([b.y]))
        amp = float(np.hypot(bx[0]-ax[0], by[0]-ay[0]))
        if amp >= min_saccade_deg:
            span = (b.t_start - a.t_end) or 1e-3
            saccades.append(Saccade(a.t_end, b.t_start, amp, amp/span))
    return fixations, saccades, vel


def _merge_fixations(fix, screen, merge_time, merge_angle_deg):
    if not fix:
        return []
    out = [fix[0]]
    for f in fix[1:]:
        p = out[-1]
        px, py = screen.norm_to_deg(np.array([p.x]), np.array([p.y]))
        fx, fy = screen.norm_to_deg(np.array([f.x]), np.array([f.y]))
        gap = f.t_start - p.t_end
        ang = float(np.hypot(fx[0]-px[0], fy[0]-py[0]))
        if gap <= merge_time and ang <= merge_angle_deg:
            # weighted-centroid merge
            wp, wf = p.duration or 1e-3, f.duration or 1e-3
            out[-1] = Fixation(p.t_start, f.t_end,
                               (p.x*wp + f.x*wf)/(wp+wf), (p.y*wp + f.y*wf)/(wp+wf))
        else:
            out.append(f)
    return out


def gaze_metrics(session: dict, **kw) -> dict:
    """Summary eye-movement metrics for a whole gaze recording."""
    t, xn, yn = _gaze(session, kw.get("stream", "EyeTracker"))
    fix, sac, vel = classify_ivt(session, **kw)
    valid = ~(np.isnan(xn) | np.isnan(yn))
    dur = (t[-1] - t[0]) if len(t) > 1 else 0.0
    fix_durs = np.array([f.duration for f in fix]) if fix else np.array([0.0])
    amps = np.array([s.amplitude_deg for s in sac]) if sac else np.array([0.0])
    # scanpath length over fixation centroids (normalized units)
    if len(fix) > 1:
        fx = np.array([f.x for f in fix]); fy = np.array([f.y for f in fix])
        scan = float(np.sum(np.hypot(np.diff(fx), np.diff(fy))))
    else:
        scan = 0.0
    return {
        "duration_s": float(dur),
        "valid_ratio": float(np.mean(valid)) if len(valid) else 0.0,
        "fixation_count": len(fix),
        "fixation_rate_hz": float(len(fix) / dur) if dur else 0.0,
        "mean_fixation_ms": float(np.mean(fix_durs) * 1000),
        "saccade_count": len(sac),
        "mean_saccade_amp_deg": float(np.mean(amps)),
        "scanpath_len_norm": scan,
    }


# --------------------------------------------------------------------------
# visualizations
# --------------------------------------------------------------------------
def heatmap_grid(session: dict, *, stream="EyeTracker", bins=64, sigma=1.6):
    """2-D gaze density on the normalized screen (gaussian-smoothed histogram)."""
    t, xn, yn = _gaze(session, stream)
    m = ~(np.isnan(xn) | np.isnan(yn))
    H, xe, ye = np.histogram2d(xn[m], yn[m], bins=bins, range=[[0, 1], [0, 1]])
    H = _gauss_blur(H, sigma)
    return H.T, xe, ye           # transpose so rows=y


def fig_heatmap(session: dict, **kw):
    import plotly.graph_objects as go
    H, xe, ye = heatmap_grid(session, **kw)
    fig = go.Figure(go.Heatmap(z=H, x=(xe[:-1]+xe[1:])/2, y=(ye[:-1]+ye[1:])/2,
                               colorscale="Turbo", showscale=True, zsmooth="best"))
    fig.update_yaxes(autorange="reversed", range=[1, 0], title="screen y")
    fig.update_xaxes(range=[0, 1], title="screen x")
    fig.update_layout(title="Gaze heatmap (normalized screen)", height=420,
                      margin=dict(l=50, r=20, t=50, b=40))
    return fig


def fig_scanpath(session: dict, **kw):
    import plotly.graph_objects as go
    fix, _, _ = classify_ivt(session, **kw)
    fig = go.Figure()
    if fix:
        xs = [f.x for f in fix]; ys = [f.y for f in fix]
        sizes = [8 + 60 * f.duration for f in fix]
        fig.add_trace(go.Scatter(x=xs, y=ys, mode="lines", line=dict(color="#3a9aa8", width=1.5),
                                 hoverinfo="skip", showlegend=False))
        fig.add_trace(go.Scatter(x=xs, y=ys, mode="markers+text",
                                 text=[str(i+1) for i in range(len(fix))],
                                 textposition="middle center", textfont=dict(size=9, color="#fff"),
                                 marker=dict(size=sizes, color="#2f6470", opacity=.75,
                                             line=dict(color="#fff", width=1)),
                                 showlegend=False))
    fig.update_yaxes(autorange="reversed", range=[1, 0], title="screen y")
    fig.update_xaxes(range=[0, 1], title="screen x")
    fig.update_layout(title="Scanpath (fixation order; size = duration)", height=420,
                      margin=dict(l=50, r=20, t=50, b=40))
    return fig


def _gauss_blur(a, sigma):
    if sigma <= 0:
        return a
    r = int(3 * sigma)
    k = np.exp(-(np.arange(-r, r + 1) ** 2) / (2 * sigma ** 2))
    k /= k.sum()
    out = np.apply_along_axis(lambda m: np.convolve(m, k, "same"), 0, a)
    out = np.apply_along_axis(lambda m: np.convolve(m, k, "same"), 1, out)
    return out
