{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a0b1c2d3",
   "metadata": {},
   "source": [
    "# Chapter 34: LSTM — The Gating Revolution\n",
    "\n",
    "The vanishing gradient problem identified by Hochreiter (1991) and Bengio et al. (1994) seemed to doom recurrent networks. In 1997, Hochreiter and Schmidhuber proposed an elegant solution: instead of fighting the gradient decay, engineer a pathway where gradients can flow unchanged.\n",
    "\n",
    "The **Long Short-Term Memory** (LSTM) network introduces a separate *cell state* $C_t$ that carries information forward through time via additive updates, bypassing the multiplicative bottleneck that causes gradients to vanish in vanilla RNNs. Three learnable *gates* control the flow of information into, out of, and within this cell state. The result is a network that can learn dependencies spanning hundreds of time steps—something that was practically impossible with the architectures we studied in previous chapters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1c2d3e4",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "plt.style.use('seaborn-v0_8-whitegrid')\n",
    "\n",
    "BLUE = '#3b82f6'\n",
    "GREEN = '#059669'\n",
    "RED = '#dc2626'\n",
    "AMBER = '#d97706'\n",
    "INDIGO = '#4f46e5'\n",
    "\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "\n",
    "print('PyTorch version:', torch.__version__)\n",
    "print('Device:', 'cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2d3e4f5",
   "metadata": {},
   "source": [
    "## 34.1 The Constant Error Carousel\n",
    "\n",
    "Recall the fundamental problem with vanilla RNNs: the hidden state update\n",
    "\n",
    "$$h_t = \\tanh(W_{hh} h_{t-1} + W_{xh} x_t + b)$$\n",
    "\n",
    "involves a **multiplicative** interaction with $W_{hh}$ at every time step. When we backpropagate through $T$ steps, the gradient includes the product\n",
    "\n",
    "$$\\frac{\\partial h_T}{\\partial h_1} = \\prod_{t=2}^{T} \\frac{\\partial h_t}{\\partial h_{t-1}} = \\prod_{t=2}^{T} W_{hh}^\\top \\cdot \\text{diag}(1 - h_t^2)$$\n",
    "\n",
    "If the largest singular value of $W_{hh}$ is less than 1, this product vanishes exponentially. If it is greater than 1, the product explodes.\n",
    "\n",
    "Hochreiter's key insight was to replace this multiplicative chain with an **additive** update. The LSTM cell state update is:\n",
    "\n",
    "$$C_t = f_t \\odot C_{t-1} + i_t \\odot \\tilde{C}_t$$\n",
    "\n",
    "The gradient of $C_t$ with respect to $C_{t-1}$ is simply:\n",
    "\n",
    "$$\\frac{\\partial C_t}{\\partial C_{t-1}} = f_t$$\n",
    "\n",
    "where $f_t$ is the **forget gate**, a sigmoid output that can be close to 1. When $f_t \\approx 1$, the gradient passes through unchanged—this is the **Constant Error Carousel** (CEC). Information stored in the cell state can persist indefinitely, and gradients flow back through time without decay.\n",
    "\n",
    "```{admonition} The Constant Error Carousel\n",
    ":class: important\n",
    "\n",
    "The CEC is the defining innovation of LSTM. By making the cell state update *additive* rather than *multiplicative*, the gradient $\\partial C_t / \\partial C_{t-1} = f_t$ can remain close to 1 for arbitrarily many time steps. This solves the vanishing gradient problem at its mathematical root.\n",
    "```\n",
    "\n",
    "The following figure compares gradient flow in a vanilla RNN versus an LSTM:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3e4f5a6",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Gradient flow comparison: Vanilla RNN vs LSTM\n",
    "fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))\n",
    "\n",
    "T = 30\n",
    "steps = np.arange(1, T + 1)\n",
    "\n",
    "# Vanilla RNN gradient decay\n",
    "ax = axes[0]\n",
    "for gamma, label, color, ls in [\n",
    "    (0.95, r'$\\|W_{hh}\\| = 0.95$', BLUE, '-'),\n",
    "    (0.85, r'$\\|W_{hh}\\| = 0.85$', AMBER, '--'),\n",
    "    (0.70, r'$\\|W_{hh}\\| = 0.70$', RED, '-.'),\n",
    "]:\n",
    "    grad_norms = gamma ** steps\n",
    "    ax.plot(steps, grad_norms, color=color, linestyle=ls, linewidth=2, label=label)\n",
    "\n",
    "ax.set_xlabel('Time steps back', fontsize=11)\n",
    "ax.set_ylabel('Gradient magnitude (relative)', fontsize=11)\n",
    "ax.set_title('Vanilla RNN: Gradient Decay', fontsize=12, fontweight='bold')\n",
    "ax.set_ylim(0, 1.1)\n",
    "ax.legend(fontsize=9)\n",
    "ax.axhline(y=0.01, color='gray', linestyle=':', alpha=0.5)\n",
    "ax.text(T - 1, 0.03, 'effectively zero', fontsize=8, color='gray', ha='right')\n",
    "\n",
    "# LSTM gradient preservation\n",
    "ax = axes[1]\n",
    "for f_val, label, color, ls in [\n",
    "    (1.00, r'$f_t = 1.0$ (perfect memory)', GREEN, '-'),\n",
    "    (0.98, r'$f_t = 0.98$', BLUE, '--'),\n",
    "    (0.90, r'$f_t = 0.90$', AMBER, '-.'),\n",
    "]:\n",
    "    grad_norms = f_val ** steps\n",
    "    ax.plot(steps, grad_norms, color=color, linestyle=ls, linewidth=2, label=label)\n",
    "\n",
    "ax.set_xlabel('Time steps back', fontsize=11)\n",
    "ax.set_ylabel('Gradient magnitude (relative)', fontsize=11)\n",
    "ax.set_title('LSTM: Gradient via Cell State', fontsize=12, fontweight='bold')\n",
    "ax.set_ylim(0, 1.1)\n",
    "ax.legend(fontsize=9)\n",
    "\n",
    "plt.suptitle('Gradient Flow: Vanilla RNN vs LSTM', fontsize=14, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4f5a6b7",
   "metadata": {},
   "source": [
    "The contrast is stark. A vanilla RNN with $\\|W_{hh}\\| = 0.85$ retains less than 1% of the gradient after just 25 steps. An LSTM with forget gate values near 1 preserves the gradient almost perfectly over the same horizon."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5a6b7c8",
   "metadata": {},
   "source": [
    "## 34.2 LSTM Cell Architecture\n",
    "\n",
    "The LSTM cell maintains two state vectors: the **cell state** $C_t$ (the long-term memory highway) and the **hidden state** $h_t$ (the short-term output). At each time step, three gates regulate the information flow.\n",
    "\n",
    "```{admonition} Definition: LSTM Cell Equations\n",
    ":class: note\n",
    "\n",
    "Given input $x_t \\in \\mathbb{R}^d$, previous hidden state $h_{t-1} \\in \\mathbb{R}^n$, and previous cell state $C_{t-1} \\in \\mathbb{R}^n$:\n",
    "\n",
    "**Forget gate** (what to discard from cell state):\n",
    "\n",
    "$$f_t = \\sigma(W_f [h_{t-1}, x_t] + b_f)$$\n",
    "\n",
    "**Input gate** (what new information to store):\n",
    "\n",
    "$$i_t = \\sigma(W_i [h_{t-1}, x_t] + b_i)$$\n",
    "\n",
    "**Cell candidate** (proposed new content):\n",
    "\n",
    "$$\\tilde{C}_t = \\tanh(W_C [h_{t-1}, x_t] + b_C)$$\n",
    "\n",
    "**Cell state update** (the Constant Error Carousel):\n",
    "\n",
    "$$C_t = f_t \\odot C_{t-1} + i_t \\odot \\tilde{C}_t$$\n",
    "\n",
    "**Output gate** (what to reveal from cell state):\n",
    "\n",
    "$$o_t = \\sigma(W_o [h_{t-1}, x_t] + b_o)$$\n",
    "\n",
    "**Hidden state** (output at this time step):\n",
    "\n",
    "$$h_t = o_t \\odot \\tanh(C_t)$$\n",
    "\n",
    "Here $\\sigma$ is the sigmoid function, $\\odot$ denotes element-wise multiplication, and $[h_{t-1}, x_t]$ denotes concatenation.\n",
    "```\n",
    "\n",
    "Each gate is a full neural network layer with its own weights and biases. The sigmoid activation ensures gate values lie in $[0, 1]$, acting as soft switches:\n",
    "- $f_t \\approx 1$: keep the old cell state (remember).\n",
    "- $f_t \\approx 0$: erase the old cell state (forget).\n",
    "- $i_t \\approx 1$: write the candidate into the cell state.\n",
    "- $o_t \\approx 1$: expose the cell state to the outside.\n",
    "\n",
    "The following diagram illustrates the data flow through an LSTM cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6b7c8d9",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# LSTM Cell Diagram\n",
    "fig, ax = plt.subplots(figsize=(14, 8))\n",
    "ax.set_xlim(-1, 15)\n",
    "ax.set_ylim(-1, 10)\n",
    "ax.set_aspect('equal')\n",
    "ax.axis('off')\n",
    "\n",
    "import matplotlib.patches as mpatches\n",
    "\n",
    "# Colors\n",
    "gate_colors = {'forget': RED, 'input': GREEN, 'output': BLUE, 'candidate': AMBER}\n",
    "\n",
    "def draw_gate(ax, x, y, label, color, w=1.8, h=0.9):\n",
    "    rect = mpatches.FancyBboxPatch(\n",
    "        (x - w/2, y - h/2), w, h,\n",
    "        boxstyle=mpatches.BoxStyle('Round', pad=0.1),\n",
    "        facecolor=color, edgecolor='white', linewidth=2, alpha=0.85\n",
    "    )\n",
    "    ax.add_patch(rect)\n",
    "    ax.text(x, y, label, ha='center', va='center', fontsize=10,\n",
    "            fontweight='bold', color='white')\n",
    "\n",
    "def draw_op(ax, x, y, symbol, size=0.45):\n",
    "    circle = plt.Circle((x, y), size, facecolor='white', edgecolor='#334155',\n",
    "                         linewidth=1.5, zorder=5)\n",
    "    ax.add_patch(circle)\n",
    "    ax.text(x, y, symbol, ha='center', va='center', fontsize=14,\n",
    "            fontweight='bold', color='#334155', zorder=6)\n",
    "\n",
    "# Cell state highway (top)\n",
    "ax.annotate('', xy=(13, 8), xytext=(1, 8),\n",
    "            arrowprops=dict(arrowstyle='->', lw=3, color='#475569'))\n",
    "ax.text(0.3, 8, '$C_{t-1}$', fontsize=13, fontweight='bold', color='#475569')\n",
    "ax.text(13.3, 8, '$C_t$', fontsize=13, fontweight='bold', color='#475569')\n",
    "ax.text(7, 9, 'Cell State (Long-Term Memory Highway)', fontsize=11,\n",
    "        ha='center', fontstyle='italic', color='#64748b')\n",
    "\n",
    "# Hidden state (bottom)\n",
    "ax.annotate('', xy=(13, 2), xytext=(1, 2),\n",
    "            arrowprops=dict(arrowstyle='->', lw=3, color='#475569'))\n",
    "ax.text(0.3, 2, '$h_{t-1}$', fontsize=13, fontweight='bold', color='#475569')\n",
    "ax.text(13.3, 2, '$h_t$', fontsize=13, fontweight='bold', color='#475569')\n",
    "\n",
    "# Input\n",
    "ax.text(7, 0, '$x_t$', fontsize=13, fontweight='bold', ha='center', color='#475569')\n",
    "ax.annotate('', xy=(7, 1.2), xytext=(7, 0.4),\n",
    "            arrowprops=dict(arrowstyle='->', lw=2, color='#94a3b8'))\n",
    "\n",
    "# Forget gate\n",
    "draw_gate(ax, 3.5, 4.5, r'Forget gate' + '\\n' + r'$\\sigma$', RED)\n",
    "ax.text(3.5, 3.5, '$f_t$', fontsize=11, ha='center', color=RED, fontweight='bold')\n",
    "# Arrow from concat to forget gate\n",
    "ax.annotate('', xy=(3.5, 4.0), xytext=(3.5, 2.5),\n",
    "            arrowprops=dict(arrowstyle='->', lw=1.5, color='#94a3b8'))\n",
    "# Multiply on cell state\n",
    "draw_op(ax, 3.5, 8, r'$\\times$')\n",
    "ax.annotate('', xy=(3.5, 7.5), xytext=(3.5, 5.0),\n",
    "            arrowprops=dict(arrowstyle='->', lw=1.5, color=RED))\n",
    "\n",
    "# Input gate\n",
    "draw_gate(ax, 6.5, 4.5, r'Input gate' + '\\n' + r'$\\sigma$', GREEN)\n",
    "ax.text(6.5, 3.5, '$i_t$', fontsize=11, ha='center', color=GREEN, fontweight='bold')\n",
    "ax.annotate('', xy=(6.5, 4.0), xytext=(6.5, 2.5),\n",
    "            arrowprops=dict(arrowstyle='->', lw=1.5, color='#94a3b8'))\n",
    "\n",
    "# Candidate\n",
    "draw_gate(ax, 8.5, 4.5, r'Candidate' + '\\n' + r'tanh', AMBER)\n",
    "ax.text(8.5, 3.5, r'$\\tilde{C}_t$', fontsize=11, ha='center', color=AMBER, fontweight='bold')\n",
    "ax.annotate('', xy=(8.5, 4.0), xytext=(8.5, 2.5),\n",
    "            arrowprops=dict(arrowstyle='->', lw=1.5, color='#94a3b8'))\n",
    "\n",
    "# i_t * C_tilde -> multiply\n",
    "draw_op(ax, 7.5, 6.5, r'$\\times$')\n",
    "ax.annotate('', xy=(7.1, 6.5), xytext=(6.5, 5.0),\n",
    "            arrowprops=dict(arrowstyle='->', lw=1.5, color=GREEN))\n",
    "ax.annotate('', xy=(7.9, 6.5), xytext=(8.5, 5.0),\n",
    "            arrowprops=dict(arrowstyle='->', lw=1.5, color=AMBER))\n",
    "\n",
    "# Add on cell state\n",
    "draw_op(ax, 7.5, 8, '+')\n",
    "ax.annotate('', xy=(7.5, 7.5), xytext=(7.5, 7.0),\n",
    "            arrowprops=dict(arrowstyle='->', lw=1.5, color='#64748b'))\n",
    "\n",
    "# Output gate\n",
    "draw_gate(ax, 10.5, 4.5, r'Output gate' + '\\n' + r'$\\sigma$', BLUE)\n",
    "ax.text(10.5, 3.5, '$o_t$', fontsize=11, ha='center', color=BLUE, fontweight='bold')\n",
    "ax.annotate('', xy=(10.5, 4.0), xytext=(10.5, 2.5),\n",
    "            arrowprops=dict(arrowstyle='->', lw=1.5, color='#94a3b8'))\n",
    "\n",
    "# tanh on cell state -> output\n",
    "draw_op(ax, 11.5, 6.5, 'tanh')\n",
    "ax.annotate('', xy=(11.5, 6.1), xytext=(11.5, 8),\n",
    "            arrowprops=dict(arrowstyle='<-', lw=1.5, color='#64748b'))\n",
    "\n",
    "# output gate * tanh(C_t) -> h_t\n",
    "draw_op(ax, 11.5, 2, r'$\\times$')\n",
    "ax.annotate('', xy=(11.5, 2.45), xytext=(11.5, 6.05),\n",
    "            arrowprops=dict(arrowstyle='<-', lw=1.5, color='#64748b'))\n",
    "ax.annotate('', xy=(11.1, 2), xytext=(10.5, 5.0),\n",
    "            arrowprops=dict(arrowstyle='<-', lw=1.5, color=BLUE, connectionstyle='arc3,rad=-0.3'))\n",
    "\n",
    "# Concat indicator\n",
    "ax.text(7, 1.5, '$[h_{t-1}, x_t]$ concatenated', fontsize=9,\n",
    "        ha='center', fontstyle='italic', color='#94a3b8')\n",
    "\n",
    "ax.set_title('LSTM Cell Architecture', fontsize=14, fontweight='bold', pad=15)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b7c8d9e0",
   "metadata": {},
   "source": [
    "## 34.3 Building LSTM from Scratch\n",
    "\n",
    "To truly understand the LSTM, we implement it using raw PyTorch tensor operations—no `nn.LSTMCell`. The key implementation insight is that all four linear transformations (for $f_t$, $i_t$, $\\tilde{C}_t$, and $o_t$) take the same input $[h_{t-1}, x_t]$, so we can concatenate them into a single large matrix multiplication and then chunk the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8d9e0f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ManualLSTMCell:\n",
    "    \"\"\"LSTM cell implemented from scratch using raw tensor operations.\n",
    "    \n",
    "    All four gates share a single weight matrix for efficiency:\n",
    "    W @ [h, x] + b -> chunk into (i, f, g, o)\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        \n",
    "        # Single weight matrix for all 4 gates: [input_gate, forget_gate, cell_candidate, output_gate]\n",
    "        # Input weights: maps x_t -> 4 * hidden_size\n",
    "        k = 1.0 / np.sqrt(hidden_size)\n",
    "        self.W_ih = torch.empty(4 * hidden_size, input_size).uniform_(-k, k)\n",
    "        self.b_ih = torch.empty(4 * hidden_size).uniform_(-k, k)\n",
    "        \n",
    "        # Hidden weights: maps h_{t-1} -> 4 * hidden_size  \n",
    "        self.W_hh = torch.empty(4 * hidden_size, hidden_size).uniform_(-k, k)\n",
    "        self.b_hh = torch.empty(4 * hidden_size).uniform_(-k, k)\n",
    "    \n",
    "    def forward(self, x_t, h_prev, c_prev):\n",
    "        \"\"\"Single LSTM step.\n",
    "        \n",
    "        Args:\n",
    "            x_t: input at time t, shape (batch, input_size)\n",
    "            h_prev: previous hidden state, shape (batch, hidden_size)\n",
    "            c_prev: previous cell state, shape (batch, hidden_size)\n",
    "        \n",
    "        Returns:\n",
    "            h_t: new hidden state\n",
    "            c_t: new cell state\n",
    "        \"\"\"\n",
    "        # Combined linear transformation\n",
    "        gates = (x_t @ self.W_ih.T + self.b_ih +\n",
    "                 h_prev @ self.W_hh.T + self.b_hh)\n",
    "        \n",
    "        # Chunk into 4 gates (PyTorch LSTM convention: i, f, g, o)\n",
    "        i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=1)\n",
    "        \n",
    "        # Apply activations\n",
    "        i_t = torch.sigmoid(i_gate)    # Input gate\n",
    "        f_t = torch.sigmoid(f_gate)    # Forget gate\n",
    "        c_tilde = torch.tanh(g_gate)   # Cell candidate\n",
    "        o_t = torch.sigmoid(o_gate)    # Output gate\n",
    "        \n",
    "        # Cell state update (the Constant Error Carousel!)\n",
    "        c_t = f_t * c_prev + i_t * c_tilde\n",
    "        \n",
    "        # Hidden state\n",
    "        h_t = o_t * torch.tanh(c_t)\n",
    "        \n",
    "        return h_t, c_t\n",
    "\n",
    "\n",
    "# Test our implementation\n",
    "input_size = 4\n",
    "hidden_size = 8\n",
    "batch_size = 2\n",
    "\n",
    "cell = ManualLSTMCell(input_size, hidden_size)\n",
    "\n",
    "x = torch.randn(batch_size, input_size)\n",
    "h0 = torch.zeros(batch_size, hidden_size)\n",
    "c0 = torch.zeros(batch_size, hidden_size)\n",
    "\n",
    "h1, c1 = cell.forward(x, h0, c0)\n",
    "print(f'Input shape:        {x.shape}')\n",
    "print(f'Hidden state shape: {h1.shape}')\n",
    "print(f'Cell state shape:   {c1.shape}')\n",
    "print(f'h1 range:           [{h1.min().item():.4f}, {h1.max().item():.4f}]')\n",
    "print(f'c1 range:           [{c1.min().item():.4f}, {c1.max().item():.4f}]')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9e0f1a2",
   "metadata": {},
   "source": [
    "Now let us verify that our manual implementation produces identical results to PyTorch's built-in `nn.LSTMCell` when initialized with the same weights:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0f1a2b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Verify against PyTorch's nn.LSTMCell\n",
    "torch.manual_seed(123)\n",
    "\n",
    "input_size = 4\n",
    "hidden_size = 8\n",
    "batch_size = 3\n",
    "\n",
    "# Create our manual cell\n",
    "manual_cell = ManualLSTMCell(input_size, hidden_size)\n",
    "\n",
    "# Create PyTorch's cell with SAME weights\n",
    "pytorch_cell = nn.LSTMCell(input_size, hidden_size)\n",
    "with torch.no_grad():\n",
    "    pytorch_cell.weight_ih.copy_(manual_cell.W_ih)\n",
    "    pytorch_cell.weight_hh.copy_(manual_cell.W_hh)\n",
    "    pytorch_cell.bias_ih.copy_(manual_cell.b_ih)\n",
    "    pytorch_cell.bias_hh.copy_(manual_cell.b_hh)\n",
    "\n",
    "# Run both on same input\n",
    "x = torch.randn(batch_size, input_size)\n",
    "h_prev = torch.randn(batch_size, hidden_size)\n",
    "c_prev = torch.randn(batch_size, hidden_size)\n",
    "\n",
    "h_manual, c_manual = manual_cell.forward(x, h_prev, c_prev)\n",
    "h_pytorch, c_pytorch = pytorch_cell(x, (h_prev, c_prev))\n",
    "\n",
    "h_diff = (h_manual - h_pytorch).abs().max().item()\n",
    "c_diff = (c_manual - c_pytorch).abs().max().item()\n",
    "\n",
    "print(f'Max absolute difference in h_t: {h_diff:.2e}')\n",
    "print(f'Max absolute difference in C_t: {c_diff:.2e}')\n",
    "print(f'Match: {\"YES\" if h_diff < 1e-6 and c_diff < 1e-6 else \"NO\"}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1a2b3c4",
   "metadata": {},
   "source": [
    "The match confirms our implementation is correct. The key efficiency trick is computing all four gate transformations with a single matrix multiply and then chunking the result.\n",
    "\n",
    "```{admonition} Implementation Note\n",
    ":class: tip\n",
    "\n",
    "PyTorch's `nn.LSTMCell` uses the gate ordering **(i, f, g, o)** — input, forget, cell candidate (called `g` internally), output. This differs from the order in many textbooks (f, i, g, o). When copying weights between implementations, be sure to match this convention.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2b3c4d5",
   "metadata": {},
   "source": [
    "## 34.4 Forget Gates\n",
    "\n",
    "```{admonition} Historical Note\n",
    ":class: note\n",
    "\n",
    "The original LSTM architecture proposed by Hochreiter and Schmidhuber (1997) had **no forget gate**. The cell state could only accumulate information—it could never discard it. This was problematic for tasks requiring the network to reset its memory (e.g., processing multiple independent sequences).\n",
    "\n",
    "The forget gate was added by Gers, Schmidhuber & Cummins (2000), completing the modern LSTM. They showed that the forget gate is essential for tasks involving continuous input streams where old information must eventually be discarded.\n",
    "```\n",
    "\n",
    "To illustrate the importance of the forget gate, consider a **counting task**: the network receives a stream of 0s and 1s and must output the running count of 1s, modulo some number. Without a forget gate, the cell state monotonically accumulates, eventually saturating and failing.\n",
    "\n",
    "We demonstrate with a simpler diagnostic: the network must remember a signal from the start of a sequence but **reset** when it sees a special token."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3c4d5e6",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Demonstrate forget gate importance with a counting task\n",
    "# Task: count the number of 1s in a binary sequence, modulo 4\n",
    "# Without forget gate, the cell state can only grow.\n",
    "\n",
    "def generate_counting_data(n_samples=500, seq_len=20):\n",
    "    \"\"\"Generate binary sequences and their running count mod 4.\"\"\"\n",
    "    X = torch.randint(0, 2, (n_samples, seq_len, 1)).float()\n",
    "    # Target: count of 1s at each step, mod 4\n",
    "    counts = X.squeeze(-1).cumsum(dim=1) % 4\n",
    "    return X, counts.long()\n",
    "\n",
    "class CountingLSTM(nn.Module):\n",
    "    def __init__(self, use_forget_gate=True):\n",
    "        super().__init__()\n",
    "        self.hidden_size = 16\n",
    "        self.lstm = nn.LSTMCell(1, self.hidden_size)\n",
    "        self.fc = nn.Linear(self.hidden_size, 4)  # 4 classes: 0,1,2,3\n",
    "        self.use_forget_gate = use_forget_gate\n",
    "        \n",
    "        if not use_forget_gate:\n",
    "            # Disable forget gate by setting its bias very high (f_t -> 1 always)\n",
    "            # In PyTorch's (i,f,g,o) layout, forget gate bias is indices [hidden_size:2*hidden_size]\n",
    "            with torch.no_grad():\n",
    "                self.lstm.bias_ih[self.hidden_size:2*self.hidden_size] = 100.0\n",
    "                self.lstm.bias_hh[self.hidden_size:2*self.hidden_size] = 0.0\n",
    "                # Also freeze these during training\n",
    "        \n",
    "    def forward(self, x_seq):\n",
    "        batch_size, seq_len, _ = x_seq.shape\n",
    "        h = torch.zeros(batch_size, self.hidden_size)\n",
    "        c = torch.zeros(batch_size, self.hidden_size)\n",
    "        outputs = []\n",
    "        \n",
    "        for t in range(seq_len):\n",
    "            h, c = self.lstm(x_seq[:, t, :], (h, c))\n",
    "            \n",
    "            if not self.use_forget_gate:\n",
    "                # Clamp forget gate bias to keep it at ~1\n",
    "                with torch.no_grad():\n",
    "                    self.lstm.bias_ih.data[self.hidden_size:2*self.hidden_size] = 100.0\n",
    "            \n",
    "            outputs.append(self.fc(h))\n",
    "        \n",
    "        return torch.stack(outputs, dim=1)  # (batch, seq_len, 4)\n",
    "\n",
    "def train_counting(use_forget_gate, n_epochs=80):\n",
    "    torch.manual_seed(42)\n",
    "    model = CountingLSTM(use_forget_gate=use_forget_gate)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    \n",
    "    X_train, y_train = generate_counting_data(500, 20)\n",
    "    X_test, y_test = generate_counting_data(200, 20)\n",
    "    \n",
    "    losses = []\n",
    "    accs = []\n",
    "    \n",
    "    for epoch in range(n_epochs):\n",
    "        model.train()\n",
    "        out = model(X_train)\n",
    "        loss = criterion(out.reshape(-1, 4), y_train.reshape(-1))\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        losses.append(loss.item())\n",
    "        \n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            test_out = model(X_test)\n",
    "            preds = test_out.argmax(dim=-1)\n",
    "            acc = (preds == y_test).float().mean().item()\n",
    "            accs.append(acc)\n",
    "    \n",
    "    return losses, accs\n",
    "\n",
    "losses_with_fg, accs_with_fg = train_counting(use_forget_gate=True)\n",
    "losses_no_fg, accs_no_fg = train_counting(use_forget_gate=False)\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))\n",
    "\n",
    "ax = axes[0]\n",
    "ax.plot(losses_with_fg, color=GREEN, linewidth=2, label='With forget gate')\n",
    "ax.plot(losses_no_fg, color=RED, linewidth=2, label='Without forget gate', linestyle='--')\n",
    "ax.set_xlabel('Epoch', fontsize=11)\n",
    "ax.set_ylabel('Cross-Entropy Loss', fontsize=11)\n",
    "ax.set_title('Counting Task: Training Loss', fontsize=12, fontweight='bold')\n",
    "ax.legend(fontsize=10)\n",
    "\n",
    "ax = axes[1]\n",
    "ax.plot(accs_with_fg, color=GREEN, linewidth=2, label='With forget gate')\n",
    "ax.plot(accs_no_fg, color=RED, linewidth=2, label='Without forget gate', linestyle='--')\n",
    "ax.set_xlabel('Epoch', fontsize=11)\n",
    "ax.set_ylabel('Accuracy', fontsize=11)\n",
    "ax.set_title('Counting Task: Test Accuracy', fontsize=12, fontweight='bold')\n",
    "ax.legend(fontsize=10)\n",
    "ax.set_ylim(0, 1.05)\n",
    "\n",
    "plt.suptitle('The Forget Gate is Essential for Counting (mod 4)',\n",
    "             fontsize=13, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(f'Final accuracy WITH forget gate:    {accs_with_fg[-1]:.3f}')\n",
    "print(f'Final accuracy WITHOUT forget gate:  {accs_no_fg[-1]:.3f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4d5e6f7",
   "metadata": {},
   "source": [
    "The counting task requires the network to track a value that wraps around (modulo 4). The standard LSTM with a forget gate can learn to reset the count at the right moments, while the version with the forget gate locked to 1 struggles because the cell state can only accumulate, never release information.\n",
    "\n",
    "```{admonition} Citation\n",
    ":class: note\n",
    "\n",
    "F. A. Gers, J. Schmidhuber, and F. Cummins, \"Learning to forget: Continual prediction with LSTM,\" *Neural Computation*, vol. 12, no. 10, pp. 2451–2471, 2000.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5e6f7a8",
   "metadata": {},
   "source": [
    "## 34.5 GRU: A Simplified Alternative\n",
    "\n",
    "In 2014, Cho et al. proposed the **Gated Recurrent Unit** (GRU), a streamlined variant that merges the cell state and hidden state into a single vector and uses only two gates instead of three.\n",
    "\n",
    "```{admonition} Definition: GRU Equations\n",
    ":class: note\n",
    "\n",
    "Given input $x_t$, previous hidden state $h_{t-1}$:\n",
    "\n",
    "**Update gate** (analogous to LSTM's forget + input gates):\n",
    "\n",
    "$$z_t = \\sigma(W_z [h_{t-1}, x_t])$$\n",
    "\n",
    "**Reset gate** (controls how much past to reveal to candidate):\n",
    "\n",
    "$$r_t = \\sigma(W_r [h_{t-1}, x_t])$$\n",
    "\n",
    "**Candidate hidden state**:\n",
    "\n",
    "$$\\tilde{h}_t = \\tanh(W [r_t \\odot h_{t-1}, x_t])$$\n",
    "\n",
    "**Hidden state update** (convex combination):\n",
    "\n",
    "$$h_t = (1 - z_t) \\odot h_{t-1} + z_t \\odot \\tilde{h}_t$$\n",
    "\n",
    "```\n",
    "\n",
    "The GRU has no separate cell state. The update gate $z_t$ plays a dual role: when $z_t \\approx 0$, the hidden state is copied forward (like an LSTM with $f_t \\approx 1$ and $i_t \\approx 0$). When $z_t \\approx 1$, the hidden state is replaced with the candidate.\n",
    "\n",
    "```{admonition} LSTM vs GRU\n",
    ":class: tip\n",
    "\n",
    "| Feature | LSTM | GRU |\n",
    "|:--------|:----:|:---:|\n",
    "| State vectors | 2 ($h_t$, $C_t$) | 1 ($h_t$) |\n",
    "| Gates | 3 (forget, input, output) | 2 (update, reset) |\n",
    "| Parameters per unit | $4n(n+d) + 4n$ | $3n(n+d) + 3n$ |\n",
    "| Ratio | 1.0x | 0.75x |\n",
    "\n",
    "Where $n$ = hidden size, $d$ = input size. GRU has 25% fewer parameters.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6f7a8b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ManualGRUCell:\n",
    "    \"\"\"GRU cell implemented from scratch using raw tensor operations.\n",
    "    \n",
    "    Uses the same concatenation trick: W @ [h, x] + b -> chunk into (r, z, n)\n",
    "    Note: PyTorch GRU convention applies reset gate BEFORE the linear transform\n",
    "    for the candidate, which requires separate weight matrices.\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, input_size, hidden_size):\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        \n",
    "        k = 1.0 / np.sqrt(hidden_size)\n",
    "        # Input weights: maps x_t -> 3 * hidden_size (for r, z, n)\n",
    "        self.W_ih = torch.empty(3 * hidden_size, input_size).uniform_(-k, k)\n",
    "        self.b_ih = torch.empty(3 * hidden_size).uniform_(-k, k)\n",
    "        \n",
    "        # Hidden weights: maps h_{t-1} -> 3 * hidden_size\n",
    "        self.W_hh = torch.empty(3 * hidden_size, hidden_size).uniform_(-k, k)\n",
    "        self.b_hh = torch.empty(3 * hidden_size).uniform_(-k, k)\n",
    "    \n",
    "    def forward(self, x_t, h_prev):\n",
    "        \"\"\"Single GRU step.\n",
    "        \n",
    "        Args:\n",
    "            x_t: input, shape (batch, input_size)\n",
    "            h_prev: previous hidden state, shape (batch, hidden_size)\n",
    "        \n",
    "        Returns:\n",
    "            h_t: new hidden state\n",
    "        \"\"\"\n",
    "        # Compute input and hidden contributions separately\n",
    "        # (needed because reset gate is applied to hidden part of candidate only)\n",
    "        gi = x_t @ self.W_ih.T + self.b_ih\n",
    "        gh = h_prev @ self.W_hh.T + self.b_hh\n",
    "        \n",
    "        # Chunk: (reset, update, new/candidate)\n",
    "        i_r, i_z, i_n = gi.chunk(3, dim=1)\n",
    "        h_r, h_z, h_n = gh.chunk(3, dim=1)\n",
    "        \n",
    "        r_t = torch.sigmoid(i_r + h_r)   # Reset gate\n",
    "        z_t = torch.sigmoid(i_z + h_z)   # Update gate\n",
    "        \n",
    "        # Candidate: reset gate applied to hidden contribution only\n",
    "        h_tilde = torch.tanh(i_n + r_t * h_n)\n",
    "        \n",
    "        # Convex combination\n",
    "        h_t = (1 - z_t) * h_prev + z_t * h_tilde\n",
    "        \n",
    "        return h_t\n",
    "\n",
    "\n",
    "# Test and verify against nn.GRUCell\n",
    "torch.manual_seed(456)\n",
    "\n",
    "input_size = 4\n",
    "hidden_size = 8\n",
    "batch_size = 3\n",
    "\n",
    "manual_gru = ManualGRUCell(input_size, hidden_size)\n",
    "\n",
    "pytorch_gru = nn.GRUCell(input_size, hidden_size)\n",
    "with torch.no_grad():\n",
    "    pytorch_gru.weight_ih.copy_(manual_gru.W_ih)\n",
    "    pytorch_gru.weight_hh.copy_(manual_gru.W_hh)\n",
    "    pytorch_gru.bias_ih.copy_(manual_gru.b_ih)\n",
    "    pytorch_gru.bias_hh.copy_(manual_gru.b_hh)\n",
    "\n",
    "x = torch.randn(batch_size, input_size)\n",
    "h_prev = torch.randn(batch_size, hidden_size)\n",
    "\n",
    "h_manual = manual_gru.forward(x, h_prev)\n",
    "h_pytorch = pytorch_gru(x, h_prev)\n",
    "\n",
    "diff = (h_manual - h_pytorch).abs().max().item()\n",
    "print(f'ManualGRUCell vs nn.GRUCell max diff: {diff:.2e}')\n",
    "print(f'Match: {\"YES\" if diff < 1e-6 else \"NO\"}')\n",
    "print(f'\\nParameter comparison:')\n",
    "lstm_params = 4 * hidden_size * (input_size + hidden_size) + 4 * hidden_size * 2\n",
    "gru_params = 3 * hidden_size * (input_size + hidden_size) + 3 * hidden_size * 2\n",
    "print(f'  LSTM parameters (h={hidden_size}, d={input_size}): {lstm_params}')\n",
    "print(f'  GRU parameters  (h={hidden_size}, d={input_size}): {gru_params}')\n",
    "print(f'  GRU/LSTM ratio: {gru_params/lstm_params:.2f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7a8b9c0",
   "metadata": {},
   "source": [
    "```{admonition} Citation\n",
    ":class: note\n",
    "\n",
    "K. Cho, B. van Merrienboer, C. Gulcehre, D. Bahdanau, F. Bougares, H. Schwenk, and Y. Bengio, \"Learning phrase representations using RNN encoder-decoder for statistical machine translation,\" in *Proceedings of EMNLP*, 2014.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8b9c0d1",
   "metadata": {},
   "source": [
    "## 34.6 The Payoff: \"Remember the First\" Revisited\n",
    "\n",
    "We now return to the diagnostic task that exposed the vanishing gradient problem in vanilla RNNs: **remember the first element of a sequence**.\n",
    "\n",
    "The task is simple: a sequence begins with a signal $x_1 \\in \\{0, 1\\}$, followed by $T-1$ noise steps. The network must output $x_1$ at the final time step. For vanilla RNNs, accuracy degrades sharply as $T$ increases beyond ~10–15 steps. If LSTM truly solves the vanishing gradient problem, it should handle $T = 50$ or more with ease."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9c0d1e2",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "def generate_remember_first(n_samples, seq_len, noise_dim=5):\n",
    "    \"\"\"Generate 'remember the first' task data.\n",
    "    \n",
    "    x_1 is a binary label (0 or 1), embedded at position 0.\n",
    "    Remaining positions are Gaussian noise.\n",
    "    Target: predict x_1 from the final hidden state.\n",
    "    \"\"\"\n",
    "    X = torch.randn(n_samples, seq_len, noise_dim)\n",
    "    labels = torch.randint(0, 2, (n_samples,))\n",
    "    # Embed the label in the first time step's first feature\n",
    "    X[:, 0, 0] = labels.float()\n",
    "    return X, labels\n",
    "\n",
    "class SeqClassifier(nn.Module):\n",
    "    \"\"\"Sequence classifier using RNN, LSTM, or GRU.\"\"\"\n",
    "    def __init__(self, input_size, hidden_size, rnn_type='lstm'):\n",
    "        super().__init__()\n",
    "        self.rnn_type = rnn_type\n",
    "        if rnn_type == 'rnn':\n",
    "            self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)\n",
    "        elif rnn_type == 'lstm':\n",
    "            self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)\n",
    "        elif rnn_type == 'gru':\n",
    "            self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)\n",
    "        self.fc = nn.Linear(hidden_size, 2)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        out, _ = self.rnn(x)\n",
    "        return self.fc(out[:, -1, :])  # Use final hidden state\n",
    "\n",
    "def train_remember_first(rnn_type, seq_len, hidden_size=32, n_epochs=100, lr=0.003):\n",
    "    torch.manual_seed(42)\n",
    "    input_size = 5\n",
    "    model = SeqClassifier(input_size, hidden_size, rnn_type)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=lr)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    \n",
    "    X_train, y_train = generate_remember_first(800, seq_len)\n",
    "    X_test, y_test = generate_remember_first(200, seq_len)\n",
    "    \n",
    "    best_acc = 0.5\n",
    "    for epoch in range(n_epochs):\n",
    "        model.train()\n",
    "        out = model(X_train)\n",
    "        loss = criterion(out, y_train)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "        optimizer.step()\n",
    "        \n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            preds = model(X_test).argmax(dim=1)\n",
    "            acc = (preds == y_test).float().mean().item()\n",
    "            best_acc = max(best_acc, acc)\n",
    "    \n",
    "    return best_acc\n",
    "\n",
    "# Test across sequence lengths\n",
    "seq_lengths = [5, 10, 15, 20, 30, 50]\n",
    "results = {'rnn': [], 'lstm': [], 'gru': []}\n",
    "\n",
    "print('Training \"Remember the First\" task across sequence lengths...')\n",
    "for rnn_type in ['rnn', 'lstm', 'gru']:\n",
    "    for T in seq_lengths:\n",
    "        acc = train_remember_first(rnn_type, T)\n",
    "        results[rnn_type].append(acc)\n",
    "        print(f'  {rnn_type.upper():4s}  T={T:3d}  acc={acc:.3f}')\n",
    "\n",
    "# Plot results\n",
    "fig, ax = plt.subplots(figsize=(10, 5))\n",
    "\n",
    "styles = {\n",
    "    'rnn':  (RED, 's', '--', 'Vanilla RNN'),\n",
    "    'lstm': (GREEN, 'o', '-', 'LSTM'),\n",
    "    'gru':  (BLUE, '^', '-.', 'GRU'),\n",
    "}\n",
    "\n",
    "for rnn_type, (color, marker, ls, label) in styles.items():\n",
    "    ax.plot(seq_lengths, results[rnn_type], color=color, marker=marker,\n",
    "            linestyle=ls, linewidth=2, markersize=8, label=label)\n",
    "\n",
    "ax.axhline(y=0.5, color='gray', linestyle=':', alpha=0.5)\n",
    "ax.text(max(seq_lengths) - 1, 0.52, 'chance level', fontsize=9, color='gray')\n",
    "ax.set_xlabel('Sequence Length T', fontsize=12)\n",
    "ax.set_ylabel('Best Test Accuracy', fontsize=12)\n",
    "ax.set_title('\"Remember the First\": RNN vs LSTM vs GRU', fontsize=13, fontweight='bold')\n",
    "ax.legend(fontsize=11)\n",
    "ax.set_ylim(0.4, 1.05)\n",
    "ax.set_xticks(seq_lengths)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0d1e2f3",
   "metadata": {},
   "source": [
    "The results confirm the theoretical analysis:\n",
    "\n",
    "- **Vanilla RNN** accuracy degrades as sequence length increases, falling toward chance level (50%) for $T \\geq 20$.\n",
    "- **LSTM** maintains high accuracy even at $T = 50$, thanks to the Constant Error Carousel.\n",
    "- **GRU** performs comparably to LSTM on this task, with fewer parameters.\n",
    "\n",
    "This is the payoff for the gating architecture: the vanishing gradient problem, which seemed like a fundamental barrier, is solved by an elegant engineering insight."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1e2f3a4",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Gradient norm comparison: track gradient norms during training\n",
    "def measure_gradient_norms(rnn_type, seq_len=30, n_steps=50):\n",
    "    \"\"\"Track gradient norms of the first layer during training.\"\"\"\n",
    "    torch.manual_seed(42)\n",
    "    input_size = 5\n",
    "    hidden_size = 32\n",
    "    model = SeqClassifier(input_size, hidden_size, rnn_type)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=0.003)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    \n",
    "    X, y = generate_remember_first(400, seq_len)\n",
    "    \n",
    "    grad_norms = []\n",
    "    for step in range(n_steps):\n",
    "        model.train()\n",
    "        out = model(X)\n",
    "        loss = criterion(out, y)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        \n",
    "        # Measure gradient norm of RNN weights\n",
    "        total_norm = 0.0\n",
    "        for p in model.rnn.parameters():\n",
    "            if p.grad is not None:\n",
    "                total_norm += p.grad.data.norm(2).item() ** 2\n",
    "        total_norm = total_norm ** 0.5\n",
    "        grad_norms.append(total_norm)\n",
    "        \n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "        optimizer.step()\n",
    "    \n",
    "    return grad_norms\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(10, 4.5))\n",
    "\n",
    "for rnn_type, (color, marker, ls, label) in styles.items():\n",
    "    norms = measure_gradient_norms(rnn_type, seq_len=30)\n",
    "    ax.plot(norms, color=color, linestyle=ls, linewidth=2, label=label, alpha=0.8)\n",
    "\n",
    "ax.set_xlabel('Training Step', fontsize=11)\n",
    "ax.set_ylabel('Gradient Norm (before clipping)', fontsize=11)\n",
    "ax.set_title('Gradient Norms During Training (T=30)', fontsize=13, fontweight='bold')\n",
    "ax.legend(fontsize=10)\n",
    "ax.set_yscale('log')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2f3a4b5",
   "metadata": {},
   "source": [
    "The gradient norm plot reveals the mechanism at work: the vanilla RNN's gradients are orders of magnitude smaller than those of the LSTM and GRU, confirming that information about the first element is lost during backpropagation through the 30-step sequence."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3a4b5c6",
   "metadata": {},
   "source": [
    "## Exercises\n",
    "\n",
    "**Exercise 34.1.** Starting from the LSTM cell state update $C_t = f_t \\odot C_{t-1} + i_t \\odot \\tilde{C}_t$, derive the gradient $\\partial L / \\partial C_{t-1}$ and show explicitly how the forget gate $f_t$ prevents gradient vanishing compared to the vanilla RNN's $\\partial h_t / \\partial h_{t-1}$.\n",
    "\n",
    "**Exercise 34.2.** Modify the `ManualLSTMCell` class to add **peephole connections** (Gers & Schmidhuber, 2000), where the gates also receive the cell state as input: $f_t = \\sigma(W_f[h_{t-1}, x_t] + w_f \\odot C_{t-1} + b_f)$ (and similarly for $i_t$ and $o_t$). Test whether peepholes improve performance on the counting task.\n",
    "\n",
    "**Exercise 34.3.** Count the total number of trainable parameters in an LSTM with input size $d = 10$ and hidden size $n = 64$. Break down the count by gate. Repeat for a GRU with the same dimensions.\n",
    "\n",
    "**Exercise 34.4.** The original 1997 LSTM used $C_t = C_{t-1} + i_t \\odot \\tilde{C}_t$ (no forget gate). Implement this variant as `OriginalLSTMCell` and show on the counting task that it fails to learn modular arithmetic. Explain mathematically why.\n",
    "\n",
    "**Exercise 34.5.** The GRU update $h_t = (1-z_t) \\odot h_{t-1} + z_t \\odot \\tilde{h}_t$ is a convex combination. Prove that $\\|h_t\\|$ is bounded if $\\|\\tilde{h}_t\\|$ is bounded (which it is, since tanh outputs are in $[-1, 1]$). Why does the LSTM need a separate output gate to achieve a similar bound on $h_t$?\n",
    "\n",
    "**Exercise 34.6.** Run the \"remember the first\" experiment with sequence lengths $T \\in \\{75, 100, 150, 200\\}$. At what length does the LSTM begin to struggle? Does increasing the hidden size from 32 to 64 help? Report your findings with accuracy plots."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4b5c6d7",
   "metadata": {},
   "source": [
    "## Summary\n",
    "\n",
    "- The **Constant Error Carousel** is LSTM's core innovation: additive cell state updates allow gradients to flow unchanged through time, solving the vanishing gradient problem.\n",
    "- The LSTM cell uses three gates—**forget**, **input**, and **output**—to control information flow, each learned independently via backpropagation.\n",
    "- The **forget gate** (Gers et al., 2000) is essential: without it, the cell state can only accumulate, never release information.\n",
    "- The **GRU** (Cho et al., 2014) simplifies the LSTM by merging cell and hidden states and using two gates, achieving comparable performance with 25% fewer parameters.\n",
    "- On the \"remember the first\" task, both LSTM and GRU maintain near-perfect accuracy at $T = 50$, where vanilla RNNs fall to chance."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b5c6d7e8",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "1. S. Hochreiter, \"Untersuchungen zu dynamischen neuronalen Netzen,\" Diploma thesis, Technische Universität München, 1991.\n",
    "\n",
    "2. Y. Bengio, P. Simard, and P. Frasconi, \"Learning long-term dependencies with gradient descent is difficult,\" *IEEE Transactions on Neural Networks*, vol. 5, no. 2, pp. 157–166, 1994.\n",
    "\n",
    "3. S. Hochreiter and J. Schmidhuber, \"Long short-term memory,\" *Neural Computation*, vol. 9, no. 8, pp. 1735–1780, 1997.\n",
    "\n",
    "4. F. A. Gers, J. Schmidhuber, and F. Cummins, \"Learning to forget: Continual prediction with LSTM,\" *Neural Computation*, vol. 12, no. 10, pp. 2451–2471, 2000.\n",
    "\n",
    "5. K. Cho, B. van Merrienboer, C. Gulcehre, D. Bahdanau, F. Bougares, H. Schwenk, and Y. Bengio, \"Learning phrase representations using RNN encoder-decoder for statistical machine translation,\" in *Proceedings of EMNLP*, 2014."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}