"""ch41_complete.py — Chapter 41 (Pretraining: BERT and GPT) in one file.

The full code for Chapter 41 of "Classical Foundations of Artificial Neural
Networks" (https://bnaskrecki.faculty.wmi.amu.edu.pl/nnets/). This is the
single-file consolidation of:

    - the building blocks normally imported from `utils.py`
      (CharTokenizer, TransformerBlock, MultiHeadAttention, FeedForward,
      Config, causal_mask, sinusoidal_positional_encoding, count_params,
      load_shakespeare, sample_batch)
    - the chapter-specific classes (GPTLike, BertLike) and training loops
      (train_lm for CLM, train_mlm for MLM, make_mlm_batch for 80-10-10
      corruption, held_out_mlm_accuracy for §41.5's measurement, LinearProbe
      for the fine-tuning experiment).

Run as:
    python ch41_complete.py

Requires:
    pip install torch numpy matplotlib

External data:
    shakespeare.txt — the 100 KB Tiny Shakespeare corpus. The script will
    auto-download it from karpathy/char-rnn if not found in the current
    directory.

Total runtime on a modern CPU: ~100 seconds.
"""

from __future__ import annotations

import math
import os
import random
import time
import urllib.request
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


# =============================================================================
# § corpus + tokenizer (inlined from utils.py)
# =============================================================================

SHAKESPEARE_URL = (
    "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/"
    "tinyshakespeare/input.txt"
)


def load_shakespeare(max_chars: int | None = 80_000, path: str = "shakespeare.txt") -> str:
    """Return Tiny Shakespeare. Downloads if not present."""
    if not os.path.exists(path):
        print(f"Downloading Tiny Shakespeare → {path} ...")
        urllib.request.urlretrieve(SHAKESPEARE_URL, path)
    with open(path, "r", encoding="utf-8") as f:
        text = f.read()
    if max_chars is not None:
        text = text[:max_chars]
    return text


class CharTokenizer:
    """Minimal char-level tokenizer with a single [MASK] symbol at index 0."""

    MASK_ID = 0
    MASK_STR = "\x00"  # never appears in real text

    def __init__(self, text: str) -> None:
        chars = sorted(set(text))
        self.itos = [self.MASK_STR] + chars
        self.stoi = {c: i for i, c in enumerate(self.itos)}
        self.vocab_size = len(self.itos)

    def encode(self, text: str) -> torch.Tensor:
        return torch.tensor([self.stoi[c] for c in text], dtype=torch.long)

    def decode(self, ids) -> str:
        if torch.is_tensor(ids):
            ids = ids.tolist()
        return "".join("_" if i == self.MASK_ID else self.itos[i] for i in ids)


# =============================================================================
# § Transformer building blocks (inlined from utils.py — same as Ch 40)
# =============================================================================

def causal_mask(T: int, device=None) -> torch.Tensor:
    """Lower-triangular (1, 1, T, T) mask: 1 = allowed, 0 = blocked."""
    m = torch.tril(torch.ones(T, T, device=device))
    return m.unsqueeze(0).unsqueeze(0)


def sinusoidal_positional_encoding(T: int, d_model: int) -> torch.Tensor:
    pe = torch.zeros(T, d_model)
    position = torch.arange(0, T, dtype=torch.float).unsqueeze(1)
    div = torch.exp(torch.arange(0, d_model, 2).float() *
                    -(math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div)
    pe[:, 1::2] = torch.cos(position * div)
    return pe                                              # (T, d_model)


def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)
    scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    attn = F.softmax(scores, dim=-1)
    return attn @ V, attn


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model, self.n_heads = d_model, n_heads
        self.d_k = d_model // n_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def _split(self, x):
        B, T, _ = x.shape
        return x.view(B, T, self.n_heads, self.d_k).transpose(1, 2)

    def forward(self, Qx, Kx, Vx, mask=None, return_attn=False):
        Q = self._split(self.W_q(Qx))
        K = self._split(self.W_k(Kx))
        V = self._split(self.W_v(Vx))
        out, attn = scaled_dot_product_attention(Q, K, V, mask)
        out = out.transpose(1, 2).contiguous().view(Qx.size(0), Qx.size(1), self.d_model)
        return (self.W_o(out), attn) if return_attn else self.W_o(out)


class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):
    """Pre-LN block, no cross-attention. Used unmodified for both
    GPT-like (with causal mask) and BERT-like (with mask=None)."""

    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        a = self.attn(self.ln1(x), self.ln1(x), self.ln1(x), mask)
        x = x + self.drop(a)
        f = self.ff(self.ln2(x))
        x = x + self.drop(f)
        return x


@dataclass
class Config:
    """Tiny configuration that fits both pretraining runs under ~60s CPU."""
    vocab_size: int = 62
    d_model: int = 96
    n_heads: int = 4
    d_ff: int = 256
    n_layers: int = 2
    max_len: int = 64
    dropout: float = 0.1


def count_params(module: nn.Module) -> int:
    return sum(p.numel() for p in module.parameters() if p.requires_grad)


def sample_batch(data: torch.Tensor, block_size: int, batch_size: int, device="cpu"):
    """Random contiguous blocks for next-token CLM."""
    ix = torch.randint(0, len(data) - block_size - 1, (batch_size,))
    x = torch.stack([data[i:i + block_size] for i in ix]).to(device)
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix]).to(device)
    return x, y


# =============================================================================
# § 41.2 — GPT-style decoder-only Transformer
# =============================================================================

class GPTLike(nn.Module):
    """Decoder-only Transformer: same TransformerBlock as Ch 40, every
    self-attention layer wears the causal mask."""

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.register_buffer(
            "pos_emb",
            sinusoidal_positional_encoding(cfg.max_len, cfg.d_model),
            persistent=False,
        )
        self.blocks = nn.ModuleList([
            TransformerBlock(cfg.d_model, cfg.n_heads, cfg.d_ff, cfg.dropout)
            for _ in range(cfg.n_layers)
        ])
        self.ln_f = nn.LayerNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

    def forward(self, x):
        B, T = x.shape
        assert T <= self.cfg.max_len
        h = self.tok_emb(x) + self.pos_emb[:T].unsqueeze(0)
        mask = causal_mask(T, x.device)
        for block in self.blocks:
            h = block(h, mask)
        return self.lm_head(self.ln_f(h))                  # (B, T, V)


def train_lm(model: GPTLike, data: torch.Tensor, tok: CharTokenizer,
             *, steps=600, block=48, batch=64, lr=3e-3, log_every=50,
             device="cpu") -> list[float]:
    """Causal LM training. Equivalent to §41.2's train_lm."""
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=steps)
    losses = []
    t0 = time.time()
    model.train()
    for step in range(steps):
        x, y = sample_batch(data, block, batch, device)
        logits = model(x)
        loss = F.cross_entropy(logits.reshape(-1, model.cfg.vocab_size),
                               y.reshape(-1), ignore_index=tok.MASK_ID)
        opt.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step(); sched.step()
        losses.append(loss.item())
        if step % log_every == 0 or step == steps - 1:
            print(f"  step {step:4d}  loss={loss.item():.4f}  "
                  f"({time.time() - t0:.1f}s elapsed)")
    print(f"Done in {time.time() - t0:.1f}s.")
    return losses


@torch.no_grad()
def generate(model: GPTLike, prompt_ids: torch.Tensor, *,
             max_new=200, temperature=1.0, top_k: int | None = None,
             device="cpu", rng=None) -> torch.Tensor:
    """Autoregressive sampling — exactly the §41.2 generate function."""
    model.eval()
    rng = rng or torch.Generator(device="cpu").manual_seed(0)
    out = prompt_ids.clone().to(device)
    for _ in range(max_new):
        ctx = out[-(model.cfg.max_len - 1):]
        logits = model(ctx.unsqueeze(0))[0, -1] / temperature
        if top_k is not None:
            topv, topi = torch.topk(logits, top_k)
            probs = torch.zeros_like(logits)
            probs[topi] = F.softmax(topv, dim=-1)
        else:
            probs = F.softmax(logits, dim=-1)
        nxt = torch.multinomial(probs, 1, generator=rng).item()
        out = torch.cat([out, torch.tensor([nxt], device=device)])
    return out


# =============================================================================
# § 41.3 — BERT-style encoder-only Transformer + MLM training
# =============================================================================

def make_mlm_batch(data: torch.Tensor, block: int, batch: int,
                   vocab_size: int, mask_token_id: int,
                   mask_prob: float = 0.15, device="cpu", rng=None):
    """The 80-10-10 corruption recipe from Devlin et al. 2019."""
    g = rng or torch.Generator(device="cpu").manual_seed(0)
    ix = torch.randint(0, len(data) - block - 1, (batch,), generator=g)
    x = torch.stack([data[i:i + block] for i in ix]).to(device)
    targets = torch.full_like(x, fill_value=-100)
    mask = (torch.rand(x.shape, generator=g, device=device) < mask_prob)
    rand_for_kind = torch.rand(x.shape, generator=g, device=device)
    is_mask = mask & (rand_for_kind < 0.80)
    is_random = mask & (rand_for_kind >= 0.80) & (rand_for_kind < 0.90)
    x_corrupted = x.clone()
    x_corrupted[is_mask] = mask_token_id
    rand_tokens = torch.randint(1, vocab_size, x.shape, generator=g, device=device)
    x_corrupted[is_random] = rand_tokens[is_random]
    targets[mask] = x[mask]
    return x_corrupted, targets


class BertLike(nn.Module):
    """Encoder-only Transformer: identical stack, no causal mask, MLM head on top."""

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.register_buffer(
            "pos_emb",
            sinusoidal_positional_encoding(cfg.max_len, cfg.d_model),
            persistent=False,
        )
        self.blocks = nn.ModuleList([
            TransformerBlock(cfg.d_model, cfg.n_heads, cfg.d_ff, cfg.dropout)
            for _ in range(cfg.n_layers)
        ])
        self.ln_f = nn.LayerNorm(cfg.d_model)
        self.mlm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=True)

    def encode(self, x):
        """Contextual hidden states — used by fine-tuning and §41.5 probing."""
        B, T = x.shape
        h = self.tok_emb(x) + self.pos_emb[:T].unsqueeze(0)
        for block in self.blocks:
            h = block(h, mask=None)                # NO causal mask — bidirectional
        return self.ln_f(h)                         # (B, T, d_model)

    def forward(self, x):
        return self.mlm_head(self.encode(x))        # (B, T, V)


def train_mlm(model: BertLike, data: torch.Tensor, tok: CharTokenizer,
              *, steps=1500, block=48, batch=64, lr=3e-3, log_every=80,
              device="cpu") -> list[float]:
    """Masked LM training. Equivalent to §41.3's train_mlm."""
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=steps)
    losses = []
    t0 = time.time()
    rng = torch.Generator(device="cpu").manual_seed(0)
    model.train()
    for step in range(steps):
        x, y = make_mlm_batch(data, block, batch,
                              model.cfg.vocab_size, tok.MASK_ID,
                              device=device, rng=rng)
        logits = model(x)
        loss = F.cross_entropy(logits.reshape(-1, model.cfg.vocab_size),
                               y.reshape(-1), ignore_index=-100)
        opt.zero_grad(); loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step(); sched.step()
        losses.append(loss.item())
        if step % log_every == 0 or step == steps - 1:
            print(f"  step {step:4d}  loss={loss.item():.4f}  "
                  f"({time.time() - t0:.1f}s elapsed)")
    print(f"Done in {time.time() - t0:.1f}s.")
    return losses


# =============================================================================
# § 41.5 — Held-out MLM accuracy (the cleanest measure of pretraining quality)
# =============================================================================

@torch.no_grad()
def held_out_mlm_accuracy(model: BertLike, corpus: torch.Tensor, tok: CharTokenizer,
                          *, n_batches=80, block=48, batch=64, mask_prob=0.15,
                          device="cpu") -> tuple[float, float]:
    """Top-1 and top-5 mask-fill accuracy on a held-out corpus."""
    model.eval()
    rng = torch.Generator(device="cpu").manual_seed(0)
    correct1 = correct5 = total = 0
    for _ in range(n_batches):
        x, y = make_mlm_batch(corpus, block, batch,
                              model.cfg.vocab_size, tok.MASK_ID,
                              mask_prob=mask_prob, device=device, rng=rng)
        logits = model(x)
        mask = (y != -100)
        pred = logits.argmax(-1)
        correct1 += ((pred == y) & mask).sum().item()
        top5 = logits.topk(5, dim=-1).indices
        in_top5 = (top5 == y.unsqueeze(-1)).any(dim=-1) & mask
        correct5 += in_top5.sum().item()
        total += mask.sum().item()
    return correct1 / total, correct5 / total


# =============================================================================
# § 41.5 — Linear probe (frozen encoder + Linear head, sentiment task sketch)
# =============================================================================

class LinearProbe(nn.Module):
    """Frozen encoder + Linear classification head. The cleanest diagnostic
    for 'are the pretrained features linearly separable for this task?'"""

    def __init__(self, encoder: BertLike, d_model: int, n_classes: int = 2,
                 mask_id: int = 0):
        super().__init__()
        self.encoder = encoder
        for p in self.encoder.parameters():
            p.requires_grad = False
        self.encoder.eval()
        self.cls = nn.Linear(d_model, n_classes)
        self.mask_id = mask_id

    def forward(self, x):
        with torch.no_grad():
            h = self.encoder.encode(x)
        valid = (x != self.mask_id).float().unsqueeze(-1)
        pooled = (h * valid).sum(dim=1) / valid.sum(dim=1).clamp(min=1)
        return self.cls(pooled)


# =============================================================================
# Main — run the full Ch 41 pipeline end-to-end
# =============================================================================

def main():
    device = torch.device("cpu")
    torch.manual_seed(0); random.seed(0)

    # --- Load corpus + tokenizer
    text = load_shakespeare(max_chars=80_000)
    tok = CharTokenizer(text)
    data = tok.encode(text)
    print(f"Corpus      : {len(text):,} chars, {tok.vocab_size} tokens (incl. [MASK])")

    cfg = Config(vocab_size=tok.vocab_size, d_model=96, n_heads=4, d_ff=256,
                 n_layers=2, max_len=64, dropout=0.1)

    # ---------------- §41.2 — train GPT-like ----------------
    print(f"\n=== §41.2 — Training GPT-like (decoder-only) ===")
    torch.manual_seed(0); random.seed(0)
    gpt = GPTLike(cfg).to(device)
    print(f"  Parameters: {count_params(gpt):,}")
    train_lm(gpt, data, tok, steps=600, block=48, batch=64, lr=3e-3, device=device)

    print("\nGenerated sample (τ=0.8):")
    prompt = tok.encode("ROMEO:").to(device)
    out = generate(gpt, prompt, max_new=160, temperature=0.8, device=device)
    print(tok.decode(out))

    # ---------------- §41.3 — train BERT-like ----------------
    print(f"\n=== §41.3 — Training BERT-like (encoder-only, MLM) ===")
    torch.manual_seed(0); random.seed(0)
    bert = BertLike(cfg).to(device)
    print(f"  Parameters: {count_params(bert):,}")
    train_mlm(bert, data, tok, steps=1500, block=48, batch=64, lr=3e-3, device=device)

    # ---------------- §41.5 — measure pretraining quality ----------------
    print(f"\n=== §41.5 — Held-out mask-fill accuracy ===")
    # Hold out 20k chars the pretraining never saw
    held_out_text = load_shakespeare(max_chars=None)[80_000:]
    held_out = tok.encode(held_out_text)
    print(f"  Held-out corpus: {len(held_out)} tokens.")

    acc1_pre, acc5_pre = held_out_mlm_accuracy(bert, held_out, tok, device=device)
    print(f"  PRETRAINED:   top-1 = {acc1_pre:.3f}   top-5 = {acc5_pre:.3f}")

    torch.manual_seed(0)
    bert_random = BertLike(cfg).to(device)
    acc1_rnd, acc5_rnd = held_out_mlm_accuracy(bert_random, held_out, tok, device=device)
    print(f"  RANDOM-INIT:  top-1 = {acc1_rnd:.3f}   top-5 = {acc5_rnd:.3f}")
    chance = 1.0 / (tok.vocab_size - 1)
    print(f"  Uniform chance: top-1 = {chance:.3f}")
    print(f"\n  Pretraining advantage:")
    print(f"    top-1: {acc1_pre / max(acc1_rnd, 1e-9):4.1f}x over random "
          f"({acc1_pre:.3f} vs {acc1_rnd:.3f})")
    print(f"    top-5: {acc5_pre / max(acc5_rnd, 1e-9):4.1f}x over random "
          f"({acc5_pre:.3f} vs {acc5_rnd:.3f})")


if __name__ == "__main__":
    main()
