{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a0b1c2d3",
   "metadata": {},
   "source": [
    "# Chapter 35: Character-Level Language Modeling\n",
    "\n",
    "In 2015, Andrej Karpathy demonstrated that a character-level LSTM trained on raw text could generate remarkably coherent prose. In this chapter, we build exactly this system on Shakespeare.\n",
    "\n",
    "Character-level language modeling is the simplest formulation of the problem that underpins modern AI: **next-token prediction**. Given a sequence of characters $c_1, c_2, \\ldots, c_t$, the model outputs a probability distribution over the next character $c_{t+1}$. This is the same objective used by GPT, BERT, and every large language model—the only difference is the granularity of the tokens. By working at the character level, we strip away the complexity of tokenization and see the core mechanism in its purest form."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1c2d3e4",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import time\n",
    "\n",
    "plt.style.use('seaborn-v0_8-whitegrid')\n",
    "\n",
    "BLUE = '#3b82f6'\n",
    "GREEN = '#059669'\n",
    "RED = '#dc2626'\n",
    "AMBER = '#d97706'\n",
    "INDIGO = '#4f46e5'\n",
    "\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(f'PyTorch version: {torch.__version__}')\n",
    "print(f'Device: {device}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2d3e4f5",
   "metadata": {},
   "source": [
    "## 35.1 The Character-Level Language Model\n",
    "\n",
    "A **language model** assigns a probability to a sequence of tokens:\n",
    "\n",
    "$$P(c_1, c_2, \\ldots, c_T) = \\prod_{t=1}^{T} P(c_t \\mid c_1, \\ldots, c_{t-1})$$\n",
    "\n",
    "In a character-level model, each token $c_t$ is a single character. The vocabulary $V$ is simply the set of unique characters in the training corpus—typically 50–80 characters for English text (letters, digits, punctuation, whitespace).\n",
    "\n",
    "```{admonition} Next-Token Prediction\n",
    ":class: important\n",
    "\n",
    "The training objective is to minimize the **cross-entropy** between the model's predicted distribution and the true next character:\n",
    "\n",
    "$$\\mathcal{L} = -\\frac{1}{T} \\sum_{t=1}^{T} \\log P_{\\text{model}}(c_t \\mid c_1, \\ldots, c_{t-1})$$\n",
    "\n",
    "This is identical to the loss function used by GPT-3, GPT-4, and other autoregressive language models. The only difference is scale: our vocabulary has ~65 characters instead of ~50,000 subword tokens, and our model has thousands of parameters instead of billions.\n",
    "```\n",
    "\n",
    "An RNN is a natural fit for this task: at each time step, the hidden state $h_t$ encodes the context $c_1, \\ldots, c_t$, and a linear layer maps $h_t$ to logits over the vocabulary."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3e4f5a6",
   "metadata": {},
   "source": [
    "## 35.2 The Shakespeare Dataset\n",
    "\n",
    "We use a 100,000-character excerpt from Shakespeare's works as our training corpus. This is the same dataset used in Karpathy's influential blog post \"The Unreasonable Effectiveness of Recurrent Neural Networks\" (2015)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4f5a6b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load Shakespeare text from local file\n",
    "with open('shakespeare.txt', 'r') as f:\n",
    "    text = f.read()\n",
    "\n",
    "# Basic statistics\n",
    "chars = sorted(set(text))\n",
    "vocab_size = len(chars)\n",
    "\n",
    "print(f'Total characters: {len(text):,}')\n",
    "print(f'Unique characters (vocab size): {vocab_size}')\n",
    "print(f'\\nCharacter set:')\n",
    "print(repr(''.join(chars)))\n",
    "print(f'\\n--- Sample passage (first 500 chars) ---')\n",
    "print(text[:500])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5a6b7c8",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Character frequency histogram\n",
    "from collections import Counter\n",
    "\n",
    "char_counts = Counter(text)\n",
    "# Sort by frequency\n",
    "sorted_chars = sorted(char_counts.items(), key=lambda x: -x[1])\n",
    "top_chars = sorted_chars[:30]\n",
    "\n",
    "labels = [repr(c)[1:-1] if c not in ('\\n', ' ', '\\t') else \n",
    "          {'\\n': '\\\\n', ' ': 'SP', '\\t': '\\\\t'}[c] \n",
    "          for c, _ in top_chars]\n",
    "counts = [cnt for _, cnt in top_chars]\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(12, 4.5))\n",
    "bars = ax.bar(range(len(labels)), counts, color=BLUE, alpha=0.8, edgecolor='white')\n",
    "ax.set_xticks(range(len(labels)))\n",
    "ax.set_xticklabels(labels, fontsize=9, fontfamily='monospace')\n",
    "ax.set_xlabel('Character', fontsize=11)\n",
    "ax.set_ylabel('Frequency', fontsize=11)\n",
    "ax.set_title('Top 30 Character Frequencies in Shakespeare Corpus', fontsize=13, fontweight='bold')\n",
    "\n",
    "# Highlight space and newline\n",
    "for i, (c, _) in enumerate(top_chars):\n",
    "    if c in (' ', '\\n'):\n",
    "        bars[i].set_color(AMBER)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(f'Most common: {labels[0]} ({counts[0]:,} occurrences, {counts[0]/len(text)*100:.1f}%)')\n",
    "print(f'Whitespace (space + newline): {char_counts.get(\" \", 0) + char_counts.get(chr(10), 0):,} '\n",
    "      f'({(char_counts.get(\" \", 0) + char_counts.get(chr(10), 0))/len(text)*100:.1f}%)')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6b7c8d9",
   "metadata": {},
   "source": [
    "## 35.3 Data Preparation\n",
    "\n",
    "We need three components:\n",
    "1. **Character-to-index mapping** (and its inverse) to convert between characters and integers.\n",
    "2. **Sequence windowing**: slide a window of length `seq_len` across the text to create input/target pairs.\n",
    "3. A PyTorch `Dataset` and `DataLoader` for batched training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7c8d9e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Character-to-index mapping\n",
    "char_to_idx = {ch: i for i, ch in enumerate(chars)}\n",
    "idx_to_char = {i: ch for i, ch in enumerate(chars)}\n",
    "\n",
    "# Encode the entire text as a tensor of indices\n",
    "encoded = torch.tensor([char_to_idx[ch] for ch in text], dtype=torch.long)\n",
    "print(f'Encoded tensor shape: {encoded.shape}')\n",
    "print(f'First 50 indices: {encoded[:50].tolist()}')\n",
    "print(f'Decoded back:     {repr(\"\".join(idx_to_char[i.item()] for i in encoded[:50]))}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8d9e0f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sequence windowing: create input/target pairs\n",
    "class ShakespeareDataset(Dataset):\n",
    "    \"\"\"Character-level language model dataset.\n",
    "    \n",
    "    Each sample is a pair (input_seq, target_seq) where:\n",
    "    - input_seq  = text[i : i + seq_len]\n",
    "    - target_seq = text[i+1 : i + seq_len + 1]\n",
    "    \n",
    "    The target is the input shifted by one character.\n",
    "    \"\"\"\n",
    "    def __init__(self, data, seq_len):\n",
    "        self.data = data\n",
    "        self.seq_len = seq_len\n",
    "    \n",
    "    def __len__(self):\n",
    "        return (len(self.data) - 1) // self.seq_len\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        start = idx * self.seq_len\n",
    "        end = start + self.seq_len\n",
    "        x = self.data[start:end]\n",
    "        y = self.data[start+1:end+1]\n",
    "        return x, y\n",
    "\n",
    "\n",
    "# Hyperparameters\n",
    "SEQ_LEN = 50\n",
    "BATCH_SIZE = 64\n",
    "EMBED_SIZE = 32\n",
    "HIDDEN_SIZE = 128\n",
    "LR = 0.003\n",
    "N_EPOCHS = 10\n",
    "\n",
    "# Create dataset and dataloader\n",
    "dataset = ShakespeareDataset(encoded, SEQ_LEN)\n",
    "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)\n",
    "\n",
    "print(f'Sequence length: {SEQ_LEN}')\n",
    "print(f'Batch size: {BATCH_SIZE}')\n",
    "print(f'Number of sequences: {len(dataset)}')\n",
    "print(f'Batches per epoch: {len(dataloader)}')\n",
    "\n",
    "# Peek at one batch\n",
    "x_batch, y_batch = next(iter(dataloader))\n",
    "print(f'\\nBatch shapes: x={x_batch.shape}, y={y_batch.shape}')\n",
    "print(f'Example input:  {repr(\"\".join(idx_to_char[i.item()] for i in x_batch[0][:30]))}')\n",
    "print(f'Example target: {repr(\"\".join(idx_to_char[i.item()] for i in y_batch[0][:30]))}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9e0f1a2",
   "metadata": {},
   "source": [
    "Notice how the target is simply the input shifted by one position. The model learns to predict each character given all previous characters in the window."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0f1a2b3",
   "metadata": {},
   "source": [
    "## 35.4 Vanilla RNN on Shakespeare\n",
    "\n",
    "Our model architecture is straightforward:\n",
    "1. **Embedding layer**: maps each character index to a dense vector of size `embed_size`.\n",
    "2. **RNN layer**: processes the sequence, producing hidden states at each time step.\n",
    "3. **Linear layer**: maps each hidden state to logits over the vocabulary.\n",
    "\n",
    "We start with a vanilla RNN to establish a baseline, then upgrade to LSTM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1a2b3c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CharRNN(nn.Module):\n",
    "    \"\"\"Character-level language model with configurable RNN type.\"\"\"\n",
    "    \n",
    "    def __init__(self, vocab_size, embed_size, hidden_size, rnn_type='rnn'):\n",
    "        super().__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.rnn_type = rnn_type\n",
    "        \n",
    "        # Character embedding\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
    "        \n",
    "        # Recurrent layer\n",
    "        if rnn_type == 'rnn':\n",
    "            self.rnn = nn.RNN(embed_size, hidden_size, batch_first=True)\n",
    "        elif rnn_type == 'lstm':\n",
    "            self.rnn = nn.LSTM(embed_size, hidden_size, batch_first=True)\n",
    "        elif rnn_type == 'gru':\n",
    "            self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True)\n",
    "        \n",
    "        # Output projection: hidden state -> vocabulary logits\n",
    "        self.fc = nn.Linear(hidden_size, vocab_size)\n",
    "    \n",
    "    def forward(self, x, hidden=None):\n",
    "        \"\"\"Forward pass.\n",
    "        \n",
    "        Args:\n",
    "            x: input indices, shape (batch, seq_len)\n",
    "            hidden: initial hidden state (optional)\n",
    "        \n",
    "        Returns:\n",
    "            logits: shape (batch, seq_len, vocab_size)\n",
    "            hidden: final hidden state\n",
    "        \"\"\"\n",
    "        emb = self.embedding(x)           # (batch, seq_len, embed_size)\n",
    "        out, hidden = self.rnn(emb, hidden)  # (batch, seq_len, hidden_size)\n",
    "        logits = self.fc(out)             # (batch, seq_len, vocab_size)\n",
    "        return logits, hidden\n",
    "    \n",
    "    def generate(self, start_str, length=200, temperature=1.0):\n",
    "        \"\"\"Generate text autoregressively.\"\"\"\n",
    "        self.eval()\n",
    "        chars_generated = list(start_str)\n",
    "        \n",
    "        # Encode the start string\n",
    "        input_idx = torch.tensor([[char_to_idx[ch] for ch in start_str]], dtype=torch.long)\n",
    "        \n",
    "        hidden = None\n",
    "        with torch.no_grad():\n",
    "            # Process the seed\n",
    "            logits, hidden = self.forward(input_idx, hidden)\n",
    "            \n",
    "            # Generate one character at a time\n",
    "            for _ in range(length):\n",
    "                # Use the last character's logits\n",
    "                last_logits = logits[0, -1, :] / temperature\n",
    "                probs = torch.softmax(last_logits, dim=0)\n",
    "                next_idx = torch.multinomial(probs, 1).item()\n",
    "                chars_generated.append(idx_to_char[next_idx])\n",
    "                \n",
    "                # Feed the generated character back\n",
    "                input_idx = torch.tensor([[next_idx]], dtype=torch.long)\n",
    "                logits, hidden = self.forward(input_idx, hidden)\n",
    "        \n",
    "        return ''.join(chars_generated)\n",
    "\n",
    "\n",
    "# Count parameters\n",
    "def count_params(model):\n",
    "    return sum(p.numel() for p in model.parameters())\n",
    "\n",
    "rnn_model = CharRNN(vocab_size, EMBED_SIZE, HIDDEN_SIZE, rnn_type='rnn')\n",
    "print(f'Vanilla RNN model:')\n",
    "print(f'  Parameters: {count_params(rnn_model):,}')\n",
    "print(f'  Embedding:  {vocab_size} x {EMBED_SIZE} = {vocab_size * EMBED_SIZE:,}')\n",
    "print(f'  RNN:        {sum(p.numel() for p in rnn_model.rnn.parameters()):,}')\n",
    "print(f'  Output FC:  {HIDDEN_SIZE} x {vocab_size} + {vocab_size} = {HIDDEN_SIZE * vocab_size + vocab_size:,}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2b3c4d5",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "def train_model(model, dataloader, n_epochs, lr, model_name='Model'):\n",
    "    \"\"\"Train a character-level language model.\"\"\"\n",
    "    optimizer = optim.Adam(model.parameters(), lr=lr)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    \n",
    "    losses = []\n",
    "    samples = {}  # epoch -> generated text\n",
    "    sample_epochs = {2, 5, 10}\n",
    "    \n",
    "    start_time = time.time()\n",
    "    \n",
    "    for epoch in range(1, n_epochs + 1):\n",
    "        model.train()\n",
    "        epoch_loss = 0.0\n",
    "        n_batches = 0\n",
    "        \n",
    "        for x_batch, y_batch in dataloader:\n",
    "            logits, _ = model(x_batch)\n",
    "            # Reshape for cross-entropy: (batch * seq_len, vocab_size) vs (batch * seq_len,)\n",
    "            loss = criterion(logits.reshape(-1, vocab_size), y_batch.reshape(-1))\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)\n",
    "            optimizer.step()\n",
    "            \n",
    "            epoch_loss += loss.item()\n",
    "            n_batches += 1\n",
    "        \n",
    "        avg_loss = epoch_loss / n_batches\n",
    "        losses.append(avg_loss)\n",
    "        \n",
    "        elapsed = time.time() - start_time\n",
    "        print(f'  Epoch {epoch:2d}/{n_epochs}  loss={avg_loss:.4f}  [{elapsed:.1f}s]')\n",
    "        \n",
    "        if epoch in sample_epochs:\n",
    "            sample = model.generate('KING ', length=150, temperature=0.8)\n",
    "            samples[epoch] = sample\n",
    "    \n",
    "    total_time = time.time() - start_time\n",
    "    print(f'  Training complete in {total_time:.1f}s')\n",
    "    \n",
    "    return losses, samples\n",
    "\n",
    "# Train vanilla RNN\n",
    "print('=== Training Vanilla RNN ===')\n",
    "torch.manual_seed(42)\n",
    "rnn_model = CharRNN(vocab_size, EMBED_SIZE, HIDDEN_SIZE, rnn_type='rnn')\n",
    "rnn_losses, rnn_samples = train_model(rnn_model, dataloader, N_EPOCHS, LR, 'RNN')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3c4d5e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show generated text at different training stages\n",
    "print('=== Vanilla RNN: Generated Text at Different Epochs ===')\n",
    "for epoch in sorted(rnn_samples.keys()):\n",
    "    print(f'\\n--- Epoch {epoch} ---')\n",
    "    print(rnn_samples[epoch])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4d5e6f7",
   "metadata": {},
   "source": [
    "Watch how the generated text improves: early epochs produce near-random characters, while later epochs begin capturing word boundaries, common words, and rudimentary syntax."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5e6f7a8",
   "metadata": {},
   "source": [
    "## 35.5 LSTM on Shakespeare\n",
    "\n",
    "Now we train the same architecture with an LSTM backbone. The model has more parameters due to the four gate matrices, but should learn longer-range dependencies and produce more coherent text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6f7a8b9",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Train LSTM\n",
    "print('=== Training LSTM ===')\n",
    "torch.manual_seed(42)\n",
    "lstm_model = CharRNN(vocab_size, EMBED_SIZE, HIDDEN_SIZE, rnn_type='lstm')\n",
    "print(f'LSTM parameters: {count_params(lstm_model):,}')\n",
    "lstm_losses, lstm_samples = train_model(lstm_model, dataloader, N_EPOCHS, LR, 'LSTM')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7a8b9c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Side-by-side comparison of generated text\n",
    "print('=' * 80)\n",
    "print('COMPARISON: Generated text after 10 epochs')\n",
    "print('=' * 80)\n",
    "\n",
    "# Generate fresh samples with the same seed\n",
    "torch.manual_seed(99)\n",
    "rnn_text = rnn_model.generate('KING ', length=200, temperature=0.8)\n",
    "torch.manual_seed(99)\n",
    "lstm_text = lstm_model.generate('KING ', length=200, temperature=0.8)\n",
    "\n",
    "print('\\n--- Vanilla RNN ---')\n",
    "print(rnn_text)\n",
    "print('\\n--- LSTM ---')\n",
    "print(lstm_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8b9c0d1",
   "metadata": {},
   "source": [
    "The LSTM typically produces text with better structure: more consistent word lengths, more plausible Shakespearean vocabulary, and occasionally coherent phrases. The difference becomes more pronounced with longer training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9c0d1e2",
   "metadata": {},
   "source": [
    "## 35.6 Temperature Sampling\n",
    "\n",
    "Recall from Chapter 26 that the **softmax temperature** $T$ controls the entropy of the output distribution:\n",
    "\n",
    "$$P(c_i) = \\frac{\\exp(z_i / T)}{\\sum_j \\exp(z_j / T)}$$\n",
    "\n",
    "- **$T < 1$ (low temperature)**: The distribution becomes peaked—the model becomes more \"confident\" and repetitive. Generated text is more predictable but less diverse.\n",
    "- **$T = 1$**: The unmodified model distribution.\n",
    "- **$T > 1$ (high temperature)**: The distribution becomes flatter—the model explores more alternatives. Generated text is more creative but may become incoherent.\n",
    "\n",
    "```{admonition} The Temperature Tradeoff\n",
    ":class: tip\n",
    "\n",
    "Low temperature produces **safe, repetitive** text. High temperature produces **diverse, risky** text. There is no universally optimal temperature—the best value depends on the application. For creative writing, $T \\approx 0.7$–$0.9$ often works well.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0d1e2f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate text at different temperatures\n",
    "temperatures = [0.5, 1.0, 1.5]\n",
    "\n",
    "print('=== LSTM: Temperature Sampling ===')\n",
    "for temp in temperatures:\n",
    "    torch.manual_seed(42)\n",
    "    generated = lstm_model.generate('HAMLET:\\n', length=250, temperature=temp)\n",
    "    print(f'\\n{\"=\"*60}')\n",
    "    print(f'Temperature = {temp}')\n",
    "    print(f'{\"=\"*60}')\n",
    "    print(generated)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1e2f3a4",
   "metadata": {},
   "source": [
    "Notice the tradeoff:\n",
    "- At $T = 0.5$, the text may repeat common patterns but maintains consistency.\n",
    "- At $T = 1.0$, the text is more varied and natural.\n",
    "- At $T = 1.5$, the text becomes more erratic, with unusual character combinations."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2f3a4b5",
   "metadata": {},
   "source": [
    "## 35.7 Gate Activation Visualization\n",
    "\n",
    "One of the most illuminating analyses of an LSTM is to visualize what the gates are doing on actual text. By feeding a sample passage through the trained LSTM and extracting the gate activations at each time step, we can see which characters trigger forgetting, storage, and output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3a4b5c6",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "def extract_gate_activations(model, text_str, n_units=10):\n",
    "    \"\"\"Extract LSTM gate activations for visualization.\n",
    "    \n",
    "    We hook into the LSTM to capture gate values at each step.\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    \n",
    "    # Encode the text\n",
    "    indices = torch.tensor([[char_to_idx[ch] for ch in text_str]], dtype=torch.long)\n",
    "    \n",
    "    # We'll manually step through the LSTM to capture gates\n",
    "    emb = model.embedding(indices)  # (1, T, embed_size)\n",
    "    \n",
    "    T = len(text_str)\n",
    "    hidden_size = model.hidden_size\n",
    "    \n",
    "    # Get LSTM weights\n",
    "    lstm = model.rnn\n",
    "    W_ih = lstm.weight_ih_l0  # (4*H, input_size)\n",
    "    W_hh = lstm.weight_hh_l0  # (4*H, H)\n",
    "    b_ih = lstm.bias_ih_l0    # (4*H,)\n",
    "    b_hh = lstm.bias_hh_l0    # (4*H,)\n",
    "    \n",
    "    h = torch.zeros(1, hidden_size)\n",
    "    c = torch.zeros(1, hidden_size)\n",
    "    \n",
    "    forget_gates = []\n",
    "    input_gates = []\n",
    "    output_gates = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for t in range(T):\n",
    "            x_t = emb[0, t:t+1, :]  # (1, embed_size)\n",
    "            \n",
    "            gates = x_t @ W_ih.T + b_ih + h @ W_hh.T + b_hh\n",
    "            i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=1)\n",
    "            \n",
    "            i_t = torch.sigmoid(i_gate)\n",
    "            f_t = torch.sigmoid(f_gate)\n",
    "            g_t = torch.tanh(g_gate)\n",
    "            o_t = torch.sigmoid(o_gate)\n",
    "            \n",
    "            c = f_t * c + i_t * g_t\n",
    "            h = o_t * torch.tanh(c)\n",
    "            \n",
    "            forget_gates.append(f_t[0, :n_units].numpy())\n",
    "            input_gates.append(i_t[0, :n_units].numpy())\n",
    "            output_gates.append(o_t[0, :n_units].numpy())\n",
    "    \n",
    "    return {\n",
    "        'forget': np.array(forget_gates),   # (T, n_units)\n",
    "        'input': np.array(input_gates),\n",
    "        'output': np.array(output_gates),\n",
    "    }\n",
    "\n",
    "# Extract gates for a sample passage\n",
    "sample_text = 'First Citizen:\\nBefore we proceed any further, hear me speak.\\n\\nAll:\\nSpeak, speak.'\n",
    "gates = extract_gate_activations(lstm_model, sample_text, n_units=8)\n",
    "\n",
    "# Plot gate heatmaps\n",
    "fig, axes = plt.subplots(3, 1, figsize=(14, 10), sharex=True)\n",
    "\n",
    "gate_names = ['Forget Gate', 'Input Gate', 'Output Gate']\n",
    "gate_keys = ['forget', 'input', 'output']\n",
    "cmaps = ['Reds', 'Greens', 'Blues']\n",
    "\n",
    "char_labels = [repr(c)[1:-1] if c not in ('\\n', ' ') else \n",
    "               {'\\n': r'$\\hookleftarrow$', ' ': r'$\\sqcup$'}[c] \n",
    "               for c in sample_text]\n",
    "\n",
    "for ax, name, key, cmap in zip(axes, gate_names, gate_keys, cmaps):\n",
    "    im = ax.imshow(gates[key].T, aspect='auto', cmap=cmap, vmin=0, vmax=1)\n",
    "    ax.set_ylabel(f'{name}\\n(units)', fontsize=10)\n",
    "    ax.set_yticks(range(8))\n",
    "    ax.set_yticklabels([f'#{i}' for i in range(8)], fontsize=8)\n",
    "    plt.colorbar(im, ax=ax, shrink=0.8, label='Activation')\n",
    "\n",
    "# Character labels on bottom axis\n",
    "axes[-1].set_xticks(range(len(sample_text)))\n",
    "axes[-1].set_xticklabels(char_labels, fontsize=7, fontfamily='monospace', rotation=0)\n",
    "axes[-1].set_xlabel('Character position', fontsize=11)\n",
    "\n",
    "plt.suptitle('LSTM Gate Activations on Shakespeare Text', fontsize=14, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Highlight interesting patterns\n",
    "print('Gate activation statistics:')\n",
    "for key in gate_keys:\n",
    "    g = gates[key]\n",
    "    print(f'  {key:7s}: mean={g.mean():.3f}, std={g.std():.3f}, '\n",
    "          f'min={g.min():.3f}, max={g.max():.3f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4b5c6d7",
   "metadata": {},
   "source": [
    "Look for these patterns in the heatmap:\n",
    "- The **forget gate** often activates strongly (values near 1) during word interiors, preserving context, and drops at word boundaries or punctuation—signaling the network to update its representation.\n",
    "- The **input gate** tends to spike at the beginning of new words or after punctuation, indicating that new information is being written into the cell state.\n",
    "- The **output gate** may show interesting patterns around newlines and colons, which in Shakespeare mark speaker transitions."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5c6d7e8",
   "metadata": {},
   "source": [
    "## 35.8 Architecture Comparison\n",
    "\n",
    "We now train a GRU model on the same data and compare all three architectures: Vanilla RNN, LSTM, and GRU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6d7e8f9",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Train GRU\n",
    "print('=== Training GRU ===')\n",
    "torch.manual_seed(42)\n",
    "gru_model = CharRNN(vocab_size, EMBED_SIZE, HIDDEN_SIZE, rnn_type='gru')\n",
    "print(f'GRU parameters: {count_params(gru_model):,}')\n",
    "gru_losses, gru_samples = train_model(gru_model, dataloader, N_EPOCHS, LR, 'GRU')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7e8f9a0",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Training curves comparison\n",
    "fig, axes = plt.subplots(1, 2, figsize=(13, 5))\n",
    "\n",
    "# Loss curves\n",
    "ax = axes[0]\n",
    "ax.plot(range(1, N_EPOCHS + 1), rnn_losses, color=RED, marker='s', linewidth=2,\n",
    "        markersize=6, label=f'Vanilla RNN ({count_params(rnn_model):,} params)')\n",
    "ax.plot(range(1, N_EPOCHS + 1), lstm_losses, color=GREEN, marker='o', linewidth=2,\n",
    "        markersize=6, label=f'LSTM ({count_params(lstm_model):,} params)')\n",
    "ax.plot(range(1, N_EPOCHS + 1), gru_losses, color=BLUE, marker='^', linewidth=2,\n",
    "        markersize=6, label=f'GRU ({count_params(gru_model):,} params)')\n",
    "ax.set_xlabel('Epoch', fontsize=11)\n",
    "ax.set_ylabel('Cross-Entropy Loss', fontsize=11)\n",
    "ax.set_title('Training Loss Comparison', fontsize=12, fontweight='bold')\n",
    "ax.legend(fontsize=9)\n",
    "ax.set_xticks(range(1, N_EPOCHS + 1))\n",
    "\n",
    "# Final comparison table as bar chart\n",
    "ax = axes[1]\n",
    "models_data = {\n",
    "    'RNN': (count_params(rnn_model), rnn_losses[-1]),\n",
    "    'LSTM': (count_params(lstm_model), lstm_losses[-1]),\n",
    "    'GRU': (count_params(gru_model), gru_losses[-1]),\n",
    "}\n",
    "x_pos = np.arange(3)\n",
    "colors = [RED, GREEN, BLUE]\n",
    "final_losses = [rnn_losses[-1], lstm_losses[-1], gru_losses[-1]]\n",
    "bars = ax.bar(x_pos, final_losses, color=colors, alpha=0.85, edgecolor='white', width=0.6)\n",
    "ax.set_xticks(x_pos)\n",
    "ax.set_xticklabels(['Vanilla RNN', 'LSTM', 'GRU'], fontsize=11)\n",
    "ax.set_ylabel('Final Loss (epoch 10)', fontsize=11)\n",
    "ax.set_title('Final Loss Comparison', fontsize=12, fontweight='bold')\n",
    "for bar, loss in zip(bars, final_losses):\n",
    "    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,\n",
    "            f'{loss:.3f}', ha='center', fontsize=10, fontweight='bold')\n",
    "\n",
    "plt.suptitle('Character-Level Shakespeare: Architecture Comparison',\n",
    "             fontsize=14, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Print comparison table\n",
    "print(f'{\"\":-<60}')\n",
    "print(f'{\"Architecture\":<15} {\"Parameters\":>12} {\"Final Loss\":>12} {\"Loss/Param\":>15}')\n",
    "print(f'{\"\":-<60}')\n",
    "for name, (params, loss) in models_data.items():\n",
    "    print(f'{name:<15} {params:>12,} {loss:>12.4f} {loss/params:>15.2e}')\n",
    "print(f'{\"\":-<60}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8f9a0b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Side-by-side generated samples from all three models\n",
    "print('=' * 70)\n",
    "print('Generated Shakespeare: Final Models (T=0.8)')\n",
    "print('=' * 70)\n",
    "\n",
    "for name, model in [('Vanilla RNN', rnn_model), ('LSTM', lstm_model), ('GRU', gru_model)]:\n",
    "    torch.manual_seed(42)\n",
    "    sample = model.generate('ROMEO:\\n', length=200, temperature=0.8)\n",
    "    print(f'\\n--- {name} ---')\n",
    "    print(sample)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9a0b1c2",
   "metadata": {},
   "source": [
    "```{admonition} Comparison Summary\n",
    ":class: important\n",
    "\n",
    "| Feature | Vanilla RNN | LSTM | GRU |\n",
    "|:--------|:-----------:|:----:|:---:|\n",
    "| Gates | 0 | 3 (forget, input, output) | 2 (update, reset) |\n",
    "| State vectors | 1 ($h_t$) | 2 ($h_t$, $C_t$) | 1 ($h_t$) |\n",
    "| Relative parameters | 1.0x | ~4x | ~3x |\n",
    "| Long-range memory | Poor | Excellent | Good |\n",
    "| Training speed | Fastest | Slowest | Middle |\n",
    "\n",
    "For this small task, both LSTM and GRU outperform the vanilla RNN. On larger datasets and longer sequences, the difference becomes even more dramatic.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0b1c2d4",
   "metadata": {},
   "source": [
    "## 35.9 Framework Corner\n",
    "\n",
    "````{admonition} Same Char-RNN in Other Frameworks\n",
    ":class: tip dropdown\n",
    "\n",
    "**TensorFlow / Keras:**\n",
    "```python\n",
    "import tensorflow as tf\n",
    "\n",
    "model = tf.keras.Sequential([\n",
    "    tf.keras.layers.Embedding(vocab_size, 32),\n",
    "    tf.keras.layers.LSTM(128, return_sequences=True),\n",
    "    tf.keras.layers.Dense(vocab_size)\n",
    "])\n",
    "model.compile(\n",
    "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "    optimizer='adam'\n",
    ")\n",
    "model.fit(x_train, y_train, epochs=10, batch_size=64)\n",
    "```\n",
    "\n",
    "**JAX / Flax:**\n",
    "```python\n",
    "import jax\n",
    "from flax import linen as fnn\n",
    "\n",
    "class CharLSTM(fnn.Module):\n",
    "    vocab_size: int\n",
    "    hidden_size: int = 128\n",
    "\n",
    "    @fnn.compact\n",
    "    def __call__(self, x):\n",
    "        x = fnn.Embed(self.vocab_size, 32)(x)\n",
    "        carry = fnn.LSTMCell.initialize_carry(\n",
    "            jax.random.PRNGKey(0), (x.shape[0],), self.hidden_size\n",
    "        )\n",
    "        lstm = fnn.LSTMCell(features=self.hidden_size)\n",
    "        for t in range(x.shape[1]):\n",
    "            carry, _ = lstm(carry, x[:, t])\n",
    "        return fnn.Dense(self.vocab_size)(carry[0])\n",
    "```\n",
    "\n",
    "The architecture is identical across frameworks --- only the API differs.\n",
    "````"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1c2d3e5",
   "metadata": {},
   "source": [
    "## Exercises\n",
    "\n",
    "**Exercise 35.1.** Compute the perplexity of each model (RNN, LSTM, GRU) on a held-out portion of the Shakespeare text. Recall that perplexity $= \\exp(\\mathcal{L})$ where $\\mathcal{L}$ is the cross-entropy loss. Which model achieves the lowest perplexity? How does perplexity relate to the \"quality\" of generated text?\n",
    "\n",
    "**Exercise 35.2.** Modify the `CharRNN` model to use a **2-layer** LSTM (set `num_layers=2` in `nn.LSTM`). Does the additional depth improve the loss or generated text quality? Report the parameter count and training curves.\n",
    "\n",
    "**Exercise 35.3.** Implement **top-k sampling**: instead of sampling from the full vocabulary distribution, restrict sampling to the $k$ most probable characters. Compare generated text quality for $k \\in \\{5, 10, 20, 65\\}$ (where 65 = full vocabulary). How does top-k interact with temperature?\n",
    "\n",
    "**Exercise 35.4.** Train the LSTM on a different corpus of your choice (e.g., a Python source file, a novel, song lyrics). How does the generated text reflect the structure of the training data? What features does the model learn to reproduce?\n",
    "\n",
    "**Exercise 35.5.** The current model processes fixed-length windows independently. Implement **stateful training** where the hidden state from the end of one batch is passed as the initial state of the next batch (with gradient detaching). Does this improve the loss? Why would maintaining state across batches be beneficial?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2d3e4f6",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- **Character-level language modeling** is next-token prediction at the character level—the same objective that powers GPT and other large language models, at a much smaller scale.\n",
    "- The **Shakespeare dataset** (100K characters, ~65 unique characters) is sufficient to train a small LSTM that captures word boundaries, common vocabulary, speaker-turn structure, and rudimentary grammar.\n",
    "- **Temperature sampling** controls the diversity-quality tradeoff: low $T$ produces safe, repetitive text; high $T$ produces creative but potentially incoherent text.\n",
    "- **Gate activation visualization** reveals that LSTM gates learn interpretable roles: forget gates reset at sentence boundaries, input gates fire at word onsets.\n",
    "- **LSTM and GRU** consistently outperform vanilla RNNs in both loss and text quality, confirming the practical importance of gating mechanisms."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3e4f5a7",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "1. A. Karpathy, \"The Unreasonable Effectiveness of Recurrent Neural Networks,\" blog post, 2015. http://karpathy.github.io/2015/05/21/rnn-effectiveness/\n",
    "\n",
    "2. S. Hochreiter and J. Schmidhuber, \"Long short-term memory,\" *Neural Computation*, vol. 9, no. 8, pp. 1735–1780, 1997.\n",
    "\n",
    "3. K. Cho, B. van Merrienboer, C. Gulcehre, D. Bahdanau, F. Bougares, H. Schwenk, and Y. Bengio, \"Learning phrase representations using RNN encoder-decoder for statistical machine translation,\" in *Proceedings of EMNLP*, 2014.\n",
    "\n",
    "4. I. Sutskever, J. Martens, and G. Hinton, \"Generating text with recurrent neural networks,\" in *Proceedings of ICML*, pp. 1017–1024, 2011."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}