"""Event-based MIDI tokenizer for the genre-comparison experiment.

Each music21 score is rendered to a sequence of events of three kinds:
    TIME_SHIFT_<d>   d in 16th-note units, d ∈ {1, 2, 3, ..., 32}
    NOTE_ON_<p>      MIDI pitch p ∈ {0, ..., 127}
    NOTE_OFF_<p>     MIDI pitch p ∈ {0, ..., 127}

Special tokens occupy the first few indices:
    0  [PAD]
    1  [BOS]   start-of-piece
    2  [EOS]   end-of-piece

Total vocab: 3 specials + 32 time-shifts + 128 note-ons + 128 note-offs = 291.
"""
from __future__ import annotations

import io
import warnings
from collections import Counter

from music21 import converter, midi as m21midi
import pretty_midi

# Suppress noisy music21 / pretty_midi warnings during bulk processing
warnings.filterwarnings("ignore")


PAD_ID = 0
BOS_ID = 1
EOS_ID = 2
N_SPECIAL = 3

N_TIME = 32      # TIME_SHIFT_1 .. TIME_SHIFT_32 (= 1 sixteenth .. 32 sixteenths = 2 whole notes)
N_PITCH = 128    # MIDI pitches 0..127

TIME_BASE = N_SPECIAL                         # 3
NOTEON_BASE = TIME_BASE + N_TIME              # 35
NOTEOFF_BASE = NOTEON_BASE + N_PITCH          # 163
VOCAB_SIZE = NOTEOFF_BASE + N_PITCH           # 291


def token_str(tid: int) -> str:
    """Human-readable token name (for debugging)."""
    if tid == PAD_ID: return "[PAD]"
    if tid == BOS_ID: return "[BOS]"
    if tid == EOS_ID: return "[EOS]"
    if TIME_BASE <= tid < NOTEON_BASE:
        return f"TIME_{tid - TIME_BASE + 1}"
    if NOTEON_BASE <= tid < NOTEOFF_BASE:
        return f"ON_{tid - NOTEON_BASE}"
    if NOTEOFF_BASE <= tid < VOCAB_SIZE:
        return f"OFF_{tid - NOTEOFF_BASE}"
    return f"?{tid}"


def tokenize_score(score, sixteenths_per_quarter: int = 4, max_shift: int = N_TIME) -> list[int]:
    """Convert a music21 Score (or a .mxl/.mid file path) into a token list.

    Strategy:
      1. Write the score to a temporary in-memory MIDI bytestream via music21.
      2. Parse with pretty_midi to get a flat list of (start, end, pitch)
         note events across all instruments.
      3. Quantize start/end times to 16th-note units.
      4. Build the sorted event timeline: at each timestamp, emit any
         NOTE_OFFs that fire there, then any NOTE_ONs.
      5. Emit TIME_SHIFT tokens between successive timestamps.
    """
    if isinstance(score, str):
        score = converter.parse(score)

    # Write music21 → temp MIDI file → read with pretty_midi
    # (music21's MidiFile.close() destroys in-memory buffers, so disk is simpler)
    import tempfile, os as _os
    with tempfile.NamedTemporaryFile(suffix=".mid", delete=False) as tmp:
        tmp_path = tmp.name
    try:
        score.write("midi", fp=tmp_path)
        pm = pretty_midi.PrettyMIDI(tmp_path)
    finally:
        try: _os.unlink(tmp_path)
        except OSError: pass

    # Compute "seconds per sixteenth" assuming the default tempo
    tempos = pm.get_tempo_changes()[1]
    bpm = float(tempos[0]) if len(tempos) > 0 else 120.0
    sec_per_quarter = 60.0 / bpm
    sec_per_sixteenth = sec_per_quarter / sixteenths_per_quarter

    events = []  # list of (time_in_sixteenths, kind, pitch)
    for inst in pm.instruments:
        if inst.is_drum:
            continue
        for n in inst.notes:
            t_on = round(n.start / sec_per_sixteenth)
            t_off = max(t_on + 1, round(n.end / sec_per_sixteenth))
            events.append((t_on, 0, n.pitch))   # 0 = on
            events.append((t_off, 1, n.pitch))  # 1 = off

    if not events:
        return [BOS_ID, EOS_ID]

    # Sort by time, then OFF before ON at the same time (let notes release first)
    events.sort(key=lambda e: (e[0], -e[1]))

    out = [BOS_ID]
    cur_t = events[0][0]
    # initial silence ignored (start time = 0 in token space)
    for t, kind, pitch in events:
        dt = t - cur_t
        while dt > 0:
            shift = min(dt, max_shift)
            out.append(TIME_BASE + (shift - 1))
            dt -= shift
        if kind == 0:
            out.append(NOTEON_BASE + pitch)
        else:
            out.append(NOTEOFF_BASE + pitch)
        cur_t = t
    out.append(EOS_ID)
    return out


def detokenize(tokens: list[int], sixteenths_per_quarter: int = 4, bpm: float = 120.0,
               instrument_program: int = 0) -> pretty_midi.PrettyMIDI:
    """Inverse of tokenize_score: build a pretty_midi object from a token list.

    Notes that get a NOTE_ON without a matching NOTE_OFF are auto-released
    at the end of the sequence (max duration 4 quarters).
    """
    sec_per_sixteenth = (60.0 / bpm) / sixteenths_per_quarter
    pm = pretty_midi.PrettyMIDI(initial_tempo=bpm)
    inst = pretty_midi.Instrument(program=instrument_program)
    pm.instruments.append(inst)

    cur_t_sx = 0                       # current time in sixteenths
    open_notes: dict[int, float] = {}  # pitch -> start time in seconds

    for tid in tokens:
        if tid in (PAD_ID, BOS_ID, EOS_ID):
            continue
        if TIME_BASE <= tid < NOTEON_BASE:
            cur_t_sx += (tid - TIME_BASE + 1)
        elif NOTEON_BASE <= tid < NOTEOFF_BASE:
            pitch = tid - NOTEON_BASE
            if pitch in open_notes:
                # Re-trigger: close the previous one
                start = open_notes[pitch]
                end = cur_t_sx * sec_per_sixteenth
                if end > start:
                    inst.notes.append(pretty_midi.Note(velocity=80, pitch=pitch, start=start, end=end))
            open_notes[pitch] = cur_t_sx * sec_per_sixteenth
        elif NOTEOFF_BASE <= tid < VOCAB_SIZE:
            pitch = tid - NOTEOFF_BASE
            if pitch in open_notes:
                start = open_notes.pop(pitch)
                end = cur_t_sx * sec_per_sixteenth
                if end > start:
                    inst.notes.append(pretty_midi.Note(velocity=80, pitch=pitch, start=start, end=end))

    # Close any still-open notes at the end of the sequence
    end_t = (cur_t_sx + 4) * sec_per_sixteenth
    for pitch, start in open_notes.items():
        inst.notes.append(pretty_midi.Note(velocity=80, pitch=pitch, start=start, end=end_t))

    return pm
