{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chapter 30: DataLoaders, Training Loops, and Evaluation\n",
    "\n",
    "In the previous chapters, we assembled all the ingredients for training neural networks:\n",
    "loss functions that measure prediction quality (Chapter 26), optimizers that update\n",
    "parameters intelligently (Chapter 27), automatic differentiation that computes gradients\n",
    "effortlessly (Chapter 28), and PyTorch's tensor abstractions that bring these pieces\n",
    "together at scale (Chapter 29).\n",
    "\n",
    "Now we combine them into the **training loop** -- the central algorithm of deep learning.\n",
    "Every neural network, from a 9-parameter XOR solver to a 175-billion-parameter GPT,\n",
    "is trained by the same iterative procedure: forward pass, loss computation, backward\n",
    "pass, parameter update. The details change; the structure does not.\n",
    "\n",
    "This chapter makes the training loop explicit, introduces PyTorch's data loading\n",
    "infrastructure, and applies everything to the MNIST handwritten digit recognition\n",
    "benchmark -- the \"hello world\" of deep learning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset, random_split\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "# Consistent style for all plots\n",
    "plt.rcParams.update({\n",
    "    'figure.dpi': 100,\n",
    "    'font.size': 11,\n",
    "    'axes.titlesize': 13,\n",
    "    'axes.labelsize': 12\n",
    "})\n",
    "\n",
    "# Standard color palette\n",
    "BLUE = '#3b82f6'\n",
    "GREEN = '#059669'\n",
    "RED = '#dc2626'\n",
    "AMBER = '#d97706'\n",
    "INDIGO = '#4f46e5'\n",
    "\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "\n",
    "print('PyTorch version:', torch.__version__)\n",
    "print('torchvision version:', torchvision.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 30.1 Training Loop Anatomy\n",
    "\n",
    "Every neural network training procedure follows the same pattern. We state it\n",
    "explicitly as an algorithm, mapping each step to the chapter where we derived\n",
    "the underlying mathematics.\n",
    "\n",
    "```{admonition} Algorithm: The Standard Training Loop\n",
    ":class: important\n",
    "\n",
    "**Input:** Model $f_\\theta$, training data $\\mathcal{D}$, loss function $\\mathcal{L}$,\n",
    "optimizer, number of epochs $E$, batch size $B$.\n",
    "\n",
    "1. **For** epoch $= 1, \\ldots, E$:\n",
    "   1. **Shuffle** $\\mathcal{D}$ and partition into mini-batches of size $B$.\n",
    "   2. **For** each mini-batch $(X_b, y_b)$:\n",
    "      1. **Forward pass:** Compute predictions $\\hat{y}_b = f_\\theta(X_b)$.  *(Ch. 29)*\n",
    "      2. **Loss:** Compute $L = \\mathcal{L}(\\hat{y}_b, y_b)$.  *(Ch. 26)*\n",
    "      3. **Zero gradients:** Set $\\nabla_\\theta = 0$.  *(Ch. 28, 29)*\n",
    "      4. **Backward pass:** Compute $\\nabla_\\theta L$ via autograd.  *(Ch. 28)*\n",
    "      5. **Update:** $\\theta \\leftarrow \\theta - \\eta \\cdot g(\\nabla_\\theta L)$ where $g$ is the optimizer rule.  *(Ch. 27)*\n",
    "   3. **Evaluate** on validation set (optional).\n",
    "2. **Return** trained model $f_\\theta$.\n",
    "```\n",
    "\n",
    "In PyTorch, this translates directly into code:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- The canonical PyTorch training loop (pseudocode made concrete) ---\n",
    "\n",
    "def train_one_epoch(model, dataloader, loss_fn, optimizer):\n",
    "    \"\"\"One pass through the entire training set.\"\"\"\n",
    "    model.train()                                # Set training mode\n",
    "    total_loss = 0.0\n",
    "    n_batches = 0\n",
    "    \n",
    "    for X_batch, y_batch in dataloader:           # Step 1.2: iterate mini-batches\n",
    "        pred = model(X_batch)                     # Step 1.2.1: forward pass\n",
    "        loss = loss_fn(pred, y_batch)             # Step 1.2.2: compute loss\n",
    "        \n",
    "        optimizer.zero_grad()                     # Step 1.2.3: zero gradients\n",
    "        loss.backward()                           # Step 1.2.4: backward pass\n",
    "        optimizer.step()                          # Step 1.2.5: update parameters\n",
    "        \n",
    "        total_loss += loss.item()\n",
    "        n_batches += 1\n",
    "    \n",
    "    return total_loss / n_batches\n",
    "\n",
    "print('train_one_epoch() defined -- maps directly to the algorithm above.')\n",
    "print('Each line corresponds to a step we derived from first principles.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 30.2 Dataset and DataLoader\n",
    "\n",
    "Real datasets are too large to fit in a single tensor multiplication. PyTorch's\n",
    "`torch.utils.data` module provides two key abstractions:\n",
    "\n",
    "- **`Dataset`**: Stores samples and their labels. Implements `__len__()` and `__getitem__()`.\n",
    "- **`DataLoader`**: Wraps a `Dataset` to provide iteration, batching, shuffling,\n",
    "  and parallel data loading.\n",
    "\n",
    "```{admonition} Why Mini-Batches?\n",
    ":class: note\n",
    "Recall from Chapter 27 that **stochastic gradient descent** uses a subset of the\n",
    "training data to estimate gradients. Mini-batches provide a favorable trade-off:\n",
    "- **Batch size 1** (pure SGD): very noisy gradients, slow convergence.\n",
    "- **Full batch** (GD): exact gradients, but one step requires processing all data.\n",
    "- **Mini-batch** (typical: 32-256): gradient noise provides implicit regularization,\n",
    "  and matrix operations are efficiently parallelized on modern hardware.\n",
    "```\n",
    "\n",
    "### A Simple Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- DataLoader basics ---\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Create a toy dataset: 100 samples, 5 features\n",
    "rng = np.random.default_rng(42)\n",
    "X_toy = torch.randn(100, 5)\n",
    "y_toy = (X_toy[:, 0] + X_toy[:, 1] > 0).long()  # binary classification\n",
    "\n",
    "dataset = TensorDataset(X_toy, y_toy)\n",
    "print(f'Dataset size: {len(dataset)}')\n",
    "print(f'One sample: X.shape={dataset[0][0].shape}, y={dataset[0][1]}')\n",
    "\n",
    "# DataLoader: batching + shuffling\n",
    "loader = DataLoader(dataset, batch_size=16, shuffle=True)\n",
    "\n",
    "print(f'\\nNumber of batches: {len(loader)}')\n",
    "for i, (X_b, y_b) in enumerate(loader):\n",
    "    if i < 3:\n",
    "        print(f'  Batch {i}: X.shape={X_b.shape}, y.shape={y_b.shape}')\n",
    "    else:\n",
    "        break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 30.3 Loss Functions and Optimizers\n",
    "\n",
    "PyTorch implements all the loss functions and optimizers we derived in Chapters 26-27\n",
    "as ready-to-use classes.\n",
    "\n",
    "### Loss Functions (Chapter 26 Revisited)\n",
    "\n",
    "| PyTorch Class | Mathematical Form | Use Case |\n",
    "|---------------|-------------------|----------|\n",
    "| `nn.MSELoss()` | $\\frac{1}{n}\\sum(y_i - \\hat{y}_i)^2$ | Regression |\n",
    "| `nn.CrossEntropyLoss()` | $-\\sum y_k \\log \\hat{y}_k$ | Multi-class classification |\n",
    "| `nn.BCEWithLogitsLoss()` | $-[y\\log\\sigma(z) + (1-y)\\log(1-\\sigma(z))]$ | Binary classification |\n",
    "\n",
    "```{admonition} CrossEntropyLoss = LogSoftmax + NLLLoss\n",
    ":class: warning\n",
    "PyTorch's `nn.CrossEntropyLoss` expects **raw logits** (unnormalized scores),\n",
    "not probabilities. It internally applies log-softmax for numerical stability.\n",
    "Do **not** apply softmax before `CrossEntropyLoss` -- this is the most common\n",
    "PyTorch beginner mistake.\n",
    "```\n",
    "\n",
    "### Optimizers (Chapter 27 Revisited)\n",
    "\n",
    "| PyTorch Class | Algorithm | Key Parameters |\n",
    "|---------------|-----------|----------------|\n",
    "| `optim.SGD` | (Stochastic) Gradient Descent | `lr`, `momentum` |\n",
    "| `optim.Adam` | Adaptive Moment Estimation | `lr`, `betas`, `eps` |\n",
    "| `optim.RMSprop` | Root Mean Square Propagation | `lr`, `alpha` |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Loss function demonstration ---\n",
    "torch.manual_seed(42)\n",
    "\n",
    "# CrossEntropyLoss expects raw logits, NOT probabilities\n",
    "logits = torch.tensor([[2.0, 1.0, 0.1],    # sample 1: class 0 has highest score\n",
    "                        [0.5, 2.5, 0.3]])   # sample 2: class 1 has highest score\n",
    "targets = torch.tensor([0, 1])               # correct classes\n",
    "\n",
    "ce_loss = nn.CrossEntropyLoss()\n",
    "loss = ce_loss(logits, targets)\n",
    "print(f'CrossEntropyLoss: {loss.item():.4f}')\n",
    "\n",
    "# Manual verification (Chapter 26 formula)\n",
    "import torch.nn.functional as F\n",
    "log_probs = F.log_softmax(logits, dim=1)\n",
    "manual_loss = -log_probs[0, 0] - log_probs[1, 1]  # negative log-prob of correct classes\n",
    "manual_loss = manual_loss / 2  # mean over batch\n",
    "print(f'Manual computation: {manual_loss.item():.4f}')\n",
    "print(f'Match: {torch.isclose(loss, manual_loss).item()}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 30.4 MNIST MLP\n",
    "\n",
    "The MNIST dataset (LeCun et al., 1998) consists of 70,000 handwritten digit images,\n",
    "each $28 \\times 28$ pixels in grayscale. It is the standard benchmark for introducing\n",
    "neural network training on real data.\n",
    "\n",
    "```{admonition} Historical Note\n",
    ":class: note\n",
    "MNIST was created by Yann LeCun, Corinna Cortes, and Christopher J.C. Burges\n",
    "at AT&T Bell Labs. The dataset was derived from NIST Special Database 3 (Census\n",
    "Bureau employees) and Special Database 1 (high school students). LeCun used\n",
    "MNIST to demonstrate the effectiveness of convolutional networks in his landmark\n",
    "1998 paper -- we will replicate this in Chapter 31. For now, we use a simple\n",
    "multi-layer perceptron that treats each image as a flat 784-dimensional vector.\n",
    "```\n",
    "\n",
    "### Loading the Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Load MNIST ---\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),           # Convert PIL image to tensor [0, 1]\n",
    "    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std\n",
    "])\n",
    "\n",
    "train_dataset = torchvision.datasets.MNIST(\n",
    "    root='./data', train=True, download=True, transform=transform\n",
    ")\n",
    "test_dataset = torchvision.datasets.MNIST(\n",
    "    root='./data', train=False, download=True, transform=transform\n",
    ")\n",
    "\n",
    "print(f'Training samples: {len(train_dataset)}')\n",
    "print(f'Test samples:     {len(test_dataset)}')\n",
    "print(f'Image shape:      {train_dataset[0][0].shape}')\n",
    "print(f'Label example:    {train_dataset[0][1]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# --- Visualize sample digits ---\n",
    "fig, axes = plt.subplots(2, 8, figsize=(12, 3.5))\n",
    "for i, ax in enumerate(axes.flat):\n",
    "    img, label = train_dataset[i]\n",
    "    ax.imshow(img.squeeze(), cmap='gray')\n",
    "    ax.set_title(str(label), fontsize=11)\n",
    "    ax.axis('off')\n",
    "fig.suptitle('MNIST Sample Images', fontsize=14, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Create DataLoaders ---\n",
    "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\n",
    "\n",
    "print(f'Training batches per epoch: {len(train_loader)}')\n",
    "print(f'Test batches: {len(test_loader)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the MLP\n",
    "\n",
    "Our architecture: $784 \\to 128 \\to 64 \\to 10$. Each hidden layer uses ReLU\n",
    "activation (Chapter 17). The output layer produces raw logits for 10 classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Define the MLP ---\n",
    "class MNISTMLP(nn.Module):\n",
    "    \"\"\"Multi-layer perceptron for MNIST digit classification.\"\"\"\n",
    "    \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()          # (B, 1, 28, 28) -> (B, 784)\n",
    "        self.layers = nn.Sequential(\n",
    "            nn.Linear(784, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 10),               # 10 classes, raw logits\n",
    "        )\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.flatten(x)\n",
    "        return self.layers(x)\n",
    "\n",
    "torch.manual_seed(42)\n",
    "model = MNISTMLP()\n",
    "print(model)\n",
    "\n",
    "n_params = sum(p.numel() for p in model.parameters())\n",
    "print(f'\\nTotal trainable parameters: {n_params:,}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Train the MLP ---\n",
    "torch.manual_seed(42)\n",
    "model = MNISTMLP()\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "train_losses = []\n",
    "test_accuracies = []\n",
    "n_epochs = 5\n",
    "\n",
    "for epoch in range(n_epochs):\n",
    "    # Training\n",
    "    model.train()\n",
    "    epoch_loss = 0.0\n",
    "    n_batches = 0\n",
    "    \n",
    "    for X_batch, y_batch in train_loader:\n",
    "        pred = model(X_batch)\n",
    "        loss = loss_fn(pred, y_batch)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        epoch_loss += loss.item()\n",
    "        n_batches += 1\n",
    "    \n",
    "    avg_loss = epoch_loss / n_batches\n",
    "    train_losses.append(avg_loss)\n",
    "    \n",
    "    # Evaluation\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for X_batch, y_batch in test_loader:\n",
    "            pred = model(X_batch)\n",
    "            _, predicted = torch.max(pred, 1)\n",
    "            total += y_batch.size(0)\n",
    "            correct += (predicted == y_batch).sum().item()\n",
    "    \n",
    "    accuracy = 100.0 * correct / total\n",
    "    test_accuracies.append(accuracy)\n",
    "    \n",
    "    print(f'Epoch {epoch+1}/{n_epochs} -- '\n",
    "          f'Train Loss: {avg_loss:.4f}, '\n",
    "          f'Test Accuracy: {accuracy:.2f}%')\n",
    "\n",
    "print(f'\\nFinal test accuracy: {test_accuracies[-1]:.2f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# --- Plot training curves ---\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
    "\n",
    "# Loss curve\n",
    "ax1.plot(range(1, n_epochs + 1), train_losses, 'o-', color=INDIGO, linewidth=2, markersize=8)\n",
    "ax1.set_xlabel('Epoch')\n",
    "ax1.set_ylabel('Training Loss (Cross-Entropy)')\n",
    "ax1.set_title('Training Loss', fontweight='bold')\n",
    "ax1.grid(True, alpha=0.3)\n",
    "ax1.set_xticks(range(1, n_epochs + 1))\n",
    "\n",
    "# Accuracy curve\n",
    "ax2.plot(range(1, n_epochs + 1), test_accuracies, 'o-', color=GREEN, linewidth=2, markersize=8)\n",
    "ax2.set_xlabel('Epoch')\n",
    "ax2.set_ylabel('Test Accuracy (%)')\n",
    "ax2.set_title('Test Accuracy', fontweight='bold')\n",
    "ax2.grid(True, alpha=0.3)\n",
    "ax2.set_xticks(range(1, n_epochs + 1))\n",
    "ax2.set_ylim(90, 100)\n",
    "ax2.axhline(y=97, color=RED, linestyle='--', alpha=0.5, label='97% target')\n",
    "ax2.legend()\n",
    "\n",
    "fig.suptitle('MNIST MLP Training (784-128-64-10)', fontsize=14, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 30.5 Evaluation Best Practices\n",
    "\n",
    "Proper evaluation requires care to avoid subtle bugs and misleading metrics.\n",
    "\n",
    "### model.train() vs. model.eval()\n",
    "\n",
    "Some layers behave differently during training and evaluation:\n",
    "- **Dropout** (Chapter 33, upcoming): randomly zeros activations during training,\n",
    "  scales outputs during evaluation.\n",
    "- **BatchNorm**: uses batch statistics during training, running averages during evaluation.\n",
    "\n",
    "Always call `model.eval()` before evaluation and `model.train()` before resuming training.\n",
    "\n",
    "### torch.no_grad()\n",
    "\n",
    "During evaluation, we do not need gradients. The `torch.no_grad()` context manager\n",
    "disables gradient computation, saving memory and computation.\n",
    "\n",
    "```{admonition} Common Mistake\n",
    ":class: danger\n",
    "Forgetting `model.eval()` or `torch.no_grad()` during evaluation does not cause\n",
    "errors -- it silently produces incorrect results (if the model uses dropout or\n",
    "batchnorm) or wastes memory. Always use both.\n",
    "```\n",
    "\n",
    "### Confusion Matrix\n",
    "\n",
    "A confusion matrix reveals per-class performance -- essential for understanding\n",
    "which digits the model confuses."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Compute confusion matrix ---\n",
    "model.eval()\n",
    "all_preds = []\n",
    "all_labels = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for X_batch, y_batch in test_loader:\n",
    "        pred = model(X_batch)\n",
    "        _, predicted = torch.max(pred, 1)\n",
    "        all_preds.extend(predicted.numpy())\n",
    "        all_labels.extend(y_batch.numpy())\n",
    "\n",
    "all_preds = np.array(all_preds)\n",
    "all_labels = np.array(all_labels)\n",
    "\n",
    "# Build confusion matrix manually (no sklearn dependency)\n",
    "n_classes = 10\n",
    "conf_matrix = np.zeros((n_classes, n_classes), dtype=int)\n",
    "for true, pred in zip(all_labels, all_preds):\n",
    "    conf_matrix[true, pred] += 1\n",
    "\n",
    "# Per-class accuracy\n",
    "print('Per-class accuracy:')\n",
    "for digit in range(10):\n",
    "    total = conf_matrix[digit].sum()\n",
    "    correct = conf_matrix[digit, digit]\n",
    "    print(f'  Digit {digit}: {correct}/{total} = {100*correct/total:.1f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# --- Plot confusion matrix ---\n",
    "fig, ax = plt.subplots(figsize=(8, 7))\n",
    "im = ax.imshow(conf_matrix, interpolation='nearest', cmap='Blues')\n",
    "ax.set_title('Confusion Matrix -- MNIST MLP', fontweight='bold', fontsize=14)\n",
    "ax.set_xlabel('Predicted Label')\n",
    "ax.set_ylabel('True Label')\n",
    "\n",
    "# Add text annotations\n",
    "thresh = conf_matrix.max() / 2.0\n",
    "for i in range(n_classes):\n",
    "    for j in range(n_classes):\n",
    "        color = 'white' if conf_matrix[i, j] > thresh else 'black'\n",
    "        ax.text(j, i, str(conf_matrix[i, j]),\n",
    "                ha='center', va='center', color=color, fontsize=8)\n",
    "\n",
    "ax.set_xticks(range(10))\n",
    "ax.set_yticks(range(10))\n",
    "fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# --- Show some misclassified examples ---\n",
    "misclassified_idx = np.where(all_preds != all_labels)[0]\n",
    "n_show = min(12, len(misclassified_idx))\n",
    "\n",
    "fig, axes = plt.subplots(2, 6, figsize=(12, 4))\n",
    "for i, ax in enumerate(axes.flat[:n_show]):\n",
    "    idx = misclassified_idx[i]\n",
    "    img, _ = test_dataset[idx]\n",
    "    ax.imshow(img.squeeze(), cmap='gray')\n",
    "    ax.set_title(f'True: {all_labels[idx]}\\nPred: {all_preds[idx]}',\n",
    "                 fontsize=9, color=RED)\n",
    "    ax.axis('off')\n",
    "\n",
    "fig.suptitle('Misclassified Examples', fontsize=14, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercises\n",
    "\n",
    "**Exercise 30.1.** Modify the training loop to record the training loss **per batch**\n",
    "(not per epoch). Plot the loss curve for all batches across all 5 epochs. You should\n",
    "see a noisy but decreasing trend. How does this compare to the per-epoch curve?\n",
    "\n",
    "**Exercise 30.2.** Replace `optim.Adam` with `optim.SGD` (lr=0.01, momentum=0.9)\n",
    "and retrain the MNIST MLP. Compare the final accuracy and the shape of the loss\n",
    "curve with the Adam version. Which optimizer converges faster in terms of epochs?\n",
    "\n",
    "**Exercise 30.3.** Implement a **validation split**: use 50,000 samples for training\n",
    "and 10,000 for validation (from the original 60,000 training set). Use\n",
    "`torch.utils.data.random_split()`. Plot both training and validation loss curves\n",
    "on the same axes. Do you observe any signs of overfitting?\n",
    "\n",
    "**Exercise 30.4.** Experiment with the architecture: try (a) a single hidden layer\n",
    "with 256 units, (b) three hidden layers with 128-64-32 units, and (c) a very\n",
    "shallow network with one hidden layer of 32 units. Report the test accuracy\n",
    "for each. How does depth vs. width affect performance on MNIST?\n",
    "\n",
    "**Exercise 30.5.** Add **dropout** (`nn.Dropout(p=0.2)`) after each ReLU in the MLP.\n",
    "Train for 10 epochs instead of 5. Compare test accuracy with and without dropout.\n",
    "Remember to verify that `model.eval()` disables dropout during evaluation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**References.**\n",
    "\n",
    "- LeCun, Y., Bottou, L., Bengio, Y., and Haffner, P. (1998). \"Gradient-Based Learning Applied to Document Recognition.\" *Proceedings of the IEEE*, 86(11), 2278-2324.\n",
    "- Paszke, A., Gross, S., Massa, F., et al. (2019). \"PyTorch: An Imperative Style, High-Performance Deep Learning Library.\" *NeurIPS 2019*.\n",
    "- Kingma, D. P. and Ba, J. (2015). \"Adam: A Method for Stochastic Optimization.\" *ICLR 2015*."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}