{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-00-infrastructure",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# ── Utility functions ──────────────────────────────────────────────\n",
    "\n",
    "def softmax(logits):\n",
    "    shifted = logits - logits.max(axis=1, keepdims=True)\n",
    "    exp_values = np.exp(shifted)\n",
    "    return exp_values / exp_values.sum(axis=1, keepdims=True)\n",
    "\n",
    "def softmax_cross_entropy(logits, targets):\n",
    "    probabilities = softmax(logits)\n",
    "    clipped = np.clip(probabilities[np.arange(targets.shape[0]), targets], 1e-12, 1.0)\n",
    "    loss = -np.log(clipped).mean()\n",
    "    d_logits = probabilities.copy()\n",
    "    d_logits[np.arange(targets.shape[0]), targets] -= 1.0\n",
    "    d_logits /= targets.shape[0]\n",
    "    return loss, probabilities, d_logits\n",
    "\n",
    "# ── Layer definitions ──────────────────────────────────────────────\n",
    "\n",
    "class Conv2D:\n",
    "    def __init__(self, in_channels, out_channels, kernel_size, seed=42):\n",
    "        rng = np.random.default_rng(seed)\n",
    "        fan_in = in_channels * kernel_size * kernel_size\n",
    "        scale = np.sqrt(2.0 / fan_in)\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.kernel_size = kernel_size\n",
    "        self.weights = rng.normal(0.0, scale, size=(out_channels, in_channels, kernel_size, kernel_size))\n",
    "        self.bias = np.zeros(out_channels)\n",
    "        self.d_weights = np.zeros_like(self.weights)\n",
    "        self.d_bias = np.zeros_like(self.bias)\n",
    "        self.last_input = None\n",
    "    def forward(self, x):\n",
    "        self.last_input = x\n",
    "        batch_size, _, height, width = x.shape\n",
    "        out_h = height - self.kernel_size + 1\n",
    "        out_w = width - self.kernel_size + 1\n",
    "        output = np.zeros((batch_size, self.out_channels, out_h, out_w))\n",
    "        for row in range(out_h):\n",
    "            for col in range(out_w):\n",
    "                patch = x[:, :, row:row+self.kernel_size, col:col+self.kernel_size]\n",
    "                output[:, :, row, col] = np.tensordot(patch, self.weights, axes=([1,2,3],[1,2,3])) + self.bias\n",
    "        return output\n",
    "    def backward(self, d_output):\n",
    "        x = self.last_input\n",
    "        _, _, out_h, out_w = d_output.shape\n",
    "        self.d_weights.fill(0.0)\n",
    "        self.d_bias = d_output.sum(axis=(0, 2, 3))\n",
    "        d_input = np.zeros_like(x)\n",
    "        for row in range(out_h):\n",
    "            for col in range(out_w):\n",
    "                patch = x[:, :, row:row+self.kernel_size, col:col+self.kernel_size]\n",
    "                self.d_weights += np.tensordot(d_output[:, :, row, col], patch, axes=([0], [0]))\n",
    "                d_input[:, :, row:row+self.kernel_size, col:col+self.kernel_size] += np.tensordot(\n",
    "                    d_output[:, :, row, col], self.weights, axes=([1], [0]))\n",
    "        return d_input\n",
    "    def step(self, lr):\n",
    "        self.weights -= lr * self.d_weights\n",
    "        self.bias -= lr * self.d_bias\n",
    "    @property\n",
    "    def parameter_count(self):\n",
    "        return int(self.weights.size + self.bias.size)\n",
    "\n",
    "class ReLU:\n",
    "    def __init__(self):\n",
    "        self.last_input = None\n",
    "    def forward(self, x):\n",
    "        self.last_input = x\n",
    "        return np.maximum(0.0, x)\n",
    "    def backward(self, d_output):\n",
    "        return d_output * (self.last_input > 0.0)\n",
    "\n",
    "class MaxPool2D:\n",
    "    def __init__(self, pool_size=2):\n",
    "        self.pool_size = pool_size\n",
    "        self.last_input = None\n",
    "        self.last_mask = None\n",
    "    def forward(self, x):\n",
    "        self.last_input = x\n",
    "        batch_size, channels, height, width = x.shape\n",
    "        out_h = height // self.pool_size\n",
    "        out_w = width // self.pool_size\n",
    "        output = np.zeros((batch_size, channels, out_h, out_w))\n",
    "        mask = np.zeros_like(x)\n",
    "        for row in range(out_h):\n",
    "            for col in range(out_w):\n",
    "                rs, cs = row * self.pool_size, col * self.pool_size\n",
    "                window = x[:, :, rs:rs+self.pool_size, cs:cs+self.pool_size]\n",
    "                output[:, :, row, col] = window.max(axis=(2, 3))\n",
    "                flat_idx = window.reshape(batch_size, channels, -1).argmax(axis=2)\n",
    "                for b in range(batch_size):\n",
    "                    for c in range(channels):\n",
    "                        winner = flat_idx[b, c]\n",
    "                        mask[b, c, rs + winner // self.pool_size, cs + winner % self.pool_size] = 1.0\n",
    "        self.last_mask = mask\n",
    "        return output\n",
    "    def backward(self, d_output):\n",
    "        d_input = np.zeros_like(self.last_input)\n",
    "        out_h, out_w = d_output.shape[2], d_output.shape[3]\n",
    "        for row in range(out_h):\n",
    "            for col in range(out_w):\n",
    "                rs, cs = row * self.pool_size, col * self.pool_size\n",
    "                mask_w = self.last_mask[:, :, rs:rs+self.pool_size, cs:cs+self.pool_size]\n",
    "                d_input[:, :, rs:rs+self.pool_size, cs:cs+self.pool_size] += (\n",
    "                    mask_w * d_output[:, :, row, col][:, :, None, None])\n",
    "        return d_input\n",
    "\n",
    "class Flatten:\n",
    "    def __init__(self):\n",
    "        self.last_shape = None\n",
    "    def forward(self, x):\n",
    "        self.last_shape = x.shape\n",
    "        return x.reshape(x.shape[0], -1)\n",
    "    def backward(self, d_output):\n",
    "        return d_output.reshape(self.last_shape)\n",
    "\n",
    "class Dense:\n",
    "    def __init__(self, in_features, out_features, seed=42):\n",
    "        rng = np.random.default_rng(seed)\n",
    "        scale = np.sqrt(2.0 / in_features)\n",
    "        self.weights = rng.normal(0.0, scale, size=(in_features, out_features))\n",
    "        self.bias = np.zeros(out_features)\n",
    "        self.d_weights = np.zeros_like(self.weights)\n",
    "        self.d_bias = np.zeros_like(self.bias)\n",
    "        self.last_input = None\n",
    "    def forward(self, x):\n",
    "        self.last_input = x\n",
    "        return x @ self.weights + self.bias\n",
    "    def backward(self, d_output):\n",
    "        self.d_weights = self.last_input.T @ d_output\n",
    "        self.d_bias = d_output.sum(axis=0)\n",
    "        return d_output @ self.weights.T\n",
    "    def step(self, lr):\n",
    "        self.weights -= lr * self.d_weights\n",
    "        self.bias -= lr * self.d_bias\n",
    "    @property\n",
    "    def parameter_count(self):\n",
    "        return int(self.weights.size + self.bias.size)\n",
    "\n",
    "# ── TinyCNN ─────────────────────────────────────────────────────────\n",
    "\n",
    "class TinyCNN:\n",
    "    def __init__(self, seed=3, input_size=8, num_classes=3, conv_filters=3):\n",
    "        rng_seed = seed\n",
    "        self.conv = Conv2D(1, conv_filters, 3, seed=rng_seed)\n",
    "        self.relu = ReLU()\n",
    "        self.pool = MaxPool2D(2)\n",
    "        self.flatten = Flatten()\n",
    "        conv_out = input_size - 3 + 1\n",
    "        pooled = conv_out // 2\n",
    "        self.dense = Dense(conv_filters * pooled * pooled, num_classes, seed=rng_seed + 1)\n",
    "        self.num_classes = num_classes\n",
    "        self.conv_filters = conv_filters\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.dense.forward(self.flatten.forward(self.pool.forward(self.relu.forward(self.conv.forward(x)))))\n",
    "\n",
    "    def forward_with_trace(self, x):\n",
    "        \"\"\"Forward pass returning all intermediate activations.\"\"\"\n",
    "        conv_out = self.conv.forward(x)\n",
    "        relu_out = self.relu.forward(conv_out)\n",
    "        pool_out = self.pool.forward(relu_out)\n",
    "        flat_out = self.flatten.forward(pool_out)\n",
    "        logits = self.dense.forward(flat_out)\n",
    "        probs = softmax(logits)\n",
    "        return {\n",
    "            'input': x,\n",
    "            'conv': conv_out,\n",
    "            'relu': relu_out,\n",
    "            'pool': pool_out,\n",
    "            'logits': logits,\n",
    "            'probs': probs,\n",
    "        }\n",
    "\n",
    "    def loss_and_grad(self, x, y):\n",
    "        logits = self.forward(x)\n",
    "        loss, probs, d_logits = softmax_cross_entropy(logits, y)\n",
    "        d = self.dense.backward(d_logits)\n",
    "        d = self.flatten.backward(d)\n",
    "        d = self.pool.backward(d)\n",
    "        d = self.relu.backward(d)\n",
    "        self.conv.backward(d)\n",
    "        return loss, probs\n",
    "\n",
    "    def step(self, lr):\n",
    "        self.conv.step(lr)\n",
    "        self.dense.step(lr)\n",
    "\n",
    "    def evaluate(self, x, y):\n",
    "        logits = self.forward(x)\n",
    "        loss, _, _ = softmax_cross_entropy(logits, y)\n",
    "        accuracy = float((logits.argmax(axis=1) == y).mean())\n",
    "        return loss, accuracy\n",
    "\n",
    "    def fit(self, x_train, y_train, x_val, y_val, epochs=80, lr=0.12, batch_size=18, seed=11, snapshot_epochs=(0,1,5,15,40,80)):\n",
    "        rng = np.random.default_rng(seed)\n",
    "        history = []\n",
    "        snapshots = []\n",
    "        def record(ep):\n",
    "            tl, ta = self.evaluate(x_train, y_train)\n",
    "            _, va = self.evaluate(x_val, y_val)\n",
    "            history.append({\"epoch\": ep, \"train_loss\": tl, \"train_acc\": ta, \"val_acc\": va})\n",
    "            if ep in snapshot_epochs:\n",
    "                snapshots.append({\"epoch\": ep, \"kernels\": self.conv.weights.copy()})\n",
    "        record(0)\n",
    "        for ep in range(1, epochs + 1):\n",
    "            order = rng.permutation(x_train.shape[0])\n",
    "            sx, sy = x_train[order], y_train[order]\n",
    "            for start in range(0, x_train.shape[0], batch_size):\n",
    "                self.loss_and_grad(sx[start:start+batch_size], sy[start:start+batch_size])\n",
    "                self.step(lr)\n",
    "            record(ep)\n",
    "        return history, snapshots\n",
    "\n",
    "    @property\n",
    "    def total_parameters(self):\n",
    "        return self.conv.parameter_count + self.dense.parameter_count\n",
    "\n",
    "# ── TinyMLP ─────────────────────────────────────────────────────────\n",
    "\n",
    "class TinyMLP:\n",
    "    def __init__(self, seed=5, input_size=8, hidden_size=18, num_classes=3):\n",
    "        rng = np.random.default_rng(seed)\n",
    "        flat_size = input_size * input_size\n",
    "        self.flatten = Flatten()\n",
    "        self.hidden = Dense(flat_size, hidden_size, seed=seed)\n",
    "        self.relu = ReLU()\n",
    "        self.output = Dense(hidden_size, num_classes, seed=seed + 1)\n",
    "        self.num_classes = num_classes\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.output.forward(self.relu.forward(self.hidden.forward(self.flatten.forward(x))))\n",
    "\n",
    "    def loss_and_grad(self, x, y):\n",
    "        logits = self.forward(x)\n",
    "        loss, probs, d_logits = softmax_cross_entropy(logits, y)\n",
    "        d = self.output.backward(d_logits)\n",
    "        d = self.relu.backward(d)\n",
    "        d = self.hidden.backward(d)\n",
    "        self.flatten.backward(d)\n",
    "        return loss, probs\n",
    "\n",
    "    def step(self, lr):\n",
    "        self.hidden.step(lr)\n",
    "        self.output.step(lr)\n",
    "\n",
    "    def evaluate(self, x, y):\n",
    "        logits = self.forward(x)\n",
    "        loss, _, _ = softmax_cross_entropy(logits, y)\n",
    "        accuracy = float((logits.argmax(axis=1) == y).mean())\n",
    "        return loss, accuracy\n",
    "\n",
    "    def fit(self, x_train, y_train, x_val, y_val, epochs=80, lr=0.12, batch_size=18, seed=13):\n",
    "        rng = np.random.default_rng(seed)\n",
    "        history = []\n",
    "        def record(ep):\n",
    "            tl, ta = self.evaluate(x_train, y_train)\n",
    "            _, va = self.evaluate(x_val, y_val)\n",
    "            history.append({\"epoch\": ep, \"train_loss\": tl, \"train_acc\": ta, \"val_acc\": va})\n",
    "        record(0)\n",
    "        for ep in range(1, epochs + 1):\n",
    "            order = rng.permutation(x_train.shape[0])\n",
    "            sx, sy = x_train[order], y_train[order]\n",
    "            for start in range(0, x_train.shape[0], batch_size):\n",
    "                self.loss_and_grad(sx[start:start+batch_size], sy[start:start+batch_size])\n",
    "                self.step(lr)\n",
    "            record(ep)\n",
    "        return history\n",
    "\n",
    "    @property\n",
    "    def total_parameters(self):\n",
    "        return self.hidden.parameter_count + self.output.parameter_count\n",
    "\n",
    "# ── Dataset ──────────────────────────────────────────────────────────\n",
    "\n",
    "CLASS_NAMES = (\"vertical\", \"horizontal\", \"diagonal\")\n",
    "DOT_CLASS_NAMES = (\"vertical\", \"horizontal\", \"diagonal\", \"dot\")\n",
    "\n",
    "def _pattern_for_class(name, size):\n",
    "    pattern = np.zeros((size, size))\n",
    "    c = size // 2\n",
    "    if name == \"vertical\": pattern[:, c-1:c+1] = 1.0\n",
    "    elif name == \"horizontal\": pattern[c-1:c+1, :] = 1.0\n",
    "    elif name == \"diagonal\":\n",
    "        np.fill_diagonal(pattern, 1.0)\n",
    "        pattern += 0.35 * np.eye(size, k=1) + 0.35 * np.eye(size, k=-1)\n",
    "    elif name == \"dot\": pattern[c-1:c+1, c-1:c+1] = 1.0\n",
    "    return np.clip(pattern, 0.0, 1.0)\n",
    "\n",
    "def make_dataset(train_per_class=60, val_per_class=30, size=8, seed=7, class_names=CLASS_NAMES):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    total = train_per_class + val_per_class\n",
    "    patterns = [_pattern_for_class(n, size) for n in class_names]\n",
    "    all_x, all_y = [], []\n",
    "    for label, pat in enumerate(patterns):\n",
    "        for _ in range(total):\n",
    "            img = pat.copy()\n",
    "            sy, sx = rng.integers(-1, 2, size=2)\n",
    "            shifted = np.zeros_like(img)\n",
    "            src_ys = max(0, -sy); src_ye = size - max(0, sy)\n",
    "            src_xs = max(0, -sx); src_xe = size - max(0, sx)\n",
    "            dst_ys = max(0, sy); dst_xs = max(0, sx)\n",
    "            shifted[dst_ys:dst_ys+(src_ye-src_ys), dst_xs:dst_xs+(src_xe-src_xs)] = img[src_ys:src_ye, src_xs:src_xe]\n",
    "            canvas = shifted * rng.uniform(0.85, 1.15)\n",
    "            canvas += rng.normal(0, 0.12, (size, size))\n",
    "            all_x.append(np.clip(canvas, 0, 1))\n",
    "            all_y.append(label)\n",
    "    x = np.array(all_x)[:, None, :, :]\n",
    "    y = np.array(all_y, dtype=np.int64)\n",
    "    rng2 = np.random.default_rng(seed + 99)\n",
    "    xt, yt, xv, yv = [], [], [], []\n",
    "    for label in range(len(class_names)):\n",
    "        idx = np.where(y == label)[0]\n",
    "        order = rng2.permutation(len(idx))\n",
    "        idx = idx[order]\n",
    "        xt.append(x[idx[:train_per_class]]); yt.append(y[idx[:train_per_class]])\n",
    "        xv.append(x[idx[train_per_class:]]); yv.append(y[idx[train_per_class:]])\n",
    "    x_train = np.concatenate(xt); y_train = np.concatenate(yt)\n",
    "    x_val = np.concatenate(xv); y_val = np.concatenate(yv)\n",
    "    order_t = rng2.permutation(len(x_train)); order_v = rng2.permutation(len(x_val))\n",
    "    return x_train[order_t], y_train[order_t], x_val[order_v], y_val[order_v]\n",
    "\n",
    "# ── Plot styling ─────────────────────────────────────────────────────\n",
    "\n",
    "CREAM = '#fdf6ec'\n",
    "CNN_BLUE = '#22577a'\n",
    "MLP_RED = '#bc4749'\n",
    "GREEN_LIGHT = '#55a630'\n",
    "GREEN_DARK = '#2b9348'\n",
    "\n",
    "def style_ax(ax, title=None, xlabel=None, ylabel=None):\n",
    "    ax.set_facecolor(CREAM)\n",
    "    if title: ax.set_title(title, fontsize=13, fontweight='bold', pad=8)\n",
    "    if xlabel: ax.set_xlabel(xlabel, fontsize=11)\n",
    "    if ylabel: ax.set_ylabel(ylabel, fontsize=11)\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    ax.grid(True, alpha=0.3, linestyle='--')\n",
    "\n",
    "print('Infrastructure loaded: Conv2D, ReLU, MaxPool2D, Flatten, Dense, TinyCNN, TinyMLP, make_dataset')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-01-title",
   "metadata": {},
   "source": [
    "# Chapter 25: CNN Experiments and Analysis\n",
    "\n",
    "In the previous chapter we built a convolutional neural network from scratch and\n",
    "watched it learn to classify simple 8$\\times$8 patterns. Now we turn to the\n",
    "**experimental** side: How does the CNN compare to a plain fully-connected\n",
    "network? How sensitive is the MLP baseline to its hidden-layer width? What\n",
    "happens when we add a fourth pattern class? And what, exactly, do the learned\n",
    "convolutional filters look like?\n",
    "\n",
    "This chapter is heavy on **visualization and analysis**. Every figure is\n",
    "generated from our pure-NumPy implementations; no external deep-learning\n",
    "library is used."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-02-sec25-1-title",
   "metadata": {},
   "source": [
    "## 25.1 CNN vs MLP Baseline\n",
    "\n",
    "Our first experiment is the most natural one: **train both architectures on the\n",
    "same 3-class data and compare them**.\n",
    "\n",
    "The `TinyCNN` uses 3 convolutional filters of size $3\\times 3$ followed by\n",
    "ReLU, $2\\times 2$ max-pooling, flattening, and a dense output layer. The\n",
    "`TinyMLP` flattens the image directly and passes it through a hidden layer of\n",
    "18 units with ReLU, then a dense output layer.\n",
    "\n",
    "We train both for 80 epochs with learning rate $\\eta = 0.12$ and batch size 18."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-03-baseline-train",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Train CNN and MLP on the 3-class dataset\n",
    "x_train, y_train, x_val, y_val = make_dataset()\n",
    "\n",
    "cnn = TinyCNN(seed=3)\n",
    "cnn_history, cnn_snapshots = cnn.fit(x_train, y_train, x_val, y_val)\n",
    "\n",
    "mlp = TinyMLP(seed=5)\n",
    "mlp_history = mlp.fit(x_train, y_train, x_val, y_val)\n",
    "\n",
    "# Two-panel figure: training curves and parameter counts\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4.2))\n",
    "fig.patch.set_facecolor(CREAM)\n",
    "\n",
    "# Left: validation accuracy curves\n",
    "cnn_epochs = [h['epoch'] for h in cnn_history]\n",
    "cnn_val = [h['val_acc'] for h in cnn_history]\n",
    "mlp_epochs = [h['epoch'] for h in mlp_history]\n",
    "mlp_val = [h['val_acc'] for h in mlp_history]\n",
    "\n",
    "ax1.plot(cnn_epochs, cnn_val, color=CNN_BLUE, linewidth=2.2, label=f'TinyCNN ({cnn.total_parameters} params)')\n",
    "ax1.plot(mlp_epochs, mlp_val, color=MLP_RED, linewidth=2.2, linestyle='--', label=f'TinyMLP ({mlp.total_parameters} params)')\n",
    "ax1.axhline(1.0, color='grey', linewidth=0.8, linestyle=':')\n",
    "ax1.set_ylim(0.0, 1.08)\n",
    "style_ax(ax1, title='Validation Accuracy', xlabel='Epoch', ylabel='Accuracy')\n",
    "ax1.legend(fontsize=10, loc='lower right', framealpha=0.9)\n",
    "\n",
    "# Right: parameter count bar chart\n",
    "models = ['TinyCNN', 'TinyMLP']\n",
    "params = [cnn.total_parameters, mlp.total_parameters]\n",
    "colors = [CNN_BLUE, MLP_RED]\n",
    "bars = ax2.bar(models, params, color=colors, width=0.5, edgecolor='white', linewidth=1.5)\n",
    "for bar, p in zip(bars, params):\n",
    "    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 20,\n",
    "             f'{p:,}', ha='center', va='bottom', fontsize=12, fontweight='bold')\n",
    "style_ax(ax2, title='Parameter Count', ylabel='Number of Parameters')\n",
    "ax2.set_ylim(0, max(params) * 1.25)\n",
    "\n",
    "fig.suptitle('CNN vs MLP on 3-Class Pattern Recognition', fontsize=14, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Print final results\n",
    "cnn_final = cnn_history[-1]\n",
    "mlp_final = mlp_history[-1]\n",
    "print(f'CNN  final accuracy: train={cnn_final[\"train_acc\"]:.3f}, val={cnn_final[\"val_acc\"]:.3f}  ({cnn.total_parameters} parameters)')\n",
    "print(f'MLP  final accuracy: train={mlp_final[\"train_acc\"]:.3f}, val={mlp_final[\"val_acc\"]:.3f}  ({mlp.total_parameters} parameters)')\n",
    "print(f'Parameter ratio:     MLP / CNN = {mlp.total_parameters / cnn.total_parameters:.1f}x')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-04-baseline-discussion",
   "metadata": {},
   "source": [
    "```{admonition} Key Observation\n",
    ":class: important\n",
    "On this toy task, both architectures achieve excellent accuracy. The CNN\n",
    "advantage is not in final accuracy but in **parameter efficiency** and\n",
    "**interpretability**: the CNN uses roughly 10$\\times$ fewer parameters, and its\n",
    "learned filters have a clear visual meaning (as we will see in Section 25.5).\n",
    "```\n",
    "\n",
    "The MLP must learn separate weights for every pixel position, while the CNN\n",
    "re-uses the same small $3\\times 3$ kernel across the entire image. This\n",
    "**weight sharing** is the source of the parameter savings."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-05-sec25-2-title",
   "metadata": {},
   "source": [
    "## 25.2 Hidden-Width Capacity Sweep\n",
    "\n",
    "The TinyMLP with 18 hidden units has more than enough capacity for our\n",
    "3-class task. But what happens if we **shrink** the hidden layer? At what point\n",
    "does the MLP fail, and how does its stability change?\n",
    "\n",
    "We sweep over hidden-layer widths $h \\in \\{1, 2, 3, 4, 6, 10, 18\\}$ and\n",
    "train 5 random seeds per width. For each width we report the **mean**,\n",
    "**minimum**, and **maximum** validation accuracy across seeds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-06-capacity-sweep",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Capacity sweep: MLP hidden-width from 1 to 18\n",
    "widths = [1, 2, 3, 4, 6, 10, 18]\n",
    "seeds = list(range(5))\n",
    "sweep_results = {}\n",
    "\n",
    "for w in widths:\n",
    "    accs = []\n",
    "    for s in seeds:\n",
    "        model = TinyMLP(seed=s, hidden_size=w)\n",
    "        hist = model.fit(x_train, y_train, x_val, y_val, seed=s + 100)\n",
    "        accs.append(hist[-1]['val_acc'])\n",
    "    sweep_results[w] = {'mean': np.mean(accs), 'min': np.min(accs),\n",
    "                        'max': np.max(accs), 'all': accs}\n",
    "\n",
    "# CNN reference (single value)\n",
    "cnn_ref_acc = cnn_history[-1]['val_acc']\n",
    "\n",
    "# Band plot\n",
    "fig, ax = plt.subplots(figsize=(8, 4.5))\n",
    "fig.patch.set_facecolor(CREAM)\n",
    "\n",
    "means = [sweep_results[w]['mean'] for w in widths]\n",
    "mins = [sweep_results[w]['min'] for w in widths]\n",
    "maxs = [sweep_results[w]['max'] for w in widths]\n",
    "\n",
    "ax.fill_between(widths, mins, maxs, color=MLP_RED, alpha=0.18, label='MLP min/max range')\n",
    "ax.plot(widths, means, color=MLP_RED, linewidth=2.2, marker='o', markersize=6, label='MLP mean accuracy')\n",
    "ax.axhline(cnn_ref_acc, color=CNN_BLUE, linewidth=2.0, linestyle='--',\n",
    "           label=f'CNN ({cnn.total_parameters} params)')\n",
    "\n",
    "# Mark the \"first stable\" width\n",
    "ax.axvline(4, color=GREEN_DARK, linewidth=1.2, linestyle=':', alpha=0.7)\n",
    "ax.annotate('first stable\\nwidth', xy=(4, 0.55), fontsize=9, color=GREEN_DARK,\n",
    "            ha='center', style='italic')\n",
    "\n",
    "style_ax(ax, title='MLP Validation Accuracy vs Hidden-Layer Width',\n",
    "         xlabel='Hidden units', ylabel='Validation accuracy')\n",
    "ax.set_xticks(widths)\n",
    "ax.set_ylim(0.2, 1.08)\n",
    "ax.legend(fontsize=10, loc='lower right', framealpha=0.9)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Print table\n",
    "print(f'{\"Width\":>6} {\"Mean\":>8} {\"Min\":>8} {\"Max\":>8}')\n",
    "print('-' * 34)\n",
    "for w in widths:\n",
    "    r = sweep_results[w]\n",
    "    print(f'{w:>6} {r[\"mean\"]:>8.3f} {r[\"min\"]:>8.3f} {r[\"max\"]:>8.3f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-07-capacity-discussion",
   "metadata": {},
   "source": [
    "The sweep reveals a clear pattern:\n",
    "\n",
    "- **Width 1** consistently fails: a single hidden unit cannot separate three\n",
    "  classes (it can only create two half-spaces in the feature space).\n",
    "- **Width 2--3** is unstable: some seeds converge, others do not. The network\n",
    "  is operating at the edge of its representational capacity.\n",
    "- **Width 4** is the first width where all five seeds achieve high accuracy.\n",
    "  This aligns with theory: $k$ classes require at least $k+1$ hidden units in\n",
    "  the worst case for non-degenerate decision boundaries.\n",
    "- **Width 10--18** gives uniformly high accuracy with no seed sensitivity.\n",
    "\n",
    "```{admonition} Capacity vs Inductive Bias\n",
    ":class: note\n",
    "The MLP needs at least 4 hidden units to reliably solve a 3-class problem on\n",
    "8$\\times$8 images. The CNN, by contrast, solves the same problem with only 3\n",
    "convolutional filters plus a small output head. The difference is **inductive\n",
    "bias**: the CNN's built-in assumptions (local connectivity, weight sharing,\n",
    "translation equivariance) make the hypothesis space much smaller, so fewer\n",
    "parameters suffice.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-08-sec25-3-title",
   "metadata": {},
   "source": [
    "## 25.3 Adding a Fourth Class\n",
    "\n",
    "Our original dataset has three pattern classes: vertical, horizontal, and\n",
    "diagonal bars. What happens when we add a fourth class -- a centered **dot**\n",
    "pattern?\n",
    "\n",
    "For the CNN, we keep the same 3 convolutional filters and simply grow the\n",
    "output layer from 3 to 4 units. For the MLP, we keep hidden\\_size=18 and also\n",
    "grow the output layer.\n",
    "\n",
    "The central question: **Does the CNN need more filters?**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-09-four-class",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Four-class dataset\n",
    "x4_train, y4_train, x4_val, y4_val = make_dataset(class_names=DOT_CLASS_NAMES)\n",
    "\n",
    "# Train CNN (3 filters, 4 classes)\n",
    "cnn4 = TinyCNN(seed=3, num_classes=4, conv_filters=3)\n",
    "cnn4_history, _ = cnn4.fit(x4_train, y4_train, x4_val, y4_val)\n",
    "\n",
    "# Train MLP (hidden=18, 4 classes)\n",
    "mlp4 = TinyMLP(seed=5, hidden_size=18, num_classes=4)\n",
    "mlp4_history = mlp4.fit(x4_train, y4_train, x4_val, y4_val)\n",
    "\n",
    "# Two-panel comparison: 3-class vs 4-class\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4.2))\n",
    "fig.patch.set_facecolor(CREAM)\n",
    "\n",
    "# Left: 3-class results\n",
    "ax1.plot(cnn_epochs, cnn_val, color=CNN_BLUE, linewidth=2.2, label='CNN (3 filters)')\n",
    "ax1.plot(mlp_epochs, mlp_val, color=MLP_RED, linewidth=2.2, linestyle='--', label='MLP (h=18)')\n",
    "ax1.axhline(1.0, color='grey', linewidth=0.8, linestyle=':')\n",
    "ax1.set_ylim(0.0, 1.08)\n",
    "style_ax(ax1, title='3-Class Task', xlabel='Epoch', ylabel='Validation Accuracy')\n",
    "ax1.legend(fontsize=9, loc='lower right', framealpha=0.9)\n",
    "\n",
    "# Right: 4-class results\n",
    "cnn4_epochs = [h['epoch'] for h in cnn4_history]\n",
    "cnn4_val = [h['val_acc'] for h in cnn4_history]\n",
    "mlp4_epochs = [h['epoch'] for h in mlp4_history]\n",
    "mlp4_val = [h['val_acc'] for h in mlp4_history]\n",
    "\n",
    "ax2.plot(cnn4_epochs, cnn4_val, color=CNN_BLUE, linewidth=2.2, label=f'CNN (3 filters, {cnn4.total_parameters} params)')\n",
    "ax2.plot(mlp4_epochs, mlp4_val, color=MLP_RED, linewidth=2.2, linestyle='--', label=f'MLP (h=18, {mlp4.total_parameters} params)')\n",
    "ax2.axhline(1.0, color='grey', linewidth=0.8, linestyle=':')\n",
    "ax2.set_ylim(0.0, 1.08)\n",
    "style_ax(ax2, title='4-Class Task (+ dot)', xlabel='Epoch', ylabel='Validation Accuracy')\n",
    "ax2.legend(fontsize=9, loc='lower right', framealpha=0.9)\n",
    "\n",
    "fig.suptitle('Scaling from 3 to 4 Classes', fontsize=14, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(f'4-class CNN:  val_acc = {cnn4_history[-1][\"val_acc\"]:.3f}  ({cnn4.total_parameters} params, 3 filters)')\n",
    "print(f'4-class MLP:  val_acc = {mlp4_history[-1][\"val_acc\"]:.3f}  ({mlp4.total_parameters} params, h=18)')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-10-four-class-discussion",
   "metadata": {},
   "source": [
    "The answer is **no** -- the CNN does not need more convolutional filters to\n",
    "handle the fourth class. The convolutional layers act as a general-purpose\n",
    "**feature extractor**: they detect oriented edges and local intensity patterns\n",
    "regardless of how many classes the output head must distinguish. Adding a\n",
    "fourth class only requires one more output neuron in the dense layer, adding\n",
    "just $3 \\times 3 + 1 = 10$ parameters (one weight per pooled feature map\n",
    "location, plus one bias).\n",
    "\n",
    "This separation between the feature extractor and the classifier head is one\n",
    "of the most important architectural ideas in deep learning. It is the basis\n",
    "for **transfer learning**: pre-train the convolutional layers on a large\n",
    "dataset, then replace only the output head for a new task."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-11-sec25-4-title",
   "metadata": {},
   "source": [
    "## 25.4 Inference Trace\n",
    "\n",
    "To build intuition for what happens inside the CNN, we take a single input\n",
    "image and **trace** it through every stage of the pipeline:\n",
    "\n",
    "$$\n",
    "\\text{Input} \\;\\xrightarrow{\\text{Conv}}\\; \\text{Feature maps}\n",
    "\\;\\xrightarrow{\\text{ReLU}}\\; \\text{Activations}\n",
    "\\;\\xrightarrow{\\text{MaxPool}}\\; \\text{Pooled maps}\n",
    "\\;\\xrightarrow{\\text{Dense + Softmax}}\\; \\text{Class probabilities}\n",
    "$$\n",
    "\n",
    "We use the CNN trained on the 3-class data (from Section 25.1) and pass a\n",
    "single **vertical-bar** sample through it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-12-inference-trace",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Pick a vertical-bar sample from validation set\n",
    "vert_idx = np.where(y_val == 0)[0][0]\n",
    "sample = x_val[vert_idx:vert_idx+1]  # shape (1, 1, 8, 8)\n",
    "\n",
    "# Trace through the trained CNN\n",
    "trace = cnn.forward_with_trace(sample)\n",
    "\n",
    "n_filters = cnn.conv_filters\n",
    "fig, axes = plt.subplots(5, n_filters, figsize=(9, 12),\n",
    "                         gridspec_kw={'height_ratios': [1.2, 1.0, 1.0, 0.8, 1.0]})\n",
    "fig.patch.set_facecolor(CREAM)\n",
    "\n",
    "# Row 0: Input image (span all filter columns)\n",
    "for j in range(n_filters):\n",
    "    axes[0, j].set_visible(False)\n",
    "ax_input = fig.add_axes([0.35, 0.82, 0.3, 0.14])  # manual positioning\n",
    "ax_input.imshow(sample[0, 0], cmap='gray', vmin=0, vmax=1, aspect='equal')\n",
    "ax_input.set_title('Input (8x8)', fontsize=11, fontweight='bold')\n",
    "ax_input.set_xticks([]); ax_input.set_yticks([])\n",
    "ax_input.patch.set_facecolor(CREAM)\n",
    "\n",
    "# Row 1: Conv output (3 feature maps, 6x6)\n",
    "for j in range(n_filters):\n",
    "    im = axes[1, j].imshow(trace['conv'][0, j], cmap='RdBu_r', aspect='equal')\n",
    "    axes[1, j].set_title(f'Conv filter {j}', fontsize=9)\n",
    "    axes[1, j].set_xticks([]); axes[1, j].set_yticks([])\n",
    "    axes[1, j].patch.set_facecolor(CREAM)\n",
    "fig.text(0.02, 0.68, 'Conv\\nOutput', fontsize=10, fontweight='bold', va='center', ha='center')\n",
    "\n",
    "# Row 2: ReLU output\n",
    "for j in range(n_filters):\n",
    "    axes[2, j].imshow(trace['relu'][0, j], cmap='Oranges', aspect='equal', vmin=0)\n",
    "    axes[2, j].set_title(f'ReLU {j}', fontsize=9)\n",
    "    axes[2, j].set_xticks([]); axes[2, j].set_yticks([])\n",
    "    axes[2, j].patch.set_facecolor(CREAM)\n",
    "fig.text(0.02, 0.52, 'ReLU\\nOutput', fontsize=10, fontweight='bold', va='center', ha='center')\n",
    "\n",
    "# Row 3: Pooled output (3x3)\n",
    "for j in range(n_filters):\n",
    "    axes[3, j].imshow(trace['pool'][0, j], cmap='Oranges', aspect='equal', vmin=0)\n",
    "    axes[3, j].set_title(f'Pool {j}', fontsize=9)\n",
    "    axes[3, j].set_xticks([]); axes[3, j].set_yticks([])\n",
    "    axes[3, j].patch.set_facecolor(CREAM)\n",
    "fig.text(0.02, 0.36, 'MaxPool\\nOutput', fontsize=10, fontweight='bold', va='center', ha='center')\n",
    "\n",
    "# Row 4: Softmax bar chart (span all columns)\n",
    "for j in range(n_filters):\n",
    "    axes[4, j].set_visible(False)\n",
    "ax_bar = fig.add_axes([0.2, 0.05, 0.6, 0.15])\n",
    "probs = trace['probs'][0]\n",
    "bar_colors = [CNN_BLUE if i == 0 else '#aaaaaa' for i in range(len(probs))]\n",
    "bars = ax_bar.barh(CLASS_NAMES[:len(probs)], probs, color=bar_colors, edgecolor='white', height=0.5)\n",
    "for bar, p in zip(bars, probs):\n",
    "    ax_bar.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height()/2,\n",
    "               f'{p:.3f}', va='center', fontsize=10, fontweight='bold')\n",
    "ax_bar.set_xlim(0, 1.15)\n",
    "ax_bar.set_title('Softmax Probabilities', fontsize=11, fontweight='bold')\n",
    "ax_bar.set_facecolor(CREAM)\n",
    "ax_bar.spines['top'].set_visible(False)\n",
    "ax_bar.spines['right'].set_visible(False)\n",
    "\n",
    "fig.suptitle('Inference Trace: Vertical Bar through TinyCNN', fontsize=14,\n",
    "             fontweight='bold', y=0.99)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-13-trace-discussion",
   "metadata": {},
   "source": [
    "Reading the trace from top to bottom:\n",
    "\n",
    "1. **Input**: the 8$\\times$8 grayscale image shows a vertical bar (two bright\n",
    "   columns in the center).\n",
    "2. **Conv output**: each of the 3 learned filters responds differently to the\n",
    "   input. Some filters produce strong positive responses where the bar is,\n",
    "   others are more muted.\n",
    "3. **ReLU output**: negative activations are zeroed out. Only the regions\n",
    "   where a filter truly \"fires\" remain.\n",
    "4. **MaxPool output**: the 6$\\times$6 feature maps are reduced to 3$\\times$3\n",
    "   by taking the maximum in each 2$\\times$2 window. This provides a small\n",
    "   degree of translation invariance.\n",
    "5. **Softmax output**: the dense layer combines the 27 pooled features into\n",
    "   class logits, and softmax converts them to probabilities. The network\n",
    "   correctly assigns the highest probability to \"vertical\"."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-14-sec25-5-title",
   "metadata": {},
   "source": [
    "## 25.5 What the Filters Learn\n",
    "\n",
    "Perhaps the most compelling aspect of convolutional networks is that the\n",
    "learned filters are **interpretable**. Each $3\\times 3$ kernel is a tiny\n",
    "template that the network slides across the image, computing a local\n",
    "similarity score at each position.\n",
    "\n",
    "Let us visualize the three learned kernels from the trained CNN and see what\n",
    "patterns they detect."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-15-filter-viz",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": "# Visualize learned kernels\nkernels = cnn.conv.weights  # shape (3, 1, 3, 3)\n\nfig, axes = plt.subplots(1, 3, figsize=(9, 3.5))\nfig.patch.set_facecolor(CREAM)\n\nvmax = np.abs(kernels).max()\nfor i in range(3):\n    k = kernels[i, 0]  # shape (3, 3)\n    im = axes[i].imshow(k, cmap='RdBu_r', vmin=-vmax, vmax=vmax, aspect='equal')\n    axes[i].set_title(f'Filter {i}', fontsize=12, fontweight='bold')\n    axes[i].set_xticks(range(3)); axes[i].set_yticks(range(3))\n    axes[i].patch.set_facecolor(CREAM)\n    # Annotate each weight value\n    for r in range(3):\n        for c in range(3):\n            val = k[r, c]\n            color = 'white' if abs(val) > vmax * 0.6 else 'black'\n            axes[i].text(c, r, f'{val:.2f}', ha='center', va='center',\n                        fontsize=9, fontweight='bold', color=color)\n\nfig.colorbar(im, ax=axes, fraction=0.02, pad=0.04, label='Weight value')\nfig.suptitle('Learned Convolutional Filters (3x3)', fontsize=14, fontweight='bold')\nfig.subplots_adjust(top=0.85, bottom=0.05, wspace=0.3)\nplt.show()"
  },
  {
   "cell_type": "markdown",
   "id": "cell-16-filter-discussion",
   "metadata": {},
   "source": [
    "The trained filters typically show clear **oriented-edge** structure:\n",
    "\n",
    "- One filter develops a vertical gradient (strong positive weights in a\n",
    "  column, negative or near-zero elsewhere) -- it responds to vertical edges.\n",
    "- Another develops a horizontal gradient -- it responds to horizontal edges.\n",
    "- The third often captures diagonal or more complex patterns.\n",
    "\n",
    "These are exactly the features needed to distinguish our three pattern classes.\n",
    "The network has discovered, through gradient descent alone, that oriented edge\n",
    "detection is the right strategy.\n",
    "\n",
    "```{admonition} Connection to Neuroscience\n",
    ":class: tip\n",
    "In 1959, David Hubel and Torsten Wiesel discovered that neurons in the cat's\n",
    "primary visual cortex respond selectively to **oriented edges** at specific\n",
    "positions in the visual field. They called these **simple cells**. The filters\n",
    "learned by our CNN bear a striking resemblance to simple-cell receptive\n",
    "fields: small, spatially localized, and orientation-selective. This parallel\n",
    "between biological vision and artificial convolutional networks is not\n",
    "accidental -- Kunihiko Fukushima explicitly cited Hubel and Wiesel's work\n",
    "when he designed the Neocognitron (1980), the direct ancestor of modern CNNs.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-17-sec25-6-title",
   "metadata": {},
   "source": [
    "## 25.6 Scaling Up: From 8x8 to Real Images\n",
    "\n",
    "Our `TinyCNN` operates on 8$\\times$8 grayscale images with 3 filters and a\n",
    "single convolutional layer. Real-world convolutional networks follow the same\n",
    "architectural pattern -- **Conv $\\to$ ReLU $\\to$ Pool $\\to$ Dense** -- but at\n",
    "vastly larger scale. Let us trace the historical progression:\n",
    "\n",
    "| Network | Year | Input Size | Layers | Parameters | Key Innovation |\n",
    "|---------|------|-----------|--------|------------|----------------|\n",
    "| **LeNet-5** (LeCun) | 1998 | 32$\\times$32 | 7 | 60K | Proven on MNIST digits |\n",
    "| **AlexNet** (Krizhevsky) | 2012 | 227$\\times$227 | 8 | 60M | GPU training, dropout, ReLU |\n",
    "| **VGG-16** (Simonyan) | 2014 | 224$\\times$224 | 16 | 138M | Uniform 3$\\times$3 filters throughout |\n",
    "| **ResNet-50** (He) | 2015 | 224$\\times$224 | 50 | 25M | Skip connections, batch normalization |\n",
    "\n",
    "Several patterns emerge:\n",
    "\n",
    "**Depth increases.** LeNet-5 has 2 convolutional layers; ResNet-50 has 49.\n",
    "Deeper networks can learn hierarchical features: early layers detect edges,\n",
    "middle layers detect textures and parts, and deep layers detect whole objects.\n",
    "\n",
    "**Filter counts grow with depth.** A typical pattern is 64 filters in the\n",
    "first layer, 128 in the second, 256 in the third, and so on. As spatial\n",
    "resolution decreases (through pooling), the number of feature channels\n",
    "increases.\n",
    "\n",
    "**The core operation is unchanged.** Every network in the table above uses the\n",
    "same cross-correlation operation we implemented in `Conv2D.forward`. The\n",
    "mathematical foundations from our toy example carry over directly.\n",
    "\n",
    "```{admonition} From MNIST to ImageNet\n",
    ":class: note\n",
    "LeNet-5 was designed for the MNIST handwritten digit dataset (10 classes,\n",
    "60,000 training images of 28$\\times$28 pixels). AlexNet was the first CNN to\n",
    "win the ImageNet Large Scale Visual Recognition Challenge (1,000 classes,\n",
    "1.2 million training images of roughly 256$\\times$256 pixels). The jump from\n",
    "MNIST to ImageNet required not just bigger networks, but also GPU computing,\n",
    "data augmentation, and regularization techniques like dropout -- topics that\n",
    "go beyond our classical foundations but build directly on the principles we\n",
    "have studied.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-18-exercises-title",
   "metadata": {},
   "source": "## Exercises"
  },
  {
   "cell_type": "markdown",
   "id": "cell-19-exercises",
   "metadata": {},
   "source": [
    "**Exercise 25.1.** Add a fifth class to the dataset (for example, a\n",
    "\"cross\" pattern that combines vertical and horizontal bars). Train the\n",
    "CNN with 3 filters. Does it still achieve high accuracy? At what point\n",
    "(how many classes) do 3 filters become insufficient?\n",
    "\n",
    "**Exercise 25.2.** Modify `Conv2D` to support a $2\\times 2$ kernel\n",
    "instead of $3\\times 3$. Train on the 3-class data and compare the\n",
    "learned filters. Are $2\\times 2$ kernels expressive enough to\n",
    "distinguish the patterns? What about $5\\times 5$? (Note: a $5\\times 5$\n",
    "kernel on an $8\\times 8$ input produces only a $4\\times 4$ feature map,\n",
    "and after $2\\times 2$ pooling you get $2\\times 2$. This still works but\n",
    "leaves very little spatial information.)\n",
    "\n",
    "**Exercise 25.3.** Implement **strided convolution** by adding a\n",
    "`stride` parameter to `Conv2D`. With `stride=2`, the kernel moves two\n",
    "pixels at a time instead of one, reducing the output size without\n",
    "needing a separate pooling layer. Train a CNN that uses `Conv2D` with\n",
    "`stride=2` and no `MaxPool2D`. Compare the results.\n",
    "\n",
    "**Exercise 25.4.** Implement a simple **dropout** layer that randomly\n",
    "zeroes out activations during training (with probability $p = 0.3$)\n",
    "and scales the remaining activations by $\\frac{1}{1-p}$. During\n",
    "evaluation, dropout should be a no-op. Insert it between the flatten\n",
    "and dense layers of `TinyCNN`. Does it help on this toy task? (Hint:\n",
    "dropout is most useful when the model is overfitting, which our small\n",
    "CNN does not.)\n",
    "\n",
    "**Exercise 25.5.** Run the capacity sweep from Section 25.2 but for\n",
    "the **CNN** instead: vary the number of convolutional filters from 1 to\n",
    "8. How does the CNN's accuracy change? Is the CNN more or less\n",
    "sensitive to this hyperparameter than the MLP is to its hidden width?"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}