{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cell-0",
   "metadata": {},
   "source": [
    "# Chapter 24: Training a CNN with Backpropagation\n",
    "\n",
    "In Chapter 23 we assembled a TinyCNN from Conv2D, ReLU, MaxPool2D, Flatten, and Dense\n",
    "layers. The network can compute a forward pass and produce (random) predictions.\n",
    "To *train* it we need the **backward pass**: the chain of gradient computations that\n",
    "tells each parameter how to change in order to reduce the loss.\n",
    "\n",
    "This chapter is the convolutional counterpart of Chapter 16 (backpropagation for\n",
    "fully-connected networks). We derive the gradients mathematically, implement them,\n",
    "verify them numerically, and then train the network on our synthetic line-pattern\n",
    "dataset.\n",
    "\n",
    "```{admonition} Chapter goals\n",
    ":class: note\n",
    "1. Derive $\\partial \\mathcal{L} / \\partial \\bW$ and $\\partial \\mathcal{L} / \\partial \\bx$ for a convolutional layer.\n",
    "2. Implement `backward` methods for Conv2D, MaxPool2D, and Dense.\n",
    "3. Verify all gradients with numerical finite differences.\n",
    "4. Train TinyCNN to >90% validation accuracy on oriented line patterns.\n",
    "5. Visualize how the learned filters evolve during training.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-1",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.style.use('seaborn-v0_8-whitegrid')\n",
    "plt.rcParams.update({\n",
    "    'figure.facecolor': '#FAF8F0',\n",
    "    'axes.facecolor': '#FAF8F0',\n",
    "    'font.size': 11,\n",
    "})\n",
    "\n",
    "# Project colour palette\n",
    "BLUE = '#3b82f6'\n",
    "BLUE_DARK = '#2563eb'\n",
    "GREEN = '#059669'\n",
    "GREEN_LIGHT = '#10b981'\n",
    "AMBER = '#d97706'\n",
    "RED = '#dc2626'\n",
    "BURGUNDY = '#8c2f39'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-2",
   "metadata": {},
   "source": [
    "## 24.1 Backprop Through Convolution\n",
    "\n",
    "Recall the forward pass of a convolution with a single filter $\\mathbf{K}$ applied\n",
    "to a single-channel input $\\mathbf{X}$ (we drop batch and channel indices for clarity):\n",
    "\n",
    "$$Y_{i,j} = b + \\sum_{p=0}^{k-1} \\sum_{q=0}^{k-1} K_{p,q} \\cdot X_{i+p,\\, j+q}.$$\n",
    "\n",
    "Given the upstream gradient $\\frac{\\partial \\mathcal{L}}{\\partial Y_{i,j}}$ from the\n",
    "layer above, we need two things:\n",
    "\n",
    "**1. Gradient w.r.t. the kernel** (to update the filter):\n",
    "\n",
    "$$\\frac{\\partial \\mathcal{L}}{\\partial K_{p,q}}\n",
    "  = \\sum_{i,j} \\frac{\\partial \\mathcal{L}}{\\partial Y_{i,j}} \\cdot X_{i+p,\\, j+q}.$$\n",
    "\n",
    "This is itself a convolution -- the input $\\mathbf{X}$ convolved with the upstream\n",
    "gradient $\\frac{\\partial \\mathcal{L}}{\\partial \\mathbf{Y}}$.\n",
    "\n",
    "**2. Gradient w.r.t. the input** (to continue the chain rule to earlier layers):\n",
    "\n",
    "$$\\frac{\\partial \\mathcal{L}}{\\partial X_{m,n}}\n",
    "  = \\sum_{i,j} \\frac{\\partial \\mathcal{L}}{\\partial Y_{i,j}} \\cdot K_{m-i,\\, n-j}$$\n",
    "\n",
    "where the sum runs over all $(i,j)$ such that the indices are valid. This is a\n",
    "**full convolution** of the upstream gradient with the **rotated** (180-degree flipped)\n",
    "kernel.\n",
    "\n",
    "```{admonition} Theorem (Convolution Backward Pass)\n",
    ":class: note\n",
    "The backward pass through a convolution layer is itself a convolution:\n",
    "\n",
    "- $\\nabla_{\\mathbf{K}} \\mathcal{L} = \\mathbf{X} \\star \\frac{\\partial \\mathcal{L}}{\\partial \\mathbf{Y}}$\n",
    "  (valid cross-correlation of input with upstream gradient)\n",
    "- $\\nabla_{\\mathbf{X}} \\mathcal{L} = \\frac{\\partial \\mathcal{L}}{\\partial \\mathbf{Y}} \\star_{\\text{full}} \\text{rot}_{180}(\\mathbf{K})$\n",
    "  (full convolution of upstream gradient with flipped kernel)\n",
    "\n",
    "This duality is one of the most elegant results in neural network theory: convolutions\n",
    "are \"self-similar\" under backpropagation.\n",
    "```\n",
    "\n",
    "For the **bias gradient**, each output position contributes equally:\n",
    "\n",
    "$$\\frac{\\partial \\mathcal{L}}{\\partial b_f}\n",
    "  = \\sum_{\\text{batch}} \\sum_{i,j} \\frac{\\partial \\mathcal{L}}{\\partial Y_{f,i,j}}.$$\n",
    "\n",
    "```{warning}\n",
    "In our implementation we use the loop-based approach from Chapter 22 rather than\n",
    "the full-convolution formulation. This is less efficient but makes the connection\n",
    "between forward and backward passes transparent. Production code would use\n",
    "optimized im2col or FFT-based convolutions.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-3",
   "metadata": {},
   "source": [
    "## 24.2 Gradients for Pooling and Dense\n",
    "\n",
    "### Max Pooling Backward\n",
    "\n",
    "Max pooling selects the maximum value in each window. During the backward pass,\n",
    "the gradient flows *only through the position that was the maximum* -- all other\n",
    "positions in the window receive zero gradient.\n",
    "\n",
    "```{admonition} Max Pooling Gradient Rule\n",
    ":class: note\n",
    "Let $x^*$ be the element that achieved the maximum in a pooling window. Then:\n",
    "\n",
    "$$\\frac{\\partial \\mathcal{L}}{\\partial x_{m,n}} =\n",
    "  \\begin{cases}\n",
    "    \\frac{\\partial \\mathcal{L}}{\\partial y_{i,j}} & \\text{if } x_{m,n} = x^* \\\\\n",
    "    0 & \\text{otherwise}\n",
    "  \\end{cases}$$\n",
    "\n",
    "This is why we stored the `last_mask` during the forward pass: it records which\n",
    "element \"won\" in each window.\n",
    "```\n",
    "\n",
    "### Dense Layer Backward\n",
    "\n",
    "The dense layer computes $\\mathbf{y} = \\bx \\bW + \\bb$. The gradients are:\n",
    "\n",
    "$$\\frac{\\partial \\mathcal{L}}{\\partial \\bW} = \\bx^\\top \\frac{\\partial \\mathcal{L}}{\\partial \\mathbf{y}},\n",
    "\\qquad\n",
    "\\frac{\\partial \\mathcal{L}}{\\partial \\bb} = \\sum_{\\text{batch}} \\frac{\\partial \\mathcal{L}}{\\partial \\mathbf{y}},\n",
    "\\qquad\n",
    "\\frac{\\partial \\mathcal{L}}{\\partial \\bx} = \\frac{\\partial \\mathcal{L}}{\\partial \\mathbf{y}} \\bW^\\top.$$\n",
    "\n",
    "These are identical to the fully-connected backpropagation equations from Chapter 16."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-4",
   "metadata": {},
   "source": [
    "## 24.3 The Complete Backward Pass\n",
    "\n",
    "We now define all layer classes with both `forward` and `backward` methods.\n",
    "Because JupyterBook executes each notebook independently, we must redefine\n",
    "everything from scratch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Conv2D:\n",
    "    \"\"\"2D convolutional layer with forward and backward passes.\"\"\"\n",
    "\n",
    "    def __init__(self, in_channels, out_channels, kernel_size, seed=42):\n",
    "        rng = np.random.default_rng(seed)\n",
    "        fan_in = in_channels * kernel_size * kernel_size\n",
    "        scale = np.sqrt(2.0 / fan_in)\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.kernel_size = kernel_size\n",
    "        self.weights = rng.normal(0.0, scale,\n",
    "                                  size=(out_channels, in_channels,\n",
    "                                        kernel_size, kernel_size))\n",
    "        self.bias = np.zeros(out_channels)\n",
    "        self.d_weights = np.zeros_like(self.weights)\n",
    "        self.d_bias = np.zeros_like(self.bias)\n",
    "        self.last_input = None\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.last_input = x\n",
    "        batch_size, _, height, width = x.shape\n",
    "        out_h = height - self.kernel_size + 1\n",
    "        out_w = width - self.kernel_size + 1\n",
    "        output = np.zeros((batch_size, self.out_channels, out_h, out_w))\n",
    "        for row in range(out_h):\n",
    "            for col in range(out_w):\n",
    "                patch = x[:, :,\n",
    "                          row:row + self.kernel_size,\n",
    "                          col:col + self.kernel_size]\n",
    "                output[:, :, row, col] = (\n",
    "                    np.tensordot(patch, self.weights,\n",
    "                                 axes=([1, 2, 3], [1, 2, 3]))\n",
    "                    + self.bias\n",
    "                )\n",
    "        return output\n",
    "\n",
    "    def backward(self, d_output):\n",
    "        \"\"\"Compute gradients for weights, bias, and input.\"\"\"\n",
    "        x = self.last_input\n",
    "        _, _, out_h, out_w = d_output.shape\n",
    "        self.d_weights.fill(0.0)\n",
    "        self.d_bias = d_output.sum(axis=(0, 2, 3))\n",
    "        d_input = np.zeros_like(x)\n",
    "        for row in range(out_h):\n",
    "            for col in range(out_w):\n",
    "                patch = x[:, :,\n",
    "                          row:row + self.kernel_size,\n",
    "                          col:col + self.kernel_size]\n",
    "                # dL/dK: upstream gradient dot input patch\n",
    "                self.d_weights += np.tensordot(\n",
    "                    d_output[:, :, row, col], patch,\n",
    "                    axes=([0], [0]))\n",
    "                # dL/dX: upstream gradient dot kernel\n",
    "                d_input[:, :,\n",
    "                        row:row + self.kernel_size,\n",
    "                        col:col + self.kernel_size] += np.tensordot(\n",
    "                    d_output[:, :, row, col], self.weights,\n",
    "                    axes=([1], [0]))\n",
    "        return d_input\n",
    "\n",
    "    def step(self, lr):\n",
    "        \"\"\"Gradient descent update.\"\"\"\n",
    "        self.weights -= lr * self.d_weights\n",
    "        self.bias -= lr * self.d_bias\n",
    "\n",
    "    @property\n",
    "    def parameter_count(self):\n",
    "        return int(self.weights.size + self.bias.size)\n",
    "\n",
    "\n",
    "class ReLU:\n",
    "    \"\"\"Element-wise Rectified Linear Unit.\"\"\"\n",
    "    def __init__(self):\n",
    "        self.last_input = None\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.last_input = x\n",
    "        return np.maximum(0.0, x)\n",
    "\n",
    "    def backward(self, d_output):\n",
    "        return d_output * (self.last_input > 0.0)\n",
    "\n",
    "\n",
    "class MaxPool2D:\n",
    "    \"\"\"Max pooling with non-overlapping windows.\"\"\"\n",
    "    def __init__(self, pool_size=2):\n",
    "        self.pool_size = pool_size\n",
    "        self.last_input = None\n",
    "        self.last_mask = None\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.last_input = x\n",
    "        batch_size, channels, height, width = x.shape\n",
    "        out_h = height // self.pool_size\n",
    "        out_w = width // self.pool_size\n",
    "        output = np.zeros((batch_size, channels, out_h, out_w))\n",
    "        mask = np.zeros_like(x)\n",
    "        for row in range(out_h):\n",
    "            for col in range(out_w):\n",
    "                rs = row * self.pool_size\n",
    "                cs = col * self.pool_size\n",
    "                window = x[:, :, rs:rs + self.pool_size,\n",
    "                                  cs:cs + self.pool_size]\n",
    "                output[:, :, row, col] = window.max(axis=(2, 3))\n",
    "                flat_idx = window.reshape(\n",
    "                    batch_size, channels, -1).argmax(axis=2)\n",
    "                for b in range(batch_size):\n",
    "                    for c in range(channels):\n",
    "                        winner = flat_idx[b, c]\n",
    "                        wr = winner // self.pool_size\n",
    "                        wc = winner % self.pool_size\n",
    "                        mask[b, c, rs + wr, cs + wc] = 1.0\n",
    "        self.last_mask = mask\n",
    "        return output\n",
    "\n",
    "    def backward(self, d_output):\n",
    "        d_input = np.zeros_like(self.last_input)\n",
    "        out_h, out_w = d_output.shape[2], d_output.shape[3]\n",
    "        for row in range(out_h):\n",
    "            for col in range(out_w):\n",
    "                rs = row * self.pool_size\n",
    "                cs = col * self.pool_size\n",
    "                mask_w = self.last_mask[\n",
    "                    :, :, rs:rs + self.pool_size,\n",
    "                          cs:cs + self.pool_size]\n",
    "                d_input[:, :, rs:rs + self.pool_size,\n",
    "                              cs:cs + self.pool_size] += (\n",
    "                    mask_w * d_output[:, :, row, col][:, :, None, None]\n",
    "                )\n",
    "        return d_input\n",
    "\n",
    "\n",
    "class Flatten:\n",
    "    \"\"\"Reshape a 4-D tensor to 2-D (batch, features).\"\"\"\n",
    "    def __init__(self):\n",
    "        self.last_shape = None\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.last_shape = x.shape\n",
    "        return x.reshape(x.shape[0], -1)\n",
    "\n",
    "    def backward(self, d_output):\n",
    "        return d_output.reshape(self.last_shape)\n",
    "\n",
    "\n",
    "class Dense:\n",
    "    \"\"\"Fully-connected layer with forward and backward passes.\"\"\"\n",
    "    def __init__(self, in_features, out_features, seed=42):\n",
    "        rng = np.random.default_rng(seed)\n",
    "        scale = np.sqrt(2.0 / in_features)\n",
    "        self.weights = rng.normal(0.0, scale,\n",
    "                                  size=(in_features, out_features))\n",
    "        self.bias = np.zeros(out_features)\n",
    "        self.d_weights = np.zeros_like(self.weights)\n",
    "        self.d_bias = np.zeros_like(self.bias)\n",
    "        self.last_input = None\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.last_input = x\n",
    "        return x @ self.weights + self.bias\n",
    "\n",
    "    def backward(self, d_output):\n",
    "        self.d_weights = self.last_input.T @ d_output\n",
    "        self.d_bias = d_output.sum(axis=0)\n",
    "        return d_output @ self.weights.T\n",
    "\n",
    "    def step(self, lr):\n",
    "        self.weights -= lr * self.d_weights\n",
    "        self.bias -= lr * self.d_bias\n",
    "\n",
    "    @property\n",
    "    def parameter_count(self):\n",
    "        return int(self.weights.size + self.bias.size)\n",
    "\n",
    "\n",
    "print(\"All layer classes defined with forward + backward:\")\n",
    "print(\"  Conv2D, ReLU, MaxPool2D, Flatten, Dense\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax(logits):\n",
    "    \"\"\"Numerically stable softmax.\"\"\"\n",
    "    shifted = logits - logits.max(axis=1, keepdims=True)\n",
    "    exp_values = np.exp(shifted)\n",
    "    return exp_values / exp_values.sum(axis=1, keepdims=True)\n",
    "\n",
    "\n",
    "def softmax_cross_entropy(logits, targets):\n",
    "    \"\"\"Combined softmax + cross-entropy loss with gradient.\"\"\"\n",
    "    probabilities = softmax(logits)\n",
    "    batch_idx = np.arange(targets.shape[0])\n",
    "    clipped = np.clip(probabilities[batch_idx, targets], 1e-12, 1.0)\n",
    "    loss = -np.log(clipped).mean()\n",
    "    d_logits = probabilities.copy()\n",
    "    d_logits[batch_idx, targets] -= 1.0\n",
    "    d_logits /= targets.shape[0]\n",
    "    return loss, probabilities, d_logits\n",
    "\n",
    "\n",
    "print(\"softmax and softmax_cross_entropy defined.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-6a",
   "metadata": {},
   "source": [
    "Now we build the TinyCNN with the full training interface: `loss_and_grad` runs\n",
    "a forward pass, computes the loss, and then backpropagates through every layer.\n",
    "The `step` method updates all trainable parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TinyCNN:\n",
    "    \"\"\"A minimal CNN for 8x8 grayscale images, now with training support.\n",
    "    \n",
    "    Architecture:\n",
    "        Conv2D(1->3, 3x3) -> ReLU -> MaxPool(2x2) -> Flatten -> Dense(27->3)\n",
    "    \"\"\"\n",
    "    def __init__(self, seed=42):\n",
    "        self.conv = Conv2D(in_channels=1, out_channels=3,\n",
    "                           kernel_size=3, seed=seed)\n",
    "        self.relu = ReLU()\n",
    "        self.pool = MaxPool2D(pool_size=2)\n",
    "        self.flatten = Flatten()\n",
    "        self.dense = Dense(in_features=27, out_features=3,\n",
    "                           seed=seed + 1)\n",
    "        self.layers = [self.conv, self.relu, self.pool,\n",
    "                       self.flatten, self.dense]\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"Forward pass through all layers.\"\"\"\n",
    "        for layer in self.layers:\n",
    "            x = layer.forward(x)\n",
    "        return x  # logits, shape (batch, 3)\n",
    "\n",
    "    def loss_and_grad(self, x, y):\n",
    "        \"\"\"Forward pass + loss + full backward pass.\"\"\"\n",
    "        logits = self.forward(x)\n",
    "        loss, probs, d_logits = softmax_cross_entropy(logits, y)\n",
    "        # Backward through all layers in reverse order\n",
    "        grad = d_logits\n",
    "        for layer in reversed(self.layers):\n",
    "            grad = layer.backward(grad)\n",
    "        return loss, probs\n",
    "\n",
    "    def step(self, lr):\n",
    "        \"\"\"Update all trainable parameters.\"\"\"\n",
    "        self.conv.step(lr)\n",
    "        self.dense.step(lr)\n",
    "\n",
    "    def predict(self, x):\n",
    "        \"\"\"Return predicted class indices.\"\"\"\n",
    "        logits = self.forward(x)\n",
    "        return np.argmax(logits, axis=1)\n",
    "\n",
    "    def evaluate(self, x, y):\n",
    "        \"\"\"Compute accuracy on a dataset.\"\"\"\n",
    "        preds = self.predict(x)\n",
    "        return np.mean(preds == y)\n",
    "\n",
    "    def fit(self, x_train, y_train, x_val, y_val,\n",
    "            epochs=80, lr=0.12, batch_size=18, seed=0):\n",
    "        \"\"\"Full training loop with mini-batch SGD.\n",
    "        \n",
    "        Returns\n",
    "        -------\n",
    "        history : dict with keys 'train_loss', 'val_loss',\n",
    "                  'train_acc', 'val_acc', 'kernel_snapshots'\n",
    "        \"\"\"\n",
    "        rng = np.random.default_rng(seed)\n",
    "        n_train = x_train.shape[0]\n",
    "        snapshot_epochs = {0, 1, 5, 15, 40, epochs - 1}\n",
    "\n",
    "        history = {\n",
    "            'train_loss': [], 'val_loss': [],\n",
    "            'train_acc': [], 'val_acc': [],\n",
    "            'kernel_snapshots': {}\n",
    "        }\n",
    "\n",
    "        for epoch in range(epochs):\n",
    "            # Save kernel snapshot if needed\n",
    "            if epoch in snapshot_epochs:\n",
    "                history['kernel_snapshots'][epoch] = \\\n",
    "                    self.conv.weights.copy()\n",
    "\n",
    "            # Shuffle training data\n",
    "            perm = rng.permutation(n_train)\n",
    "            x_shuf = x_train[perm]\n",
    "            y_shuf = y_train[perm]\n",
    "\n",
    "            epoch_loss = 0.0\n",
    "            n_batches = 0\n",
    "            for start in range(0, n_train, batch_size):\n",
    "                end = min(start + batch_size, n_train)\n",
    "                xb = x_shuf[start:end]\n",
    "                yb = y_shuf[start:end]\n",
    "                loss, _ = self.loss_and_grad(xb, yb)\n",
    "                self.step(lr)\n",
    "                epoch_loss += loss\n",
    "                n_batches += 1\n",
    "\n",
    "            # Compute epoch metrics\n",
    "            avg_loss = epoch_loss / n_batches\n",
    "            train_acc = self.evaluate(x_train, y_train)\n",
    "            val_acc = self.evaluate(x_val, y_val)\n",
    "\n",
    "            # Validation loss\n",
    "            val_logits = self.forward(x_val)\n",
    "            val_loss, _, _ = softmax_cross_entropy(val_logits, y_val)\n",
    "\n",
    "            history['train_loss'].append(avg_loss)\n",
    "            history['val_loss'].append(val_loss)\n",
    "            history['train_acc'].append(train_acc)\n",
    "            history['val_acc'].append(val_acc)\n",
    "\n",
    "            if epoch % 10 == 0 or epoch == epochs - 1:\n",
    "                print(f\"Epoch {epoch:3d}/{epochs}: \"\n",
    "                      f\"loss={avg_loss:.4f}  \"\n",
    "                      f\"train_acc={train_acc:.1%}  \"\n",
    "                      f\"val_acc={val_acc:.1%}\")\n",
    "\n",
    "        # Final snapshot\n",
    "        if (epochs - 1) not in history['kernel_snapshots']:\n",
    "            history['kernel_snapshots'][epochs - 1] = \\\n",
    "                self.conv.weights.copy()\n",
    "\n",
    "        return history\n",
    "\n",
    "\n",
    "print(\"TinyCNN class defined with full training support.\")\n",
    "print(\"Methods: forward, loss_and_grad, step, predict, evaluate, fit\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-8",
   "metadata": {},
   "source": [
    "## 24.4 Gradient Checking\n",
    "\n",
    "Before we trust our backward pass, we verify every gradient numerically using\n",
    "two-sided finite differences, exactly as we did for fully-connected networks\n",
    "in Section 18.1.\n",
    "\n",
    "```{danger}\n",
    "**ALWAYS verify your gradient implementation numerically before training.**\n",
    "A wrong gradient will silently produce bad results -- the network will still\n",
    "\"train\" but learn nonsense. This is especially dangerous for convolutional layers\n",
    "where the index arithmetic is error-prone.\n",
    "```\n",
    "\n",
    "```{tip}\n",
    "We use $\\varepsilon = 10^{-5}$ for finite differences. The relative error between\n",
    "analytical and numerical gradients should be below $10^{-5}$ for a correct\n",
    "implementation.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Gradient checking for TinyCNN\n",
    "\n",
    "def numerical_gradient(model, x, y, param_array, epsilon=1e-5):\n",
    "    \"\"\"Compute numerical gradient for a parameter array via finite differences.\"\"\"\n",
    "    grad = np.zeros_like(param_array)\n",
    "    it = np.nditer(param_array, flags=['multi_index'], op_flags=['readwrite'])\n",
    "    while not it.finished:\n",
    "        idx = it.multi_index\n",
    "        old_val = param_array[idx]\n",
    "\n",
    "        param_array[idx] = old_val + epsilon\n",
    "        logits_plus = model.forward(x)\n",
    "        loss_plus, _, _ = softmax_cross_entropy(logits_plus, y)\n",
    "\n",
    "        param_array[idx] = old_val - epsilon\n",
    "        logits_minus = model.forward(x)\n",
    "        loss_minus, _, _ = softmax_cross_entropy(logits_minus, y)\n",
    "\n",
    "        param_array[idx] = old_val\n",
    "        grad[idx] = (loss_plus - loss_minus) / (2 * epsilon)\n",
    "        it.iternext()\n",
    "    return grad\n",
    "\n",
    "\n",
    "# Small model for gradient checking\n",
    "model_check = TinyCNN(seed=99)\n",
    "\n",
    "# Small random batch\n",
    "rng_check = np.random.default_rng(123)\n",
    "x_check = rng_check.normal(0, 1, size=(4, 1, 8, 8))\n",
    "y_check = np.array([0, 1, 2, 1])\n",
    "\n",
    "# Analytical gradients (from backward pass)\n",
    "model_check.loss_and_grad(x_check, y_check)\n",
    "\n",
    "print(\"Gradient Checking Results:\")\n",
    "print(\"=\" * 60)\n",
    "\n",
    "# Check Conv2D weights\n",
    "num_grad_cw = numerical_gradient(model_check, x_check, y_check,\n",
    "                                  model_check.conv.weights)\n",
    "ana_grad_cw = model_check.conv.d_weights\n",
    "diff = np.linalg.norm(ana_grad_cw - num_grad_cw)\n",
    "norm_sum = np.linalg.norm(ana_grad_cw) + np.linalg.norm(num_grad_cw) + 1e-8\n",
    "rel_err = diff / norm_sum\n",
    "status = \"OK\" if rel_err < 1e-5 else \"FAIL\"\n",
    "print(f\"Conv2D weights: relative error = {rel_err:.2e}  [{status}]\")\n",
    "\n",
    "# Check Conv2D bias\n",
    "num_grad_cb = numerical_gradient(model_check, x_check, y_check,\n",
    "                                  model_check.conv.bias)\n",
    "ana_grad_cb = model_check.conv.d_bias\n",
    "diff = np.linalg.norm(ana_grad_cb - num_grad_cb)\n",
    "norm_sum = np.linalg.norm(ana_grad_cb) + np.linalg.norm(num_grad_cb) + 1e-8\n",
    "rel_err = diff / norm_sum\n",
    "status = \"OK\" if rel_err < 1e-5 else \"FAIL\"\n",
    "print(f\"Conv2D bias:    relative error = {rel_err:.2e}  [{status}]\")\n",
    "\n",
    "# Check Dense weights\n",
    "num_grad_dw = numerical_gradient(model_check, x_check, y_check,\n",
    "                                  model_check.dense.weights)\n",
    "ana_grad_dw = model_check.dense.d_weights\n",
    "diff = np.linalg.norm(ana_grad_dw - num_grad_dw)\n",
    "norm_sum = np.linalg.norm(ana_grad_dw) + np.linalg.norm(num_grad_dw) + 1e-8\n",
    "rel_err = diff / norm_sum\n",
    "status = \"OK\" if rel_err < 1e-5 else \"FAIL\"\n",
    "print(f\"Dense weights:  relative error = {rel_err:.2e}  [{status}]\")\n",
    "\n",
    "# Check Dense bias\n",
    "num_grad_db = numerical_gradient(model_check, x_check, y_check,\n",
    "                                  model_check.dense.bias)\n",
    "ana_grad_db = model_check.dense.d_bias\n",
    "diff = np.linalg.norm(ana_grad_db - num_grad_db)\n",
    "norm_sum = np.linalg.norm(ana_grad_db) + np.linalg.norm(num_grad_db) + 1e-8\n",
    "rel_err = diff / norm_sum\n",
    "status = \"OK\" if rel_err < 1e-5 else \"FAIL\"\n",
    "print(f\"Dense bias:     relative error = {rel_err:.2e}  [{status}]\")\n",
    "\n",
    "print(\"=\" * 60)\n",
    "print(\"All relative errors should be < 1e-5.\")\n",
    "print(\"GRADIENT CHECK PASSED\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-10",
   "metadata": {},
   "source": [
    "## 24.5 Training the TinyCNN\n",
    "\n",
    "With verified gradients, we can now train the network on our synthetic line-pattern\n",
    "dataset. We first redefine the dataset generation functions (since this notebook\n",
    "executes independently)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-11",
   "metadata": {},
   "outputs": [],
   "source": [
    "CLASS_NAMES = (\"vertical\", \"horizontal\", \"diagonal\")\n",
    "\n",
    "\n",
    "def _pattern_for_class(class_name, size):\n",
    "    \"\"\"Generate the base pattern for a given class.\"\"\"\n",
    "    pattern = np.zeros((size, size))\n",
    "    center = size // 2\n",
    "    if class_name == \"vertical\":\n",
    "        pattern[:, center - 1:center + 1] = 1.0\n",
    "    elif class_name == \"horizontal\":\n",
    "        pattern[center - 1:center + 1, :] = 1.0\n",
    "    elif class_name == \"diagonal\":\n",
    "        np.fill_diagonal(pattern, 1.0)\n",
    "        pattern += 0.35 * np.eye(size, k=1)\n",
    "        pattern += 0.35 * np.eye(size, k=-1)\n",
    "    return np.clip(pattern, 0.0, 1.0)\n",
    "\n",
    "\n",
    "def make_dataset_bundle(train_per_class=60, val_per_class=30,\n",
    "                         size=8, seed=7, class_names=CLASS_NAMES):\n",
    "    \"\"\"Generate shifted, noisy variants of base patterns.\"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    all_x, all_y = [], []\n",
    "    total_per_class = train_per_class + val_per_class\n",
    "\n",
    "    for cls_idx, cls_name in enumerate(class_names):\n",
    "        base = _pattern_for_class(cls_name, size)\n",
    "        for _ in range(total_per_class):\n",
    "            shift_r = rng.integers(-2, 3)\n",
    "            shift_c = rng.integers(-2, 3)\n",
    "            shifted = np.roll(np.roll(base, shift_r, axis=0),\n",
    "                              shift_c, axis=1)\n",
    "            noisy = shifted + rng.normal(0, 0.15, size=(size, size))\n",
    "            noisy = np.clip(noisy, 0.0, 1.0)\n",
    "            all_x.append(noisy)\n",
    "            all_y.append(cls_idx)\n",
    "\n",
    "    all_x = np.array(all_x)[:, None, :, :]\n",
    "    all_y = np.array(all_y, dtype=int)\n",
    "\n",
    "    perm = rng.permutation(len(all_y))\n",
    "    all_x, all_y = all_x[perm], all_y[perm]\n",
    "\n",
    "    n_train = train_per_class * len(class_names)\n",
    "    x_train, y_train = all_x[:n_train], all_y[:n_train]\n",
    "    x_val, y_val = all_x[n_train:], all_y[n_train:]\n",
    "\n",
    "    return x_train, y_train, x_val, y_val, class_names\n",
    "\n",
    "\n",
    "# Generate dataset\n",
    "x_train, y_train, x_val, y_val, class_names = make_dataset_bundle()\n",
    "print(f\"Training:   {x_train.shape[0]} images\")\n",
    "print(f\"Validation: {x_val.shape[0]} images\")\n",
    "print(f\"Classes: {class_names}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-11a",
   "metadata": {},
   "source": [
    "Now we train the TinyCNN for 80 epochs with learning rate 0.12 and batch size 18."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training the TinyCNN\n",
    "model = TinyCNN(seed=42)\n",
    "\n",
    "print(\"Training TinyCNN on synthetic line patterns...\")\n",
    "print(f\"Architecture: Conv2D(1->3, 3x3) -> ReLU -> MaxPool(2) -> Flatten -> Dense(27->3)\")\n",
    "print(f\"Parameters: {model.conv.parameter_count + model.dense.parameter_count}\")\n",
    "print(f\"Hyperparameters: epochs=80, lr=0.12, batch_size=18\")\n",
    "print()\n",
    "\n",
    "history = model.fit(x_train, y_train, x_val, y_val,\n",
    "                    epochs=80, lr=0.12, batch_size=18, seed=0)\n",
    "\n",
    "print(f\"\\nFinal train accuracy: {history['train_acc'][-1]:.1%}\")\n",
    "print(f\"Final val accuracy:   {history['val_acc'][-1]:.1%}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-13",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Loss curves\n",
    "fig, axes = plt.subplots(1, 2, figsize=(13, 5))\n",
    "\n",
    "epochs_range = range(len(history['train_loss']))\n",
    "\n",
    "# Loss\n",
    "axes[0].plot(epochs_range, history['train_loss'], linewidth=2,\n",
    "             color=BLUE, label='Train loss')\n",
    "axes[0].plot(epochs_range, history['val_loss'], linewidth=2,\n",
    "             color=AMBER, linestyle='--', label='Val loss')\n",
    "axes[0].set_xlabel('Epoch', fontsize=12)\n",
    "axes[0].set_ylabel('Cross-Entropy Loss', fontsize=12)\n",
    "axes[0].set_title('Training and Validation Loss', fontsize=13)\n",
    "axes[0].legend(fontsize=11)\n",
    "axes[0].grid(True, alpha=0.3)\n",
    "\n",
    "# Accuracy\n",
    "axes[1].plot(epochs_range, [a * 100 for a in history['train_acc']],\n",
    "             linewidth=2, color=GREEN, label='Train accuracy')\n",
    "axes[1].plot(epochs_range, [a * 100 for a in history['val_acc']],\n",
    "             linewidth=2, color=BURGUNDY, linestyle='--',\n",
    "             label='Val accuracy')\n",
    "axes[1].axhline(y=100/3, color='gray', linestyle=':', alpha=0.5,\n",
    "                label='Chance (33.3%)')\n",
    "axes[1].set_xlabel('Epoch', fontsize=12)\n",
    "axes[1].set_ylabel('Accuracy (%)', fontsize=12)\n",
    "axes[1].set_title('Training and Validation Accuracy', fontsize=13)\n",
    "axes[1].set_ylim(0, 105)\n",
    "axes[1].legend(fontsize=11, loc='lower right')\n",
    "axes[1].grid(True, alpha=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-14",
   "metadata": {},
   "source": [
    "## 24.6 Filter Evolution\n",
    "\n",
    "One of the most compelling aspects of CNNs is that the learned filters become\n",
    "*interpretable*: they specialize into oriented edge detectors that match the\n",
    "structure of the data. The figure below shows snapshots of the three $3 \\times 3$\n",
    "convolutional filters at various stages of training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-15",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Filter evolution snapshots\n",
    "snapshot_epochs = sorted(history['kernel_snapshots'].keys())\n",
    "n_snapshots = len(snapshot_epochs)\n",
    "n_filters = 3\n",
    "\n",
    "fig, axes = plt.subplots(n_filters, n_snapshots,\n",
    "                          figsize=(2.5 * n_snapshots, 2.5 * n_filters))\n",
    "\n",
    "# Find global min/max for consistent colorscale\n",
    "all_vals = np.concatenate(\n",
    "    [w.ravel() for w in history['kernel_snapshots'].values()])\n",
    "vmax = max(abs(all_vals.min()), abs(all_vals.max()))\n",
    "\n",
    "for col, epoch in enumerate(snapshot_epochs):\n",
    "    kernels = history['kernel_snapshots'][epoch]\n",
    "    for f in range(n_filters):\n",
    "        ax = axes[f, col]\n",
    "        im = ax.imshow(kernels[f, 0], cmap='RdBu_r',\n",
    "                       vmin=-vmax, vmax=vmax,\n",
    "                       interpolation='nearest')\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        if f == 0:\n",
    "            ax.set_title(f'Epoch {epoch}', fontsize=11,\n",
    "                         fontweight='bold')\n",
    "        if col == 0:\n",
    "            ax.set_ylabel(f'Filter {f+1}', fontsize=11,\n",
    "                          fontweight='bold')\n",
    "\n",
    "fig.suptitle('Convolutional Filter Evolution During Training',\n",
    "             fontsize=14, fontweight='bold', y=1.02)\n",
    "fig.colorbar(im, ax=axes, orientation='vertical', fraction=0.02,\n",
    "             pad=0.04, label='Weight value')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(\"The filters evolve from random noise (epoch 0) to structured\")\n",
    "print(\"patterns that detect specific orientations in the input.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-16",
   "metadata": {},
   "source": [
    "## 24.7 Confusion Matrix\n",
    "\n",
    "The confusion matrix reveals which classes the network confuses. For our three-class\n",
    "problem, diagonal entries represent correct predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-17",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Confusion matrix on validation set\n",
    "val_preds = model.predict(x_val)\n",
    "n_classes = len(class_names)\n",
    "conf_matrix = np.zeros((n_classes, n_classes), dtype=int)\n",
    "for true, pred in zip(y_val, val_preds):\n",
    "    conf_matrix[true, pred] += 1\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(6, 5))\n",
    "im = ax.imshow(conf_matrix, cmap='Blues')\n",
    "\n",
    "# Annotate cells\n",
    "for i in range(n_classes):\n",
    "    for j in range(n_classes):\n",
    "        val = conf_matrix[i, j]\n",
    "        color = 'white' if val > conf_matrix.max() / 2 else 'black'\n",
    "        ax.text(j, i, str(val), ha='center', va='center',\n",
    "                fontsize=16, fontweight='bold', color=color)\n",
    "\n",
    "ax.set_xticks(range(n_classes))\n",
    "ax.set_yticks(range(n_classes))\n",
    "ax.set_xticklabels(class_names, fontsize=11)\n",
    "ax.set_yticklabels(class_names, fontsize=11)\n",
    "ax.set_xlabel('Predicted', fontsize=12)\n",
    "ax.set_ylabel('True', fontsize=12)\n",
    "ax.set_title('Validation Confusion Matrix', fontsize=13,\n",
    "             fontweight='bold')\n",
    "fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Per-class accuracy\n",
    "print(\"Per-class accuracy:\")\n",
    "for i, name in enumerate(class_names):\n",
    "    total = conf_matrix[i].sum()\n",
    "    correct = conf_matrix[i, i]\n",
    "    print(f\"  {name:>12s}: {correct}/{total} = {correct/total:.1%}\")\n",
    "print(f\"\\nOverall validation accuracy: {model.evaluate(x_val, y_val):.1%}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-18",
   "metadata": {},
   "source": [
    "## Exercises\n",
    "\n",
    "**Exercise 24.1.** Implement **momentum** for the Conv2D and Dense layers. Add a\n",
    "`velocity` attribute initialized to zero, and modify `step` to use\n",
    "$v \\leftarrow \\beta v + \\nabla \\mathcal{L}$, $\\theta \\leftarrow \\theta - \\eta v$\n",
    "with $\\beta = 0.9$. Compare convergence with and without momentum.\n",
    "\n",
    "**Exercise 24.2.** Experiment with the learning rate. Train TinyCNN with\n",
    "$\\eta \\in \\{0.01, 0.05, 0.12, 0.5, 1.0\\}$ and plot the loss curves on the same\n",
    "axes. What happens when the learning rate is too large?\n",
    "\n",
    "**Exercise 24.3.** Modify TinyCNN to use two convolutional layers:\n",
    "`Conv2D(1,3,3) -> ReLU -> Conv2D(3,6,3) -> ReLU -> MaxPool(2) -> Flatten -> Dense`.\n",
    "How many parameters does this deeper architecture have? Does it achieve higher\n",
    "accuracy on the line-pattern task?\n",
    "\n",
    "**Exercise 24.4.** Derive the backward pass for **average pooling** (where each\n",
    "window is replaced by its mean instead of its maximum). Implement it and compare\n",
    "performance with max pooling on the same dataset.\n",
    "\n",
    "**Exercise 24.5.** Add a fourth class `\"dot\"` to the dataset and retrain the network\n",
    "with `Dense(27, 4)`. Inspect the learned filters -- does a new filter emerge that\n",
    "responds to the dot pattern? If not, explain why three $3 \\times 3$ filters might\n",
    "be insufficient and propose a solution."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
