Three Separation Lemmas#
The mathematical heart of Monico’s elementary UAT proof
The proof strategy is to show progressively stronger separation results. Each lemma builds on the previous one, adding one hidden layer at each step:
This notebook covers the three separation lemmas (Lemmas 3.1–3.3 of Monico, 2024). Each lemma is stated formally, proved step by step with auxiliary commentary, verified numerically, and illustrated with interactive plots. For notation and prerequisites, see IP01: Overview & Notations.
Section 1: Lemma 3.1 — Point Separation#
Lemma 3.1 (Point-Point Separation)
Let \(x_0\) and \(x_1\) be distinct real numbers. For each \(\varepsilon > 0\) there exist \(s, t \in \mathbb{R}\) such that
If, in addition, \(x_0 < x_1\) and \(\varepsilon < 1/2\), then \(\sigma(s + tx) < \varepsilon\) on the interval \((-\infty, x_0]\) and \(\sigma(s + tx) > 1 - \varepsilon\) on the interval \([x_1, \infty)\).
Intuition#
We want to shove \(x_0\) down to \(0\) and \(x_1\) up to \(1\) using \(\sigma\). The trick: an affine pre-map \(s + tx\) lets us place \(\sigma\)’s steep transition zone exactly between \(x_0\) and \(x_1\). By making \(t\) large enough (i.e., making the affine map steep), we can squeeze the transition into an arbitrarily narrow window, achieving any desired \(\varepsilon\)-separation.
Proof walkthrough#
Step 1: Find the IVT targets.
By the Intermediate Value Theorem, since \(\sigma\) is continuous with \(\lim_{x \to -\infty} \sigma(x) = 0\) and \(\lim_{x \to +\infty} \sigma(x) = 1\), for any target value \(c \in (0,1)\) there exists \(y \in \mathbb{R}\) with \(\sigma(y) = c\).
Choose \(y_0\) with \(\sigma(y_0) = \varepsilon/2\) and \(y_1\) with \(\sigma(y_1) = 1 - \varepsilon/2\).
Auxiliary: Why does \(y_0\) exist?
Because \(\sigma\) is continuous, \(\sigma(x) \to 0\) as \(x \to -\infty\), and \(\sigma(x) \to 1\) as \(x \to +\infty\). Since \(\varepsilon/2 \in (0,1)\), the Intermediate Value Theorem guarantees at least one \(y_0 \in \mathbb{R}\) with \(\sigma(y_0) = \varepsilon/2\). Furthermore, since \(\sigma\) is strictly increasing (it is increasing and continuous with distinct limits), this \(y_0\) is unique. For the standard sigmoid \(\sigma(x) = 1/(1+e^{-x})\), we can compute it explicitly: \(y_0 = \log\bigl(\varepsilon/(2 - \varepsilon)\bigr)\).
Step 2: Solve the \(2 \times 2\) linear system.
We need \(s + t x_0 = y_0\) and \(s + t x_1 = y_1\). In matrix form:
The determinant is \(x_1 - x_0 \neq 0\) (since \(x_0 \neq x_1\)), so the system has a unique solution:
Step 3: Verify the separation.
Substituting back:
This establishes the first claim of the lemma.
Step 4: Extend to intervals (the monotonicity argument).
Suppose in addition that \(x_0 < x_1\) and \(\varepsilon < 1/2\). Then:
Since \(\sigma\) is increasing and \(s + tx\) is affine (hence monotone), the composition \(x \mapsto \sigma(s + tx)\) must be increasing (if it were decreasing, we would have \(\sigma(s+tx_0) > \sigma(s+tx_1)\), contradicting the inequality above). Therefore:
For all \(x \leq x_0\): \(\;\sigma(s+tx) \leq \sigma(s+tx_0) < \varepsilon\).
For all \(x \geq x_1\): \(\;\sigma(s+tx) \geq \sigma(s+tx_1) > 1-\varepsilon\). \(\;\square\)
Numerical verification#
Show code cell source
import numpy as np
import matplotlib.pyplot as plt
def sigma(x):
"""Standard sigmoid, numerically stable."""
return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
def sigma_inv(y):
"""Inverse of sigmoid: log(y/(1-y))."""
return np.log(y / (1 - y))
# Example: x0=1, x1=3, eps=0.05
x0, x1, eps = 1.0, 3.0, 0.05
y0 = sigma_inv(eps / 2)
y1 = sigma_inv(1 - eps / 2)
# Solve: s + t*x0 = y0, s + t*x1 = y1
t_val = (y1 - y0) / (x1 - x0)
s_val = y0 - t_val * x0
print(f"Parameters: x\u2080={x0}, x\u2081={x1}, \u03b5={eps}")
print(f"IVT targets: y\u2080=\u03c3\u207b\u00b9({eps/2})={y0:.4f}, y\u2081=\u03c3\u207b\u00b9({1-eps/2})={y1:.4f}")
print(f"Solution: s={s_val:.4f}, t={t_val:.4f}")
print()
val0 = sigma(s_val + t_val * x0)
val1 = sigma(s_val + t_val * x1)
print(f"\u03c3(s+t\u00b7x\u2080) = \u03c3({s_val + t_val*x0:.4f}) = {val0:.6f} < {eps} \u2713" if val0 < eps else f"\u2717")
print(f"\u03c3(s+t\u00b7x\u2081) = \u03c3({s_val + t_val*x1:.4f}) = {val1:.6f} > {1-eps} \u2713" if val1 > 1-eps else f"\u2717")
Parameters: x₀=1.0, x₁=3.0, ε=0.05
IVT targets: y₀=σ⁻¹(0.025)=-3.6636, y₁=σ⁻¹(0.975)=3.6636
Solution: s=-7.3271, t=3.6636
σ(s+t·x₀) = σ(-3.6636) = 0.025000 < 0.05 ✓
σ(s+t·x₁) = σ(3.6636) = 0.975000 > 0.95 ✓
Visual proof#
Show code cell source
# ============================================================
# Visual proof for Lemma 3.1
# ============================================================
ACCENT = '#4f46e5'
THM = '#059669'
WARN = '#d97706'
DANGER = '#dc2626'
fig, ax = plt.subplots(figsize=(9, 5))
# Plot sigma(s + tx) over a wide range
x_range = np.linspace(x0 - 3, x1 + 3, 1000)
y_range = sigma(s_val + t_val * x_range)
ax.plot(x_range, y_range, color=ACCENT, linewidth=2.5, label=r'$\sigma(s + tx)$', zorder=5)
# Vertical lines at x0 and x1
ax.axvline(x0, color=THM, linestyle='--', linewidth=1.2, alpha=0.8, label=f'$x_0 = {x0}$')
ax.axvline(x1, color=DANGER, linestyle='--', linewidth=1.2, alpha=0.8, label=f'$x_1 = {x1}$')
# Horizontal epsilon-bands
ax.axhspan(0, eps, alpha=0.12, color=THM, zorder=1)
ax.axhspan(1 - eps, 1, alpha=0.12, color=DANGER, zorder=1)
ax.axhline(eps, color=THM, linestyle=':', linewidth=1, alpha=0.6)
ax.axhline(1 - eps, color=DANGER, linestyle=':', linewidth=1, alpha=0.6)
# Shade regions where sigma(s+tx) < eps and sigma(s+tx) > 1-eps
mask_low = y_range < eps
mask_high = y_range > 1 - eps
ax.fill_between(x_range, 0, y_range, where=mask_low, alpha=0.15, color=THM, zorder=2)
ax.fill_between(x_range, y_range, 1, where=mask_high, alpha=0.15, color=DANGER, zorder=2)
# Highlight points (x0, sigma(s+tx0)) and (x1, sigma(s+tx1))
ax.plot(x0, sigma(s_val + t_val * x0), 'o', color=THM, markersize=10, zorder=6,
markeredgecolor='white', markeredgewidth=1.5)
ax.plot(x1, sigma(s_val + t_val * x1), 'o', color=DANGER, markersize=10, zorder=6,
markeredgecolor='white', markeredgewidth=1.5)
# Annotations
ax.annotate(f'$\\sigma(s+tx_0) = {sigma(s_val+t_val*x0):.3f}$',
xy=(x0, sigma(s_val + t_val * x0)),
xytext=(x0 - 1.5, 0.25), fontsize=10,
arrowprops=dict(arrowstyle='->', color=THM, lw=1.5),
color=THM, fontweight='bold')
ax.annotate(f'$\\sigma(s+tx_1) = {sigma(s_val+t_val*x1):.3f}$',
xy=(x1, sigma(s_val + t_val * x1)),
xytext=(x1 + 0.5, 0.7), fontsize=10,
arrowprops=dict(arrowstyle='->', color=DANGER, lw=1.5),
color=DANGER, fontweight='bold')
# Labels
ax.text(x0 - 2.5, eps / 2, f'$< \\varepsilon = {eps}$', fontsize=10, color=THM,
va='center', fontweight='bold')
ax.text(x1 + 1.5, 1 - eps / 2, f'$> 1-\\varepsilon = {1-eps}$', fontsize=10, color=DANGER,
va='center', fontweight='bold')
ax.set_xlabel('$x$', fontsize=13)
ax.set_ylabel(r'$\sigma(s + tx)$', fontsize=13)
ax.set_title('Lemma 3.1: Point-Point Separation', fontsize=14, fontweight='bold')
ax.set_ylim(-0.05, 1.08)
ax.legend(loc='center left', fontsize=10, framealpha=0.9)
ax.grid(True, alpha=0.3)
fig.tight_layout()
plt.show()
Try it yourself → Squashing Function Lab — drag \(x_0\) and \(x_1\), adjust \(\varepsilon\), and watch the linear system update in real time.
Section 2: Lemma 3.2 — Point-Set Separation#
Lemma 3.2 (Point-Set Separation)
Let \(B \subset K\) be a closed set, and \(\boldsymbol{x}_0 \in K \setminus B\). For each \(\varepsilon > 0\) there exists \(g \in \mathcal{N}_2\) such that
Intuition#
Cover \(B\) with finitely many “alarm bells” — each rings near its target \(\boldsymbol{b} \in B\) but stays quiet at \(\boldsymbol{x}_0\). Sum them up: the total alarm is loud on all of \(B\) (at least one bell rings for each point) but quiet at \(\boldsymbol{x}_0\) (each bell contributes only a tiny amount there).
Proof walkthrough#
Step 1: Separate each \(\boldsymbol{b}\) from \(\boldsymbol{x}_0\) individually.
WLOG assume \(0 < \varepsilon < 1/3\).
For each \(\boldsymbol{b} \in B\), since \(\boldsymbol{b} \neq \boldsymbol{x}_0\), they differ in at least one coordinate. Define \(f_{\boldsymbol{b}}\) as a suitable function from \(\mathcal{N}_1\) (an affine function composed with \(\sigma\), using Lemma 3.1) such that
Auxiliary: constructing \(f_{\boldsymbol{b}}\)
Since \(\boldsymbol{b} \neq \boldsymbol{x}_0\) in \(\mathbb{R}^n\), there exists an index \(i\) with \(b_i \neq (x_0)_i\). The coordinate projection \(\pi_i(\boldsymbol{x}) = x_i\) is an affine function, so \(\pi_i \in \mathcal{N}_1\), and it maps \(\boldsymbol{b}\) and \(\boldsymbol{x}_0\) to distinct real numbers \(b_i\) and \((x_0)_i\).
Now apply Lemma 3.1 to these two real numbers: find \(s, t\) such that \(\sigma(s + t \cdot (x_0)_i) < \varepsilon/2\) and \(\sigma(s + t \cdot b_i) > 1 - \varepsilon/2\). Then \(f_{\boldsymbol{b}}(\boldsymbol{x}) = \sigma(s + t \, x_i)\) is the composition \(\sigma \circ (s + t \, \pi_i) \in \mathcal{N}_1^\sigma\), which is what we need.
(Note: the paper writes \(f_{\boldsymbol{b}} \in \mathcal{N}_1\) for the pre-activation function, and the post-activation version \(\sigma \circ f_{\boldsymbol{b}}\) does the separation. We keep this convention.)
Step 2: Build an open cover of \(B\).
Define
Since \(f_{\boldsymbol{b}}\) is continuous, \(U_{\boldsymbol{b}}\) is open. And \(\boldsymbol{b} \in U_{\boldsymbol{b}}\) because \(f_{\boldsymbol{b}}(\boldsymbol{b}) > 1 - \varepsilon/2 > 1 - \varepsilon\).
Step 3: Extract a finite subcover by compactness.
\(\{U_{\boldsymbol{b}}\}_{\boldsymbol{b} \in B}\) is an open cover of \(B\). Since \(B\) is closed in the compact set \(K\), \(B\) itself is compact.
Key: Why compactness matters
Compactness is essential here. \(B\) might have infinitely many points, each needing its own alarm bell. Without compactness, we would need infinitely many alarm bells — and the sum might diverge or fail to stay below \(\varepsilon\) at \(\boldsymbol{x}_0\). Compactness guarantees a finite subcover, which is what makes the epsilon-management in Step 5 work.
Step 4: Finite subcover.
Extract a finite subcover: \(B \subset U_{\boldsymbol{b}_1} \cup \cdots \cup U_{\boldsymbol{b}_N}\).
Step 5: Sharpen with Lemma 3.1 and the \(\varepsilon/N\) trick.
Apply Lemma 3.1 to find \(s, t \in \mathbb{R}\) such that
Define \(F_j = \sigma(s + t \cdot f_{\boldsymbol{b}_j}) \in \mathcal{N}_1^\sigma\) for each \(1 \leq j \leq N\).
Auxiliary: Why \(\varepsilon/N\)?
Each \(F_j\) contributes at most \(\varepsilon/N\) to the sum at \(\boldsymbol{x}_0\) (because \(f_{\boldsymbol{b}_j}(\boldsymbol{x}_0) < \varepsilon/2 < \varepsilon\), so the input to \(\sigma(s + t \cdot \ldots)\) falls in \((-\infty, \varepsilon)\)). With \(N\) terms:
This is the epsilon-management trick: divide the budget equally among the finitely many cover elements.
Step 6: Sum and verify.
Let \(g = \sum_{j=1}^N F_j \in \mathcal{N}_2\). Then:
At \(\boldsymbol{x}_0\): \(g(\boldsymbol{x}_0) = \sum_{j=1}^N F_j(\boldsymbol{x}_0) < N \cdot (\varepsilon/N) = \varepsilon\). \(\;\checkmark\)
On \(B\): For any \(\boldsymbol{b} \in B\), there exists some \(k\) with \(\boldsymbol{b} \in U_{\boldsymbol{b}_k}\), so \(f_{\boldsymbol{b}_k}(\boldsymbol{b}) > 1 - \varepsilon\), hence \(F_k(\boldsymbol{b}) = \sigma(s + t \cdot f_{\boldsymbol{b}_k}(\boldsymbol{b})) > 1 - \varepsilon\), hence \(g(\boldsymbol{b}) \geq F_k(\boldsymbol{b}) > 1 - \varepsilon\). \(\;\checkmark\;\square\)
Layer counting#
Why is \(g \in \mathcal{N}_2\)?
Each \(f_{\boldsymbol{b}_j} \in \mathcal{N}_1\) (affine function of coordinates).
Each \(F_j = \sigma(s + t \cdot f_{\boldsymbol{b}_j})\). Since \(s + t \cdot f_{\boldsymbol{b}_j} \in \mathcal{N}_1\) (affine combination of affine functions), we have \(F_j \in \mathcal{N}_1^\sigma\).
The sum \(g = \sum F_j\) is an affine combination of elements of \(\mathcal{N}_1^\sigma\), which by definition is \(\mathcal{N}_2\).
Numerical demonstration#
Show code cell source
# ============================================================
# Lemma 3.2: Point-Set Separation in K = [0,1]^2
# B = disk of radius 0.3 centered at (0.5, 0.5)
# x0 = (0.9, 0.9)
# eps = 0.1
# ============================================================
import numpy as np
import matplotlib.pyplot as plt
def sigma(x):
return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
def sigma_inv(y):
return np.log(y / (1 - y))
ACCENT = '#4f46e5'
THM = '#059669'
WARN = '#d97706'
DANGER = '#dc2626'
# Parameters
center_B = np.array([0.5, 0.5])
radius_B = 0.3
x0_pt = np.array([0.9, 0.9])
eps = 0.1
# Choose N=8 evenly-spaced points on B's boundary
N_cover = 8
angles = np.linspace(0, 2 * np.pi, N_cover, endpoint=False)
b_points = np.column_stack([
center_B[0] + radius_B * np.cos(angles),
center_B[1] + radius_B * np.sin(angles)
])
# For each b_j, find the coordinate i where |b_j[i] - x0[i]| is maximized
# Then use Lemma 3.1 on that coordinate
def build_separator(b_pt, x0_pt, eps_target):
"""Build f_b using the coordinate that differs most from x0."""
diffs = np.abs(b_pt - x0_pt)
i = np.argmax(diffs) # coordinate index
# Apply Lemma 3.1 to separate b_pt[i] from x0_pt[i]
val_x0 = x0_pt[i]
val_b = b_pt[i]
y0_target = sigma_inv(eps_target / 2)
y1_target = sigma_inv(1 - eps_target / 2)
# We want sigma(s + t * val_x0) < eps_target and sigma(s + t * val_b) > 1-eps_target
# So: s + t * val_x0 = y0_target, s + t * val_b = y1_target
if abs(val_b - val_x0) < 1e-12:
# Fallback: use the other coordinate
i = 1 - i
val_x0 = x0_pt[i]
val_b = b_pt[i]
t_sep = (y1_target - y0_target) / (val_b - val_x0)
s_sep = y0_target - t_sep * val_x0
return i, s_sep, t_sep
# Build all separators
separators = [build_separator(b_points[j], x0_pt, eps) for j in range(N_cover)]
# Apply Lemma 3.1 for sharpening: sigma(s'+t'x) < eps/N on (-inf, eps)
# and sigma(s'+t'x) > 1-eps on (1-eps, inf)
eps_sharp = eps / N_cover
y0_sharp = sigma_inv(eps_sharp / 2) # target: sigma = eps/(2N)
y1_sharp = sigma_inv(1 - eps / 2) # target: sigma = 1-eps/2
# Solve for the sharpening affine map: s' + t'*eps = y0_sharp, s' + t'*(1-eps) = y1_sharp
# Actually: we want separation at points eps and 1-eps on the real line
val_low = eps # values of f_bj at x0 are below this
val_high = 1 - eps # values of f_bj at b are above this
t_sharp = (y1_sharp - y0_sharp) / (val_high - val_low)
s_sharp = y0_sharp - t_sharp * val_low
# Evaluate g on a grid
resolution = 200
xx = np.linspace(0, 1, resolution)
yy = np.linspace(0, 1, resolution)
XX, YY = np.meshgrid(xx, yy)
points = np.column_stack([XX.ravel(), YY.ravel()])
g_total = np.zeros(resolution * resolution)
for j in range(N_cover):
coord_idx, s_sep, t_sep = separators[j]
# f_bj(x) = sigma(s_sep + t_sep * x[coord_idx]) (from Lemma 3.1)
f_bj = sigma(s_sep + t_sep * points[:, coord_idx])
# F_j(x) = sigma(s_sharp + t_sharp * f_bj(x)) (sharpening)
F_j = sigma(s_sharp + t_sharp * f_bj)
g_total += F_j
G = g_total.reshape(resolution, resolution)
# Verify bounds
# Check g(x0)
g_at_x0 = 0
for j in range(N_cover):
coord_idx, s_sep, t_sep = separators[j]
f_val = sigma(s_sep + t_sep * x0_pt[coord_idx])
F_val = sigma(s_sharp + t_sharp * f_val)
g_at_x0 += F_val
print(f"Point-Set Separation (Lemma 3.2)")
print(f"================================")
print(f"B = disk of radius {radius_B} at {tuple(center_B)}")
print(f"x\u2080 = {tuple(x0_pt)}, \u03b5 = {eps}, N = {N_cover}")
print(f"")
print(f"g(x\u2080) = {g_at_x0:.6f} < {eps} \u2713" if g_at_x0 < eps else f"g(x\u2080) = {g_at_x0:.6f} \u2717")
# Check on B: sample points on the boundary and interior
test_angles = np.linspace(0, 2*np.pi, 100)
test_radii = np.linspace(0, radius_B, 10)
min_on_B = np.inf
for r in test_radii:
for a in test_angles:
pt = center_B + r * np.array([np.cos(a), np.sin(a)])
if 0 <= pt[0] <= 1 and 0 <= pt[1] <= 1:
g_val = 0
for j in range(N_cover):
coord_idx, s_sep, t_sep = separators[j]
f_val = sigma(s_sep + t_sep * pt[coord_idx])
F_val = sigma(s_sharp + t_sharp * f_val)
g_val += F_val
min_on_B = min(min_on_B, g_val)
print(f"min(g on B) = {min_on_B:.6f} > {1-eps} \u2713" if min_on_B > 1-eps else f"min(g on B) = {min_on_B:.6f} \u2717")
Point-Set Separation (Lemma 3.2)
================================
B = disk of radius 0.3 at (0.5, 0.5)
x₀ = (0.9, 0.9), ε = 0.1, N = 8
g(x₀) = 0.030376 < 0.1 ✓
min(g on B) = 1.443156 > 0.9 ✓
Show code cell source
# ============================================================
# Heatmap of g over K = [0,1]^2
# ============================================================
fig, ax = plt.subplots(figsize=(7, 6))
im = ax.imshow(G, extent=[0, 1, 0, 1], origin='lower', cmap='RdYlGn',
vmin=0, vmax=max(N_cover * 0.5, G.max()), aspect='equal')
# Draw the disk B
theta_circle = np.linspace(0, 2 * np.pi, 200)
circle_x = center_B[0] + radius_B * np.cos(theta_circle)
circle_y = center_B[1] + radius_B * np.sin(theta_circle)
ax.plot(circle_x, circle_y, color='white', linewidth=2.5, linestyle='-', zorder=4)
ax.plot(circle_x, circle_y, color=DANGER, linewidth=1.5, linestyle='-', zorder=5,
label='$B$ (boundary)')
# Mark the point x0
ax.plot(x0_pt[0], x0_pt[1], 'o', color=ACCENT, markersize=12, zorder=6,
markeredgecolor='white', markeredgewidth=2, label=f'$\\mathbf{{x}}_0 = ({x0_pt[0]}, {x0_pt[1]})$')
# Mark the cover points b_j
for j in range(N_cover):
label = '$\\mathbf{b}_j$ (cover points)' if j == 0 else None
ax.plot(b_points[j, 0], b_points[j, 1], 's', color=WARN, markersize=7, zorder=6,
markeredgecolor='white', markeredgewidth=1, label=label)
# Contour lines at eps and 1-eps
contour = ax.contour(XX, YY, G, levels=[eps, 1-eps], colors=['white', 'white'],
linewidths=[1.5, 1.5], linestyles=['--', '-'], zorder=4)
ax.clabel(contour, fmt={eps: f'$g = \\varepsilon$', 1-eps: f'$g = 1-\\varepsilon$'},
fontsize=9)
cbar = fig.colorbar(im, ax=ax, shrink=0.85, label='$g(\\mathbf{x})$')
ax.set_xlabel('$x_1$', fontsize=12)
ax.set_ylabel('$x_2$', fontsize=12)
ax.set_title('Lemma 3.2: Point-Set Separation\n'
f'$g(\\mathbf{{x}}_0) = {g_at_x0:.4f} < \\varepsilon = {eps}$, '
f'$\\min_B g = {min_on_B:.4f} > 1 - \\varepsilon = {1-eps}$',
fontsize=12, fontweight='bold')
ax.legend(loc='lower left', fontsize=9, framealpha=0.9)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
fig.tight_layout()
plt.show()
Try it yourself → Point-Set Separator
Section 3: Lemma 3.3 — Set-Set Separation#
Lemma 3.3 (Set-Set Separation)
Let \(A\) and \(B\) be disjoint closed subsets of \(K\). Then for each \(\varepsilon > 0\):
(i) there exists \(h \in \mathcal{N}_3\) such that \(h < \varepsilon\) on \(B\) and \(h > 1 - \varepsilon\) on \(A\);
(ii) there exists \(H \in \mathcal{N}_3^\sigma\) such that \(0 \leq H < \varepsilon\) on \(B\) and \(1 - \varepsilon < H \leq 1\) on \(A\).
Intuition#
For each point \(\boldsymbol{a} \in A\), Lemma 3.2 gives a “spotlight” — bright at \(\boldsymbol{a}\) and dark on \(B\). Cover \(A\) with finitely many spotlights, sum them up, then squash with one more application of \(\sigma\): done.
The structure is exactly parallel to Lemma 3.2, but one level higher:
Lemma 3.2 |
Lemma 3.3 |
|---|---|
Separate each \(\boldsymbol{b}\) from \(\boldsymbol{x}_0\) |
Separate each \(\boldsymbol{a}\) from \(B\) |
Using \(\mathcal{N}_1^\sigma\) elements |
Using \(\mathcal{N}_2\) elements |
Cover \(B\) |
Cover \(A\) |
Sum gives \(g \in \mathcal{N}_2\) |
Sum gives \(h \in \mathcal{N}_3\) |
Proof walkthrough#
Step 1: Separate each \(\boldsymbol{a}\) from \(B\) using Lemma 3.2.
WLOG assume \(\varepsilon \in (0, 1/3)\).
For each \(\boldsymbol{a} \in A\), apply Lemma 3.2 (with the roles swapped: separate the single point \(\boldsymbol{a}\) from the closed set \(B\)) to obtain \(\widetilde{g}_{\boldsymbol{a}} \in \mathcal{N}_2\) with
Step 2: Flip to get the right orientation.
Set \(g_{\boldsymbol{a}} = 1 - \widetilde{g}_{\boldsymbol{a}}\). Then \(g_{\boldsymbol{a}} \in \mathcal{N}_2\) and:
Now \(g_{\boldsymbol{a}}\) is high at \(\boldsymbol{a}\) and low on \(B\) — the opposite orientation from \(\widetilde{g}_{\boldsymbol{a}}\).
Auxiliary: Why does flipping preserve \(\mathcal{N}_2\)?
If \(\widetilde{g} \in \mathcal{N}_2\) then \(\widetilde{g} = a_0 + \sum a_j \sigma(f_j)\) for some coefficients \(a_j\) and functions \(f_j \in \mathcal{N}_1\). Then
which is also an affine combination of elements of \(\mathcal{N}_1^\sigma\), hence in \(\mathcal{N}_2\). In general, \(\mathcal{N}_k\) is closed under affine combinations.
Step 3: Build an open cover of \(A\).
Define
Since \(g_{\boldsymbol{a}}\) is continuous, \(U_{\boldsymbol{a}}\) is open. And \(\boldsymbol{a} \in U_{\boldsymbol{a}}\) because \(g_{\boldsymbol{a}}(\boldsymbol{a}) > 1 - \varepsilon/2 > 1 - \varepsilon\). So \(\{U_{\boldsymbol{a}}\}_{\boldsymbol{a} \in A}\) is an open cover of the compact set \(A\).
By compactness, extract a finite subcover: \(A \subset U_{\boldsymbol{a}_1} \cup \cdots \cup U_{\boldsymbol{a}_N}\).
Step 4: Sharpen with Lemma 3.1.
Use Lemma 3.1 to find \(s, t \in \mathbb{R}\) with
Step 5: Sum to get \(h \in \mathcal{N}_3\).
Define
Layer counting
Each \(g_{\boldsymbol{a}_j} \in \mathcal{N}_2\), so \(s + t \, g_{\boldsymbol{a}_j} \in \mathcal{N}_2\) (affine combination of an \(\mathcal{N}_2\) element). Therefore \(\sigma(s + t \, g_{\boldsymbol{a}_j}) \in \mathcal{N}_2^\sigma\).
The sum \(h = \sum \sigma(s + t \, g_{\boldsymbol{a}_j})\) is an affine combination of elements of \(\mathcal{N}_2^\sigma\), which by definition is \(\mathcal{N}_3\).
Step 6: Verify the bounds.
On \(A\): If \(\boldsymbol{a} \in A\), then \(\boldsymbol{a} \in U_{\boldsymbol{a}_k}\) for some \(k\), so \(g_{\boldsymbol{a}_k}(\boldsymbol{a}) > 1 - \varepsilon\), hence \(\sigma(s + t \, g_{\boldsymbol{a}_k}(\boldsymbol{a})) > 1 - \varepsilon\), hence \(h(\boldsymbol{a}) \geq \sigma(s + t \, g_{\boldsymbol{a}_k}(\boldsymbol{a})) > 1 - \varepsilon\). \(\;\checkmark\)
On \(B\): If \(\boldsymbol{b} \in B\), then \(g_{\boldsymbol{a}_j}(\boldsymbol{b}) < \varepsilon/2 < \varepsilon\) for all \(j\), so each \(\sigma(s + t \, g_{\boldsymbol{a}_j}(\boldsymbol{b})) < \varepsilon/N\), hence \(h(\boldsymbol{b}) < N \cdot (\varepsilon/N) = \varepsilon\). \(\;\checkmark\)
This proves part (i).
Step 7: Part (ii) — one more squash.
Apply Lemma 3.1 once more to find \(s', t' \in \mathbb{R}\) with
Set \(H = \sigma(s' + t' h) \in \mathcal{N}_3^\sigma\). Since \(\sigma\) is increasing and \(0 \leq \sigma \leq 1\):
On \(B\): \(h < \varepsilon\), so \(H = \sigma(s' + t'h) < \varepsilon\) and \(H \geq 0\). \(\;\checkmark\)
On \(A\): \(h > 1 - \varepsilon\), so \(H = \sigma(s' + t'h) > 1 - \varepsilon\) and \(H \leq 1\). \(\;\checkmark\;\square\)
The epsilon cascade#
Here is how the epsilons propagate through the construction:
Target: ε
│
├── Lemma 3.3, Step 1: call Lemma 3.2 with ε/2
│ │
│ ├── Lemma 3.2, Step 1: call Lemma 3.1 with ε/2
│ │ └── IVT targets: σ(y₀) = ε/4, σ(y₁) = 1 - ε/4
│ │
│ └── Lemma 3.2, Step 5: sharpen with ε/N₁ (N₁ = cover size for B)
│
├── Lemma 3.3, Step 2: flip g̃ → g = 1 - g̃
│
├── Lemma 3.3, Step 4: sharpen with ε/N₂ (N₂ = cover size for A)
│
└── Lemma 3.3, Step 7: one more squash for part (ii)
└── Final H ∈ N₃ᵞ with 0 ≤ H < ε on B, 1-ε < H ≤ 1 on A
Numerical demonstration#
Show code cell source
# ============================================================
# Lemma 3.3: Set-Set Separation in K = [0,1]^2
# A = rectangle [0, 0.35] x [0.2, 0.8]
# B = rectangle [0.65, 1] x [0.2, 0.8]
# eps = 0.15
# ============================================================
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
def sigma(x):
return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
def sigma_inv(y):
return np.log(y / (1 - y))
ACCENT = '#4f46e5'
THM = '#059669'
WARN = '#d97706'
DANGER = '#dc2626'
# Sets
A_box = [0.0, 0.35, 0.2, 0.8] # [x_min, x_max, y_min, y_max]
B_box = [0.65, 1.0, 0.2, 0.8]
eps33 = 0.15
def in_box(pt, box):
return box[0] <= pt[0] <= box[1] and box[2] <= pt[1] <= box[3]
# Grid for evaluation
resolution = 150
xx = np.linspace(0, 1, resolution)
yy = np.linspace(0, 1, resolution)
XX, YY = np.meshgrid(xx, yy)
points = np.column_stack([XX.ravel(), YY.ravel()])
# -----------------------------------------------------------
# Step 1-2: For each a in A, build g_a = 1 - g_tilde_a
# g_tilde_a > 1 - eps/2 on B, g_tilde_a(a) < eps/2
# => g_a < eps/2 on B, g_a(a) > 1 - eps/2
#
# We implement a simplified version:
# For each a_j, we separate a_j from B using coordinate
# projections. Since a_j is in A (x <= 0.35) and B has x >= 0.65,
# the x-coordinate alone separates them.
# -----------------------------------------------------------
# Choose cover points a_1, ..., a_N along A
N_A = 6
a_points = []
for i in range(3):
for j in range(2):
ax_val = A_box[0] + (i + 0.5) / 3 * (A_box[1] - A_box[0])
ay_val = A_box[2] + (j + 0.5) / 2 * (A_box[3] - A_box[2])
a_points.append(np.array([ax_val, ay_val]))
a_points = np.array(a_points)
N_A = len(a_points)
def build_g_a(a_pt, eps_inner):
"""
Build g_a in N_2 such that g_a(a) > 1-eps_inner/2 and g_a < eps_inner/2 on B.
Since A is on the left (x < 0.35) and B is on the right (x > 0.65),
we use the x-coordinate to separate.
For each point b in B's boundary, we build f_b that is high at b and low at a.
But since B is a rectangle, we can cover it with a few well-chosen points.
Then we sum sharpened versions.
"""
# Simple approach: use Lemma 3.1 on the x-coordinate
# We want sigma(s+t*a[0]) < eps_inner/2 and sigma(s+t*b[0]) > 1-eps_inner/2
# for all b in B, i.e., b[0] >= 0.65
# For g_tilde_a (high on B, low at a):
# sigma(s + t * a[0]) < eps_inner/2 at a[0]
# sigma(s + t * 0.65) > 1 - eps_inner/2 (leftmost point of B in x)
y0_ivt = sigma_inv(eps_inner / 2)
y1_ivt = sigma_inv(1 - eps_inner / 2)
t_g = (y1_ivt - y0_ivt) / (0.65 - a_pt[0])
s_g = y0_ivt - t_g * a_pt[0]
return s_g, t_g
# Build g_a for each a_j
# g_tilde_a(x) = sigma(s_g + t_g * x[0]) -- high on B, low at a
# g_a(x) = 1 - g_tilde_a(x) -- low on B, high at a
eps_inner = eps33 # using eps for inner calls
g_a_params = [build_g_a(a_points[j], eps_inner) for j in range(N_A)]
# Evaluate individual g_a functions on the grid
def eval_g_a(s_g, t_g, pts):
"""g_a = 1 - sigma(s_g + t_g * x[0])"""
g_tilde = sigma(s_g + t_g * pts[:, 0])
return 1 - g_tilde
# Step 4: Sharpening for the sum
# sigma(s' + t'*x) < eps/N on (-inf, eps) and > 1-eps on (1-eps, inf)
eps_sharp_33 = eps33 / N_A
y0_s = sigma_inv(eps_sharp_33 / 2)
y1_s = sigma_inv(1 - eps33 / 2)
val_low_33 = eps33
val_high_33 = 1 - eps33
t_sharp_33 = (y1_s - y0_s) / (val_high_33 - val_low_33)
s_sharp_33 = y0_s - t_sharp_33 * val_low_33
# Step 5: h = sum sigma(s_sharp + t_sharp * g_aj)
h_total = np.zeros(len(points))
for j in range(N_A):
s_g, t_g = g_a_params[j]
g_a_vals = eval_g_a(s_g, t_g, points)
h_total += sigma(s_sharp_33 + t_sharp_33 * g_a_vals)
H_grid = h_total.reshape(resolution, resolution)
# Step 7: H = sigma(s'' + t'' * h) for part (ii)
y0_final = sigma_inv(eps33 / 2)
y1_final = sigma_inv(1 - eps33 / 2)
t_final = (y1_final - y0_final) / (val_high_33 - val_low_33)
s_final = y0_final - t_final * val_low_33
H_sigma = sigma(s_final + t_final * h_total).reshape(resolution, resolution)
# Verify bounds
# On A
mask_A = np.array([in_box(p, A_box) for p in points])
mask_B = np.array([in_box(p, B_box) for p in points])
h_on_A = h_total[mask_A]
h_on_B = h_total[mask_B]
H_on_A = H_sigma.ravel()[mask_A]
H_on_B = H_sigma.ravel()[mask_B]
print(f"Set-Set Separation (Lemma 3.3)")
print(f"==============================")
print(f"A = [{A_box[0]}, {A_box[1]}] \u00d7 [{A_box[2]}, {A_box[3]}]")
print(f"B = [{B_box[0]}, {B_box[1]}] \u00d7 [{B_box[2]}, {B_box[3]}]")
print(f"\u03b5 = {eps33}, N = {N_A}")
print()
print(f"Part (i): h \u2208 N\u2083")
print(f" min(h on A) = {h_on_A.min():.4f} > {1-eps33} \u2713" if h_on_A.min() > 1-eps33 else f" min(h on A) = {h_on_A.min():.4f} \u2717")
print(f" max(h on B) = {h_on_B.max():.4f} < {eps33} \u2713" if h_on_B.max() < eps33 else f" max(h on B) = {h_on_B.max():.4f} \u2717")
print()
print(f"Part (ii): H \u2208 N\u2083\u1d5e")
print(f" min(H on A) = {H_on_A.min():.4f} > {1-eps33} \u2713" if H_on_A.min() > 1-eps33 else f" min(H on A) = {H_on_A.min():.4f} \u2717")
print(f" max(H on B) = {H_on_B.max():.4f} < {eps33} \u2713" if H_on_B.max() < eps33 else f" max(H on B) = {H_on_B.max():.4f} \u2717")
print(f" H range on A: [{H_on_A.min():.4f}, {H_on_A.max():.4f}] \u2282 ({1-eps33}, 1]")
print(f" H range on B: [{H_on_B.min():.4f}, {H_on_B.max():.4f}] \u2282 [0, {eps33})")
Set-Set Separation (Lemma 3.3)
==============================
A = [0.0, 0.35] × [0.2, 0.8]
B = [0.65, 1.0] × [0.2, 0.8]
ε = 0.15, N = 6
Part (i): h ∈ N₃
min(h on A) = 3.7765 > 0.85 ✓
max(h on B) = 0.0358 < 0.15 ✓
Part (ii): H ∈ N₃ᵞ
min(H on A) = 1.0000 > 0.85 ✓
max(H on B) = 0.0345 < 0.15 ✓
H range on A: [1.0000, 1.0000] ⊂ (0.85, 1]
H range on B: [0.0304, 0.0345] ⊂ [0, 0.15)
Show code cell source
# ============================================================
# 2x2 subplot grid: individual g_a, h, H
# ============================================================
fig, axes = plt.subplots(2, 2, figsize=(11, 10))
def draw_boxes(ax):
"""Draw rectangles A and B on the given axis."""
rect_A = Rectangle((A_box[0], A_box[2]), A_box[1]-A_box[0], A_box[3]-A_box[2],
linewidth=2, edgecolor=THM, facecolor='none',
linestyle='-', label='$A$', zorder=5)
rect_B = Rectangle((B_box[0], B_box[2]), B_box[1]-B_box[0], B_box[3]-B_box[2],
linewidth=2, edgecolor=DANGER, facecolor='none',
linestyle='-', label='$B$', zorder=5)
ax.add_patch(rect_A)
ax.add_patch(rect_B)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_aspect('equal')
# --- Panel (a): g_a for a single a ---
ax = axes[0, 0]
s_g0, t_g0 = g_a_params[0]
g_a0_vals = eval_g_a(s_g0, t_g0, points).reshape(resolution, resolution)
im0 = ax.imshow(g_a0_vals, extent=[0,1,0,1], origin='lower', cmap='RdYlGn', vmin=0, vmax=1)
draw_boxes(ax)
ax.plot(a_points[0, 0], a_points[0, 1], '*', color=ACCENT, markersize=14,
markeredgecolor='white', markeredgewidth=1, zorder=6)
ax.set_title(f'(a) Single $g_{{\\mathbf{{a}}_1}}$: high at $\\mathbf{{a}}_1$, low on $B$',
fontsize=11, fontweight='bold')
ax.legend(loc='upper right', fontsize=9)
fig.colorbar(im0, ax=ax, shrink=0.8)
# --- Panel (b): g_a for another a ---
ax = axes[0, 1]
s_g3, t_g3 = g_a_params[3]
g_a3_vals = eval_g_a(s_g3, t_g3, points).reshape(resolution, resolution)
im1 = ax.imshow(g_a3_vals, extent=[0,1,0,1], origin='lower', cmap='RdYlGn', vmin=0, vmax=1)
draw_boxes(ax)
ax.plot(a_points[3, 0], a_points[3, 1], '*', color=ACCENT, markersize=14,
markeredgecolor='white', markeredgewidth=1, zorder=6)
ax.set_title(f'(b) Single $g_{{\\mathbf{{a}}_4}}$: high at $\\mathbf{{a}}_4$, low on $B$',
fontsize=11, fontweight='bold')
ax.legend(loc='upper right', fontsize=9)
fig.colorbar(im1, ax=ax, shrink=0.8)
# --- Panel (c): h = sum of sharpened g_a ---
ax = axes[1, 0]
im2 = ax.imshow(H_grid, extent=[0,1,0,1], origin='lower', cmap='RdYlGn',
vmin=0, vmax=max(H_grid.max(), 1.5))
draw_boxes(ax)
for j in range(N_A):
ax.plot(a_points[j, 0], a_points[j, 1], '*', color=ACCENT, markersize=10,
markeredgecolor='white', markeredgewidth=0.8, zorder=6)
ax.set_title(f'(c) $h = \\sum \\sigma(s + t \\cdot g_{{\\mathbf{{a}}_j}}) \\in \\mathcal{{N}}_3$',
fontsize=11, fontweight='bold')
ax.legend(loc='upper right', fontsize=9)
fig.colorbar(im2, ax=ax, shrink=0.8)
# --- Panel (d): H = sigma(s' + t'h) ---
ax = axes[1, 1]
im3 = ax.imshow(H_sigma, extent=[0,1,0,1], origin='lower', cmap='RdYlGn', vmin=0, vmax=1)
draw_boxes(ax)
contour = ax.contour(XX, YY, H_sigma, levels=[eps33, 1-eps33],
colors=['white', 'white'], linewidths=[1.5, 1.5],
linestyles=['--', '-'], zorder=4)
ax.clabel(contour, fmt={eps33: f'$H = \\varepsilon$', 1-eps33: f'$H = 1-\\varepsilon$'},
fontsize=9)
ax.set_title(f'(d) $H = \\sigma(s\' + t\'h) \\in \\mathcal{{N}}_3^\\sigma$',
fontsize=11, fontweight='bold')
ax.legend(loc='upper right', fontsize=9)
fig.colorbar(im3, ax=ax, shrink=0.8)
for ax in axes.flat:
ax.set_xlabel('$x_1$', fontsize=11)
ax.set_ylabel('$x_2$', fontsize=11)
fig.suptitle('Lemma 3.3: Set-Set Separation', fontsize=14, fontweight='bold', y=1.01)
fig.tight_layout()
plt.show()
Try it yourself → Set-Set Separator
Exercises#
Exercise 2.1. For Lemma 3.1 with \(x_0 = 0\) and \(x_1 = 1\), solve the linear system explicitly and express \(s\) and \(t\) as functions of \(\varepsilon\). What happens to \(t\) as \(\varepsilon \to 0\)?
Exercise 2.2. Can you replace \(\sigma\) with ReLU in Lemma 3.1? What goes wrong? (Hint: ReLU is not bounded.)
Exercise 2.3. In Lemma 3.2, why must \(B\) be closed? Construct a counterexample with \(B\) open where the conclusion fails.
Exercise 2.4. How does the number of neurons in the network from Lemma 3.2 depend on \(N\) (the cover size)? Write a formula for the total parameter count.
Exercise 2.5. Trace the epsilon bounds through Lemma 3.3 for \(\varepsilon = 0.1\). How many neurons does the network need? (This is a counting exercise.)
Exercise 2.6. \(\star\) The paper notes that the proof is “wasteful.” Identify specifically where in Lemma 3.3 the argument uses more network capacity than strictly necessary.