{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cell-0",
   "metadata": {},
   "source": [
    "# Chapter 10: Linear Separability --- A Deep Dive\n",
    "\n",
    "## Making the Concept Precise\n",
    "\n",
    "Linear separability is the central concept governing what a single-layer perceptron can and cannot compute. In previous chapters, we encountered it informally through examples like AND (separable) and XOR (not separable). In this chapter, we develop the theory systematically: formal definitions, the convex hull criterion with proof, Cover's function counting theorem, VC dimension, and the powerful idea of lifting data into higher-dimensional spaces where separability is recovered.\n",
    "\n",
    "These ideas form the mathematical backbone connecting perceptrons to support vector machines, kernel methods, and the broader theory of statistical learning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-1",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from itertools import combinations\n",
    "\n",
    "plt.rcParams.update({\n",
    "    'figure.figsize': (8, 6),\n",
    "    'font.size': 12,\n",
    "    'axes.grid': True,\n",
    "    'grid.alpha': 0.3\n",
    "})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-2",
   "metadata": {},
   "source": [
    "## 1. Formal Definition\n",
    "\n",
    "**Definition.** Two finite sets $A, B \\subset \\mathbb{R}^n$ are **linearly separable** if there exist a weight vector $\\mathbf{w} \\in \\mathbb{R}^n$ and a bias $b \\in \\mathbb{R}$ such that:\n",
    "\n",
    "$$\\mathbf{w} \\cdot \\mathbf{x} + b > 0 \\quad \\forall \\mathbf{x} \\in A$$\n",
    "\n",
    "$$\\mathbf{w} \\cdot \\mathbf{x} + b \\leq 0 \\quad \\forall \\mathbf{x} \\in B$$\n",
    "\n",
    "The **separating hyperplane** is the set:\n",
    "\n",
    "$$H = \\{\\mathbf{x} \\in \\mathbb{R}^n : \\mathbf{w} \\cdot \\mathbf{x} + b = 0\\}$$\n",
    "\n",
    "This is an $(n-1)$-dimensional affine subspace of $\\mathbb{R}^n$:\n",
    "- In $\\mathbb{R}^2$: $H$ is a line.\n",
    "- In $\\mathbb{R}^3$: $H$ is a plane.\n",
    "- In $\\mathbb{R}^n$: $H$ is a hyperplane.\n",
    "\n",
    "### Geometric Interpretation\n",
    "\n",
    "The weight vector $\\mathbf{w}$ is **normal** (perpendicular) to the hyperplane $H$. The bias $b$ controls the **offset** of the hyperplane from the origin. Points in $A$ lie on the positive side of $H$ (where $\\mathbf{w} \\cdot \\mathbf{x} + b > 0$), and points in $B$ lie on the negative side.\n",
    "\n",
    "### Equivalence with Perceptron Computation\n",
    "\n",
    "A perceptron with weight vector $\\mathbf{w}$ and bias $b$ computes:\n",
    "\n",
    "$$f(\\mathbf{x}) = \\text{step}(\\mathbf{w} \\cdot \\mathbf{x} + b) = \\begin{cases} 1 & \\text{if } \\mathbf{w} \\cdot \\mathbf{x} + b \\geq 0 \\\\ 0 & \\text{if } \\mathbf{w} \\cdot \\mathbf{x} + b < 0 \\end{cases}$$\n",
    "\n",
    "Therefore, a perceptron can classify $A$ vs. $B$ correctly **if and only if** $A$ and $B$ are linearly separable."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-3",
   "metadata": {},
   "source": [
    "## 2. The Convex Hull Criterion\n",
    "\n",
    "The most elegant characterization of linear separability uses the concept of **convex hulls**.\n",
    "\n",
    "**Definition.** The *convex hull* of a finite set $S = \\{\\mathbf{x}_1, \\ldots, \\mathbf{x}_m\\} \\subset \\mathbb{R}^n$ is:\n",
    "\n",
    "$$\\text{conv}(S) = \\left\\{\\sum_{i=1}^{m} \\lambda_i \\mathbf{x}_i : \\lambda_i \\geq 0, \\; \\sum_{i=1}^{m} \\lambda_i = 1\\right\\}$$\n",
    "\n",
    "It is the smallest convex set containing $S$, or equivalently, the set of all convex combinations of points in $S$.\n",
    "\n",
    "```{admonition} Theorem (Convex Hull Separability)\n",
    ":class: important\n",
    "\n",
    "Two finite sets $A, B \\subset \\mathbb{R}^n$ are linearly separable **if and only if** their convex hulls are disjoint:\n",
    "\n",
    "$$A, B \\text{ linearly separable} \\iff \\text{conv}(A) \\cap \\text{conv}(B) = \\emptyset$$\n",
    "\n",
    "This provides a purely geometric characterization: instead of searching over all possible hyperplanes, you only need to check whether two convex bodies overlap.\n",
    "```\n",
    "\n",
    "```{admonition} Proof\n",
    ":class: dropdown\n",
    "\n",
    "$(\\Rightarrow)$ **Separable implies disjoint convex hulls.**\n",
    "\n",
    "Assume $A$ and $B$ are separated by the hyperplane $\\mathbf{w} \\cdot \\mathbf{x} + b = 0$, so:\n",
    "- $\\mathbf{w} \\cdot \\mathbf{a} + b > 0$ for all $\\mathbf{a} \\in A$\n",
    "- $\\mathbf{w} \\cdot \\mathbf{b}' + b \\leq 0$ for all $\\mathbf{b}' \\in B$\n",
    "\n",
    "Let $\\mathbf{p} \\in \\text{conv}(A)$, so $\\mathbf{p} = \\sum_i \\lambda_i \\mathbf{a}_i$ with $\\lambda_i \\geq 0$, $\\sum \\lambda_i = 1$. Then:\n",
    "\n",
    "$$\\mathbf{w} \\cdot \\mathbf{p} + b = \\sum_i \\lambda_i (\\mathbf{w} \\cdot \\mathbf{a}_i + b) > 0$$\n",
    "\n",
    "since each term is positive and the $\\lambda_i$ are non-negative with sum 1.\n",
    "\n",
    "Similarly, let $\\mathbf{q} \\in \\text{conv}(B)$, so $\\mathbf{q} = \\sum_j \\mu_j \\mathbf{b}_j$. Then:\n",
    "\n",
    "$$\\mathbf{w} \\cdot \\mathbf{q} + b = \\sum_j \\mu_j (\\mathbf{w} \\cdot \\mathbf{b}_j + b) \\leq 0$$\n",
    "\n",
    "Therefore $\\mathbf{w} \\cdot \\mathbf{p} + b > 0 \\geq \\mathbf{w} \\cdot \\mathbf{q} + b$, so $\\mathbf{p} \\neq \\mathbf{q}$. Since $\\mathbf{p}$ and $\\mathbf{q}$ were arbitrary, $\\text{conv}(A) \\cap \\text{conv}(B) = \\emptyset$.\n",
    "\n",
    "$(\\Leftarrow)$ **Disjoint convex hulls implies separable.**\n",
    "\n",
    "If $\\text{conv}(A) \\cap \\text{conv}(B) = \\emptyset$, then we have two disjoint compact convex sets (compact because they are convex hulls of finite sets). By the **Separating Hyperplane Theorem** (a fundamental result in convex analysis), there exists a hyperplane $\\mathbf{w} \\cdot \\mathbf{x} + b = 0$ that strictly separates them. In particular, $\\mathbf{w} \\cdot \\mathbf{x} + b > 0$ for all $\\mathbf{x} \\in \\text{conv}(A) \\supseteq A$ and $\\mathbf{w} \\cdot \\mathbf{x} + b < 0$ for all $\\mathbf{x} \\in \\text{conv}(B) \\supseteq B$. $\\blacksquare$\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-4",
   "metadata": {},
   "source": [
    "## 3. Convex Hull Visualization\n",
    "\n",
    "Let us visualize the convex hull criterion for several 2D datasets, both separable and non-separable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-5",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Compute convex hull of 2D points (Graham scan)\n",
    "\n",
    "def convex_hull_2d(points):\n",
    "    \"\"\"Compute the convex hull of 2D points using Graham scan.\n",
    "    Returns the hull vertices in counter-clockwise order.\"\"\"\n",
    "    points = np.array(points)\n",
    "    if len(points) <= 1:\n",
    "        return points\n",
    "    if len(points) == 2:\n",
    "        return points\n",
    "    \n",
    "    # Find the point with the lowest y-coordinate (leftmost if tie)\n",
    "    start = np.lexsort((points[:, 0], points[:, 1]))[0]\n",
    "    pivot = points[start]\n",
    "    \n",
    "    # Sort by polar angle with respect to pivot\n",
    "    def angle_key(p):\n",
    "        return np.arctan2(p[1] - pivot[1], p[0] - pivot[0])\n",
    "    \n",
    "    indices = list(range(len(points)))\n",
    "    indices.sort(key=lambda i: (angle_key(points[i]), np.linalg.norm(points[i] - pivot)))\n",
    "    \n",
    "    def cross(o, a, b):\n",
    "        return (a[0]-o[0])*(b[1]-o[1]) - (a[1]-o[1])*(b[0]-o[0])\n",
    "    \n",
    "    hull = []\n",
    "    for idx in indices:\n",
    "        while len(hull) >= 2 and cross(points[hull[-2]], points[hull[-1]], points[idx]) <= 0:\n",
    "            hull.pop()\n",
    "        hull.append(idx)\n",
    "    \n",
    "    return points[hull]\n",
    "\n",
    "\n",
    "def check_hull_intersection(hull_A, hull_B):\n",
    "    \"\"\"Simple check if two convex polygons (given as hull vertices) intersect.\n",
    "    Uses the Separating Axis Theorem (SAT).\"\"\"\n",
    "    def get_edges(hull):\n",
    "        edges = []\n",
    "        n = len(hull)\n",
    "        for i in range(n):\n",
    "            edge = hull[(i+1) % n] - hull[i]\n",
    "            edges.append(edge)\n",
    "        return edges\n",
    "    \n",
    "    def project(hull, axis):\n",
    "        projections = np.dot(hull, axis)\n",
    "        return projections.min(), projections.max()\n",
    "    \n",
    "    for hull in [hull_A, hull_B]:\n",
    "        if len(hull) < 2:\n",
    "            continue\n",
    "        edges = get_edges(hull)\n",
    "        for edge in edges:\n",
    "            # Normal to the edge\n",
    "            normal = np.array([-edge[1], edge[0]])\n",
    "            if np.linalg.norm(normal) < 1e-10:\n",
    "                continue\n",
    "            normal = normal / np.linalg.norm(normal)\n",
    "            \n",
    "            min_A, max_A = project(hull_A, normal)\n",
    "            min_B, max_B = project(hull_B, normal)\n",
    "            \n",
    "            if max_A < min_B - 1e-10 or max_B < min_A - 1e-10:\n",
    "                return False  # Found a separating axis\n",
    "    \n",
    "    return True  # No separating axis found => intersection\n",
    "\n",
    "\n",
    "print(\"Convex hull and intersection checking functions defined.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-6",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Generate and visualize several datasets\n",
    "\n",
    "np.random.seed(42)\n",
    "\n",
    "datasets = [\n",
    "    # (name, class_A, class_B)\n",
    "    (\"Separable: Two Clusters\",\n",
    "     np.random.randn(8, 2) * 0.5 + np.array([-2, 0]),\n",
    "     np.random.randn(8, 2) * 0.5 + np.array([2, 0])),\n",
    "    \n",
    "    (\"Separable: Diagonal\",\n",
    "     np.random.randn(8, 2) * 0.4 + np.array([-1.5, -1.5]),\n",
    "     np.random.randn(8, 2) * 0.4 + np.array([1.5, 1.5])),\n",
    "    \n",
    "    (\"NOT Separable: XOR Pattern\",\n",
    "     np.array([[0, 0], [1, 1]]) + np.random.randn(2, 2) * 0.15,\n",
    "     np.array([[0, 1], [1, 0]]) + np.random.randn(2, 2) * 0.15),\n",
    "    \n",
    "    (\"NOT Separable: Interleaved\",\n",
    "     np.array([[-1, 0], [0, 1], [1, 0], [0, -1.0]]),\n",
    "     np.array([[0, 0], [-0.5, 0.5], [0.5, 0.5], [0, -0.3]]))\n",
    "]\n",
    "\n",
    "fig, axes = plt.subplots(2, 2, figsize=(14, 12))\n",
    "\n",
    "for idx, (name, A, B) in enumerate(datasets):\n",
    "    ax = axes[idx // 2][idx % 2]\n",
    "    \n",
    "    # Compute convex hulls\n",
    "    hull_A = convex_hull_2d(A)\n",
    "    hull_B = convex_hull_2d(B)\n",
    "    \n",
    "    # Check intersection\n",
    "    intersects = check_hull_intersection(hull_A, hull_B)\n",
    "    separable = not intersects\n",
    "    \n",
    "    # Plot points\n",
    "    ax.scatter(A[:, 0], A[:, 1], c='blue', s=100, zorder=5,\n",
    "               edgecolors='black', linewidths=1.5, marker='o', label='Class A')\n",
    "    ax.scatter(B[:, 0], B[:, 1], c='red', s=100, zorder=5,\n",
    "               edgecolors='black', linewidths=1.5, marker='s', label='Class B')\n",
    "    \n",
    "    # Draw convex hulls\n",
    "    if len(hull_A) >= 3:\n",
    "        hull_closed = np.vstack([hull_A, hull_A[0]])\n",
    "        ax.fill(hull_A[:, 0], hull_A[:, 1], alpha=0.15, color='blue')\n",
    "        ax.plot(hull_closed[:, 0], hull_closed[:, 1], 'b-', linewidth=2, alpha=0.5)\n",
    "    elif len(hull_A) == 2:\n",
    "        ax.plot(hull_A[:, 0], hull_A[:, 1], 'b-', linewidth=2, alpha=0.5)\n",
    "    \n",
    "    if len(hull_B) >= 3:\n",
    "        hull_closed = np.vstack([hull_B, hull_B[0]])\n",
    "        ax.fill(hull_B[:, 0], hull_B[:, 1], alpha=0.15, color='red')\n",
    "        ax.plot(hull_closed[:, 0], hull_closed[:, 1], 'r-', linewidth=2, alpha=0.5)\n",
    "    elif len(hull_B) == 2:\n",
    "        ax.plot(hull_B[:, 0], hull_B[:, 1], 'r-', linewidth=2, alpha=0.5)\n",
    "    \n",
    "    # Title with separability status\n",
    "    color = 'green' if separable else 'red'\n",
    "    status = 'SEPARABLE' if separable else 'NOT SEPARABLE'\n",
    "    ax.set_title(f\"{name}\\nConvex hulls {'disjoint' if separable else 'overlap'} => {status}\",\n",
    "                 fontsize=12, color=color)\n",
    "    ax.legend(fontsize=10)\n",
    "    ax.set_aspect('equal')\n",
    "    ax.grid(True, alpha=0.3)\n",
    "\n",
    "plt.suptitle('Convex Hull Criterion for Linear Separability', fontsize=15, y=1.01)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-7",
   "metadata": {},
   "source": [
    "## 4. Cover's Function Counting Theorem\n",
    "\n",
    "A fundamental question in the theory of linear separability is: **how many** of the possible labelings of $m$ points can be achieved by a hyperplane in $\\mathbb{R}^d$?\n",
    "\n",
    "### Setting\n",
    "\n",
    "Given $m$ points $\\mathbf{x}_1, \\ldots, \\mathbf{x}_m \\in \\mathbb{R}^d$ in **general position** (meaning no $d+1$ points lie on a common hyperplane), a **dichotomy** is a partition of the points into two classes. There are $2^m$ possible dichotomies.\n",
    "\n",
    "A dichotomy is **linearly realizable** if there exists a hyperplane separating the two classes.\n",
    "\n",
    "```{admonition} Theorem (Cover, 1965)\n",
    ":class: important\n",
    "\n",
    "The number of linearly realizable dichotomies of $m$ points in general position in $\\mathbb{R}^d$ is:\n",
    "\n",
    "$$C(m, d) = 2 \\sum_{k=0}^{d} \\binom{m-1}{k}$$\n",
    "\n",
    "**Properties:**\n",
    "1. When $m \\leq d+1$: $C(m, d) = 2^m$. All dichotomies are realizable (the points can be **shattered**).\n",
    "2. When $m = 2(d+1)$: $C(m, d) \\approx 2^{m-1}$. Approximately half of all dichotomies are realizable.\n",
    "3. When $m \\gg d$: $C(m, d) \\ll 2^m$. Almost no dichotomies are realizable.\n",
    "\n",
    "There is a sharp **phase transition** around $m = 2d$: below this threshold, most dichotomies are separable; above it, most are not.\n",
    "```\n",
    "\n",
    "```{tip}\n",
    "**Why higher dimensions help separability (Cover's theorem).**\n",
    "\n",
    "Cover's theorem reveals a profound principle: **data that is not separable in low dimensions often becomes separable in high dimensions**. The fraction of realizable dichotomies depends on the ratio $m/d$ (number of points to number of dimensions). If you double the number of dimensions while keeping the number of points fixed, you go from \"almost no dichotomies work\" to \"almost all dichotomies work.\" This is the mathematical foundation for:\n",
    "- **Feature engineering**: adding computed features increases dimension\n",
    "- **Kernel methods**: implicitly mapping to high-dimensional feature spaces\n",
    "- **Hidden layers in neural networks**: the hidden layer creates a high-dimensional representation where the output layer can separate the classes\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-8",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Cover's Function Counting Theorem: computation and visualization\n",
    "\n",
    "from math import comb\n",
    "\n",
    "def cover_count(m, d):\n",
    "    \"\"\"Number of linearly realizable dichotomies of m points in R^d.\"\"\"\n",
    "    return 2 * sum(comb(m - 1, k) for k in range(d + 1))\n",
    "\n",
    "# Table of C(m, d)\n",
    "print(\"Cover's Function C(m, d): Number of realizable dichotomies\")\n",
    "print(\"=\" * 65)\n",
    "print(f\"{'m':>4}\", end=\"\")\n",
    "for d in range(1, 7):\n",
    "    print(f\"{'d='+str(d):>10}\", end=\"\")\n",
    "print(f\"{'2^m':>10}\")\n",
    "print(\"-\" * 65)\n",
    "\n",
    "for m in range(1, 13):\n",
    "    print(f\"{m:>4}\", end=\"\")\n",
    "    for d in range(1, 7):\n",
    "        c = cover_count(m, d)\n",
    "        total = 2**m\n",
    "        if c >= total:\n",
    "            print(f\"{'ALL':>10}\", end=\"\")\n",
    "        else:\n",
    "            print(f\"{c:>10}\", end=\"\")\n",
    "    print(f\"{2**m:>10}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-9",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Plot: fraction of realizable dichotomies as m grows\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
    "\n",
    "# Left: fraction for several dimensions\n",
    "ax1 = axes[0]\n",
    "for d in [1, 2, 3, 5, 10]:\n",
    "    m_vals = np.arange(1, 8 * d + 1)\n",
    "    fractions = [min(cover_count(m, d) / 2**m, 1.0) for m in m_vals]\n",
    "    ax1.plot(m_vals / (2*d), fractions, '-', linewidth=2, label=f'd = {d}')\n",
    "\n",
    "ax1.axvline(x=1, color='black', linestyle='--', alpha=0.5, label='$m = 2d$ (transition)')\n",
    "ax1.set_xlabel('$m / (2d)$', fontsize=13)\n",
    "ax1.set_ylabel('Fraction of realizable dichotomies', fontsize=12)\n",
    "ax1.set_title('Cover\\'s Theorem: Phase Transition at $m \\\\approx 2d$', fontsize=13)\n",
    "ax1.legend(fontsize=10)\n",
    "ax1.grid(True, alpha=0.3)\n",
    "ax1.set_ylim(-0.05, 1.1)\n",
    "\n",
    "# Right: absolute count for d=2\n",
    "ax2 = axes[1]\n",
    "d = 2\n",
    "m_vals = np.arange(1, 20)\n",
    "cover_vals = [cover_count(m, d) for m in m_vals]\n",
    "total_vals = [2**m for m in m_vals]\n",
    "\n",
    "ax2.semilogy(m_vals, total_vals, 'r-o', markersize=5, label=f'$2^m$ (all dichotomies)')\n",
    "ax2.semilogy(m_vals, cover_vals, 'b-s', markersize=5, label=f'$C(m, {d})$ (realizable)')\n",
    "ax2.axvline(x=2*(d+1), color='green', linestyle='--', alpha=0.7,\n",
    "            label=f'$m = 2(d+1) = {2*(d+1)}$ (half-point)')\n",
    "ax2.set_xlabel('$m$ (number of points)', fontsize=13)\n",
    "ax2.set_ylabel('Number of dichotomies (log scale)', fontsize=12)\n",
    "ax2.set_title(f'$d = {d}$: Realizable vs. Total Dichotomies', fontsize=13)\n",
    "ax2.legend(fontsize=10)\n",
    "ax2.grid(True, alpha=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-9b",
   "metadata": {},
   "source": [
    "## 4a. Graph of Cover's Counting Function\n",
    "\n",
    "Let us plot the number of linearly separable dichotomies $C(m,d)$ versus the number of points $m$ for several dimensions $d$, overlaid with $2^m$ to see where the gap opens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-9c",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from math import comb\n",
    "\n",
    "def cover_count(m, d):\n",
    "    \"\"\"Number of linearly realizable dichotomies of m points in R^d.\"\"\"\n",
    "    return 2 * sum(comb(m - 1, k) for k in range(d + 1))\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
    "\n",
    "# Left panel: C(m,d) for various d on log scale\n",
    "ax1 = axes[0]\n",
    "colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']\n",
    "dims = [1, 2, 3, 5, 10]\n",
    "for i, d in enumerate(dims):\n",
    "    m_vals = np.arange(1, 6 * d + 1)\n",
    "    c_vals = [cover_count(int(m), d) for m in m_vals]\n",
    "    total_vals = [2**int(m) for m in m_vals]\n",
    "    ax1.semilogy(m_vals, c_vals, '-o', color=colors[i], markersize=3,\n",
    "                 linewidth=2, label=f'$C(m, {d})$')\n",
    "\n",
    "# Also plot 2^m\n",
    "m_all = np.arange(1, 61)\n",
    "ax1.semilogy(m_all, [2**int(m) for m in m_all], 'k--', linewidth=2, alpha=0.5, label='$2^m$ (all)')\n",
    "\n",
    "ax1.set_xlabel('$m$ (number of points)', fontsize=13)\n",
    "ax1.set_ylabel('Number of dichotomies (log scale)', fontsize=12)\n",
    "ax1.set_title('Cover\\'s Counting Function $C(m,d)$\\nvs. total dichotomies $2^m$', fontsize=13)\n",
    "ax1.legend(fontsize=10, loc='lower right')\n",
    "ax1.grid(True, alpha=0.3)\n",
    "ax1.set_xlim(1, 60)\n",
    "\n",
    "# Right panel: fraction C(m,d) / 2^m showing the phase transition\n",
    "ax2 = axes[1]\n",
    "for i, d in enumerate(dims):\n",
    "    m_vals = np.arange(1, 8 * d + 1)\n",
    "    fractions = [min(cover_count(int(m), d) / 2**int(m), 1.0) for m in m_vals]\n",
    "    ax2.plot(m_vals, fractions, '-', color=colors[i], linewidth=2, label=f'd = {d}')\n",
    "    # Mark the transition point m = 2d\n",
    "    ax2.axvline(x=2*d, color=colors[i], linestyle=':', alpha=0.3)\n",
    "\n",
    "ax2.set_xlabel('$m$ (number of points)', fontsize=13)\n",
    "ax2.set_ylabel('Fraction $C(m,d) / 2^m$', fontsize=12)\n",
    "ax2.set_title('Phase Transition: Fraction of Separable Dichotomies\\n'\n",
    "              'Dotted lines mark $m = 2d$', fontsize=13)\n",
    "ax2.legend(fontsize=10)\n",
    "ax2.grid(True, alpha=0.3)\n",
    "ax2.set_ylim(-0.05, 1.1)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(\"Key observation: For each dimension d, the fraction of separable dichotomies\")\n",
    "print(\"drops sharply around m = 2d points. This is the phase transition.\")\n",
    "print(\"Higher dimensions push the transition to more points --- this is why\")\n",
    "print(\"mapping to higher dimensions helps with separability.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-10",
   "metadata": {},
   "source": [
    "## 5. VC Dimension\n",
    "\n",
    "The **Vapnik-Chervonenkis (VC) dimension** formalizes the expressive power of a hypothesis class.\n",
    "\n",
    "```{admonition} Definition (VC Dimension)\n",
    ":class: note\n",
    "\n",
    "**Shattering.** A hypothesis class $\\mathcal{H}$ **shatters** a set of points $S = \\{\\mathbf{x}_1, \\ldots, \\mathbf{x}_m\\}$ if for every possible labeling $(y_1, \\ldots, y_m) \\in \\{0, 1\\}^m$, there exists $h \\in \\mathcal{H}$ such that $h(\\mathbf{x}_i) = y_i$ for all $i$. In other words, $\\mathcal{H}$ can realize all $2^m$ dichotomies of $S$.\n",
    "\n",
    "**VC Dimension.** The VC dimension of $\\mathcal{H}$ is:\n",
    "\n",
    "$$\\text{VCdim}(\\mathcal{H}) = \\max\\{m : \\exists S \\text{ of size } m \\text{ shattered by } \\mathcal{H}\\}$$\n",
    "\n",
    "It is the **largest** number of points that can be arranged so that the hypothesis class can realize every possible labeling. Note the existential quantifier: we only need **one** arrangement of $m$ points that can be shattered, not all arrangements.\n",
    "```\n",
    "\n",
    "### VC Dimension of Perceptrons\n",
    "\n",
    "**Theorem.** The VC dimension of the class of perceptrons (linear classifiers) in $\\mathbb{R}^n$ is $n + 1$.\n",
    "\n",
    "```{admonition} Proof\n",
    ":class: dropdown\n",
    "\n",
    "**Proof sketch:**\n",
    "1. *VC dim $\\geq n+1$*: Take $n+1$ points in general position (e.g., the origin plus the $n$ standard basis vectors). Any labeling can be achieved by a hyperplane. This follows from the fact that $n+1$ points in general position in $\\mathbb{R}^n$ always have $C(n+1, n) = 2^{n+1}$ realizable dichotomies (by Cover's theorem, since $m = n+1 \\leq n+1 = d+1$).\n",
    "\n",
    "2. *VC dim $\\leq n+1$*: By **Radon's theorem**, any set of $n+2$ points in $\\mathbb{R}^n$ can be partitioned into two disjoint subsets $P, Q$ such that $\\text{conv}(P) \\cap \\text{conv}(Q) \\neq \\emptyset$. The labeling that assigns one class to $P$ and the other to $Q$ cannot be realized by any hyperplane (by the Convex Hull Criterion). Therefore no set of $n+2$ points can be shattered. $\\blacksquare$\n",
    "```\n",
    "\n",
    "```{danger}\n",
    "**Curse of dimensionality** --- more dimensions do NOT always help.\n",
    "\n",
    "While Cover's theorem shows that higher dimensions increase the fraction of separable dichotomies, there is a hidden cost: **the curse of dimensionality**. In high-dimensional spaces:\n",
    "- Data becomes **sparse**: the volume of the space grows exponentially, so the same number of data points covers an exponentially smaller fraction.\n",
    "- **Distance concentration**: all pairwise distances become nearly equal, making nearest-neighbor methods unreliable.\n",
    "- **Overfitting**: with enough dimensions, any dataset becomes trivially separable, but the resulting classifier generalizes poorly.\n",
    "\n",
    "The VC dimension gives us a quantitative handle on this tradeoff: a hypothesis class with VC dimension $d$ needs roughly $O(d / \\epsilon^2)$ training examples to generalize well (by the VC theorem). So increasing dimension helps separability but hurts generalization --- unless the training set grows proportionally.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-11",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# VC Dimension visualization: 3 points in R^2, all 8 dichotomies\n",
    "\n",
    "# Choose 3 points in general position\n",
    "points = np.array([[0, 0], [1, 0], [0.5, 0.87]])\n",
    "\n",
    "fig, axes = plt.subplots(2, 4, figsize=(16, 8))\n",
    "\n",
    "# All 8 labelings of 3 points\n",
    "for idx in range(8):\n",
    "    ax = axes[idx // 4][idx % 4]\n",
    "    labels = [(idx >> i) & 1 for i in range(3)]\n",
    "    \n",
    "    # Plot points\n",
    "    for i, (pt, lbl) in enumerate(zip(points, labels)):\n",
    "        color = 'red' if lbl == 1 else 'blue'\n",
    "        marker = 's' if lbl == 1 else 'o'\n",
    "        ax.scatter(pt[0], pt[1], c=color, s=200, marker=marker,\n",
    "                   edgecolors='black', linewidths=2, zorder=5)\n",
    "    \n",
    "    # Find a separating line (solve for w1*x + w2*y + b = 0)\n",
    "    # Using a simple approach: try to find weights\n",
    "    class_pos = points[np.array(labels) == 1]\n",
    "    class_neg = points[np.array(labels) == 0]\n",
    "    \n",
    "    # Find separating line by optimization\n",
    "    found = False\n",
    "    if len(class_pos) == 0 or len(class_neg) == 0:\n",
    "        # Trivial case: all same class\n",
    "        if len(class_pos) == 0:\n",
    "            # All negative: any line with all points on negative side\n",
    "            w = np.array([0, -1])\n",
    "            b = -1.5\n",
    "        else:\n",
    "            w = np.array([0, 1])\n",
    "            b = -0.1 if len(class_pos) == 3 else 1.5\n",
    "        found = True\n",
    "    else:\n",
    "        # Try many random directions\n",
    "        best_w, best_b = None, None\n",
    "        for angle in np.linspace(0, 2*np.pi, 360):\n",
    "            w_try = np.array([np.cos(angle), np.sin(angle)])\n",
    "            proj_pos = [np.dot(w_try, p) for p in class_pos]\n",
    "            proj_neg = [np.dot(w_try, p) for p in class_neg]\n",
    "            if min(proj_pos) > max(proj_neg):\n",
    "                b_try = -(min(proj_pos) + max(proj_neg)) / 2\n",
    "                best_w = w_try\n",
    "                best_b = b_try\n",
    "                found = True\n",
    "                break\n",
    "        if found:\n",
    "            w = best_w\n",
    "            b = best_b\n",
    "    \n",
    "    # Draw separating line if found\n",
    "    if found:\n",
    "        x_line = np.linspace(-0.5, 1.5, 100)\n",
    "        if abs(w[1]) > 1e-10:\n",
    "            y_line = -(w[0] * x_line + b) / w[1]\n",
    "            valid = (y_line >= -0.5) & (y_line <= 1.5)\n",
    "            ax.plot(x_line[valid], y_line[valid], 'g-', linewidth=2, alpha=0.7)\n",
    "        else:\n",
    "            ax.axvline(x=-b/w[0], color='g', linewidth=2, alpha=0.7)\n",
    "    \n",
    "    label_str = ''.join(str(l) for l in labels)\n",
    "    ax.set_title(f'Labels: {label_str}', fontsize=11)\n",
    "    ax.set_xlim(-0.3, 1.3)\n",
    "    ax.set_ylim(-0.3, 1.2)\n",
    "    ax.set_aspect('equal')\n",
    "    ax.grid(True, alpha=0.2)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "\n",
    "plt.suptitle('VC Dimension: 3 Points in $\\\\mathbb{R}^2$ --- All 8 Dichotomies are Realizable\\n'\n",
    "             'Therefore VCdim(perceptron in $\\\\mathbb{R}^2$) $\\\\geq$ 3',\n",
    "             fontsize=14, y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-11b",
   "metadata": {},
   "source": [
    "## 5a. VC Dimension Shatter Diagrams\n",
    "\n",
    "To develop deeper intuition for shattering, let us explicitly visualize the shatter diagram: for 3 points in $\\mathbb{R}^2$, we show all 8 dichotomies with their separating lines, and then show that for 4 points, at least one dichotomy (the XOR pattern) cannot be realized."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-11c",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, axes = plt.subplots(2, 5, figsize=(20, 8))\n",
    "\n",
    "# ---- Top row: 3 points shattered by lines in 2D ----\n",
    "pts3 = np.array([[0.2, 0.1], [0.8, 0.1], [0.5, 0.8]])\n",
    "\n",
    "for idx in range(8):\n",
    "    ax = axes[0][idx] if idx < 5 else axes[1][idx - 5]\n",
    "    labels = [(idx >> i) & 1 for i in range(3)]\n",
    "\n",
    "    for i, (pt, lbl) in enumerate(zip(pts3, labels)):\n",
    "        color = 'red' if lbl == 1 else 'blue'\n",
    "        marker = 's' if lbl == 1 else 'o'\n",
    "        ax.scatter(pt[0], pt[1], c=color, s=250, marker=marker,\n",
    "                   edgecolors='black', linewidths=2, zorder=5)\n",
    "\n",
    "    # Find separating line\n",
    "    class_pos = pts3[np.array(labels) == 1]\n",
    "    class_neg = pts3[np.array(labels) == 0]\n",
    "\n",
    "    found = False\n",
    "    if len(class_pos) == 0 or len(class_neg) == 0:\n",
    "        found = True\n",
    "        w = np.array([0, 1])\n",
    "        b = -1.5 if len(class_pos) == 0 else 0.05\n",
    "    else:\n",
    "        for angle in np.linspace(0, 2 * np.pi, 720):\n",
    "            w_try = np.array([np.cos(angle), np.sin(angle)])\n",
    "            proj_pos = [np.dot(w_try, p) for p in class_pos]\n",
    "            proj_neg = [np.dot(w_try, p) for p in class_neg]\n",
    "            if min(proj_pos) > max(proj_neg) + 1e-6:\n",
    "                b = -(min(proj_pos) + max(proj_neg)) / 2\n",
    "                w = w_try\n",
    "                found = True\n",
    "                break\n",
    "\n",
    "    if found:\n",
    "        x_line = np.linspace(-0.2, 1.2, 200)\n",
    "        if abs(w[1]) > 1e-10:\n",
    "            y_line = -(w[0] * x_line + b) / w[1]\n",
    "            valid = (y_line >= -0.2) & (y_line <= 1.2)\n",
    "            ax.plot(x_line[valid], y_line[valid], 'g-', linewidth=2.5, alpha=0.8)\n",
    "        else:\n",
    "            ax.axvline(x=-b / w[0], color='g', linewidth=2.5, alpha=0.8)\n",
    "\n",
    "    label_str = ''.join(str(l) for l in labels)\n",
    "    ax.set_title(f'{label_str}', fontsize=12, fontweight='bold',\n",
    "                 color='green')\n",
    "    ax.set_xlim(-0.1, 1.1)\n",
    "    ax.set_ylim(-0.1, 1.0)\n",
    "    ax.set_aspect('equal')\n",
    "    ax.grid(True, alpha=0.2)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "\n",
    "# ---- Bottom row remaining: 4 points, show XOR failure ----\n",
    "# Fill remaining bottom-row axes\n",
    "pts4 = np.array([[0.2, 0.2], [0.8, 0.2], [0.2, 0.8], [0.8, 0.8]])\n",
    "\n",
    "# Show a separable dichotomy\n",
    "ax_ok = axes[1][3]\n",
    "labels_ok = [0, 1, 0, 1]\n",
    "for i, (pt, lbl) in enumerate(zip(pts4, labels_ok)):\n",
    "    color = 'red' if lbl == 1 else 'blue'\n",
    "    marker = 's' if lbl == 1 else 'o'\n",
    "    ax_ok.scatter(pt[0], pt[1], c=color, s=250, marker=marker,\n",
    "                  edgecolors='black', linewidths=2, zorder=5)\n",
    "# Draw separating line\n",
    "ax_ok.axvline(x=0.5, color='green', linewidth=2.5)\n",
    "ax_ok.set_title('4pts: 0101\\nSeparable', fontsize=11, color='green', fontweight='bold')\n",
    "ax_ok.set_xlim(-0.1, 1.1)\n",
    "ax_ok.set_ylim(-0.1, 1.0)\n",
    "ax_ok.set_aspect('equal')\n",
    "ax_ok.grid(True, alpha=0.2)\n",
    "ax_ok.set_xticks([])\n",
    "ax_ok.set_yticks([])\n",
    "\n",
    "# Show the XOR failure\n",
    "ax_fail = axes[1][4]\n",
    "labels_fail = [0, 1, 1, 0]  # XOR pattern\n",
    "for i, (pt, lbl) in enumerate(zip(pts4, labels_fail)):\n",
    "    color = 'red' if lbl == 1 else 'blue'\n",
    "    marker = 's' if lbl == 1 else 'o'\n",
    "    ax_fail.scatter(pt[0], pt[1], c=color, s=250, marker=marker,\n",
    "                    edgecolors='black', linewidths=2, zorder=5)\n",
    "# Draw the intersecting convex hulls\n",
    "ax_fail.plot([0.2, 0.8], [0.2, 0.8], 'b--', linewidth=2, alpha=0.5)\n",
    "ax_fail.plot([0.8, 0.2], [0.2, 0.8], 'r--', linewidth=2, alpha=0.5)\n",
    "ax_fail.plot(0.5, 0.5, 'kX', markersize=15, zorder=10)\n",
    "ax_fail.set_title('4pts: 0110 (XOR)\\nNOT Separable!', fontsize=11, color='red', fontweight='bold')\n",
    "ax_fail.set_xlim(-0.1, 1.1)\n",
    "ax_fail.set_ylim(-0.1, 1.0)\n",
    "ax_fail.set_aspect('equal')\n",
    "ax_fail.grid(True, alpha=0.2)\n",
    "ax_fail.set_xticks([])\n",
    "ax_fail.set_yticks([])\n",
    "\n",
    "plt.suptitle('Shatter Diagrams: 3 points can be shattered (all 8 dichotomies work)\\n'\n",
    "             '4 points CANNOT be shattered (XOR dichotomy fails) $\\\\Rightarrow$ VCdim = 3',\n",
    "             fontsize=14, y=1.04)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-12",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Show that 4 points in R^2 CANNOT all be shattered\n",
    "# The XOR configuration is the failing dichotomy\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
    "\n",
    "# 4 points in general position\n",
    "pts4 = np.array([[0, 0], [1, 0], [0, 1], [0.5, 0.5]])\n",
    "\n",
    "# Show a solvable dichotomy\n",
    "ax1 = axes[0]\n",
    "labels_ok = [0, 1, 0, 1]\n",
    "for i, (pt, lbl) in enumerate(zip(pts4, labels_ok)):\n",
    "    color = 'red' if lbl == 1 else 'blue'\n",
    "    marker = 's' if lbl == 1 else 'o'\n",
    "    ax1.scatter(pt[0], pt[1], c=color, s=200, marker=marker,\n",
    "               edgecolors='black', linewidths=2, zorder=5)\n",
    "ax1.set_title('4 points: This dichotomy IS realizable', fontsize=12, color='green')\n",
    "# Draw a separating line\n",
    "x_line = np.linspace(-0.3, 1.3, 100)\n",
    "ax1.plot(x_line, -0.8*x_line + 0.35, 'g-', linewidth=2.5)\n",
    "ax1.set_xlim(-0.3, 1.3)\n",
    "ax1.set_ylim(-0.3, 1.3)\n",
    "ax1.set_aspect('equal')\n",
    "ax1.grid(True, alpha=0.3)\n",
    "\n",
    "# Show the XOR-like failing dichotomy\n",
    "ax2 = axes[1]\n",
    "# Use the standard XOR pattern on a square\n",
    "pts_xor = np.array([[0, 0], [1, 0], [0, 1], [1, 1]])\n",
    "labels_fail = [0, 1, 1, 0]  # XOR pattern\n",
    "for i, (pt, lbl) in enumerate(zip(pts_xor, labels_fail)):\n",
    "    color = 'red' if lbl == 1 else 'blue'\n",
    "    marker = 's' if lbl == 1 else 'o'\n",
    "    ax2.scatter(pt[0], pt[1], c=color, s=200, marker=marker,\n",
    "               edgecolors='black', linewidths=2, zorder=5)\n",
    "\n",
    "# Draw the intersecting convex hulls\n",
    "ax2.plot([0, 1], [0, 1], 'b--', linewidth=2, alpha=0.5, label='conv(Class 0)')\n",
    "ax2.plot([1, 0], [0, 1], 'r--', linewidth=2, alpha=0.5, label='conv(Class 1)')\n",
    "ax2.plot(0.5, 0.5, 'k*', markersize=15, zorder=10)\n",
    "\n",
    "ax2.set_title('4 points: This dichotomy is NOT realizable (XOR)\\n'\n",
    "              'No line can separate blue circles from red squares',\n",
    "              fontsize=12, color='red')\n",
    "ax2.set_xlim(-0.3, 1.3)\n",
    "ax2.set_ylim(-0.3, 1.3)\n",
    "ax2.set_aspect('equal')\n",
    "ax2.grid(True, alpha=0.3)\n",
    "ax2.legend(fontsize=11)\n",
    "\n",
    "plt.suptitle('VCdim(perceptron in $\\\\mathbb{R}^2$) = 3, NOT 4\\n'\n",
    "             '3 points can be shattered, but 4 points cannot', fontsize=14, y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-13",
   "metadata": {},
   "source": [
    "## 6. Higher-Dimensional Embeddings\n",
    "\n",
    "One of the most powerful ideas in machine learning is that data which is **not** linearly separable in its original space may become separable after mapping it to a **higher-dimensional** feature space.\n",
    "\n",
    "### XOR in $\\mathbb{R}^3$\n",
    "\n",
    "Consider the feature map $\\varphi: \\mathbb{R}^2 \\to \\mathbb{R}^3$ defined by:\n",
    "\n",
    "$$\\varphi(x_1, x_2) = (x_1, x_2, x_1 x_2)$$\n",
    "\n",
    "This adds a new dimension: the **product** of the two inputs. Under this mapping:\n",
    "\n",
    "| $(x_1, x_2)$ | $\\varphi(x_1, x_2) = (x_1, x_2, x_1 x_2)$ | XOR |\n",
    "|:------------:|:------------------------------------------:|:---:|\n",
    "| $(0, 0)$ | $(0, 0, 0)$ | 0 |\n",
    "| $(0, 1)$ | $(0, 1, 0)$ | 1 |\n",
    "| $(1, 0)$ | $(1, 0, 0)$ | 1 |\n",
    "| $(1, 1)$ | $(1, 1, 1)$ | 0 |\n",
    "\n",
    "**Claim:** These 4 points are linearly separable in $\\mathbb{R}^3$.\n",
    "\n",
    "**Verification:** We need $w_1 x_1 + w_2 x_2 + w_3 x_1 x_2 + b$ to be positive for XOR=1 and negative for XOR=0.\n",
    "\n",
    "Try $w_1 = 1, w_2 = 1, w_3 = -2, b = -0.5$:\n",
    "- $(0,0,0)$: $0 + 0 + 0 - 0.5 = -0.5 < 0$ (correct: class 0)\n",
    "- $(0,1,0)$: $0 + 1 + 0 - 0.5 = 0.5 > 0$ (correct: class 1)\n",
    "- $(1,0,0)$: $1 + 0 + 0 - 0.5 = 0.5 > 0$ (correct: class 1)\n",
    "- $(1,1,1)$: $1 + 1 - 2 - 0.5 = -0.5 < 0$ (correct: class 0)\n",
    "\n",
    "### Connection to Kernel Methods\n",
    "\n",
    "This idea --- mapping data to a higher-dimensional space where it becomes linearly separable --- is the foundation of **kernel methods**, including **Support Vector Machines (SVMs)**. The \"kernel trick\" allows performing computations in the high-dimensional feature space implicitly, without ever computing the mapping $\\varphi$ explicitly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-14",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# 3D visualization: XOR becomes separable after lifting\n",
    "\n",
    "fig = plt.figure(figsize=(14, 6))\n",
    "\n",
    "# Left: original 2D space\n",
    "ax1 = fig.add_subplot(121)\n",
    "X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n",
    "y = np.array([0, 1, 1, 0])\n",
    "\n",
    "colors = ['blue' if yi == 0 else 'red' for yi in y]\n",
    "markers = ['o' if yi == 0 else 's' for yi in y]\n",
    "for i in range(4):\n",
    "    ax1.scatter(X[i, 0], X[i, 1], c=colors[i], s=250, marker=markers[i],\n",
    "               edgecolors='black', linewidths=2, zorder=5)\n",
    "    ax1.annotate(f'XOR={y[i]}', (X[i, 0], X[i, 1]),\n",
    "                xytext=(X[i, 0]+0.05, X[i, 1]+0.08), fontsize=11)\n",
    "\n",
    "ax1.set_xlabel('$x_1$', fontsize=14)\n",
    "ax1.set_ylabel('$x_2$', fontsize=14)\n",
    "ax1.set_title('Original Space $\\\\mathbb{R}^2$: NOT separable', fontsize=13)\n",
    "ax1.set_xlim(-0.3, 1.4)\n",
    "ax1.set_ylim(-0.3, 1.4)\n",
    "ax1.set_aspect('equal')\n",
    "ax1.grid(True, alpha=0.3)\n",
    "\n",
    "# Right: lifted 3D space\n",
    "ax2 = fig.add_subplot(122, projection='3d')\n",
    "\n",
    "# Lift to 3D: phi(x1, x2) = (x1, x2, x1*x2)\n",
    "X_lifted = np.column_stack([X, X[:, 0] * X[:, 1]])\n",
    "\n",
    "for i in range(4):\n",
    "    ax2.scatter(X_lifted[i, 0], X_lifted[i, 1], X_lifted[i, 2],\n",
    "               c=colors[i], s=200, marker=markers[i],\n",
    "               edgecolors='black', linewidths=2, zorder=5)\n",
    "\n",
    "# Draw the separating plane: x1 + x2 - 2*x1*x2 = 0.5\n",
    "# => z3 = (x1 + x2 - 0.5) / 2  (where z3 = x1*x2)\n",
    "xx, yy = np.meshgrid(np.linspace(-0.2, 1.2, 20), np.linspace(-0.2, 1.2, 20))\n",
    "zz = (xx + yy - 0.5) / 2\n",
    "ax2.plot_surface(xx, yy, zz, alpha=0.2, color='green')\n",
    "\n",
    "ax2.set_xlabel('$x_1$', fontsize=12)\n",
    "ax2.set_ylabel('$x_2$', fontsize=12)\n",
    "ax2.set_zlabel('$x_1 x_2$', fontsize=12)\n",
    "ax2.set_title('Lifted Space $\\\\mathbb{R}^3$: SEPARABLE!', fontsize=13)\n",
    "ax2.view_init(elev=25, azim=45)\n",
    "\n",
    "plt.suptitle('Feature Map $\\\\varphi(x_1, x_2) = (x_1, x_2, x_1 x_2)$: '\n",
    "             'XOR Becomes Linearly Separable in 3D', fontsize=14, y=1.0)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Verify the separation\n",
    "w = np.array([1, 1, -2])\n",
    "b = -0.5\n",
    "print(\"Verification of linear separability in R^3:\")\n",
    "print(f\"Separating plane: {w[0]}*x1 + {w[1]}*x2 + ({w[2]})*x1*x2 + ({b}) = 0\")\n",
    "print()\n",
    "for i in range(4):\n",
    "    val = np.dot(w, X_lifted[i]) + b\n",
    "    pred = 1 if val > 0 else 0\n",
    "    print(f\"phi({X[i]}) = {X_lifted[i]} => w.phi + b = {val:.1f} => class {pred} \"\n",
    "          f\"(true: {y[i]}) {'OK' if pred == y[i] else 'FAIL'}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-15",
   "metadata": {},
   "source": [
    "## 7. Margin and Support Vectors\n",
    "\n",
    "When data is linearly separable, there are infinitely many separating hyperplanes. Which one is \"best\"? This question leads to the concept of **margin** and eventually to **Support Vector Machines (SVMs)**.\n",
    "\n",
    "### Geometric Margin\n",
    "\n",
    "Given a separating hyperplane $H: \\mathbf{w} \\cdot \\mathbf{x} + b = 0$ (with $\\|\\mathbf{w}\\| = 1$), the **geometric margin** is the minimum distance from any data point to $H$:\n",
    "\n",
    "$$\\gamma = \\min_i |\\mathbf{w} \\cdot \\mathbf{x}_i + b|$$\n",
    "\n",
    "The distance from a point $\\mathbf{x}_i$ to the hyperplane is:\n",
    "\n",
    "$$d(\\mathbf{x}_i, H) = \\frac{|\\mathbf{w} \\cdot \\mathbf{x}_i + b|}{\\|\\mathbf{w}\\|}$$\n",
    "\n",
    "### Support Vectors\n",
    "\n",
    "The **support vectors** are the data points closest to the separating hyperplane --- those achieving the minimum distance $\\gamma$. These are the \"hardest\" points to classify and the ones that determine the position of the optimal boundary.\n",
    "\n",
    "### Maximum Margin Classifier\n",
    "\n",
    "The **maximum margin hyperplane** is the one that maximizes $\\gamma$. This is the optimal linear classifier in a precise sense: it provides the greatest \"safety buffer\" between the two classes.\n",
    "\n",
    "The optimization problem is:\n",
    "\n",
    "$$\\max_{\\mathbf{w}, b} \\frac{2}{\\|\\mathbf{w}\\|} \\quad \\text{s.t.} \\quad y_i(\\mathbf{w} \\cdot \\mathbf{x}_i + b) \\geq 1 \\; \\forall i$$\n",
    "\n",
    "This is equivalent to minimizing $\\|\\mathbf{w}\\|^2$ subject to the constraints --- a **quadratic program** that can be solved efficiently. This is the foundation of the **Support Vector Machine (SVM)**, developed by Vapnik and Cortes in the 1990s."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-16",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Visualization: margin and support vectors\n",
    "\n",
    "np.random.seed(12)\n",
    "\n",
    "# Generate linearly separable data\n",
    "A = np.random.randn(15, 2) * 0.6 + np.array([-1.5, 0])\n",
    "B = np.random.randn(15, 2) * 0.6 + np.array([1.5, 0])\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
    "\n",
    "# Left: many valid separating hyperplanes\n",
    "ax1 = axes[0]\n",
    "ax1.scatter(A[:, 0], A[:, 1], c='blue', s=80, edgecolors='black', linewidths=1, marker='o')\n",
    "ax1.scatter(B[:, 0], B[:, 1], c='red', s=80, edgecolors='black', linewidths=1, marker='s')\n",
    "\n",
    "# Draw several valid separating lines\n",
    "x_plot = np.linspace(-3.5, 3.5, 100)\n",
    "for slope, intercept, style in [(0, 0, '-'), (0.5, 0.2, '--'), (-0.3, -0.1, '-.'), (0.8, 0.3, ':')]:\n",
    "    ax1.plot(x_plot, slope * x_plot + intercept, 'gray', linestyle=style, linewidth=1.5, alpha=0.6)\n",
    "\n",
    "ax1.set_title('Many valid separating lines', fontsize=13)\n",
    "ax1.set_xlim(-3.5, 3.5)\n",
    "ax1.set_ylim(-3, 3)\n",
    "ax1.set_xlabel('$x_1$', fontsize=12)\n",
    "ax1.set_ylabel('$x_2$', fontsize=12)\n",
    "ax1.grid(True, alpha=0.3)\n",
    "\n",
    "# Right: maximum margin hyperplane\n",
    "ax2 = axes[1]\n",
    "ax2.scatter(A[:, 0], A[:, 1], c='blue', s=80, edgecolors='black', linewidths=1, marker='o')\n",
    "ax2.scatter(B[:, 0], B[:, 1], c='red', s=80, edgecolors='black', linewidths=1, marker='s')\n",
    "\n",
    "# Find approximate maximum margin (simple approach: midpoint of closest pair)\n",
    "min_dist = np.inf\n",
    "sv_a, sv_b = None, None\n",
    "for a in A:\n",
    "    for b in B:\n",
    "        d = np.linalg.norm(a - b)\n",
    "        if d < min_dist:\n",
    "            min_dist = d\n",
    "            sv_a, sv_b = a, b\n",
    "\n",
    "# The optimal separating line is perpendicular to sv_b - sv_a at the midpoint\n",
    "midpoint = (sv_a + sv_b) / 2\n",
    "w_dir = sv_b - sv_a\n",
    "w_dir = w_dir / np.linalg.norm(w_dir)\n",
    "\n",
    "# Line perpendicular to w_dir through midpoint\n",
    "perp = np.array([-w_dir[1], w_dir[0]])\n",
    "t_vals = np.linspace(-3, 3, 100)\n",
    "line_pts = midpoint + np.outer(t_vals, perp)\n",
    "ax2.plot(line_pts[:, 0], line_pts[:, 1], 'g-', linewidth=3, label='Max-margin boundary')\n",
    "\n",
    "# Draw margin lines\n",
    "margin = min_dist / 2\n",
    "for sign, style in [(1, '--'), (-1, '--')]:\n",
    "    margin_pts = midpoint + sign * margin * w_dir + np.outer(t_vals, perp)\n",
    "    ax2.plot(margin_pts[:, 0], margin_pts[:, 1], 'g', linestyle=style, linewidth=1.5, alpha=0.5)\n",
    "\n",
    "# Highlight support vectors\n",
    "ax2.scatter([sv_a[0]], [sv_a[1]], c='blue', s=250, edgecolors='green', linewidths=3,\n",
    "            marker='o', zorder=6, label='Support vectors')\n",
    "ax2.scatter([sv_b[0]], [sv_b[1]], c='red', s=250, edgecolors='green', linewidths=3,\n",
    "            marker='s', zorder=6)\n",
    "\n",
    "# Draw margin width\n",
    "ax2.annotate('', xy=sv_b, xytext=sv_a,\n",
    "             arrowprops=dict(arrowstyle='<->', color='purple', linewidth=2))\n",
    "ax2.text(midpoint[0] + 0.1, midpoint[1] + 0.3, f'margin = {min_dist:.2f}',\n",
    "         fontsize=12, color='purple', fontweight='bold')\n",
    "\n",
    "ax2.set_title('Maximum margin hyperplane', fontsize=13)\n",
    "ax2.set_xlim(-3.5, 3.5)\n",
    "ax2.set_ylim(-3, 3)\n",
    "ax2.set_xlabel('$x_1$', fontsize=12)\n",
    "ax2.set_ylabel('$x_2$', fontsize=12)\n",
    "ax2.legend(fontsize=10, loc='upper left')\n",
    "ax2.grid(True, alpha=0.3)\n",
    "\n",
    "plt.suptitle('From Separability to Optimal Separability: The Margin Idea', fontsize=14, y=1.01)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-17",
   "metadata": {},
   "source": [
    "## 8. Exercises\n",
    "\n",
    "### Exercise 10.1: Separability by Convex Hull\n",
    "\n",
    "For each of the 5 random 2D datasets below, determine whether the two classes are linearly separable by computing and visualizing the convex hulls. Use the `convex_hull_2d` and `check_hull_intersection` functions defined earlier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-18",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Exercise 10.1: Generate 5 random datasets and check separability\n",
    "\n",
    "np.random.seed(2024)\n",
    "\n",
    "exercise_datasets = []\n",
    "for i in range(5):\n",
    "    # Randomly generate two clusters\n",
    "    center_A = np.random.randn(2) * 2\n",
    "    center_B = np.random.randn(2) * 2\n",
    "    spread = np.random.uniform(0.3, 1.5)\n",
    "    n_points = np.random.randint(5, 12)\n",
    "    A = np.random.randn(n_points, 2) * spread + center_A\n",
    "    B = np.random.randn(n_points, 2) * spread + center_B\n",
    "    exercise_datasets.append((f'Dataset {i+1}', A, B))\n",
    "\n",
    "print(\"Exercise 10.1: For each dataset, determine separability.\")\n",
    "print(\"Use convex_hull_2d() and check_hull_intersection().\")\n",
    "print()\n",
    "\n",
    "for name, A, B in exercise_datasets:\n",
    "    hull_A = convex_hull_2d(A)\n",
    "    hull_B = convex_hull_2d(B)\n",
    "    intersects = check_hull_intersection(hull_A, hull_B)\n",
    "    status = \"NOT separable\" if intersects else \"SEPARABLE\"\n",
    "    print(f\"{name}: {len(A)} vs {len(B)} points => {status}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-19",
   "metadata": {},
   "source": [
    "### Exercise 10.2: VC Dimension of Circles\n",
    "\n",
    "Consider the hypothesis class of **circles** in $\\mathbb{R}^2$: a point is classified as positive if it lies inside the circle, and negative if outside.\n",
    "\n",
    "$$h_{c,r}(\\mathbf{x}) = \\begin{cases} 1 & \\text{if } \\|\\mathbf{x} - \\mathbf{c}\\| \\leq r \\\\ 0 & \\text{if } \\|\\mathbf{x} - \\mathbf{c}\\| > r \\end{cases}$$\n",
    "\n",
    "**Task:** Show that the VC dimension of this class is 3.\n",
    "\n",
    "1. Show that 3 points can be shattered.\n",
    "2. Show that no set of 4 points can be shattered.\n",
    "\n",
    "*Hint for (1):* Place 3 points on a circle. Show all 8 labelings can be achieved.\n",
    "\n",
    "*Hint for (2):* Consider 4 points. If one is inside the convex hull of the other three, use that to find a labeling that fails.\n",
    "\n",
    "\n",
    "### Exercise 10.3: Feature Map for Concentric Circles\n",
    "\n",
    "Consider a dataset where class 0 points form a ring at distance $\\approx 2$ from the origin, and class 1 points form a cluster near the origin.\n",
    "\n",
    "Find a feature map $\\varphi: \\mathbb{R}^2 \\to \\mathbb{R}^k$ (for some $k$) that makes this dataset linearly separable.\n",
    "\n",
    "*Hint:* What quantity naturally separates points near the origin from points far from it?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-20",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Exercise 10.3: Concentric circles dataset\n",
    "\n",
    "np.random.seed(42)\n",
    "\n",
    "# Generate concentric circles data\n",
    "n_inner = 50\n",
    "n_outer = 80\n",
    "\n",
    "# Inner circle (class 1)\n",
    "theta_inner = np.random.uniform(0, 2*np.pi, n_inner)\n",
    "r_inner = np.random.uniform(0, 0.8, n_inner)\n",
    "X_inner = np.column_stack([r_inner * np.cos(theta_inner), r_inner * np.sin(theta_inner)])\n",
    "\n",
    "# Outer ring (class 0)\n",
    "theta_outer = np.random.uniform(0, 2*np.pi, n_outer)\n",
    "r_outer = np.random.uniform(1.5, 2.5, n_outer)\n",
    "X_outer = np.column_stack([r_outer * np.cos(theta_outer), r_outer * np.sin(theta_outer)])\n",
    "\n",
    "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n",
    "\n",
    "# Original space\n",
    "ax1 = axes[0]\n",
    "ax1.scatter(X_inner[:, 0], X_inner[:, 1], c='red', s=30, label='Class 1 (inner)', alpha=0.7)\n",
    "ax1.scatter(X_outer[:, 0], X_outer[:, 1], c='blue', s=30, label='Class 0 (outer)', alpha=0.7)\n",
    "ax1.set_title('Original Space: NOT linearly separable', fontsize=13)\n",
    "ax1.set_xlabel('$x_1$', fontsize=12)\n",
    "ax1.set_ylabel('$x_2$', fontsize=12)\n",
    "ax1.legend(fontsize=10)\n",
    "ax1.set_aspect('equal')\n",
    "ax1.grid(True, alpha=0.3)\n",
    "\n",
    "# Feature space: phi(x1, x2) = x1^2 + x2^2 (just the radius squared)\n",
    "r2_inner = X_inner[:, 0]**2 + X_inner[:, 1]**2\n",
    "r2_outer = X_outer[:, 0]**2 + X_outer[:, 1]**2\n",
    "\n",
    "ax2 = axes[1]\n",
    "ax2.scatter(r2_inner, np.zeros_like(r2_inner) + np.random.randn(n_inner)*0.05,\n",
    "            c='red', s=30, label='Class 1 (inner)', alpha=0.7)\n",
    "ax2.scatter(r2_outer, np.zeros_like(r2_outer) + np.random.randn(n_outer)*0.05,\n",
    "            c='blue', s=30, label='Class 0 (outer)', alpha=0.7)\n",
    "ax2.axvline(x=1.0, color='green', linewidth=3, linestyle='--',\n",
    "            label='Threshold at $r^2 = 1.0$')\n",
    "ax2.set_title('Feature Space $\\\\varphi(x_1, x_2) = x_1^2 + x_2^2$: SEPARABLE!', fontsize=13)\n",
    "ax2.set_xlabel('$r^2 = x_1^2 + x_2^2$', fontsize=12)\n",
    "ax2.set_ylabel('(jittered for visibility)', fontsize=10)\n",
    "ax2.legend(fontsize=10)\n",
    "ax2.grid(True, alpha=0.3)\n",
    "\n",
    "plt.suptitle('Exercise 10.3: The radius-squared feature makes concentric circles separable',\n",
    "             fontsize=14, y=1.01)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(\"The feature map phi(x1, x2) = x1^2 + x2^2 maps each point to its squared distance from the origin.\")\n",
    "print(\"In this 1D feature space, a simple threshold separates the two classes.\")\n",
    "print(f\"\\nInner class r^2 range: [{r2_inner.min():.3f}, {r2_inner.max():.3f}]\")\n",
    "print(f\"Outer class r^2 range: [{r2_outer.min():.3f}, {r2_outer.max():.3f}]\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}