{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chapter 31: PyTorch CNN on MNIST\n",
    "\n",
    "In Chapters 23-25 we built a convolutional neural network from scratch in NumPy\n",
    "on $8 \\times 8$ synthetic patterns. Our `TinyCNN` class implemented convolution,\n",
    "ReLU, max pooling, and a fully connected classifier -- all with hand-written\n",
    "forward and backward passes. The exercise was invaluable for understanding the\n",
    "mechanics of CNNs, but the implementation was slow and limited to tiny images.\n",
    "\n",
    "Now we rebuild the same architecture in PyTorch and train it on the full MNIST\n",
    "dataset -- 60,000 images of $28 \\times 28$ pixels. Where our NumPy CNN took\n",
    "minutes to train on 200 synthetic samples, PyTorch will process millions of\n",
    "images with automatic differentiation, optimized BLAS routines, and (optionally)\n",
    "GPU acceleration.\n",
    "\n",
    "The conceptual leap is small -- we already understand every layer -- but the\n",
    "practical leap is enormous."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "# Consistent style for all plots\n",
    "plt.rcParams.update({\n",
    "    'figure.dpi': 100,\n",
    "    'font.size': 11,\n",
    "    'axes.titlesize': 13,\n",
    "    'axes.labelsize': 12\n",
    "})\n",
    "\n",
    "# Standard color palette\n",
    "BLUE = '#3b82f6'\n",
    "GREEN = '#059669'\n",
    "RED = '#dc2626'\n",
    "AMBER = '#d97706'\n",
    "INDIGO = '#4f46e5'\n",
    "\n",
    "torch.manual_seed(42)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f'PyTorch version: {torch.__version__}')\n",
    "print(f'Device: {device}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 31.1 From TinyCNN to PyTorch\n",
    "\n",
    "Recall the architecture of our NumPy `TinyCNN` from Chapter 23:\n",
    "1. **Conv2D**: 1 input channel, 4 filters of size $3 \\times 3$\n",
    "2. **ReLU**: element-wise activation\n",
    "3. **MaxPool**: $2 \\times 2$ pooling with stride 2\n",
    "4. **Flatten**: reshape to a vector\n",
    "5. **Dense**: fully connected to 2 output classes\n",
    "\n",
    "Each of these required 50-100 lines of careful NumPy code for both forward\n",
    "and backward passes. In PyTorch, the same architecture is a few lines.\n",
    "\n",
    "The following table shows the exact correspondence:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# --- Side-by-side comparison table ---\n",
    "fig, ax = plt.subplots(figsize=(11, 4.5))\n",
    "ax.axis('off')\n",
    "\n",
    "table_data = [\n",
    "    ['Layer', 'NumPy TinyCNN (Ch. 23)', 'PyTorch CNN (Ch. 31)'],\n",
    "    ['Convolution', 'Conv2D(n_in, n_out, k)\\n+ hand-written backward', 'nn.Conv2d(n_in, n_out, k)'],\n",
    "    ['Activation', 'np.maximum(0, x)\\n+ manual gradient mask', 'nn.ReLU()'],\n",
    "    ['Pooling', 'MaxPool2D(size)\\n+ argmax index tracking', 'nn.MaxPool2d(size)'],\n",
    "    ['Flatten', 'x.reshape(batch, -1)', 'nn.Flatten()'],\n",
    "    ['Dense', 'DenseLayer(n_in, n_out)\\n+ hand-written backward', 'nn.Linear(n_in, n_out)'],\n",
    "    ['Backward pass', '~150 lines of manual code', 'loss.backward()  # 1 line'],\n",
    "    ['Update', 'param -= lr * grad', 'optimizer.step()'],\n",
    "]\n",
    "\n",
    "table = ax.table(cellText=table_data[1:], colLabels=table_data[0],\n",
    "                 cellLoc='left', loc='center',\n",
    "                 colWidths=[0.14, 0.40, 0.35])\n",
    "table.auto_set_font_size(False)\n",
    "table.set_fontsize(9)\n",
    "table.scale(1.0, 1.8)\n",
    "\n",
    "# Style header\n",
    "for j in range(3):\n",
    "    table[0, j].set_facecolor(INDIGO)\n",
    "    table[0, j].set_text_props(color='white', fontweight='bold')\n",
    "\n",
    "# Alternate row colors\n",
    "for i in range(1, len(table_data)):\n",
    "    color = '#f0f0ff' if i % 2 == 0 else 'white'\n",
    "    for j in range(3):\n",
    "        table[i, j].set_facecolor(color)\n",
    "\n",
    "ax.set_title('NumPy TinyCNN vs. PyTorch CNN: Layer-by-Layer Comparison',\n",
    "             fontsize=13, fontweight='bold', pad=20)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{admonition} The Key Insight\n",
    ":class: important\n",
    "The conceptual content is identical -- both implementations perform the same\n",
    "mathematical operations. But PyTorch's autograd eliminates the need to manually\n",
    "derive and implement backward passes for each layer. This is exactly the shift\n",
    "from Chapter 28 (manual backprop) to Chapter 29 (autograd): the mathematics\n",
    "stays the same, but the engineering burden drops dramatically.\n",
    "```\n",
    "\n",
    "## 31.2 Training on Full MNIST\n",
    "\n",
    "We scale up from the TinyCNN's 4 filters on $8 \\times 8$ images to a proper\n",
    "architecture for $28 \\times 28$ MNIST digits:\n",
    "\n",
    "$$\\text{Conv}(1 \\to 16, 3{\\times}3) \\to \\text{ReLU} \\to \\text{Pool}(2{\\times}2) \\to \\text{Conv}(16 \\to 32, 3{\\times}3) \\to \\text{ReLU} \\to \\text{Pool}(2{\\times}2) \\to \\text{Linear}(800 \\to 10)$$\n",
    "\n",
    "Let us trace the dimensions through the network:\n",
    "1. Input: $(B, 1, 28, 28)$\n",
    "2. After Conv1 ($3 \\times 3$, 16 filters): $(B, 16, 26, 26)$\n",
    "3. After Pool1 ($2 \\times 2$): $(B, 16, 13, 13)$\n",
    "4. After Conv2 ($3 \\times 3$, 32 filters): $(B, 32, 11, 11)$\n",
    "5. After Pool2 ($2 \\times 2$): $(B, 32, 5, 5)$\n",
    "6. After Flatten: $(B, 800)$\n",
    "7. After Linear: $(B, 10)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Define the CNN ---\n",
    "class MNISTCNN(nn.Module):\n",
    "    \"\"\"Two-layer CNN for MNIST, extending TinyCNN (Ch. 23) to full scale.\"\"\"\n",
    "    \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # Feature extraction (cf. TinyCNN's conv + pool)\n",
    "        self.conv1 = nn.Conv2d(1, 16, kernel_size=3)    # 28x28 -> 26x26\n",
    "        self.conv2 = nn.Conv2d(16, 32, kernel_size=3)   # 13x13 -> 11x11\n",
    "        self.pool = nn.MaxPool2d(2, 2)                   # halve spatial dims\n",
    "        \n",
    "        # Classifier (cf. TinyCNN's dense layer)\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.fc = nn.Linear(32 * 5 * 5, 10)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))   # (B,1,28,28) -> (B,16,13,13)\n",
    "        x = self.pool(F.relu(self.conv2(x)))   # (B,16,13,13) -> (B,32,5,5)\n",
    "        x = self.flatten(x)                     # (B,32,5,5) -> (B,800)\n",
    "        x = self.fc(x)                          # (B,800) -> (B,10)\n",
    "        return x\n",
    "\n",
    "torch.manual_seed(42)\n",
    "cnn_model = MNISTCNN().to(device)\n",
    "print(cnn_model)\n",
    "\n",
    "n_params = sum(p.numel() for p in cnn_model.parameters())\n",
    "print(f'\\nTotal parameters: {n_params:,}')\n",
    "\n",
    "# Verify dimensions with a dummy input\n",
    "dummy = torch.randn(1, 1, 28, 28).to(device)\n",
    "out = cnn_model(dummy)\n",
    "print(f'Input shape:  {dummy.shape}')\n",
    "print(f'Output shape: {out.shape}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Load MNIST data ---\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize((0.1307,), (0.3081,))\n",
    "])\n",
    "\n",
    "train_dataset = torchvision.datasets.MNIST(\n",
    "    root='./data', train=True, download=True, transform=transform\n",
    ")\n",
    "test_dataset = torchvision.datasets.MNIST(\n",
    "    root='./data', train=False, download=True, transform=transform\n",
    ")\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\n",
    "\n",
    "print(f'Training batches: {len(train_loader)}')\n",
    "print(f'Test batches: {len(test_loader)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Train the CNN ---\n",
    "torch.manual_seed(42)\n",
    "cnn_model = MNISTCNN().to(device)\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(cnn_model.parameters(), lr=0.001)\n",
    "\n",
    "cnn_train_losses = []\n",
    "cnn_test_accuracies = []\n",
    "n_epochs = 5\n",
    "\n",
    "for epoch in range(n_epochs):\n",
    "    # Training\n",
    "    cnn_model.train()\n",
    "    epoch_loss = 0.0\n",
    "    n_batches = 0\n",
    "    \n",
    "    for X_batch, y_batch in train_loader:\n",
    "        X_batch, y_batch = X_batch.to(device), y_batch.to(device)\n",
    "        \n",
    "        pred = cnn_model(X_batch)\n",
    "        loss = loss_fn(pred, y_batch)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        epoch_loss += loss.item()\n",
    "        n_batches += 1\n",
    "    \n",
    "    avg_loss = epoch_loss / n_batches\n",
    "    cnn_train_losses.append(avg_loss)\n",
    "    \n",
    "    # Evaluation\n",
    "    cnn_model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for X_batch, y_batch in test_loader:\n",
    "            X_batch, y_batch = X_batch.to(device), y_batch.to(device)\n",
    "            pred = cnn_model(X_batch)\n",
    "            _, predicted = torch.max(pred, 1)\n",
    "            total += y_batch.size(0)\n",
    "            correct += (predicted == y_batch).sum().item()\n",
    "    \n",
    "    accuracy = 100.0 * correct / total\n",
    "    cnn_test_accuracies.append(accuracy)\n",
    "    \n",
    "    print(f'Epoch {epoch+1}/{n_epochs} -- '\n",
    "          f'Train Loss: {avg_loss:.4f}, '\n",
    "          f'Test Accuracy: {accuracy:.2f}%')\n",
    "\n",
    "print(f'\\nFinal CNN test accuracy: {cnn_test_accuracies[-1]:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# --- Plot training curves ---\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
    "\n",
    "# Loss curve\n",
    "ax1.plot(range(1, n_epochs + 1), cnn_train_losses, 'o-', color=INDIGO,\n",
    "         linewidth=2, markersize=8)\n",
    "ax1.set_xlabel('Epoch')\n",
    "ax1.set_ylabel('Training Loss')\n",
    "ax1.set_title('CNN Training Loss', fontweight='bold')\n",
    "ax1.grid(True, alpha=0.3)\n",
    "ax1.set_xticks(range(1, n_epochs + 1))\n",
    "\n",
    "# Accuracy curve\n",
    "ax2.plot(range(1, n_epochs + 1), cnn_test_accuracies, 'o-', color=GREEN,\n",
    "         linewidth=2, markersize=8)\n",
    "ax2.set_xlabel('Epoch')\n",
    "ax2.set_ylabel('Test Accuracy (%)')\n",
    "ax2.set_title('CNN Test Accuracy', fontweight='bold')\n",
    "ax2.grid(True, alpha=0.3)\n",
    "ax2.set_xticks(range(1, n_epochs + 1))\n",
    "ax2.set_ylim(95, 100)\n",
    "ax2.axhline(y=98, color=RED, linestyle='--', alpha=0.5, label='98% target')\n",
    "ax2.legend()\n",
    "\n",
    "fig.suptitle('MNIST CNN Training (Conv16-Conv32-FC10)',\n",
    "             fontsize=14, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{admonition} CNN vs. MLP\n",
    ":class: tip\n",
    "Compare the CNN's accuracy (~98-99%) with the MLP's (~97%) from Chapter 30.\n",
    "The CNN achieves better performance with fewer parameters because convolutional\n",
    "layers exploit the spatial structure of images -- exactly the motivation we\n",
    "discussed in Chapter 21 (translation invariance, local connectivity, weight sharing).\n",
    "```\n",
    "\n",
    "## 31.3 Learned Filters\n",
    "\n",
    "In Chapter 25, we visualized the learned filters of our NumPy TinyCNN and observed\n",
    "that they resembled edge detectors. Let us perform the same analysis on our\n",
    "PyTorch CNN's first convolutional layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Extract and inspect first-layer filters ---\n",
    "filters = cnn_model.conv1.weight.data.cpu().numpy()\n",
    "print(f'First-layer filter shape: {filters.shape}')  # (16, 1, 3, 3)\n",
    "print(f'Number of filters: {filters.shape[0]}')\n",
    "print(f'Filter size: {filters.shape[2]}x{filters.shape[3]}')\n",
    "print(f'Value range: [{filters.min():.3f}, {filters.max():.3f}]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# --- Visualize all 16 first-layer filters ---\n",
    "fig, axes = plt.subplots(2, 8, figsize=(14, 4))\n",
    "\n",
    "vmax = max(abs(filters.min()), abs(filters.max()))\n",
    "\n",
    "for i, ax in enumerate(axes.flat):\n",
    "    kernel = filters[i, 0]  # shape: (3, 3)\n",
    "    im = ax.imshow(kernel, cmap='RdBu_r', vmin=-vmax, vmax=vmax,\n",
    "                   interpolation='nearest')\n",
    "    ax.set_title(f'Filter {i}', fontsize=9)\n",
    "    ax.axis('off')\n",
    "\n",
    "fig.suptitle('Learned First-Layer Convolution Filters (cf. Ch. 25 TinyCNN)',\n",
    "             fontsize=14, fontweight='bold')\n",
    "fig.colorbar(im, ax=axes, fraction=0.02, pad=0.04, label='Weight')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{admonition} Interpreting the Filters\n",
    ":class: note\n",
    "The first-layer filters learn to detect simple visual features: horizontal edges,\n",
    "vertical edges, diagonal edges, and simple gradients. This matches what we observed\n",
    "in Chapter 25 with TinyCNN and aligns with the classical findings of Hubel and\n",
    "Wiesel (1962) on simple cells in the cat visual cortex. The network has independently\n",
    "discovered edge detection as the optimal first processing step for digit recognition.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# --- Visualize feature maps for a sample digit ---\n",
    "sample_img, sample_label = test_dataset[0]\n",
    "sample_img = sample_img.unsqueeze(0).to(device)  # (1, 1, 28, 28)\n",
    "\n",
    "# Get activations after first conv layer\n",
    "cnn_model.eval()\n",
    "with torch.no_grad():\n",
    "    conv1_out = F.relu(cnn_model.conv1(sample_img))  # (1, 16, 26, 26)\n",
    "\n",
    "conv1_maps = conv1_out[0].cpu().numpy()  # (16, 26, 26)\n",
    "\n",
    "fig, axes = plt.subplots(2, 9, figsize=(14, 3.5))\n",
    "\n",
    "# Original image\n",
    "axes[0, 0].imshow(sample_img[0, 0].cpu(), cmap='gray')\n",
    "axes[0, 0].set_title(f'Input (digit {sample_label})', fontsize=9)\n",
    "axes[0, 0].axis('off')\n",
    "axes[1, 0].axis('off')\n",
    "\n",
    "# Feature maps\n",
    "for i in range(16):\n",
    "    row = i // 8\n",
    "    col = i % 8 + 1\n",
    "    axes[row, col].imshow(conv1_maps[i], cmap='viridis')\n",
    "    axes[row, col].set_title(f'Map {i}', fontsize=8)\n",
    "    axes[row, col].axis('off')\n",
    "\n",
    "fig.suptitle('Feature Maps After First Convolution Layer',\n",
    "             fontsize=13, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 31.4 Pretrained Models Preview\n",
    "\n",
    "Our MNIST CNN has ~13,000 parameters and achieves ~98-99% accuracy on a simple\n",
    "benchmark. In practice, modern computer vision uses much larger architectures\n",
    "pretrained on millions of images.\n",
    "\n",
    "PyTorch's `torchvision.models` provides ready-to-use architectures:\n",
    "\n",
    "| Model | Year | Parameters | ImageNet Top-1 |\n",
    "|-------|------|------------|----------------|\n",
    "| AlexNet | 2012 | 61M | 56.5% |\n",
    "| VGG-16 | 2014 | 138M | 71.6% |\n",
    "| ResNet-50 | 2015 | 25M | 76.1% |\n",
    "| EfficientNet-B0 | 2019 | 5.3M | 77.1% |\n",
    "| ViT-B/16 | 2020 | 86M | 77.9% |\n",
    "\n",
    "```{admonition} Transfer Learning\n",
    ":class: note\n",
    "A pretrained model's early layers learn universal visual features (edges, textures,\n",
    "shapes) that transfer across tasks. **Fine-tuning** -- replacing the final classification\n",
    "layer and training on a new dataset -- often achieves excellent results with very\n",
    "little data. We will explore transfer learning in detail in a later chapter.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Preview: listing available pretrained models ---\n",
    "# We only inspect the API here; downloading pretrained weights is deferred\n",
    "# to the transfer learning chapter.\n",
    "\n",
    "print('Selected torchvision.models architectures:')\n",
    "selected_models = ['resnet18', 'resnet50', 'vgg16', 'mobilenet_v2', 'efficientnet_b0']\n",
    "for name in selected_models:\n",
    "    model_fn = getattr(torchvision.models, name)\n",
    "    m = model_fn(weights=None)  # no pretrained weights\n",
    "    n_params = sum(p.numel() for p in m.parameters())\n",
    "    print(f'  {name:25s} -- {n_params:>12,} parameters')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 31.5 Framework Corner: Same CNN in Other Frameworks\n",
    "\n",
    "````{admonition} The Same Architecture in TensorFlow/Keras and JAX/Flax\n",
    ":class: tip dropdown\n",
    "\n",
    "**TensorFlow/Keras:**\n",
    "```python\n",
    "import tensorflow as tf\n",
    "\n",
    "model = tf.keras.Sequential([\n",
    "    tf.keras.layers.Conv2D(16, 3, activation='relu', input_shape=(28, 28, 1)),\n",
    "    tf.keras.layers.MaxPooling2D(2),\n",
    "    tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
    "    tf.keras.layers.MaxPooling2D(2),\n",
    "    tf.keras.layers.Flatten(),\n",
    "    tf.keras.layers.Dense(10),\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
    "              metrics=['accuracy'])\n",
    "model.fit(x_train, y_train, epochs=5, batch_size=64)\n",
    "```\n",
    "\n",
    "Note: Keras uses channels-last format `(H, W, C)` by default, while PyTorch uses channels-first `(C, H, W)`.\n",
    "\n",
    "**JAX/Flax:**\n",
    "```python\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from flax import linen as fnn\n",
    "\n",
    "class MNISTCNN(fnn.Module):\n",
    "    @fnn.compact\n",
    "    def __call__(self, x):\n",
    "        x = fnn.Conv(16, (3, 3))(x)\n",
    "        x = fnn.relu(x)\n",
    "        x = fnn.max_pool(x, (2, 2), strides=(2, 2))\n",
    "        x = fnn.Conv(32, (3, 3))(x)\n",
    "        x = fnn.relu(x)\n",
    "        x = fnn.max_pool(x, (2, 2), strides=(2, 2))\n",
    "        x = x.reshape((x.shape[0], -1))\n",
    "        x = fnn.Dense(10)(x)\n",
    "        return x\n",
    "```\n",
    "\n",
    "Flax follows a functional paradigm: parameters are passed explicitly rather than stored in the model object. This makes JAX models pure functions, enabling `jit`, `grad`, and `vmap` transformations.\n",
    "\n",
    "All three frameworks implement the same mathematical operations. The choice between them is primarily about API preference and ecosystem:\n",
    "- **PyTorch**: dominant in research, imperative style\n",
    "- **TensorFlow/Keras**: strong in deployment (TF Lite, TF Serving)\n",
    "- **JAX/Flax**: functional, composable transformations, Google TPU integration\n",
    "````"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercises\n",
    "\n",
    "**Exercise 31.1.** Add a third convolutional layer `Conv2d(32, 64, 3)` with ReLU and\n",
    "max pooling between `conv2` and the fully connected layer. Compute the new flattened\n",
    "dimension by tracing shapes through the network. Does the additional layer improve\n",
    "test accuracy? How many additional parameters does it add?\n",
    "\n",
    "**Exercise 31.2.** Replace `nn.MaxPool2d` with `nn.AvgPool2d` (average pooling) in\n",
    "the CNN. Train for 5 epochs and compare accuracy. Relate the difference to the\n",
    "discussion of pooling strategies in Chapter 23.\n",
    "\n",
    "**Exercise 31.3.** Visualize the **second-layer** feature maps (after `conv2`) for the\n",
    "same sample digit used in Section 31.3. The 32 feature maps of size $11 \\times 11$\n",
    "should show more abstract, higher-level features than the first layer. Display them\n",
    "in a $4 \\times 8$ grid.\n",
    "\n",
    "**Exercise 31.4.** Implement a function `count_parameters(model)` that prints the\n",
    "name, shape, and number of parameters for each layer. Apply it to both the MLP\n",
    "from Chapter 30 and the CNN from this chapter. Which architecture is more\n",
    "parameter-efficient, and why? (Hint: consider weight sharing in convolutional layers.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**References.**\n",
    "\n",
    "- LeCun, Y., Bottou, L., Bengio, Y., and Haffner, P. (1998). \"Gradient-Based Learning Applied to Document Recognition.\" *Proceedings of the IEEE*, 86(11), 2278-2324.\n",
    "- Paszke, A., Gross, S., Massa, F., et al. (2019). \"PyTorch: An Imperative Style, High-Performance Deep Learning Library.\" *NeurIPS 2019*.\n",
    "- He, K., Zhang, X., Ren, S., and Sun, J. (2016). \"Deep Residual Learning for Image Recognition.\" *CVPR 2016*.\n",
    "- Hubel, D. H. and Wiesel, T. N. (1962). \"Receptive fields, binocular interaction and functional architecture in the cat's visual cortex.\" *Journal of Physiology*, 160(1), 106-154.\n",
    "- Krizhevsky, A., Sutskever, I., and Hinton, G. E. (2012). \"ImageNet Classification with Deep Convolutional Neural Networks.\" *NeurIPS 2012*."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}