{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cell-0",
   "metadata": {},
   "source": "# Chapter 13: Oja's Rule and Principal Component Analysis\n\n\nIn Chapter 12, we saw that the basic Hebbian rule causes the weight vector to align with\nthe leading eigenvector of the input second-moment matrix (for centered data, the first principal component), but its norm diverges to infinity. In this\nchapter, we study **Oja's rule** (1982), an elegant modification that solves the instability\nproblem while preserving the desirable convergence to the principal component direction."
  },
  {
   "cell_type": "markdown",
   "id": "cell-1",
   "metadata": {},
   "source": [
    "## 13.1 Oja's Rule\n",
    "\n",
    "```{admonition} Definition (Oja's Rule, 1982)\n",
    ":class: note\n",
    "\n",
    "**Oja's rule** modifies the basic Hebbian rule by adding a weight decay term\n",
    "proportional to the square of the output:\n",
    "\n",
    "$$\\Delta \\mathbf{w} = \\eta\\left(y\\mathbf{x} - y^2 \\mathbf{w}\\right)$$\n",
    "\n",
    "where $y = \\mathbf{w}^\\top \\mathbf{x}$ is the linear output and $\\eta > 0$ is the learning rate.\n",
    "\n",
    "Component-wise:\n",
    "\n",
    "$$\\Delta w_i = \\eta(y x_i - y^2 w_i)$$\n",
    "```\n",
    "\n",
    "The first term $y\\mathbf{x}$ is the standard Hebbian term. The second term $-y^2 \\mathbf{w}$\n",
    "is a **weight decay** that prevents the norm from growing without bound.\n",
    "\n",
    "```{tip}\n",
    "**The key insight** -- the $-y^2 w$ term provides automatic weight normalization, preventing the instability of pure Hebbian learning. When $\\|\\mathbf{w}\\|$ grows too large, the output $y$ becomes large, and the decay term $-y^2 \\mathbf{w}$ dominates, pulling the weights back. This creates an elegant self-regulating mechanism.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-2",
   "metadata": {},
   "source": [
    "## 13.2 Derivation from Weight Normalization\n",
    "\n",
    "Oja's rule can be derived as an approximation to the \"Hebb + normalize\" procedure. This\n",
    "derivation reveals its deep connection to constrained optimization.\n",
    "\n",
    "```{admonition} Proof\n",
    ":class: dropdown\n",
    "\n",
    "**Derivation of Oja's rule from Hebbian update + normalization.**\n",
    "\n",
    "**Step 1: Hebbian Update.** Start with the basic Hebbian update:\n",
    "\n",
    "$$\\mathbf{w}' = \\mathbf{w} + \\eta \\, y \\, \\mathbf{x}$$\n",
    "\n",
    "**Step 2: Normalize.** After the update, normalize the weight vector:\n",
    "\n",
    "$$\\mathbf{w}_{\\text{new}} = \\frac{\\mathbf{w}'}{\\|\\mathbf{w}'\\|}$$\n",
    "\n",
    "**Step 3: Compute the Norm.**\n",
    "\n",
    "$$\\|\\mathbf{w}'\\|^2 = \\|\\mathbf{w} + \\eta y \\mathbf{x}\\|^2 = \\|\\mathbf{w}\\|^2 + 2\\eta y (\\mathbf{w}^\\top \\mathbf{x}) + \\eta^2 y^2 \\|\\mathbf{x}\\|^2$$\n",
    "\n",
    "Since $y = \\mathbf{w}^\\top \\mathbf{x}$:\n",
    "\n",
    "$$\\|\\mathbf{w}'\\|^2 = \\|\\mathbf{w}\\|^2 + 2\\eta y^2 + \\eta^2 y^2 \\|\\mathbf{x}\\|^2$$\n",
    "\n",
    "Assuming $\\|\\mathbf{w}\\| = 1$ (we maintain normalization) and keeping terms to first order in $\\eta$:\n",
    "\n",
    "$$\\|\\mathbf{w}'\\|^2 \\approx 1 + 2\\eta y^2$$\n",
    "\n",
    "**Step 4: Taylor Expand the Inverse Norm.**\n",
    "\n",
    "$$\\frac{1}{\\|\\mathbf{w}'\\|} = (1 + 2\\eta y^2)^{-1/2} \\approx 1 - \\eta y^2 + O(\\eta^2)$$\n",
    "\n",
    "using $(1 + u)^{-1/2} \\approx 1 - u/2$ for small $u$.\n",
    "\n",
    "**Step 5: Compute the Normalized Update.**\n",
    "\n",
    "$$\\mathbf{w}_{\\text{new}} = \\frac{\\mathbf{w}'}{\\|\\mathbf{w}'\\|} \\approx (\\mathbf{w} + \\eta y \\mathbf{x})(1 - \\eta y^2)$$\n",
    "\n",
    "$$= \\mathbf{w} + \\eta y \\mathbf{x} - \\eta y^2 \\mathbf{w} - \\eta^2 y^3 \\mathbf{x}$$\n",
    "\n",
    "Dropping the $O(\\eta^2)$ term:\n",
    "\n",
    "$$\\mathbf{w}_{\\text{new}} \\approx \\mathbf{w} + \\eta(y\\mathbf{x} - y^2 \\mathbf{w})$$\n",
    "\n",
    "$$\\Delta \\mathbf{w} = \\eta(y\\mathbf{x} - y^2 \\mathbf{w})$$\n",
    "\n",
    "This is precisely Oja's rule. $\\blacksquare$\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-3",
   "metadata": {},
   "source": "## 13.3 Convergence Theorem\n\n```{admonition} Theorem (Oja's Convergence to First Principal Component)\n:class: important\n\nUnder Oja's rule with stationary inputs having correlation matrix $\\mathbf{C}$ with distinct\neigenvalues $\\lambda_1 > \\lambda_2 > \\cdots > \\lambda_n > 0$, the weight vector converges to\nthe leading eigenvector of $\\mathbf{C}$ (for centered data, the first principal component):\n\n$$\\mathbf{w}(t) \\to \\pm \\mathbf{e}_1 \\quad \\text{as } t \\to \\infty$$\n\nwhere $\\mathbf{e}_1$ is the eigenvector corresponding to $\\lambda_1$, and\n$\\|\\mathbf{w}(t)\\| \\to 1$.\n```\n\n```{tip}\n**Connection to PCA** -- Oja's neuron converges to the eigenvector with the largest eigenvalue of the input correlation matrix $\\mathbf{C} = \\mathbb{E}[\\mathbf{x}\\mathbf{x}^\\top]$. For centered data, this is exactly the first principal component from PCA. The neuron effectively performs online, streaming extraction of the leading eigenvector using only local information.\n```\n\n```{admonition} Proof\n:class: dropdown\n\nWe use a Lyapunov function approach on the unit sphere.\n\n**Step 1: Expected dynamics.** Taking the expectation of Oja's rule over the input\ndistribution (continuous-time approximation):\n\n$$\\frac{d\\mathbf{w}}{dt} = \\eta\\left(\\mathbf{C}\\mathbf{w} - (\\mathbf{w}^\\top \\mathbf{C}\\mathbf{w})\\mathbf{w}\\right)$$\n\nThis follows because $\\mathbb{E}[y\\mathbf{x}] = \\mathbb{E}[\\mathbf{x}\\mathbf{x}^\\top]\\mathbf{w} = \\mathbf{C}\\mathbf{w}$\nand $\\mathbb{E}[y^2] = \\mathbf{w}^\\top \\mathbf{C}\\mathbf{w}$.\n\n**Step 2: The unit sphere is invariant.** Compute $d\\|\\mathbf{w}\\|^2/dt$:\n\n$$\\frac{d\\|\\mathbf{w}\\|^2}{dt} = 2\\mathbf{w}^\\top \\frac{d\\mathbf{w}}{dt} = 2\\eta\\left(\\mathbf{w}^\\top\\mathbf{C}\\mathbf{w} - (\\mathbf{w}^\\top\\mathbf{C}\\mathbf{w})\\|\\mathbf{w}\\|^2\\right) = 2\\eta(\\mathbf{w}^\\top\\mathbf{C}\\mathbf{w})(1 - \\|\\mathbf{w}\\|^2)$$\n\nIf $\\|\\mathbf{w}\\| < 1$, the norm increases; if $\\|\\mathbf{w}\\| > 1$, it decreases.\nThe unit sphere $\\|\\mathbf{w}\\| = 1$ is an attracting invariant set.\n\n**Step 3: Lyapunov function.** On the unit sphere, define:\n\n$$V(\\mathbf{w}) = -\\mathbf{w}^\\top \\mathbf{C} \\mathbf{w}$$\n\nThis is the negative of the variance captured by $\\mathbf{w}$ (the Rayleigh quotient).\n\nCompute $dV/dt$ on the unit sphere:\n\n$$\\frac{dV}{dt} = -2\\mathbf{w}^\\top \\mathbf{C} \\frac{d\\mathbf{w}}{dt} = -2\\eta\\left(\\mathbf{w}^\\top \\mathbf{C}^2 \\mathbf{w} - (\\mathbf{w}^\\top \\mathbf{C} \\mathbf{w})^2\\right)$$\n\nExpand in the eigenbasis. Let $\\mathbf{w} = \\sum_i c_i \\mathbf{e}_i$ with\n$\\sum_i c_i^2 = 1$. Then:\n\n$$\\mathbf{w}^\\top \\mathbf{C}^2 \\mathbf{w} = \\sum_i \\lambda_i^2 c_i^2, \\quad\n(\\mathbf{w}^\\top \\mathbf{C} \\mathbf{w})^2 = \\left(\\sum_i \\lambda_i c_i^2\\right)^2$$\n\nBy the Cauchy-Schwarz inequality (or the variance of $\\lambda_i$ w.r.t. the distribution\n$\\{c_i^2\\}$ on the eigenvalues):\n\n$$\\sum_i \\lambda_i^2 c_i^2 \\geq \\left(\\sum_i \\lambda_i c_i^2\\right)^2$$\n\nwith equality if and only if $\\mathbf{w}$ is an eigenvector of $\\mathbf{C}$.\n\nTherefore $dV/dt \\leq 0$, with equality only at eigenvectors.\n\n**Step 4: Stability analysis.** The critical points on the unit sphere are the eigenvectors\n$\\pm \\mathbf{e}_i$. At each critical point, $V(\\pm \\mathbf{e}_i) = -\\lambda_i$.\n\nThe Lyapunov function is minimized at $\\pm \\mathbf{e}_1$ (where $V = -\\lambda_1$), maximized\nat $\\pm \\mathbf{e}_n$ (where $V = -\\lambda_n$).\n\nLinearization around each eigenvector shows:\n- $\\pm \\mathbf{e}_1$ is a **stable** equilibrium (all perturbation eigenvalues negative)\n- All other $\\pm \\mathbf{e}_i$ for $i > 1$ are **unstable** (saddle points)\n\nBy LaSalle's invariance principle, almost all initial conditions converge to $\\pm \\mathbf{e}_1$. $\\blacksquare$\n```\n\n```{warning}\nOja's rule only finds the **FIRST** principal component. To extract multiple principal components, extensions such as **Sanger's rule** (Generalized Hebbian Algorithm, GHA) or **Oja's subspace rule** are needed. These are discussed in Section 13.4.\n```"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-4",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "np.random.seed(42)\n",
    "\n",
    "# Generate 2D correlated data\n",
    "n_samples = 2000\n",
    "angle = np.pi / 4  # 45 degrees\n",
    "R = np.array([[np.cos(angle), -np.sin(angle)],\n",
    "              [np.sin(angle),  np.cos(angle)]])\n",
    "Lambda_mat = np.diag([3.0, 0.5])\n",
    "C_true = R @ Lambda_mat @ R.T\n",
    "\n",
    "X = np.random.multivariate_normal([0, 0], C_true, n_samples)\n",
    "\n",
    "# True principal component\n",
    "eigenvalues, eigenvectors = np.linalg.eigh(np.cov(X.T))\n",
    "idx = np.argsort(-eigenvalues)  # descending\n",
    "eigenvalues = eigenvalues[idx]\n",
    "eigenvectors = eigenvectors[:, idx]\n",
    "pc1_true = eigenvectors[:, 0]\n",
    "\n",
    "print(f\"True eigenvalues: {eigenvalues}\")\n",
    "print(f\"True PC1 direction: {pc1_true}\")\n",
    "\n",
    "# ---- Oja's Rule ----\n",
    "eta_oja = 0.001\n",
    "w_oja = np.array([1.0, 0.0])  # initial weights\n",
    "w_oja = w_oja / np.linalg.norm(w_oja)\n",
    "\n",
    "oja_norms = [np.linalg.norm(w_oja)]\n",
    "oja_angles = [np.degrees(np.arccos(np.clip(np.abs(w_oja @ pc1_true), -1, 1)))]\n",
    "oja_history = [w_oja.copy()]\n",
    "\n",
    "for epoch in range(3):\n",
    "    for i in range(n_samples):\n",
    "        x = X[i]\n",
    "        y = w_oja @ x\n",
    "        w_oja = w_oja + eta_oja * (y * x - y**2 * w_oja)  # Oja's rule\n",
    "        oja_norms.append(np.linalg.norm(w_oja))\n",
    "        cos_a = np.clip(np.abs(w_oja @ pc1_true) / np.linalg.norm(w_oja), -1, 1)\n",
    "        oja_angles.append(np.degrees(np.arccos(cos_a)))\n",
    "        oja_history.append(w_oja.copy())\n",
    "\n",
    "# Numpy PCA for comparison\n",
    "from numpy.linalg import svd\n",
    "U, S, Vt = svd(X - X.mean(axis=0), full_matrices=False)\n",
    "pca_direction = Vt[0]\n",
    "\n",
    "print(f\"\\nOja final direction: {w_oja / np.linalg.norm(w_oja)}\")\n",
    "print(f\"Numpy PCA direction: {pca_direction}\")\n",
    "print(f\"Oja final norm: {np.linalg.norm(w_oja):.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-5",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# 3-panel visualization: Oja's rule convergence\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
    "\n",
    "# Panel 1: Weight norm over time\n",
    "axes[0].plot(oja_norms, color='blue', linewidth=1)\n",
    "axes[0].axhline(y=1.0, color='red', linestyle='--', label='||w||=1', alpha=0.7)\n",
    "axes[0].set_xlabel('Iteration')\n",
    "axes[0].set_ylabel('||w||')\n",
    "axes[0].set_title('Weight Norm (Stable!)')\n",
    "axes[0].legend()\n",
    "axes[0].grid(True, alpha=0.3)\n",
    "axes[0].set_ylim(0.8, 1.2)\n",
    "\n",
    "# Panel 2: Data with learned direction\n",
    "axes[1].scatter(X[:, 0], X[:, 1], alpha=0.15, s=5, color='gray')\n",
    "w_final = w_oja / np.linalg.norm(w_oja)\n",
    "scale = 4\n",
    "axes[1].annotate('', xy=w_final*scale, xytext=-w_final*scale,\n",
    "                 arrowprops=dict(arrowstyle='->', color='blue', lw=2.5))\n",
    "axes[1].annotate('', xy=pc1_true*scale, xytext=-pc1_true*scale,\n",
    "                 arrowprops=dict(arrowstyle='->', color='red', lw=2.5, linestyle='--'))\n",
    "axes[1].set_xlabel('$x_1$')\n",
    "axes[1].set_ylabel('$x_2$')\n",
    "axes[1].set_title('Data + Learned PC1 Direction')\n",
    "axes[1].set_aspect('equal')\n",
    "axes[1].set_xlim(-6, 6)\n",
    "axes[1].set_ylim(-6, 6)\n",
    "axes[1].legend(['Oja (blue)', 'True PC1 (red)'], loc='upper left')\n",
    "axes[1].grid(True, alpha=0.3)\n",
    "\n",
    "# Panel 3: Angle to true PC1 over time\n",
    "axes[2].plot(oja_angles, color='blue', linewidth=1)\n",
    "axes[2].axhline(y=0, color='red', linestyle='--', alpha=0.5)\n",
    "axes[2].set_xlabel('Iteration')\n",
    "axes[2].set_ylabel('Angle to PC1 (degrees)')\n",
    "axes[2].set_title('Convergence to PC1')\n",
    "axes[2].grid(True, alpha=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('oja_convergence.png', dpi=150, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-5b",
   "metadata": {},
   "source": [
    "### Eigenvector Convergence Visualization\n",
    "\n",
    "The following animation-style plot shows the weight vector rotating toward the first\n",
    "principal component over training iterations, superimposed on the 2D data cloud."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-5c",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": "import numpy as np\nimport matplotlib.pyplot as plt\n\nnp.random.seed(42)\n\n# Generate 2D correlated data\nn_samples = 2000\nangle = np.pi / 4\nR = np.array([[np.cos(angle), -np.sin(angle)],\n              [np.sin(angle),  np.cos(angle)]])\nC_true = R @ np.diag([3.0, 0.5]) @ R.T\nX = np.random.multivariate_normal([0, 0], C_true, n_samples)\n\n# True PC1\nevals, evecs = np.linalg.eigh(np.cov(X.T))\npc1 = evecs[:, np.argmax(evals)]\n\n# Run Oja's rule and capture snapshots\neta = 0.001\nw = np.array([0.0, 1.0])  # start pointing \"up\" (away from PC1)\nw = w / np.linalg.norm(w)\n\n# Capture snapshots at specific iterations\nsnapshot_iters = [0, 10, 50, 200, 500, 1000, 2000, 4000, 6000]\nsnapshots = {0: w.copy()}\niteration = 0\n\nfor epoch in range(3):\n    for i in range(n_samples):\n        x = X[i]\n        y = w @ x\n        w = w + eta * (y * x - y**2 * w)\n        iteration += 1\n        if iteration in snapshot_iters:\n            snapshots[iteration] = w.copy()\n\n# Create the convergence visualization\nfig, ax = plt.subplots(figsize=(10, 6))\n\n# Plot data cloud\nax.scatter(X[:, 0], X[:, 1], alpha=0.1, s=3, color='lightgray', zorder=1)\n\n# Plot true PC1 and PC2\npc2 = evecs[:, np.argmin(evals)]\nax.annotate('', xy=pc1*5, xytext=-pc1*5,\n            arrowprops=dict(arrowstyle='->', color='red', lw=3, linestyle='--'),\n            zorder=5)\nax.annotate('PC1 (target)', xy=pc1*4.5, fontsize=11, color='red',\n            fontweight='bold', zorder=5)\n\n# Plot weight vector snapshots with colormap\ncmap = plt.cm.viridis\nsorted_iters = sorted(snapshots.keys())\nn_snaps = len(sorted_iters)\n\nfor j, it in enumerate(sorted_iters):\n    w_snap = snapshots[it]\n    w_dir = w_snap / np.linalg.norm(w_snap)\n    color = cmap(j / max(n_snaps - 1, 1))\n    scale = 3.5 + 0.5 * j / n_snaps  # slightly growing arrow\n    ax.annotate('', xy=w_dir * scale, xytext=[0, 0],\n                arrowprops=dict(arrowstyle='->', color=color, lw=2.5),\n                zorder=10)\n    ax.annotate(f't={it}', xy=w_dir * (scale + 0.3),\n                fontsize=8, color=color, zorder=10)\n\n# Draw the unit circle for reference\ntheta_circle = np.linspace(0, 2*np.pi, 200)\nax.plot(3.5*np.cos(theta_circle), 3.5*np.sin(theta_circle),\n        'k--', alpha=0.15, linewidth=1)\n\nax.set_xlabel('$x_1$', fontsize=12)\nax.set_ylabel('$x_2$', fontsize=12)\nax.set_title(\"Oja's Rule: Weight Vector Rotating Toward PC1\\n\"\n             \"(color: early=dark, late=yellow)\", fontsize=13)\nax.set_aspect('equal')\nax.set_xlim(-6, 6)\nax.set_ylim(-6, 6)\nax.grid(True, alpha=0.3)\n\nplt.tight_layout()\nplt.show()\n\nprint(\"The weight vector (colored arrows) rotates from its initial direction\")\nprint(\"toward the leading eigenvector (red dashed arrow) over training.\")"
  },
  {
   "cell_type": "markdown",
   "id": "cell-5d",
   "metadata": {},
   "source": [
    "### Side-by-Side: Exact PCA vs Online Oja\n",
    "\n",
    "The following comparison demonstrates that Oja's online rule converges to the same\n",
    "result as exact (batch) PCA computed via eigendecomposition, using a higher-dimensional\n",
    "dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-5e",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "np.random.seed(42)\n",
    "\n",
    "# Generate 5D correlated data with known principal components\n",
    "n_dim = 5\n",
    "n_samples = 5000\n",
    "\n",
    "# Create a covariance matrix with distinct eigenvalues\n",
    "true_eigenvalues = np.array([5.0, 3.0, 1.5, 0.5, 0.1])\n",
    "# Random orthogonal matrix for eigenvectors\n",
    "Q, _ = np.linalg.qr(np.random.randn(n_dim, n_dim))\n",
    "C_true = Q @ np.diag(true_eigenvalues) @ Q.T\n",
    "\n",
    "X = np.random.multivariate_normal(np.zeros(n_dim), C_true, n_samples)\n",
    "\n",
    "# --- Exact PCA ---\n",
    "C_sample = np.cov(X.T)\n",
    "evals_exact, evecs_exact = np.linalg.eigh(C_sample)\n",
    "idx_sort = np.argsort(-evals_exact)\n",
    "evals_exact = evals_exact[idx_sort]\n",
    "evecs_exact = evecs_exact[:, idx_sort]\n",
    "pc1_exact = evecs_exact[:, 0]\n",
    "\n",
    "# --- Oja's Rule (online) ---\n",
    "eta = 0.0005\n",
    "w_oja = np.random.randn(n_dim)\n",
    "w_oja = w_oja / np.linalg.norm(w_oja)\n",
    "\n",
    "oja_norms = [np.linalg.norm(w_oja)]\n",
    "oja_angles = []\n",
    "cos_a = np.clip(np.abs(w_oja @ pc1_exact) / np.linalg.norm(w_oja), -1, 1)\n",
    "oja_angles.append(np.degrees(np.arccos(cos_a)))\n",
    "\n",
    "for epoch in range(5):\n",
    "    for i in range(n_samples):\n",
    "        x = X[i]\n",
    "        y = w_oja @ x\n",
    "        w_oja = w_oja + eta * (y * x - y**2 * w_oja)\n",
    "        if i % 50 == 0:\n",
    "            oja_norms.append(np.linalg.norm(w_oja))\n",
    "            cos_a = np.clip(np.abs(w_oja @ pc1_exact) / np.linalg.norm(w_oja), -1, 1)\n",
    "            oja_angles.append(np.degrees(np.arccos(cos_a)))\n",
    "\n",
    "w_oja_final = w_oja / np.linalg.norm(w_oja)\n",
    "\n",
    "# --- Projection comparison ---\n",
    "proj_exact = X @ pc1_exact\n",
    "proj_oja = X @ w_oja_final\n",
    "\n",
    "# Align sign (PC direction is arbitrary up to sign)\n",
    "if np.corrcoef(proj_exact, proj_oja)[0, 1] < 0:\n",
    "    proj_oja = -proj_oja\n",
    "    w_oja_final = -w_oja_final\n",
    "\n",
    "# Visualization\n",
    "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
    "\n",
    "# Panel 1: Component comparison\n",
    "x_pos = np.arange(n_dim)\n",
    "width = 0.35\n",
    "axes[0].bar(x_pos - width/2, pc1_exact, width, label='Exact PCA', color='#E91E63', alpha=0.8)\n",
    "axes[0].bar(x_pos + width/2, w_oja_final, width, label='Oja (online)', color='#2196F3', alpha=0.8)\n",
    "axes[0].set_xlabel('Component index', fontsize=11)\n",
    "axes[0].set_ylabel('Weight value', fontsize=11)\n",
    "axes[0].set_title('PC1 Components: Exact vs Oja', fontsize=12)\n",
    "axes[0].set_xticks(x_pos)\n",
    "axes[0].legend(fontsize=10)\n",
    "axes[0].grid(True, alpha=0.3, axis='y')\n",
    "\n",
    "# Panel 2: Projection scatter\n",
    "axes[1].scatter(proj_exact, proj_oja, alpha=0.1, s=3, color='steelblue')\n",
    "lim = max(np.abs(proj_exact).max(), np.abs(proj_oja).max()) * 1.1\n",
    "axes[1].plot([-lim, lim], [-lim, lim], 'r--', linewidth=1, label='y = x')\n",
    "axes[1].set_xlabel('Exact PCA projection', fontsize=11)\n",
    "axes[1].set_ylabel('Oja projection', fontsize=11)\n",
    "axes[1].set_title(f'Projection Comparison (r = {np.corrcoef(proj_exact, proj_oja)[0,1]:.6f})',\n",
    "                  fontsize=12)\n",
    "axes[1].set_aspect('equal')\n",
    "axes[1].legend(fontsize=10)\n",
    "axes[1].grid(True, alpha=0.3)\n",
    "\n",
    "# Panel 3: Convergence angle over time\n",
    "axes[2].plot(oja_angles, color='#2196F3', linewidth=1.5)\n",
    "axes[2].axhline(y=0, color='red', linestyle='--', alpha=0.5)\n",
    "axes[2].set_xlabel('Iteration (x50)', fontsize=11)\n",
    "axes[2].set_ylabel('Angle to true PC1 (degrees)', fontsize=11)\n",
    "axes[2].set_title('Oja Convergence in 5D', fontsize=12)\n",
    "axes[2].grid(True, alpha=0.3)\n",
    "\n",
    "plt.suptitle('Exact PCA vs Online Oja\\'s Rule on 5D Data', fontsize=14, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(f\"Exact PC1:  {pc1_exact}\")\n",
    "print(f\"Oja PC1:    {w_oja_final}\")\n",
    "print(f\"Cosine similarity: {np.abs(pc1_exact @ w_oja_final):.8f}\")\n",
    "print(f\"Oja final norm: {np.linalg.norm(w_oja):.6f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-5f",
   "metadata": {},
   "source": [
    "### Weight Norm Stability: Oja vs Pure Hebb\n",
    "\n",
    "This visualization directly compares how the weight norm evolves under Oja's rule\n",
    "(stabilized at $\\|\\mathbf{w}\\| \\approx 1$) versus pure Hebbian learning (diverges\n",
    "exponentially)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-5g",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "np.random.seed(42)\n",
    "\n",
    "# Generate 2D correlated data\n",
    "n_samples = 2000\n",
    "angle = np.pi / 4\n",
    "R = np.array([[np.cos(angle), -np.sin(angle)],\n",
    "              [np.sin(angle),  np.cos(angle)]])\n",
    "C_true = R @ np.diag([3.0, 0.5]) @ R.T\n",
    "X = np.random.multivariate_normal([0, 0], C_true, n_samples)\n",
    "\n",
    "eta = 0.001\n",
    "w_init = np.array([0.5, 0.5])\n",
    "w_init = w_init / np.linalg.norm(w_init)\n",
    "\n",
    "# --- Pure Hebb ---\n",
    "w_hebb = w_init.copy()\n",
    "hebb_norms = [np.linalg.norm(w_hebb)]\n",
    "\n",
    "for epoch in range(3):\n",
    "    for i in range(n_samples):\n",
    "        x = X[i]\n",
    "        y = w_hebb @ x\n",
    "        w_hebb = w_hebb + eta * y * x\n",
    "        hebb_norms.append(np.linalg.norm(w_hebb))\n",
    "\n",
    "# --- Oja ---\n",
    "w_oja = w_init.copy()\n",
    "oja_norms_compare = [np.linalg.norm(w_oja)]\n",
    "\n",
    "for epoch in range(3):\n",
    "    for i in range(n_samples):\n",
    "        x = X[i]\n",
    "        y = w_oja @ x\n",
    "        w_oja = w_oja + eta * (y * x - y**2 * w_oja)\n",
    "        oja_norms_compare.append(np.linalg.norm(w_oja))\n",
    "\n",
    "# --- Hebb + explicit normalization ---\n",
    "w_hebb_norm = w_init.copy()\n",
    "hebb_norm_norms = [np.linalg.norm(w_hebb_norm)]\n",
    "\n",
    "for epoch in range(3):\n",
    "    for i in range(n_samples):\n",
    "        x = X[i]\n",
    "        y = w_hebb_norm @ x\n",
    "        w_hebb_norm = w_hebb_norm + eta * y * x\n",
    "        w_hebb_norm = w_hebb_norm / np.linalg.norm(w_hebb_norm)  # explicit normalize\n",
    "        hebb_norm_norms.append(np.linalg.norm(w_hebb_norm))\n",
    "\n",
    "# Visualization\n",
    "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
    "\n",
    "iterations = np.arange(len(hebb_norms))\n",
    "\n",
    "# Panel 1: Log scale comparison\n",
    "axes[0].plot(iterations, hebb_norms, color='red', linewidth=2, label='Pure Hebb (diverges)')\n",
    "axes[0].plot(iterations, oja_norms_compare, color='blue', linewidth=2, label=\"Oja's Rule (stable)\")\n",
    "axes[0].plot(iterations, hebb_norm_norms, color='green', linewidth=2,\n",
    "             linestyle='--', label='Hebb + normalize')\n",
    "axes[0].set_yscale('log')\n",
    "axes[0].set_xlabel('Iteration', fontsize=12)\n",
    "axes[0].set_ylabel('$||\\\\mathbf{w}||$ (log scale)', fontsize=12)\n",
    "axes[0].set_title('Weight Norm: Three Learning Rules', fontsize=13)\n",
    "axes[0].legend(fontsize=11)\n",
    "axes[0].grid(True, alpha=0.3)\n",
    "\n",
    "# Panel 2: Zoomed in on Oja near ||w||=1\n",
    "axes[1].plot(iterations, oja_norms_compare, color='blue', linewidth=1.5, label=\"Oja's Rule\")\n",
    "axes[1].axhline(y=1.0, color='red', linestyle='--', linewidth=1.5, label='$||w|| = 1$', alpha=0.7)\n",
    "axes[1].set_xlabel('Iteration', fontsize=12)\n",
    "axes[1].set_ylabel('$||\\\\mathbf{w}||$', fontsize=12)\n",
    "axes[1].set_title(\"Oja's Rule: Norm Stabilizes at 1\", fontsize=13)\n",
    "axes[1].set_ylim(0.85, 1.15)\n",
    "axes[1].legend(fontsize=11)\n",
    "axes[1].grid(True, alpha=0.3)\n",
    "\n",
    "plt.suptitle('Weight Norm Over Time: Why Oja Fixes Hebbian Instability',\n",
    "             fontsize=14, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(f\"Final norms: Pure Hebb = {hebb_norms[-1]:.2f}, Oja = {oja_norms_compare[-1]:.4f}, \"\n",
    "      f\"Hebb+normalize = {hebb_norm_norms[-1]:.4f}\")\n",
    "print(\"\\nKey insight: Oja's rule achieves the same effect as 'Hebb + normalize'\")\n",
    "print(\"but through a single elegant update rule with no explicit normalization step.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-6",
   "metadata": {},
   "source": [
    "## 13.4 Sanger's Generalized Hebbian Algorithm (GHA)\n",
    "\n",
    "Oja's rule extracts only the **first** principal component. To extract multiple components,\n",
    "Sanger (1989) proposed the **Generalized Hebbian Algorithm (GHA)**.\n",
    "\n",
    "### Setup\n",
    "\n",
    "Consider $p$ output neurons with weight vectors $\\mathbf{w}_1, \\ldots, \\mathbf{w}_p$,\n",
    "arranged in a weight matrix $\\mathbf{W} \\in \\mathbb{R}^{p \\times n}$ where row $j$ is\n",
    "$\\mathbf{w}_j^\\top$.\n",
    "\n",
    "Outputs: $y_j = \\mathbf{w}_j^\\top \\mathbf{x}$ for $j = 1, \\ldots, p$.\n",
    "\n",
    "### Sanger's Rule\n",
    "\n",
    "$$\\Delta w_{ji} = \\eta \\left( y_j x_i - y_j \\sum_{k=1}^{j} y_k w_{ki} \\right)$$\n",
    "\n",
    "In matrix form, using $\\text{LT}(\\cdot)$ to denote the lower-triangular part (including diagonal):\n",
    "\n",
    "$$\\Delta \\mathbf{W} = \\eta \\left( \\mathbf{y}\\mathbf{x}^\\top - \\text{LT}(\\mathbf{y}\\mathbf{y}^\\top) \\mathbf{W} \\right)$$\n",
    "\n",
    "### Key Idea\n",
    "\n",
    "- The first neuron ($j=1$) follows Oja's rule exactly: $\\Delta \\mathbf{w}_1 = \\eta(y_1 \\mathbf{x} - y_1^2 \\mathbf{w}_1)$\n",
    "- Each subsequent neuron effectively learns from a **deflated** input, with the projections\n",
    "  onto previously extracted components removed.\n",
    "- Result: $\\mathbf{w}_j \\to \\pm \\mathbf{e}_j$ (the $j$-th principal component).\n",
    "\n",
    "### Convergence\n",
    "\n",
    "Under appropriate conditions on the learning rate, GHA converges to the first $p$ principal\n",
    "components in order: $\\mathbf{w}_1 \\to \\pm\\mathbf{e}_1$, $\\mathbf{w}_2 \\to \\pm\\mathbf{e}_2$, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-7",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Side-by-side comparison: Basic Hebb vs Oja's Rule\n",
    "\n",
    "np.random.seed(123)\n",
    "\n",
    "# Generate correlated 2D data\n",
    "n_samples = 2000\n",
    "angle = np.pi / 3\n",
    "R = np.array([[np.cos(angle), -np.sin(angle)],\n",
    "              [np.sin(angle),  np.cos(angle)]])\n",
    "C_true = R @ np.diag([3.0, 0.5]) @ R.T\n",
    "X = np.random.multivariate_normal([0, 0], C_true, n_samples)\n",
    "\n",
    "# True PC1\n",
    "evals, evecs = np.linalg.eigh(np.cov(X.T))\n",
    "pc1 = evecs[:, np.argmax(evals)]\n",
    "\n",
    "eta = 0.001\n",
    "w_init = np.array([0.5, 0.5])\n",
    "w_init = w_init / np.linalg.norm(w_init)\n",
    "\n",
    "# Basic Hebb\n",
    "w_hebb = w_init.copy()\n",
    "hebb_norms = [np.linalg.norm(w_hebb)]\n",
    "hebb_angles = []\n",
    "cos_a = np.clip(np.abs(w_hebb @ pc1) / np.linalg.norm(w_hebb), -1, 1)\n",
    "hebb_angles.append(np.degrees(np.arccos(cos_a)))\n",
    "\n",
    "for epoch in range(3):\n",
    "    for i in range(n_samples):\n",
    "        x = X[i]\n",
    "        y = w_hebb @ x\n",
    "        w_hebb = w_hebb + eta * y * x\n",
    "        hebb_norms.append(np.linalg.norm(w_hebb))\n",
    "        cos_a = np.clip(np.abs(w_hebb @ pc1) / np.linalg.norm(w_hebb), -1, 1)\n",
    "        hebb_angles.append(np.degrees(np.arccos(cos_a)))\n",
    "\n",
    "# Oja\n",
    "w_oja = w_init.copy()\n",
    "oja_norms2 = [np.linalg.norm(w_oja)]\n",
    "oja_angles2 = []\n",
    "cos_a = np.clip(np.abs(w_oja @ pc1) / np.linalg.norm(w_oja), -1, 1)\n",
    "oja_angles2.append(np.degrees(np.arccos(cos_a)))\n",
    "\n",
    "for epoch in range(3):\n",
    "    for i in range(n_samples):\n",
    "        x = X[i]\n",
    "        y = w_oja @ x\n",
    "        w_oja = w_oja + eta * (y * x - y**2 * w_oja)\n",
    "        oja_norms2.append(np.linalg.norm(w_oja))\n",
    "        cos_a = np.clip(np.abs(w_oja @ pc1) / np.linalg.norm(w_oja), -1, 1)\n",
    "        oja_angles2.append(np.degrees(np.arccos(cos_a)))\n",
    "\n",
    "# Plot comparison\n",
    "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
    "\n",
    "# Norm comparison\n",
    "axes[0].plot(hebb_norms, label='Basic Hebb', color='red', alpha=0.8)\n",
    "axes[0].plot(oja_norms2, label=\"Oja's Rule\", color='blue', alpha=0.8)\n",
    "axes[0].set_xlabel('Iteration')\n",
    "axes[0].set_ylabel('||w||')\n",
    "axes[0].set_title('Weight Norm Comparison')\n",
    "axes[0].set_yscale('log')\n",
    "axes[0].legend()\n",
    "axes[0].grid(True, alpha=0.3)\n",
    "\n",
    "# Data with both learned directions\n",
    "axes[1].scatter(X[:, 0], X[:, 1], alpha=0.15, s=5, color='gray')\n",
    "s = 4\n",
    "# Hebb direction (normalized)\n",
    "w_h_dir = w_hebb / np.linalg.norm(w_hebb)\n",
    "axes[1].annotate('', xy=w_h_dir*s, xytext=[0,0],\n",
    "                 arrowprops=dict(arrowstyle='->', color='red', lw=2.5))\n",
    "# Oja direction\n",
    "w_o_dir = w_oja / np.linalg.norm(w_oja)\n",
    "axes[1].annotate('', xy=w_o_dir*s, xytext=[0,0],\n",
    "                 arrowprops=dict(arrowstyle='->', color='blue', lw=2.5))\n",
    "# True PC1\n",
    "axes[1].annotate('', xy=pc1*s, xytext=[0,0],\n",
    "                 arrowprops=dict(arrowstyle='->', color='green', lw=2.5, linestyle='--'))\n",
    "axes[1].set_xlabel('$x_1$')\n",
    "axes[1].set_ylabel('$x_2$')\n",
    "axes[1].set_title('Learned Directions')\n",
    "axes[1].legend(['Hebb', 'Oja', 'True PC1'], loc='upper left')\n",
    "axes[1].set_aspect('equal')\n",
    "axes[1].set_xlim(-6, 6)\n",
    "axes[1].set_ylim(-6, 6)\n",
    "axes[1].grid(True, alpha=0.3)\n",
    "\n",
    "# Angle comparison\n",
    "axes[2].plot(hebb_angles, label='Basic Hebb', color='red', alpha=0.8)\n",
    "axes[2].plot(oja_angles2, label=\"Oja's Rule\", color='blue', alpha=0.8)\n",
    "axes[2].axhline(y=0, color='gray', linestyle='--', alpha=0.5)\n",
    "axes[2].set_xlabel('Iteration')\n",
    "axes[2].set_ylabel('Angle to PC1 (degrees)')\n",
    "axes[2].set_title('Direction Convergence')\n",
    "axes[2].legend()\n",
    "axes[2].grid(True, alpha=0.3)\n",
    "\n",
    "plt.suptitle('Basic Hebb vs Oja: Both converge in direction, only Oja is stable in norm',\n",
    "             fontsize=13, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.savefig('hebb_vs_oja.png', dpi=150, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-8",
   "metadata": {},
   "source": "## Exercises\n\n**Exercise 13.1.** Implement Sanger's GHA to extract the first 3 principal components of a\n5-dimensional dataset. Compare the extracted components with those from `numpy.linalg.eigh`.\n\n**Exercise 13.2.** Apply Oja's rule to the Iris dataset (4 features). Compare the learned\ndirection with the output of `sklearn.decomposition.PCA`. Visualize the data projected onto\nthe leading eigenvector of $\\mathbf{C}$ (for centered data, the first principal component).\n\n**Exercise 13.3.** Investigate the effect of the learning rate $\\eta$ on Oja's rule.\nFor various values of $\\eta$, plot (a) the convergence speed (iterations to reach angle < 1 degree)\nand (b) whether the norm remains stable. What happens for very large $\\eta$?\n\n**Exercise 13.4.** Prove that the equilibrium points of Oja's rule on the unit sphere are exactly\nthe eigenvectors of $\\mathbf{C}$. (Hint: Set $d\\mathbf{w}/dt = 0$ with the constraint\n$\\|\\mathbf{w}\\| = 1$.)\n\n**Exercise 13.5.** Implement a **mini-batch** version of Oja's rule where the weight update\naverages over a batch of $B$ samples. Compare convergence for $B = 1, 10, 50, 200$."
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}