{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cell-0",
   "metadata": {},
   "source": [
    "# Chapter 33: Backpropagation Through Time\n",
    "\n",
    "In Chapter 16, we derived backpropagation for feedforward networks by applying\n",
    "the chain rule layer by layer. For recurrent networks, the same principle\n",
    "applies -- but the chain extends through **time**. This temporal unrolling\n",
    "reveals a fundamental problem: gradients can **vanish** or **explode**\n",
    "exponentially.\n",
    "\n",
    "The vanishing gradient problem, first identified by Hochreiter in his 1991\n",
    "diploma thesis and formally analyzed by Bengio, Simard & Frasconi (1994),\n",
    "explains why simple RNNs fail to learn long-range dependencies. Understanding\n",
    "this failure is essential -- it motivates the LSTM architecture that solved\n",
    "the problem and launched the modern era of sequence modeling.\n",
    "\n",
    "In this chapter we derive the **backpropagation through time** (BPTT) algorithm,\n",
    "prove why gradients vanish or explode, demonstrate the failure empirically on\n",
    "a \"remember the first character\" task, and introduce two practical mitigations:\n",
    "gradient clipping and truncated BPTT.\n",
    "\n",
    "```{admonition} Prerequisites\n",
    ":class: note\n",
    "This chapter builds directly on Chapter 16 (backpropagation derivation) and\n",
    "Chapter 32 (simple RNN). Familiarity with matrix norms and eigenvalues is\n",
    "helpful but not strictly required.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-1",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "from copy import deepcopy\n",
    "\n",
    "plt.style.use('seaborn-v0_8-whitegrid')\n",
    "plt.rcParams.update({\n",
    "    'figure.facecolor': '#FAF8F0',\n",
    "    'axes.facecolor': '#FAF8F0',\n",
    "    'font.size': 11,\n",
    "})\n",
    "\n",
    "# Project colour palette\n",
    "BLUE = '#3b82f6'\n",
    "BLUE_DARK = '#2563eb'\n",
    "GREEN = '#059669'\n",
    "GREEN_LIGHT = '#10b981'\n",
    "AMBER = '#d97706'\n",
    "RED = '#dc2626'\n",
    "BURGUNDY = '#8c2f39'\n",
    "PURPLE = '#7c3aed'\n",
    "GRAY = '#6b7280'\n",
    "\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "\n",
    "print('Imports loaded: numpy, torch, matplotlib')\n",
    "print(f'PyTorch version: {torch.__version__}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-2",
   "metadata": {},
   "source": [
    "## 33.1 Unrolling the RNN\n",
    "\n",
    "Recall the simple RNN equations from Chapter 32:\n",
    "\n",
    "$$h_t = \\tanh(W_h h_{t-1} + W_x x_t + b_h) \\tag{RNN-1}$$\n",
    "\n",
    "$$y_t = W_y h_t + b_y \\tag{RNN-2}$$\n",
    "\n",
    "When we process a sequence of length $T$, the RNN applies these equations\n",
    "$T$ times, with the same weights at each step. For the purpose of computing\n",
    "gradients, we can **unroll** the RNN into a feedforward network with $T$\n",
    "layers -- one per time step.\n",
    "\n",
    "```{admonition} Unrolling = Depth\n",
    ":class: important\n",
    "An RNN processing a sequence of length $T$ is equivalent, for gradient\n",
    "computation, to a feedforward network with $T$ layers that share weights.\n",
    "A sequence of length 100 becomes a 100-layer deep network. The depth of\n",
    "this unrolled network is the source of the vanishing/exploding gradient\n",
    "problem.\n",
    "```\n",
    "\n",
    "At each time step $t$, we may incur a loss $\\ell_t$ (e.g., cross-entropy between\n",
    "the predicted and actual next character). The total loss over the sequence is:\n",
    "\n",
    "$$L = \\sum_{t=1}^T \\ell_t$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-3",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Diagram: Folded RNN -> Unrolled computation graph\n",
    "from matplotlib.patches import FancyBboxPatch\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
    "\n",
    "# Left: Folded view\n",
    "ax = axes[0]\n",
    "ax.set_xlim(-1, 6)\n",
    "ax.set_ylim(-1, 6)\n",
    "ax.set_aspect('equal')\n",
    "\n",
    "rnn_box = FancyBboxPatch((1.5, 1.5), 2.5, 2.5, boxstyle='round,pad=0.2',\n",
    "                          facecolor=BLUE, edgecolor=BLUE_DARK, linewidth=2, alpha=0.25)\n",
    "ax.add_patch(rnn_box)\n",
    "ax.text(2.75, 2.75, 'RNN\\nCell', ha='center', va='center',\n",
    "        fontsize=15, fontweight='bold', color=BLUE_DARK)\n",
    "\n",
    "# Input\n",
    "ax.annotate('', xy=(2.75, 1.5), xytext=(2.75, 0),\n",
    "            arrowprops=dict(arrowstyle='->', lw=2, color='black'))\n",
    "ax.text(2.75, -0.3, '$x_t$', ha='center', fontsize=14, fontweight='bold')\n",
    "\n",
    "# Loss\n",
    "ax.annotate('', xy=(2.75, 5.5), xytext=(2.75, 4),\n",
    "            arrowprops=dict(arrowstyle='->', lw=2, color='black'))\n",
    "ax.text(2.75, 5.7, '$\\\\ell_t$', ha='center', fontsize=14, fontweight='bold')\n",
    "\n",
    "# Self-loop\n",
    "ax.annotate('', xy=(4.0, 3.5), xytext=(4.7, 2.75),\n",
    "            arrowprops=dict(arrowstyle='->', color=RED, lw=2.5,\n",
    "                          connectionstyle='arc3,rad=-0.8'))\n",
    "ax.text(5.1, 3.5, '$h_t$', ha='left', fontsize=13, fontweight='bold', color=RED)\n",
    "\n",
    "ax.set_title('Folded RNN', fontsize=13, fontweight='bold')\n",
    "ax.axis('off')\n",
    "\n",
    "# Right: Unrolled view with gradient flow arrows\n",
    "ax = axes[1]\n",
    "ax.set_xlim(-1, 14)\n",
    "ax.set_ylim(-2, 7)\n",
    "ax.set_aspect('equal')\n",
    "\n",
    "T_draw = 4\n",
    "x_positions = [1.5, 4.5, 7.5, 10.5]\n",
    "labels = ['1', '2', '...', 'T']\n",
    "\n",
    "for i, (px, lt) in enumerate(zip(x_positions, labels)):\n",
    "    box = FancyBboxPatch((px - 0.9, 1.5), 1.8, 1.8,\n",
    "                          boxstyle='round,pad=0.1',\n",
    "                          facecolor=BLUE, edgecolor=BLUE_DARK,\n",
    "                          linewidth=2, alpha=0.25)\n",
    "    ax.add_patch(box)\n",
    "    ax.text(px, 2.4, 'RNN', ha='center', va='center',\n",
    "            fontsize=10, fontweight='bold', color=BLUE_DARK)\n",
    "\n",
    "    # Input\n",
    "    ax.annotate('', xy=(px, 1.5), xytext=(px, 0.2),\n",
    "                arrowprops=dict(arrowstyle='->', lw=1.5, color='black'))\n",
    "    ax.text(px, -0.1, f'$x_{{{lt}}}$', ha='center', fontsize=12, fontweight='bold')\n",
    "\n",
    "    # Loss\n",
    "    ax.annotate('', xy=(px, 5.2), xytext=(px, 3.3),\n",
    "                arrowprops=dict(arrowstyle='->', lw=1.5, color='black'))\n",
    "    ax.text(px, 5.5, f'$\\\\ell_{{{lt}}}$', ha='center', fontsize=12, fontweight='bold')\n",
    "\n",
    "# Forward hidden state arrows\n",
    "for i in range(len(x_positions) - 1):\n",
    "    ax.annotate('', xy=(x_positions[i+1] - 0.9, 2.4),\n",
    "                xytext=(x_positions[i] + 0.9, 2.4),\n",
    "                arrowprops=dict(arrowstyle='->', color=RED, lw=2))\n",
    "    mid = (x_positions[i] + x_positions[i+1]) / 2\n",
    "    ax.text(mid, 2.9, f'$h_{{{labels[i]}}}$', ha='center',\n",
    "            fontsize=10, fontweight='bold', color=RED)\n",
    "\n",
    "# Initial h\n",
    "ax.annotate('', xy=(x_positions[0] - 0.9, 2.4), xytext=(-0.5, 2.4),\n",
    "            arrowprops=dict(arrowstyle='->', color=RED, lw=2))\n",
    "ax.text(-0.7, 2.9, '$h_0$', ha='center', fontsize=10, fontweight='bold', color=RED)\n",
    "\n",
    "# Backward gradient arrows (dashed)\n",
    "for i in range(len(x_positions) - 1, 0, -1):\n",
    "    ax.annotate('', xy=(x_positions[i-1] + 0.9, 1.7),\n",
    "                xytext=(x_positions[i] - 0.9, 1.7),\n",
    "                arrowprops=dict(arrowstyle='->', color=AMBER, lw=2,\n",
    "                              linestyle='dashed'))\n",
    "\n",
    "ax.text(6.0, -1.3, 'Gradient flow (backward)',\n",
    "        ha='center', fontsize=11, fontstyle='italic', color=AMBER)\n",
    "ax.annotate('', xy=(3, -1.0), xytext=(9, -1.0),\n",
    "            arrowprops=dict(arrowstyle='->', color=AMBER, lw=2, linestyle='dashed'))\n",
    "\n",
    "ax.set_title('Unrolled (T steps) with Gradient Flow', fontsize=13, fontweight='bold')\n",
    "ax.axis('off')\n",
    "\n",
    "fig.suptitle('RNN Unrolling for Backpropagation Through Time',\n",
    "             fontsize=14, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print('Each copy of the RNN cell shares the SAME weights.')\n",
    "print('Gradients flow backward through every time step (dashed arrows).')\n",
    "print('The longer the sequence, the deeper the effective network.')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-4",
   "metadata": {},
   "source": [
    "## 33.2 BPTT Derivation\n",
    "\n",
    "We now derive the backpropagation through time algorithm, extending the\n",
    "chain rule analysis of Chapter 16 to the temporal dimension.\n",
    "\n",
    "### Setup\n",
    "\n",
    "Let $\\ell_t$ be the loss at time step $t$ (e.g., cross-entropy between the\n",
    "predicted next token and the ground truth). The total loss is $L = \\sum_{t=1}^T \\ell_t$.\n",
    "We need the gradients $\\frac{\\partial L}{\\partial W_h}$, $\\frac{\\partial L}{\\partial W_x}$,\n",
    "and $\\frac{\\partial L}{\\partial b_h}$ to update the shared parameters.\n",
    "\n",
    "### The Chain Through Time\n",
    "\n",
    "Since $W_h$ is used at *every* time step, its gradient accumulates contributions\n",
    "from all time steps:\n",
    "\n",
    "$$\\frac{\\partial L}{\\partial W_h} = \\sum_{t=1}^T \\frac{\\partial \\ell_t}{\\partial W_h}$$\n",
    "\n",
    "The loss $\\ell_t$ depends on $W_h$ through the chain:\n",
    "\n",
    "$$\\ell_t \\leftarrow y_t \\leftarrow h_t \\leftarrow h_{t-1} \\leftarrow \\cdots \\leftarrow h_1 \\leftarrow h_0$$\n",
    "\n",
    "Applying the chain rule:\n",
    "\n",
    "$$\\frac{\\partial \\ell_t}{\\partial W_h} = \\sum_{k=1}^t \\frac{\\partial \\ell_t}{\\partial h_t}\n",
    "\\underbrace{\\left(\\prod_{j=k+1}^t \\frac{\\partial h_j}{\\partial h_{j-1}}\\right)}_{\\text{temporal Jacobian product}}\n",
    "\\frac{\\partial h_k}{\\partial W_h}$$\n",
    "\n",
    "```{admonition} Theorem (BPTT Gradient)\n",
    ":class: note\n",
    "The gradient of the total loss with respect to the hidden-to-hidden weight\n",
    "matrix is:\n",
    "\n",
    "$$\\frac{\\partial L}{\\partial W_h} = \\sum_{t=1}^T \\sum_{k=1}^t\n",
    "\\frac{\\partial \\ell_t}{\\partial h_t}\n",
    "\\left(\\prod_{j=k+1}^t \\frac{\\partial h_j}{\\partial h_{j-1}}\\right)\n",
    "\\frac{\\partial h_k}{\\partial W_h}$$\n",
    "\n",
    "where the **temporal Jacobian** at each step is:\n",
    "\n",
    "$$\\frac{\\partial h_j}{\\partial h_{j-1}} = \\text{diag}\\left(1 - h_j^2\\right) W_h$$\n",
    "\n",
    "using the fact that $\\tanh'(z) = 1 - \\tanh^2(z)$ and $h_j = \\tanh(W_h h_{j-1} + W_x x_j + b_h)$.\n",
    "```\n",
    "\n",
    "### Connection to Chapter 16\n",
    "\n",
    "In Chapter 16, we derived four equations BP1--BP4 for feedforward networks.\n",
    "BPTT is the same chain rule, but applied to a network with **shared weights**\n",
    "across layers and **multiple loss terms** (one per time step):\n",
    "\n",
    "| Feedforward (Ch. 16) | Recurrent (BPTT) |\n",
    "|---|---|\n",
    "| One loss at the output | Loss at each time step |\n",
    "| Different $W^{(l)}$ per layer | Same $W_h$ at every step |\n",
    "| Chain through $L$ layers | Chain through $T$ time steps |\n",
    "| $\\delta^{(l)} = \\sigma'(z^{(l)}) \\odot (W^{(l+1)})^\\top \\delta^{(l+1)}$ | $\\delta_t = (1 - h_t^2) \\odot (W_h^\\top \\delta_{t+1} + \\frac{\\partial \\ell_t}{\\partial h_t})$ |\n",
    "\n",
    "### BPTT Algorithm\n",
    "\n",
    "```{admonition} Algorithm: Backpropagation Through Time\n",
    ":class: important\n",
    "\n",
    "**Input:** Sequence $x_1, \\ldots, x_T$; targets $y_1^*, \\ldots, y_T^*$; parameters $W_x, W_h, W_y, b_h, b_y$.\n",
    "\n",
    "**Forward pass:**\n",
    "1. Set $h_0 = \\mathbf{0}$\n",
    "2. For $t = 1, \\ldots, T$:\n",
    "   - $h_t = \\tanh(W_h h_{t-1} + W_x x_t + b_h)$\n",
    "   - $y_t = W_y h_t + b_y$\n",
    "   - Compute loss $\\ell_t = \\text{Loss}(y_t, y_t^*)$\n",
    "\n",
    "**Backward pass:**\n",
    "3. Initialize $\\delta_{T+1}^h = \\mathbf{0}$ (no future gradient)\n",
    "4. For $t = T, T-1, \\ldots, 1$:\n",
    "   - $\\delta_t^y = \\frac{\\partial \\ell_t}{\\partial y_t}$ (output gradient)\n",
    "   - $\\delta_t^h = W_y^\\top \\delta_t^y + W_h^\\top \\delta_{t+1}^h$ (total gradient at $h_t$)\n",
    "   - $\\delta_t^z = \\delta_t^h \\odot (1 - h_t^2)$ (through tanh)\n",
    "   - Accumulate: $\\Delta W_h \\mathrel{+}= \\delta_t^z \\, h_{t-1}^\\top$\n",
    "   - Accumulate: $\\Delta W_x \\mathrel{+}= \\delta_t^z \\, x_t^\\top$\n",
    "   - Accumulate: $\\Delta b_h \\mathrel{+}= \\delta_t^z$\n",
    "   - Pass backward: $\\delta_t^h = \\delta_t^z$ (for next iteration, used as $\\delta_{t+1}^h$... but we already computed $W_h^\\top \\delta_{t+1}^h$ above)\n",
    "\n",
    "**Update:** $W_h \\leftarrow W_h - \\eta \\, \\Delta W_h$, etc.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-5",
   "metadata": {},
   "source": [
    "## 33.3 The Vanishing and Exploding Gradient Problem\n",
    "\n",
    "The BPTT formula contains the product of temporal Jacobians:\n",
    "\n",
    "$$\\prod_{j=k+1}^t \\frac{\\partial h_j}{\\partial h_{j-1}} = \\prod_{j=k+1}^t \\text{diag}(1 - h_j^2) \\, W_h$$\n",
    "\n",
    "This is a product of $t - k$ matrices. What happens to such a product as $t - k$\n",
    "grows large?\n",
    "\n",
    "```{admonition} Theorem (Gradient Magnitude Bound)\n",
    ":class: note\n",
    "\n",
    "Let $\\sigma_{\\max}$ denote the largest singular value of $W_h$, and let\n",
    "$\\gamma = \\max_z |\\tanh'(z)| = 1$. Then:\n",
    "\n",
    "$$\\left\\|\\prod_{j=k+1}^t \\text{diag}(1 - h_j^2) \\, W_h\\right\\| \\le (\\gamma \\cdot \\sigma_{\\max})^{t-k}$$\n",
    "\n",
    "**Proof sketch.** Each factor in the product has norm at most\n",
    "$\\|\\text{diag}(1 - h_j^2)\\| \\cdot \\|W_h\\| \\le \\gamma \\cdot \\sigma_{\\max}$.\n",
    "By sub-multiplicativity of the operator norm:\n",
    "\n",
    "$$\\left\\|\\prod_{j=k+1}^t A_j\\right\\| \\le \\prod_{j=k+1}^t \\|A_j\\| \\le (\\gamma \\cdot \\sigma_{\\max})^{t-k}$$\n",
    "\n",
    "Three regimes emerge:\n",
    "\n",
    "| Condition | Behavior | Consequence |\n",
    "|---|---|---|\n",
    "| $\\gamma \\cdot \\sigma_{\\max} < 1$ | Gradients decay as $(\\gamma \\sigma_{\\max})^{t-k}$ | **Vanishing**: early inputs are forgotten |\n",
    "| $\\gamma \\cdot \\sigma_{\\max} = 1$ | Gradients remain bounded | Ideal (but unstable equilibrium) |\n",
    "| $\\gamma \\cdot \\sigma_{\\max} > 1$ | Gradients grow as $(\\gamma \\sigma_{\\max})^{t-k}$ | **Exploding**: training diverges |\n",
    "```\n",
    "\n",
    "For $\\tanh$, $\\gamma = 1$, so the critical quantity is $\\sigma_{\\max}(W_h)$.\n",
    "In practice, the diagonal factors $\\text{diag}(1 - h_j^2)$ have entries in $[0, 1]$,\n",
    "so even when $\\sigma_{\\max}(W_h) = 1$, the gradient typically *vanishes*.\n",
    "\n",
    "```{admonition} Historical Note\n",
    ":class: note\n",
    "Hochreiter (1991) first identified the vanishing gradient problem in his\n",
    "diploma thesis (in German). Bengio, Simard & Frasconi (1994) published the\n",
    "first widely-read English analysis, proving that learning long-range\n",
    "dependencies with gradient descent is \"difficult\" -- the gradient signal\n",
    "decays exponentially with the temporal distance. This paper is one of the\n",
    "most cited in all of deep learning.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-6",
   "metadata": {},
   "source": [
    "Let us verify the theory numerically. We create a random $W_h$ matrix and\n",
    "compute the product of Jacobians for increasing numbers of steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Numerical verification: Jacobian product norms\n",
    "np.random.seed(42)\n",
    "\n",
    "hidden_size = 32\n",
    "\n",
    "# Case 1: sigma_max(W_h) < 1 (vanishing)\n",
    "W_h_small = np.random.randn(hidden_size, hidden_size) * 0.3\n",
    "sigma_max_small = np.linalg.svd(W_h_small, compute_uv=False)[0]\n",
    "\n",
    "# Case 2: sigma_max(W_h) > 1 (exploding)\n",
    "W_h_large = np.random.randn(hidden_size, hidden_size) * 0.7\n",
    "sigma_max_large = np.linalg.svd(W_h_large, compute_uv=False)[0]\n",
    "\n",
    "print(f'Case 1 (vanishing): sigma_max = {sigma_max_small:.3f}')\n",
    "print(f'Case 2 (exploding): sigma_max = {sigma_max_large:.3f}')\n",
    "print()\n",
    "\n",
    "# Compute ||product of Jacobians|| for T steps\n",
    "T_max = 50\n",
    "norms_small = []\n",
    "norms_large = []\n",
    "\n",
    "for T in range(1, T_max + 1):\n",
    "    # Simulate: use random hidden states for the diagonal\n",
    "    rng = np.random.default_rng(T)\n",
    "    \n",
    "    prod_small = np.eye(hidden_size)\n",
    "    prod_large = np.eye(hidden_size)\n",
    "    \n",
    "    for j in range(T):\n",
    "        # Random hidden state for tanh derivative\n",
    "        h_j = np.tanh(rng.normal(0, 1, hidden_size))\n",
    "        diag_j = np.diag(1 - h_j**2)\n",
    "        \n",
    "        prod_small = diag_j @ W_h_small @ prod_small\n",
    "        prod_large = diag_j @ W_h_large @ prod_large\n",
    "    \n",
    "    norms_small.append(np.linalg.norm(prod_small))\n",
    "    norms_large.append(np.linalg.norm(prod_large))\n",
    "\n",
    "print(f'After {T_max} steps:')\n",
    "print(f'  Vanishing case: ||Jacobian product|| = {norms_small[-1]:.2e}')\n",
    "print(f'  Exploding case: ||Jacobian product|| = {norms_large[-1]:.2e}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-8",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Plot: Jacobian product norms vs number of time steps\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))\n",
    "\n",
    "steps = list(range(1, T_max + 1))\n",
    "\n",
    "ax1.semilogy(steps, norms_small, color=BLUE, linewidth=2,\n",
    "             label=f'$\\\\sigma_{{\\\\max}} = {sigma_max_small:.2f}$')\n",
    "ax1.set_xlabel('Number of Time Steps (t - k)')\n",
    "ax1.set_ylabel('$\\\\|\\\\prod \\\\partial h_j / \\\\partial h_{{j-1}}\\\\|$')\n",
    "ax1.set_title('Vanishing Gradients', fontweight='bold')\n",
    "ax1.legend(fontsize=11)\n",
    "ax1.axhline(y=1, color=GRAY, linestyle='--', alpha=0.5)\n",
    "\n",
    "ax2.semilogy(steps, norms_large, color=RED, linewidth=2,\n",
    "             label=f'$\\\\sigma_{{\\\\max}} = {sigma_max_large:.2f}$')\n",
    "ax2.set_xlabel('Number of Time Steps (t - k)')\n",
    "ax2.set_ylabel('$\\\\|\\\\prod \\\\partial h_j / \\\\partial h_{{j-1}}\\\\|$')\n",
    "ax2.set_title('Exploding Gradients', fontweight='bold')\n",
    "ax2.legend(fontsize=11)\n",
    "ax2.axhline(y=1, color=GRAY, linestyle='--', alpha=0.5)\n",
    "\n",
    "fig.suptitle('Temporal Jacobian Product Norms',\n",
    "             fontsize=14, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print('Left: gradients shrink exponentially -> network forgets distant inputs.')\n",
    "print('Right: gradients grow exponentially -> training becomes unstable.')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-9",
   "metadata": {},
   "source": [
    "## 33.4 Empirical Demonstration: The \"Remember the First\" Task\n",
    "\n",
    "To make the vanishing gradient problem tangible, we design a task that\n",
    "**requires** long-range memory: given a sequence of $T$ random characters,\n",
    "the model must output the **first** character at the very end.\n",
    "\n",
    "$$\\underbrace{x_1}_{\\text{remember this}}, \\;x_2, \\;x_3, \\;\\ldots, \\;x_T \\;\\longrightarrow \\;x_1$$\n",
    "\n",
    "For this task, the gradient from the loss at time $T$ must flow all the way\n",
    "back to time $1$. If gradients vanish over $T$ steps, the network cannot\n",
    "learn to solve this task.\n",
    "\n",
    "We train a simple RNN with `hidden_size=32` for 200 epochs on this task\n",
    "with varying sequence lengths $T \\in \\{5, 10, 20, 50\\}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-10",
   "metadata": {},
   "outputs": [],
   "source": [
    "# \"Remember the first character\" task\n",
    "\n",
    "class RememberFirstRNN(nn.Module):\n",
    "    \"\"\"RNN that must predict the first character of a sequence at the end.\"\"\"\n",
    "    \n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super().__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.rnn = nn.RNN(vocab_size, hidden_size, batch_first=True)\n",
    "        self.fc = nn.Linear(hidden_size, vocab_size)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"x: (batch, seq_len, vocab_size) one-hot encoded.\n",
    "        Returns logits for predicting the first character.\"\"\"\n",
    "        out, _ = self.rnn(x)         # (batch, seq_len, hidden_size)\n",
    "        last_h = out[:, -1, :]       # (batch, hidden_size) -- final step\n",
    "        logits = self.fc(last_h)     # (batch, vocab_size)\n",
    "        return logits\n",
    "\n",
    "\n",
    "def generate_remember_first_data(n_samples, seq_len, n_classes=8, seed=42):\n",
    "    \"\"\"Generate data for the 'remember the first' task.\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    X : tensor, shape (n_samples, seq_len, n_classes) -- one-hot\n",
    "    y : tensor, shape (n_samples,) -- class of first character\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    indices = rng.integers(0, n_classes, size=(n_samples, seq_len))\n",
    "    X = torch.zeros(n_samples, seq_len, n_classes)\n",
    "    for i in range(n_samples):\n",
    "        for t in range(seq_len):\n",
    "            X[i, t, indices[i, t]] = 1.0\n",
    "    y = torch.tensor(indices[:, 0], dtype=torch.long)  # first character\n",
    "    return X, y\n",
    "\n",
    "\n",
    "print('RememberFirstRNN class and data generator defined.')\n",
    "print('Task: given a sequence of T random characters, predict the first one.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-11",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train on different sequence lengths\n",
    "seq_lengths = [5, 10, 20, 50]\n",
    "n_classes = 8\n",
    "hidden_size = 32\n",
    "n_train = 512\n",
    "n_test = 128\n",
    "n_epochs = 200\n",
    "batch_size = 64\n",
    "\n",
    "results = {}  # seq_len -> {'losses': [...], 'accs': [...], 'grad_norms': [...]}\n",
    "\n",
    "for seq_len in seq_lengths:\n",
    "    torch.manual_seed(42)\n",
    "    \n",
    "    # Generate data\n",
    "    X_train, y_train = generate_remember_first_data(n_train, seq_len, n_classes, seed=42)\n",
    "    X_test, y_test = generate_remember_first_data(n_test, seq_len, n_classes, seed=99)\n",
    "    \n",
    "    # Create model\n",
    "    model = RememberFirstRNN(n_classes, hidden_size)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    \n",
    "    losses = []\n",
    "    accs = []\n",
    "    grad_norms_wh = []\n",
    "    \n",
    "    for epoch in range(n_epochs):\n",
    "        model.train()\n",
    "        epoch_loss = 0.0\n",
    "        n_batches = 0\n",
    "        \n",
    "        # Mini-batch training\n",
    "        perm = torch.randperm(n_train)\n",
    "        for start in range(0, n_train, batch_size):\n",
    "            idx = perm[start:start+batch_size]\n",
    "            xb = X_train[idx]\n",
    "            yb = y_train[idx]\n",
    "            \n",
    "            logits = model(xb)\n",
    "            loss = loss_fn(logits, yb)\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            \n",
    "            # Record gradient norm of W_hh\n",
    "            wh_grad = model.rnn.weight_hh_l0.grad\n",
    "            if wh_grad is not None:\n",
    "                grad_norms_wh.append(wh_grad.norm().item())\n",
    "            \n",
    "            optimizer.step()\n",
    "            epoch_loss += loss.item()\n",
    "            n_batches += 1\n",
    "        \n",
    "        losses.append(epoch_loss / n_batches)\n",
    "        \n",
    "        # Test accuracy\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            test_logits = model(X_test)\n",
    "            test_preds = test_logits.argmax(dim=1)\n",
    "            acc = (test_preds == y_test).float().mean().item()\n",
    "        accs.append(acc)\n",
    "    \n",
    "    results[seq_len] = {\n",
    "        'losses': losses,\n",
    "        'accs': accs,\n",
    "        'grad_norms': grad_norms_wh\n",
    "    }\n",
    "    \n",
    "    print(f'T={seq_len:3d}: final acc = {accs[-1]:.1%}, '\n",
    "          f'final loss = {losses[-1]:.3f}, '\n",
    "          f'chance = {1/n_classes:.1%}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-12",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Plot: accuracy vs sequence length and gradient norms\n",
    "colors_seq = {5: BLUE, 10: GREEN, 20: AMBER, 50: RED}\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))\n",
    "\n",
    "# Panel 1: Training loss\n",
    "ax = axes[0]\n",
    "for T in seq_lengths:\n",
    "    ax.plot(results[T]['losses'], color=colors_seq[T], linewidth=1.5,\n",
    "            label=f'T={T}')\n",
    "ax.set_xlabel('Epoch')\n",
    "ax.set_ylabel('Loss')\n",
    "ax.set_title('Training Loss', fontweight='bold')\n",
    "ax.legend()\n",
    "chance_loss = -np.log(1/n_classes)\n",
    "ax.axhline(y=chance_loss, color=GRAY, linestyle='--', alpha=0.5,\n",
    "           label='Chance')\n",
    "\n",
    "# Panel 2: Test accuracy\n",
    "ax = axes[1]\n",
    "for T in seq_lengths:\n",
    "    ax.plot(results[T]['accs'], color=colors_seq[T], linewidth=1.5,\n",
    "            label=f'T={T}')\n",
    "ax.axhline(y=1/n_classes, color=GRAY, linestyle='--', alpha=0.5,\n",
    "           label='Chance (12.5%)')\n",
    "ax.set_xlabel('Epoch')\n",
    "ax.set_ylabel('Accuracy')\n",
    "ax.set_title('Test Accuracy', fontweight='bold')\n",
    "ax.set_ylim(0, 1.05)\n",
    "ax.legend()\n",
    "\n",
    "# Panel 3: Final accuracy vs sequence length\n",
    "ax = axes[2]\n",
    "final_accs = [results[T]['accs'][-1] for T in seq_lengths]\n",
    "bar_colors = [colors_seq[T] for T in seq_lengths]\n",
    "bars = ax.bar([str(T) for T in seq_lengths], final_accs, color=bar_colors,\n",
    "              edgecolor='white', linewidth=1.5)\n",
    "ax.axhline(y=1/n_classes, color=GRAY, linestyle='--', alpha=0.5,\n",
    "           label='Chance')\n",
    "ax.set_xlabel('Sequence Length T')\n",
    "ax.set_ylabel('Final Test Accuracy')\n",
    "ax.set_title('Accuracy vs Sequence Length', fontweight='bold')\n",
    "ax.set_ylim(0, 1.05)\n",
    "for bar, acc in zip(bars, final_accs):\n",
    "    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,\n",
    "            f'{acc:.0%}', ha='center', va='bottom', fontweight='bold', fontsize=11)\n",
    "ax.legend()\n",
    "\n",
    "fig.suptitle('\"Remember the First Character\" Task: Simple RNN Performance',\n",
    "             fontsize=14, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print('Short sequences (T=5, 10): RNN can learn to remember the first character.')\n",
    "print('Long sequences (T=20, 50): accuracy collapses toward chance level.')\n",
    "print('This is the vanishing gradient problem in action.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-13",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Gradient norm at each time step (for a single example)\n",
    "# We compute the gradient of the loss w.r.t. the hidden state at each step\n",
    "\n",
    "torch.manual_seed(42)\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))\n",
    "\n",
    "for plot_idx, seq_len in enumerate([10, 50]):\n",
    "    model = RememberFirstRNN(n_classes, hidden_size)\n",
    "    X_demo, y_demo = generate_remember_first_data(1, seq_len, n_classes, seed=42)\n",
    "    \n",
    "    # Manual forward pass to get hidden states with gradients\n",
    "    x_input = X_demo  # (1, seq_len, n_classes)\n",
    "    h = torch.zeros(1, 1, hidden_size)\n",
    "    \n",
    "    # Store hidden states\n",
    "    hidden_states = []\n",
    "    rnn_cell = model.rnn\n",
    "    \n",
    "    # Use the RNN layer step by step\n",
    "    h_t = torch.zeros(1, 1, hidden_size, requires_grad=True)\n",
    "    all_h = []\n",
    "    \n",
    "    # Forward through RNN step by step\n",
    "    out, _ = model.rnn(x_input, h_t)\n",
    "    last_h = out[:, -1, :]\n",
    "    logits = model.fc(last_h)\n",
    "    loss = nn.CrossEntropyLoss()(logits, y_demo)\n",
    "    loss.backward()\n",
    "    \n",
    "    # Compute gradient norms at each time step using hooks\n",
    "    # Alternative: compute numerically by looking at how much the loss changes\n",
    "    # when we perturb h_t\n",
    "    model2 = RememberFirstRNN(n_classes, hidden_size)\n",
    "    # Copy weights\n",
    "    model2.load_state_dict(model.state_dict())\n",
    "    \n",
    "    grad_norms_per_step = []\n",
    "    \n",
    "    for t_probe in range(seq_len):\n",
    "        # Forward pass, but make hidden state at step t_probe require grad\n",
    "        model2.eval()\n",
    "        x_in = X_demo.clone()\n",
    "        \n",
    "        # Manually unroll to inject gradient tracking at step t_probe\n",
    "        W_ih = model2.rnn.weight_ih_l0  # (hidden_size, input_size)\n",
    "        W_hh = model2.rnn.weight_hh_l0  # (hidden_size, hidden_size)\n",
    "        b_ih = model2.rnn.bias_ih_l0\n",
    "        b_hh = model2.rnn.bias_hh_l0\n",
    "        \n",
    "        h_cur = torch.zeros(hidden_size)\n",
    "        h_list = []\n",
    "        \n",
    "        for t in range(seq_len):\n",
    "            x_t = x_in[0, t]  # (n_classes,)\n",
    "            z = W_ih @ x_t + b_ih + W_hh @ h_cur + b_hh\n",
    "            h_cur = torch.tanh(z)\n",
    "            if t == t_probe:\n",
    "                h_cur = h_cur.detach().requires_grad_(True)\n",
    "                h_probe = h_cur\n",
    "            h_list.append(h_cur)\n",
    "        \n",
    "        final_logits = model2.fc(h_list[-1].unsqueeze(0))\n",
    "        probe_loss = nn.CrossEntropyLoss()(final_logits, y_demo)\n",
    "        probe_loss.backward()\n",
    "        \n",
    "        grad_norm = h_probe.grad.norm().item()\n",
    "        grad_norms_per_step.append(grad_norm)\n",
    "    \n",
    "    ax = axes[plot_idx]\n",
    "    ax.semilogy(range(seq_len), grad_norms_per_step,\n",
    "                color=BLUE if seq_len == 10 else RED,\n",
    "                linewidth=2, marker='o', markersize=3)\n",
    "    ax.set_xlabel('Time Step t')\n",
    "    ax.set_ylabel('$\\\\|\\\\partial L / \\\\partial h_t\\\\|$ (log scale)')\n",
    "    ax.set_title(f'Gradient Norm at Each Step (T={seq_len})',\n",
    "                 fontweight='bold')\n",
    "    ax.axvline(x=0, color=GREEN, linestyle='--', alpha=0.5, label='Target info (t=0)')\n",
    "    ax.legend()\n",
    "\n",
    "fig.suptitle('Gradient Signal Decay Through Time',\n",
    "             fontsize=14, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print('The gradient at t=0 (where the target information resides) is much')\n",
    "print('smaller than at t=T-1 (where the loss is computed).')\n",
    "print('For T=50, the signal at t=0 is essentially zero -- the network')\n",
    "print('cannot learn from the first character.')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-14",
   "metadata": {},
   "source": [
    "```{admonition} The Key Insight\n",
    ":class: danger\n",
    "The \"remember the first character\" experiment reveals the core failure mode\n",
    "of simple RNNs: the gradient signal from the loss at the end of the sequence\n",
    "**decays exponentially** as it propagates backward through time. For long\n",
    "sequences, the gradient at early time steps is effectively zero, making it\n",
    "impossible to learn dependencies that span many steps.\n",
    "\n",
    "This is not a matter of training longer or using a better optimizer -- it is\n",
    "a **structural** limitation of the simple RNN architecture. Overcoming it\n",
    "requires architectural changes (LSTM, GRU) that create alternative gradient\n",
    "pathways through the network.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-15",
   "metadata": {},
   "source": [
    "## 33.5 Gradient Clipping\n",
    "\n",
    "While the vanishing gradient problem has no simple fix within the simple RNN\n",
    "architecture, the **exploding** gradient problem can be mitigated with a\n",
    "straightforward technique: **gradient clipping**.\n",
    "\n",
    "The idea, introduced by Pascanu, Mikolov & Bengio (2013) in their paper\n",
    "*\"On the difficulty of training recurrent neural networks\"*, is to rescale\n",
    "the gradient whenever its norm exceeds a threshold $\\theta$:\n",
    "\n",
    "$$\\hat{g} = \\begin{cases}\n",
    "\\frac{\\theta}{\\|g\\|} g & \\text{if } \\|g\\| > \\theta \\\\\n",
    "g & \\text{otherwise}\n",
    "\\end{cases}$$\n",
    "\n",
    "In PyTorch, this is a single line:\n",
    "```python\n",
    "torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=theta)\n",
    "```\n",
    "\n",
    "```{admonition} Clipping Prevents Explosion but Not Vanishing\n",
    ":class: warning\n",
    "Gradient clipping is a practical necessity for training RNNs, but it only\n",
    "addresses one half of the problem. It prevents gradients from exploding\n",
    "(causing NaN losses or wild parameter updates), but it does nothing to\n",
    "amplify gradients that have vanished. A clipped gradient of $10^{-15}$ is\n",
    "still $10^{-15}$.\n",
    "```\n",
    "\n",
    "Let us demonstrate gradient clipping on a sequence where the exploding\n",
    "gradient problem would otherwise cause training to diverge."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-16",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Demonstrate gradient clipping\n",
    "torch.manual_seed(42)\n",
    "\n",
    "seq_len_clip = 20\n",
    "X_clip, y_clip = generate_remember_first_data(256, seq_len_clip, n_classes, seed=42)\n",
    "\n",
    "# Train with and without gradient clipping using a larger learning rate\n",
    "# to provoke instability\n",
    "configs = [\n",
    "    ('No clipping', None, 0.01),\n",
    "    ('Clip norm=5', 5.0, 0.01),\n",
    "    ('Clip norm=1', 1.0, 0.01),\n",
    "]\n",
    "\n",
    "clip_results = {}\n",
    "\n",
    "for name, clip_val, lr in configs:\n",
    "    torch.manual_seed(42)\n",
    "    model = RememberFirstRNN(n_classes, hidden_size)\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    \n",
    "    losses = []\n",
    "    grad_norms = []\n",
    "    \n",
    "    for epoch in range(100):\n",
    "        model.train()\n",
    "        logits = model(X_clip)\n",
    "        loss = loss_fn(logits, y_clip)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        \n",
    "        # Record gradient norm BEFORE clipping\n",
    "        total_norm = 0.0\n",
    "        for p in model.parameters():\n",
    "            if p.grad is not None:\n",
    "                total_norm += p.grad.data.norm(2).item() ** 2\n",
    "        total_norm = total_norm ** 0.5\n",
    "        grad_norms.append(total_norm)\n",
    "        \n",
    "        # Apply clipping if specified\n",
    "        if clip_val is not None:\n",
    "            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_val)\n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        loss_val = loss.item()\n",
    "        if np.isnan(loss_val) or np.isinf(loss_val):\n",
    "            losses.append(float('nan'))\n",
    "            # Fill remaining with NaN\n",
    "            losses.extend([float('nan')] * (99 - epoch))\n",
    "            grad_norms.extend([float('nan')] * (99 - epoch))\n",
    "            break\n",
    "        losses.append(loss_val)\n",
    "    \n",
    "    clip_results[name] = {'losses': losses, 'grad_norms': grad_norms}\n",
    "    final_loss = losses[-1] if not np.isnan(losses[-1]) else 'DIVERGED'\n",
    "    print(f'{name:20s}: final loss = {final_loss}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-17",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Plot: effect of gradient clipping\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4.5))\n",
    "\n",
    "clip_colors = {'No clipping': RED, 'Clip norm=5': AMBER, 'Clip norm=1': GREEN}\n",
    "\n",
    "for name, res in clip_results.items():\n",
    "    valid_losses = [l for l in res['losses'] if not np.isnan(l)]\n",
    "    ax1.plot(range(len(valid_losses)), valid_losses,\n",
    "             color=clip_colors[name], linewidth=1.5, label=name)\n",
    "    \n",
    "    valid_gn = [g for g in res['grad_norms'] if not np.isnan(g)]\n",
    "    ax2.semilogy(range(len(valid_gn)), valid_gn,\n",
    "                 color=clip_colors[name], linewidth=1.5, label=name, alpha=0.8)\n",
    "\n",
    "ax1.set_xlabel('Epoch')\n",
    "ax1.set_ylabel('Loss')\n",
    "ax1.set_title('Training Loss', fontweight='bold')\n",
    "ax1.legend()\n",
    "ax1.set_ylim(0, 5)\n",
    "\n",
    "ax2.set_xlabel('Epoch')\n",
    "ax2.set_ylabel('Gradient Norm (before clipping)')\n",
    "ax2.set_title('Gradient Norms During Training', fontweight='bold')\n",
    "ax2.legend()\n",
    "\n",
    "fig.suptitle('Effect of Gradient Clipping on RNN Training',\n",
    "             fontsize=14, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print('Gradient clipping prevents loss spikes and divergence.')\n",
    "print('However, it does NOT help the network learn long-range dependencies.')\n",
    "print('The vanishing gradient problem remains -- only explosion is tamed.')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-18",
   "metadata": {},
   "source": [
    "## 33.6 Truncated BPTT\n",
    "\n",
    "Full BPTT propagates gradients through the entire sequence of length $T$.\n",
    "This has two costs:\n",
    "\n",
    "1. **Memory:** We must store all $T$ hidden states for the backward pass.\n",
    "2. **Time:** The backward pass is $O(T)$, which can be slow for long sequences.\n",
    "\n",
    "**Truncated backpropagation through time** (TBPTT) addresses both costs by\n",
    "limiting the backward pass to only $K$ steps, where $K \\ll T$.\n",
    "\n",
    "### How It Works\n",
    "\n",
    "Instead of backpropagating through the entire sequence:\n",
    "\n",
    "1. Process the sequence in chunks of $K$ time steps.\n",
    "2. After each chunk, compute the loss and backpropagate through the $K$ steps.\n",
    "3. **Detach** the hidden state before starting the next chunk, severing\n",
    "   the gradient connection to earlier time steps.\n",
    "\n",
    "In PyTorch, detaching is a single operation: `h = h.detach()`.\n",
    "\n",
    "```{admonition} The Truncation Trade-off\n",
    ":class: important\n",
    "Truncated BPTT trades **long-range gradient flow** for **computational\n",
    "efficiency**. With truncation length $K$:\n",
    "\n",
    "- The network can still *use* long-range information (via the hidden state,\n",
    "  which propagates forward without truncation).\n",
    "- But it can only *learn* from dependencies up to $K$ steps apart\n",
    "  (because gradients are cut beyond $K$ steps).\n",
    "\n",
    "Choosing $K$ is an engineering judgment: too small and the network cannot\n",
    "learn medium-range patterns; too large and you lose the computational\n",
    "benefits (and still face vanishing gradients).\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-19",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Demonstrate truncated BPTT\n",
    "def train_with_tbptt(model, X, y, K, n_epochs=100, lr=0.005):\n",
    "    \"\"\"Train an RNN using truncated BPTT with truncation length K.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    model : RememberFirstRNN\n",
    "    X : tensor (batch, seq_len, n_classes)\n",
    "    y : tensor (batch,)\n",
    "    K : int, truncation length (0 = full BPTT)\n",
    "    \"\"\"\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    seq_len = X.shape[1]\n",
    "    losses = []\n",
    "    \n",
    "    for epoch in range(n_epochs):\n",
    "        model.train()\n",
    "        \n",
    "        if K == 0 or K >= seq_len:\n",
    "            # Full BPTT\n",
    "            logits = model(X)\n",
    "            loss = loss_fn(logits, y)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n",
    "            optimizer.step()\n",
    "            losses.append(loss.item())\n",
    "        else:\n",
    "            # Truncated BPTT\n",
    "            W_ih = model.rnn.weight_ih_l0\n",
    "            W_hh = model.rnn.weight_hh_l0\n",
    "            b_ih = model.rnn.bias_ih_l0\n",
    "            b_hh = model.rnn.bias_hh_l0\n",
    "            \n",
    "            batch_size_local = X.shape[0]\n",
    "            h = torch.zeros(batch_size_local, model.hidden_size)\n",
    "            \n",
    "            total_loss = 0.0\n",
    "            steps_in_chunk = 0\n",
    "            \n",
    "            for t in range(seq_len):\n",
    "                x_t = X[:, t, :]  # (batch, n_classes)\n",
    "                z = x_t @ W_ih.t() + b_ih + h @ W_hh.t() + b_hh\n",
    "                h = torch.tanh(z)\n",
    "                steps_in_chunk += 1\n",
    "                \n",
    "                # At chunk boundaries (or end), detach\n",
    "                if steps_in_chunk >= K and t < seq_len - 1:\n",
    "                    h = h.detach()\n",
    "                    steps_in_chunk = 0\n",
    "            \n",
    "            # Final prediction\n",
    "            logits = model.fc(h)\n",
    "            loss = loss_fn(logits, y)\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)\n",
    "            optimizer.step()\n",
    "            losses.append(loss.item())\n",
    "    \n",
    "    return losses\n",
    "\n",
    "\n",
    "# Compare full BPTT vs truncated on T=20 sequence\n",
    "torch.manual_seed(42)\n",
    "seq_len_tbptt = 20\n",
    "X_tbptt, y_tbptt = generate_remember_first_data(256, seq_len_tbptt, n_classes, seed=42)\n",
    "X_test_tbptt, y_test_tbptt = generate_remember_first_data(128, seq_len_tbptt, n_classes, seed=99)\n",
    "\n",
    "tbptt_results = {}\n",
    "truncation_lengths = [0, 20, 10, 5]  # 0 = full BPTT\n",
    "labels_tbptt = ['Full BPTT', 'K=20 (full)', 'K=10', 'K=5']\n",
    "\n",
    "for K, label in zip(truncation_lengths, labels_tbptt):\n",
    "    torch.manual_seed(42)\n",
    "    model = RememberFirstRNN(n_classes, hidden_size)\n",
    "    losses = train_with_tbptt(model, X_tbptt, y_tbptt, K, n_epochs=150, lr=0.005)\n",
    "    \n",
    "    # Test accuracy\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        test_logits = model(X_test_tbptt)\n",
    "        acc = (test_logits.argmax(1) == y_test_tbptt).float().mean().item()\n",
    "    \n",
    "    tbptt_results[label] = {'losses': losses, 'acc': acc}\n",
    "    print(f'{label:15s}: final acc = {acc:.1%}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-20",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Plot truncated BPTT comparison\n",
    "fig, ax = plt.subplots(figsize=(10, 5))\n",
    "\n",
    "tbptt_colors = {'Full BPTT': BLUE, 'K=20 (full)': GREEN, 'K=10': AMBER, 'K=5': RED}\n",
    "\n",
    "for label, res in tbptt_results.items():\n",
    "    ax.plot(res['losses'], color=tbptt_colors[label], linewidth=1.5,\n",
    "            label=f'{label} (acc={res[\"acc\"]:.0%})')\n",
    "\n",
    "ax.set_xlabel('Epoch', fontsize=12)\n",
    "ax.set_ylabel('Loss', fontsize=12)\n",
    "ax.set_title('Truncated BPTT: Training Loss for \"Remember First\" (T=20)',\n",
    "             fontweight='bold', fontsize=13)\n",
    "ax.legend(fontsize=11)\n",
    "ax.axhline(y=-np.log(1/n_classes), color=GRAY, linestyle='--', alpha=0.5)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print('Truncated BPTT with K < T cuts off gradient flow to early time steps.')\n",
    "print('For the \"remember first\" task, K=5 is too short -- the network cannot')\n",
    "print('learn to propagate information from step 0 to step 19.')\n",
    "print()\n",
    "print('In practice, truncated BPTT is used with K=20-200 to balance efficiency')\n",
    "print('and learning range. But for truly long-range dependencies, architectural')\n",
    "print('solutions (LSTM, Transformer) are needed.')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-21",
   "metadata": {},
   "source": [
    "## Looking Ahead\n",
    "\n",
    "The vanishing gradient problem is not just a practical nuisance -- it is a\n",
    "**theoretical barrier** that limits what simple RNNs can learn. Hochreiter\n",
    "and Schmidhuber recognized this in 1997 and proposed the **Long Short-Term\n",
    "Memory** (LSTM) architecture, which introduces gating mechanisms that create\n",
    "a \"gradient highway\" through time, allowing information to flow across\n",
    "hundreds of time steps without decay.\n",
    "\n",
    "The LSTM is the subject of our next chapter. Understanding *why* it works\n",
    "requires understanding *why* the simple RNN fails -- and that is precisely\n",
    "what we have established in this chapter:\n",
    "\n",
    "1. Gradients are products of Jacobians along the time axis.\n",
    "2. These products shrink (or grow) exponentially with sequence length.\n",
    "3. Gradient clipping fixes explosion but not vanishing.\n",
    "4. Truncated BPTT reduces cost but limits the learning horizon.\n",
    "\n",
    "The LSTM's solution is elegant: instead of multiplying by the same $W_h$\n",
    "at every step, it learns **gates** that control what to remember, what to\n",
    "forget, and what to output -- creating a cell state that can carry\n",
    "information across arbitrary distances."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-22",
   "metadata": {},
   "source": [
    "## Exercises\n",
    "\n",
    "**Exercise 33.1.** Derive the BPTT gradient for $W_x$ (the input-to-hidden\n",
    "weight matrix). Show that it has the same product-of-Jacobians structure as\n",
    "the gradient for $W_h$, and explain why the vanishing gradient problem\n",
    "affects $W_x$ equally.\n",
    "\n",
    "**Exercise 33.2.** Consider a linear RNN (no activation function):\n",
    "$h_t = W_h h_{t-1} + W_x x_t$. Show that the temporal Jacobian product\n",
    "simplifies to $W_h^{t-k}$. If $W_h$ has eigenvalues $\\lambda_1, \\ldots, \\lambda_n$,\n",
    "express the gradient in terms of $\\lambda_i^{t-k}$ and discuss when\n",
    "vanishing/exploding occurs.\n",
    "\n",
    "**Exercise 33.3.** Run the \"remember the first\" experiment with\n",
    "`hidden_size=128` instead of 32. Does increasing the hidden size help with\n",
    "the vanishing gradient problem? Why or why not?\n",
    "\n",
    "**Exercise 33.4.** Implement BPTT manually for a 3-step RNN (without using\n",
    "`loss.backward()`). Given a concrete $W_h$, $W_x$, $b_h$, input sequence\n",
    "$(x_1, x_2, x_3)$, and target $y_3$:\n",
    "- Compute the forward pass.\n",
    "- Compute $\\frac{\\partial L}{\\partial W_h}$ by hand using the BPTT formula.\n",
    "- Verify your result against PyTorch's autograd.\n",
    "\n",
    "**Exercise 33.5.** Pascanu et al. (2013) also propose **gradient norm\n",
    "rescaling** as an alternative to clipping: instead of clipping to a maximum\n",
    "norm, rescale so the gradient always has a fixed norm $\\theta$. Implement\n",
    "this and compare training dynamics to standard clipping on the\n",
    "\"remember the first\" task with $T=20$."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}