"""ch42_complete.py — Chapter 42 (Tokenizers) in one file.

The full code for Chapter 42 of "Classical Foundations of Artificial Neural
Networks" (https://bnaskrecki.faculty.wmi.amu.edu.pl/nnets/). This single
file contains:

    - the from-scratch BPETokenizer (~70 lines, no libraries) — §42.2
    - a worked toy-corpus walkthrough — §42.2
    - the multi-tokenizer comparison applet from §42.7 (uses HuggingFace
      `transformers` to load real GPT-2 and BERT tokenizers)
    - the arithmetic-pathology table (GPT-2 on consecutive integers)
    - the held-out compression / coverage estimates from §42.9

Run as:
    python ch42_complete.py

Requires:
    pip install torch numpy transformers

The HuggingFace tokenizers (gpt-2 / bert-base-uncased) are downloaded
from HuggingFace Hub on first use; ~3 MB total.

Total runtime: under 10 seconds (BPE training on 80 KB Shakespeare is
the longest step at ~3 s).
"""

from __future__ import annotations

import os
import re
import sys
import urllib.request
from collections import Counter

import torch

# =============================================================================
# Corpus loader (same as Ch 41 — included for self-containment)
# =============================================================================

SHAKESPEARE_URL = (
    "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/"
    "tinyshakespeare/input.txt"
)


def load_shakespeare(max_chars: int | None = 80_000, path: str = "shakespeare.txt") -> str:
    if not os.path.exists(path):
        print(f"Downloading Tiny Shakespeare → {path} ...")
        urllib.request.urlretrieve(SHAKESPEARE_URL, path)
    with open(path, "r", encoding="utf-8") as f:
        text = f.read()
    if max_chars is not None:
        text = text[:max_chars]
    return text


# =============================================================================
# § 42.2 — BPETokenizer (from scratch, no libraries)
# =============================================================================

class BPETokenizer:
    """Byte-pair-encoding tokenizer (Sennrich, Haddow & Birch 2016; the
    algorithm itself is from Gage 1994). Pedagogical implementation;
    not tuned for speed."""

    END_OF_WORD = "</w>"

    def __init__(self):
        self.merges: list[tuple[str, str]] = []
        self.vocab: set[str] = set()

    @staticmethod
    def _word_to_tuple(word: str) -> tuple[str, ...]:
        return tuple(list(word) + [BPETokenizer.END_OF_WORD])

    @staticmethod
    def _count_pairs(corpus: dict[tuple[str, ...], int]) -> Counter:
        pairs: Counter = Counter()
        for word, freq in corpus.items():
            for i in range(len(word) - 1):
                pairs[(word[i], word[i + 1])] += freq
        return pairs

    @staticmethod
    def _apply_merge(pair, corpus):
        merged = pair[0] + pair[1]
        new_corpus = {}
        for word, freq in corpus.items():
            new_word, i = [], 0
            while i < len(word):
                if i < len(word) - 1 and (word[i], word[i + 1]) == pair:
                    new_word.append(merged); i += 2
                else:
                    new_word.append(word[i]); i += 1
            new_corpus[tuple(new_word)] = freq
        return new_corpus

    def train(self, text: str, vocab_size: int = 500, log_every: int | None = None):
        words = text.split()
        word_freqs: Counter = Counter(words)
        corpus = {self._word_to_tuple(w): f for w, f in word_freqs.items()}
        chars: set[str] = {ch for word in corpus for ch in word}
        self.vocab = set(chars)
        self.merges = []
        while len(self.vocab) < vocab_size:
            pairs = self._count_pairs(corpus)
            if not pairs:
                break
            best_pair, best_count = pairs.most_common(1)[0]
            if best_count < 2:
                break
            corpus = self._apply_merge(best_pair, corpus)
            self.merges.append(best_pair)
            self.vocab.add(best_pair[0] + best_pair[1])
            if log_every and len(self.merges) % log_every == 0:
                print(f"  merge {len(self.merges):4d}: {best_pair!r} "
                      f"(count={best_count}); |V|={len(self.vocab)}")
        return self.merges

    def encode_word(self, word: str) -> list[str]:
        symbols = list(word) + [self.END_OF_WORD]
        for left, right in self.merges:
            new_symbols, i = [], 0
            while i < len(symbols):
                if i < len(symbols) - 1 and symbols[i] == left and symbols[i + 1] == right:
                    new_symbols.append(left + right); i += 2
                else:
                    new_symbols.append(symbols[i]); i += 1
            symbols = new_symbols
        return symbols

    def encode(self, text: str) -> list[str]:
        return [tok for word in text.split() for tok in self.encode_word(word)]


# =============================================================================
# § 42.2 — Toy corpus walkthrough
# =============================================================================

def demo_toy_corpus():
    """Five-merge BPE walkthrough on a deliberately small corpus."""
    toy = ("low low low low low lower lower newest newest newest newest "
           "newest widest widest widest")
    print(f"Toy corpus ({len(toy.split())} word tokens, "
          f"{len(set(toy.split()))} unique):  {toy!r}")
    print("\nTraining log:")
    bpe = BPETokenizer()
    bpe.train(toy, vocab_size=20, log_every=1)
    print("\nEncodings after training:")
    for w in ("low", "lower", "newest", "widest"):
        print(f"  {w!r:>10s}  →  {bpe.encode_word(w)}")


# =============================================================================
# § 42.7 — Multi-tokenizer comparison
# =============================================================================

SAMPLES = [
    ("English",        "It is the east, and Juliet is the sun."),
    ("Numbers",        "The temperature was 12345 degrees Fahrenheit yesterday."),
    ("Polish",         "Sieci neuronowe są podstawą współczesnej sztucznej inteligencji."),
    ("Chinese",        "神经网络是现代人工智能的基础。"),
    ("Python code",    "for i in range(10):\n    print(i**2)"),
    ("Emoji+punct",    "WOW!!! That is amazing 🎉🚀 — definitely 100% true."),
]


def compare_tokenizers(bpe: BPETokenizer):
    """Print a side-by-side table of token counts for 6 pathological inputs.
    Requires `pip install transformers`."""
    try:
        from transformers import GPT2TokenizerFast, BertTokenizerFast
    except ImportError:
        print("\n[skip] Install `transformers` to run the multi-tokenizer comparison.")
        return

    gpt2 = GPT2TokenizerFast.from_pretrained("gpt2")
    bert = BertTokenizerFast.from_pretrained("bert-base-uncased")

    tokenizers = [
        ("char",            lambda s: list(s)),
        ("whitespace",      lambda s: s.split()),
        ("our BPE (500)",   lambda s: bpe.encode(s)),
        ("GPT-2 BPE",       lambda s: gpt2.tokenize(s)),
        ("BERT WordPiece",  lambda s: bert.tokenize(s)),
    ]

    print(f"\n{'tag':<14s} " + " ".join(f"{n:>15s}" for n, _ in tokenizers))
    print("-" * (14 + 16 * len(tokenizers)))
    for tag, sentence in SAMPLES:
        counts = [len(fn(sentence)) for _, fn in tokenizers]
        print(f"{tag:<14s} " + " ".join(f"{c:>15d}" for c in counts))


# =============================================================================
# § 42.7 — Arithmetic pathology
# =============================================================================

def arithmetic_pathology():
    """Show that consecutive integers get wildly different GPT-2 tokenisations."""
    try:
        from transformers import GPT2TokenizerFast
    except ImportError:
        print("\n[skip] Install `transformers` to run the arithmetic-pathology demo.")
        return
    gpt2 = GPT2TokenizerFast.from_pretrained("gpt2")
    print(f"\n{'number':>15s}  GPT-2 tokens")
    print("-" * 70)
    for n in (list(range(123, 130)) + [12345, 56789, 1_000_000, 1_000_000_000]):
        s = str(n)
        toks = gpt2.tokenize(s)
        print(f"{s:>15s}  {toks}  ({len(toks)} tokens)")


# =============================================================================
# § 42.9 — Compression and effective-coverage numbers
# =============================================================================

def compression_numbers(bpe: BPETokenizer, text: str):
    """The §42.9 numerical sketch — what swapping CharTokenizer for our BPE
    would have done to the Ch 41 model."""
    pieces = bpe.encode(text)
    n_chars = len(text)
    n_bpe = len(pieces)
    compression = n_chars / n_bpe
    d_model = 96
    print(f"\n{'':30s} CharTokenizer   BPETokenizer (500)")
    print(f"{'tokens for 80 KB corpus':30s} {n_chars:>10,d}      {n_bpe:>10,d}")
    print(f"{'chars per token':30s} {1.0:>10.2f}x     {compression:>10.2f}x")
    char_embed = 62 * d_model        # 62-char vocab + [MASK]
    bpe_embed = len(bpe.vocab) * d_model
    print(f"{'embedding params at d=96':30s} {char_embed:>10,d}      {bpe_embed:>10,d}")
    print(f"\n  Effective coverage per fixed T^2 attention budget: {compression**2:.1f}x")


# =============================================================================
# Main
# =============================================================================

def main():
    print("=" * 70)
    print("§ 42.2 — Toy corpus walkthrough")
    print("=" * 70)
    demo_toy_corpus()

    print("\n" + "=" * 70)
    print("§ 42.2 — Train BPE on 80 KB Shakespeare (vocab_size = 500)")
    print("=" * 70)
    text = load_shakespeare(max_chars=80_000)
    bpe = BPETokenizer()
    bpe.train(text, vocab_size=500, log_every=100)
    print(f"\nFinal vocab : {len(bpe.vocab)}")
    print(f"Merges learned : {len(bpe.merges)}")
    print(f"First 5 merges : {bpe.merges[:5]}")

    print("\n" + "=" * 70)
    print("§ 42.7 — Multi-tokenizer comparison (5 tokenizers x 6 samples)")
    print("=" * 70)
    compare_tokenizers(bpe)

    print("\n" + "=" * 70)
    print("§ 42.7 — Arithmetic pathology (GPT-2 on consecutive integers)")
    print("=" * 70)
    arithmetic_pathology()

    print("\n" + "=" * 70)
    print("§ 42.9 — What swapping the tokenizer would buy on the Ch 41 model")
    print("=" * 70)
    compression_numbers(bpe, text)


if __name__ == "__main__":
    main()
