{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cell-0",
   "metadata": {},
   "source": [
    "# Chapter 7: What Perceptrons Can Compute\n",
    "\n",
    "\n",
    "## Part 2: The Perceptron"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-1",
   "metadata": {},
   "source": [
    "## 7.1 Introduction\n",
    "\n",
    "Having established the perceptron model (Chapter 4) and its learning algorithm (Chapter 5), we now ask a fundamental question: **what functions can a single perceptron compute?**\n",
    "\n",
    "Since a perceptron is a binary classifier that partitions its input space with a hyperplane, this question reduces to: **which Boolean functions are linearly separable?**\n",
    "\n",
    "This chapter provides a systematic exploration:\n",
    "1. We enumerate all 16 two-input Boolean functions and determine which are perceptron-computable.\n",
    "2. We visualize the decision boundaries for all separable functions and demonstrate the impossibility for the non-separable ones.\n",
    "3. We count how the fraction of linearly separable functions changes with the number of inputs.\n",
    "4. We introduce the **convex hull criterion** for linear separability.\n",
    "5. We define the **VC dimension** of the perceptron.\n",
    "\n",
    "These results set the stage for the Minsky-Papert critique and the motivation for multi-layer networks."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-2",
   "metadata": {},
   "source": [
    "## 7.2 All 16 Two-Input Boolean Functions\n",
    "\n",
    "A Boolean function of two variables $f: \\{0,1\\}^2 \\to \\{0,1\\}$ is completely determined by its values on the four inputs $(0,0), (0,1), (1,0), (1,1)$. Since each output can be 0 or 1, there are $2^4 = 16$ possible functions.\n",
    "\n",
    "We can list them systematically. If we write the outputs in the order $(f(0,0), f(0,1), f(1,0), f(1,1))$ as a 4-bit binary number, we get the functions numbered 0 through 15:\n",
    "\n",
    "| # | $(0,0)$ | $(0,1)$ | $(1,0)$ | $(1,1)$ | Name | Linearly Separable? |\n",
    "|---|---------|---------|---------|---------|------|-----------------|\n",
    "| 0 | 0 | 0 | 0 | 0 | FALSE (contradiction) | Yes (trivially) |\n",
    "| 1 | 0 | 0 | 0 | 1 | AND ($x_1 \\wedge x_2$) | Yes |\n",
    "| 2 | 0 | 0 | 1 | 0 | $x_1 \\wedge \\neg x_2$ | Yes |\n",
    "| 3 | 0 | 0 | 1 | 1 | $x_1$ (projection) | Yes |\n",
    "| 4 | 0 | 1 | 0 | 0 | $\\neg x_1 \\wedge x_2$ | Yes |\n",
    "| 5 | 0 | 1 | 0 | 1 | $x_2$ (projection) | Yes |\n",
    "| 6 | 0 | 1 | 1 | 0 | **XOR** ($x_1 \\oplus x_2$) | **No** |\n",
    "| 7 | 0 | 1 | 1 | 1 | OR ($x_1 \\vee x_2$) | Yes |\n",
    "| 8 | 1 | 0 | 0 | 0 | NOR ($\\neg(x_1 \\vee x_2)$) | Yes |\n",
    "| 9 | 1 | 0 | 0 | 1 | **XNOR** ($\\neg(x_1 \\oplus x_2)$) | **No** |\n",
    "| 10 | 1 | 0 | 1 | 0 | $\\neg x_2$ | Yes |\n",
    "| 11 | 1 | 0 | 1 | 1 | $x_1 \\vee \\neg x_2$ | Yes |\n",
    "| 12 | 1 | 1 | 0 | 0 | $\\neg x_1$ | Yes |\n",
    "| 13 | 1 | 1 | 0 | 1 | $\\neg x_1 \\vee x_2$ | Yes |\n",
    "| 14 | 1 | 1 | 1 | 0 | NAND ($\\neg(x_1 \\wedge x_2)$) | Yes |\n",
    "| 15 | 1 | 1 | 1 | 1 | TRUE (tautology) | Yes (trivially) |\n",
    "\n",
    "```{danger}\n",
    "**Most Boolean functions are NOT linearly separable!** While 14 out of 16 two-input functions are separable (a fraction of 87.5%), this is deceptively optimistic. The fraction drops **super-exponentially** with the number of inputs:\n",
    "\n",
    "- $n = 2$: 14 / 16 = **87.5%** separable\n",
    "- $n = 3$: 104 / 256 = **40.6%** separable\n",
    "- $n = 4$: 1,882 / 65,536 = **2.87%** separable\n",
    "- $n = 5$: 94,572 / 4,294,967,296 = **0.0022%** separable\n",
    "\n",
    "For even moderate $n$, virtually NO Boolean function can be computed by a single perceptron. This is the fundamental limitation that Minsky and Papert exposed in 1969.\n",
    "```\n",
    "\n",
    "Note that XNOR is the complement of XOR, so their non-separability is related: if XOR were separable, flipping all labels would give XNOR, which would also be separable (just negate the weights and bias)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-3",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.patches import Polygon\n",
    "from scipy.spatial import ConvexHull\n",
    "\n",
    "plt.rcParams.update({\n",
    "    'figure.figsize': (8, 6),\n",
    "    'font.size': 12,\n",
    "    'axes.grid': True,\n",
    "    'grid.alpha': 0.3\n",
    "})\n",
    "\n",
    "try:\n",
    "    plt.style.use('seaborn-v0_8-whitegrid')\n",
    "except OSError:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-3a",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Complete table of all 16 two-input Boolean functions with separability status\n",
    "fig, ax = plt.subplots(figsize=(16, 10))\n",
    "ax.axis('off')\n",
    "\n",
    "# Headers\n",
    "headers = ['#', '(0,0)', '(0,1)', '(1,0)', '(1,1)', 'Name', 'Separable?', 'Weights (w1,w2,b)']\n",
    "\n",
    "# Data rows\n",
    "rows = [\n",
    "    ['0',  '0', '0', '0', '0', 'FALSE',                       'Yes (trivial)', 'b = -1'],\n",
    "    ['1',  '0', '0', '0', '1', 'AND',                         'Yes',           'w=(1,1), b=-1.5'],\n",
    "    ['2',  '0', '0', '1', '0', '$x_1 \\\\wedge \\\\neg x_2$',     'Yes',           'w=(1,-1), b=-0.5'],\n",
    "    ['3',  '0', '0', '1', '1', '$x_1$',                       'Yes',           'w=(1,0), b=-0.5'],\n",
    "    ['4',  '0', '1', '0', '0', '$\\\\neg x_1 \\\\wedge x_2$',     'Yes',           'w=(-1,1), b=-0.5'],\n",
    "    ['5',  '0', '1', '0', '1', '$x_2$',                       'Yes',           'w=(0,1), b=-0.5'],\n",
    "    ['6',  '0', '1', '1', '0', 'XOR',                         'NO',            '---'],\n",
    "    ['7',  '0', '1', '1', '1', 'OR',                          'Yes',           'w=(1,1), b=-0.5'],\n",
    "    ['8',  '1', '0', '0', '0', 'NOR',                         'Yes',           'w=(-1,-1), b=0.5'],\n",
    "    ['9',  '1', '0', '0', '1', 'XNOR',                        'NO',            '---'],\n",
    "    ['10', '1', '0', '1', '0', '$\\\\neg x_2$',                  'Yes',           'w=(0,-1), b=0.5'],\n",
    "    ['11', '1', '0', '1', '1', '$x_1 \\\\vee \\\\neg x_2$',       'Yes',           'w=(1,-1), b=0.5'],\n",
    "    ['12', '1', '1', '0', '0', '$\\\\neg x_1$',                  'Yes',           'w=(-1,0), b=0.5'],\n",
    "    ['13', '1', '1', '0', '1', '$\\\\neg x_1 \\\\vee x_2$',       'Yes',           'w=(-1,1), b=0.5'],\n",
    "    ['14', '1', '1', '1', '0', 'NAND',                        'Yes',           'w=(-1,-1), b=1.5'],\n",
    "    ['15', '1', '1', '1', '1', 'TRUE',                        'Yes (trivial)', 'b = 1'],\n",
    "]\n",
    "\n",
    "# Create table\n",
    "table = ax.table(\n",
    "    cellText=rows,\n",
    "    colLabels=headers,\n",
    "    cellLoc='center',\n",
    "    loc='center',\n",
    "    colWidths=[0.04, 0.06, 0.06, 0.06, 0.06, 0.18, 0.12, 0.18]\n",
    ")\n",
    "\n",
    "table.auto_set_font_size(False)\n",
    "table.set_fontsize(9)\n",
    "table.scale(1.0, 1.8)\n",
    "\n",
    "# Style headers\n",
    "for j in range(len(headers)):\n",
    "    cell = table[0, j]\n",
    "    cell.set_facecolor('#2C3E50')\n",
    "    cell.set_text_props(color='white', fontweight='bold', fontsize=10)\n",
    "\n",
    "# Style rows\n",
    "for i in range(1, len(rows) + 1):\n",
    "    row_data = rows[i-1]\n",
    "    is_nonsep = row_data[6] == 'NO'\n",
    "    \n",
    "    for j in range(len(headers)):\n",
    "        cell = table[i, j]\n",
    "        if is_nonsep:\n",
    "            cell.set_facecolor('#FADBD8')\n",
    "            cell.set_text_props(fontweight='bold', color='#C0392B')\n",
    "        else:\n",
    "            color = '#EBF5FB' if i % 2 == 1 else '#FDFEFE'\n",
    "            cell.set_facecolor(color)\n",
    "\n",
    "ax.set_title('Complete Table of All 16 Two-Input Boolean Functions\\n'\n",
    "             'with Linear Separability Status and Perceptron Weights',\n",
    "             fontsize=14, fontweight='bold', pad=20)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-4",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Define all 16 two-input Boolean functions\n",
    "# Input: (0,0), (0,1), (1,0), (1,1)\n",
    "inputs_2 = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n",
    "\n",
    "boolean_functions = {\n",
    "    0:  {'name': 'FALSE',                'outputs': [0,0,0,0], 'separable': True},\n",
    "    1:  {'name': 'AND',                  'outputs': [0,0,0,1], 'separable': True},\n",
    "    2:  {'name': '$x_1 \\\\wedge \\\\neg x_2$', 'outputs': [0,0,1,0], 'separable': True},\n",
    "    3:  {'name': '$x_1$',                'outputs': [0,0,1,1], 'separable': True},\n",
    "    4:  {'name': '$\\\\neg x_1 \\\\wedge x_2$', 'outputs': [0,1,0,0], 'separable': True},\n",
    "    5:  {'name': '$x_2$',                'outputs': [0,1,0,1], 'separable': True},\n",
    "    6:  {'name': 'XOR',                  'outputs': [0,1,1,0], 'separable': False},\n",
    "    7:  {'name': 'OR',                   'outputs': [0,1,1,1], 'separable': True},\n",
    "    8:  {'name': 'NOR',                  'outputs': [1,0,0,0], 'separable': True},\n",
    "    9:  {'name': 'XNOR',                 'outputs': [1,0,0,1], 'separable': False},\n",
    "    10: {'name': '$\\\\neg x_2$',           'outputs': [1,0,1,0], 'separable': True},\n",
    "    11: {'name': '$x_1 \\\\vee \\\\neg x_2$',  'outputs': [1,0,1,1], 'separable': True},\n",
    "    12: {'name': '$\\\\neg x_1$',           'outputs': [1,1,0,0], 'separable': True},\n",
    "    13: {'name': '$\\\\neg x_1 \\\\vee x_2$',  'outputs': [1,1,0,1], 'separable': True},\n",
    "    14: {'name': 'NAND',                 'outputs': [1,1,1,0], 'separable': True},\n",
    "    15: {'name': 'TRUE',                 'outputs': [1,1,1,1], 'separable': True},\n",
    "}\n",
    "\n",
    "# Weights and biases for the separable functions (found analytically or by perceptron)\n",
    "# For the constant functions (FALSE, TRUE), we use a degenerate boundary\n",
    "separating_params = {\n",
    "    0:  (np.array([0, 0]), -1),       # FALSE: always 0 (bias < 0)\n",
    "    1:  (np.array([1, 1]), -1.5),     # AND\n",
    "    2:  (np.array([1, -1]), -0.5),    # x1 AND NOT x2\n",
    "    3:  (np.array([1, 0]), -0.5),     # x1\n",
    "    4:  (np.array([-1, 1]), -0.5),    # NOT x1 AND x2\n",
    "    5:  (np.array([0, 1]), -0.5),     # x2\n",
    "    7:  (np.array([1, 1]), -0.5),     # OR\n",
    "    8:  (np.array([-1, -1]), 0.5),    # NOR\n",
    "    10: (np.array([0, -1]), 0.5),     # NOT x2\n",
    "    11: (np.array([1, -1]), 0.5),     # x1 OR NOT x2\n",
    "    12: (np.array([-1, 0]), 0.5),     # NOT x1\n",
    "    13: (np.array([-1, 1]), 0.5),     # NOT x1 OR x2\n",
    "    14: (np.array([-1, -1]), 1.5),    # NAND\n",
    "    15: (np.array([0, 0]), 1),        # TRUE: always 1 (bias > 0)\n",
    "}\n",
    "\n",
    "# Verify all separable functions\n",
    "print(\"Verification of perceptron weights for all separable functions:\")\n",
    "print(\"=\" * 65)\n",
    "all_correct = True\n",
    "for idx, info in boolean_functions.items():\n",
    "    if info['separable'] and idx in separating_params:\n",
    "        w, b = separating_params[idx]\n",
    "        outputs = info['outputs']\n",
    "        predictions = [(inputs_2[i] @ w + b >= 0).astype(int) for i in range(4)]\n",
    "        correct = all(p == o for p, o in zip(predictions, outputs))\n",
    "        if not correct:\n",
    "            all_correct = False\n",
    "        status = 'OK' if correct else 'FAIL'\n",
    "        print(f\"  {idx:>2d}. {info['name']:>25s}: w={w}, b={b:+.1f} [{status}]\")\n",
    "\n",
    "print(f\"\\nAll correct: {all_correct}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-5",
   "metadata": {},
   "source": [
    "## 7.3 Decision Boundary Gallery"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-6",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, axes = plt.subplots(4, 4, figsize=(20, 20))\n",
    "\n",
    "for idx in range(16):\n",
    "    row, col = divmod(idx, 4)\n",
    "    ax = axes[row, col]\n",
    "    \n",
    "    info = boolean_functions[idx]\n",
    "    outputs = np.array(info['outputs'])\n",
    "    \n",
    "    # Plot background\n",
    "    xx, yy = np.meshgrid(np.linspace(-0.5, 1.5, 200),\n",
    "                         np.linspace(-0.5, 1.5, 200))\n",
    "    \n",
    "    if info['separable'] and idx in separating_params:\n",
    "        w, b = separating_params[idx]\n",
    "        Z = xx * w[0] + yy * w[1] + b\n",
    "        \n",
    "        ax.contourf(xx, yy, Z, levels=[-1e10, 0, 1e10],\n",
    "                    colors=['#FFCCCC', '#CCCCFF'], alpha=0.4)\n",
    "        \n",
    "        # Draw boundary only if weights are nonzero\n",
    "        if np.linalg.norm(w) > 0:\n",
    "            ax.contour(xx, yy, Z, levels=[0], colors='black', linewidths=2)\n",
    "    else:\n",
    "        # Not separable: show diagonal hatching / grey background\n",
    "        ax.set_facecolor('#FFE0E0')\n",
    "        ax.text(0.5, -0.35, 'NOT SEPARABLE', ha='center', fontsize=10,\n",
    "                color='red', fontweight='bold')\n",
    "    \n",
    "    # Plot points\n",
    "    for i in range(4):\n",
    "        color = 'blue' if outputs[i] == 1 else 'red'\n",
    "        marker = 's' if outputs[i] == 1 else 'o'\n",
    "        ax.scatter(inputs_2[i, 0], inputs_2[i, 1], c=color, marker=marker,\n",
    "                   s=200, edgecolors='black', zorder=5, linewidths=1.5)\n",
    "    \n",
    "    ax.set_xlim(-0.5, 1.5)\n",
    "    ax.set_ylim(-0.5, 1.5)\n",
    "    ax.set_aspect('equal')\n",
    "    ax.set_title(f'#{idx}: {info[\"name\"]}', fontsize=11, fontweight='bold')\n",
    "    ax.set_xticks([0, 1])\n",
    "    ax.set_yticks([0, 1])\n",
    "    ax.grid(True, alpha=0.3)\n",
    "\n",
    "fig.suptitle('All 16 Two-Input Boolean Functions\\n'\n",
    "             '(Blue squares = output 1, Red circles = output 0)',\n",
    "             fontsize=16, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-7",
   "metadata": {},
   "source": [
    "## 7.4 Counting Linearly Separable Functions\n",
    "\n",
    "As the number of inputs $n$ increases, the total number of Boolean functions $2^{2^n}$ grows super-exponentially. How many of these are linearly separable?\n",
    "\n",
    "This is a deep combinatorial question. The answer depends on the number of **threshold functions** (Boolean functions computable by a single perceptron). Let $T(n)$ denote the number of linearly separable Boolean functions of $n$ variables.\n",
    "\n",
    "```{admonition} Theorem (Counting Linearly Separable Boolean Functions)\n",
    ":class: important\n",
    "\n",
    "Let $T(n)$ denote the number of linearly separable Boolean functions of $n$ variables. The known values are:\n",
    "\n",
    "| $n$ | Total functions $2^{2^n}$ | Linearly separable $T(n)$ | Fraction |\n",
    "|-----|--------------------------|--------------------------|----------|\n",
    "| 1   | 4                        | 4                        | 1.000    |\n",
    "| 2   | 16                       | 14                       | 0.875    |\n",
    "| 3   | 256                      | 104                      | 0.406    |\n",
    "| 4   | 65,536                   | 1,882                    | 0.0287   |\n",
    "| 5   | ~4.3 billion             | 94,572                   | ~0.0000220 |\n",
    "| 6   | ~1.8 x 10^19            | 15,028,134               | ~8.3 x 10^-13 |\n",
    "\n",
    "The fraction $T(n) / 2^{2^n}$ goes to zero **super-exponentially** as $n$ grows.\n",
    "```\n",
    "\n",
    "```{tip}\n",
    "**The Combinatorial Explosion of Non-Separable Functions**\n",
    "\n",
    "The total number of Boolean functions grows as $2^{2^n}$ -- a **double exponential** (also called a *tetration*). This is incomprehensibly fast:\n",
    "- At $n = 5$, there are about 4.3 billion functions.\n",
    "- At $n = 6$, there are about $1.8 \\times 10^{19}$ functions -- more than the number of grains of sand on Earth.\n",
    "- At $n = 10$, there are about $10^{308}$ functions -- vastly more than the number of atoms in the observable universe ($\\sim 10^{80}$).\n",
    "\n",
    "The number of linearly separable functions $T(n)$ also grows, but at a much slower rate. As a result, the fraction of functions that a single perceptron can compute shrinks to effectively zero. This is the mathematical underpinning of the perceptron's limitation: it can only express a vanishingly small fraction of all possible input-output mappings.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-8",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Table and plot of linearly separable function counts\n",
    "n_values = [1, 2, 3, 4, 5, 6]\n",
    "total_functions = [2**(2**n) for n in n_values]\n",
    "lin_sep_counts = [4, 14, 104, 1882, 94572, 15028134]\n",
    "fractions = [ls / tot for ls, tot in zip(lin_sep_counts, total_functions)]\n",
    "total_functions_float = [float(x) for x in total_functions]  # convert for matplotlib (2**64 overflows int64)\n",
    "\n",
    "print(\"Linearly Separable Boolean Functions\")\n",
    "print(\"=\" * 70)\n",
    "print(f\"{'n':>4} | {'Total 2^(2^n)':>18} | {'Lin. Sep. T(n)':>16} | {'Fraction':>14}\")\n",
    "print(\"-\" * 70)\n",
    "for n, tot, ls, frac in zip(n_values, total_functions, lin_sep_counts, fractions):\n",
    "    print(f\"{n:>4d} | {tot:>18,} | {ls:>16,} | {frac:>14.6e}\")\n",
    "\n",
    "# Plot\n",
    "fig, axes = plt.subplots(1, 2, figsize=(16, 6))\n",
    "\n",
    "# Panel 1: Log scale of counts\n",
    "ax = axes[0]\n",
    "ax.semilogy(n_values, total_functions_float, 'ro-', linewidth=2, markersize=8,\n",
    "            label=r'Total: $2^{2^n}$')\n",
    "ax.semilogy(n_values, lin_sep_counts, 'bs-', linewidth=2, markersize=8,\n",
    "            label='Linearly separable: $T(n)$')\n",
    "\n",
    "# Shade the gap\n",
    "ax.fill_between(n_values, lin_sep_counts, total_functions_float, alpha=0.15, color='red',\n",
    "                label='Non-separable functions')\n",
    "\n",
    "ax.set_xlabel('Number of inputs $n$', fontsize=13)\n",
    "ax.set_ylabel('Count (log scale)', fontsize=13)\n",
    "ax.set_title('Boolean Functions vs. Linearly Separable Ones',\n",
    "             fontsize=13, fontweight='bold')\n",
    "ax.legend(fontsize=11)\n",
    "ax.grid(True, alpha=0.3, which='both')\n",
    "ax.set_xticks(n_values)\n",
    "\n",
    "# Panel 2: Fraction (log scale)\n",
    "ax = axes[1]\n",
    "ax.semilogy(n_values, fractions, 'g^-', linewidth=2, markersize=10,\n",
    "            color='darkgreen')\n",
    "ax.set_xlabel('Number of inputs $n$', fontsize=13)\n",
    "ax.set_ylabel('Fraction (log scale)', fontsize=13)\n",
    "ax.set_title('Fraction of Boolean Functions That Are\\nLinearly Separable',\n",
    "             fontsize=13, fontweight='bold')\n",
    "ax.grid(True, alpha=0.3, which='both')\n",
    "ax.set_xticks(n_values)\n",
    "\n",
    "# Annotate\n",
    "for i, (n, frac) in enumerate(zip(n_values, fractions)):\n",
    "    ax.annotate(f'{frac:.2e}', (n, frac), textcoords='offset points',\n",
    "                xytext=(10, 5), fontsize=9)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "print(\"\\nConclusion: As n grows, the fraction of linearly separable\")\n",
    "print(\"functions goes to zero super-exponentially. Most Boolean\")\n",
    "print(\"functions are NOT computable by a single perceptron.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-8a",
   "metadata": {},
   "source": [
    "### Pie Charts: Separable vs. Non-Separable\n",
    "\n",
    "The following visualization shows the proportion of linearly separable functions for $n = 2, 3, 4, 5$ inputs using pie charts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-8b",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": "import numpy as np\nimport matplotlib.pyplot as plt\n\n# Pie chart showing separable vs non-separable for n=2,3,4,5\nn_vals_pie = [2, 3, 4, 5]\ntotal_fns = [2**(2**n) for n in n_vals_pie]\nsep_counts = [14, 104, 1882, 94572]\nnonsep_counts = [t - s for t, s in zip(total_fns, sep_counts)]\n\nfig, axes = plt.subplots(1, 4, figsize=(20, 5))\n\ncolors_pie = ['#2ecc71', '#e74c3c']\n\nfor i, (n, total, sep, nonsep) in enumerate(zip(n_vals_pie, total_fns, sep_counts, nonsep_counts)):\n    ax = axes[i]\n    frac_sep = sep / total\n    frac_nonsep = nonsep / total\n    \n    sizes = [frac_sep, frac_nonsep]\n    \n    # For n>=4, the separable slice is too thin -- use explode\n    explode = (0.05, 0) if frac_sep > 0.01 else (0.15, 0)\n    \n    # autopct=None returns only (wedges, texts), not 3 values\n    wedges, texts = ax.pie(\n        sizes, labels=None, colors=colors_pie,\n        explode=explode,\n        startangle=90, wedgeprops={'edgecolor': 'black', 'linewidth': 1.5}\n    )\n    \n    ax.set_title(f'$n = {n}$\\n$2^{{2^{n}}} = {total:,}$ total',\n                 fontsize=12, fontweight='bold')\n    \n    # Add text below\n    ax.text(0, -1.4, f'Separable: {sep:,} ({frac_sep:.2%})',\n            ha='center', fontsize=10, color='#27ae60', fontweight='bold')\n    ax.text(0, -1.65, f'Non-sep: {nonsep:,} ({frac_nonsep:.2%})',\n            ha='center', fontsize=10, color='#c0392b', fontweight='bold')\n\n# Add a shared legend\nfig.legend(['Linearly Separable', 'Non-Separable'],\n           loc='upper center', ncol=2, fontsize=12,\n           bbox_to_anchor=(0.5, 1.05),\n           facecolor='white', edgecolor='black')\n\nfig.suptitle('Proportion of Linearly Separable Boolean Functions',\n             fontsize=15, fontweight='bold', y=1.12)\nplt.tight_layout()\nplt.show()"
  },
  {
   "cell_type": "markdown",
   "id": "cell-9",
   "metadata": {},
   "source": [
    "## 7.5 The Convex Hull Criterion\n",
    "\n",
    "We now present a beautiful geometric characterization of linear separability.\n",
    "\n",
    "### Theorem (Convex Hull Separability)\n",
    "\n",
    "> Two finite sets $S_0, S_1 \\subset \\mathbb{R}^n$ are **linearly separable** if and only if their convex hulls are **disjoint**:\n",
    ">\n",
    "> $$\\text{conv}(S_0) \\cap \\text{conv}(S_1) = \\emptyset$$\n",
    "\n",
    "### Proof Sketch\n",
    "\n",
    "$(\\Rightarrow)$ **If separable, then convex hulls are disjoint.**\n",
    "\n",
    "Suppose $\\mathbf{w} \\cdot \\mathbf{x} + b > 0$ for all $\\mathbf{x} \\in S_1$ and $\\mathbf{w} \\cdot \\mathbf{x} + b < 0$ for all $\\mathbf{x} \\in S_0$. Let $\\mathbf{p} \\in \\text{conv}(S_1)$. Then $\\mathbf{p} = \\sum_i \\lambda_i \\mathbf{x}_i$ with $\\mathbf{x}_i \\in S_1$, $\\lambda_i \\geq 0$, $\\sum \\lambda_i = 1$. Hence:\n",
    "\n",
    "$$\\mathbf{w} \\cdot \\mathbf{p} + b = \\sum_i \\lambda_i(\\mathbf{w} \\cdot \\mathbf{x}_i + b) > 0$$\n",
    "\n",
    "Similarly, for any $\\mathbf{q} \\in \\text{conv}(S_0)$ we get $\\mathbf{w} \\cdot \\mathbf{q} + b < 0$, hence $\\mathbf{p} \\neq \\mathbf{q}$, and the convex hulls are disjoint.\n",
    "\n",
    "$(\\Leftarrow)$ **If convex hulls are disjoint, then separable.**\n",
    "\n",
    "This follows from the **Separating Hyperplane Theorem** (a consequence of the Hahn-Banach theorem): if two disjoint convex compact sets in $\\mathbb{R}^n$ exist, there is a hyperplane separating them. Since $\\text{conv}(S_0)$ and $\\text{conv}(S_1)$ are convex and compact (being convex hulls of finite sets), a separating hyperplane exists. $\\blacksquare$\n",
    "\n",
    "### Why This Matters\n",
    "\n",
    "The convex hull criterion gives us an intuitive way to determine linear separability:\n",
    "- **AND**: The class-1 convex hull is the single point $\\{(1,1)\\}$. The class-0 convex hull is the triangle with vertices $(0,0), (0,1), (1,0)$. These are disjoint, so AND is separable.\n",
    "- **XOR**: Class-1 points are $\\{(0,1), (1,0)\\}$ and class-0 points are $\\{(0,0), (1,1)\\}$. The convex hull of class 1 is the line segment from $(0,1)$ to $(1,0)$, and the convex hull of class 0 is the line segment from $(0,0)$ to $(1,1)$. These two segments **intersect** at $(0.5, 0.5)$, so XOR is NOT separable."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-10",
   "metadata": {},
   "source": [
    "## 7.6 Convex Hull Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-11",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.patches import Polygon\n",
    "from scipy.spatial import ConvexHull\n",
    "\n",
    "def plot_convex_hulls(X, y, ax, title, function_name):\n",
    "    \"\"\"Plot data points and convex hulls for each class.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    X : np.ndarray of shape (n, 2)\n",
    "    y : np.ndarray of shape (n,)\n",
    "    ax : matplotlib axes\n",
    "    title : str\n",
    "    function_name : str\n",
    "    \"\"\"\n",
    "    class_0 = X[y == 0]\n",
    "    class_1 = X[y == 1]\n",
    "    \n",
    "    # Plot convex hulls\n",
    "    for cls_data, color, label in [(class_0, 'red', 'Class 0'),\n",
    "                                    (class_1, 'blue', 'Class 1')]:\n",
    "        if len(cls_data) >= 3:\n",
    "            hull = ConvexHull(cls_data)\n",
    "            hull_vertices = cls_data[hull.vertices]\n",
    "            hull_vertices = np.vstack([hull_vertices, hull_vertices[0]])\n",
    "            polygon = Polygon(hull_vertices, alpha=0.2, color=color,\n",
    "                            edgecolor=color, linewidth=2)\n",
    "            ax.add_patch(polygon)\n",
    "        elif len(cls_data) == 2:\n",
    "            # Line segment\n",
    "            ax.plot(cls_data[:, 0], cls_data[:, 1], '-', color=color,\n",
    "                    linewidth=3, alpha=0.4)\n",
    "        # Single point: just the scatter will show it\n",
    "    \n",
    "    # Plot data points\n",
    "    if len(class_0) > 0:\n",
    "        ax.scatter(class_0[:, 0], class_0[:, 1], c='red', marker='o',\n",
    "                   s=200, edgecolors='black', zorder=5, linewidths=2,\n",
    "                   label='Class 0')\n",
    "    if len(class_1) > 0:\n",
    "        ax.scatter(class_1[:, 0], class_1[:, 1], c='blue', marker='s',\n",
    "                   s=200, edgecolors='black', zorder=5, linewidths=2,\n",
    "                   label='Class 1')\n",
    "    \n",
    "    ax.set_xlim(-0.3, 1.3)\n",
    "    ax.set_ylim(-0.3, 1.3)\n",
    "    ax.set_aspect('equal')\n",
    "    ax.set_title(title, fontsize=13, fontweight='bold')\n",
    "    ax.legend(fontsize=10, loc='upper left')\n",
    "    ax.grid(True, alpha=0.3)\n",
    "\n",
    "\n",
    "fig, axes = plt.subplots(2, 2, figsize=(14, 14))\n",
    "\n",
    "# AND: separable (convex hulls disjoint)\n",
    "plot_convex_hulls(inputs_2, np.array([0,0,0,1]), axes[0,0],\n",
    "                  'AND: Convex Hulls are DISJOINT\\n(Linearly Separable)',\n",
    "                  'AND')\n",
    "axes[0,0].set_title('AND: Convex Hulls are DISJOINT\\n(Linearly Separable)',\n",
    "                     fontsize=13, fontweight='bold', color='green')\n",
    "\n",
    "# OR: separable\n",
    "plot_convex_hulls(inputs_2, np.array([0,1,1,1]), axes[0,1],\n",
    "                  'OR: Convex Hulls are DISJOINT\\n(Linearly Separable)',\n",
    "                  'OR')\n",
    "axes[0,1].set_title('OR: Convex Hulls are DISJOINT\\n(Linearly Separable)',\n",
    "                     fontsize=13, fontweight='bold', color='green')\n",
    "\n",
    "# XOR: NOT separable (convex hulls intersect)\n",
    "plot_convex_hulls(inputs_2, np.array([0,1,1,0]), axes[1,0],\n",
    "                  'XOR: Convex Hulls INTERSECT\\n(NOT Linearly Separable)',\n",
    "                  'XOR')\n",
    "axes[1,0].set_title('XOR: Convex Hulls INTERSECT\\n(NOT Linearly Separable)',\n",
    "                     fontsize=13, fontweight='bold', color='red')\n",
    "# Mark intersection point\n",
    "axes[1,0].plot(0.5, 0.5, 'kx', markersize=15, markeredgewidth=3, zorder=10)\n",
    "axes[1,0].annotate('Intersection\\n(0.5, 0.5)', xy=(0.5, 0.5),\n",
    "                    xytext=(0.75, 0.15), fontsize=11,\n",
    "                    arrowprops=dict(arrowstyle='->', color='black'),\n",
    "                    fontweight='bold')\n",
    "\n",
    "# XNOR: NOT separable (convex hulls intersect)\n",
    "plot_convex_hulls(inputs_2, np.array([1,0,0,1]), axes[1,1],\n",
    "                  'XNOR: Convex Hulls INTERSECT\\n(NOT Linearly Separable)',\n",
    "                  'XNOR')\n",
    "axes[1,1].set_title('XNOR: Convex Hulls INTERSECT\\n(NOT Linearly Separable)',\n",
    "                     fontsize=13, fontweight='bold', color='red')\n",
    "axes[1,1].plot(0.5, 0.5, 'kx', markersize=15, markeredgewidth=3, zorder=10)\n",
    "axes[1,1].annotate('Intersection\\n(0.5, 0.5)', xy=(0.5, 0.5),\n",
    "                    xytext=(0.75, 0.15), fontsize=11,\n",
    "                    arrowprops=dict(arrowstyle='->', color='black'),\n",
    "                    fontweight='bold')\n",
    "\n",
    "fig.suptitle('Convex Hull Criterion for Linear Separability',\n",
    "             fontsize=16, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-12",
   "metadata": {},
   "source": [
    "In the XOR and XNOR plots, the convex hulls of the two classes (line segments connecting the respective corners of the unit square) cross at the point $(0.5, 0.5)$. This intersection proves that no hyperplane (line) can separate the two classes."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-13",
   "metadata": {},
   "source": [
    "## 7.7 VC Dimension\n",
    "\n",
    "The **Vapnik-Chervonenkis (VC) dimension** is a fundamental measure of the capacity of a class of binary classifiers. It tells us the largest number of points that the classifier can \"shatter\" -- that is, classify correctly for every possible labeling.\n",
    "\n",
    "### Definition: Shattering\n",
    "\n",
    "> A set of points $S = \\{\\mathbf{x}_1, \\ldots, \\mathbf{x}_m\\} \\subset \\mathbb{R}^n$ is **shattered** by a class of classifiers $\\mathcal{H}$ if, for every possible labeling $(y_1, \\ldots, y_m) \\in \\{0, 1\\}^m$, there exists some $h \\in \\mathcal{H}$ that correctly classifies all points:\n",
    ">\n",
    "> $$h(\\mathbf{x}_i) = y_i \\quad \\text{for all } i = 1, \\ldots, m$$\n",
    "\n",
    "In other words, $\\mathcal{H}$ shatters $S$ if it can realize all $2^m$ possible dichotomies of $S$.\n",
    "\n",
    "### Definition: VC Dimension\n",
    "\n",
    "> The **VC dimension** of a hypothesis class $\\mathcal{H}$, denoted $\\text{VCdim}(\\mathcal{H})$, is the size of the largest set that $\\mathcal{H}$ can shatter:\n",
    ">\n",
    "> $$\\text{VCdim}(\\mathcal{H}) = \\max\\{m : \\exists S \\text{ of size } m \\text{ that } \\mathcal{H} \\text{ shatters}\\}$$\n",
    "\n",
    "### Theorem: VC Dimension of the Perceptron\n",
    "\n",
    "> The VC dimension of the perceptron (half-spaces) in $\\mathbb{R}^n$ is exactly $n + 1$.\n",
    "\n",
    "This means:\n",
    "- In $\\mathbb{R}^2$ (a line classifier): VCdim = 3. Any 3 points in general position can be shattered; no set of 4 points can always be shattered.\n",
    "- In $\\mathbb{R}^n$: VCdim = $n+1$.\n",
    "\n",
    "### Example: 3 Points in $\\mathbb{R}^2$\n",
    "\n",
    "Consider three non-collinear points in $\\mathbb{R}^2$. There are $2^3 = 8$ possible labelings. We must show that for **each** labeling, there exists a line that separates the two classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-14",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Demonstrate shattering of 3 points in R^2\n",
    "# Three non-collinear points\n",
    "points_3 = np.array([[0.0, 0.0], [1.0, 0.0], [0.5, 1.0]])\n",
    "\n",
    "# All 8 possible labelings\n",
    "all_labelings = []\n",
    "for i in range(8):\n",
    "    labeling = [(i >> bit) & 1 for bit in range(3)]\n",
    "    all_labelings.append(labeling)\n",
    "\n",
    "def find_separating_line(X, y):\n",
    "    \"\"\"Find separating weights for a small dataset using the perceptron.\"\"\"\n",
    "    y_arr = np.array(y)\n",
    "    \n",
    "    # Handle trivial cases (all same label)\n",
    "    if np.all(y_arr == 0):\n",
    "        return np.array([0.0, 0.0]), -1.0\n",
    "    if np.all(y_arr == 1):\n",
    "        return np.array([0.0, 0.0]), 1.0\n",
    "    \n",
    "    # Run perceptron\n",
    "    w = np.zeros(2)\n",
    "    b = 0.0\n",
    "    for epoch in range(1000):\n",
    "        errors = 0\n",
    "        for i in range(len(X)):\n",
    "            y_hat = int(w @ X[i] + b >= 0)\n",
    "            if y_hat != y_arr[i]:\n",
    "                update = y_arr[i] - y_hat\n",
    "                w = w + update * X[i]\n",
    "                b = b + update\n",
    "                errors += 1\n",
    "        if errors == 0:\n",
    "            return w, b\n",
    "    return w, b  # May not have converged\n",
    "\n",
    "\n",
    "fig, axes = plt.subplots(2, 4, figsize=(20, 10))\n",
    "\n",
    "for idx, labeling in enumerate(all_labelings):\n",
    "    row, col = divmod(idx, 4)\n",
    "    ax = axes[row, col]\n",
    "    \n",
    "    y_label = np.array(labeling)\n",
    "    \n",
    "    # Find separating line\n",
    "    w, b = find_separating_line(points_3, y_label)\n",
    "    \n",
    "    # Plot decision regions\n",
    "    xx, yy = np.meshgrid(np.linspace(-0.5, 1.5, 200),\n",
    "                         np.linspace(-0.5, 1.5, 200))\n",
    "    Z = xx * w[0] + yy * w[1] + b\n",
    "    \n",
    "    ax.contourf(xx, yy, Z, levels=[-1e10, 0, 1e10],\n",
    "                colors=['#FFCCCC', '#CCCCFF'], alpha=0.4)\n",
    "    if np.linalg.norm(w) > 0:\n",
    "        ax.contour(xx, yy, Z, levels=[0], colors='black', linewidths=2)\n",
    "    \n",
    "    # Plot points\n",
    "    for i in range(3):\n",
    "        color = 'blue' if y_label[i] == 1 else 'red'\n",
    "        marker = 's' if y_label[i] == 1 else 'o'\n",
    "        ax.scatter(points_3[i, 0], points_3[i, 1], c=color, marker=marker,\n",
    "                   s=200, edgecolors='black', zorder=5, linewidths=2)\n",
    "    \n",
    "    label_str = ''.join(str(l) for l in labeling)\n",
    "    ax.set_title(f'Labeling: ({label_str})', fontsize=11, fontweight='bold')\n",
    "    ax.set_xlim(-0.5, 1.5)\n",
    "    ax.set_ylim(-0.5, 1.5)\n",
    "    ax.set_aspect('equal')\n",
    "    ax.grid(True, alpha=0.3)\n",
    "\n",
    "fig.suptitle('Shattering 3 Points in $\\\\mathbb{R}^2$: All $2^3 = 8$ Labelings\\n'\n",
    "             '(Each has a separating line $\\\\Rightarrow$ VCdim $\\\\geq$ 3)',\n",
    "             fontsize=15, fontweight='bold', y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-15",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Show that 4 points in general position in R^2 CANNOT be shattered\n",
    "# Take 4 points: the vertices of a square\n",
    "points_4 = np.array([[0, 0], [1, 0], [0, 1], [1, 1]])\n",
    "\n",
    "# There are 2^4 = 16 labelings. We need to find one that is NOT separable.\n",
    "# The XOR labeling (0,1,1,0) is not separable, as we've seen.\n",
    "\n",
    "# Check all 16 labelings\n",
    "unseparable_count = 0\n",
    "unseparable_examples = []\n",
    "\n",
    "for i in range(16):\n",
    "    labeling = np.array([(i >> bit) & 1 for bit in range(4)])\n",
    "    \n",
    "    # Try to find separating line using perceptron (many epochs)\n",
    "    w = np.zeros(2)\n",
    "    b = 0.0\n",
    "    converged = False\n",
    "    \n",
    "    # Handle trivial cases\n",
    "    if np.all(labeling == 0) or np.all(labeling == 1):\n",
    "        converged = True\n",
    "    else:\n",
    "        for epoch in range(500):\n",
    "            errors = 0\n",
    "            for j in range(4):\n",
    "                y_hat = int(w @ points_4[j] + b >= 0)\n",
    "                if y_hat != labeling[j]:\n",
    "                    update = labeling[j] - y_hat\n",
    "                    w = w + update * points_4[j]\n",
    "                    b = b + update\n",
    "                    errors += 1\n",
    "            if errors == 0:\n",
    "                converged = True\n",
    "                break\n",
    "    \n",
    "    if not converged:\n",
    "        unseparable_count += 1\n",
    "        unseparable_examples.append(labeling)\n",
    "\n",
    "print(f\"Out of 16 labelings of 4 points, {unseparable_count} are NOT linearly separable.\")\n",
    "print(f\"\\nUnseparable labelings:\")\n",
    "for lab in unseparable_examples:\n",
    "    label_str = ''.join(str(l) for l in lab)\n",
    "    print(f\"  ({label_str})\")\n",
    "\n",
    "print(f\"\\nSince not all {2**4} labelings can be realized,\")\n",
    "print(f\"4 points in R^2 CANNOT be shattered.\")\n",
    "print(f\"Therefore VCdim(perceptron in R^2) = 3 = n+1 (with n=2).\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-16",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Visualize the unseparable labelings of 4 points\n",
    "n_unsep = len(unseparable_examples)\n",
    "fig, axes = plt.subplots(1, n_unsep, figsize=(6 * n_unsep, 5))\n",
    "if n_unsep == 1:\n",
    "    axes = [axes]\n",
    "\n",
    "for ax, labeling in zip(axes, unseparable_examples):\n",
    "    for i in range(4):\n",
    "        color = 'blue' if labeling[i] == 1 else 'red'\n",
    "        marker = 's' if labeling[i] == 1 else 'o'\n",
    "        ax.scatter(points_4[i, 0], points_4[i, 1], c=color, marker=marker,\n",
    "                   s=200, edgecolors='black', zorder=5, linewidths=2)\n",
    "    \n",
    "    # Show that no line works by drawing several failed attempts\n",
    "    x_line = np.linspace(-0.5, 1.5, 100)\n",
    "    for angle in np.linspace(0, 180, 12):\n",
    "        rad = np.radians(angle)\n",
    "        w_try = np.array([np.cos(rad), np.sin(rad)])\n",
    "        for b_try in np.linspace(-2, 2, 10):\n",
    "            preds = np.array([int(w_try @ points_4[j] + b_try >= 0) for j in range(4)])\n",
    "            if np.array_equal(preds, labeling):\n",
    "                # Found a separating line (shouldn't happen for XOR/XNOR)\n",
    "                break\n",
    "    \n",
    "    # Draw several failed attempts as gray dashed lines\n",
    "    for angle in [30, 60, 90, 120, 150]:\n",
    "        rad = np.radians(angle)\n",
    "        slope = -np.cos(rad) / (np.sin(rad) + 1e-10)\n",
    "        y_line = slope * (x_line - 0.5) + 0.5\n",
    "        ax.plot(x_line, y_line, '--', color='gray', alpha=0.3, linewidth=1)\n",
    "    \n",
    "    label_str = ''.join(str(l) for l in labeling)\n",
    "    ax.set_title(f'Labeling ({label_str}): NOT separable',\n",
    "                 fontsize=12, fontweight='bold', color='red')\n",
    "    ax.set_xlim(-0.3, 1.3)\n",
    "    ax.set_ylim(-0.3, 1.3)\n",
    "    ax.set_aspect('equal')\n",
    "    ax.grid(True, alpha=0.3)\n",
    "\n",
    "fig.suptitle('4 Points in $\\\\mathbb{R}^2$: Labelings That Cannot Be Separated\\n'\n",
    "             '(Proving VCdim < 4)',\n",
    "             fontsize=14, fontweight='bold', y=1.05)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-17",
   "metadata": {},
   "source": [
    "### Proof that VCdim = $n+1$ for Perceptrons in $\\mathbb{R}^n$\n",
    "\n",
    "**VCdim $\\geq n+1$** (Lower bound):\n",
    "\n",
    "Consider the $n+1$ points in $\\mathbb{R}^n$ given by the origin and the $n$ standard basis vectors: $\\{\\mathbf{0}, \\mathbf{e}_1, \\mathbf{e}_2, \\ldots, \\mathbf{e}_n\\}$. These $n+1$ points can be shattered by half-spaces. For any labeling $y_0, y_1, \\ldots, y_n \\in \\{0, 1\\}$, one can construct a weight vector $\\mathbf{w}$ and bias $b$ that achieves the labeling.\n",
    "\n",
    "**VCdim $\\leq n+1$** (Upper bound, via Radon's theorem):\n",
    "\n",
    "**Radon's theorem** states that any set of $n+2$ points in $\\mathbb{R}^n$ can be partitioned into two subsets whose convex hulls intersect. But if the convex hulls of the two classes intersect, the labeling is not linearly separable (by the convex hull criterion). Therefore, no set of $n+2$ points can be shattered.\n",
    "\n",
    "Combining: VCdim = $n+1$. $\\blacksquare$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cell-18",
   "metadata": {},
   "source": [
    "## 7.8 Exercises\n",
    "\n",
    "### Exercise 7.1: Symmetric Truth Tables\n",
    "\n",
    "A Boolean function $f(x_1, x_2)$ has a **symmetric truth table** if swapping the inputs does not change the output: $f(x_1, x_2) = f(x_2, x_1)$ for all inputs.\n",
    "\n",
    "1. Which of the 16 two-input Boolean functions have symmetric truth tables? (List them.)\n",
    "2. Among the symmetric ones, which are linearly separable?\n",
    "3. Verify that AND, OR, NAND, NOR are the only symmetric, linearly separable, non-trivial gates.\n",
    "\n",
    "```{hint}\n",
    ":class: dropdown\n",
    "A function is symmetric if $f(0,1) = f(1,0)$, i.e., the second and third entries in the output column are equal. Check each of the 16 functions for this condition. You should find that functions #0, 1, 6, 7, 8, 9, 14, 15 are symmetric. Among these, #6 (XOR) and #9 (XNOR) are not separable, and #0 (FALSE) and #15 (TRUE) are trivial.\n",
    "```\n",
    "\n",
    "### Exercise 7.2: The 3-Input Majority Function\n",
    "\n",
    "The **majority function** of 3 inputs outputs 1 if and only if at least 2 of the 3 inputs are 1:\n",
    "\n",
    "$$\\text{MAJ}(x_1, x_2, x_3) = \\begin{cases} 1 & \\text{if } x_1 + x_2 + x_3 \\geq 2 \\\\ 0 & \\text{otherwise} \\end{cases}$$\n",
    "\n",
    "1. Write out the complete truth table.\n",
    "2. Find perceptron weights $\\mathbf{w}$ and bias $b$ that compute this function.\n",
    "3. Verify your answer. Is the majority function linearly separable?\n",
    "\n",
    "```{hint}\n",
    ":class: dropdown\n",
    "Since the majority function fires when $x_1 + x_2 + x_3 \\geq 2$, try $\\mathbf{w} = (1, 1, 1)$ and $b = -1.5$. Then $\\mathbf{w} \\cdot \\mathbf{x} + b = x_1 + x_2 + x_3 - 1.5$, which is $\\geq 0$ exactly when the sum is $\\geq 2$.\n",
    "```\n",
    "\n",
    "### Exercise 7.3: Unshattering\n",
    "\n",
    "Find a set of 3 points in $\\mathbb{R}^2$ that **cannot** be shattered by any line. Why does this not contradict the fact that VCdim = 3?\n",
    "\n",
    "```{hint}\n",
    ":class: dropdown\n",
    "Consider three **collinear** points, e.g., $(0,0)$, $(1,1)$, $(2,2)$. Try to separate the labeling where the middle point has a different label from the outer two (e.g., labels = (0, 1, 0)). You will find that no line can isolate the middle point. This does NOT contradict VCdim = 3 because the VC dimension only requires that SOME set of 3 points can be shattered -- not ALL sets. Points in \"degenerate position\" (collinear) cannot be shattered.\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-19",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Exercise 7.1 verification\n",
    "print(\"=\" * 60)\n",
    "print(\"Exercise 7.1: Symmetric Boolean Functions\")\n",
    "print(\"=\" * 60)\n",
    "print()\n",
    "print(\"A function f is symmetric if f(x1,x2) = f(x2,x1),\")\n",
    "print(\"i.e., f(0,1) = f(1,0).\")\n",
    "print()\n",
    "\n",
    "symmetric_functions = []\n",
    "for idx, info in boolean_functions.items():\n",
    "    outputs = info['outputs']\n",
    "    # Check if f(0,1) == f(1,0), i.e., outputs[1] == outputs[2]\n",
    "    if outputs[1] == outputs[2]:\n",
    "        symmetric_functions.append(idx)\n",
    "        sep_str = 'Yes' if info['separable'] else 'No'\n",
    "        trivial = '(trivial)' if idx in [0, 15] else ''\n",
    "        print(f\"  #{idx:>2d}: {info['name']:>25s} | Separable: {sep_str} {trivial}\")\n",
    "\n",
    "print(f\"\\nTotal symmetric functions: {len(symmetric_functions)}\")\n",
    "print(\"Non-trivial symmetric and separable: AND, OR, NOR, NAND\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-20",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Exercise 7.2: 3-input majority function\n",
    "print(\"=\" * 60)\n",
    "print(\"Exercise 7.2: 3-Input Majority Function\")\n",
    "print(\"=\" * 60)\n",
    "print()\n",
    "\n",
    "# Truth table\n",
    "inputs_3 = np.array([[0,0,0], [0,0,1], [0,1,0], [0,1,1],\n",
    "                      [1,0,0], [1,0,1], [1,1,0], [1,1,1]])\n",
    "y_majority = (inputs_3.sum(axis=1) >= 2).astype(int)\n",
    "\n",
    "print(\"Truth table:\")\n",
    "print(f\"  {'x1':>3} {'x2':>3} {'x3':>3} | {'MAJ':>4}\")\n",
    "print(\"  \" + \"-\" * 20)\n",
    "for x, y in zip(inputs_3, y_majority):\n",
    "    print(f\"  {x[0]:>3} {x[1]:>3} {x[2]:>3} | {y:>4}\")\n",
    "\n",
    "# Find weights: w = (1,1,1), b = -1.5\n",
    "w_maj = np.array([1, 1, 1])\n",
    "b_maj = -1.5\n",
    "\n",
    "print(f\"\\nProposed weights: w = {w_maj}, b = {b_maj}\")\n",
    "print(\"\\nVerification:\")\n",
    "for x, y in zip(inputs_3, y_majority):\n",
    "    z = w_maj @ x + b_maj\n",
    "    y_hat = int(z >= 0)\n",
    "    status = 'OK' if y_hat == y else 'FAIL'\n",
    "    print(f\"  x={x}, z={z:+.1f}, y_hat={y_hat}, y={y} [{status}]\")\n",
    "\n",
    "print(\"\\nThe 3-input majority function IS linearly separable.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cell-21",
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Exercise 7.3: Three collinear points cannot be shattered\n",
    "print(\"=\" * 60)\n",
    "print(\"Exercise 7.3: Three Collinear Points\")\n",
    "print(\"=\" * 60)\n",
    "print()\n",
    "\n",
    "# Three collinear points\n",
    "collinear_points = np.array([[0, 0], [1, 1], [2, 2]])\n",
    "\n",
    "print(\"Points: (0,0), (1,1), (2,2) -- all on the line y = x\")\n",
    "print()\n",
    "\n",
    "# Try all 8 labelings\n",
    "print(\"Trying all 8 labelings:\")\n",
    "fail_count = 0\n",
    "\n",
    "for i in range(8):\n",
    "    labeling = np.array([(i >> bit) & 1 for bit in range(3)])\n",
    "    label_str = ''.join(str(l) for l in labeling)\n",
    "    \n",
    "    # Try perceptron\n",
    "    w = np.zeros(2)\n",
    "    b = 0.0\n",
    "    converged = False\n",
    "    \n",
    "    if np.all(labeling == 0) or np.all(labeling == 1):\n",
    "        converged = True\n",
    "    else:\n",
    "        for epoch in range(500):\n",
    "            errors = 0\n",
    "            for j in range(3):\n",
    "                y_hat = int(w @ collinear_points[j] + b >= 0)\n",
    "                if y_hat != labeling[j]:\n",
    "                    update = labeling[j] - y_hat\n",
    "                    w = w + update * collinear_points[j]\n",
    "                    b = b + update\n",
    "                    errors += 1\n",
    "            if errors == 0:\n",
    "                converged = True\n",
    "                break\n",
    "    \n",
    "    status = 'separable' if converged else 'NOT separable'\n",
    "    if not converged:\n",
    "        fail_count += 1\n",
    "    print(f\"  ({label_str}): {status}\")\n",
    "\n",
    "print(f\"\\n{fail_count} labeling(s) are not separable.\")\n",
    "print(f\"Therefore, these 3 collinear points CANNOT be shattered.\")\n",
    "print()\n",
    "print(\"This does NOT contradict VCdim = 3, because the VC dimension\")\n",
    "print(\"only requires that SOME set of 3 points can be shattered,\")\n",
    "print(\"not ALL sets. Collinear points are in 'degenerate position.'\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}