{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chapter 29: From Micrograd to PyTorch -- Tensors, Autograd, and nn.Module\n",
    "\n",
    "In the previous chapter, we built a reverse-mode automatic differentiation engine from scratch.\n",
    "Our `Value` class operated on scalars -- each number was individually tracked through a\n",
    "computational graph, and gradients flowed backward one scalar at a time. This was\n",
    "conceptually illuminating but computationally impractical: real neural networks have\n",
    "millions of parameters, and operating on them one-by-one would be absurdly slow.\n",
    "\n",
    "PyTorch extends the same idea to **multi-dimensional arrays** -- tensors -- with GPU\n",
    "acceleration. The gradient tape we built by hand in Chapter 28 is precisely what\n",
    "`torch.autograd` does under the hood, but on tensors of arbitrary shape, with\n",
    "hundreds of optimized backward kernels, and with optional CUDA parallelism.\n",
    "\n",
    "This chapter bridges the gap between our educational micrograd engine and the\n",
    "industrial-strength framework we will use for the rest of the course."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as mpatches\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "# Consistent style for all plots\n",
    "plt.rcParams.update({\n",
    "    'figure.dpi': 100,\n",
    "    'font.size': 11,\n",
    "    'axes.titlesize': 13,\n",
    "    'axes.labelsize': 12\n",
    "})\n",
    "\n",
    "# Standard color palette\n",
    "BLUE = '#3b82f6'\n",
    "GREEN = '#059669'\n",
    "RED = '#dc2626'\n",
    "AMBER = '#d97706'\n",
    "INDIGO = '#4f46e5'\n",
    "\n",
    "print('PyTorch version:', torch.__version__)\n",
    "print('CUDA available:', torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 29.1 Historical Context: The Rise of Deep Learning Frameworks\n",
    "\n",
    "The history of neural network software follows a clear trajectory from manual\n",
    "gradient derivation toward fully automatic, hardware-accelerated differentiation.\n",
    "\n",
    "**Theano (2010).** Developed at the Montreal Institute for Learning Algorithms (MILA)\n",
    "under Yoshua Bengio, Theano was the first widely adopted framework to combine symbolic\n",
    "differentiation with GPU compilation. Users defined computation graphs symbolically,\n",
    "then Theano compiled them into optimized CUDA code. The landmark paper by\n",
    "Bergstra et al. (2010) introduced the paradigm of \"define-then-run\" that would\n",
    "dominate for years.\n",
    "\n",
    "**Caffe (2014).** Yangqing Jia's framework from Berkeley emphasized speed and\n",
    "modularity for convolutional networks. Caffe's `prototxt` configuration files\n",
    "made it easy to define standard architectures without writing code, but this\n",
    "rigidity made experimentation with novel architectures difficult.\n",
    "\n",
    "**TensorFlow (2015).** Google's framework, described by Abadi et al. (2016),\n",
    "adopted Theano's define-then-run paradigm with industrial-scale engineering.\n",
    "Its static graph approach offered deployment advantages but made debugging\n",
    "notoriously painful -- Python served merely as a graph-construction language,\n",
    "with actual execution happening in a separate C++ runtime.\n",
    "\n",
    "**PyTorch (2017).** Paszke et al. introduced a radically different approach:\n",
    "**define-by-run** (also called \"eager execution\"). Instead of building a static\n",
    "graph and then executing it, PyTorch builds the computational graph dynamically\n",
    "as operations execute. This means standard Python control flow (`if`, `for`,\n",
    "`while`) works naturally inside models -- a crucial advantage for research.\n",
    "The framework descended from Torch7 (a Lua-based system) and drew on the\n",
    "ideas of Chainer (2015), which pioneered define-by-run in Python.\n",
    "\n",
    "**JAX (2018).** Google's response to PyTorch, led by Bradbury et al., combined\n",
    "NumPy-compatible syntax with functional transformations (`jit`, `grad`, `vmap`).\n",
    "JAX compiles Python+NumPy programs via XLA, offering both eager and compiled modes.\n",
    "\n",
    "```{admonition} The Convergence\n",
    ":class: note\n",
    "By 2019, TensorFlow added eager execution (TF 2.0), and PyTorch added\n",
    "compilation (`torch.jit`). The frameworks converged toward a common design:\n",
    "eager by default for development, with optional compilation for deployment.\n",
    "As of 2024, PyTorch dominates research (>80% of ML papers) while TensorFlow\n",
    "retains a significant deployment footprint.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Timeline of deep learning frameworks\n",
    "fig, ax = plt.subplots(figsize=(12, 4))\n",
    "\n",
    "frameworks = [\n",
    "    (2010, 'Theano', 'Bergstra et al.\\n(MILA)', BLUE),\n",
    "    (2014, 'Caffe', 'Jia et al.\\n(Berkeley)', GREEN),\n",
    "    (2015, 'TensorFlow', 'Abadi et al.\\n(Google)', RED),\n",
    "    (2015.3, 'Chainer', 'Tokui et al.\\n(Preferred Networks)', AMBER),\n",
    "    (2017, 'PyTorch', 'Paszke et al.\\n(Facebook AI)', INDIGO),\n",
    "    (2018, 'JAX', 'Bradbury et al.\\n(Google)', GREEN),\n",
    "]\n",
    "\n",
    "for i, (year, name, authors, color) in enumerate(frameworks):\n",
    "    ypos = 0.6 if i % 2 == 0 else 0.2\n",
    "    ax.scatter(year, 0.4, s=120, color=color, zorder=5)\n",
    "    ax.annotate(f'{name}\\n({int(year)})',\n",
    "                xy=(year, 0.4), xytext=(year, ypos),\n",
    "                ha='center', va='center', fontsize=10, fontweight='bold',\n",
    "                color=color,\n",
    "                arrowprops=dict(arrowstyle='->', color=color, lw=1.5))\n",
    "    ax.text(year, ypos - 0.12, authors, ha='center', va='top',\n",
    "            fontsize=7, color='gray')\n",
    "\n",
    "ax.axhline(y=0.4, color='lightgray', linewidth=2, zorder=1)\n",
    "ax.set_xlim(2009, 2019.5)\n",
    "ax.set_ylim(-0.1, 0.95)\n",
    "ax.set_title('Timeline of Deep Learning Frameworks', fontsize=14, fontweight='bold')\n",
    "ax.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 29.2 Tensors as Generalized Arrays\n",
    "\n",
    "A **tensor** in PyTorch is a multi-dimensional array, conceptually identical to\n",
    "NumPy's `ndarray` but with two critical additions:\n",
    "\n",
    "1. **Automatic differentiation support** -- tensors can track operations for gradient computation.\n",
    "2. **Device placement** -- tensors can live on CPU or GPU, enabling hardware acceleration.\n",
    "\n",
    "The mathematical terminology is precise: a scalar is a 0-dimensional tensor,\n",
    "a vector is 1-dimensional, a matrix is 2-dimensional, and higher-rank objects\n",
    "are simply called tensors. PyTorch uses the same convention.\n",
    "\n",
    "### Creating Tensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Tensor creation ---\n",
    "\n",
    "# From Python lists\n",
    "t1 = torch.tensor([1.0, 2.0, 3.0])\n",
    "print(f'From list:    {t1}, dtype={t1.dtype}, shape={t1.shape}')\n",
    "\n",
    "# From NumPy (shares memory -- no copy!)\n",
    "arr = np.array([[1, 2], [3, 4]], dtype=np.float32)\n",
    "t2 = torch.from_numpy(arr)\n",
    "print(f'From NumPy:   {t2.shape}, dtype={t2.dtype}')\n",
    "\n",
    "# Back to NumPy\n",
    "arr_back = t2.numpy()\n",
    "print(f'Back to NumPy: same object? {np.shares_memory(arr, arr_back)}')\n",
    "\n",
    "# Standard constructors\n",
    "t_zeros = torch.zeros(2, 3)\n",
    "t_ones = torch.ones(2, 3)\n",
    "t_rand = torch.randn(2, 3)  # standard normal\n",
    "t_eye = torch.eye(3)\n",
    "\n",
    "print(f'\\nzeros(2,3):\\n{t_zeros}')\n",
    "print(f'\\neye(3):\\n{t_eye}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Types and Device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Data types ---\n",
    "t_int = torch.tensor([1, 2, 3])          # default: int64\n",
    "t_float = torch.tensor([1.0, 2.0, 3.0])  # default: float32\n",
    "t_double = torch.tensor([1.0, 2.0], dtype=torch.float64)\n",
    "\n",
    "print(f'Integer tensor: dtype={t_int.dtype}')\n",
    "print(f'Float tensor:   dtype={t_float.dtype}')\n",
    "print(f'Double tensor:  dtype={t_double.dtype}')\n",
    "\n",
    "# Type casting\n",
    "t_cast = t_int.float()  # int64 -> float32\n",
    "print(f'After .float(): dtype={t_cast.dtype}')\n",
    "\n",
    "# Device (CPU by default, GPU if available)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f'\\nUsing device: {device}')\n",
    "t_device = t_float.to(device)\n",
    "print(f'Tensor on {t_device.device}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Tensor Operations\n",
    "\n",
    "PyTorch tensors support the same broadcasting and vectorized operations as NumPy.\n",
    "Every operation builds a node in the computational graph (when `requires_grad=True`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Basic operations ---\n",
    "a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])\n",
    "b = torch.tensor([[5.0, 6.0], [7.0, 8.0]])\n",
    "\n",
    "print('Element-wise addition:')\n",
    "print(a + b)\n",
    "\n",
    "print('\\nMatrix multiplication:')\n",
    "print(a @ b)\n",
    "\n",
    "print('\\nBroadcasting (matrix + scalar):')\n",
    "print(a + 10)\n",
    "\n",
    "print('\\nReduction (sum along axis 1):')\n",
    "print(a.sum(dim=1))\n",
    "\n",
    "print('\\nReshape:')\n",
    "print(a.view(4))  # flatten\n",
    "print(a.view(1, 4))  # row vector"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 29.3 Autograd on Tensors\n",
    "\n",
    "In Chapter 28, we built a `Value` class that tracked operations on scalars and\n",
    "accumulated gradients via reverse-mode AD. PyTorch's `autograd` does exactly\n",
    "the same thing, but on tensors.\n",
    "\n",
    "The key API:\n",
    "- Set `requires_grad=True` on a tensor to start tracking operations.\n",
    "- Call `.backward()` on a scalar loss to compute all gradients.\n",
    "- Access gradients via the `.grad` attribute.\n",
    "\n",
    "```{admonition} Connection to Chapter 28\n",
    ":class: important\n",
    "Recall that our micrograd `Value` stored `self.grad` and `self._backward`\n",
    "for each node. PyTorch tensors have the same structure: each tensor with\n",
    "`requires_grad=True` stores a `.grad` tensor and a `.grad_fn` pointing to\n",
    "the backward function of the operation that created it.\n",
    "```\n",
    "\n",
    "### Replicating the Micrograd Example\n",
    "\n",
    "In Chapter 28, we computed gradients of $f(x, y) = (x + y) \\cdot y$ at\n",
    "$x = 2, y = 3$. Let us verify that PyTorch produces the same result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Replicating ch28's micrograd example ---\n",
    "\n",
    "# In ch28, we computed: f(x,y) = (x + y) * y\n",
    "# df/dx = y = 3, df/dy = x + 2y = 2 + 6 = 8\n",
    "\n",
    "x = torch.tensor(2.0, requires_grad=True)\n",
    "y = torch.tensor(3.0, requires_grad=True)\n",
    "\n",
    "# Forward pass\n",
    "f = (x + y) * y\n",
    "print(f'f(2, 3) = {f.item():.1f}')\n",
    "\n",
    "# Backward pass\n",
    "f.backward()\n",
    "\n",
    "print(f'df/dx = {x.grad.item():.1f}  (expected: 3.0)')\n",
    "print(f'df/dy = {y.grad.item():.1f}  (expected: 8.0)')\n",
    "\n",
    "# Verify: the grad_fn shows the last operation\n",
    "print(f'\\nf.grad_fn = {f.grad_fn}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{admonition} Exact Match\n",
    ":class: tip\n",
    "The gradients $\\frac{\\partial f}{\\partial x} = 3$ and $\\frac{\\partial f}{\\partial y} = 8$\n",
    "match exactly what our hand-built `Value` class computed in Chapter 28. This is not a coincidence --\n",
    "both implement the same reverse-mode AD algorithm. The difference is that PyTorch's\n",
    "implementation is written in C++ and operates on tensors of arbitrary shape.\n",
    "```\n",
    "\n",
    "### Autograd with Tensors (not just scalars)\n",
    "\n",
    "The real power of PyTorch emerges when we compute gradients of tensor expressions.\n",
    "Consider a simple linear regression loss:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Autograd on tensor operations ---\n",
    "torch.manual_seed(42)\n",
    "\n",
    "# Parameters\n",
    "W = torch.randn(3, 2, requires_grad=True)\n",
    "b = torch.randn(2, requires_grad=True)\n",
    "\n",
    "# Input batch (4 samples, 3 features)\n",
    "X = torch.randn(4, 3)\n",
    "y_true = torch.randn(4, 2)\n",
    "\n",
    "# Forward pass: linear transformation + MSE loss\n",
    "y_pred = X @ W + b             # (4, 2)\n",
    "loss = ((y_pred - y_true)**2).mean()\n",
    "\n",
    "print(f'Loss: {loss.item():.4f}')\n",
    "print(f'W.grad before backward: {W.grad}')\n",
    "\n",
    "# Backward pass\n",
    "loss.backward()\n",
    "\n",
    "print(f'\\nW.grad after backward (shape {W.grad.shape}):')\n",
    "print(W.grad)\n",
    "print(f'\\nb.grad after backward: {b.grad}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{admonition} Important: Zero Gradients\n",
    ":class: warning\n",
    "PyTorch **accumulates** gradients by default. If you call `.backward()` twice\n",
    "without zeroing, gradients will be summed. This is occasionally useful (e.g.,\n",
    "gradient accumulation across mini-batches) but usually a source of bugs.\n",
    "Always call `optimizer.zero_grad()` or manually set `.grad = None` before\n",
    "each backward pass.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Demonstration: gradient accumulation trap ---\n",
    "p = torch.tensor(2.0, requires_grad=True)\n",
    "\n",
    "# First backward\n",
    "loss1 = p ** 2\n",
    "loss1.backward()\n",
    "print(f'After 1st backward: p.grad = {p.grad.item()}')\n",
    "\n",
    "# Second backward WITHOUT zeroing\n",
    "loss2 = p ** 2\n",
    "loss2.backward()\n",
    "print(f'After 2nd backward (accumulated!): p.grad = {p.grad.item()}')\n",
    "\n",
    "# Fix: zero the gradient\n",
    "p.grad = None\n",
    "loss3 = p ** 2\n",
    "loss3.backward()\n",
    "print(f'After zeroing + 3rd backward: p.grad = {p.grad.item()}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 29.4 nn.Module: Building Neural Networks\n",
    "\n",
    "In Chapter 28, we built a `Neuron` class from `Value` objects, then composed\n",
    "neurons into `Layer` and `MLP` classes. PyTorch provides an analogous but more\n",
    "powerful abstraction: `nn.Module`.\n",
    "\n",
    "An `nn.Module` is any differentiable building block:\n",
    "- It has **parameters** (learnable tensors registered via `nn.Parameter`).\n",
    "- It defines a `forward()` method that computes the output.\n",
    "- It automatically collects parameters from sub-modules.\n",
    "\n",
    "```{admonition} The Module Hierarchy\n",
    ":class: note\n",
    "Just as our micrograd `MLP` contained `Layer` objects which contained `Neuron`\n",
    "objects, a PyTorch `nn.Module` can contain other `nn.Module` instances as\n",
    "attributes. The framework automatically discovers all parameters recursively\n",
    "via `model.parameters()`.\n",
    "```\n",
    "\n",
    "### XOR Network as nn.Module\n",
    "\n",
    "Let us solve the XOR problem -- the same task from Chapters 8 and 28 -- using\n",
    "`nn.Module`. We use the same 2-2-1 architecture."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class XORNet(nn.Module):\n",
    "    \"\"\"A 2-2-1 network for XOR, matching the ch28 micrograd architecture.\"\"\"\n",
    "    \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden = nn.Linear(2, 2)   # 2 inputs -> 2 hidden\n",
    "        self.output = nn.Linear(2, 1)   # 2 hidden -> 1 output\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = torch.tanh(self.hidden(x))\n",
    "        x = torch.tanh(self.output(x))\n",
    "        return x\n",
    "\n",
    "# Inspect the model\n",
    "torch.manual_seed(42)\n",
    "model = XORNet()\n",
    "print(model)\n",
    "print(f'\\nTotal parameters: {sum(p.numel() for p in model.parameters())}')\n",
    "print('\\nParameter details:')\n",
    "for name, param in model.named_parameters():\n",
    "    print(f'  {name}: shape={param.shape}, requires_grad={param.requires_grad}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Train XOR network ---\n",
    "torch.manual_seed(42)\n",
    "model = XORNet()\n",
    "\n",
    "# XOR dataset (matching ch28)\n",
    "X_xor = torch.tensor([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])\n",
    "Y_xor = torch.tensor([[-1.0], [1.0], [1.0], [-1.0]])  # tanh targets\n",
    "\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.5)\n",
    "losses = []\n",
    "\n",
    "for epoch in range(500):\n",
    "    # Forward pass\n",
    "    pred = model(X_xor)\n",
    "    loss = ((pred - Y_xor) ** 2).mean()\n",
    "    \n",
    "    # Backward pass\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    losses.append(loss.item())\n",
    "\n",
    "# Final predictions\n",
    "with torch.no_grad():\n",
    "    final_pred = model(X_xor)\n",
    "\n",
    "print('XOR Training Results:')\n",
    "print(f'Final loss: {losses[-1]:.6f}')\n",
    "print()\n",
    "for i in range(4):\n",
    "    x_str = f'({X_xor[i, 0]:.0f}, {X_xor[i, 1]:.0f})'\n",
    "    print(f'  Input {x_str} -> pred={final_pred[i, 0]:+.4f}, target={Y_xor[i, 0]:+.0f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# --- Plot XOR training curve ---\n",
    "fig, ax = plt.subplots(figsize=(8, 4))\n",
    "ax.plot(losses, color=INDIGO, linewidth=1.5)\n",
    "ax.set_xlabel('Epoch')\n",
    "ax.set_ylabel('MSE Loss')\n",
    "ax.set_title('XOR Training with nn.Module (cf. ch28 micrograd)', fontweight='bold')\n",
    "ax.set_yscale('log')\n",
    "ax.grid(True, alpha=0.3)\n",
    "ax.set_xlim(0, 500)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{admonition} Micrograd vs. PyTorch: Same Algorithm, Different Scale\n",
    ":class: important\n",
    "Compare the training loop above with the one in Chapter 28. The structure is\n",
    "identical: forward pass, loss computation, backward pass, parameter update.\n",
    "The only differences are:\n",
    "1. We operate on **batched tensors** instead of individual `Value` scalars.\n",
    "2. `optimizer.zero_grad()` replaces our manual `p.grad = 0` loop.\n",
    "3. `optimizer.step()` replaces our manual `p.data -= lr * p.grad` loop.\n",
    "4. `torch.no_grad()` context manager replaces our careful avoidance of\n",
    "   gradient tracking during evaluation.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 29.5 Common Modules: nn.Linear, nn.ReLU, nn.Sequential\n",
    "\n",
    "PyTorch provides a rich library of pre-built modules. The most fundamental are:\n",
    "\n",
    "| Module | Description | Parameters |\n",
    "|--------|-------------|------------|\n",
    "| `nn.Linear(in, out)` | Affine transformation $y = xW^T + b$ | $W \\in \\mathbb{R}^{\\text{out} \\times \\text{in}}$, $b \\in \\mathbb{R}^{\\text{out}}$ |\n",
    "| `nn.ReLU()` | Rectified linear unit $\\max(0, x)$ | None |\n",
    "| `nn.Tanh()` | Hyperbolic tangent | None |\n",
    "| `nn.Sigmoid()` | Logistic function $\\sigma(x) = \\frac{1}{1 + e^{-x}}$ | None |\n",
    "| `nn.Sequential(...)` | Chain modules in order | Inherited |\n",
    "\n",
    "### XOR with nn.Sequential\n",
    "\n",
    "For simple feed-forward architectures, `nn.Sequential` eliminates the need\n",
    "to write a custom `forward()` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- XOR with nn.Sequential ---\n",
    "torch.manual_seed(42)\n",
    "\n",
    "model_seq = nn.Sequential(\n",
    "    nn.Linear(2, 8),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(8, 1),\n",
    ")\n",
    "\n",
    "print(model_seq)\n",
    "print(f'\\nTotal parameters: {sum(p.numel() for p in model_seq.parameters())}')\n",
    "\n",
    "# XOR data (0/1 targets for ReLU-based network)\n",
    "X_xor = torch.tensor([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])\n",
    "Y_xor = torch.tensor([[0.0], [1.0], [1.0], [0.0]])\n",
    "\n",
    "optimizer = torch.optim.Adam(model_seq.parameters(), lr=0.01)\n",
    "loss_fn = nn.MSELoss()\n",
    "\n",
    "for epoch in range(1000):\n",
    "    pred = model_seq(X_xor)\n",
    "    loss = loss_fn(pred, Y_xor)\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "# Results\n",
    "with torch.no_grad():\n",
    "    final_pred = model_seq(X_xor)\n",
    "\n",
    "print('\\nXOR with Sequential + ReLU:')\n",
    "for i in range(4):\n",
    "    x_str = f'({X_xor[i, 0]:.0f}, {X_xor[i, 1]:.0f})'\n",
    "    print(f'  Input {x_str} -> pred={final_pred[i, 0]:.4f}, target={Y_xor[i, 0]:.0f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Custom Module vs. Sequential: When to Use Which\n",
    "\n",
    "```{admonition} Rule of Thumb\n",
    ":class: tip\n",
    "Use `nn.Sequential` for straightforward feed-forward architectures where\n",
    "data flows linearly through layers. Write a custom `nn.Module` when you need:\n",
    "- Skip connections (ResNet)\n",
    "- Multiple inputs or outputs\n",
    "- Conditional computation\n",
    "- Custom logic in the forward pass\n",
    "```\n",
    "\n",
    "### The Forward-Backward Duality\n",
    "\n",
    "Every `nn.Module` defines a `forward()` method. PyTorch's autograd automatically\n",
    "provides the corresponding backward computation. This is the industrial realization\n",
    "of the principle we explored in Chapter 28: if you can evaluate a function, you\n",
    "can differentiate it.\n",
    "\n",
    "The following diagram shows the correspondence between our micrograd building blocks\n",
    "and their PyTorch equivalents:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# --- Comparison table ---\n",
    "fig, ax = plt.subplots(figsize=(10, 4))\n",
    "ax.axis('off')\n",
    "\n",
    "table_data = [\n",
    "    ['Concept', 'Micrograd (Ch. 28)', 'PyTorch'],\n",
    "    ['Differentiable value', 'Value(data)', 'torch.tensor(data, requires_grad=True)'],\n",
    "    ['Gradient storage', 'value.grad', 'tensor.grad'],\n",
    "    ['Backward function', 'value._backward()', 'tensor.grad_fn'],\n",
    "    ['Neuron', 'Neuron(nin)', 'nn.Linear(nin, 1)'],\n",
    "    ['Layer', 'Layer(nin, nout)', 'nn.Linear(nin, nout)'],\n",
    "    ['Network', 'MLP(nin, nouts)', 'nn.Sequential(...)'],\n",
    "    ['Training step', 'p.data -= lr * p.grad', 'optimizer.step()'],\n",
    "    ['Zero gradients', 'p.grad = 0', 'optimizer.zero_grad()'],\n",
    "]\n",
    "\n",
    "table = ax.table(cellText=table_data[1:], colLabels=table_data[0],\n",
    "                 cellLoc='left', loc='center',\n",
    "                 colWidths=[0.22, 0.35, 0.43])\n",
    "table.auto_set_font_size(False)\n",
    "table.set_fontsize(9)\n",
    "table.scale(1.0, 1.6)\n",
    "\n",
    "# Style header\n",
    "for j in range(3):\n",
    "    table[0, j].set_facecolor(INDIGO)\n",
    "    table[0, j].set_text_props(color='white', fontweight='bold')\n",
    "\n",
    "# Alternate row colors\n",
    "for i in range(1, len(table_data)):\n",
    "    color = '#f0f0ff' if i % 2 == 0 else 'white'\n",
    "    for j in range(3):\n",
    "        table[i, j].set_facecolor(color)\n",
    "\n",
    "ax.set_title('Micrograd to PyTorch: Concept Mapping', fontsize=13, fontweight='bold', pad=20)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercises\n",
    "\n",
    "**Exercise 29.1.** Create a 3D tensor of shape $(2, 3, 4)$ filled with random integers\n",
    "between 0 and 9. Print its `shape`, `dtype`, `device`, and the total number of elements.\n",
    "Convert it to `float32` and verify the dtype changed.\n",
    "\n",
    "**Exercise 29.2.** Using PyTorch autograd, compute the gradient of\n",
    "$f(x) = \\frac{e^{2x}}{(1 + e^{2x})^2}$ at $x = 1$. Verify your answer by comparing\n",
    "with the analytical derivative (note that $f(x) = \\sigma'(2x) \\cdot 2$ where $\\sigma$\n",
    "is the sigmoid function).\n",
    "\n",
    "**Exercise 29.3.** Build a custom `nn.Module` for a network with **skip connections**:\n",
    "the architecture should be $y = \\text{ReLU}(W_2 \\cdot \\text{ReLU}(W_1 x + b_1) + b_2) + x$.\n",
    "This cannot be expressed as `nn.Sequential`. Test it on a random input of shape $(4, 8)$.\n",
    "\n",
    "**Exercise 29.4.** Demonstrate the gradient accumulation issue: create a parameter $\\theta = 3.0$,\n",
    "compute $\\nabla_\\theta (\\theta^3)$ twice without zeroing, and show that the gradient\n",
    "doubles. Then fix it with `param.grad = None`.\n",
    "\n",
    "**Exercise 29.5.** Rewrite the `XORNet` class to use `nn.ReLU` activations instead of\n",
    "`torch.tanh`, with $[0, 1]$ targets instead of $[-1, 1]$. How many hidden units\n",
    "are needed for reliable convergence? Experiment with hidden sizes 2, 4, 8, and 16."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "**References.**\n",
    "\n",
    "- Paszke, A., Gross, S., Massa, F., et al. (2019). \"PyTorch: An Imperative Style, High-Performance Deep Learning Library.\" *Advances in Neural Information Processing Systems 32*.\n",
    "- Bergstra, J., Breuleux, O., Bastien, F., et al. (2010). \"Theano: A CPU and GPU Math Compiler in Python.\" *Proc. SciPy 2010*.\n",
    "- Abadi, M., Barham, P., Chen, J., et al. (2016). \"TensorFlow: A System for Large-Scale Machine Learning.\" *12th USENIX Symposium on OSDI*.\n",
    "- Bradbury, J., Frostig, R., Hawkins, P., et al. (2018). \"JAX: Composable transformations of Python+NumPy programs.\" *GitHub repository*.\n",
    "- Karpathy, A. (2020). \"micrograd: A tiny scalar-valued autograd engine.\" *GitHub repository*."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}