{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cell-0",
   "metadata": {},
   "source": [
    "# Chapter 23: Building a CNN from Scratch\n",
    "\n",
    "In Part V we built a fully-connected neural network from the ground up. Every input\n",
    "neuron was connected to every hidden neuron -- a sensible choice for small, unstructured\n",
    "inputs. But images have *spatial structure*: nearby pixels are related, and the same\n",
    "edge or texture can appear anywhere in the field of view. A **convolutional neural\n",
    "network** (CNN) exploits this structure by replacing the dense matrix multiply with a\n",
    "local, sliding-window operation called *convolution*.\n",
    "\n",
    "In this chapter we assemble all the building blocks of a small CNN:\n",
    "\n",
    "1. **Conv2D** -- the convolutional layer (defined in Chapter 22),\n",
    "2. **ReLU** -- the nonlinearity inserted after each convolution,\n",
    "3. **MaxPool2D** -- spatial downsampling that preserves the strongest activations,\n",
    "4. **Flatten** and **Dense** -- the bridge from spatial feature maps to class scores.\n",
    "\n",
    "We then combine them into a complete **TinyCNN** class and generate a synthetic\n",
    "dataset of oriented line patterns to test it on.\n",
    "\n",
    "```{admonition} Prerequisites\n",
    ":class: note\n",
    "This chapter assumes familiarity with the `Conv2D` layer introduced in Chapter 22.\n",
    "All code is pure NumPy -- no deep learning frameworks are used.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-1",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.style.use('seaborn-v0_8-whitegrid')\n",
    "plt.rcParams.update({\n",
    "    'figure.facecolor': '#FAF8F0',\n",
    "    'axes.facecolor': '#FAF8F0',\n",
    "    'font.size': 11,\n",
    "})\n",
    "\n",
    "# Project colour palette\n",
    "BLUE = '#3b82f6'\n",
    "BLUE_DARK = '#2563eb'\n",
    "GREEN = '#059669'\n",
    "GREEN_LIGHT = '#10b981'\n",
    "AMBER = '#d97706'\n",
    "RED = '#dc2626'\n",
    "BURGUNDY = '#8c2f39'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-2",
   "metadata": {},
   "source": [
    "We begin by defining the `Conv2D` layer. In this chapter we implement only the\n",
    "**forward pass**; the backward pass (needed for training) will be derived and\n",
    "implemented in Chapter 24.\n",
    "\n",
    "```{admonition} Definition (2D Convolution Layer)\n",
    ":class: note\n",
    "A **Conv2D** layer with $C_{\\text{out}}$ filters of size $k \\times k$ applied to\n",
    "an input with $C_{\\text{in}}$ channels computes\n",
    "\n",
    "$$y_{f,i,j} = b_f + \\sum_{c=1}^{C_{\\text{in}}} \\sum_{p=0}^{k-1} \\sum_{q=0}^{k-1}\n",
    "  K_{f,c,p,q} \\cdot x_{c,\\, i+p,\\, j+q}$$\n",
    "\n",
    "for each filter $f = 1, \\ldots, C_{\\text{out}}$ and each valid spatial position\n",
    "$(i, j)$. The output spatial dimensions are $H_{\\text{out}} = H - k + 1$ and\n",
    "$W_{\\text{out}} = W - k + 1$ (no padding, stride 1).\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Conv2D:\n",
    "    \"\"\"2D convolutional layer (forward pass only).\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    in_channels : int\n",
    "        Number of input channels.\n",
    "    out_channels : int\n",
    "        Number of filters (output channels).\n",
    "    kernel_size : int\n",
    "        Spatial size of each filter (kernel_size x kernel_size).\n",
    "    seed : int\n",
    "        Random seed for reproducibility.\n",
    "    \"\"\"\n",
    "    def __init__(self, in_channels, out_channels, kernel_size, seed=42):\n",
    "        rng = np.random.default_rng(seed)\n",
    "        fan_in = in_channels * kernel_size * kernel_size\n",
    "        scale = np.sqrt(2.0 / fan_in)          # He initialization\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.kernel_size = kernel_size\n",
    "        self.weights = rng.normal(0.0, scale,\n",
    "                                  size=(out_channels, in_channels,\n",
    "                                        kernel_size, kernel_size))\n",
    "        self.bias = np.zeros(out_channels)\n",
    "        self.last_input = None                  # cached for backward (ch24)\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"Slide each filter across the input and accumulate dot products.\n",
    "        \n",
    "        Parameters\n",
    "        ----------\n",
    "        x : ndarray, shape (batch, in_channels, height, width)\n",
    "        \n",
    "        Returns\n",
    "        -------\n",
    "        output : ndarray, shape (batch, out_channels, out_h, out_w)\n",
    "        \"\"\"\n",
    "        self.last_input = x\n",
    "        batch_size, _, height, width = x.shape\n",
    "        out_h = height - self.kernel_size + 1\n",
    "        out_w = width - self.kernel_size + 1\n",
    "        output = np.zeros((batch_size, self.out_channels, out_h, out_w))\n",
    "        for row in range(out_h):\n",
    "            for col in range(out_w):\n",
    "                patch = x[:, :,\n",
    "                          row:row + self.kernel_size,\n",
    "                          col:col + self.kernel_size]\n",
    "                # tensordot over (in_channels, kH, kW) axes\n",
    "                output[:, :, row, col] = (\n",
    "                    np.tensordot(patch, self.weights,\n",
    "                                 axes=([1, 2, 3], [1, 2, 3]))\n",
    "                    + self.bias\n",
    "                )\n",
    "        return output\n",
    "\n",
    "    @property\n",
    "    def parameter_count(self):\n",
    "        return int(self.weights.size + self.bias.size)\n",
    "\n",
    "\n",
    "print(\"Conv2D class defined (forward only).\")\n",
    "print(\"Attributes: weights, bias, last_input\")\n",
    "print(\"Methods: forward, parameter_count\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-4",
   "metadata": {},
   "source": [
    "## 23.1 ReLU After Convolution\n",
    "\n",
    "A convolution is a *linear* operation: it computes weighted sums over local patches.\n",
    "Stacking two linear operations without a nonlinearity in between collapses to a single\n",
    "linear operation -- exactly the lesson of Chapter 8 (the XOR problem). To build depth\n",
    "that matters, we insert a **rectified linear unit (ReLU)** after every convolution:\n",
    "\n",
    "$$\\text{ReLU}(z) = \\max(0, z).$$\n",
    "\n",
    "```{admonition} Why ReLU?\n",
    ":class: note\n",
    "ReLU is the default activation in modern CNNs for three reasons:\n",
    "1. **Sparse activation** -- roughly half the units output zero, yielding efficient representations.\n",
    "2. **No vanishing gradient** -- the gradient is either 0 or 1, so deep networks train easily.\n",
    "3. **Computational simplicity** -- a single comparison, much cheaper than sigmoid or tanh.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReLU:\n",
    "    \"\"\"Element-wise Rectified Linear Unit.\"\"\"\n",
    "    def __init__(self):\n",
    "        self.last_input = None\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.last_input = x\n",
    "        return np.maximum(0.0, x)\n",
    "\n",
    "    def backward(self, d_output):\n",
    "        return d_output * (self.last_input > 0.0)\n",
    "\n",
    "\n",
    "print(\"ReLU class defined.\")\n",
    "print(\"forward:  max(0, x)\")\n",
    "print(\"backward: pass gradient where x > 0, zero elsewhere\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-5a",
   "metadata": {},
   "source": [
    "The figure below shows what happens to a feature map before and after ReLU. All\n",
    "negative activations are set to zero, producing a sparser representation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-6",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Visualize pre- and post-ReLU feature maps\n",
    "rng_demo = np.random.default_rng(0)\n",
    "demo_input = rng_demo.normal(0, 1, size=(1, 1, 8, 8))  # single 8x8 image\n",
    "conv_demo = Conv2D(in_channels=1, out_channels=3, kernel_size=3, seed=0)\n",
    "relu_demo = ReLU()\n",
    "\n",
    "pre_relu = conv_demo.forward(demo_input)    # (1, 3, 6, 6)\n",
    "post_relu = relu_demo.forward(pre_relu)\n",
    "\n",
    "fig, axes = plt.subplots(2, 3, figsize=(10, 6.5))\n",
    "\n",
    "for f in range(3):\n",
    "    vmin = pre_relu[0, f].min()\n",
    "    vmax = pre_relu[0, f].max()\n",
    "    vm = max(abs(vmin), abs(vmax))\n",
    "    axes[0, f].imshow(pre_relu[0, f], cmap='RdBu_r', vmin=-vm, vmax=vm)\n",
    "    axes[0, f].set_title(f'Filter {f+1} (pre-ReLU)', fontsize=11)\n",
    "    axes[0, f].axis('off')\n",
    "    axes[1, f].imshow(post_relu[0, f], cmap='RdBu_r', vmin=-vm, vmax=vm)\n",
    "    axes[1, f].set_title(f'Filter {f+1} (post-ReLU)', fontsize=11)\n",
    "    axes[1, f].axis('off')\n",
    "\n",
    "fig.suptitle('Feature Maps Before and After ReLU', fontsize=13, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-7",
   "metadata": {},
   "source": [
    "## 23.2 Max Pooling\n",
    "\n",
    "After the convolution + ReLU pair has extracted local features, we want to\n",
    "**downsample** the feature maps. This serves two purposes:\n",
    "\n",
    "1. **Reduce computation** -- fewer spatial positions means fewer operations in later layers.\n",
    "2. **Translation invariance** -- small shifts in the input do not change the pooled output.\n",
    "\n",
    "```{admonition} Definition (Max Pooling)\n",
    ":class: note\n",
    "**Max pooling** with pool size $p$ partitions each channel of the feature map into\n",
    "non-overlapping $p \\times p$ windows and replaces each window with its maximum value:\n",
    "\n",
    "$$y_{c,i,j} = \\max_{0 \\le r,s < p}\\; x_{c,\\, ip+r,\\, jp+s}$$\n",
    "\n",
    "The output spatial dimensions are $H_{\\text{out}} = \\lfloor H/p \\rfloor$ and\n",
    "$W_{\\text{out}} = \\lfloor W/p \\rfloor$.\n",
    "```\n",
    "\n",
    "```{tip}\n",
    "Max pooling retains the *strongest* activation in each window, discarding precise\n",
    "spatial location. This is exactly the trade-off we want: \"Is there an edge *somewhere*\n",
    "in this region?\" rather than \"Is there an edge at pixel $(3, 7)$?\"\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-8",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MaxPool2D:\n",
    "    \"\"\"Max pooling with non-overlapping windows.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    pool_size : int\n",
    "        Side length of each pooling window.\n",
    "    \"\"\"\n",
    "    def __init__(self, pool_size=2):\n",
    "        self.pool_size = pool_size\n",
    "        self.last_input = None\n",
    "        self.last_mask = None      # records winner positions for backward\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.last_input = x\n",
    "        batch_size, channels, height, width = x.shape\n",
    "        out_h = height // self.pool_size\n",
    "        out_w = width // self.pool_size\n",
    "        output = np.zeros((batch_size, channels, out_h, out_w))\n",
    "        mask = np.zeros_like(x)\n",
    "        for row in range(out_h):\n",
    "            for col in range(out_w):\n",
    "                rs = row * self.pool_size\n",
    "                cs = col * self.pool_size\n",
    "                window = x[:, :, rs:rs + self.pool_size,\n",
    "                                  cs:cs + self.pool_size]\n",
    "                output[:, :, row, col] = window.max(axis=(2, 3))\n",
    "                flat_idx = window.reshape(\n",
    "                    batch_size, channels, -1).argmax(axis=2)\n",
    "                for b in range(batch_size):\n",
    "                    for c in range(channels):\n",
    "                        winner = flat_idx[b, c]\n",
    "                        wr = winner // self.pool_size\n",
    "                        wc = winner % self.pool_size\n",
    "                        mask[b, c, rs + wr, cs + wc] = 1.0\n",
    "        self.last_mask = mask\n",
    "        return output\n",
    "\n",
    "    def backward(self, d_output):\n",
    "        d_input = np.zeros_like(self.last_input)\n",
    "        out_h, out_w = d_output.shape[2], d_output.shape[3]\n",
    "        for row in range(out_h):\n",
    "            for col in range(out_w):\n",
    "                rs = row * self.pool_size\n",
    "                cs = col * self.pool_size\n",
    "                mask_w = self.last_mask[\n",
    "                    :, :, rs:rs + self.pool_size,\n",
    "                          cs:cs + self.pool_size]\n",
    "                d_input[:, :, rs:rs + self.pool_size,\n",
    "                              cs:cs + self.pool_size] += (\n",
    "                    mask_w * d_output[:, :, row, col][:, :, None, None]\n",
    "                )\n",
    "        return d_input\n",
    "\n",
    "\n",
    "print(\"MaxPool2D class defined.\")\n",
    "print(\"forward:  take the max in each pool_size x pool_size window\")\n",
    "print(\"backward: route gradient only to the max position\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-9",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Visualize max pooling on a 6x6 feature map -> 3x3\n",
    "rng_pool = np.random.default_rng(42)\n",
    "fmap = rng_pool.integers(0, 10, size=(1, 1, 6, 6)).astype(float)\n",
    "pool_demo = MaxPool2D(pool_size=2)\n",
    "pooled = pool_demo.forward(fmap)\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(10, 4))\n",
    "\n",
    "# Input 6x6\n",
    "ax = axes[0]\n",
    "im = ax.imshow(fmap[0, 0], cmap='Blues', vmin=0, vmax=10)\n",
    "for i in range(6):\n",
    "    for j in range(6):\n",
    "        val = int(fmap[0, 0, i, j])\n",
    "        is_max = pool_demo.last_mask[0, 0, i, j] > 0\n",
    "        color = RED if is_max else 'black'\n",
    "        weight = 'bold' if is_max else 'normal'\n",
    "        ax.text(j, i, str(val), ha='center', va='center',\n",
    "                fontsize=13, color=color, fontweight=weight)\n",
    "# Draw pool boundaries\n",
    "for k in range(0, 7, 2):\n",
    "    ax.axhline(k - 0.5, color='gray', linewidth=2)\n",
    "    ax.axvline(k - 0.5, color='gray', linewidth=2)\n",
    "ax.set_title('Input 6x6 (max positions in red)', fontsize=12)\n",
    "ax.set_xticks([])\n",
    "ax.set_yticks([])\n",
    "\n",
    "# Output 3x3\n",
    "ax = axes[1]\n",
    "ax.imshow(pooled[0, 0], cmap='Blues', vmin=0, vmax=10)\n",
    "for i in range(3):\n",
    "    for j in range(3):\n",
    "        ax.text(j, i, str(int(pooled[0, 0, i, j])),\n",
    "                ha='center', va='center', fontsize=15,\n",
    "                fontweight='bold', color=BLUE_DARK)\n",
    "ax.set_title('Output 3x3 (after MaxPool 2x2)', fontsize=12)\n",
    "ax.set_xticks([])\n",
    "ax.set_yticks([])\n",
    "\n",
    "fig.suptitle('Max Pooling: 6x6 -> 3x3 with pool_size=2',\n",
    "             fontsize=13, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(\"Each 2x2 block is replaced by its maximum value.\")\n",
    "print(\"The spatial resolution is halved in each dimension.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-10",
   "metadata": {},
   "source": [
    "## 23.3 Flatten and Dense\n",
    "\n",
    "After convolution, ReLU, and pooling, we have a 3-D tensor of shape\n",
    "`(batch, channels, height, width)`. To produce class scores we need a\n",
    "fully-connected (dense) layer, which expects a 1-D vector per sample.\n",
    "The **Flatten** layer reshapes the spatial feature maps into a flat vector:\n",
    "\n",
    "$$(B, C, H, W) \\longrightarrow (B, C \\cdot H \\cdot W).$$\n",
    "\n",
    "The **Dense** layer then computes the familiar linear transformation:\n",
    "\n",
    "$$\\mathbf{y} = \\bx \\bW + \\bb$$\n",
    "\n",
    "where $\\bW \\in \\mathbb{R}^{D_{\\text{in}} \\times D_{\\text{out}}}$ and\n",
    "$\\bb \\in \\mathbb{R}^{D_{\\text{out}}}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-11",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Flatten:\n",
    "    \"\"\"Reshape a 4-D tensor to 2-D (batch, features).\"\"\"\n",
    "    def __init__(self):\n",
    "        self.last_shape = None\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.last_shape = x.shape\n",
    "        return x.reshape(x.shape[0], -1)\n",
    "\n",
    "    def backward(self, d_output):\n",
    "        return d_output.reshape(self.last_shape)\n",
    "\n",
    "\n",
    "class Dense:\n",
    "    \"\"\"Fully-connected layer (forward pass only in this chapter).\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    in_features : int\n",
    "        Dimensionality of each input vector.\n",
    "    out_features : int\n",
    "        Number of output units.\n",
    "    seed : int\n",
    "        Random seed for reproducibility.\n",
    "    \"\"\"\n",
    "    def __init__(self, in_features, out_features, seed=42):\n",
    "        rng = np.random.default_rng(seed)\n",
    "        scale = np.sqrt(2.0 / in_features)  # He initialization\n",
    "        self.weights = rng.normal(0.0, scale,\n",
    "                                  size=(in_features, out_features))\n",
    "        self.bias = np.zeros(out_features)\n",
    "        self.last_input = None\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.last_input = x\n",
    "        return x @ self.weights + self.bias\n",
    "\n",
    "    @property\n",
    "    def parameter_count(self):\n",
    "        return int(self.weights.size + self.bias.size)\n",
    "\n",
    "\n",
    "print(\"Flatten and Dense classes defined (forward only).\")\n",
    "print(\"Flatten: (B, C, H, W) -> (B, C*H*W)\")\n",
    "print(\"Dense:   y = x @ W + b\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-12",
   "metadata": {},
   "source": [
    "We also need the **softmax** function and the **cross-entropy loss** for multi-class\n",
    "classification. These are identical to what we would use in a fully-connected network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-13",
   "metadata": {},
   "outputs": [],
   "source": [
    "def softmax(logits):\n",
    "    \"\"\"Numerically stable softmax.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    logits : ndarray, shape (batch, num_classes)\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    probabilities : ndarray, same shape as logits\n",
    "    \"\"\"\n",
    "    shifted = logits - logits.max(axis=1, keepdims=True)\n",
    "    exp_values = np.exp(shifted)\n",
    "    return exp_values / exp_values.sum(axis=1, keepdims=True)\n",
    "\n",
    "\n",
    "def softmax_cross_entropy(logits, targets):\n",
    "    \"\"\"Combined softmax + cross-entropy loss.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    logits : ndarray, shape (batch, num_classes)\n",
    "    targets : ndarray of int, shape (batch,)\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    loss : float\n",
    "    probabilities : ndarray, shape (batch, num_classes)\n",
    "    d_logits : ndarray, shape (batch, num_classes)\n",
    "    \"\"\"\n",
    "    probabilities = softmax(logits)\n",
    "    batch_idx = np.arange(targets.shape[0])\n",
    "    clipped = np.clip(probabilities[batch_idx, targets], 1e-12, 1.0)\n",
    "    loss = -np.log(clipped).mean()\n",
    "    d_logits = probabilities.copy()\n",
    "    d_logits[batch_idx, targets] -= 1.0\n",
    "    d_logits /= targets.shape[0]\n",
    "    return loss, probabilities, d_logits\n",
    "\n",
    "\n",
    "print(\"softmax and softmax_cross_entropy defined.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-14",
   "metadata": {},
   "source": [
    "## 23.4 The TinyCNN Class\n",
    "\n",
    "We now assemble the pieces into a complete network. Our architecture is deliberately\n",
    "small so that it trains in seconds on the CPU:\n",
    "\n",
    "$$\\text{Input}(1, 8, 8) \\;\\xrightarrow{\\text{Conv2D}(1{\\to}3,\\; 3{\\times}3)}\\;\n",
    "  (3, 6, 6) \\;\\xrightarrow{\\text{ReLU}}\\;\n",
    "  (3, 6, 6) \\;\\xrightarrow{\\text{MaxPool}(2)}\\;\n",
    "  (3, 3, 3) \\;\\xrightarrow{\\text{Flatten}}\\;\n",
    "  (27) \\;\\xrightarrow{\\text{Dense}(27{\\to}3)}\\;\n",
    "  (3)$$\n",
    "\n",
    "```{admonition} Architecture Summary\n",
    ":class: important\n",
    "**TinyCNN** has five layers:\n",
    "1. `Conv2D(1, 3, 3)` -- 3 filters of size $3 \\times 3$ on 1 input channel\n",
    "2. `ReLU` -- element-wise nonlinearity\n",
    "3. `MaxPool2D(2)` -- $2 \\times 2$ max pooling\n",
    "4. `Flatten` -- reshape $(3, 3, 3) \\to (27)$\n",
    "5. `Dense(27, 3)` -- fully-connected layer producing 3 class scores\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-15",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TinyCNN:\n",
    "    \"\"\"A minimal convolutional neural network for 8x8 grayscale images.\n",
    "    \n",
    "    Architecture:\n",
    "        Conv2D(1->3, 3x3) -> ReLU -> MaxPool(2x2) -> Flatten -> Dense(27->3)\n",
    "    \"\"\"\n",
    "    def __init__(self, seed=42):\n",
    "        self.conv = Conv2D(in_channels=1, out_channels=3,\n",
    "                           kernel_size=3, seed=seed)\n",
    "        self.relu = ReLU()\n",
    "        self.pool = MaxPool2D(pool_size=2)\n",
    "        self.flatten = Flatten()\n",
    "        self.dense = Dense(in_features=27, out_features=3,\n",
    "                           seed=seed + 1)\n",
    "        self.layers = [self.conv, self.relu, self.pool,\n",
    "                       self.flatten, self.dense]\n",
    "\n",
    "    def forward(self, x):\n",
    "        \"\"\"Forward pass through all layers.\"\"\"\n",
    "        for layer in self.layers:\n",
    "            x = layer.forward(x)\n",
    "        return x  # logits, shape (batch, 3)\n",
    "\n",
    "    def predict(self, x):\n",
    "        \"\"\"Return predicted class indices.\"\"\"\n",
    "        logits = self.forward(x)\n",
    "        return np.argmax(logits, axis=1)\n",
    "\n",
    "    def evaluate(self, x, y):\n",
    "        \"\"\"Compute accuracy on a dataset.\"\"\"\n",
    "        preds = self.predict(x)\n",
    "        return np.mean(preds == y)\n",
    "\n",
    "\n",
    "print(\"TinyCNN class defined.\")\n",
    "print(\"Architecture: Conv2D(1->3, 3x3) -> ReLU -> MaxPool(2) -> Flatten -> Dense(27->3)\")\n",
    "print(\"Methods: forward, predict, evaluate\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-16",
   "metadata": {},
   "source": [
    "## 23.5 Layer Summary\n",
    "\n",
    "Let us verify the shapes and count the parameters at each stage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-17",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Layer summary table\n",
    "model = TinyCNN(seed=42)\n",
    "\n",
    "# Run a dummy forward pass to capture shapes\n",
    "dummy = np.zeros((1, 1, 8, 8))\n",
    "shapes = [('Input', dummy.shape)]\n",
    "x = dummy\n",
    "for layer in model.layers:\n",
    "    x = layer.forward(x)\n",
    "    shapes.append((type(layer).__name__, x.shape))\n",
    "\n",
    "print(f'{\"Layer\":<12} {\"Output Shape\":<22} {\"Parameters\":>10}')\n",
    "print('=' * 46)\n",
    "\n",
    "total_params = 0\n",
    "for name, shape in shapes:\n",
    "    if name == 'Conv2D':\n",
    "        p = model.conv.parameter_count\n",
    "    elif name == 'Dense':\n",
    "        p = model.dense.parameter_count\n",
    "    else:\n",
    "        p = 0\n",
    "    total_params += p\n",
    "    shape_str = str(shape)\n",
    "    p_str = str(p) if p > 0 else '-'\n",
    "    print(f'{name:<12} {shape_str:<22} {p_str:>10}')\n",
    "\n",
    "print('=' * 46)\n",
    "print(f'{\"TOTAL\":<12} {\"\":<22} {total_params:>10}')\n",
    "print(f'\\nConv2D:  3 filters x (1 x 3 x 3) + 3 biases = {model.conv.parameter_count}')\n",
    "print(f'Dense:   27 x 3 weights + 3 biases = {model.dense.parameter_count}')\n",
    "print(f'Total trainable parameters: {total_params}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-18",
   "metadata": {},
   "source": [
    "## 23.6 The Synthetic Dataset\n",
    "\n",
    "To test our TinyCNN we create a simple dataset of $8 \\times 8$ grayscale images\n",
    "containing three classes of oriented line patterns:\n",
    "\n",
    "- **vertical** -- a vertical bar through the center,\n",
    "- **horizontal** -- a horizontal bar through the center,\n",
    "- **diagonal** -- a diagonal line from top-left to bottom-right.\n",
    "\n",
    "Each training example is generated by taking the base pattern, applying random\n",
    "vertical/horizontal shifts, and adding Gaussian noise. This mimics the kind of\n",
    "translation variability that makes CNNs superior to fully-connected networks.\n",
    "\n",
    "```{admonition} Why Synthetic Data?\n",
    ":class: tip\n",
    "Using synthetic data lets us control exactly what the network must learn.\n",
    "We *know* the ground truth -- three oriented patterns -- so we can later\n",
    "inspect whether the learned filters match these orientations.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-19",
   "metadata": {},
   "outputs": [],
   "source": [
    "CLASS_NAMES = (\"vertical\", \"horizontal\", \"diagonal\")\n",
    "\n",
    "\n",
    "def _pattern_for_class(class_name, size):\n",
    "    \"\"\"Generate the base pattern for a given class.\"\"\"\n",
    "    pattern = np.zeros((size, size))\n",
    "    center = size // 2\n",
    "    if class_name == \"vertical\":\n",
    "        pattern[:, center - 1:center + 1] = 1.0\n",
    "    elif class_name == \"horizontal\":\n",
    "        pattern[center - 1:center + 1, :] = 1.0\n",
    "    elif class_name == \"diagonal\":\n",
    "        np.fill_diagonal(pattern, 1.0)\n",
    "        pattern += 0.35 * np.eye(size, k=1)\n",
    "        pattern += 0.35 * np.eye(size, k=-1)\n",
    "    return np.clip(pattern, 0.0, 1.0)\n",
    "\n",
    "\n",
    "def make_dataset_bundle(train_per_class=60, val_per_class=30,\n",
    "                         size=8, seed=7, class_names=CLASS_NAMES):\n",
    "    \"\"\"Generate shifted, noisy variants of base patterns.\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    x_train : ndarray, shape (N_train, 1, size, size)\n",
    "    y_train : ndarray of int, shape (N_train,)\n",
    "    x_val : ndarray, shape (N_val, 1, size, size)\n",
    "    y_val : ndarray of int, shape (N_val,)\n",
    "    class_names : tuple of str\n",
    "    \"\"\"\n",
    "    rng = np.random.default_rng(seed)\n",
    "    all_x, all_y = [], []\n",
    "    total_per_class = train_per_class + val_per_class\n",
    "\n",
    "    for cls_idx, cls_name in enumerate(class_names):\n",
    "        base = _pattern_for_class(cls_name, size)\n",
    "        for _ in range(total_per_class):\n",
    "            # Random shift\n",
    "            shift_r = rng.integers(-2, 3)\n",
    "            shift_c = rng.integers(-2, 3)\n",
    "            shifted = np.roll(np.roll(base, shift_r, axis=0),\n",
    "                              shift_c, axis=1)\n",
    "            # Add noise\n",
    "            noisy = shifted + rng.normal(0, 0.15, size=(size, size))\n",
    "            noisy = np.clip(noisy, 0.0, 1.0)\n",
    "            all_x.append(noisy)\n",
    "            all_y.append(cls_idx)\n",
    "\n",
    "    all_x = np.array(all_x)[:, None, :, :]   # (N, 1, 8, 8)\n",
    "    all_y = np.array(all_y, dtype=int)\n",
    "\n",
    "    # Shuffle and split\n",
    "    perm = rng.permutation(len(all_y))\n",
    "    all_x, all_y = all_x[perm], all_y[perm]\n",
    "\n",
    "    n_train = train_per_class * len(class_names)\n",
    "    x_train, y_train = all_x[:n_train], all_y[:n_train]\n",
    "    x_val, y_val = all_x[n_train:], all_y[n_train:]\n",
    "\n",
    "    return x_train, y_train, x_val, y_val, class_names\n",
    "\n",
    "\n",
    "# Generate the dataset\n",
    "x_train, y_train, x_val, y_val, class_names = make_dataset_bundle()\n",
    "\n",
    "print(f\"Training set:   {x_train.shape[0]} images, shape {x_train.shape[1:]}\")\n",
    "print(f\"Validation set: {x_val.shape[0]} images, shape {x_val.shape[1:]}\")\n",
    "print(f\"Classes: {class_names}\")\n",
    "for i, name in enumerate(class_names):\n",
    "    n_tr = (y_train == i).sum()\n",
    "    n_va = (y_val == i).sum()\n",
    "    print(f\"  {name}: {n_tr} train, {n_va} val\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-20",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Gallery of example patterns (2 rows x 3 cols)\n",
    "fig, axes = plt.subplots(2, 3, figsize=(10, 6.5))\n",
    "\n",
    "for cls_idx in range(3):\n",
    "    # Show base pattern\n",
    "    base = _pattern_for_class(class_names[cls_idx], 8)\n",
    "    axes[0, cls_idx].imshow(base, cmap='gray_r', vmin=0, vmax=1)\n",
    "    axes[0, cls_idx].set_title(f'{class_names[cls_idx]} (base)',\n",
    "                                fontsize=12, fontweight='bold')\n",
    "    axes[0, cls_idx].axis('off')\n",
    "\n",
    "    # Show a noisy/shifted example\n",
    "    idx = np.where(y_train == cls_idx)[0][0]\n",
    "    axes[1, cls_idx].imshow(x_train[idx, 0], cmap='gray_r',\n",
    "                             vmin=0, vmax=1)\n",
    "    axes[1, cls_idx].set_title(f'{class_names[cls_idx]} (noisy)',\n",
    "                                fontsize=12)\n",
    "    axes[1, cls_idx].axis('off')\n",
    "\n",
    "fig.suptitle('Synthetic Line Patterns: Base vs Noisy Examples',\n",
    "             fontsize=13, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-21",
   "metadata": {},
   "source": [
    "Let us verify that the untrained TinyCNN produces essentially random predictions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-22",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "model_untrained = TinyCNN(seed=42)\n",
    "acc_train = model_untrained.evaluate(x_train, y_train)\n",
    "acc_val = model_untrained.evaluate(x_val, y_val)\n",
    "print(f\"Untrained TinyCNN accuracy:\")\n",
    "print(f\"  Train: {acc_train:.1%}  (chance = {1/3:.1%})\")\n",
    "print(f\"  Val:   {acc_val:.1%}\")\n",
    "print(\"\\nThe network needs training! That is the subject of Chapter 24.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-23",
   "metadata": {},
   "source": [
    "## Exercises\n",
    "\n",
    "**Exercise 23.1.** Compute the output shape of a Conv2D layer with `in_channels=3`,\n",
    "`out_channels=8`, `kernel_size=5` applied to an input of shape `(16, 3, 32, 32)`.\n",
    "How many parameters does this layer have?\n",
    "\n",
    "**Exercise 23.2.** What happens if we remove the ReLU between convolution and pooling?\n",
    "Explain why the network's representational power would be affected, connecting your\n",
    "answer to the XOR impossibility result of Chapter 8.\n",
    "\n",
    "**Exercise 23.3.** Average pooling replaces each window with its *mean* instead of its\n",
    "maximum. Implement an `AvgPool2D` class with `forward` and `backward` methods.\n",
    "What is the backward pass of average pooling?\n",
    "\n",
    "**Exercise 23.4.** Add a fourth class `\"dot\"` (a 2x2 square in the center) to the\n",
    "dataset. What changes are needed in the `TinyCNN` architecture? Modify the code\n",
    "and verify that the shapes are consistent.\n",
    "\n",
    "**Exercise 23.5.** Our TinyCNN has about 114 parameters. A fully-connected network\n",
    "mapping 64 inputs to 3 outputs through a hidden layer of 27 neurons would have\n",
    "$64 \\times 27 + 27 + 27 \\times 3 + 3 = 1839$ parameters. Explain the source of this\n",
    "large difference and discuss the concept of **parameter sharing** in CNNs."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
