"""
Eye-tracker calibration & validation (iMotions-class).

Calibration here means the *quality check* every eye-tracking study needs before
recording: show known target points, collect gaze while the subject fixates each,
then report accuracy and precision in degrees of visual angle.

  * accuracy  — mean angular distance between the gaze centroid and the target
                (systematic error). Lower is better; research-grade aims < ~1deg.
  * precision — RMS sample-to-sample angular distance at a fixed target
                (jitter / noise). Lower is better.

`CalibrationRoutine` defines the target layout and accumulates gaze per target so
it works both live (feed samples from a gaze Source) and offline. The math is pure
numpy and fully testable with synthetic gaze.
"""

from __future__ import annotations

from dataclasses import dataclass, field

import numpy as np

from .eyetracking import Screen


# standard layouts (normalized screen coords)
LAYOUTS = {
    5: [(0.5, 0.5), (0.1, 0.1), (0.9, 0.1), (0.1, 0.9), (0.9, 0.9)],
    9: [(x, y) for y in (0.1, 0.5, 0.9) for x in (0.1, 0.5, 0.9)],
    12: [(x, y) for y in (0.12, 0.5, 0.88) for x in (0.08, 0.36, 0.64, 0.92)],
    13: [(0.5, 0.5), (0.1, 0.1), (0.5, 0.1), (0.9, 0.1), (0.1, 0.5), (0.9, 0.5),
         (0.1, 0.9), (0.5, 0.9), (0.9, 0.9), (0.3, 0.3), (0.7, 0.3), (0.3, 0.7), (0.7, 0.7)],
}


@dataclass
class PointResult:
    target: tuple
    n_samples: int
    accuracy_deg: float
    precision_deg: float
    valid: bool


@dataclass
class CalibrationResult:
    points: list
    accuracy_deg: float
    precision_deg: float
    passed: bool
    threshold_deg: float

    def summary(self) -> dict:
        return {"accuracy_deg": round(self.accuracy_deg, 3),
                "precision_deg": round(self.precision_deg, 3),
                "passed": self.passed, "n_points": len(self.points),
                "threshold_deg": self.threshold_deg}


def evaluate_point(target, gaze_xy, screen: Screen) -> PointResult:
    """Accuracy + precision (deg) for one target from its collected gaze samples."""
    g = np.asarray(gaze_xy, float)
    g = g[~np.isnan(g).any(axis=1)]
    if len(g) < 3:
        return PointResult(tuple(target), len(g), float("nan"), float("nan"), False)
    gx_deg, gy_deg = screen.norm_to_deg(g[:, 0], g[:, 1])
    tx_deg, ty_deg = screen.norm_to_deg(np.array([target[0]]), np.array([target[1]]))
    # accuracy: distance from mean gaze to target
    acc = float(np.hypot(np.mean(gx_deg) - tx_deg[0], np.mean(gy_deg) - ty_deg[0]))
    # precision: RMS of successive sample distances
    d = np.hypot(np.diff(gx_deg), np.diff(gy_deg))
    prec = float(np.sqrt(np.mean(d ** 2))) if len(d) else 0.0
    return PointResult(tuple(target), len(g), acc, prec, True)


def evaluate(points: dict, *, screen: Screen | None = None,
             threshold_deg: float = 1.0) -> CalibrationResult:
    """
    points: {(tx,ty): [[gx,gy], ...]}. Returns overall accuracy/precision and a
    pass/fail against `threshold_deg` (mean accuracy across valid points).
    """
    screen = screen or Screen()
    results = [evaluate_point(t, g, screen) for t, g in points.items()]
    valid = [r for r in results if r.valid]
    acc = float(np.mean([r.accuracy_deg for r in valid])) if valid else float("nan")
    prec = float(np.mean([r.precision_deg for r in valid])) if valid else float("nan")
    passed = bool(valid) and acc <= threshold_deg
    return CalibrationResult(results, acc, prec, passed, threshold_deg)


class CalibrationRoutine:
    """
    Drives a calibration: iterate targets, collect gaze per target, evaluate.

    Live use:
        rt = CalibrationRoutine(points=9)
        for target in rt.targets:
            show_dot(target)                       # your presenter
            for x, y in read_gaze_for(1.0):        # ~1s of gaze samples
                rt.add(target, x, y)
        result = rt.result()

    Offline: feed previously recorded gaze the same way.
    """

    def __init__(self, points: int | list = 9, screen: Screen | None = None,
                 threshold_deg: float = 1.0):
        self.targets = LAYOUTS[points] if isinstance(points, int) else list(points)
        self.screen = screen or Screen()
        self.threshold_deg = threshold_deg
        self._acc: dict = {t: [] for t in self.targets}

    def add(self, target, x: float, y: float):
        self._acc[tuple(target)].append([x, y])

    def result(self) -> CalibrationResult:
        pts = {t: np.array(v) for t, v in self._acc.items() if v}
        return evaluate(pts, screen=self.screen, threshold_deg=self.threshold_deg)
