"""build_comparison.py — once all 12 genres are trained, regenerate the
comparison plot + collect per-genre stats from samples/*.mid.

This runs over existing samples/*.mid files and does not retrain anything.
"""
from __future__ import annotations

import json
import sys
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import pretty_midi

warnings.filterwarnings("ignore")

HERE = Path(__file__).resolve().parent
sys.path.insert(0, str(HERE))

# Ordered list of genres + descriptions for the comparison plot
GENRE_ORDER = [
    # Row 1: pre-baroque polyphony
    ("trecento",      "Trecento (Italian ars nova, c.1370)"),
    ("palestrina",    "Palestrina (Renaissance, c.1570)"),
    ("monteverdi",    "Monteverdi (late Renaissance, c.1600)"),
    # Row 2: tonal western classical
    ("bach",          "Bach chorales (c.1720)"),
    ("beethoven",     "Beethoven (string quartets, c.1810)"),
    ("essenFolksong", "Essen Folksong (German folk)"),
    # Row 3: pre-modern dance + atonal control
    ("ryansMammoth",  "Ryan's Mammoth (Irish/Scottish dance, 1880s)"),
    ("atonal",        "Synthetic 12-tone (atonal control)"),
    ("rap",           "Hip-hop beat (synthetic)"),
    # Row 4: modern popular genres
    ("rock",          "Rock (synthetic)"),
    ("pop",           "Pop (synthetic)"),
    ("metal",         "Heavy metal (synthetic)"),
]


def piano_roll(notes, ax, title: str, t_max: float = 30.0):
    if not notes:
        ax.text(0.5, 0.5, "(empty)", ha="center", va="center", transform=ax.transAxes)
        ax.set_title(title, fontsize=9); return
    for n in notes:
        if n.start > t_max: continue
        ax.broken_barh([(n.start, min(n.end, t_max) - n.start)],
                       (n.pitch - 0.4, 0.8),
                       facecolors=plt.cm.viridis((n.pitch - 30) / 70), alpha=0.85)
    pitches = [n.pitch for n in notes if n.start < t_max]
    if pitches:
        ax.set_ylim(min(pitches) - 3, max(pitches) + 3)
    ax.set_xlim(0, t_max)
    ax.set_xlabel("time (s)", fontsize=8)
    ax.set_ylabel("pitch", fontsize=8)
    ax.set_title(title, fontsize=9)
    ax.grid(alpha=0.25)
    ax.tick_params(labelsize=7)


def stats_from_notes(notes) -> dict:
    if not notes:
        return dict(n_notes=0, pitch_mean=0.0, pitch_std=0.0, density=0.0, polyphony=0.0)
    import numpy as np
    pitches = [n.pitch for n in notes]
    durations = [n.end - n.start for n in notes]
    total_dur = sum(durations)
    span = max(n.end for n in notes) - min(n.start for n in notes)
    return dict(
        n_notes=len(notes),
        pitch_mean=float(np.mean(pitches)),
        pitch_std=float(np.std(pitches)),
        density=len(notes) / span if span > 0 else 0.0,
        polyphony=total_dur / span if span > 0 else 0.0,
    )


def main():
    samples_dir = HERE / "samples"
    n_genres = len(GENRE_ORDER)
    cols = 3
    rows = (n_genres + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 5.5, rows * 3.2))
    axes = axes.flatten()

    stats = {}
    missing = []
    for ax, (g, desc) in zip(axes, GENRE_ORDER):
        midi_path = samples_dir / f"{g}_sample.mid"
        if not midi_path.exists():
            ax.text(0.5, 0.5, f"[{g}]\nnot trained yet", ha="center", va="center",
                    transform=ax.transAxes, color="#94a3b8", fontsize=10)
            ax.set_title(desc, fontsize=9)
            ax.axis("off")
            missing.append(g)
            continue
        pm = pretty_midi.PrettyMIDI(str(midi_path))
        notes = pm.instruments[0].notes if pm.instruments else []
        piano_roll(notes, ax, desc, t_max=30)
        stats[g] = dict(desc=desc, **stats_from_notes(notes))

    # Hide unused axes
    for ax in axes[n_genres:]:
        ax.axis("off")

    plt.suptitle("MIDI Genre Transformer — same architecture, twelve different corpora",
                 fontsize=13, fontweight="bold", y=1.005)
    plt.tight_layout()
    out_path = samples_dir / "comparison.png"
    plt.savefig(out_path, dpi=110, bbox_inches="tight")
    plt.close()
    print(f"Wrote {out_path}")
    if missing:
        print(f"\n[WARN] missing samples for: {', '.join(missing)}")

    # Stats table
    print(f"\n{'genre':<16s} {'notes':>6s} {'mean p':>8s} {'std p':>7s} {'density':>9s} {'poly':>7s}")
    print("-" * 60)
    for g, info in stats.items():
        print(f"{g:<16s} {info['n_notes']:>6d} {info['pitch_mean']:>8.1f} "
              f"{info['pitch_std']:>7.1f} {info['density']:>9.2f} {info['polyphony']:>7.2f}")

    # Save as JSON for the landing page to consume
    (HERE / "stats.json").write_text(json.dumps(stats, indent=2))


if __name__ == "__main__":
    main()
