{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1000001",
   "metadata": {},
   "source": [
    "# Chapter 22: The Convolution Operation\n",
    "\n",
    "In Chapter 21, we motivated the need for convolutional neural networks by identifying three structural assumptions about image data: locality, weight sharing, and translation equivariance. In this chapter, we make these ideas precise by defining the **convolution operation** mathematically, implementing it in NumPy, and demonstrating its power as a feature extractor on synthetic images."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000002",
   "metadata": {},
   "source": [
    "## 1. Cross-Correlation vs. Convolution\n",
    "\n",
    "In signal processing, **convolution** involves flipping the kernel before sliding it over the input. In deep learning, however, what we call \"convolution\" is actually **cross-correlation**\u2014the kernel is applied without flipping. Since the network learns the kernel weights, flipping is irrelevant: the network simply learns the already-flipped version.\n",
    "\n",
    "```{admonition} Warning: Naming Convention\n",
    ":class: warning\n",
    "\n",
    "Virtually all deep learning frameworks (TensorFlow, PyTorch, JAX) implement **cross-correlation** but call it \"convolution.\" We follow this convention throughout. When we say \"convolution,\" we mean cross-correlation unless explicitly stated otherwise.\n",
    "```\n",
    "\n",
    "### Mathematical Definition\n",
    "\n",
    "```{admonition} Definition (2D Cross-Correlation / \"Convolution\")\n",
    ":class: note\n",
    "\n",
    "Given a 2D input $\\mathbf{X} \\in \\mathbb{R}^{H \\times W}$ and a kernel $\\mathbf{K} \\in \\mathbb{R}^{K_h \\times K_w}$, the 2D cross-correlation (with bias $b$) produces an output $\\mathbf{Y}$ where:\n",
    "\n",
    "$$y_{i,j} = \\sum_{u=0}^{K_h-1}\\sum_{v=0}^{K_w-1} x_{i+u,\\, j+v} \\cdot k_{u,v} + b$$\n",
    "\n",
    "for $i = 0, 1, \\ldots, H - K_h$ and $j = 0, 1, \\ldots, W - K_w$.\n",
    "```\n",
    "\n",
    "In words: place the kernel's top-left corner at position $(i, j)$ in the input, compute the element-wise product of the kernel with the covered patch, sum all products, and add the bias. Repeat for every valid position.\n",
    "\n",
    "For a **true convolution**, the kernel is first flipped both horizontally and vertically:\n",
    "\n",
    "$$y_{i,j}^{\\text{(true conv)}} = \\sum_{u=0}^{K_h-1}\\sum_{v=0}^{K_w-1} x_{i+u,\\, j+v} \\cdot k_{K_h - 1 - u,\\, K_w - 1 - v} + b$$\n",
    "\n",
    "Since learned kernels have no predefined orientation, the distinction is purely academic in the context of neural networks."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000003",
   "metadata": {},
   "source": [
    "## 2. A 2D Convolution by Hand\n",
    "\n",
    "Let us work through a complete example. Consider a $5 \\times 5$ input and a $3 \\times 3$ kernel:\n",
    "\n",
    "$$\\mathbf{X} = \\begin{pmatrix} 1 & 0 & 2 & 1 & 0 \\\\ 0 & 1 & 1 & 0 & 2 \\\\ 2 & 0 & 0 & 1 & 1 \\\\ 1 & 1 & 2 & 0 & 0 \\\\ 0 & 2 & 1 & 1 & 1 \\end{pmatrix}, \\quad \\mathbf{K} = \\begin{pmatrix} 1 & 0 & -1 \\\\ 1 & 0 & -1 \\\\ 1 & 0 & -1 \\end{pmatrix}, \\quad b = 0$$\n",
    "\n",
    "The output has size $(5 - 3 + 1) \\times (5 - 3 + 1) = 3 \\times 3$.\n",
    "\n",
    "**Position $(0, 0)$:**\n",
    "\n",
    "$$y_{0,0} = 1 \\cdot 1 + 0 \\cdot 0 + 2 \\cdot (-1) + 0 \\cdot 1 + 1 \\cdot 0 + 1 \\cdot (-1) + 2 \\cdot 1 + 0 \\cdot 0 + 0 \\cdot (-1) = 1 - 2 - 1 + 2 = 0$$\n",
    "\n",
    "**Position $(0, 1)$:**\n",
    "\n",
    "$$y_{0,1} = 0 \\cdot 1 + 2 \\cdot 0 + 1 \\cdot (-1) + 1 \\cdot 1 + 1 \\cdot 0 + 0 \\cdot (-1) + 0 \\cdot 1 + 0 \\cdot 0 + 1 \\cdot (-1) = -1 + 1 - 1 = -1$$\n",
    "\n",
    "Continuing for all positions, we can verify the full result with the visualization below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1000004",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as mpatches\n",
    "\n",
    "plt.style.use('seaborn-v0_8-whitegrid')\n",
    "\n",
    "# Define input and kernel\n",
    "X = np.array([\n",
    "    [1, 0, 2, 1, 0],\n",
    "    [0, 1, 1, 0, 2],\n",
    "    [2, 0, 0, 1, 1],\n",
    "    [1, 1, 2, 0, 0],\n",
    "    [0, 2, 1, 1, 1]\n",
    "])\n",
    "\n",
    "K = np.array([\n",
    "    [1,  0, -1],\n",
    "    [1,  0, -1],\n",
    "    [1,  0, -1]\n",
    "])\n",
    "\n",
    "# Compute cross-correlation\n",
    "kh, kw = K.shape\n",
    "oh, ow = X.shape[0] - kh + 1, X.shape[1] - kw + 1\n",
    "Y = np.zeros((oh, ow))\n",
    "for i in range(oh):\n",
    "    for j in range(ow):\n",
    "        Y[i, j] = np.sum(X[i:i+kh, j:j+kw] * K)\n",
    "\n",
    "def draw_matrix(ax, mat, title, cmap='Blues', highlight=None, fontsize=14):\n",
    "    \"\"\"Draw a matrix as a coloured grid with values.\"\"\"\n",
    "    h, w = mat.shape\n",
    "    vmax = max(abs(mat.min()), abs(mat.max()), 1)\n",
    "    ax.imshow(mat, cmap=cmap, vmin=-vmax, vmax=vmax, aspect='equal')\n",
    "    for i in range(h):\n",
    "        for j in range(w):\n",
    "            color = 'white' if abs(mat[i, j]) > vmax * 0.6 else 'black'\n",
    "            ax.text(j, i, f'{mat[i,j]:.0f}', ha='center', va='center',\n",
    "                    fontsize=fontsize, fontweight='bold', color=color)\n",
    "    if highlight is not None:\n",
    "        r, c = highlight\n",
    "        rect = mpatches.Rectangle((c - 0.5, r - 0.5), kw, kh,\n",
    "                                   linewidth=3, edgecolor='#dc2626',\n",
    "                                   facecolor='none', linestyle='--')\n",
    "        ax.add_patch(rect)\n",
    "    ax.set_title(title, fontsize=12, fontweight='bold')\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, figsize=(12, 4))\n",
    "\n",
    "# Highlight position (0,0) in the input\n",
    "draw_matrix(axes[0], X, 'Input X (5x5)', cmap='Blues', highlight=(0, 0))\n",
    "draw_matrix(axes[1], K, 'Kernel K (3x3)', cmap='RdBu_r')\n",
    "draw_matrix(axes[2], Y, 'Output Y (3x3)', cmap='RdBu_r')\n",
    "\n",
    "# Highlight the (0,0) output position\n",
    "rect = mpatches.Rectangle((-0.5, -0.5), 1, 1, linewidth=3,\n",
    "                           edgecolor='#dc2626', facecolor='none', linestyle='--')\n",
    "axes[2].add_patch(rect)\n",
    "\n",
    "# Add operation symbols\n",
    "fig.text(0.355, 0.5, '*', fontsize=28, ha='center', va='center', fontweight='bold')\n",
    "fig.text(0.645, 0.5, '=', fontsize=28, ha='center', va='center', fontweight='bold')\n",
    "\n",
    "plt.suptitle('2D Cross-Correlation (\"Convolution\")', fontsize=14, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000005",
   "metadata": {},
   "source": [
    "Notice that this particular kernel is a **vertical edge detector**: it computes the difference between the left and right columns of each $3 \\times 3$ patch. Positive values in the output indicate left-to-right brightness transitions; negative values indicate right-to-left transitions."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000006",
   "metadata": {},
   "source": [
    "## 3. Output Size Formula\n",
    "\n",
    "The output size depends on three parameters beyond the input and kernel sizes:\n",
    "\n",
    "- **Padding** $P$: the number of zero-valued pixels added around the input border.\n",
    "- **Stride** $S$: the step size when sliding the kernel.\n",
    "\n",
    "```{admonition} Theorem (Output Size Formula)\n",
    ":class: note\n",
    "\n",
    "For an input of spatial size $W$, kernel size $K$, padding $P$, and stride $S$:\n",
    "\n",
    "$$\\text{output\\_size} = \\left\\lfloor \\frac{W - K + 2P}{S} \\right\\rfloor + 1$$\n",
    "\n",
    "The formula applies independently to height and width.\n",
    "```\n",
    "\n",
    "**Special cases:**\n",
    "\n",
    "| Configuration | Padding | Stride | Output Size |\n",
    "|:---:|:---:|:---:|:---:|\n",
    "| Valid (no padding) | $P = 0$ | $S = 1$ | $W - K + 1$ |\n",
    "| Same (preserve size) | $P = \\lfloor K/2 \\rfloor$ | $S = 1$ | $W$ |\n",
    "| Strided | $P = 0$ | $S > 1$ | $\\lfloor(W - K)/S\\rfloor + 1$ |\n",
    "\n",
    "### Worked Examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1000007",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def output_size(W, K, P=0, S=1):\n",
    "    \"\"\"Compute the output size of a convolution.\"\"\"\n",
    "    return (W - K + 2 * P) // S + 1\n",
    "\n",
    "# Example 1: Valid convolution\n",
    "print(\"Example 1: W=28, K=5, P=0, S=1\")\n",
    "print(f\"  Output size: {output_size(28, 5, 0, 1)}\")\n",
    "\n",
    "# Example 2: Same padding\n",
    "print(\"\\nExample 2: W=28, K=5, P=2, S=1  ('same' padding)\")\n",
    "print(f\"  Output size: {output_size(28, 5, 2, 1)}\")\n",
    "\n",
    "# Example 3: Stride 2\n",
    "print(\"\\nExample 3: W=28, K=5, P=0, S=2\")\n",
    "print(f\"  Output size: {output_size(28, 5, 0, 2)}\")\n",
    "\n",
    "# Example 4: Stride 2 with padding\n",
    "print(\"\\nExample 4: W=32, K=3, P=1, S=2\")\n",
    "print(f\"  Output size: {output_size(32, 3, 1, 2)}\")\n",
    "\n",
    "# Example 5: Large kernel\n",
    "print(\"\\nExample 5: W=224, K=7, P=3, S=2  (first layer of ResNet)\")\n",
    "print(f\"  Output size: {output_size(224, 7, 3, 2)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000008",
   "metadata": {},
   "source": [
    "```{admonition} Tip: \"Same\" Padding\n",
    ":class: tip\n",
    "\n",
    "To preserve the spatial dimensions with stride $S=1$, set padding to $P = \\lfloor K/2 \\rfloor$. For example, a $3 \\times 3$ kernel needs $P = 1$; a $5 \\times 5$ kernel needs $P = 2$. This is called **\"same\"** padding because the output has the **same** spatial size as the input.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000009",
   "metadata": {},
   "source": [
    "## 4. Implementing Conv2D in NumPy\n",
    "\n",
    "We now implement a full `Conv2D` layer that handles multi-channel inputs, multiple output filters, and batched data. The input tensor has shape $(N, C_{\\text{in}}, H, W)$ where $N$ is the batch size and $C_{\\text{in}}$ is the number of input channels (e.g., 3 for RGB).\n",
    "\n",
    "Each output filter has shape $(C_{\\text{in}}, K, K)$\u2014it spans all input channels. With $C_{\\text{out}}$ filters, the weight tensor has shape $(C_{\\text{out}}, C_{\\text{in}}, K, K)$.\n",
    "\n",
    "We use **He initialization** (Chapter 17), which sets the scale of initial weights to $\\sqrt{2 / \\text{fan\\_in}}$ where $\\text{fan\\_in} = C_{\\text{in}} \\cdot K \\cdot K$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1000010",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "class Conv2D:\n",
    "    def __init__(self, in_channels, out_channels, kernel_size, seed=42):\n",
    "        rng = np.random.default_rng(seed)\n",
    "        fan_in = in_channels * kernel_size * kernel_size\n",
    "        scale = np.sqrt(2.0 / fan_in)\n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.kernel_size = kernel_size\n",
    "        self.weights = rng.normal(0.0, scale, size=(out_channels, in_channels, kernel_size, kernel_size))\n",
    "        self.bias = np.zeros(out_channels)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        self.last_input = x\n",
    "        batch_size, _, height, width = x.shape\n",
    "        out_h = height - self.kernel_size + 1\n",
    "        out_w = width - self.kernel_size + 1\n",
    "        output = np.zeros((batch_size, self.out_channels, out_h, out_w))\n",
    "        for row in range(out_h):\n",
    "            for col in range(out_w):\n",
    "                patch = x[:, :, row:row+self.kernel_size, col:col+self.kernel_size]\n",
    "                output[:, :, row, col] = np.tensordot(patch, self.weights, axes=([1,2,3],[1,2,3])) + self.bias\n",
    "        return output\n",
    "    \n",
    "    @property\n",
    "    def parameter_count(self):\n",
    "        return int(self.weights.size + self.bias.size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1000011",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Verify our implementation\n",
    "conv = Conv2D(in_channels=1, out_channels=4, kernel_size=3)\n",
    "print(f\"Weight shape: {conv.weights.shape}\")\n",
    "print(f\"Bias shape:   {conv.bias.shape}\")\n",
    "print(f\"Parameters:   {conv.parameter_count}\")\n",
    "\n",
    "# Test with a batch of 2 grayscale 8x8 images\n",
    "x_test = np.random.default_rng(0).standard_normal((2, 1, 8, 8))\n",
    "y_test = conv.forward(x_test)\n",
    "print(f\"\\nInput shape:  {x_test.shape}\")\n",
    "print(f\"Output shape: {y_test.shape}\")\n",
    "print(f\"Expected:     (2, 4, 6, 6)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000012",
   "metadata": {},
   "source": [
    "Let us trace through the key line of the `forward` method:\n",
    "\n",
    "```python\n",
    "output[:, :, row, col] = np.tensordot(patch, self.weights, axes=([1,2,3],[1,2,3])) + self.bias\n",
    "```\n",
    "\n",
    "- `patch` has shape $(N, C_{\\text{in}}, K, K)$\u2014the input region under the kernel for all batch elements.\n",
    "- `self.weights` has shape $(C_{\\text{out}}, C_{\\text{in}}, K, K)$.\n",
    "- `tensordot` with `axes=([1,2,3],[1,2,3])` sums over channels $\\times$ kernel height $\\times$ kernel width, producing shape $(N, C_{\\text{out}})$.\n",
    "- This is stored at spatial position $(\\text{row}, \\text{col})$ for all batch elements and all output channels simultaneously.\n",
    "\n",
    "```{admonition} Note: Backward Pass\n",
    ":class: tip\n",
    "\n",
    "The `Conv2D` class above implements only the forward pass. The backward pass (computing gradients with respect to both the input and the kernel weights) will be derived in Chapter 24.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000013",
   "metadata": {},
   "source": [
    "## 5. Edge Detection on Synthetic Patterns\n",
    "\n",
    "Before we let the network learn its own filters, let us see what hand-crafted kernels can do. We will create simple $8 \\times 8$ synthetic patterns and apply classical edge-detection kernels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1000014",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.style.use('seaborn-v0_8-whitegrid')\n",
    "\n",
    "# Create synthetic 8x8 patterns\n",
    "def make_vertical_line(pos=3):\n",
    "    img = np.zeros((8, 8))\n",
    "    img[:, pos] = 1.0\n",
    "    img[:, pos+1] = 1.0\n",
    "    return img\n",
    "\n",
    "def make_horizontal_line(pos=3):\n",
    "    img = np.zeros((8, 8))\n",
    "    img[pos, :] = 1.0\n",
    "    img[pos+1, :] = 1.0\n",
    "    return img\n",
    "\n",
    "def make_diagonal():\n",
    "    img = np.zeros((8, 8))\n",
    "    for i in range(8):\n",
    "        img[i, i] = 1.0\n",
    "        if i + 1 < 8:\n",
    "            img[i, i+1] = 0.5\n",
    "        if i - 1 >= 0:\n",
    "            img[i, i-1] = 0.5\n",
    "    return img\n",
    "\n",
    "def make_box():\n",
    "    img = np.zeros((8, 8))\n",
    "    img[2:6, 2:6] = 1.0\n",
    "    return img\n",
    "\n",
    "patterns = {\n",
    "    'Vertical Line': make_vertical_line(),\n",
    "    'Horizontal Line': make_horizontal_line(),\n",
    "    'Diagonal': make_diagonal(),\n",
    "    'Box': make_box(),\n",
    "}\n",
    "\n",
    "fig, axes = plt.subplots(1, 4, figsize=(12, 3))\n",
    "for ax, (name, img) in zip(axes, patterns.items()):\n",
    "    ax.imshow(img, cmap='Blues', vmin=0, vmax=1)\n",
    "    ax.set_title(name, fontsize=11, fontweight='bold')\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    # Show grid lines\n",
    "    for i in range(9):\n",
    "        ax.axhline(i - 0.5, color='#94a3b8', linewidth=0.5)\n",
    "        ax.axvline(i - 0.5, color='#94a3b8', linewidth=0.5)\n",
    "\n",
    "plt.suptitle('Synthetic 8x8 Test Patterns', fontsize=13, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000015",
   "metadata": {},
   "source": [
    "Now we apply two classical edge-detection kernels:\n",
    "\n",
    "**Horizontal edge detector** (detects horizontal boundaries):\n",
    "\n",
    "$$\\mathbf{K}_{\\text{horiz}} = \\begin{pmatrix} -1 & -1 & -1 \\\\ 0 & 0 & 0 \\\\ 1 & 1 & 1 \\end{pmatrix}$$\n",
    "\n",
    "**Vertical edge detector** (detects vertical boundaries):\n",
    "\n",
    "$$\\mathbf{K}_{\\text{vert}} = \\begin{pmatrix} -1 & 0 & 1 \\\\ -1 & 0 & 1 \\\\ -1 & 0 & 1 \\end{pmatrix}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1000016",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.style.use('seaborn-v0_8-whitegrid')\n",
    "\n",
    "# Define edge detection kernels\n",
    "K_horiz = np.array([[-1, -1, -1],\n",
    "                    [ 0,  0,  0],\n",
    "                    [ 1,  1,  1]], dtype=float)\n",
    "\n",
    "K_vert = np.array([[-1, 0, 1],\n",
    "                   [-1, 0, 1],\n",
    "                   [-1, 0, 1]], dtype=float)\n",
    "\n",
    "def cross_correlate_2d(image, kernel):\n",
    "    \"\"\"Simple 2D cross-correlation (no padding, stride=1).\"\"\"\n",
    "    kh, kw = kernel.shape\n",
    "    oh = image.shape[0] - kh + 1\n",
    "    ow = image.shape[1] - kw + 1\n",
    "    output = np.zeros((oh, ow))\n",
    "    for i in range(oh):\n",
    "        for j in range(ow):\n",
    "            output[i, j] = np.sum(image[i:i+kh, j:j+kw] * kernel)\n",
    "    return output\n",
    "\n",
    "# Patterns to test\n",
    "patterns = {\n",
    "    'Vertical Line': make_vertical_line(),\n",
    "    'Horizontal Line': make_horizontal_line(),\n",
    "    'Diagonal': make_diagonal(),\n",
    "    'Box': make_box(),\n",
    "}\n",
    "\n",
    "kernels = {\n",
    "    'Horizontal Edge': K_horiz,\n",
    "    'Vertical Edge': K_vert,\n",
    "}\n",
    "\n",
    "fig, axes = plt.subplots(len(kernels), len(patterns) + 1, figsize=(14, 6))\n",
    "\n",
    "for ki, (kname, kernel) in enumerate(kernels.items()):\n",
    "    # Show the kernel in the first column\n",
    "    ax = axes[ki, 0]\n",
    "    vmax = max(abs(kernel.min()), abs(kernel.max()))\n",
    "    ax.imshow(kernel, cmap='RdBu_r', vmin=-vmax, vmax=vmax)\n",
    "    for i in range(3):\n",
    "        for j in range(3):\n",
    "            ax.text(j, i, f'{kernel[i,j]:.0f}', ha='center', va='center',\n",
    "                    fontsize=12, fontweight='bold',\n",
    "                    color='white' if abs(kernel[i,j]) > 0.5 else 'black')\n",
    "    ax.set_title(f'Kernel:\\n{kname}', fontsize=10, fontweight='bold')\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "    \n",
    "    # Apply kernel to each pattern\n",
    "    for pi, (pname, pattern) in enumerate(patterns.items()):\n",
    "        result = cross_correlate_2d(pattern, kernel)\n",
    "        ax = axes[ki, pi + 1]\n",
    "        vmax_r = max(abs(result.min()), abs(result.max()), 0.1)\n",
    "        ax.imshow(result, cmap='RdBu_r', vmin=-vmax_r, vmax=vmax_r)\n",
    "        ax.set_title(pname if ki == 0 else '', fontsize=10, fontweight='bold')\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        # Add grid\n",
    "        for i in range(result.shape[0] + 1):\n",
    "            ax.axhline(i - 0.5, color='#94a3b8', linewidth=0.3)\n",
    "        for j in range(result.shape[1] + 1):\n",
    "            ax.axvline(j - 0.5, color='#94a3b8', linewidth=0.3)\n",
    "\n",
    "plt.suptitle('Edge Detection: Kernel Responses to Synthetic Patterns',\n",
    "             fontsize=13, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000017",
   "metadata": {},
   "source": [
    "Observe the results:\n",
    "\n",
    "- The **horizontal edge kernel** produces strong responses on the horizontal line and the top/bottom edges of the box, but gives zero response on the vertical line (which has no horizontal gradients).\n",
    "- The **vertical edge kernel** lights up on the vertical line and the left/right edges of the box, but is blind to horizontal structures.\n",
    "- Both kernels respond to the diagonal, reflecting the fact that a diagonal edge has both horizontal and vertical components.\n",
    "\n",
    "This is the essence of convolutional feature extraction: different kernels are sensitive to different spatial patterns. A CNN learns the right set of kernels for its task automatically through backpropagation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000018",
   "metadata": {},
   "source": [
    "Now let us use our `Conv2D` class with hand-crafted weights to verify it produces the same results:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1000019",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use our Conv2D class with hand-crafted kernels\n",
    "conv_edge = Conv2D(in_channels=1, out_channels=2, kernel_size=3)\n",
    "\n",
    "# Manually set the kernels\n",
    "conv_edge.weights[0, 0] = K_horiz  # First filter: horizontal edge\n",
    "conv_edge.weights[1, 0] = K_vert   # Second filter: vertical edge\n",
    "conv_edge.bias[:] = 0.0\n",
    "\n",
    "# Create a batch with our test patterns\n",
    "test_batch = np.stack([\n",
    "    make_vertical_line(),\n",
    "    make_horizontal_line(),\n",
    "    make_box(),\n",
    "])[:, np.newaxis, :, :]  # Shape: (3, 1, 8, 8)\n",
    "\n",
    "output = conv_edge.forward(test_batch)\n",
    "print(f\"Input shape:  {test_batch.shape}  (batch=3, channels=1, 8x8)\")\n",
    "print(f\"Output shape: {output.shape}  (batch=3, filters=2, 6x6)\")\n",
    "print(f\"\\nMax response of horiz-edge filter on horizontal line: {output[1, 0].max():.1f}\")\n",
    "print(f\"Max response of vert-edge filter on vertical line:   {output[0, 1].max():.1f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000020",
   "metadata": {},
   "source": [
    "## 6. Multiple Filters and Channels\n",
    "\n",
    "In practice, a convolutional layer applies **multiple filters** to the input, each producing its own **feature map**. The collection of feature maps forms the layer's output tensor.\n",
    "\n",
    "### Shape Conventions\n",
    "\n",
    "Throughout this course, we use the **channels-first** convention (NCHW):\n",
    "\n",
    "| Tensor | Shape | Description |\n",
    "|:------:|:-----:|:------------|\n",
    "| Input | $(N, C_{\\text{in}}, H, W)$ | Batch of $N$ images, $C_{\\text{in}}$ channels |\n",
    "| Weights | $(C_{\\text{out}}, C_{\\text{in}}, K, K)$ | $C_{\\text{out}}$ filters, each spanning all input channels |\n",
    "| Bias | $(C_{\\text{out}},)$ | One bias per output channel |\n",
    "| Output | $(N, C_{\\text{out}}, H_{\\text{out}}, W_{\\text{out}})$ | $C_{\\text{out}}$ feature maps |\n",
    "\n",
    "### How Multi-Channel Convolution Works\n",
    "\n",
    "Each filter $\\bw_k \\in \\mathbb{R}^{C_{\\text{in}} \\times K \\times K}$ produces one output feature map by:\n",
    "\n",
    "1. Extracting a $(C_{\\text{in}}, K, K)$ patch from the input at each spatial position.\n",
    "2. Computing the dot product of this patch with the filter (summing over all channels and spatial positions within the kernel).\n",
    "3. Adding the bias $b_k$.\n",
    "\n",
    "```{admonition} Definition (Feature Map)\n",
    ":class: note\n",
    "\n",
    "A **feature map** is the 2D output produced by applying a single filter across all spatial positions of the input. If a convolutional layer has $C_{\\text{out}}$ filters, it produces $C_{\\text{out}}$ feature maps, which together form the output tensor.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1000021",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as mpatches\n",
    "\n",
    "plt.style.use('seaborn-v0_8-whitegrid')\n",
    "\n",
    "# Demonstrate multi-filter output\n",
    "# Create a more interesting 12x12 pattern\n",
    "rng = np.random.default_rng(42)\n",
    "img = np.zeros((12, 12))\n",
    "img[2:10, 4:8] = 1.0       # Vertical bar\n",
    "img[5:7, 1:11] = 1.0       # Horizontal bar (cross)\n",
    "\n",
    "# Apply 4 different kernels\n",
    "kernels_demo = {\n",
    "    'Horiz. Edge': np.array([[-1,-1,-1],[0,0,0],[1,1,1]], dtype=float),\n",
    "    'Vert. Edge':  np.array([[-1,0,1],[-1,0,1],[-1,0,1]], dtype=float),\n",
    "    'Sharpen':     np.array([[0,-1,0],[-1,5,-1],[0,-1,0]], dtype=float),\n",
    "    'Blur (avg)':  np.ones((3,3), dtype=float) / 9.0,\n",
    "}\n",
    "\n",
    "fig, axes = plt.subplots(1, 5, figsize=(15, 3))\n",
    "\n",
    "# Input\n",
    "axes[0].imshow(img, cmap='Blues', vmin=0, vmax=1)\n",
    "axes[0].set_title('Input (12x12)', fontsize=10, fontweight='bold')\n",
    "axes[0].set_xticks([])\n",
    "axes[0].set_yticks([])\n",
    "\n",
    "# Feature maps\n",
    "for idx, (kname, kernel) in enumerate(kernels_demo.items()):\n",
    "    kh, kw = kernel.shape\n",
    "    oh = img.shape[0] - kh + 1\n",
    "    ow = img.shape[1] - kw + 1\n",
    "    feat_map = np.zeros((oh, ow))\n",
    "    for i in range(oh):\n",
    "        for j in range(ow):\n",
    "            feat_map[i, j] = np.sum(img[i:i+kh, j:j+kw] * kernel)\n",
    "    \n",
    "    ax = axes[idx + 1]\n",
    "    vmax = max(abs(feat_map.min()), abs(feat_map.max()), 0.1)\n",
    "    ax.imshow(feat_map, cmap='RdBu_r', vmin=-vmax, vmax=vmax)\n",
    "    ax.set_title(f'Filter {idx+1}:\\n{kname}', fontsize=9, fontweight='bold')\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "\n",
    "plt.suptitle('One Input, Multiple Feature Maps', fontsize=13, fontweight='bold', y=1.05)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000022",
   "metadata": {},
   "source": [
    "Each filter extracts a different aspect of the input:\n",
    "\n",
    "- The **horizontal edge** filter highlights the top and bottom boundaries of the cross shape.\n",
    "- The **vertical edge** filter highlights the left and right boundaries.\n",
    "- The **sharpening** filter enhances contrast at all edges.\n",
    "- The **averaging** (blur) filter smooths the image, reducing noise.\n",
    "\n",
    "In a trained CNN, the network discovers which filters are useful for the task at hand. Early layers typically learn edge detectors and texture filters; deeper layers learn more complex patterns like corners, shapes, and eventually parts of objects.\n",
    "\n",
    "### Parameter Count\n",
    "\n",
    "The total number of learnable parameters in a `Conv2D` layer is:\n",
    "\n",
    "$$\\text{params} = C_{\\text{out}} \\times (C_{\\text{in}} \\times K \\times K + 1)$$\n",
    "\n",
    "where the $+1$ accounts for one bias per filter. This is **independent of the input spatial dimensions** $H$ and $W$\u2014one of the key advantages of convolution over fully connected layers."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000023",
   "metadata": {},
   "source": [
    "## 7. Exercises\n",
    "\n",
    "### Exercise 22.1: Convolution by Hand\n",
    "\n",
    "Compute the full output of the following convolution (no padding, stride 1, no bias):\n",
    "\n",
    "$$\\mathbf{X} = \\begin{pmatrix} 2 & 1 & 0 & 3 \\\\ 1 & 0 & 2 & 1 \\\\ 0 & 3 & 1 & 0 \\\\ 1 & 2 & 0 & 1 \\end{pmatrix}, \\quad \\mathbf{K} = \\begin{pmatrix} 1 & -1 \\\\ -1 & 1 \\end{pmatrix}$$\n",
    "\n",
    "Verify that the output has shape $3 \\times 3$ and compute all 9 values.\n",
    "\n",
    "### Exercise 22.2: Output Size Computation\n",
    "\n",
    "For each configuration, compute the output spatial dimensions:\n",
    "\n",
    "**(a)** Input: $32 \\times 32$, Kernel: $5 \\times 5$, Padding: 0, Stride: 1\n",
    "\n",
    "**(b)** Input: $32 \\times 32$, Kernel: $5 \\times 5$, Padding: 2, Stride: 1\n",
    "\n",
    "**(c)** Input: $32 \\times 32$, Kernel: $5 \\times 5$, Padding: 0, Stride: 2\n",
    "\n",
    "**(d)** Input: $224 \\times 224$, Kernel: $11 \\times 11$, Padding: 2, Stride: 4 (AlexNet first layer)\n",
    "\n",
    "**(e)** Input: $7 \\times 7$, Kernel: $3 \\times 3$, Padding: 1, Stride: 1. How many times can you apply this operation while maintaining the same spatial size?\n",
    "\n",
    "### Exercise 22.3: Parameter Comparison\n",
    "\n",
    "Consider processing a $64 \\times 64 \\times 3$ RGB input.\n",
    "\n",
    "**(a)** Compute the number of parameters in a fully connected layer that maps this input to 64 output units.\n",
    "\n",
    "**(b)** Compute the number of parameters in a `Conv2D` layer with 64 filters of size $3 \\times 3$ applied to the same input.\n",
    "\n",
    "**(c)** By what factor does the convolutional layer reduce the parameter count?\n",
    "\n",
    "### Exercise 22.4: Identity Kernel\n",
    "\n",
    "What $3 \\times 3$ kernel acts as the **identity** (output equals input, assuming valid padding)? Test your answer by applying it to a $5 \\times 5$ input matrix.\n",
    "\n",
    "### Exercise 22.5: Implementing Stride\n",
    "\n",
    "Extend the `Conv2D` class to support a `stride` parameter. The constructor should accept `stride=1` by default. Modify the `forward` method so that the kernel moves by `stride` pixels at each step. Verify that:\n",
    "\n",
    "- With stride 1 on an $8 \\times 8$ input with $3 \\times 3$ kernel, the output is $6 \\times 6$.\n",
    "- With stride 2 on an $8 \\times 8$ input with $3 \\times 3$ kernel, the output is $3 \\times 3$.\n",
    "\n",
    "*Hint:* The main change is in the loop bounds and indexing:\n",
    "```python\n",
    "for row in range(out_h):\n",
    "    for col in range(out_w):\n",
    "        patch = x[:, :, row*stride:row*stride+self.kernel_size,\n",
    "                        col*stride:col*stride+self.kernel_size]\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000024",
   "metadata": {},
   "source": [
    "## 8. Summary and Key Takeaways\n",
    "\n",
    "- In deep learning, \"convolution\" actually refers to **cross-correlation**: the kernel is not flipped before sliding.\n",
    "- The output of a 2D convolution at position $(i, j)$ is: $y_{i,j} = \\sum_{u,v} x_{i+u,j+v} \\cdot k_{u,v} + b$.\n",
    "- The **output size formula** is $\\lfloor(W - K + 2P)/S\\rfloor + 1$, where $W$ is input size, $K$ is kernel size, $P$ is padding, and $S$ is stride.\n",
    "- Our `Conv2D` class implements forward propagation with He initialization and supports batched, multi-channel inputs.\n",
    "- Hand-crafted kernels (horizontal/vertical edge detectors) demonstrate that convolution is a natural operation for feature extraction.\n",
    "- Multiple filters produce multiple **feature maps**, each sensitive to different spatial patterns.\n",
    "- Convolutional layers have far fewer parameters than equivalent fully connected layers, and the count is **independent of input spatial size**."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1000025",
   "metadata": {},
   "source": [
    "## 9. References\n",
    "\n",
    "1. Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner, \"Gradient-based learning applied to document recognition,\" *Proceedings of the IEEE*, vol. 86, no. 11, pp. 2278\u20132324, 1998.\n",
    "\n",
    "2. I. Goodfellow, Y. Bengio, and A. Courville, *Deep Learning*, MIT Press, 2016. Chapter 9: Convolutional Networks.\n",
    "\n",
    "3. A. Krizhevsky, I. Sutskever, and G. E. Hinton, \"ImageNet classification with deep convolutional neural networks,\" *Advances in Neural Information Processing Systems*, vol. 25, 2012."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}