{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cell-0",
   "metadata": {},
   "source": [
    "# Chapter 14: The Zoo of Learning Rules\n",
    "\n",
    "\n",
    "In the previous chapters, we studied the basic Hebbian rule (Chapter 12) and Oja's\n",
    "stabilized variant (Chapter 13). In this chapter, we survey the broader landscape of\n",
    "biologically-inspired learning rules, with particular attention to the **BCM rule** (Bienenstock,\n",
    "Cooper & Munro, 1982). We then identify the fundamental limitations of Hebbian-family rules\n",
    "and motivate the transition to supervised learning via backpropagation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-1",
   "metadata": {},
   "source": [
    "## 14.1 The BCM Rule\n",
    "\n",
    "### Motivation\n",
    "\n",
    "The BCM (Bienenstock-Cooper-Munro) theory was proposed in 1982 to explain the development\n",
    "of orientation selectivity in visual cortex neurons. Its key innovation is a **sliding\n",
    "threshold** that provides homeostatic stability.\n",
    "\n",
    "```{note}\n",
    "**Historical note** -- BCM theory (1982) predated experimental confirmation by approximately 15 years. The sliding threshold mechanism was a theoretical prediction that was later validated by experiments on synaptic plasticity in visual cortex (Kirkwood, Rioult & Bear, 1996) and hippocampus. This is a remarkable case of theory leading experiment in computational neuroscience.\n",
    "```\n",
    "\n",
    "### Formulation\n",
    "\n",
    "```{admonition} Definition (BCM Rule -- Bienenstock, Cooper, Munro, 1982)\n",
    ":class: note\n",
    "\n",
    "The BCM rule for a single neuron with output $y = \\mathbf{w}^\\top \\mathbf{x}$:\n",
    "\n",
    "$$\\frac{d\\mathbf{w}}{dt} = \\eta \\, \\mathbf{x} \\, y \\, (y - \\theta_M)$$\n",
    "\n",
    "where $\\theta_M$ is the **modification threshold**, defined as a function of recent postsynaptic\n",
    "activity:\n",
    "\n",
    "$$\\theta_M = \\langle y^2 \\rangle$$\n",
    "\n",
    "Here $\\langle \\cdot \\rangle$ denotes a temporal running average.\n",
    "```\n",
    "\n",
    "### Interpretation\n",
    "\n",
    "The term $y(y - \\theta_M)$ creates three regimes:\n",
    "\n",
    "- **$y > \\theta_M$**: Strong postsynaptic activity. The update is positive (LTP).\n",
    "  Active synapses are strengthened.\n",
    "- **$0 < y < \\theta_M$**: Weak postsynaptic activity. The update is negative (LTD).\n",
    "  Weakly active synapses are depressed.\n",
    "- **$y < 0$**: Negative activity (if allowed). The update is positive for $y < 0$ and negative\n",
    "  input, mimicking anti-Hebbian behavior.\n",
    "\n",
    "```{tip}\n",
    "BCM's sliding threshold $\\theta_M$ prevents both runaway potentiation and complete depression. When the neuron is too active, $\\theta_M$ rises, making potentiation harder; when the neuron is too quiet, $\\theta_M$ falls, making potentiation easier. This elegant negative feedback loop ensures long-term stability without any explicit weight normalization.\n",
    "```\n",
    "\n",
    "### Stability Analysis\n",
    "\n",
    "The sliding threshold provides a natural homeostatic mechanism:\n",
    "\n",
    "1. If the neuron becomes **too active** (large $\\langle y^2 \\rangle$), the threshold $\\theta_M$\n",
    "   increases, making it harder for synapses to be potentiated and easier for them to be\n",
    "   depressed. This reduces overall activity.\n",
    "\n",
    "2. If the neuron becomes **too quiet** (small $\\langle y^2 \\rangle$), the threshold $\\theta_M$\n",
    "   decreases, making potentiation easier and depression harder. This increases activity.\n",
    "\n",
    "This creates a **negative feedback loop** that stabilizes the neuron's firing rate.\n",
    "\n",
    "**Formally**: The BCM rule has stable fixed points where the weight vector selects for\n",
    "specific input patterns (orientation selectivity). The fixed points satisfy:\n",
    "\n",
    "$$\\mathbb{E}[\\mathbf{x} \\, y \\, (y - \\theta_M)] = 0$$\n",
    "\n",
    "with $\\theta_M = \\mathbb{E}[y^2]$. Cooper, Intrator & others showed that these fixed points\n",
    "are stable and correspond to directions that maximize a \"selectivity\" objective."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-2",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "np.random.seed(42)\n",
    "\n",
    "# BCM Rule Implementation\n",
    "# Generate oriented input patterns (like visual cortex stimuli)\n",
    "\n",
    "n_inputs = 10\n",
    "n_patterns = 4\n",
    "n_samples = 20000\n",
    "\n",
    "# Create oriented patterns\n",
    "patterns = []\n",
    "for k in range(n_patterns):\n",
    "    theta = k * np.pi / n_patterns\n",
    "    p = np.zeros(n_inputs)\n",
    "    center = n_inputs // 2\n",
    "    for i in range(n_inputs):\n",
    "        p[i] = np.exp(-0.5 * ((i - center) * np.cos(theta))**2 / 2.0)\n",
    "    p = p / np.linalg.norm(p)\n",
    "    patterns.append(p)\n",
    "\n",
    "patterns = np.array(patterns)\n",
    "\n",
    "# BCM learning\n",
    "eta = 0.01\n",
    "tau_theta = 100  # time constant for threshold averaging\n",
    "w = np.random.randn(n_inputs) * 0.1\n",
    "theta_M = 0.1  # initial threshold\n",
    "\n",
    "# Track history\n",
    "w_norms = [np.linalg.norm(w)]\n",
    "theta_history = [theta_M]\n",
    "y_history = []\n",
    "selectivity_history = []\n",
    "\n",
    "for t in range(n_samples):\n",
    "    # Randomly select a pattern\n",
    "    idx = np.random.randint(n_patterns)\n",
    "    x = patterns[idx] + np.random.randn(n_inputs) * 0.05  # add noise\n",
    "    \n",
    "    # Output\n",
    "    y = w @ x\n",
    "    y_history.append(y)\n",
    "    \n",
    "    # BCM update\n",
    "    dw = eta * x * y * (y - theta_M)\n",
    "    w = w + dw\n",
    "    \n",
    "    # Update sliding threshold (exponential moving average of y^2)\n",
    "    theta_M = theta_M + (1.0 / tau_theta) * (y**2 - theta_M)\n",
    "    \n",
    "    w_norms.append(np.linalg.norm(w))\n",
    "    theta_history.append(theta_M)\n",
    "    \n",
    "    # Measure selectivity: response to each pattern\n",
    "    if t % 100 == 0:\n",
    "        responses = [w @ p for p in patterns]\n",
    "        selectivity_history.append(responses)\n",
    "\n",
    "selectivity_history = np.array(selectivity_history)\n",
    "\n",
    "# Plot results\n",
    "fig, axes = plt.subplots(2, 2, figsize=(14, 10))\n",
    "\n",
    "# Weight norm\n",
    "axes[0, 0].plot(w_norms)\n",
    "axes[0, 0].set_xlabel('Iteration')\n",
    "axes[0, 0].set_ylabel('||w||')\n",
    "axes[0, 0].set_title('Weight Norm (BCM: Stable!)')\n",
    "axes[0, 0].grid(True, alpha=0.3)\n",
    "\n",
    "# Sliding threshold\n",
    "axes[0, 1].plot(theta_history, color='orange')\n",
    "axes[0, 1].set_xlabel('Iteration')\n",
    "axes[0, 1].set_ylabel(r'$\\theta_M$')\n",
    "axes[0, 1].set_title('Sliding Threshold $\\\\theta_M = \\\\langle y^2 \\\\rangle$')\n",
    "axes[0, 1].grid(True, alpha=0.3)\n",
    "\n",
    "# Selectivity development\n",
    "for k in range(n_patterns):\n",
    "    axes[1, 0].plot(selectivity_history[:, k], label=f'Pattern {k+1}')\n",
    "axes[1, 0].set_xlabel('Time (x100 iterations)')\n",
    "axes[1, 0].set_ylabel('Response')\n",
    "axes[1, 0].set_title('Orientation Selectivity Development')\n",
    "axes[1, 0].legend()\n",
    "axes[1, 0].grid(True, alpha=0.3)\n",
    "\n",
    "# Final weight vector\n",
    "axes[1, 1].bar(range(n_inputs), w, color='steelblue')\n",
    "axes[1, 1].set_xlabel('Input index')\n",
    "axes[1, 1].set_ylabel('Weight')\n",
    "axes[1, 1].set_title('Final Weight Vector (Receptive Field)')\n",
    "axes[1, 1].grid(True, alpha=0.3)\n",
    "\n",
    "plt.suptitle('BCM Rule: Homeostatic Hebbian Learning', fontsize=14, fontweight='bold')\n",
    "plt.tight_layout()\n",
    "plt.savefig('bcm_learning.png', dpi=150, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "print(\"\\nKey observation: BCM develops selectivity for one pattern.\")\n",
    "print(\"The sliding threshold ensures stability without explicit normalization.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-2b",
   "metadata": {},
   "source": [
    "### BCM Selectivity Curve\n",
    "\n",
    "The BCM rule's behavior is governed by the **modification function** $\\phi(y, \\theta_M) = y(y - \\theta_M)$.\n",
    "This function determines whether a given level of postsynaptic activity leads to potentiation\n",
    "or depression. The sliding threshold $\\theta_M$ shifts this curve dynamically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-2c",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# BCM selectivity curve: the phi function with sliding threshold\n",
    "\n",
    "y_vals = np.linspace(-1, 4, 500)\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n",
    "\n",
    "# Panel 1: phi(y) = y(y - theta_M) for different theta_M values\n",
    "theta_M_values = [0.5, 1.0, 2.0, 3.0]\n",
    "colors = ['#2196F3', '#4CAF50', '#FF9800', '#E91E63']\n",
    "\n",
    "for theta_M, color in zip(theta_M_values, colors):\n",
    "    phi = y_vals * (y_vals - theta_M)\n",
    "    axes[0].plot(y_vals, phi, color=color, linewidth=2,\n",
    "                label=f'$\\\\theta_M = {theta_M}$')\n",
    "    # Mark the threshold crossing\n",
    "    axes[0].plot(theta_M, 0, 'o', color=color, markersize=8)\n",
    "\n",
    "axes[0].axhline(y=0, color='gray', linewidth=0.5)\n",
    "axes[0].axvline(x=0, color='gray', linewidth=0.5)\n",
    "axes[0].fill_between(y_vals, 0, 0.1, where=(y_vals > 0) & (y_vals < 1.0),\n",
    "                     alpha=0.1, color='red', label='LTD zone ($\\\\theta_M=1$)')\n",
    "axes[0].fill_between(y_vals, 0, 0.1, where=(y_vals > 1.0),\n",
    "                     alpha=0.1, color='green', label='LTP zone ($\\\\theta_M=1$)')\n",
    "axes[0].set_xlabel('Postsynaptic activity $y$', fontsize=11)\n",
    "axes[0].set_ylabel('$\\\\phi(y) = y(y - \\\\theta_M)$', fontsize=11)\n",
    "axes[0].set_title('BCM Modification Function', fontsize=12)\n",
    "axes[0].legend(fontsize=9)\n",
    "axes[0].set_xlim(-1, 4)\n",
    "axes[0].set_ylim(-2, 6)\n",
    "axes[0].grid(True, alpha=0.3)\n",
    "\n",
    "# Panel 2: Sliding threshold dynamics\n",
    "# Simulate theta_M adaptation for different mean activity levels\n",
    "np.random.seed(42)\n",
    "n_steps = 2000\n",
    "tau = 100\n",
    "\n",
    "activity_levels = [0.5, 1.0, 2.0]\n",
    "activity_colors = ['#2196F3', '#4CAF50', '#E91E63']\n",
    "\n",
    "for mu, color in zip(activity_levels, activity_colors):\n",
    "    theta_M = 0.5  # initial\n",
    "    theta_history = [theta_M]\n",
    "    for t in range(n_steps):\n",
    "        y = np.random.exponential(mu)  # random activity\n",
    "        theta_M = theta_M + (1/tau) * (y**2 - theta_M)\n",
    "        theta_history.append(theta_M)\n",
    "    axes[1].plot(theta_history, color=color, linewidth=1.5,\n",
    "                label=f'Mean activity = {mu}')\n",
    "    axes[1].axhline(y=mu**2 + mu**2, color=color, linestyle='--', alpha=0.4)\n",
    "\n",
    "axes[1].set_xlabel('Time step', fontsize=11)\n",
    "axes[1].set_ylabel('$\\\\theta_M$', fontsize=11)\n",
    "axes[1].set_title('Sliding Threshold Adaptation', fontsize=12)\n",
    "axes[1].legend(fontsize=10)\n",
    "axes[1].grid(True, alpha=0.3)\n",
    "\n",
    "# Panel 3: Selectivity -- response to preferred vs non-preferred stimuli\n",
    "y_preferred = np.linspace(0, 4, 200)\n",
    "theta_M_fixed = 1.5\n",
    "phi_preferred = y_preferred * (y_preferred - theta_M_fixed)\n",
    "\n",
    "axes[2].fill_between(y_preferred, 0, phi_preferred,\n",
    "                     where=phi_preferred > 0, alpha=0.3, color='green')\n",
    "axes[2].fill_between(y_preferred, 0, phi_preferred,\n",
    "                     where=phi_preferred < 0, alpha=0.3, color='red')\n",
    "axes[2].plot(y_preferred, phi_preferred, 'k-', linewidth=2)\n",
    "axes[2].axhline(y=0, color='gray', linewidth=0.5)\n",
    "axes[2].axvline(x=theta_M_fixed, color='orange', linewidth=2, linestyle='--',\n",
    "               label=f'$\\\\theta_M = {theta_M_fixed}$')\n",
    "\n",
    "# Annotate\n",
    "axes[2].annotate('Depression\\n(weak stimuli)', xy=(0.7, -0.3),\n",
    "                fontsize=10, color='red', ha='center', fontweight='bold')\n",
    "axes[2].annotate('Potentiation\\n(strong stimuli)', xy=(2.8, 2.5),\n",
    "                fontsize=10, color='green', ha='center', fontweight='bold')\n",
    "\n",
    "axes[2].set_xlabel('Response to stimulus $y$', fontsize=11)\n",
    "axes[2].set_ylabel('Weight change $\\\\phi(y)$', fontsize=11)\n",
    "axes[2].set_title(f'Selectivity: Only Strong Responses\\nAre Reinforced ($\\\\theta_M={theta_M_fixed}$)',\n",
    "                  fontsize=12)\n",
    "axes[2].legend(fontsize=11)\n",
    "axes[2].grid(True, alpha=0.3)\n",
    "\n",
    "plt.suptitle('BCM Rule: The Sliding Threshold Creates Selectivity',\n",
    "             fontsize=14, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(\"Key insight: The BCM phi function creates a natural threshold between\")\n",
    "print(\"potentiation and depression. Only stimuli that drive the neuron above\")\n",
    "print(\"theta_M are reinforced; weaker stimuli are actively suppressed.\")\n",
    "print(\"This leads to selectivity for a preferred stimulus pattern.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-2d",
   "metadata": {},
   "source": [
    "## 14.1b Spike-Timing-Dependent Plasticity (STDP)\n",
    "\n",
    "```{admonition} Definition (STDP)\n",
    ":class: note\n",
    "\n",
    "**Spike-Timing-Dependent Plasticity (STDP)** is a biologically observed learning rule where the\n",
    "sign and magnitude of synaptic modification depend on the precise timing between pre- and\n",
    "postsynaptic spikes:\n",
    "\n",
    "$$\\Delta w = \\begin{cases}\n",
    "A_+ \\exp\\left(-\\dfrac{\\Delta t}{\\tau_+}\\right) & \\text{if } \\Delta t > 0 \\text{ (pre before post: LTP)} \\\\\n",
    "-A_- \\exp\\left(\\dfrac{\\Delta t}{\\tau_-}\\right) & \\text{if } \\Delta t < 0 \\text{ (post before pre: LTD)}\n",
    "\\end{cases}$$\n",
    "\n",
    "where $\\Delta t = t_{\\text{post}} - t_{\\text{pre}}$, $A_+, A_-$ are amplitude parameters,\n",
    "and $\\tau_+, \\tau_-$ are time constants (typically $\\sim 20$ ms).\n",
    "\n",
    "STDP refines Hebb's postulate by incorporating **temporal causality**: only synapses where\n",
    "presynaptic activity *precedes* postsynaptic firing are strengthened.\n",
    "```\n",
    "\n",
    "```{warning}\n",
    "**Biological plausibility vs mathematical tractability** -- a fundamental tension in computational neuroscience. STDP and BCM are biologically realistic but mathematically complex; Oja's rule and the Perceptron rule are mathematically clean but biologically implausible. No single learning rule currently bridges this gap satisfactorily. This tension drives much of the ongoing research in theoretical neuroscience.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-3",
   "metadata": {},
   "source": [
    "## 14.2 Comparison Table: The Zoo of Hebbian Learning Rules\n",
    "\n",
    "| Rule | Formula | Stable? | Bio. Plaus. | Key Property |\n",
    "|------|---------|---------|------------|---------------|\n",
    "| **Basic Hebb** | $\\Delta w_i = \\eta x_i y$ | No (diverges) | Moderate | Simplest correlation rule |\n",
    "| **Covariance** | $\\Delta w_i = \\eta(x_i - \\bar{x}_i)(y - \\bar{y})$ | No (diverges) | Moderate | Centered; allows LTD |\n",
    "| **Oja** | $\\Delta w_i = \\eta(y x_i - y^2 w_i)$ | Yes ($\\|w\\| \\to 1$) | Low | Extracts PC1 |\n",
    "| **Sanger (GHA)** | $\\Delta w_{ji} = \\eta(y_j x_i - y_j \\sum_{k \\leq j} y_k w_{ki})$ | Yes | Low | Extracts top $p$ PCs |\n",
    "| **BCM** | $\\Delta w_i = \\eta x_i y(y - \\theta_M)$ | Yes (via $\\theta_M$) | High | Selectivity; homeostasis |\n",
    "\n",
    "### Key Observations\n",
    "\n",
    "1. **Stability requires modification**: The basic Hebbian rule is unstable. Every useful\n",
    "   variant adds some form of normalization or threshold.\n",
    "\n",
    "2. **Biological plausibility vs. mathematical elegance**: Oja and Sanger are mathematically\n",
    "   clean but biologically implausible (they require access to $y^2$, which is non-local in\n",
    "   a biological sense). BCM is more biologically motivated.\n",
    "\n",
    "3. **All are unsupervised**: None of these rules use a target signal. They all extract\n",
    "   statistical structure from the input distribution."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-3b",
   "metadata": {},
   "source": [
    "### Comprehensive Comparison Table (Visualization)\n",
    "\n",
    "The following code creates a detailed matplotlib comparison table of all major learning\n",
    "rules covered in this part of the course."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-3c",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Comprehensive comparison table of ALL learning rules\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(14, 6))\n",
    "ax.axis('off')\n",
    "\n",
    "# Table data\n",
    "columns = ['Rule', 'Formula', 'Supervised?', 'Stable?', 'Biological?', 'Key Property']\n",
    "rows = [\n",
    "    ['Hebb',       r'$\\Delta w = \\eta \\, x \\, y$',\n",
    "     'No',  'No',  'Moderate',  'Simplest correlation'],\n",
    "    ['Oja',        r'$\\Delta w = \\eta(xy - y^2 w)$',\n",
    "     'No',  'Yes', 'Low',       'Extracts PC1'],\n",
    "    ['BCM',        r'$\\Delta w = \\eta \\, x \\, y(y-\\theta_M)$',\n",
    "     'No',  'Yes', 'High',      'Selective responses'],\n",
    "    ['Perceptron', r'$\\Delta w = \\eta(t - y) \\, x$',\n",
    "     'Yes', 'Yes', 'Low',       'Linear classification'],\n",
    "    ['STDP',       r'$\\Delta w = f(\\Delta t)$',\n",
    "     'No',  'Conditional', 'Very High', 'Temporal causality'],\n",
    "]\n",
    "\n",
    "# Create the table\n",
    "table = ax.table(\n",
    "    cellText=rows,\n",
    "    colLabels=columns,\n",
    "    cellLoc='center',\n",
    "    loc='center',\n",
    "    colWidths=[0.1, 0.25, 0.1, 0.1, 0.1, 0.2]\n",
    ")\n",
    "\n",
    "# Style the table\n",
    "table.auto_set_font_size(False)\n",
    "table.set_fontsize(11)\n",
    "table.scale(1.0, 2.2)\n",
    "\n",
    "# Header styling\n",
    "for j in range(len(columns)):\n",
    "    cell = table[0, j]\n",
    "    cell.set_facecolor('#1565C0')\n",
    "    cell.set_text_props(color='white', fontweight='bold', fontsize=11)\n",
    "    cell.set_edgecolor('white')\n",
    "\n",
    "# Row coloring and special formatting\n",
    "stability_colors = {\n",
    "    'No': '#FFCDD2',       # light red\n",
    "    'Yes': '#C8E6C9',      # light green\n",
    "    'Conditional': '#FFF9C4'  # light yellow\n",
    "}\n",
    "\n",
    "bio_colors = {\n",
    "    'Low': '#FFCDD2',\n",
    "    'Moderate': '#FFF9C4',\n",
    "    'High': '#C8E6C9',\n",
    "    'Very High': '#81C784'\n",
    "}\n",
    "\n",
    "for i in range(len(rows)):\n",
    "    # Alternate row background\n",
    "    bg_color = '#F5F5F5' if i % 2 == 0 else '#FFFFFF'\n",
    "    for j in range(len(columns)):\n",
    "        cell = table[i + 1, j]\n",
    "        cell.set_facecolor(bg_color)\n",
    "        cell.set_edgecolor('#E0E0E0')\n",
    "    \n",
    "    # Color the Stable? column\n",
    "    stable_val = rows[i][3]\n",
    "    if stable_val in stability_colors:\n",
    "        table[i + 1, 3].set_facecolor(stability_colors[stable_val])\n",
    "    \n",
    "    # Color the Biological? column\n",
    "    bio_val = rows[i][4]\n",
    "    if bio_val in bio_colors:\n",
    "        table[i + 1, 4].set_facecolor(bio_colors[bio_val])\n",
    "    \n",
    "    # Color the Supervised? column\n",
    "    sup_val = rows[i][2]\n",
    "    if sup_val == 'Yes':\n",
    "        table[i + 1, 2].set_facecolor('#BBDEFB')  # light blue for supervised\n",
    "\n",
    "ax.set_title('Comprehensive Comparison of Neural Learning Rules',\n",
    "             fontsize=15, fontweight='bold', pad=20)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(\"Color coding:\")\n",
    "print(\"  Stable? column: Green = Yes, Red = No, Yellow = Conditional\")\n",
    "print(\"  Biological? column: Dark green = Very High, Light green = High,\")\n",
    "print(\"                      Yellow = Moderate, Red = Low\")\n",
    "print(\"  Supervised? column: Blue = Yes (supervised)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-3d",
   "metadata": {},
   "source": [
    "### Learning Rule Dynamics Comparison\n",
    "\n",
    "The following 4-panel plot shows the weight evolution under each learning rule when\n",
    "presented with the same input data, providing a direct visual comparison of their\n",
    "stability and convergence properties."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-3e",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "np.random.seed(42)\n",
    "\n",
    "# Learning rule dynamics comparison: same input, different rules\n",
    "\n",
    "# Generate 2D correlated data\n",
    "n_samples = 2000\n",
    "angle = np.pi / 4\n",
    "R = np.array([[np.cos(angle), -np.sin(angle)],\n",
    "              [np.sin(angle),  np.cos(angle)]])\n",
    "C_true = R @ np.diag([3.0, 0.5]) @ R.T\n",
    "X = np.random.multivariate_normal([0, 0], C_true, n_samples)\n",
    "\n",
    "# True PC1\n",
    "evals, evecs = np.linalg.eigh(np.cov(X.T))\n",
    "pc1 = evecs[:, np.argmax(evals)]\n",
    "\n",
    "eta = 0.001\n",
    "n_iters = 3000  # use subset of data\n",
    "w_init = np.array([0.3, 0.7])\n",
    "w_init = w_init / np.linalg.norm(w_init)\n",
    "\n",
    "# ---- Rule 1: Basic Hebb ----\n",
    "w = w_init.copy()\n",
    "hebb_w1 = [w[0]]\n",
    "hebb_w2 = [w[1]]\n",
    "hebb_norms = [np.linalg.norm(w)]\n",
    "\n",
    "for t in range(n_iters):\n",
    "    x = X[t % n_samples]\n",
    "    y = w @ x\n",
    "    w = w + eta * y * x\n",
    "    hebb_w1.append(w[0])\n",
    "    hebb_w2.append(w[1])\n",
    "    hebb_norms.append(np.linalg.norm(w))\n",
    "\n",
    "# ---- Rule 2: Oja ----\n",
    "w = w_init.copy()\n",
    "oja_w1 = [w[0]]\n",
    "oja_w2 = [w[1]]\n",
    "oja_norms = [np.linalg.norm(w)]\n",
    "\n",
    "for t in range(n_iters):\n",
    "    x = X[t % n_samples]\n",
    "    y = w @ x\n",
    "    w = w + eta * (y * x - y**2 * w)\n",
    "    oja_w1.append(w[0])\n",
    "    oja_w2.append(w[1])\n",
    "    oja_norms.append(np.linalg.norm(w))\n",
    "\n",
    "# ---- Rule 3: BCM ----\n",
    "w = w_init.copy()\n",
    "bcm_w1 = [w[0]]\n",
    "bcm_w2 = [w[1]]\n",
    "bcm_norms = [np.linalg.norm(w)]\n",
    "theta_M = 0.1\n",
    "\n",
    "for t in range(n_iters):\n",
    "    x = X[t % n_samples]\n",
    "    y = w @ x\n",
    "    w = w + eta * x * y * (y - theta_M)\n",
    "    theta_M = theta_M + 0.01 * (y**2 - theta_M)\n",
    "    bcm_w1.append(w[0])\n",
    "    bcm_w2.append(w[1])\n",
    "    bcm_norms.append(np.linalg.norm(w))\n",
    "\n",
    "# ---- Rule 4: Perceptron (supervised, using sign of projection as target) ----\n",
    "# Create a simple binary classification target based on PC1\n",
    "targets = (X @ pc1 > 0).astype(float)  # binary target\n",
    "w = w_init.copy()\n",
    "perc_w1 = [w[0]]\n",
    "perc_w2 = [w[1]]\n",
    "perc_norms = [np.linalg.norm(w)]\n",
    "\n",
    "for t in range(n_iters):\n",
    "    idx = t % n_samples\n",
    "    x = X[idx]\n",
    "    y_pred = 1.0 if w @ x > 0 else 0.0\n",
    "    target = targets[idx]\n",
    "    w = w + eta * (target - y_pred) * x\n",
    "    perc_w1.append(w[0])\n",
    "    perc_w2.append(w[1])\n",
    "    perc_norms.append(np.linalg.norm(w))\n",
    "\n",
    "# ---- Visualization: 4-panel plot ----\n",
    "fig, axes = plt.subplots(2, 2, figsize=(14, 10))\n",
    "rules = [\n",
    "    ('Basic Hebb', hebb_w1, hebb_w2, hebb_norms, 'red'),\n",
    "    (\"Oja's Rule\", oja_w1, oja_w2, oja_norms, 'blue'),\n",
    "    ('BCM Rule', bcm_w1, bcm_w2, bcm_norms, 'green'),\n",
    "    ('Perceptron', perc_w1, perc_w2, perc_norms, 'purple'),\n",
    "]\n",
    "\n",
    "for ax, (name, w1_hist, w2_hist, norm_hist, color) in zip(axes.flat, rules):\n",
    "    # Weight trajectory in 2D weight space\n",
    "    w1_arr = np.array(w1_hist)\n",
    "    w2_arr = np.array(w2_hist)\n",
    "    \n",
    "    ax.plot(w1_arr, w2_arr, color=color, alpha=0.5, linewidth=0.5)\n",
    "    ax.plot(w1_arr[0], w2_arr[0], 'ko', markersize=8, label='Start')\n",
    "    ax.plot(w1_arr[-1], w2_arr[-1], 's', color=color, markersize=10, label='End')\n",
    "    \n",
    "    # Show PC1 direction\n",
    "    max_range = max(np.abs(w1_arr).max(), np.abs(w2_arr).max()) * 0.8\n",
    "    ax.annotate('', xy=pc1*max_range, xytext=-pc1*max_range,\n",
    "                arrowprops=dict(arrowstyle='->', color='gray', lw=1.5, linestyle='--'))\n",
    "    \n",
    "    # Draw unit circle\n",
    "    theta_c = np.linspace(0, 2*np.pi, 200)\n",
    "    ax.plot(np.cos(theta_c), np.sin(theta_c), 'k--', alpha=0.2, linewidth=1)\n",
    "    \n",
    "    ax.set_xlabel('$w_1$', fontsize=11)\n",
    "    ax.set_ylabel('$w_2$', fontsize=11)\n",
    "    ax.set_title(f'{name}\\n$||w||_{{final}}$ = {norm_hist[-1]:.3f}', fontsize=12)\n",
    "    ax.legend(fontsize=9, loc='lower right')\n",
    "    ax.set_aspect('equal')\n",
    "    ax.grid(True, alpha=0.3)\n",
    "\n",
    "plt.suptitle('Weight Trajectories: Four Learning Rules on Same 2D Data\\n'\n",
    "             '(gray dashed = PC1 direction, black dashed circle = unit circle)',\n",
    "             fontsize=14, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Also show norm evolution comparison\n",
    "fig, ax = plt.subplots(figsize=(10, 6))\n",
    "for name, _, _, norm_hist, color in rules:\n",
    "    ax.plot(norm_hist, color=color, linewidth=1.5, label=name)\n",
    "ax.set_yscale('log')\n",
    "ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='||w||=1')\n",
    "ax.set_xlabel('Iteration', fontsize=12)\n",
    "ax.set_ylabel('$||\\\\mathbf{w}||$ (log scale)', fontsize=12)\n",
    "ax.set_title('Weight Norm Evolution: All Four Rules', fontsize=13)\n",
    "ax.legend(fontsize=11)\n",
    "ax.grid(True, alpha=0.3)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(\"Observations:\")\n",
    "print(\"  - Hebb: weights spiral outward (divergence)\")\n",
    "print(\"  - Oja: weights converge to unit-norm PC1\")\n",
    "print(\"  - BCM: weights settle at a selective, stable fixed point\")\n",
    "print(\"  - Perceptron: weights converge to a decision boundary (supervised)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-4",
   "metadata": {},
   "source": [
    "## 14.3 From Hebbian to Supervised: The Gap\n",
    "\n",
    "### What Hebbian Learning Cannot Do\n",
    "\n",
    "Consider the XOR problem (from Chapter 8):\n",
    "\n",
    "| $x_1$ | $x_2$ | Target $y$ |\n",
    "|-------|-------|------------|\n",
    "| 0 | 0 | 0 |\n",
    "| 0 | 1 | 1 |\n",
    "| 1 | 0 | 1 |\n",
    "| 1 | 1 | 0 |\n",
    "\n",
    "**Why Hebbian learning fails on XOR**:\n",
    "\n",
    "1. **No target signal**: Hebbian learning does not know what the output *should* be.\n",
    "   It can only learn correlations in the input.\n",
    "\n",
    "2. **Linear projections only**: Even with Oja or Sanger, we can only learn linear projections.\n",
    "   XOR is not linearly separable.\n",
    "\n",
    "3. **No hidden representations**: To solve XOR, we need a hidden layer that creates an\n",
    "   appropriate internal representation. Hebbian learning provides no mechanism for\n",
    "   coordinating learning across layers.\n",
    "\n",
    "### The Credit Assignment Problem\n",
    "\n",
    "Even if we have a multi-layer network, the question remains: **how should hidden layer\n",
    "weights change to reduce output error?**\n",
    "\n",
    "This is the **credit assignment problem** (Minsky, 1961):\n",
    "\n",
    "> Given an error at the output, which internal weights (among potentially millions)\n",
    "> are responsible, and how should they be adjusted?\n",
    "\n",
    "Hebbian learning says: \"Strengthen connections between co-active neurons.\" But this says\n",
    "nothing about whether those activations are *useful* for the task.\n",
    "\n",
    "We need a learning rule that:\n",
    "1. Uses a **target signal** (supervised learning)\n",
    "2. Can propagate error information **through hidden layers**\n",
    "3. Computes the **correct gradient** of the loss with respect to all weights\n",
    "\n",
    "This is exactly what **backpropagation** provides."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-5",
   "metadata": {},
   "source": [
    "## 14.4 Preview: Backpropagation as the Solution\n",
    "\n",
    "In Part 5 (Chapters 15--19), we will develop the theory of **backpropagation**:\n",
    "\n",
    "1. **Chapter 15**: Gradient descent foundations -- optimizing a loss function.\n",
    "2. **Chapter 16**: The complete mathematical derivation of backpropagation.\n",
    "3. **Chapter 17**: Activation functions and the vanishing gradient problem.\n",
    "4. **Chapter 18**: Implementing backpropagation from scratch.\n",
    "5. **Chapter 19**: The Universal Approximation Theorem.\n",
    "\n",
    "Backpropagation solves the credit assignment problem by using the **chain rule** of calculus\n",
    "to compute exact gradients of the loss with respect to every weight in the network,\n",
    "regardless of depth.\n",
    "\n",
    "### The Price of Backpropagation\n",
    "\n",
    "While backpropagation solves the credit assignment problem, it sacrifices biological plausibility:\n",
    "\n",
    "| Property | Hebbian | Backpropagation |\n",
    "|----------|---------|----------------|\n",
    "| Target signal required | No | Yes |\n",
    "| Credit assignment | No | Yes |\n",
    "| Locality | Local | Non-local (weight transport) |\n",
    "| Biological plausibility | High | Low |\n",
    "| Can solve XOR | No | Yes |\n",
    "| Can train deep networks | No | Yes |\n",
    "\n",
    "The tension between biological plausibility and computational power remains an active\n",
    "research area (predictive coding, equilibrium propagation, feedback alignment, etc.)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-6",
   "metadata": {},
   "source": [
    "## Exercises\n",
    "\n",
    "**Exercise 14.1.** Implement the BCM rule with different time constants $\\tau_\\theta$ for\n",
    "the sliding threshold. How does $\\tau_\\theta$ affect (a) the speed of selectivity development\n",
    "and (b) the stability of the final weight vector?\n",
    "\n",
    "**Exercise 14.2.** Attempt to train a single neuron with Hebbian learning on the XOR problem.\n",
    "Show that it fails regardless of the learning rate or number of epochs.\n",
    "\n",
    "**Exercise 14.3.** Prove that the BCM fixed point $\\theta_M = \\langle y^2 \\rangle$ leads to\n",
    "selective responses. Specifically, show that at equilibrium, the neuron responds strongly to\n",
    "at most one input pattern class.\n",
    "\n",
    "**Exercise 14.4.** Compare all five learning rules (Basic Hebb, Covariance, Oja, Sanger, BCM)\n",
    "on the same synthetic dataset. Create a figure with 5 subplots showing the weight evolution\n",
    "for each rule.\n",
    "\n",
    "**Exercise 14.5.** Research and write a brief summary of one modern biologically-plausible\n",
    "alternative to backpropagation: feedback alignment (Lillicrap et al., 2016), predictive\n",
    "coding, or equilibrium propagation (Scellier & Bengio, 2017)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}