{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chapter 15: Gradient Descent Foundations\n",
    "\n",
    "\n",
    "Having established the limitations of Hebbian learning (Part 4), we now turn to\n",
    "**supervised learning**, where the goal is to minimize a loss function that measures\n",
    "the discrepancy between the network's output and a target. This chapter develops\n",
    "the mathematical foundations of **gradient descent**, the optimization algorithm\n",
    "that underlies all of modern deep learning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 15.1 Optimization as a Framework\n",
    "\n",
    "### The Supervised Learning Setup\n",
    "\n",
    "Given:\n",
    "- A training set $\\{(\\mathbf{x}^{(1)}, \\mathbf{y}^{(1)}), \\ldots, (\\mathbf{x}^{(N)}, \\mathbf{y}^{(N)})\\}$\n",
    "- A parameterized model $f(\\mathbf{x}; \\mathbf{w})$ with parameters $\\mathbf{w}$\n",
    "- A loss function $L(\\hat{\\mathbf{y}}, \\mathbf{y})$ measuring prediction quality\n",
    "\n",
    "The **empirical risk minimization** problem is:\n",
    "\n",
    "$$\\min_{\\mathbf{w}} \\; \\mathcal{L}(\\mathbf{w}) = \\frac{1}{N} \\sum_{i=1}^{N} L\\bigl(f(\\mathbf{x}^{(i)}; \\mathbf{w}),\\; \\mathbf{y}^{(i)}\\bigr)$$\n",
    "\n",
    "This frames learning as an **optimization problem**. The model \"learns\" by finding\n",
    "parameters that minimize the average loss over the training data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 15.2 The Gradient\n",
    "\n",
    "### Definition\n",
    "\n",
    "For a differentiable function $\\mathcal{L}: \\mathbb{R}^n \\to \\mathbb{R}$, the **gradient**\n",
    "at point $\\mathbf{w}$ is the vector of partial derivatives:\n",
    "\n",
    "$$\\nabla \\mathcal{L}(\\mathbf{w}) = \\begin{pmatrix} \\frac{\\partial \\mathcal{L}}{\\partial w_1} \\\\ \\frac{\\partial \\mathcal{L}}{\\partial w_2} \\\\ \\vdots \\\\ \\frac{\\partial \\mathcal{L}}{\\partial w_n} \\end{pmatrix}$$\n",
    "\n",
    "### Key Property\n",
    "\n",
    "The gradient points in the **direction of steepest ascent** of $\\mathcal{L}$.\n",
    "\n",
    "**Proof**: Consider the directional derivative of $\\mathcal{L}$ in direction $\\mathbf{v}$\n",
    "($\\|\\mathbf{v}\\| = 1$):\n",
    "\n",
    "$$D_{\\mathbf{v}} \\mathcal{L} = \\nabla \\mathcal{L}^\\top \\mathbf{v} = \\|\\nabla \\mathcal{L}\\| \\cos\\theta$$\n",
    "\n",
    "where $\\theta$ is the angle between $\\nabla \\mathcal{L}$ and $\\mathbf{v}$.\n",
    "This is maximized when $\\theta = 0$, i.e., $\\mathbf{v} = \\nabla \\mathcal{L} / \\|\\nabla \\mathcal{L}\\|$. $\\blacksquare$\n",
    "\n",
    "**Corollary**: The direction of steepest **descent** is $-\\nabla \\mathcal{L}$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 15.3 The Gradient Descent Update Rule\n",
    "\n",
    "To minimize $\\mathcal{L}(\\mathbf{w})$, we iteratively move in the direction of steepest descent:\n",
    "\n",
    "```{admonition} Algorithm (Gradient Descent)\n",
    ":class: important\n",
    "\n",
    "**Input:** Initial parameters $\\mathbf{w}_0$, learning rate $\\eta > 0$, number of iterations $T$\n",
    "\n",
    "**For** $t = 0, 1, \\ldots, T-1$:\n",
    "\n",
    "$$\\boxed{\\mathbf{w}_{t+1} = \\mathbf{w}_t - \\eta \\, \\nabla \\mathcal{L}(\\mathbf{w}_t)}$$\n",
    "\n",
    "**Output:** Final parameters $\\mathbf{w}_T$\n",
    "\n",
    "The learning rate $\\eta > 0$ controls the step size. At each iteration, compute the full gradient over **all** training examples and take a step in the negative gradient direction.\n",
    "```\n",
    "\n",
    "### Intuition\n",
    "\n",
    "Imagine standing on a mountainside in fog. You can feel the slope beneath your feet\n",
    "(the gradient) but cannot see the valley. Gradient descent says: take a step downhill\n",
    "proportional to the steepness of the slope."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm\n",
    "\n",
    "# Visualize gradient descent on two 2D functions\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(16, 7))\n",
    "\n",
    "# ---- Function 1: Quadratic Bowl ----\n",
    "# L(w1, w2) = w1^2 + 5*w2^2\n",
    "def quadratic(w):\n",
    "    return w[0]**2 + 5*w[1]**2\n",
    "\n",
    "def grad_quadratic(w):\n",
    "    return np.array([2*w[0], 10*w[1]])\n",
    "\n",
    "# Grid for contour\n",
    "w1 = np.linspace(-5, 5, 200)\n",
    "w2 = np.linspace(-3, 3, 200)\n",
    "W1, W2 = np.meshgrid(w1, w2)\n",
    "Z = W1**2 + 5*W2**2\n",
    "\n",
    "axes[0].contour(W1, W2, Z, levels=20, cmap='viridis', alpha=0.6)\n",
    "axes[0].set_xlabel('$w_1$', fontsize=12)\n",
    "axes[0].set_ylabel('$w_2$', fontsize=12)\n",
    "axes[0].set_title('Quadratic: $L = w_1^2 + 5w_2^2$', fontsize=13)\n",
    "\n",
    "# Gradient descent\n",
    "eta = 0.08\n",
    "w = np.array([4.0, 2.5])\n",
    "trajectory = [w.copy()]\n",
    "for _ in range(30):\n",
    "    g = grad_quadratic(w)\n",
    "    w = w - eta * g\n",
    "    trajectory.append(w.copy())\n",
    "trajectory = np.array(trajectory)\n",
    "\n",
    "axes[0].plot(trajectory[:, 0], trajectory[:, 1], 'ro-', markersize=4, linewidth=1.5, label=f'GD ($\\\\eta={eta}$)')\n",
    "axes[0].plot(0, 0, 'g*', markersize=15, label='Minimum')\n",
    "axes[0].legend(fontsize=10)\n",
    "axes[0].grid(True, alpha=0.2)\n",
    "\n",
    "# ---- Function 2: Rosenbrock ----\n",
    "# L(w1, w2) = (1 - w1)^2 + 100*(w2 - w1^2)^2\n",
    "def rosenbrock(w):\n",
    "    return (1 - w[0])**2 + 100*(w[1] - w[0]**2)**2\n",
    "\n",
    "def grad_rosenbrock(w):\n",
    "    dw1 = -2*(1 - w[0]) + 200*(w[1] - w[0]**2)*(-2*w[0])\n",
    "    dw2 = 200*(w[1] - w[0]**2)\n",
    "    return np.array([dw1, dw2])\n",
    "\n",
    "w1r = np.linspace(-2, 2, 300)\n",
    "w2r = np.linspace(-1, 3, 300)\n",
    "W1R, W2R = np.meshgrid(w1r, w2r)\n",
    "ZR = (1 - W1R)**2 + 100*(W2R - W1R**2)**2\n",
    "\n",
    "axes[1].contour(W1R, W2R, np.log1p(ZR), levels=30, cmap='viridis', alpha=0.6)\n",
    "axes[1].set_xlabel('$w_1$', fontsize=12)\n",
    "axes[1].set_ylabel('$w_2$', fontsize=12)\n",
    "axes[1].set_title('Rosenbrock: $L = (1-w_1)^2 + 100(w_2-w_1^2)^2$', fontsize=13)\n",
    "\n",
    "# Gradient descent on Rosenbrock\n",
    "eta_r = 0.001\n",
    "w = np.array([-1.5, 2.0])\n",
    "trajectory_r = [w.copy()]\n",
    "for _ in range(5000):\n",
    "    g = grad_rosenbrock(w)\n",
    "    # Clip gradient for stability\n",
    "    g = np.clip(g, -10, 10)\n",
    "    w = w - eta_r * g\n",
    "    trajectory_r.append(w.copy())\n",
    "trajectory_r = np.array(trajectory_r)\n",
    "\n",
    "# Plot every 50th point for clarity\n",
    "axes[1].plot(trajectory_r[::50, 0], trajectory_r[::50, 1], 'ro-', markersize=3, linewidth=1, label=f'GD ($\\\\eta={eta_r}$)')\n",
    "axes[1].plot(1, 1, 'g*', markersize=15, label='Minimum (1,1)')\n",
    "axes[1].legend(fontsize=10)\n",
    "axes[1].grid(True, alpha=0.2)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('gradient_descent_2d.png', dpi=150, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "print(f\"Quadratic: final point = ({trajectory[-1, 0]:.4f}, {trajectory[-1, 1]:.4f}), loss = {quadratic(trajectory[-1]):.6f}\")\n",
    "print(f\"Rosenbrock: final point = ({trajectory_r[-1, 0]:.4f}, {trajectory_r[-1, 1]:.4f}), loss = {rosenbrock(trajectory_r[-1]):.6f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 15.4 Convex vs Non-Convex Optimization\n\n### Convex Functions\n\nA function $\\mathcal{L}$ is **convex** if for all $\\mathbf{w}_1, \\mathbf{w}_2$ and $\\lambda \\in [0,1]$:\n\n$$\\mathcal{L}(\\lambda \\mathbf{w}_1 + (1-\\lambda)\\mathbf{w}_2) \\leq \\lambda \\mathcal{L}(\\mathbf{w}_1) + (1-\\lambda) \\mathcal{L}(\\mathbf{w}_2)$$\n\n**For convex functions**: every local minimum is a global minimum, and gradient descent\nconverges to the global minimum (with appropriate learning rate).\n\n**Examples**: Linear regression with MSE loss is convex.\n\n### Non-Convex Functions\n\nNeural network loss landscapes are almost always **non-convex**:\n- Multiple local minima\n- Saddle points (critical points where some directions curve up and others down)\n- Flat regions (plateaus)\n\nGradient descent on non-convex functions:\n- May converge to local minima (not global)\n- May get stuck at saddle points\n- Solution depends on initialization\n\n**Remarkable empirical finding**: Despite non-convexity, gradient descent works well on\nneural networks in practice. Current understanding suggests that most local minima in\nhigh-dimensional loss landscapes are nearly as good as the global minimum\n(Choromanska et al., 2015).\n\n```{danger}\n❗ **Local minima and saddle points** -- gradient descent finds LOCAL not GLOBAL minima. For non-convex loss surfaces (which neural networks have), there is NO guarantee of finding the optimal solution. The algorithm's final point depends entirely on initialization and the learning rate.\n```\n\n```{tip}\nThe loss landscape of neural networks is surprisingly benign -- most local minima are nearly as good as the global minimum. Recent research (Choromanska et al., 2015; Li et al., 2018) suggests that in high dimensions, saddle points are more problematic than local minima, and most critical points with high loss are saddle points, not local minima.\n```"
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 15.5 The Learning Rate\n",
    "\n",
    "The learning rate $\\eta$ is the most critical hyperparameter in gradient descent.\n",
    "\n",
    "### Too Large\n",
    "\n",
    "If $\\eta$ is too large, the updates overshoot the minimum, potentially causing **divergence**:\n",
    "the loss oscillates or increases without bound.\n",
    "\n",
    "### Too Small\n",
    "\n",
    "If $\\eta$ is too small, convergence is extremely slow. The algorithm may take an impractical\n",
    "number of iterations to reach an acceptable solution.\n",
    "\n",
    "### The Goldilocks Zone\n",
    "\n",
    "For a quadratic loss $\\mathcal{L}(\\mathbf{w}) = \\frac{1}{2}\\mathbf{w}^\\top \\mathbf{H} \\mathbf{w}$\n",
    "with Hessian $\\mathbf{H}$ having eigenvalues $\\lambda_{\\min}$ and $\\lambda_{\\max}$:\n",
    "\n",
    "- Convergence requires $\\eta < 2/\\lambda_{\\max}$\n",
    "- Optimal learning rate: $\\eta^* = 2/(\\lambda_{\\min} + \\lambda_{\\max})$\n",
    "- Convergence rate depends on the **condition number** $\\kappa = \\lambda_{\\max}/\\lambda_{\\min}$\n",
    "\n",
    "```{warning}\n",
    "**Learning rate sensitivity** -- Too large: the optimization diverges (loss explodes or oscillates wildly). Too small: convergence is extremely slow, potentially requiring millions of iterations. There is no universal \"best\" learning rate; it depends on the loss landscape and must be tuned carefully.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Effect of learning rate: animate trajectory for different eta values\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def quadratic_loss(w):\n",
    "    return w[0]**2 + 10*w[1]**2\n",
    "\n",
    "def quadratic_grad(w):\n",
    "    return np.array([2*w[0], 20*w[1]])\n",
    "\n",
    "etas = [0.005, 0.04, 0.09, 0.105]\n",
    "labels = ['Too small ($\\\\eta=0.005$)', 'Good ($\\\\eta=0.04$)',\n",
    "          'Fast ($\\\\eta=0.09$)', 'Too large ($\\\\eta=0.105$)']\n",
    "colors = ['blue', 'green', 'orange', 'red']\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
    "\n",
    "# Contour plot with trajectories\n",
    "w1 = np.linspace(-5, 5, 200)\n",
    "w2 = np.linspace(-3, 3, 200)\n",
    "W1, W2 = np.meshgrid(w1, w2)\n",
    "Z = W1**2 + 10*W2**2\n",
    "\n",
    "axes[0].contour(W1, W2, Z, levels=20, cmap='gray', alpha=0.4)\n",
    "\n",
    "all_losses = []\n",
    "for eta, label, color in zip(etas, labels, colors):\n",
    "    w = np.array([4.0, 2.5])\n",
    "    traj = [w.copy()]\n",
    "    losses = [quadratic_loss(w)]\n",
    "    for _ in range(50):\n",
    "        g = quadratic_grad(w)\n",
    "        w = w - eta * g\n",
    "        traj.append(w.copy())\n",
    "        losses.append(quadratic_loss(w))\n",
    "        if np.linalg.norm(w) > 100:\n",
    "            break\n",
    "    traj = np.array(traj)\n",
    "    axes[0].plot(traj[:, 0], traj[:, 1], 'o-', color=color, markersize=3,\n",
    "                 linewidth=1.5, label=label, alpha=0.8)\n",
    "    all_losses.append(losses)\n",
    "\n",
    "axes[0].plot(0, 0, 'k*', markersize=15)\n",
    "axes[0].set_xlabel('$w_1$', fontsize=12)\n",
    "axes[0].set_ylabel('$w_2$', fontsize=12)\n",
    "axes[0].set_title('Trajectories for Different Learning Rates', fontsize=13)\n",
    "axes[0].legend(fontsize=9)\n",
    "axes[0].set_xlim(-6, 6)\n",
    "axes[0].set_ylim(-4, 4)\n",
    "axes[0].grid(True, alpha=0.2)\n",
    "\n",
    "# Loss curves\n",
    "for losses, label, color in zip(all_losses, labels, colors):\n",
    "    axes[1].plot(losses, color=color, linewidth=2, label=label)\n",
    "\n",
    "axes[1].set_xlabel('Iteration', fontsize=12)\n",
    "axes[1].set_ylabel('Loss', fontsize=12)\n",
    "axes[1].set_title('Loss vs Iteration', fontsize=13)\n",
    "axes[1].set_yscale('log')\n",
    "axes[1].set_ylim(1e-4, 1e4)\n",
    "axes[1].legend(fontsize=9)\n",
    "axes[1].grid(True, alpha=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('learning_rate_effect.png', dpi=150, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": "## 15.6 Loss Functions\n\n### Mean Squared Error (MSE)\n\nFor regression tasks:\n\n$$L_{\\text{MSE}} = \\frac{1}{2} \\|\\hat{\\mathbf{y}} - \\mathbf{y}\\|^2 = \\frac{1}{2} \\sum_{j} (\\hat{y}_j - y_j)^2$$\n\nGradient with respect to the prediction:\n\n$$\\frac{\\partial L_{\\text{MSE}}}{\\partial \\hat{y}_j} = \\hat{y}_j - y_j$$\n\nThe factor of $1/2$ is a convention that simplifies the gradient.\n\n### Cross-Entropy Loss\n\nFor classification with softmax output:\n\n$$L_{\\text{CE}} = -\\sum_{j} y_j \\log \\hat{y}_j$$\n\nwhere $\\mathbf{y}$ is a one-hot target vector and $\\hat{\\mathbf{y}}$ is the predicted\nprobability distribution.\n\nGradient:\n\n$$\\frac{\\partial L_{\\text{CE}}}{\\partial \\hat{y}_j} = -\\frac{y_j}{\\hat{y}_j}$$\n\nFor the logits $z_j$ before the softmax, the chain rule gives $\\partial L/\\partial z_j = \\hat{y}_j - y_j$. This is different from $\\partial L/\\partial \\hat{y}_j = -y_j/\\hat{y}_j$, so it is important to keep track of which variable is being differentiated.\n\n### Why Cross-Entropy for Classification?\n\n1. **Information-theoretic interpretation**: CE measures the extra bits needed to encode\n   the true distribution using the predicted distribution.\n2. **Stronger gradients**: With MSE + sigmoid, gradients can be very small when the\n   prediction is confidently wrong. CE avoids this.\n3. **Maximum likelihood**: Minimizing CE is equivalent to maximizing the log-likelihood."
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 15.7 Batch, Stochastic, and Mini-Batch Gradient Descent\n",
    "\n",
    "### Batch (Full) Gradient Descent\n",
    "\n",
    "Compute the gradient using the **entire** training set:\n",
    "\n",
    "$$\\mathbf{w}_{t+1} = \\mathbf{w}_t - \\eta \\frac{1}{N} \\sum_{i=1}^{N} \\nabla L_i(\\mathbf{w}_t)$$\n",
    "\n",
    "**Pros**: Exact gradient, smooth convergence. **Cons**: Expensive for large $N$.\n",
    "\n",
    "```{admonition} Algorithm (Stochastic Gradient Descent)\n",
    ":class: important\n",
    "\n",
    "**Input:** Initial parameters $\\mathbf{w}_0$, learning rate $\\eta > 0$, training set $\\{(\\mathbf{x}^{(i)}, \\mathbf{y}^{(i)})\\}_{i=1}^{N}$\n",
    "\n",
    "**For** $t = 0, 1, 2, \\ldots$:\n",
    "1. Sample $i \\sim \\text{Uniform}(1, N)$\n",
    "2. Compute the gradient on the single sample: $\\mathbf{g}_t = \\nabla L_i(\\mathbf{w}_t)$\n",
    "3. Update:\n",
    "\n",
    "$$\\boxed{\\mathbf{w}_{t+1} = \\mathbf{w}_t - \\eta \\, \\nabla L_i(\\mathbf{w}_t) \\quad \\text{where } i \\sim \\text{Uniform}(1, N)}$$\n",
    "\n",
    "**Key property:** $\\mathbb{E}[\\mathbf{g}_t] = \\nabla \\mathcal{L}(\\mathbf{w}_t)$ -- the stochastic gradient is an **unbiased estimator** of the true gradient.\n",
    "\n",
    "**Pros:** Fast updates, can escape local minima (due to noise).  \n",
    "**Cons:** Noisy gradient, variance in updates.\n",
    "```\n",
    "\n",
    "### Mini-Batch Gradient Descent\n",
    "\n",
    "Use a **subset** (mini-batch) of $B$ samples:\n",
    "\n",
    "$$\\mathbf{w}_{t+1} = \\mathbf{w}_t - \\eta \\frac{1}{B} \\sum_{i \\in \\mathcal{B}} \\nabla L_i(\\mathbf{w}_t)$$\n",
    "\n",
    "where $|\\mathcal{B}| = B$ is the **batch size**.\n",
    "\n",
    "**In practice**: Mini-batch is the standard choice. Typical $B \\in \\{32, 64, 128, 256\\}$.\n",
    "\n",
    "```{tip}\n",
    "**Mini-batch as practical compromise** -- Mini-batch gradient descent combines the best of both worlds: it is more computationally efficient than full batch GD (processes only $B \\ll N$ samples per step), while having much lower variance than pure SGD ($B > 1$ samples reduce noise). It also exploits hardware parallelism on GPUs, making it the default choice in practice.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Compare Batch GD vs SGD vs Mini-batch on linear regression\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "np.random.seed(42)\n",
    "\n",
    "# Generate data: y = 3*x + 2 + noise\n",
    "N = 200\n",
    "X_data = np.random.randn(N, 1)\n",
    "y_data = 3 * X_data + 2 + np.random.randn(N, 1) * 0.5\n",
    "\n",
    "# Add bias column\n",
    "X_aug = np.hstack([X_data, np.ones((N, 1))])  # [x, 1]\n",
    "\n",
    "def mse_loss(X, y, w):\n",
    "    pred = X @ w\n",
    "    return 0.5 * np.mean((pred - y)**2)\n",
    "\n",
    "def mse_grad(X, y, w):\n",
    "    pred = X @ w\n",
    "    return X.T @ (pred - y) / len(y)\n",
    "\n",
    "# Batch GD\n",
    "eta_batch = 0.1\n",
    "w_batch = np.array([[0.0], [0.0]])\n",
    "batch_losses = []\n",
    "for epoch in range(100):\n",
    "    batch_losses.append(mse_loss(X_aug, y_data, w_batch))\n",
    "    g = mse_grad(X_aug, y_data, w_batch)\n",
    "    w_batch = w_batch - eta_batch * g\n",
    "\n",
    "# SGD\n",
    "eta_sgd = 0.01\n",
    "w_sgd = np.array([[0.0], [0.0]])\n",
    "sgd_losses = []\n",
    "for epoch in range(100):\n",
    "    sgd_losses.append(mse_loss(X_aug, y_data, w_sgd))\n",
    "    for _ in range(N):\n",
    "        i = np.random.randint(N)\n",
    "        xi = X_aug[i:i+1]\n",
    "        yi = y_data[i:i+1]\n",
    "        g = mse_grad(xi, yi, w_sgd)\n",
    "        w_sgd = w_sgd - eta_sgd * g\n",
    "\n",
    "# Mini-batch GD\n",
    "batch_size = 32\n",
    "eta_mini = 0.05\n",
    "w_mini = np.array([[0.0], [0.0]])\n",
    "mini_losses = []\n",
    "for epoch in range(100):\n",
    "    mini_losses.append(mse_loss(X_aug, y_data, w_mini))\n",
    "    indices = np.random.permutation(N)\n",
    "    for start in range(0, N, batch_size):\n",
    "        batch_idx = indices[start:start+batch_size]\n",
    "        X_b = X_aug[batch_idx]\n",
    "        y_b = y_data[batch_idx]\n",
    "        g = mse_grad(X_b, y_b, w_mini)\n",
    "        w_mini = w_mini - eta_mini * g\n",
    "\n",
    "# Plot\n",
    "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
    "\n",
    "# Loss curves\n",
    "axes[0].plot(batch_losses, label=f'Batch GD ($\\\\eta={eta_batch}$)', linewidth=2)\n",
    "axes[0].plot(sgd_losses, label=f'SGD ($\\\\eta={eta_sgd}$)', linewidth=2, alpha=0.8)\n",
    "axes[0].plot(mini_losses, label=f'Mini-batch B={batch_size} ($\\\\eta={eta_mini}$)', linewidth=2, alpha=0.8)\n",
    "axes[0].set_xlabel('Epoch', fontsize=12)\n",
    "axes[0].set_ylabel('MSE Loss', fontsize=12)\n",
    "axes[0].set_title('Convergence Comparison', fontsize=13)\n",
    "axes[0].set_yscale('log')\n",
    "axes[0].legend(fontsize=10)\n",
    "axes[0].grid(True, alpha=0.3)\n",
    "\n",
    "# Final fits\n",
    "x_plot = np.linspace(-3, 3, 100)\n",
    "axes[1].scatter(X_data, y_data, alpha=0.3, s=10, color='gray', label='Data')\n",
    "axes[1].plot(x_plot, w_batch[0]*x_plot + w_batch[1], linewidth=2, label='Batch GD')\n",
    "axes[1].plot(x_plot, w_sgd[0]*x_plot + w_sgd[1], linewidth=2, label='SGD', linestyle='--')\n",
    "axes[1].plot(x_plot, w_mini[0]*x_plot + w_mini[1], linewidth=2, label='Mini-batch', linestyle='-.')\n",
    "axes[1].plot(x_plot, 3*x_plot + 2, 'k--', linewidth=1, alpha=0.5, label='True: $y=3x+2$')\n",
    "axes[1].set_xlabel('x', fontsize=12)\n",
    "axes[1].set_ylabel('y', fontsize=12)\n",
    "axes[1].set_title('Learned Linear Fits', fontsize=13)\n",
    "axes[1].legend(fontsize=10)\n",
    "axes[1].grid(True, alpha=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('gd_comparison.png', dpi=150, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "print(f\"Batch GD:   w = {w_batch.flatten()} (true: [3, 2])\")\n",
    "print(f\"SGD:        w = {w_sgd.flatten()} (true: [3, 2])\")\n",
    "print(f\"Mini-batch: w = {w_mini.flatten()} (true: [3, 2])\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3D Loss Surface with Gradient Descent Trajectories\n",
    "\n",
    "The following visualization shows a 3D loss surface and gradient descent trajectories for different learning rates, providing geometric intuition for how the optimization proceeds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "\n",
    "np.random.seed(42)\n",
    "\n",
    "# Define loss surface: L(w1, w2) = w1^2 + 3*w2^2 + 0.5*sin(3*w1)*sin(3*w2)\n",
    "def loss_surface(w1, w2):\n",
    "    return w1**2 + 3*w2**2 + 0.5 * np.sin(3*w1) * np.sin(3*w2)\n",
    "\n",
    "def loss_grad(w):\n",
    "    w1, w2 = w\n",
    "    dw1 = 2*w1 + 0.5 * 3 * np.cos(3*w1) * np.sin(3*w2)\n",
    "    dw2 = 6*w2 + 0.5 * np.sin(3*w1) * 3 * np.cos(3*w2)\n",
    "    return np.array([dw1, dw2])\n",
    "\n",
    "# Create grid\n",
    "w1_range = np.linspace(-3, 3, 200)\n",
    "w2_range = np.linspace(-2, 2, 200)\n",
    "W1, W2 = np.meshgrid(w1_range, w2_range)\n",
    "Z = loss_surface(W1, W2)\n",
    "\n",
    "# Run GD for different learning rates\n",
    "etas_3d = [0.01, 0.05, 0.15]\n",
    "colors_3d = ['blue', 'green', 'red']\n",
    "labels_3d = ['$\\\\eta=0.01$ (slow)', '$\\\\eta=0.05$ (good)', '$\\\\eta=0.15$ (aggressive)']\n",
    "\n",
    "trajectories_3d = []\n",
    "for eta in etas_3d:\n",
    "    w = np.array([2.5, 1.5])\n",
    "    traj = [w.copy()]\n",
    "    for _ in range(100):\n",
    "        g = loss_grad(w)\n",
    "        w = w - eta * g\n",
    "        traj.append(w.copy())\n",
    "        if np.linalg.norm(w) > 50:\n",
    "            break\n",
    "    trajectories_3d.append(np.array(traj))\n",
    "\n",
    "# Create figure with 3D surface and 2D contour side by side\n",
    "fig = plt.figure(figsize=(16, 7))\n",
    "\n",
    "# 3D surface plot\n",
    "ax1 = fig.add_subplot(121, projection='3d')\n",
    "ax1.plot_surface(W1, W2, Z, cmap='viridis', alpha=0.6, edgecolor='none')\n",
    "\n",
    "for traj, color, label in zip(trajectories_3d, colors_3d, labels_3d):\n",
    "    z_traj = loss_surface(traj[:, 0], traj[:, 1])\n",
    "    ax1.plot(traj[:, 0], traj[:, 1], z_traj, 'o-', color=color, \n",
    "             markersize=3, linewidth=2, label=label, zorder=10)\n",
    "\n",
    "ax1.set_xlabel('$w_1$', fontsize=11)\n",
    "ax1.set_ylabel('$w_2$', fontsize=11)\n",
    "ax1.set_zlabel('Loss', fontsize=11)\n",
    "ax1.set_title('3D Loss Surface with GD Trajectories', fontsize=12)\n",
    "ax1.view_init(elev=35, azim=-60)\n",
    "ax1.legend(fontsize=8, loc='upper right')\n",
    "\n",
    "# 2D contour view\n",
    "ax2 = fig.add_subplot(122)\n",
    "ax2.contour(W1, W2, Z, levels=30, cmap='viridis', alpha=0.6)\n",
    "ax2.contourf(W1, W2, Z, levels=30, cmap='viridis', alpha=0.2)\n",
    "\n",
    "for traj, color, label in zip(trajectories_3d, colors_3d, labels_3d):\n",
    "    ax2.plot(traj[:, 0], traj[:, 1], 'o-', color=color, \n",
    "             markersize=3, linewidth=1.5, label=label, alpha=0.9)\n",
    "\n",
    "ax2.plot(0, 0, 'k*', markersize=15, label='Global minimum')\n",
    "ax2.set_xlabel('$w_1$', fontsize=12)\n",
    "ax2.set_ylabel('$w_2$', fontsize=12)\n",
    "ax2.set_title('Contour View', fontsize=12)\n",
    "ax2.legend(fontsize=9)\n",
    "ax2.grid(True, alpha=0.2)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('gd_3d_surface.png', dpi=150, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "for traj, label in zip(trajectories_3d, labels_3d):\n",
    "    final_loss = loss_surface(traj[-1, 0], traj[-1, 1])\n",
    "    print(f\"{label}: final point = ({traj[-1, 0]:.4f}, {traj[-1, 1]:.4f}), loss = {final_loss:.6f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Side-by-Side: Batch GD vs SGD vs Mini-batch Trajectories\n",
    "\n",
    "The following visualization shows the three optimization strategies on the same 2D loss surface, highlighting the key differences in their trajectories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "np.random.seed(42)\n",
    "\n",
    "# Loss surface: L(w1, w2) = w1^2 + 5*w2^2\n",
    "def loss_fn(w):\n",
    "    return w[0]**2 + 5*w[1]**2\n",
    "\n",
    "def loss_grad_full(w):\n",
    "    return np.array([2*w[0], 10*w[1]])\n",
    "\n",
    "# Simulate noisy gradient (for SGD and mini-batch)\n",
    "def loss_grad_noisy(w, noise_std=2.0):\n",
    "    \"\"\"Simulate stochastic gradient by adding noise to exact gradient.\"\"\"\n",
    "    g = loss_grad_full(w)\n",
    "    return g + np.random.randn(2) * noise_std\n",
    "\n",
    "def loss_grad_minibatch(w, noise_std=0.7):\n",
    "    \"\"\"Simulate mini-batch gradient (less noise than SGD).\"\"\"\n",
    "    g = loss_grad_full(w)\n",
    "    return g + np.random.randn(2) * noise_std\n",
    "\n",
    "# Starting point\n",
    "w0 = np.array([4.0, 2.0])\n",
    "n_steps = 60\n",
    "\n",
    "# Batch GD\n",
    "w = w0.copy()\n",
    "traj_batch = [w.copy()]\n",
    "for _ in range(n_steps):\n",
    "    w = w - 0.06 * loss_grad_full(w)\n",
    "    traj_batch.append(w.copy())\n",
    "traj_batch = np.array(traj_batch)\n",
    "\n",
    "# SGD\n",
    "np.random.seed(42)\n",
    "w = w0.copy()\n",
    "traj_sgd = [w.copy()]\n",
    "for _ in range(n_steps):\n",
    "    w = w - 0.04 * loss_grad_noisy(w, noise_std=3.0)\n",
    "    traj_sgd.append(w.copy())\n",
    "traj_sgd = np.array(traj_sgd)\n",
    "\n",
    "# Mini-batch\n",
    "np.random.seed(42)\n",
    "w = w0.copy()\n",
    "traj_mini = [w.copy()]\n",
    "for _ in range(n_steps):\n",
    "    w = w - 0.05 * loss_grad_minibatch(w, noise_std=1.0)\n",
    "    traj_mini.append(w.copy())\n",
    "traj_mini = np.array(traj_mini)\n",
    "\n",
    "# Grid for contours\n",
    "w1 = np.linspace(-5, 5, 200)\n",
    "w2 = np.linspace(-3, 3, 200)\n",
    "W1, W2 = np.meshgrid(w1, w2)\n",
    "Z = W1**2 + 5*W2**2\n",
    "\n",
    "# 3-panel plot\n",
    "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n",
    "\n",
    "titles = ['Batch GD (exact gradient)', 'SGD (single sample)', 'Mini-batch (B=32)']\n",
    "trajs = [traj_batch, traj_sgd, traj_mini]\n",
    "colors = ['blue', 'red', 'green']\n",
    "\n",
    "for ax, traj, title, color in zip(axes, trajs, titles, colors):\n",
    "    ax.contour(W1, W2, Z, levels=20, cmap='gray', alpha=0.4)\n",
    "    ax.plot(traj[:, 0], traj[:, 1], 'o-', color=color, markersize=3, \n",
    "            linewidth=1.2, alpha=0.8)\n",
    "    ax.plot(traj[0, 0], traj[0, 1], 'ko', markersize=8, label='Start')\n",
    "    ax.plot(0, 0, 'k*', markersize=15, label='Minimum')\n",
    "    ax.set_xlabel('$w_1$', fontsize=12)\n",
    "    ax.set_ylabel('$w_2$', fontsize=12)\n",
    "    ax.set_title(title, fontsize=13)\n",
    "    ax.legend(fontsize=9)\n",
    "    ax.set_xlim(-5.5, 5.5)\n",
    "    ax.set_ylim(-3.5, 3.5)\n",
    "    ax.grid(True, alpha=0.2)\n",
    "    ax.set_aspect('equal')\n",
    "\n",
    "plt.suptitle('Comparison of Optimization Trajectories on $L = w_1^2 + 5w_2^2$',\n",
    "             fontsize=14, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.savefig('gd_sgd_minibatch_comparison.png', dpi=150, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "print(\"Notice:\")\n",
    "print(\"  - Batch GD: smooth, predictable path\")\n",
    "print(\"  - SGD: very noisy, zigzagging trajectory\")\n",
    "print(\"  - Mini-batch: moderate noise, good compromise\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Learning Rate Sweep Visualization\n",
    "\n",
    "The following plot shows loss curves for a range of learning rates on the same problem, clearly demonstrating how the learning rate affects convergence behavior."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "np.random.seed(42)\n",
    "\n",
    "# Loss function: L(w) = w1^2 + 10*w2^2 (ill-conditioned quadratic)\n",
    "def loss_fn(w):\n",
    "    return w[0]**2 + 10*w[1]**2\n",
    "\n",
    "def loss_grad(w):\n",
    "    return np.array([2*w[0], 20*w[1]])\n",
    "\n",
    "# Learning rate sweep\n",
    "etas_sweep = [0.001, 0.01, 0.1, 1.0, 10.0]\n",
    "w0 = np.array([4.0, 2.5])\n",
    "n_iters = 100\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
    "\n",
    "colors = ['#1f77b4', '#2ca02c', '#ff7f0e', '#d62728', '#9467bd']\n",
    "styles = ['-', '-', '-', '--', ':']\n",
    "\n",
    "for eta, color, style in zip(etas_sweep, colors, styles):\n",
    "    w = w0.copy()\n",
    "    losses = [loss_fn(w)]\n",
    "    diverged = False\n",
    "    for t in range(n_iters):\n",
    "        g = loss_grad(w)\n",
    "        w = w - eta * g\n",
    "        l = loss_fn(w)\n",
    "        if l > 1e10:\n",
    "            losses.append(l)\n",
    "            diverged = True\n",
    "            break\n",
    "        losses.append(l)\n",
    "    \n",
    "    label = f'$\\\\eta = {eta}$'\n",
    "    if diverged:\n",
    "        label += ' (DIVERGES)'\n",
    "    \n",
    "    axes[0].plot(losses, color=color, linewidth=2.5, linestyle=style, label=label)\n",
    "\n",
    "axes[0].set_xlabel('Iteration', fontsize=13)\n",
    "axes[0].set_ylabel('Loss $\\\\mathcal{L}(\\\\mathbf{w})$', fontsize=13)\n",
    "axes[0].set_title('Learning Rate Sweep: Loss Curves', fontsize=14)\n",
    "axes[0].set_yscale('log')\n",
    "axes[0].set_ylim(1e-10, 1e12)\n",
    "axes[0].legend(fontsize=11, loc='upper right')\n",
    "axes[0].grid(True, alpha=0.3)\n",
    "axes[0].axhline(y=1e-6, color='gray', linestyle=':', alpha=0.5, label='Convergence threshold')\n",
    "\n",
    "# Convergence iteration bar chart (or infinity if diverged)\n",
    "convergence_iters = []\n",
    "threshold = 1e-4\n",
    "for eta in etas_sweep:\n",
    "    w = w0.copy()\n",
    "    converged_at = n_iters\n",
    "    for t in range(n_iters):\n",
    "        g = loss_grad(w)\n",
    "        w = w - eta * g\n",
    "        if loss_fn(w) > 1e10:  # diverged\n",
    "            converged_at = -1\n",
    "            break\n",
    "        if loss_fn(w) < threshold:\n",
    "            converged_at = t + 1\n",
    "            break\n",
    "    convergence_iters.append(converged_at)\n",
    "\n",
    "bar_colors = []\n",
    "bar_vals = []\n",
    "for ci, color in zip(convergence_iters, colors):\n",
    "    bar_colors.append(color)\n",
    "    bar_vals.append(ci if ci > 0 else n_iters)\n",
    "\n",
    "bars = axes[1].bar(range(len(etas_sweep)), bar_vals, color=bar_colors, edgecolor='black')\n",
    "\n",
    "# Mark diverged bars\n",
    "for i, ci in enumerate(convergence_iters):\n",
    "    if ci == -1:\n",
    "        axes[1].text(i, bar_vals[i] + 1, 'DIVERGED', ha='center', fontsize=9, \n",
    "                     color='red', fontweight='bold')\n",
    "    elif ci == n_iters:\n",
    "        axes[1].text(i, bar_vals[i] + 1, 'NOT\\nCONVERGED', ha='center', fontsize=8, \n",
    "                     color='orange')\n",
    "    else:\n",
    "        axes[1].text(i, bar_vals[i] + 1, f'{ci} iters', ha='center', fontsize=9)\n",
    "\n",
    "axes[1].set_xticks(range(len(etas_sweep)))\n",
    "axes[1].set_xticklabels([f'$\\\\eta={e}$' for e in etas_sweep], fontsize=10)\n",
    "axes[1].set_xlabel('Learning Rate', fontsize=13)\n",
    "axes[1].set_ylabel(f'Iterations to reach loss < {threshold}', fontsize=12)\n",
    "axes[1].set_title('Convergence Speed vs Learning Rate', fontsize=14)\n",
    "axes[1].grid(True, alpha=0.3, axis='y')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('learning_rate_sweep.png', dpi=150, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "print(\"Summary of learning rate sweep:\")\n",
    "for eta, ci in zip(etas_sweep, convergence_iters):\n",
    "    if ci == -1:\n",
    "        print(f\"  eta = {eta:6.3f}: DIVERGED\")\n",
    "    elif ci == n_iters:\n",
    "        print(f\"  eta = {eta:6.3f}: Did not converge in {n_iters} iterations\")\n",
    "    else:\n",
    "        print(f\"  eta = {eta:6.3f}: Converged in {ci} iterations\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 15.8 Convergence Properties\n",
    "\n",
    "### Batch Gradient Descent on Convex Functions\n",
    "\n",
    "For an $L$-smooth convex function with learning rate $\\eta \\leq 1/L$:\n",
    "\n",
    "$$\\mathcal{L}(\\mathbf{w}_t) - \\mathcal{L}(\\mathbf{w}^*) \\leq \\frac{\\|\\mathbf{w}_0 - \\mathbf{w}^*\\|^2}{2\\eta t} = O\\left(\\frac{1}{t}\\right)$$\n",
    "\n",
    "For $\\mu$-strongly convex functions (linear convergence):\n",
    "\n",
    "$$\\mathcal{L}(\\mathbf{w}_t) - \\mathcal{L}(\\mathbf{w}^*) \\leq \\left(1 - \\frac{\\mu}{L}\\right)^t (\\mathcal{L}(\\mathbf{w}_0) - \\mathcal{L}(\\mathbf{w}^*))$$\n",
    "\n",
    "The convergence rate depends on the condition number $\\kappa = L/\\mu$.\n",
    "\n",
    "### SGD Convergence\n",
    "\n",
    "For SGD with diminishing learning rate $\\eta_t = \\eta_0 / \\sqrt{t}$ on convex functions:\n",
    "\n",
    "$$\\mathbb{E}[\\mathcal{L}(\\bar{\\mathbf{w}}_T)] - \\mathcal{L}(\\mathbf{w}^*) = O\\left(\\frac{1}{\\sqrt{T}}\\right)$$\n",
    "\n",
    "where $\\bar{\\mathbf{w}}_T$ is the average iterate. SGD is slower but cheaper per iteration.\n",
    "\n",
    "### Summary\n",
    "\n",
    "| Method | Cost per iter | Convergence (convex) | Convergence (strongly convex) |\n",
    "|--------|--------------|---------------------|------------------------------|\n",
    "| Batch GD | $O(N)$ | $O(1/t)$ | $O(\\rho^t)$ where $\\rho < 1$ |\n",
    "| SGD | $O(1)$ | $O(1/\\sqrt{t})$ | $O(1/t)$ |\n",
    "| Mini-batch | $O(B)$ | Between the two | Between the two |"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Exercises\n",
    "\n",
    "**Exercise 15.1.** Implement gradient descent with **momentum**:\n",
    "$\\mathbf{v}_{t+1} = \\beta \\mathbf{v}_t + \\nabla \\mathcal{L}(\\mathbf{w}_t)$,\n",
    "$\\mathbf{w}_{t+1} = \\mathbf{w}_t - \\eta \\mathbf{v}_{t+1}$.\n",
    "Compare trajectories on the quadratic bowl with and without momentum ($\\beta = 0.9$).\n",
    "\n",
    "**Exercise 15.2.** Prove that gradient descent on the quadratic\n",
    "$\\mathcal{L}(\\mathbf{w}) = \\frac{1}{2}\\mathbf{w}^\\top \\mathbf{A}\\mathbf{w} - \\mathbf{b}^\\top \\mathbf{w}$\n",
    "converges if and only if $0 < \\eta < 2/\\lambda_{\\max}(\\mathbf{A})$.\n",
    "\n",
    "**Exercise 15.3.** Implement a learning rate schedule: start with $\\eta_0 = 0.1$ and\n",
    "decay by a factor of 10 at epochs 50 and 80. Compare with constant learning rate\n",
    "on the MNIST-like spiral dataset.\n",
    "\n",
    "**Exercise 15.4.** Generate a 1D non-convex function with multiple local minima.\n",
    "Run gradient descent from 20 different random initializations and plot where each\n",
    "one converges. What fraction find the global minimum?\n",
    "\n",
    "**Exercise 15.5.** Derive the gradient of the cross-entropy loss combined with softmax.\n",
    "Show that $\\partial L / \\partial z_j = \\hat{y}_j - y_j$ where $z_j$ are the logits."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}