"""train_genres.py — Train four GPT-style Transformers on four different
musical genres and compare the generated output.

Genres (from music21's built-in corpora — no downloads needed):
    bach          433 chorales              (tonal 4-voice polyphony, ~1720)
    palestrina   1318 Renaissance pieces    (modal sacred polyphony, ~1570)
    ryansMammoth 1059 Irish/Scottish tunes  (jigs, reels, hornpipes, ~1880)
    trecento      103 14th-c. Italian       (ars nova polyphony, ~1370)

The model is the Ch 41 GPTLike, slightly bigger to handle 256-token windows.
Compute budget: ~2 minutes per genre on CPU.

Usage:
    python train_genres.py [--genres bach,palestrina,...] [--steps 1200] [--quick]

Outputs:
    samples/<genre>_sample.mid          — 200-token generation, rendered to MIDI
    samples/<genre>_pianoroll.png       — piano-roll visualisation
    samples/comparison.png              — 4-panel comparison plot
    checkpoints/<genre>.pt              — trained model weights
"""
from __future__ import annotations

import argparse
import json
import math
import random
import time
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from music21 import corpus

warnings.filterwarnings("ignore")

# Import the tokenizer
import sys
HERE = Path(__file__).resolve().parent
sys.path.insert(0, str(HERE))
from midi_tokenizer import (
    tokenize_score, detokenize, token_str,
    VOCAB_SIZE, BOS_ID, EOS_ID, PAD_ID,
    TIME_BASE, NOTEON_BASE, NOTEOFF_BASE,
)

# Import the Ch 41 building blocks
sys.path.insert(0, str(HERE.parent.parent / "part12_pretraining"))
from utils import (
    TransformerBlock, sinusoidal_positional_encoding, causal_mask, count_params,
)


# ---------------------------------------------------------------------------
# Model — same GPTLike as Ch 41, just bigger d_model and max_len
# ---------------------------------------------------------------------------

class GPTMIDI(nn.Module):
    """Decoder-only Transformer for MIDI event tokens."""

    def __init__(self, vocab_size: int, d_model: int = 128, n_heads: int = 4,
                 n_layers: int = 3, d_ff: int = 384, max_len: int = 256,
                 dropout: float = 0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.register_buffer("pos_emb",
                             sinusoidal_positional_encoding(max_len, d_model),
                             persistent=False)
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, x):
        B, T = x.shape
        h = self.tok_emb(x) + self.pos_emb[:T].unsqueeze(0)
        mask = causal_mask(T, x.device)
        for block in self.blocks:
            h = block(h, mask)
        return self.lm_head(self.ln_f(h))


# ---------------------------------------------------------------------------
# Corpus loaders
# ---------------------------------------------------------------------------

GENRES = {
    # Original 4 (pre-19th-century Western)
    "bach":         ("composer:bach",         "Bach chorales (tonal polyphony, c.1720)"),
    "palestrina":   ("composer:palestrina",   "Palestrina (Renaissance polyphony, c.1570)"),
    "ryansMammoth": ("composer:ryansMammoth", "Irish/Scottish dance tunes (Ryan's Mammoth, 1880s)"),
    "trecento":     ("composer:trecento",     "Trecento Italian polyphony (ars nova, c.1370)"),
    # Second batch — Western classical at greater stylistic distance
    "monteverdi":   ("composer:monteverdi",   "Monteverdi madrigals (chromatic late Renaissance, c.1600)"),
    "beethoven":    ("composer:beethoven",    "Beethoven string quartets (Classical/Romantic, c.1810)"),
    "essenFolksong":("opus:essenFolksong",    "Essen Folksong Collection (German monophonic folk songs)"),
    "atonal":       ("synthetic:atonal",      "Synthetic 12-tone (Schoenberg-style atonal, programmatic)"),
    # Third batch — modern popular-music genres (synthesised because public-domain MIDI is scarce)
    "metal":        ("synthetic:metal",       "Heavy metal (low-register power chords, fast 16ths)"),
    "rock":         ("synthetic:rock",        "Rock (I-IV-V chord progressions, mid register, backbeat)"),
    "pop":          ("synthetic:pop",         "Pop (I-V-vi-IV progressions, melody-driven, high register)"),
    "rap":          ("synthetic:rap",         "Hip-hop beat (kick/snare/hat drum + low bass, sparse melody)"),
}


def expand_paths(corpus_key: str) -> list:
    """Resolve a 'composer:NAME' or 'opus:NAME' key to a list of music21 scores."""
    if corpus_key.startswith("composer:"):
        name = corpus_key.split(":", 1)[1]
        return [(p, "score") for p in corpus.getComposer(name)]
    if corpus_key.startswith("opus:"):
        # Each "file" is an Opus containing many short pieces; expand
        name = corpus_key.split(":", 1)[1]
        out = []
        for p in corpus.getComposer(name):
            try:
                op = corpus.parse(str(p))
                if hasattr(op, "scores"):
                    out.extend((s, "opus_score") for s in op.scores)
                else:
                    out.append((p, "score"))
            except Exception:
                continue
        return out
    raise ValueError(f"Unknown corpus key prefix: {corpus_key}")


def gen_synthetic_atonal(n_pieces: int = 80, notes_per_piece: int = 64,
                         seed: int = 0) -> list[int]:
    """Generate a stream of tokens directly (no music21 round-trip) in a
    Schoenberg-style 12-tone-row aesthetic.

    Each 'piece' uses a random row of 12 distinct pitches (in a chosen
    register), played as a sequence — no repeats until the row is complete,
    then a row inversion/permutation. Durations are random short values.
    This is the deliberately a-tonal control case.
    """
    import random as _random
    _random.seed(seed)
    out: list[int] = []

    for _ in range(n_pieces):
        out.append(BOS_ID)
        center = _random.randint(48, 72)        # row center
        # Random 12-tone row: 12 distinct pitches within an octave around center
        row = list(range(center - 6, center + 6))
        _random.shuffle(row)
        # Play several iterations of the row, sometimes inverted or reversed
        n_iter = max(2, notes_per_piece // 12)
        for it in range(n_iter):
            transform = _random.choice(["P", "R", "I", "RI"])
            seq = row[:]
            if transform == "R":
                seq = list(reversed(seq))
            elif transform == "I":
                # Inversion around the first note
                p0 = seq[0]
                seq = [2 * p0 - p for p in seq]
            elif transform == "RI":
                p0 = seq[0]
                seq = list(reversed([2 * p0 - p for p in seq]))
            # Clip to valid MIDI range
            seq = [max(21, min(108, p)) for p in seq]
            # Emit each pitch as ON + TIME_SHIFT + OFF
            for pitch in seq:
                dur = _random.choice([1, 2, 2, 3, 4])
                out.append(NOTEON_BASE + pitch)
                # one short time-shift before the note ends
                shift = min(dur, 32)
                if shift > 0:
                    out.append(TIME_BASE + shift - 1)
                out.append(NOTEOFF_BASE + pitch)
        out.append(EOS_ID)
    return out


def _emit_note(out: list[int], pitch: int, duration: int):
    """Helper: emit NOTE_ON, TIME_SHIFT, NOTE_OFF for a single note of length
    `duration` (in sixteenth-note units, 1..32)."""
    pitch = max(21, min(108, pitch))
    out.append(NOTEON_BASE + pitch)
    d = max(1, min(32, duration))
    out.append(TIME_BASE + d - 1)
    out.append(NOTEOFF_BASE + pitch)


def _emit_chord(out: list[int], pitches: list[int], duration: int):
    """Helper: emit a chord — all NOTE_ONs, then a TIME_SHIFT, then all
    NOTE_OFFs."""
    pitches = [max(21, min(108, p)) for p in pitches]
    for p in pitches:
        out.append(NOTEON_BASE + p)
    d = max(1, min(32, duration))
    out.append(TIME_BASE + d - 1)
    for p in pitches:
        out.append(NOTEOFF_BASE + p)


def gen_synthetic_metal(n_pieces: int = 80, seed: int = 0) -> list[int]:
    """Heavy metal: low-register power chords (root + fifth + octave),
    fast palm-muted 16ths, minor/Phrygian-flavoured progressions."""
    import random as _r
    _r.seed(seed)
    # Phrygian-ish progressions (degrees in semitones from root, in i-VII-VI-i families)
    PROGS = [
        [0,  0, -2,  0],   # i - i - VI - i
        [0, -2, -1,  0],   # i - VI - bVI - i (Phrygian darkness)
        [0,  0, -5,  0],   # i - i - VII - i
        [0, -5, -2,  0],   # i - VII - VI - i
        [0,  3,  5,  7],   # i - III - V - VII (going up)
    ]
    out: list[int] = []
    for _ in range(n_pieces):
        out.append(BOS_ID)
        key_root = _r.randint(28, 40)   # E1 .. E2 (very low)
        prog = _r.choice(PROGS)
        bars = _r.randint(3, 6)
        for _bar in range(bars):
            for chord_offset in prog:
                root = key_root + chord_offset
                fifth = root + 7
                octave = root + 12
                # 8 or 16 rapid 16th-note hits (palm-muted feel)
                n_hits = _r.choice([8, 12, 16])
                for h in range(n_hits):
                    _emit_chord(out, [root, fifth, octave], duration=1)
                    if _r.random() < 0.15:
                        out.append(TIME_BASE + 0)  # quick rest (1 sixteenth)
        out.append(EOS_ID)
    return out


def gen_synthetic_rock(n_pieces: int = 80, seed: int = 0) -> list[int]:
    """Rock: mid-register triads on I-IV-V or I-V-vi-IV, melody on top,
    backbeat rhythm (chord on every beat, melody on offbeats)."""
    import random as _r
    _r.seed(seed)
    MAJOR_OFFSETS = [0, 4, 7]    # major triad
    MINOR_OFFSETS = [0, 3, 7]    # minor triad
    PROGS = [
        [(0, "M"), (5, "M"), (7, "M")],                # I-IV-V
        [(0, "M"), (7, "M"), (9, "m"), (5, "M")],      # I-V-vi-IV
        [(0, "M"), (9, "m"), (5, "M"), (7, "M")],      # I-vi-IV-V (50s doo-wop)
        [(0, "M"), (0, "M"), (0, "M"), (0, "M"),
         (5, "M"), (5, "M"), (0, "M"), (0, "M"),
         (7, "M"), (5, "M"), (0, "M"), (7, "M")],       # 12-bar blues sketch
    ]
    out: list[int] = []
    for _ in range(n_pieces):
        out.append(BOS_ID)
        key_root = _r.randint(48, 55)  # C3..G3 — guitar range
        prog = _r.choice(PROGS)
        for chord_degree, quality in prog:
            offsets = MAJOR_OFFSETS if quality == "M" else MINOR_OFFSETS
            root = key_root + chord_degree
            chord = [root + o for o in offsets]
            # Each chord lasts 1 bar = 16 sixteenths. Strum on beats 1,2,3,4.
            for beat in range(4):
                _emit_chord(out, chord, duration=2)        # 8th-note chord stab
                # Melody note on the offbeat — random choice from chord tones + 9th
                mel_choices = [root + 12, root + 12 + 4, root + 12 + 7, root + 14]
                mel = _r.choice(mel_choices)
                _emit_note(out, mel, duration=2)
        out.append(EOS_ID)
    return out


def gen_synthetic_pop(n_pieces: int = 80, seed: int = 0) -> list[int]:
    """Pop: predictable I-V-vi-IV progressions, melody-driven, higher
    register than rock, even 8th-note flow."""
    import random as _r
    _r.seed(seed)
    MAJOR_OFFSETS = [0, 4, 7]
    MINOR_OFFSETS = [0, 3, 7]
    # The most famous progression in pop (axis of awesome): I-V-vi-IV
    PROGS = [
        [(0, "M"), (7, "M"), (9, "m"), (5, "M")],     # I-V-vi-IV (canonical)
        [(9, "m"), (5, "M"), (0, "M"), (7, "M")],     # vi-IV-I-V (sad pop)
        [(0, "M"), (5, "M"), (9, "m"), (7, "M")],     # I-IV-vi-V
    ]
    # A pentatonic-ish melody template (major-key scale degrees, then bend back)
    MELODY_DEGREES = [0, 2, 4, 5, 7, 9, 11, 12]       # major scale
    out: list[int] = []
    for _ in range(n_pieces):
        out.append(BOS_ID)
        key_root = _r.randint(48, 53)   # C3..F3 for chords
        mel_offset = 12                  # melody one octave above chords
        prog = _r.choice(PROGS)
        # 4 bars of progression, repeated twice
        for _rep in range(2):
            for chord_degree, quality in prog:
                offsets = MAJOR_OFFSETS if quality == "M" else MINOR_OFFSETS
                root = key_root + chord_degree
                chord = [root + o for o in offsets]
                # Hold chord for whole bar (16 sixteenths); pop a melody on top
                # Emit chord as a half-bar pad + melody as a series of 8ths
                _emit_chord(out, chord, duration=2)   # short pad
                # 7 melody notes through the rest of the bar (8th-note movement)
                for _ in range(7):
                    deg = _r.choice(MELODY_DEGREES)
                    mel_pitch = key_root + chord_degree + mel_offset + deg
                    _emit_note(out, mel_pitch, duration=2)
        out.append(EOS_ID)
    return out


def gen_synthetic_rap(n_pieces: int = 80, seed: int = 0) -> list[int]:
    """Hip-hop beat: drum-machine pitches (kick/snare/closed hat) on a
    fixed 4/4 grid + repetitive bass line. Sparse melodic content."""
    import random as _r
    _r.seed(seed)
    # General-MIDI percussion pitches (channel 10 in real MIDI; here just
    # treated as pitches that happen to be in the percussion range)
    KICK = 36         # C2
    SNARE = 38        # D2
    HAT_CLOSED = 42   # F#2
    HAT_OPEN = 46

    out: list[int] = []
    for _ in range(n_pieces):
        out.append(BOS_ID)
        bass_root = _r.randint(28, 36)    # E1..C2
        # Bass note pattern: root, fifth, root, minor-7 (or similar)
        bass_pattern = _r.choice([
            [0, 0, 7, 0],
            [0, -2, 0, 5],
            [0, 0, 5, 0],
            [0, 7, 0, -2],
        ])
        bars = _r.randint(4, 8)
        for _bar in range(bars):
            # Each bar = 16 sixteenths. Drum pattern + bass overlay.
            for sx in range(16):
                # Layer drums + bass — emit notes-on, time-shift 1, notes-off
                hits = []
                # Hi-hat on every 8th (sx = 0, 2, 4, 6, ..., 14)
                if sx % 2 == 0:
                    hits.append(HAT_CLOSED)
                # Open hat on the 'and' of beat 4 (sx == 14)
                if sx == 14 and _r.random() < 0.4:
                    hits.append(HAT_OPEN)
                # Kick on beat 1 (sx == 0) and beat 3 (sx == 8); occasionally
                # a syncopated kick on the '+' of 3 (sx == 10)
                if sx in (0, 8):
                    hits.append(KICK)
                if sx == 10 and _r.random() < 0.5:
                    hits.append(KICK)
                # Snare on beats 2 and 4 (sx == 4, 12)
                if sx in (4, 12):
                    hits.append(SNARE)
                # Bass note on beats 1 and 3 (sx == 0, 8), sometimes 7
                if sx == 0:
                    hits.append(bass_root + bass_pattern[0])
                elif sx == 4:
                    hits.append(bass_root + bass_pattern[1])
                elif sx == 8:
                    hits.append(bass_root + bass_pattern[2])
                elif sx == 12:
                    hits.append(bass_root + bass_pattern[3])
                # Emit
                if hits:
                    _emit_chord(out, hits, duration=1)
                else:
                    out.append(TIME_BASE + 0)  # rest
        out.append(EOS_ID)
    return out


def load_genre_tokens(genre_key: str, max_pieces: int = 80,
                      cache_path: Path | None = None) -> list[int]:
    """Tokenise up to max_pieces from the named music21 corpus (or generate
    synthetic data) and return a single concatenated token sequence."""
    if cache_path and cache_path.exists():
        return json.loads(cache_path.read_text())

    corpus_key = GENRES[genre_key][0]

    # Synthetic case
    if corpus_key.startswith("synthetic:"):
        kind = corpus_key.split(":", 1)[1]
        if kind == "atonal":
            tokens = gen_synthetic_atonal(n_pieces=max_pieces)
        elif kind == "metal":
            tokens = gen_synthetic_metal(n_pieces=max_pieces)
        elif kind == "rock":
            tokens = gen_synthetic_rock(n_pieces=max_pieces)
        elif kind == "pop":
            tokens = gen_synthetic_pop(n_pieces=max_pieces)
        elif kind == "rap":
            tokens = gen_synthetic_rap(n_pieces=max_pieces)
        else:
            raise ValueError(f"Unknown synthetic kind: {kind}")
        print(f"    [{genre_key}] generated synthetic '{kind}' — {len(tokens):,} tokens")
        if cache_path:
            cache_path.write_text(json.dumps(tokens))
        return tokens

    # Real-corpus case (composer: or opus:)
    items = expand_paths(corpus_key)[:max_pieces]
    all_tokens: list[int] = []
    t0 = time.time()
    skipped = 0
    for i, (item, kind) in enumerate(items):
        try:
            if kind == "score":
                tokens = tokenize_score(str(item))
            else:  # kind == "opus_score" — already a music21 Score object
                tokens = tokenize_score(item)
        except Exception as e:
            skipped += 1
            continue
        all_tokens.extend(tokens)
        if (i + 1) % 20 == 0:
            print(f"    [{genre_key}] {i + 1}/{len(items)} pieces, "
                  f"{len(all_tokens):,} tokens, {time.time() - t0:.1f}s elapsed")
    print(f"    [{genre_key}] {len(items) - skipped} pieces ok, {skipped} skipped, "
          f"{len(all_tokens):,} tokens total")
    if cache_path:
        cache_path.write_text(json.dumps(all_tokens))
    return all_tokens


# ---------------------------------------------------------------------------
# Training loop
# ---------------------------------------------------------------------------

def sample_batch(data: torch.Tensor, block_size: int, batch_size: int):
    ix = torch.randint(0, len(data) - block_size - 1, (batch_size,))
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
    return x, y


def train_one_genre(name: str, tokens: list[int], *,
                    steps: int = 1200, block: int = 128, batch: int = 32,
                    lr: float = 3e-3, log_every: int = 100) -> tuple[GPTMIDI, list[float]]:
    torch.manual_seed(0); random.seed(0)
    data = torch.tensor(tokens, dtype=torch.long)
    model = GPTMIDI(VOCAB_SIZE)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=steps)
    losses = []
    t0 = time.time()
    model.train()
    for step in range(steps):
        x, y = sample_batch(data, block, batch)
        logits = model(x)
        loss = F.cross_entropy(logits.reshape(-1, VOCAB_SIZE),
                               y.reshape(-1), ignore_index=PAD_ID)
        opt.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step(); sched.step()
        losses.append(loss.item())
        if step % log_every == 0 or step == steps - 1:
            print(f"  [{name}] step {step:4d}  loss={loss.item():.3f}  "
                  f"({time.time() - t0:.1f}s)")
    print(f"  [{name}] done in {time.time() - t0:.1f}s.")
    return model, losses


# ---------------------------------------------------------------------------
# Generation
# ---------------------------------------------------------------------------

@torch.no_grad()
def generate(model: GPTMIDI, prompt_ids: list[int], *,
             max_new: int = 200, temperature: float = 0.9,
             top_k: int | None = 40, seed: int = 0) -> list[int]:
    model.eval()
    rng = torch.Generator(device="cpu").manual_seed(seed)
    out = list(prompt_ids)
    for _ in range(max_new):
        ctx = out[-(model.max_len - 1):]
        x = torch.tensor([ctx], dtype=torch.long)
        logits = model(x)[0, -1] / temperature
        if top_k is not None:
            topv, topi = torch.topk(logits, top_k)
            probs = torch.zeros_like(logits)
            probs[topi] = F.softmax(topv, dim=-1)
        else:
            probs = F.softmax(logits, dim=-1)
        nxt = torch.multinomial(probs, 1, generator=rng).item()
        out.append(nxt)
        if nxt == EOS_ID:
            break
    return out


# ---------------------------------------------------------------------------
# Visualisation + quantitative comparison
# ---------------------------------------------------------------------------

def piano_roll(notes, ax, title: str, t_max: float = 30.0):
    """Render a list of pretty_midi notes as a piano-roll on the given axis."""
    if not notes:
        ax.text(0.5, 0.5, "(empty)", ha="center", va="center", transform=ax.transAxes)
        ax.set_title(title); return
    for n in notes:
        if n.start > t_max: continue
        ax.broken_barh([(n.start, min(n.end, t_max) - n.start)],
                       (n.pitch - 0.4, 0.8),
                       facecolors=plt.cm.viridis((n.pitch - 30) / 70), alpha=0.85)
    pitches = [n.pitch for n in notes if n.start < t_max]
    if pitches:
        ax.set_ylim(min(pitches) - 3, max(pitches) + 3)
    ax.set_xlim(0, t_max)
    ax.set_xlabel("time (s)")
    ax.set_ylabel("MIDI pitch")
    ax.set_title(title, fontsize=10)
    ax.grid(alpha=0.25)


def stats_from_notes(notes) -> dict:
    """Quantitative summary of a generated piece."""
    pitches = [n.pitch for n in notes]
    durations = [n.end - n.start for n in notes]
    if not notes:
        return dict(n_notes=0, pitch_mean=0, pitch_std=0, density=0, polyphony=0)
    # Average polyphony: at each note onset, how many notes are sounding?
    # Simple proxy: overlap ratio
    total_dur = sum(durations)
    span = max(n.end for n in notes) - min(n.start for n in notes)
    polyphony = total_dur / span if span > 0 else 0
    return dict(
        n_notes=len(notes),
        pitch_mean=float(np.mean(pitches)),
        pitch_std=float(np.std(pitches)),
        density=len(notes) / span if span > 0 else 0,
        polyphony=polyphony,
    )


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--genres", default=",".join(GENRES.keys()),
                   help="Comma-separated subset of " + ",".join(GENRES.keys()))
    p.add_argument("--steps", type=int, default=1200)
    p.add_argument("--quick", action="store_true",
                   help="Quick sanity-check run (fewer pieces, fewer steps).")
    args = p.parse_args()

    if args.quick:
        args.steps = 200

    SAMPLES = HERE / "samples"
    CKPT = HERE / "checkpoints"
    SAMPLES.mkdir(exist_ok=True)
    CKPT.mkdir(exist_ok=True)

    genre_list = [g.strip() for g in args.genres.split(",") if g.strip()]
    print(f"Genres to train: {genre_list}")
    print(f"Steps per genre: {args.steps}")
    print(f"Vocab size: {VOCAB_SIZE}\n")

    results = {}

    for g in genre_list:
        if g not in GENRES:
            print(f"  [skip] unknown genre {g!r}")
            continue
        composer_key, desc = GENRES[g]
        print(f"\n{'=' * 70}")
        print(f"  Genre: {g}  —  {desc}")
        print(f"{'=' * 70}")
        max_pieces = 30 if args.quick else 80
        tokens = load_genre_tokens(g, max_pieces=max_pieces,
                                   cache_path=HERE / f".cache_{g}.json")
        if len(tokens) < 200:
            print(f"  [skip] only {len(tokens)} tokens, not enough to train")
            continue
        print(f"  Training GPTMIDI on {len(tokens):,} tokens...")
        model, losses = train_one_genre(g, tokens, steps=args.steps)

        # Save weights
        torch.save(model.state_dict(), CKPT / f"{g}.pt")

        # Generate
        sample_tokens = generate(model, [BOS_ID], max_new=400,
                                 temperature=0.9, top_k=40, seed=0)
        # Convert to MIDI
        pm = detokenize(sample_tokens)
        midi_path = SAMPLES / f"{g}_sample.mid"
        pm.write(str(midi_path))
        print(f"  Generated -> {midi_path.name}: {len(pm.instruments[0].notes)} notes, "
              f"{pm.get_end_time():.1f}s")

        # Piano-roll preview
        fig, ax = plt.subplots(figsize=(9, 4))
        piano_roll(pm.instruments[0].notes, ax, f"{g} — generated sample", t_max=30)
        plt.tight_layout()
        plt.savefig(SAMPLES / f"{g}_pianoroll.png", dpi=110, bbox_inches="tight")
        plt.close()

        results[g] = dict(
            desc=desc,
            train_tokens=len(tokens),
            final_loss=losses[-1],
            n_notes=len(pm.instruments[0].notes),
            stats=stats_from_notes(pm.instruments[0].notes),
            midi_path=str(midi_path),
        )

    # -----------------------------------------------------------------------
    # 4-panel comparison
    # -----------------------------------------------------------------------
    n = len(results)
    if n == 0:
        print("\nNo results to compare.")
        return
    fig, axes = plt.subplots(2, 2, figsize=(14, 8))
    axes = axes.flatten()
    for ax, (g, info) in zip(axes, results.items()):
        midi_path = info["midi_path"]
        import pretty_midi
        pm = pretty_midi.PrettyMIDI(midi_path)
        piano_roll(pm.instruments[0].notes, ax, f"{g}  —  {info['desc']}", t_max=30)
    for ax in axes[n:]:
        ax.axis("off")
    plt.suptitle("Generated samples — same architecture, different genre",
                 fontsize=13, fontweight="bold", y=1.01)
    plt.tight_layout()
    plt.savefig(SAMPLES / "comparison.png", dpi=110, bbox_inches="tight")
    plt.close()
    print(f"\nWrote comparison -> {SAMPLES / 'comparison.png'}")

    # Print summary table
    print(f"\n{'genre':<14s} {'tokens':>8s} {'final loss':>10s} {'notes':>6s} {'p̄':>6s} {'σ_p':>6s} {'density':>8s} {'poly':>6s}")
    for g, info in results.items():
        s = info["stats"]
        print(f"{g:<14s} {info['train_tokens']:>8,d} {info['final_loss']:>10.3f} "
              f"{s['n_notes']:>6d} {s['pitch_mean']:>6.1f} {s['pitch_std']:>6.1f} "
              f"{s['density']:>8.2f} {s['polyphony']:>6.2f}")


if __name__ == "__main__":
    main()
