#!/usr/bin/env python3
"""Aggregate repeated PoC harness runs across trial directories."""

from __future__ import annotations

import csv
import math
import statistics
import sys
from pathlib import Path

from parse_results import run_sort_key, summarize_run


def mean(values: list[float]) -> float | None:
    return statistics.fmean(values) if values else None


def stdev(values: list[float]) -> float | None:
    if len(values) < 2:
        return 0.0 if values else None
    return statistics.stdev(values)


def fmt(value: float | int | None) -> str:
    if value is None:
        return ""
    if isinstance(value, int):
        return str(value)
    if math.isfinite(value):
        return f"{value:.4f}"
    return ""


def load_rows(root: Path) -> list[dict[str, object]]:
    rows: list[dict[str, object]] = []
    for trial_dir in sorted(p for p in root.iterdir() if p.is_dir()):
        for run_dir in sorted(
            (p for p in trial_dir.iterdir() if p.is_dir() and (p / "summary.json").exists()),
            key=run_sort_key,
        ):
            row = summarize_run(run_dir)
            row["trial"] = trial_dir.name
            rows.append(row)
    return rows


def main() -> int:
    if len(sys.argv) != 2:
        print(f"usage: {sys.argv[0]} BATCH_DIR", file=sys.stderr)
        return 2

    root = Path(sys.argv[1])
    rows = load_rows(root)
    if not rows:
        print("section,mode,runs,host_write_mean,host_write_stdev,fdp_hbmw_mean,fdp_hbmw_stdev,write_p99_mean,write_p99_stdev,failures_total")
        return 0

    fields = [
        "section",
        "trial",
        "run_dir",
        "mode",
        "run_id",
        "host_write_bytes_delta",
        "fdp_hbmw_delta",
        "fdp_mbmw_delta",
        "fdp_mbe_delta",
        "write_p99_ms",
        "all_p99_ms",
        "records",
        "failures",
        "waf_status",
    ]
    writer = csv.DictWriter(sys.stdout, fieldnames=fields)
    writer.writeheader()
    for row in rows:
        writer.writerow(
            {
                "section": "per_run",
                "trial": row["trial"],
                "run_dir": row["run_dir"],
                "mode": row["mode"],
                "run_id": row["run_id"],
                "host_write_bytes_delta": row["host_write_bytes_delta"],
                "fdp_hbmw_delta": row["fdp_hbmw_delta"],
                "fdp_mbmw_delta": row["fdp_mbmw_delta"],
                "fdp_mbe_delta": row["fdp_mbe_delta"],
                "write_p99_ms": row["write_p99_ms"],
                "all_p99_ms": row["all_p99_ms"],
                "records": row["records"],
                "failures": row["failures"],
                "waf_status": row["waf_status"],
            }
        )

    summary_fields = [
        "section",
        "trial",
        "run_dir",
        "mode",
        "run_id",
        "host_write_bytes_delta",
        "fdp_hbmw_delta",
        "fdp_mbmw_delta",
        "fdp_mbe_delta",
        "write_p99_ms",
        "all_p99_ms",
        "records",
        "failures",
        "waf_status",
    ]
    summary_writer = csv.DictWriter(sys.stdout, fieldnames=summary_fields)
    by_mode: dict[str, list[dict[str, object]]] = {}
    for row in rows:
        by_mode.setdefault(str(row["mode"]), []).append(row)
    for mode in ("no_fdp", "mixed", "separated"):
        mode_rows = by_mode.get(mode, [])
        if not mode_rows:
            continue
        host = [int(r["host_write_bytes_delta"]) for r in mode_rows if r["host_write_bytes_delta"] is not None]
        hbmw = [int(r["fdp_hbmw_delta"]) for r in mode_rows]
        write_p99 = [float(r["write_p99_ms"]) for r in mode_rows if r["write_p99_ms"] is not None]
        all_p99 = [float(r["all_p99_ms"]) for r in mode_rows if r["all_p99_ms"] is not None]
        records = [int(r["records"]) for r in mode_rows]
        failures_total = sum(int(r["failures"]) for r in mode_rows)
        summary_writer.writerow(
            {
                "section": "aggregate",
                "trial": f"runs={len(mode_rows)}",
                "run_dir": "",
                "mode": mode,
                "run_id": "",
                "host_write_bytes_delta": fmt(mean(host)),
                "fdp_hbmw_delta": fmt(mean(hbmw)),
                "fdp_mbmw_delta": "",
                "fdp_mbe_delta": "",
                "write_p99_ms": fmt(mean(write_p99)),
                "all_p99_ms": fmt(mean(all_p99)),
                "records": fmt(mean(records)),
                "failures": failures_total,
                "waf_status": (
                    f"host_stdev={fmt(stdev(host))};"
                    f"hbmw_stdev={fmt(stdev(hbmw))};"
                    f"write_p99_stdev={fmt(stdev(write_p99))}"
                ),
            }
        )
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
