Chapter 41: ML-KEM (Kyber) – Design and Implementation#

ML-KEM (Module-Lattice-Based Key-Encapsulation Mechanism), formerly known as CRYSTALS-Kyber, is the first post-quantum key encapsulation mechanism standardised by NIST as FIPS 203 (August 2024). Its security rests on the hardness of the Module Learning With Errors (MLWE) problem over polynomial rings.

In this chapter we build a simplified but fully functional ML-KEM implementation from scratch using only NumPy, explore the Number Theoretic Transform (NTT) that makes it efficient, and run experiments comparing the three parameter sets (ML-KEM-512, ML-KEM-768, ML-KEM-1024).

Chapter scope

This is a pedagogical implementation designed for clarity, not production use. A real deployment must use constant-time code, validated randomness, and pass the NIST Known Answer Tests (KATs). For production, use a vetted library such as liboqs or the PQClean reference implementation.

41.1 Historical Context: From Kyber to ML-KEM#

The NIST Post-Quantum Cryptography Competition#

In 2016, NIST launched its Post-Quantum Cryptography (PQC) Standardisation Process in anticipation of large-scale quantum computers that would break RSA, Diffie-Hellman, and elliptic curve cryptography via Shor’s algorithm.

The competition received 82 submissions in Round 1 (2017). After three rounds of public evaluation:

Round

Year

Candidates

Round 1

2017–2019

82 \(\to\) 26

Round 2

2019–2020

26 \(\to\) 7 finalists + 8 alternates

Round 3

2020–2022

4 selected for standardisation

In July 2022, NIST selected CRYSTALS-Kyber as the sole KEM standard and CRYSTALS-Dilithium, FALCON, and SPHINCS+ for digital signatures.

CRYSTALS-Kyber \(\to\) ML-KEM (FIPS 203)#

CRYSTALS-Kyber was designed by a team led by Peter Schwabe, Roberto Avanzi, Joppe Bos, Leo Ducas, Eike Kiltz, Tancrede Lepoint, Vadim Lyubashevsky, John Schanck, Gregor Seiler, and Damien Stehle.

In August 2024, NIST published the final standard as FIPS 203: Module-Lattice-Based Key-Encapsulation Mechanism Standard, officially renaming it ML-KEM.

Why lattices?

Lattice-based cryptography offers:

  • Quantum resistance: no known quantum algorithm solves lattice problems in polynomial time.

  • Efficiency: key and ciphertext sizes are moderate (hundreds of bytes to a few KB), and operations are fast thanks to the NTT.

  • Worst-case to average-case reductions: the security of random instances can be related to worst-case lattice problems (via results of Regev, Lyubashevsky, Peikert, and others).

41.2 Mathematical Foundations#

The Polynomial Ring \(R_q\)#

ML-KEM operates over the polynomial quotient ring:

\[ R_q = \mathbb{Z}_q[x] / (x^{256} + 1)\]

where \(q = 3329\) is a prime chosen so that \(q \equiv 1 \pmod{256}\) (since \(3329 = 13 \times 256 + 1\)), enabling efficient negacyclic NTT-based multiplication.

Elements of \(R_q\) are polynomials of degree at most 255 with coefficients in \(\{0, 1, \ldots, q-1\}\). Addition is coefficient-wise modulo \(q\). Multiplication is polynomial multiplication modulo both \(q\) and \(x^{256}+1\).

Module-LWE (MLWE) Problem#

Given a uniformly random matrix \(\mathbf{A} \in R_q^{k \times k}\) and a secret vector \(\mathbf{s} \in R_q^k\) with small coefficients, the MLWE problem asks to distinguish:

\[ (\mathbf{A},\; \mathbf{A} \cdot \mathbf{s} + \mathbf{e}) \quad \text{from} \quad (\mathbf{A},\; \mathbf{u})\]

where \(\mathbf{e}\) is a small error vector and \(\mathbf{u}\) is uniformly random.

The parameter \(k\) determines the security level:

Variant

\(k\)

NIST Level

Classical Security

ML-KEM-512

2

1 (AES-128)

\(\sim 118\) bits

ML-KEM-768

3

3 (AES-192)

\(\sim 182\) bits

ML-KEM-1024

4

5 (AES-256)

\(\sim 256\) bits

Centered Binomial Distribution (CBD)#

Secret and error polynomials are sampled from the centered binomial distribution \(\text{CBD}_\eta\): for each coefficient, sample \(2\eta\) bits \(a_1, \ldots, a_\eta, b_1, \ldots, b_\eta\) and compute:

\[ c = \sum_{i=1}^{\eta} a_i - \sum_{i=1}^{\eta} b_i\]

This produces integers in \(\{-\eta, \ldots, \eta\}\) with a binomial distribution centered at zero.

Why CBD instead of discrete Gaussian?

Earlier lattice schemes used discrete Gaussian sampling, which is harder to implement in constant time. The CBD is trivially constant-time (just bit operations) and provides comparable security after parameter adjustment.

41.3 The Number Theoretic Transform (NTT)#

The NTT is the finite-field analogue of the FFT. It converts polynomial multiplication from \(O(n^2)\) to \(O(n \log n)\) operations.

Primitive Root of Unity#

For \(q = 3329\) and \(n = 256\), ML-KEM uses a negacyclic NTT that exploits the factorisation of \(x^{256}+1\) over \(\mathbb{Z}_q\).

The key requirement is a primitive \(256\)th root of unity \(\zeta \in \mathbb{Z}_q\), i.e., an element satisfying \(\zeta^{256} \equiv 1 \pmod{q}\) and \(\zeta^{128} \equiv -1 \pmod{q}\).

The value \(\zeta = 17\) serves this purpose for ML-KEM: \(17^{128} \equiv 3328 \equiv -1 \pmod{3329}\), and therefore \(17^{256} \equiv 1 \pmod{3329}\).

Because \(q \equiv 1 \pmod{256}\), the multiplicative group \(\mathbb{Z}_q^*\) contains elements of order 256, and 17 is one such element.

NTT in ML-KEM#

ML-KEM uses a negacyclic NTT that maps a polynomial \(f(x) \in R_q\) to its NTT representation \(\hat{f}\). The negacyclic NTT factors \(x^{256}+1\) into 128 degree-1 polynomials \(x - \zeta^{2i+1}\) for \(i = 0, \ldots, 127\), so the NTT representation consists of 128 pairs of coefficients (one per degree-1 factor).

The forward NTT uses the Cooley-Tukey butterfly:

\[\begin{split} \begin{aligned} f[j] &\leftarrow f[j] + \zeta \cdot f[j + \ell] \\ f[j + \ell] &\leftarrow f[j] - \zeta \cdot f[j + \ell] \end{aligned}\end{split}\]

The inverse NTT uses the Gentleman-Sande butterfly:

\[\begin{split} \begin{aligned} f[j] &\leftarrow f[j] + f[j + \ell] \\ f[j + \ell] &\leftarrow \zeta^{-1} \cdot (f[j + \ell] - f[j]) \end{aligned}\end{split}\]

After all layers, each coefficient is multiplied by \(n^{-1} \bmod q\).

Pointwise Multiplication#

In NTT domain, polynomial multiplication reduces to 128 base-case multiplications of degree-1 polynomials, making it very efficient.

Computational advantage

Schoolbook polynomial multiplication over \(R_q\) requires \(256^2 = 65{,}536\) multiplications modulo \(q\). NTT-based multiplication requires only \(256 \cdot \log_2(256) / 2 = 1{,}024\) butterfly operations for the transforms plus 128 base-case multiplications – roughly a 30x speedup in operation count. Note that the measured wall-clock speedup of our pure-Python NTT over schoolbook multiplication may differ from this theoretical ratio due to interpreter overhead.

import numpy as np

# --- Verify fundamental NTT parameters for ML-KEM ---
q = 3329   # The Kyber prime
n = 256    # Polynomial degree
zeta = 17  # Primitive 256th root of unity mod q

# Check that q is prime
def is_prime(p):
    if p < 2:
        return False
    for d in range(2, int(p**0.5) + 1):
        if p % d == 0:
            return False
    return True

print(f'q = {q}')
print(f'q is prime: {is_prime(q)}')
print(f'q mod 256 = {q % 256}  (need q ≡ 1 mod 256 for negacyclic NTT)')
print(f'3329 = 13 * 256 + 1 = {13 * 256 + 1}')
print()

# Verify zeta = 17 is a primitive 256th root of unity
print(f'zeta = {zeta}')
print(f'zeta^128 mod q = {pow(zeta, 128, q)}  (should be {q-1} = -1 mod q)')
print(f'zeta^256 mod q = {pow(zeta, 256, q)}  (should be 1)')
print()

# Check it is primitive (order is exactly 256, not a proper divisor)
is_primitive = True
for d in [1, 2, 4, 8, 16, 32, 64, 128]:
    if pow(zeta, d, q) == 1:
        is_primitive = False
        print(f'  NOT primitive: zeta^{d} = 1')
        break
if is_primitive:
    print(f'zeta = {zeta} is a primitive 256th root of unity mod {q}')
    print(f'  (17^128 ≡ {pow(17, 128, q)} ≡ -1 mod {q}  ✓)')

# Compute n^{-1} mod q (needed for inverse NTT)
n_inv = pow(n, -1, q)
print(f'\nn^{{-1}} mod q = {n_inv}  (verify: {n} * {n_inv} mod {q} = {(n * n_inv) % q})')
q = 3329
q is prime: True
q mod 256 = 1  (need q ≡ 1 mod 256 for negacyclic NTT)
3329 = 13 * 256 + 1 = 3329

zeta = 17
zeta^128 mod q = 3328  (should be 3328 = -1 mod q)
zeta^256 mod q = 1  (should be 1)

zeta = 17 is a primitive 256th root of unity mod 3329
  (17^128 ≡ 3328 ≡ -1 mod 3329  ✓)

n^{-1} mod q = 3316  (verify: 256 * 3316 mod 3329 = 1)

41.4 Implementing the NTT#

We implement the NTT class following the ML-KEM specification (FIPS 203, Algorithm 9 and Algorithm 10). The key design choices:

  • Bit-reversal of the twiddle factor indices, precomputed as a table.

  • In-place butterfly operations for memory efficiency.

  • Base-case multiplication for pointwise products in NTT domain.

import numpy as np

class NTT:
    """Number Theoretic Transform for ML-KEM (Kyber).

    Operates in Z_q[x]/(x^256 + 1) with q=3329, zeta=17 (primitive 256th root of unity).
    Adapted from the CRYSTALS-Kyber reference implementation."""

    def __init__(self, zeta=17, q=3329):
        self.zeta = zeta
        self.q = q
        self.n = 256
        # f = 128^{-1} mod q, used in inverse NTT
        self.ntt_f = pow(128, -1, self.q)
        # Precompute twiddle factors with bit-reversed indices
        self.ntt_zetas = [
            pow(self.zeta, self._bitrev(i, 7), self.q) for i in range(128)
        ]

    @staticmethod
    def _bitrev(i, k):
        """Reverse the lowest k bits of integer i."""
        return int(bin(i & ((1 << k) - 1))[2:].zfill(k)[::-1], 2)

    def forward(self, f):
        """Forward NTT (Cooley-Tukey butterfly, FIPS 203 Algorithm 9)."""
        f_ = list(f)
        i = 1
        length = 128
        while length >= 2:
            start = 0
            while start < 256:
                z = self.ntt_zetas[i]
                i += 1
                for j in range(start, start + length):
                    t = (z * f_[j + length]) % self.q
                    f_[j + length] = (f_[j] - t) % self.q
                    f_[j] = (f_[j] + t) % self.q
                start += 2 * length
            length >>= 1
        return f_

    def inverse(self, f):
        """Inverse NTT (Gentleman-Sande butterfly, FIPS 203 Algorithm 10)."""
        f_ = list(f)
        length = 2
        i = 127
        while length <= 128:
            start = 0
            while start < 256:
                z = self.ntt_zetas[i]
                i -= 1
                for j in range(start, start + length):
                    t = f_[j]
                    f_[j] = (t + f_[j + length]) % self.q
                    f_[j + length] = (z * (f_[j + length] - t)) % self.q
                start += 2 * length
            length <<= 1
        for j in range(256):
            f_[j] = (f_[j] * self.ntt_f) % self.q
        return f_

    def basemul(self, a0, a1, b0, b1, gamma):
        """Base-case multiply of two degree-1 polys mod (x^2 - gamma)."""
        c0 = (a0 * b0 + a1 * b1 * gamma) % self.q
        c1 = (a0 * b1 + a1 * b0) % self.q
        return c0, c1

    def pointwise_multiply(self, f_hat, g_hat):
        """Pointwise multiplication in NTT domain (128 base-case muls)."""
        h = [0] * 256
        for i in range(64):
            z = self.ntt_zetas[64 + i]
            h[4*i], h[4*i+1] = self.basemul(
                f_hat[4*i], f_hat[4*i+1],
                g_hat[4*i], g_hat[4*i+1], z)
            h[4*i+2], h[4*i+3] = self.basemul(
                f_hat[4*i+2], f_hat[4*i+3],
                g_hat[4*i+2], g_hat[4*i+3], (-z) % self.q)
        return h

    def ntt_vector(self, vec):
        """Apply forward NTT to each polynomial in a vector."""
        return [self.forward(poly) for poly in vec]

    def intt_vector(self, vec):
        """Apply inverse NTT to each polynomial in a vector."""
        return [self.inverse(poly) for poly in vec]

# Quick test: forward then inverse should be identity
ntt = NTT()
rng = np.random.default_rng(42)
test_poly = [int(x) for x in rng.integers(0, 3329, size=256)]

ntt_result = ntt.forward(test_poly)
recovered = ntt.inverse(ntt_result)

print('NTT round-trip test:')
print(f'  Original[0:8]:  {test_poly[:8]}')
print(f'  NTT[0:8]:       {ntt_result[:8]}')
print(f'  Recovered[0:8]: {recovered[:8]}')
print(f'  Match: {test_poly == recovered}')
NTT round-trip test:
  Original[0:8]:  [297, 2576, 2179, 1461, 1441, 2858, 286, 2321]
  NTT[0:8]:       [3184, 1311, 322, 728, 858, 2422, 1987, 896]
  Recovered[0:8]: [297, 2576, 2179, 1461, 1441, 2858, 286, 2321]
  Match: True

NTT Butterfly Diagram#

The NTT consists of \(\log_2(n) = 8\) layers of butterfly operations. At each layer, pairs of elements are combined using a twiddle factor \(\zeta^{\text{br}(i)}\).

Let us visualise the butterfly structure for a small example (\(n = 8\)) to build intuition before working with the full \(n = 256\) case.

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(14, 7))

n_small = 8
n_layers = int(np.log2(n_small))

# Draw the butterfly network for n=8 NTT
for layer in range(n_layers):
    half_block = n_small >> (layer + 1)
    block_size = 2 * half_block
    x_pos = layer * 2.5

    for block_start in range(0, n_small, block_size):
        for j in range(half_block):
            top = block_start + j
            bot = block_start + j + half_block

            # Lines from previous layer to this butterfly
            if layer > 0:
                ax.plot([x_pos - 2.5, x_pos], [top, top], 'b-', linewidth=0.8)
                ax.plot([x_pos - 2.5, x_pos], [bot, bot], 'b-', linewidth=0.8)

            # Butterfly cross-connections
            ax.annotate('', xy=(x_pos + 1.5, top), xytext=(x_pos, top),
                        arrowprops=dict(arrowstyle='->', color='darkblue', lw=1.5))
            ax.annotate('', xy=(x_pos + 1.5, bot), xytext=(x_pos, bot),
                        arrowprops=dict(arrowstyle='->', color='darkblue', lw=1.5))

            # Cross connections (the butterfly)
            ax.plot([x_pos, x_pos + 1.5], [top, bot], 'r-', linewidth=0.8, alpha=0.6)
            ax.plot([x_pos, x_pos + 1.5], [bot, top], 'g-', linewidth=0.8, alpha=0.6)

            # Twiddle factor label
            br_idx = block_start // block_size * half_block + j
            ax.text(x_pos + 0.75, (top + bot) / 2, f'z{br_idx}',
                    fontsize=7, ha='center', va='center',
                    bbox=dict(boxstyle='round,pad=0.2', facecolor='lightyellow',
                              edgecolor='orange', alpha=0.9))

# Draw input/output nodes
for i in range(n_small):
    ax.plot(-0.3, i, 'ko', markersize=8)
    ax.text(-0.8, i, f'f[{i}]', fontsize=10, ha='right', va='center')
    ax.plot(n_layers * 2.5 - 1.0 + 1.5, i, 'ko', markersize=8)
    ax.text(n_layers * 2.5 - 1.0 + 2.0, i, f'f^[{i}]', fontsize=10,
            ha='left', va='center')

# Layer labels
for layer in range(n_layers):
    x_pos = layer * 2.5 + 0.75
    ax.text(x_pos, -1.0, f'Layer {layer+1}\n(len={n_small >> (layer+1)})',
            fontsize=9, ha='center', va='top', fontweight='bold')

ax.set_xlim(-1.5, n_layers * 2.5 + 2.5)
ax.set_ylim(-1.8, n_small - 0.3)
ax.invert_yaxis()
ax.set_title('NTT Butterfly Network (n = 8, Cooley-Tukey)', fontsize=14)
ax.axis('off')

# Legend
ax.plot([], [], 'r-', linewidth=1.2, label='+ branch (add)')
ax.plot([], [], 'g-', linewidth=1.2, label='- branch (subtract)')
ax.legend(loc='lower right', fontsize=10)

plt.tight_layout()
plt.savefig('ntt_butterfly_n8.png', dpi=150, bbox_inches='tight')
plt.show()
../_images/9f3b68129ab29d81eccb025515346c9ef17eb6adf1094cef4f8edf3088e1bdfa.png

Polynomial Multiplication via NTT#

To multiply two polynomials \(f, g \in R_q\):

  1. Compute \(\hat{f} = \text{NTT}(f)\) and \(\hat{g} = \text{NTT}(g)\).

  2. Compute \(\hat{h} = \hat{f} \circ \hat{g}\) (pointwise base-case multiply).

  3. Compute \(h = \text{NTT}^{-1}(\hat{h})\).

The result is \(h = f \cdot g \bmod (x^{256}+1, q)\).

Let us verify this against naive (schoolbook) multiplication.

import numpy as np

def schoolbook_multiply(f, g, q=3329, n=256):
    """Naive O(n^2) polynomial multiplication in Z_q[x]/(x^n + 1)."""
    h = [0] * n
    for i in range(n):
        for j in range(n):
            idx = i + j
            if idx < n:
                h[idx] = (h[idx] + f[i] * g[j]) % q
            else:
                # x^n = -1 in the quotient ring
                h[idx - n] = (h[idx - n] - f[i] * g[j]) % q
    return h

ntt = NTT()
rng = np.random.default_rng(123)

# Random polynomials with small coefficients (typical of Kyber)
f = [int(x) % 3329 for x in rng.integers(0, 100, size=256)]
g = [int(x) % 3329 for x in rng.integers(0, 100, size=256)]

# Method 1: Schoolbook
h_school = schoolbook_multiply(f, g)

# Method 2: NTT-based
f_hat = ntt.forward(f)
g_hat = ntt.forward(g)
h_hat = ntt.pointwise_multiply(f_hat, g_hat)
h_ntt = ntt.inverse(h_hat)

print('Polynomial multiplication verification:')
print(f'  Schoolbook result[0:8]: {h_school[:8]}')
print(f'  NTT result[0:8]:       {h_ntt[:8]}')
print(f'  Match: {h_school == h_ntt}')
Polynomial multiplication verification:
  Schoolbook result[0:8]: [779, 424, 1528, 219, 141, 911, 1332, 667]
  NTT result[0:8]:       [779, 424, 1528, 219, 141, 911, 1332, 667]
  Match: True

Timing Comparison: Schoolbook vs NTT#

Let us measure the wall-clock time of both multiplication methods.

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt
import time

ntt = NTT()
rng = np.random.default_rng(999)

n_trials = 20
school_times = []
ntt_times = []

for _ in range(n_trials):
    f = [int(x) for x in rng.integers(0, 3329, size=256)]
    g = [int(x) for x in rng.integers(0, 3329, size=256)]

    t0 = time.perf_counter()
    _ = schoolbook_multiply(f, g)
    t1 = time.perf_counter()
    school_times.append(t1 - t0)

    t0 = time.perf_counter()
    fh = ntt.forward(f)
    gh = ntt.forward(g)
    hh = ntt.pointwise_multiply(fh, gh)
    _ = ntt.inverse(hh)
    t1 = time.perf_counter()
    ntt_times.append(t1 - t0)

school_mean = np.mean(school_times) * 1000
ntt_mean = np.mean(ntt_times) * 1000
speedup = school_mean / ntt_mean

fig, ax = plt.subplots(figsize=(8, 5))
categories = ['Schoolbook\n$O(n^2)$', 'NTT-based\n$O(n \\log n)$']
means = [school_mean, ntt_mean]
colors = ['#e74c3c', '#27ae60']
bars = ax.bar(categories, means, color=colors, edgecolor='black', width=0.5)

for bar, val in zip(bars, means):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
            f'{float(val):.2f} ms', ha='center', va='bottom', fontsize=12, fontweight='bold')

ax.set_ylabel('Time per multiplication (ms)', fontsize=12)
ax.set_title(f'Polynomial Multiplication in $R_q$ (n=256, q=3329)\nNTT speedup: {float(speedup):.1f}x',
             fontsize=13)
ax.set_ylim(0, max(means) * 1.3)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('ntt_vs_schoolbook_timing.png', dpi=150, bbox_inches='tight')
plt.show()

print(f'Schoolbook: {float(school_mean):.2f} ms (mean of {n_trials} trials)')
print(f'NTT-based:  {float(ntt_mean):.2f} ms (mean of {n_trials} trials)')
print(f'Speedup:    {float(speedup):.1f}x')
../_images/aad4cdddc17a1de3f2d99379c98e90ab26157c63ab584dc6a778baec922c5517.png
Schoolbook: 5.97 ms (mean of 20 trials)
NTT-based:  0.53 ms (mean of 20 trials)
Speedup:    11.4x

41.5 Centered Binomial Distribution#

Sampling

ML-KEM samples secret and error polynomials from the centered binomial distribution \(\text{CBD}_\eta\). For each coefficient:

\[ c = \left(\sum_{i=0}^{\eta-1} a_i\right) - \left(\sum_{i=0}^{\eta-1} b_i\right), \quad a_i, b_i \in \{0, 1\}\]

This gives \(c \in \{-\eta, \ldots, \eta\}\) with probability \(\Pr[c = j] = \binom{2\eta}{\eta+j} / 2^{2\eta}\).

  • ML-KEM-512 uses \(\eta_1 = 3, \eta_2 = 2\).

  • ML-KEM-768 uses \(\eta_1 = 2, \eta_2 = 2\).

  • ML-KEM-1024 uses \(\eta_1 = 2, \eta_2 = 2\).

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt
import math

def sample_cbd(eta, num_coeffs=256, rng=None):
    """Sample polynomial coefficients from CBD_eta."""
    if rng is None:
        rng = np.random.default_rng()
    bits = rng.integers(0, 2, size=(num_coeffs, 2 * eta))
    a_sum = bits[:, :eta].sum(axis=1)
    b_sum = bits[:, eta:].sum(axis=1)
    return (a_sum - b_sum).tolist()

def cbd_from_bytes(input_bytes, eta, q=3329):
    """Sample CBD_eta from a byte array (deterministic, as in FIPS 203).
    Expects exactly 64*eta bytes.  Returns 256 coefficients mod q."""
    if len(input_bytes) != 64 * eta:
        raise ValueError(f'Need {64*eta} bytes, got {len(input_bytes)}')
    b_int = int.from_bytes(input_bytes, 'little')
    coeffs = []
    mask = (1 << eta) - 1
    for _ in range(256):
        x = b_int & ((1 << (2 * eta)) - 1)
        a = bin(x & mask).count('1')
        b = bin((x >> eta) & mask).count('1')
        coeffs.append((a - b) % q)
        b_int >>= (2 * eta)
    return coeffs

# --- Visualise the CBD for different eta values ---
rng = np.random.default_rng(42)
n_samples = 10000

fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))

for idx, eta in enumerate([2, 3, 4]):
    samples = []
    for _ in range(n_samples // 256 + 1):
        samples.extend(sample_cbd(eta, 256, rng))
    samples = np.array(samples[:n_samples])

    ax = axes[idx]
    bins = np.arange(-eta - 1, eta + 2) - 0.5
    ax.hist(samples, bins=bins, density=True,
            color=f'C{idx}', edgecolor='black', alpha=0.8)

    # Overlay theoretical PMF
    x_vals = np.arange(-eta, eta + 1)
    pmf = [math.comb(2 * eta, eta + j) / (2 ** (2 * eta)) for j in x_vals]
    ax.plot(x_vals, pmf, 'ko-', markersize=6, linewidth=1.5, label='Theoretical')

    ax.set_xlabel('Coefficient value', fontsize=10)
    ax.set_ylabel('Probability' if idx == 0 else '', fontsize=10)
    ax.set_title(f'$\\mathrm{{CBD}}_{{{eta}}}$\n'
                 f'Range: [{-eta}, {eta}], Var={float(eta/2):.1f}', fontsize=11)
    ax.legend(fontsize=9)
    ax.set_xticks(range(-eta, eta + 1))

fig.suptitle('Centered Binomial Distribution for ML-KEM', fontsize=13, y=1.02)
plt.tight_layout()
plt.savefig('cbd_distributions.png', dpi=150, bbox_inches='tight')
plt.show()
../_images/e18808545a564a9a69b737cdb3ba938954a32c23fd7faa2860eb8ff46c10e352.png

41.6 Polynomial Arithmetic Helpers#

Before building the full ML-KEM, we need helper functions for polynomial addition, subtraction, and matrix-vector multiplication over \(R_q\).

import numpy as np

def poly_add(a, b, q=3329):
    """Add two polynomials in R_q (coefficient-wise mod q)."""
    return [(ai + bi) % q for ai, bi in zip(a, b)]

def poly_sub(a, b, q=3329):
    """Subtract two polynomials in R_q."""
    return [(ai - bi) % q for ai, bi in zip(a, b)]

def mat_vec_mul_ntt(A_hat, s_hat, ntt_obj):
    """Multiply matrix A_hat by vector s_hat in NTT domain."""
    k = len(A_hat)
    result = []
    for i in range(k):
        acc = [0] * 256
        for j in range(k):
            prod = ntt_obj.pointwise_multiply(A_hat[i][j], s_hat[j])
            acc = poly_add(acc, prod)
        result.append(acc)
    return result

def vec_dot_ntt(a_hat, b_hat, ntt_obj):
    """Inner product of two vectors in NTT domain."""
    k = len(a_hat)
    acc = [0] * 256
    for i in range(k):
        prod = ntt_obj.pointwise_multiply(a_hat[i], b_hat[i])
        acc = poly_add(acc, prod)
    return acc

# Quick test
ntt = NTT()
a = [1] + [0] * 255   # polynomial '1'
b = [0, 1] + [0] * 254  # polynomial 'x'

a_hat = ntt.forward(a)
b_hat = ntt.forward(b)
c_hat = ntt.pointwise_multiply(a_hat, b_hat)
c = ntt.inverse(c_hat)

# 1 * x = x, so result should be [0, 1, 0, 0, ...]
print('Product of 1 and x:')
print(f'  Result[0:5] = {c[:5]}  (expected [0, 1, 0, 0, 0])')
print(f'  Correct: {c[:5] == [0, 1, 0, 0, 0]}')
Product of 1 and x:
  Result[0:5] = [0, 1, 0, 0, 0]  (expected [0, 1, 0, 0, 0])
  Correct: True

41.7 The ML-KEM Protocol#

ML-KEM is a Key Encapsulation Mechanism (KEM) consisting of three algorithms:

  1. KeyGen() \(\to\) \((ek, dk)\): Generate an encapsulation key and decapsulation key.

  2. Encaps(ek) \(\to\) \((K, c)\): Produce a shared secret \(K\) and ciphertext \(c\).

  3. Decaps(dk, c) \(\to\) \(K\): Recover the shared secret from the ciphertext.

Internal K-PKE (Public Key Encryption)#

Internally, ML-KEM wraps an IND-CPA secure public-key encryption scheme called K-PKE with the Fujisaki-Okamoto (FO) transform to achieve IND-CCA2 security.

K-PKE.KeyGen():

  1. Sample random seed \(d\).

  2. Derive matrix \(\hat{\mathbf{A}}\) from seed \(\rho\) (via XOF).

  3. Sample secret \(\mathbf{s}\) and error \(\mathbf{e}\) from \(\text{CBD}_{\eta_1}\).

  4. Compute \(\hat{\mathbf{t}} = \hat{\mathbf{A}} \cdot \hat{\mathbf{s}} + \hat{\mathbf{e}}\) (in NTT domain).

  5. Public key: \((\hat{\mathbf{t}}, \rho)\). Secret key: \(\hat{\mathbf{s}}\).

K-PKE.Encrypt(pk, m, r):

  1. Parse \(pk = (\hat{\mathbf{t}}, \rho)\). Derive \(\hat{\mathbf{A}}\) from \(\rho\).

  2. Sample \(\mathbf{y}\) from \(\text{CBD}_{\eta_1}\) and \(\mathbf{e}_1, e_2\) from \(\text{CBD}_{\eta_2}\).

  3. Compute \(\mathbf{u} = \text{NTT}^{-1}(\hat{\mathbf{A}}^T \cdot \hat{\mathbf{y}}) + \mathbf{e}_1\).

  4. Compute \(v = \text{NTT}^{-1}(\hat{\mathbf{t}}^T \cdot \hat{\mathbf{y}}) + e_2 + \lceil q/2 \rfloor \cdot m\).

  5. Ciphertext: \((\mathbf{u}, v)\) after compression.

Why does decryption work?

The key insight is that \(v - \mathbf{s}^T \mathbf{u} = \mathbf{e}^T \mathbf{y} - \mathbf{s}^T \mathbf{e}_1 + e_2 + \lceil q/2 \rfloor m\).

The error term \(\mathbf{e}^T \mathbf{y} - \mathbf{s}^T \mathbf{e}_1 + e_2\) is small (all vectors have small coefficients), so rounding recovers \(m\).

import numpy as np
import math

def compress(x, d, q=3329):
    """Compress: Z_q -> Z_{2^d}. Round (2^d / q) * x."""
    if isinstance(x, (list, np.ndarray)):
        return [compress(xi, d, q) for xi in x]
    return int(math.floor(((1 << d) / q) * int(x) + 0.5)) % (1 << d)

def decompress(y, d, q=3329):
    """Decompress: Z_{2^d} -> Z_q. Round (q / 2^d) * y."""
    if isinstance(y, (list, np.ndarray)):
        return [decompress(yi, d, q) for yi in y]
    return int(math.floor((q / (1 << d)) * int(y) + 0.5))

# Verify compress/decompress is lossy but close
q = 3329
for d in [1, 4, 10]:
    errors = []
    for x in range(q):
        y = compress(x, d, q)
        x_rec = decompress(y, d, q)
        err = abs(x - x_rec)
        if err > q // 2:
            err = q - err
        errors.append(err)
    print(f'd={int(d):2d}: max error = {int(max(errors)):4d}, '
          f'mean error = {float(np.mean(errors)):.1f}, '
          f'compression ratio = {float(12/d):.1f}x')
d= 1: max error =  832, mean error = 416.1, compression ratio = 12.0x
d= 4: max error =  104, mean error = 52.0, compression ratio = 3.0x
d=10: max error =    2, mean error = 0.8, compression ratio = 1.2x

41.8 Simplified ML-KEM Implementation#

We now implement a simplified ML-KEM class. For pedagogical clarity, we:

  • Use NumPy’s random generator instead of SHAKE-128/256 for matrix and noise generation.

  • Skip byte encoding/decoding (work directly with integer coefficient arrays).

  • Implement the core mathematical operations faithfully.

Simplification

A production ML-KEM must derive the matrix \(\mathbf{A}\) and noise vectors deterministically from seeds using SHAKE-128/256 (XOF/PRF). Our simplified version uses numpy.random for sampling, which preserves the mathematical structure but not the byte-level compatibility with FIPS 203.

import numpy as np
import math

class MLKEM:
    """Simplified ML-KEM (Kyber) implementation.

    Implements the core mathematical operations of ML-KEM for
    educational purposes.  Uses numpy.random for sampling instead
    of SHAKE XOF for simplicity."""

    PARAMS = {
        'ML-KEM-512':  {'k': 2, 'eta1': 3, 'eta2': 2, 'du': 10, 'dv': 4},
        'ML-KEM-768':  {'k': 3, 'eta1': 2, 'eta2': 2, 'du': 10, 'dv': 4},
        'ML-KEM-1024': {'k': 4, 'eta1': 2, 'eta2': 2, 'du': 11, 'dv': 5},
    }

    def __init__(self, variant='ML-KEM-512', seed=None):
        self.q = 3329
        self.n = 256
        params = self.PARAMS[variant]
        self.k = params['k']
        self.eta1 = params['eta1']
        self.eta2 = params['eta2']
        self.du = params['du']
        self.dv = params['dv']
        self.variant = variant
        self.ntt = NTT(zeta=17, q=self.q)
        self.rng = np.random.default_rng(seed)

    def _sample_cbd(self, eta):
        """Sample a polynomial from CBD_eta."""
        bits = self.rng.integers(0, 2, size=(self.n, 2 * eta))
        a_sum = bits[:, :eta].sum(axis=1)
        b_sum = bits[:, eta:].sum(axis=1)
        return [(int(a) - int(b)) % self.q for a, b in zip(a_sum, b_sum)]

    def _sample_uniform(self):
        """Sample a uniformly random polynomial in R_q."""
        return [int(x) for x in self.rng.integers(0, self.q, size=self.n)]

    def _compress(self, x, d):
        return compress(x, d, self.q)

    def _decompress(self, y, d):
        return decompress(y, d, self.q)

    def keygen(self):
        """Generate ML-KEM key pair.  Returns (ek, dk)."""
        k = self.k
        # In ML-KEM, matrix A is generated directly in NTT domain
        A_hat = [[self._sample_uniform() for _ in range(k)] for _ in range(k)]
        s = [self._sample_cbd(self.eta1) for _ in range(k)]
        s_hat = self.ntt.ntt_vector(s)
        e = [self._sample_cbd(self.eta1) for _ in range(k)]
        e_hat = self.ntt.ntt_vector(e)
        As_hat = mat_vec_mul_ntt(A_hat, s_hat, self.ntt)
        t_hat = [poly_add(As_hat[i], e_hat[i]) for i in range(k)]
        ek = (A_hat, t_hat)
        dk = (s_hat, ek)
        return ek, dk

    def encaps(self, ek):
        """Encapsulate: produce shared secret and ciphertext."""
        A_hat, t_hat = ek
        k = self.k
        m = [int(x) for x in self.rng.integers(0, 2, size=self.n)]
        y = [self._sample_cbd(self.eta1) for _ in range(k)]
        y_hat = self.ntt.ntt_vector(y)
        e1 = [self._sample_cbd(self.eta2) for _ in range(k)]
        e2 = self._sample_cbd(self.eta2)
        A_hat_T = [[A_hat[j][i] for j in range(k)] for i in range(k)]
        Ay_hat = mat_vec_mul_ntt(A_hat_T, y_hat, self.ntt)
        Ay = self.ntt.intt_vector(Ay_hat)
        u = [poly_add(Ay[i], e1[i]) for i in range(k)]
        ty_hat = vec_dot_ntt(t_hat, y_hat, self.ntt)
        ty = self.ntt.inverse(ty_hat)
        q_half = (self.q + 1) // 2
        m_scaled = [(mi * q_half) % self.q for mi in m]
        v = poly_add(poly_add(ty, e2), m_scaled)
        u_comp = [self._compress(ui, self.du) for ui in u]
        v_comp = self._compress(v, self.dv)
        ciphertext = (u_comp, v_comp)
        shared_secret = m
        return shared_secret, ciphertext

    def decaps(self, dk, ciphertext):
        """Decapsulate: recover shared secret from ciphertext."""
        s_hat, ek = dk
        u_comp, v_comp = ciphertext
        k = self.k
        u = [self._decompress(ui, self.du) for ui in u_comp]
        v = self._decompress(v_comp, self.dv)
        u_hat = self.ntt.ntt_vector(u)
        su_hat = vec_dot_ntt(s_hat, u_hat, self.ntt)
        su = self.ntt.inverse(su_hat)
        w = poly_sub(v, su)
        m_recovered = self._compress(w, 1)
        return m_recovered

print('MLKEM class defined with variants:', list(MLKEM.PARAMS.keys()))
MLKEM class defined with variants: ['ML-KEM-512', 'ML-KEM-768', 'ML-KEM-1024']

Pedagogical Simplification

This implementation demonstrates the core mathematical structure of ML-KEM but omits several components required by FIPS 203: the Fujisaki-Okamoto transform, the shared-secret KDF (SHA3-256), seed-derived matrix generation (XOF), byte encoding/decoding, and KAT compatibility. It should not be used for any security-sensitive application.

41.9 Full Key Exchange Demonstration#

Let us run the complete ML-KEM key exchange protocol: Alice generates a key pair, Bob encapsulates a shared secret, and Alice decapsulates to recover it.

Shared secret in FIPS 203

In this simplified demo, the “shared secret” is the raw 256-bit message vector. The actual FIPS 203 standard derives a 32-byte shared secret by applying a KDF (SHA3-256) as part of the Fujisaki-Okamoto transform. Our implementation omits the FO transform and returns the plaintext message directly.

import numpy as np

print('=' * 65)
print('         ML-KEM Key Exchange Demonstration')
print('=' * 65)

for variant in ['ML-KEM-512', 'ML-KEM-768', 'ML-KEM-1024']:
    kem = MLKEM(variant=variant, seed=2024)
    ek, dk = kem.keygen()
    K_bob, ciphertext = kem.encaps(ek)
    K_alice = kem.decaps(dk, ciphertext)
    match = (K_bob == K_alice)
    n_bits = sum(1 for a, b in zip(K_bob, K_alice) if a == b)

    print(f'\n--- {variant} (k={kem.k}) ---')
    print(f'  Shared secret length: {len(K_bob)} bits')
    print(f'  Bob   K[0:16]: {K_bob[:16]}')
    print(f'  Alice K[0:16]: {K_alice[:16]}')
    print(f'  Matching bits: {n_bits} / {len(K_bob)}')
    print(f'  Keys match:    {match}')

print('\n' + '=' * 65)
=================================================================
         ML-KEM Key Exchange Demonstration
=================================================================

--- ML-KEM-512 (k=2) ---
  Shared secret length: 256 bits
  Bob   K[0:16]: [0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0]
  Alice K[0:16]: [0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0]
  Matching bits: 256 / 256
  Keys match:    True

--- ML-KEM-768 (k=3) ---
  Shared secret length: 256 bits
  Bob   K[0:16]: [0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1]
  Alice K[0:16]: [0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1]
  Matching bits: 256 / 256
  Keys match:    True

--- ML-KEM-1024 (k=4) ---
  Shared secret length: 256 bits
  Bob   K[0:16]: [1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1]
  Alice K[0:16]: [1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1]
  Matching bits: 256 / 256
  Keys match:    True

=================================================================

Decryption Noise Analysis#

The correctness of ML-KEM depends on the decryption noise being small enough that rounding recovers the message. Let us measure the actual noise distribution.

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt

def measure_decryption_noise(variant='ML-KEM-512', seed=42):
    """Measure decryption noise (without compression) for one exchange."""
    kem = MLKEM(variant=variant, seed=seed)
    k = kem.k
    q = kem.q
    ntt_obj = kem.ntt
    A_hat = [[kem._sample_uniform() for _ in range(k)] for _ in range(k)]
    s = [kem._sample_cbd(kem.eta1) for _ in range(k)]
    s_hat = ntt_obj.ntt_vector(s)
    e = [kem._sample_cbd(kem.eta1) for _ in range(k)]
    e_hat = ntt_obj.ntt_vector(e)
    As_hat = mat_vec_mul_ntt(A_hat, s_hat, ntt_obj)
    t_hat = [poly_add(As_hat[i], e_hat[i]) for i in range(k)]
    m = [int(x) for x in kem.rng.integers(0, 2, size=256)]
    y = [kem._sample_cbd(kem.eta1) for _ in range(k)]
    y_hat = ntt_obj.ntt_vector(y)
    e1 = [kem._sample_cbd(kem.eta2) for _ in range(k)]
    e2 = kem._sample_cbd(kem.eta2)
    A_hat_T = [[A_hat[j][i] for j in range(k)] for i in range(k)]
    Ay_hat = mat_vec_mul_ntt(A_hat_T, y_hat, ntt_obj)
    Ay = ntt_obj.intt_vector(Ay_hat)
    u = [poly_add(Ay[i], e1[i]) for i in range(k)]
    ty_hat = vec_dot_ntt(t_hat, y_hat, ntt_obj)
    ty = ntt_obj.inverse(ty_hat)
    q_half = (q + 1) // 2
    m_scaled = [(mi * q_half) % q for mi in m]
    v = poly_add(poly_add(ty, e2), m_scaled)
    u_hat = ntt_obj.ntt_vector(u)
    su_hat = vec_dot_ntt(s_hat, u_hat, ntt_obj)
    su = ntt_obj.inverse(su_hat)
    w = poly_sub(v, su)
    noise = []
    for wi, mi in zip(w, m_scaled):
        diff = (wi - mi) % q
        if diff > q // 2:
            diff -= q
        noise.append(diff)
    return noise, m

fig, axes = plt.subplots(1, 3, figsize=(16, 4.5))
variants = ['ML-KEM-512', 'ML-KEM-768', 'ML-KEM-1024']
colors = ['#3498db', '#e67e22', '#2ecc71']

for idx, (variant, color) in enumerate(zip(variants, colors)):
    all_noise = []
    for trial_seed in range(50):
        noise, _ = measure_decryption_noise(variant, seed=trial_seed)
        all_noise.extend(noise)
    all_noise = np.array(all_noise)
    ax = axes[idx]
    ax.hist(all_noise, bins=80, color=color, edgecolor='black',
            alpha=0.8, density=True)
    ax.axvline(x=0, color='red', linestyle='--', linewidth=1)
    threshold = 3329 // 4
    ax.axvline(x=-threshold, color='darkred', linestyle=':', linewidth=1.5,
               label=f'Threshold (+/-{threshold})')
    ax.axvline(x=threshold, color='darkred', linestyle=':', linewidth=1.5)
    ax.set_xlabel('Noise value', fontsize=10)
    ax.set_ylabel('Density' if idx == 0 else '', fontsize=10)
    ax.set_title(f'{variant} (k={MLKEM.PARAMS[variant]["k"]})\n'
                 f'std={float(all_noise.std()):.1f}, max|noise|={np.max(np.abs(all_noise))}',
                 fontsize=11)
    ax.legend(fontsize=8)

fig.suptitle('ML-KEM Decryption Noise Distribution (no compression)', fontsize=13, y=1.02)
plt.tight_layout()
plt.savefig('mlkem_noise_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

print('If |noise| < q/4 = 832 for all coefficients, decryption succeeds.')
../_images/1ee20c20d9398fe6feb9324a3d43c08c959763c779956b1d9cdb25ab8e76638e.png
If |noise| < q/4 = 832 for all coefficients, decryption succeeds.

41.10 Parameter Comparison: ML-KEM-512/768/1024#

Parameter

ML-KEM-512

ML-KEM-768

ML-KEM-1024

\(k\) (module rank)

2

3

4

\(q\) (modulus)

3329

3329

3329

\(\eta_1\)

3

2

2

\(\eta_2\)

2

2

2

\((d_u, d_v)\)

(10, 4)

(10, 4)

(11, 5)

Public key (bytes)

800

1184

1568

Secret key (bytes)

1632

2400

3168

Ciphertext (bytes)

768

1088

1568

Shared secret (bytes)

32

32

32

NIST security level

1

3

5

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt
import time

variants = ['ML-KEM-512', 'ML-KEM-768', 'ML-KEM-1024']
pk_sizes = [800, 1184, 1568]
sk_sizes = [1632, 2400, 3168]
ct_sizes = [768, 1088, 1568]
nist_levels = [1, 3, 5]

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Plot 1: Key and ciphertext sizes
x = np.arange(len(variants))
w = 0.25
bars1 = axes[0].bar(x - w, pk_sizes, w, label='Public key', color='#3498db',
                     edgecolor='black')
bars2 = axes[0].bar(x, sk_sizes, w, label='Secret key', color='#e74c3c',
                     edgecolor='black')
bars3 = axes[0].bar(x + w, ct_sizes, w, label='Ciphertext', color='#2ecc71',
                     edgecolor='black')
axes[0].set_xticks(x)
axes[0].set_xticklabels(['512', '768', '1024'])
axes[0].set_xlabel('ML-KEM Variant', fontsize=11)
axes[0].set_ylabel('Size (bytes)', fontsize=11)
axes[0].set_title('Key and Ciphertext Sizes', fontsize=12)
axes[0].legend(fontsize=9)
axes[0].grid(axis='y', alpha=0.3)
for bars in [bars1, bars2, bars3]:
    for bar in bars:
        h = bar.get_height()
        axes[0].text(bar.get_x() + bar.get_width()/2, h + 30,
                     f'{int(h)}', ha='center', va='bottom', fontsize=8)

# Plot 2: Timing comparison
keygen_times = []
encaps_times = []
decaps_times = []
for variant in variants:
    kg_t, en_t, de_t = [], [], []
    for trial in range(5):
        kem = MLKEM(variant=variant, seed=trial)
        t0 = time.perf_counter()
        ek, dk = kem.keygen()
        t1 = time.perf_counter()
        kg_t.append(t1 - t0)
        t0 = time.perf_counter()
        K, ct = kem.encaps(ek)
        t1 = time.perf_counter()
        en_t.append(t1 - t0)
        t0 = time.perf_counter()
        K2 = kem.decaps(dk, ct)
        t1 = time.perf_counter()
        de_t.append(t1 - t0)
    keygen_times.append(np.mean(kg_t) * 1000)
    encaps_times.append(np.mean(en_t) * 1000)
    decaps_times.append(np.mean(de_t) * 1000)

bars1 = axes[1].bar(x - w, keygen_times, w, label='KeyGen', color='#9b59b6',
                     edgecolor='black')
bars2 = axes[1].bar(x, encaps_times, w, label='Encaps', color='#f39c12',
                     edgecolor='black')
bars3 = axes[1].bar(x + w, decaps_times, w, label='Decaps', color='#1abc9c',
                     edgecolor='black')
axes[1].set_xticks(x)
axes[1].set_xticklabels(['512', '768', '1024'])
axes[1].set_xlabel('ML-KEM Variant', fontsize=11)
axes[1].set_ylabel('Time (ms)', fontsize=11)
axes[1].set_title('Operation Timing (Python, educational)', fontsize=12)
axes[1].legend(fontsize=9)
axes[1].grid(axis='y', alpha=0.3)

# Plot 3: Security level vs total communication
total_comm = [pk + ct for pk, ct in zip(pk_sizes, ct_sizes)]
axes[2].plot(nist_levels, total_comm, 's-', color='darkblue', linewidth=2,
             markersize=10)
for i, (nl, tc, var) in enumerate(zip(nist_levels, total_comm, variants)):
    axes[2].annotate(f'{var}\n{tc} bytes', xy=(nl, tc),
                     xytext=(nl + 0.3, tc), fontsize=9, va='center')
axes[2].set_xlabel('NIST Security Level', fontsize=11)
axes[2].set_ylabel('Total communication (bytes)', fontsize=11)
axes[2].set_title('Security vs Communication Cost', fontsize=12)
axes[2].set_xticks([1, 3, 5])
axes[2].grid(True, alpha=0.3)

fig.suptitle('ML-KEM Parameter Comparison', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('mlkem_parameter_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
../_images/b2c0cdd90aa325298139b6c2d7c703e606d142c14bf9b3c3d17d56123967fca8.png

41.11 Visualising the NTT#

NTT Spectrum of Secret Polynomials#

The NTT transforms small-coefficient polynomials (from CBD) into the frequency domain. Let us visualise what CBD samples look like before and after NTT.

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt

ntt = NTT()
rng = np.random.default_rng(42)

fig, axes = plt.subplots(2, 3, figsize=(16, 8))

for col, eta in enumerate([2, 3, 4]):
    bits = rng.integers(0, 2, size=(256, 2 * eta))
    a_sum = bits[:, :eta].sum(axis=1)
    b_sum = bits[:, eta:].sum(axis=1)
    poly = [(int(a) - int(b)) % 3329 for a, b in zip(a_sum, b_sum)]
    poly_centered = [c if c <= 3329 // 2 else c - 3329 for c in poly]
    poly_ntt = ntt.forward(poly)

    axes[0, col].stem(range(256), poly_centered, linefmt=f'C{col}-',
                       markerfmt=f'C{col}o', basefmt='k-')
    axes[0, col].set_title(f'$\\mathrm{{CBD}}_{{{eta}}}$ coefficients', fontsize=11)
    axes[0, col].set_xlabel('Index', fontsize=10)
    axes[0, col].set_ylabel('Value' if col == 0 else '', fontsize=10)
    axes[0, col].set_ylim(-eta - 1, eta + 1)

    axes[1, col].scatter(range(256), poly_ntt, s=4, c=f'C{col}', alpha=0.7)
    axes[1, col].set_title(f'After NTT (in $\\mathbb{{Z}}_{{3329}}$)', fontsize=11)
    axes[1, col].set_xlabel('Index', fontsize=10)
    axes[1, col].set_ylabel('Value' if col == 0 else '', fontsize=10)
    axes[1, col].set_ylim(0, 3329)

fig.suptitle('CBD Polynomials: Coefficient vs NTT Domain', fontsize=13, y=1.01)
plt.tight_layout()
plt.savefig('ntt_spectrum_cbd.png', dpi=150, bbox_inches='tight')
plt.show()
../_images/92aaf7de70cc55094ca80224bad3c1e09169fb6e3d56fb6a48da6de0adbe2991.png

41.12 Decryption Failure Probability#

ML-KEM has a non-zero but negligible decryption failure probability. A failure occurs when the decryption noise exceeds \(\lfloor q/4 \rfloor = 832\) for some coefficient. The failure probability is designed to be \(< 2^{-139}\) for ML-KEM-512.

Let us empirically estimate the failure rate by running many key exchanges.

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt

n_trials = 200
variants = ['ML-KEM-512', 'ML-KEM-768', 'ML-KEM-1024']
results = {}

for variant in variants:
    successes = 0
    bit_errors_list = []
    for trial in range(n_trials):
        kem = MLKEM(variant=variant, seed=trial * 1000 + 7)
        ek, dk = kem.keygen()
        K_send, ct = kem.encaps(ek)
        K_recv = kem.decaps(dk, ct)
        n_correct = sum(1 for a, b in zip(K_send, K_recv) if a == b)
        bit_errors_list.append(256 - n_correct)
        if K_send == K_recv:
            successes += 1
    results[variant] = {'success_rate': successes / n_trials,
                        'bit_errors': bit_errors_list}
    print(f'{variant}: {successes}/{n_trials} successful '
          f'({float(100*successes/n_trials):.1f}%)')

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Success rates
variant_names = ['512', '768', '1024']
rates = [results[v]['success_rate'] * 100 for v in variants]
colors = ['#3498db', '#e67e22', '#2ecc71']
bars = axes[0].bar(variant_names, rates, color=colors, edgecolor='black')
for bar, rate in zip(bars, rates):
    axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
                 f'{float(rate):.1f}%', ha='center', va='bottom', fontsize=12,
                 fontweight='bold')
axes[0].set_xlabel('ML-KEM Variant', fontsize=11)
axes[0].set_ylabel('Success Rate (%)', fontsize=11)
axes[0].set_title(f'Key Exchange Success Rate\n({n_trials} trials per variant)',
                  fontsize=12)
axes[0].set_ylim(0, 110)
axes[0].grid(axis='y', alpha=0.3)

# Bit error distribution
for i, (variant, color) in enumerate(zip(variants, colors)):
    errs = results[variant]['bit_errors']
    axes[1].hist(errs, bins=range(0, max(max(errs) + 2, 5)),
                 alpha=0.6, color=color, edgecolor='black', label=f'{variant}')
axes[1].set_xlabel('Number of bit errors per exchange', fontsize=11)
axes[1].set_ylabel('Frequency', fontsize=11)
axes[1].set_title('Distribution of Bit Errors', fontsize=12)
axes[1].legend(fontsize=10)
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('mlkem_success_rate.png', dpi=150, bbox_inches='tight')
plt.show()
ML-KEM-512: 200/200 successful (100.0%)
ML-KEM-768: 200/200 successful (100.0%)
ML-KEM-1024: 200/200 successful (100.0%)
../_images/55f21c4bfed52439d512801a5ecc1409efe4dae03297220c9d2811254c4a1bfc.png

41.13 Visualising Twiddle Factors#

The NTT twiddle factors \(\zeta^{\text{br}(i)}\) for \(i = 0, \ldots, 127\) form a specific pattern in \(\mathbb{Z}_q\). Let us visualise their structure.

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt

ntt = NTT()
zetas = ntt.ntt_zetas

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].scatter(range(128), zetas, c=range(128), cmap='viridis',
                s=25, edgecolors='black', linewidths=0.5)
axes[0].set_xlabel('Index $i$', fontsize=11)
axes[0].set_ylabel('$\\zeta^{\\mathrm{br}(i)}$ mod $q$', fontsize=11)
axes[0].set_title('NTT Twiddle Factors ($q = 3329$, $\\zeta = 17$)', fontsize=12)
axes[0].grid(True, alpha=0.3)

# Map Z_q values to angles on unit circle analogy
angles = [2 * np.pi * z / 3329 for z in zetas]
axes[1].scatter([np.cos(a) for a in angles],
                [np.sin(a) for a in angles],
                c=range(128), cmap='plasma', s=30,
                edgecolors='black', linewidths=0.5)
circle = np.linspace(0, 2 * np.pi, 200)
axes[1].plot(np.cos(circle), np.sin(circle), 'k-', alpha=0.2)
axes[1].set_xlabel('$\\cos(2\\pi \\zeta / q)$', fontsize=11)
axes[1].set_ylabel('$\\sin(2\\pi \\zeta / q)$', fontsize=11)
axes[1].set_title('Twiddle Factors Mapped to Unit Circle', fontsize=12)
axes[1].set_aspect('equal')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('ntt_twiddle_factors.png', dpi=150, bbox_inches='tight')
plt.show()

print(f'First 10 twiddle factors: {zetas[:10]}')
print(f'zeta[0] = {zetas[0]}  (should be 17^0 = 1)')
print(f'zeta[1] = {zetas[1]}  (should be 17^64 mod 3329 = {pow(17,64,3329)})')
../_images/1916cac68628a67566b5ea9d7368eedb0f39fe51e9038968094a626f7e0afde2.png
First 10 twiddle factors: [1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919]
zeta[0] = 1  (should be 17^0 = 1)
zeta[1] = 1729  (should be 17^64 mod 3329 = 1729)

41.14 The Module Structure of ML-KEM#

The “Module” in Module-LWE refers to working with vectors and matrices over \(R_q\) rather than individual ring elements. This provides a flexible way to scale security by adjusting the module rank \(k\).

Ring-LWE vs Module-LWE vs LWE

  • LWE (Regev, 2005): vectors over \(\mathbb{Z}_q\). Large keys.

  • Ring-LWE (Lyubashevsky-Peikert-Regev, 2010): single elements of \(R_q\). Compact keys but security relies on ring structure.

  • Module-LWE (Langlois-Stehle, 2015): vectors of \(R_q\) elements. Intermediate key sizes with tuneable security via \(k\). ML-KEM uses this approach.

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(16, 5))
rng = np.random.default_rng(42)

for idx, (k, variant) in enumerate([(2, '512'), (3, '768'), (4, '1024')]):
    A = rng.integers(0, 3329, size=(k * 8, k * 8))
    ax = axes[idx]
    im = ax.imshow(A, cmap='Blues', aspect='equal')
    for i in range(1, k):
        ax.axhline(y=i * 8 - 0.5, color='red', linewidth=2)
        ax.axvline(x=i * 8 - 0.5, color='red', linewidth=2)
    ax.set_title(f'ML-KEM-{variant}\n$\\hat{{A}} \\in R_q^{{{k} \\times {k}}}$\n'
                 f'(showing first 8 coeffs per poly)', fontsize=11)
    ax.set_xlabel('Column index', fontsize=10)
    ax.set_ylabel('Row index' if idx == 0 else '', fontsize=10)
    for bi in range(k):
        for bj in range(k):
            ax.text(bj * 8 + 3.5, bi * 8 + 3.5,
                    f'$\\hat{{a}}_{{{bi},{bj}}}$',
                    ha='center', va='center', fontsize=9,
                    color='white', fontweight='bold',
                    bbox=dict(boxstyle='round', facecolor='darkblue', alpha=0.7))

fig.suptitle('Module Structure of $\\hat{\\mathbf{A}}$ in ML-KEM', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('mlkem_module_structure.png', dpi=150, bbox_inches='tight')
plt.show()
../_images/75f6584d0cc5c58b2a2f144ce2adc614a4f1c7a238f650e1443ce37917c32a95.png

41.15 Why Decryption Works: A Visual Proof#

During encryption, Bob computes:

\[ \mathbf{u} = \mathbf{A}^T \mathbf{y} + \mathbf{e}_1, \quad v = \mathbf{t}^T \mathbf{y} + e_2 + \lceil q/2 \rfloor \cdot m\]

During decryption, Alice computes:

\[ v - \mathbf{s}^T \mathbf{u} = \mathbf{e}^T \mathbf{y} - \mathbf{s}^T \mathbf{e}_1 + e_2 + \lceil q/2 \rfloor m\]

The noise \(\mathbf{e}^T \mathbf{y} - \mathbf{s}^T \mathbf{e}_1 + e_2\) is small since all these vectors have small coefficients.

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt

kem = MLKEM(variant='ML-KEM-512', seed=77)
q = kem.q
k = kem.k
ntt_obj = kem.ntt

# KeyGen
A_hat = [[kem._sample_uniform() for _ in range(k)] for _ in range(k)]
s = [kem._sample_cbd(kem.eta1) for _ in range(k)]
s_hat = ntt_obj.ntt_vector(s)
e = [kem._sample_cbd(kem.eta1) for _ in range(k)]
e_hat = ntt_obj.ntt_vector(e)
As_hat = mat_vec_mul_ntt(A_hat, s_hat, ntt_obj)
t_hat = [poly_add(As_hat[i], e_hat[i]) for i in range(k)]

# Encrypt (no compression)
m = [int(x) for x in kem.rng.integers(0, 2, size=256)]
y = [kem._sample_cbd(kem.eta1) for _ in range(k)]
y_hat = ntt_obj.ntt_vector(y)
e1 = [kem._sample_cbd(kem.eta2) for _ in range(k)]
e2 = kem._sample_cbd(kem.eta2)
A_hat_T = [[A_hat[j][i] for j in range(k)] for i in range(k)]
Ay_hat = mat_vec_mul_ntt(A_hat_T, y_hat, ntt_obj)
Ay = ntt_obj.intt_vector(Ay_hat)
u = [poly_add(Ay[i], e1[i]) for i in range(k)]
ty_hat = vec_dot_ntt(t_hat, y_hat, ntt_obj)
ty = ntt_obj.inverse(ty_hat)
q_half = (q + 1) // 2
m_scaled = [(mi * q_half) % q for mi in m]
v = poly_add(poly_add(ty, e2), m_scaled)

# Decrypt
u_hat = ntt_obj.ntt_vector(u)
su_hat = vec_dot_ntt(s_hat, u_hat, ntt_obj)
su = ntt_obj.inverse(su_hat)
w = poly_sub(v, su)

fig, axes = plt.subplots(2, 2, figsize=(14, 9))

# Plot 1: Original message (scaled)
axes[0, 0].stem(range(256), m_scaled, linefmt='b-', markerfmt='bo',
                basefmt='k-')
axes[0, 0].axhline(y=q_half, color='red', linestyle='--', alpha=0.5,
                    label=f'q/2 = {q_half}')
axes[0, 0].set_title('Encoded message $\\lceil q/2 \\rfloor \cdot  m$', fontsize=12)
axes[0, 0].set_ylabel('Value', fontsize=10)
axes[0, 0].legend(fontsize=9)

# Plot 2: Decrypted w (raw)
colors_pts = ['red' if mi == 1 else 'blue' for mi in m]
axes[0, 1].scatter(range(256), w, s=5, c=colors_pts, alpha=0.7)
axes[0, 1].axhline(y=q_half, color='green', linestyle='--', alpha=0.5)
axes[0, 1].axhline(y=q // 4, color='orange', linestyle=':', alpha=0.5)
axes[0, 1].axhline(y=3 * q // 4, color='orange', linestyle=':', alpha=0.5)
axes[0, 1].set_title('Decrypted $w = v - \\mathbf{s}^T \\mathbf{u}$\n'
                      '(red=bit 1, blue=bit 0)', fontsize=12)
axes[0, 1].set_ylabel('Value in $\\mathbb{Z}_q$', fontsize=10)

# Plot 3: Noise (w - m_scaled)
noise = [(wi - mi) % q for wi, mi in zip(w, m_scaled)]
noise_centered = [n if n <= q // 2 else n - q for n in noise]
axes[1, 0].stem(range(256), noise_centered, linefmt='r-', markerfmt='ro',
                basefmt='k-')
axes[1, 0].axhline(y=q // 4, color='darkred', linestyle=':', linewidth=1.5,
                    label=f'Threshold (+/-{q//4})')
axes[1, 0].axhline(y=-q // 4, color='darkred', linestyle=':', linewidth=1.5)
axes[1, 0].set_title('Decryption noise $w - \\lceil q/2 \\rfloor \cdot  m$', fontsize=12)
axes[1, 0].set_xlabel('Coefficient index', fontsize=10)
axes[1, 0].set_ylabel('Noise value', fontsize=10)
axes[1, 0].legend(fontsize=9)

# Plot 4: Decision regions
axes[1, 1].scatter(range(256), w, s=5, c=colors_pts, alpha=0.7)
axes[1, 1].axhspan(0, q // 4, alpha=0.1, color='blue', label='Decode as 0')
axes[1, 1].axhspan(q // 4, 3 * q // 4, alpha=0.1, color='red', label='Decode as 1')
axes[1, 1].axhspan(3 * q // 4, q, alpha=0.1, color='blue')
axes[1, 1].set_title('Decision Regions for Message Recovery', fontsize=12)
axes[1, 1].set_xlabel('Coefficient index', fontsize=10)
axes[1, 1].set_ylabel('Value in $\\mathbb{Z}_q$', fontsize=10)
axes[1, 1].legend(fontsize=9)

fig.suptitle('ML-KEM-512: Encryption, Decryption, and Noise', fontsize=14, y=1.01)
plt.tight_layout()
plt.savefig('mlkem_correctness_visual.png', dpi=150, bbox_inches='tight')
plt.show()

max_noise = max(abs(n) for n in noise_centered)
print(f'Maximum |noise| = {max_noise} (threshold = {q//4})')
print(f'Noise safely below threshold: {max_noise < q//4}')
<>:47: SyntaxWarning: invalid escape sequence '\c'
<>:69: SyntaxWarning: invalid escape sequence '\c'
<>:47: SyntaxWarning: invalid escape sequence '\c'
<>:69: SyntaxWarning: invalid escape sequence '\c'
/var/folders/z7/wp7m8p7x1250jzvklw5z24mm0000gn/T/ipykernel_72578/395299474.py:47: SyntaxWarning: invalid escape sequence '\c'
  axes[0, 0].set_title('Encoded message $\\lceil q/2 \\rfloor \cdot  m$', fontsize=12)
/var/folders/z7/wp7m8p7x1250jzvklw5z24mm0000gn/T/ipykernel_72578/395299474.py:69: SyntaxWarning: invalid escape sequence '\c'
  axes[1, 0].set_title('Decryption noise $w - \\lceil q/2 \\rfloor \cdot  m$', fontsize=12)
../_images/ea55365d7fe9d995b87264159becf01d60fb8739c996348d608126753683373b.png
Maximum |noise| = 127 (threshold = 832)
Noise safely below threshold: True

41.16 ML-KEM vs Classical Key Exchange#

Scheme

Public Key

Ciphertext

Security

Quantum Safe?

RSA-2048

256 B

256 B

~112 bits

No

ECDH P-256

64 B

64 B

~128 bits

No

ML-KEM-512

800 B

768 B

~118 bits

Yes

ML-KEM-768

1184 B

1088 B

~182 bits

Yes

ML-KEM-1024

1568 B

1568 B

~256 bits

Yes

Practical impact

While ML-KEM keys are 4–20x larger than ECDH keys, they are still small enough for most applications. The computational cost is comparable to or faster than RSA. ML-KEM is already being deployed in TLS 1.3 (as a hybrid with X25519) by major browsers and cloud providers.

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt

schemes = ['RSA\n2048', 'ECDH\nP-256', 'ML-KEM\n512', 'ML-KEM\n768', 'ML-KEM\n1024']
pk_sizes = [256, 64, 800, 1184, 1568]
ct_sizes = [256, 64, 768, 1088, 1568]
quantum_safe = [False, False, True, True, True]

fig, ax = plt.subplots(figsize=(10, 6))
x = np.arange(len(schemes))
w = 0.35
colors_pk = ['#e74c3c' if not qs else '#27ae60' for qs in quantum_safe]
colors_ct = ['#c0392b' if not qs else '#229954' for qs in quantum_safe]

bars1 = ax.bar(x - w/2, pk_sizes, w, label='Public key', color=colors_pk,
               edgecolor='black', alpha=0.85)
bars2 = ax.bar(x + w/2, ct_sizes, w, label='Ciphertext', color=colors_ct,
               edgecolor='black', alpha=0.6)

for bar, val in zip(bars1, pk_sizes):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 20,
            f'{val}', ha='center', va='bottom', fontsize=9)
for bar, val in zip(bars2, ct_sizes):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 20,
            f'{val}', ha='center', va='bottom', fontsize=9)

ax.set_xticks(x)
ax.set_xticklabels(schemes)
ax.set_ylabel('Size (bytes)', fontsize=12)
ax.set_title('Key Exchange Sizes: Classical vs Post-Quantum\n'
             '(Green = quantum-safe, Red = quantum-vulnerable)', fontsize=13)
ax.legend(fontsize=11)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('mlkem_vs_classical_sizes.png', dpi=150, bbox_inches='tight')
plt.show()
../_images/981bd669304b3f0a5c6e329727e05f5f9079d186aeccfe2fab9dce7651c16362.png

41.17 Exercises#

Exercise 41.1: NTT Round-Trip

Verify that the NTT is its own (scaled) inverse. For a random polynomial \(f \in R_q\), compute \(\hat{f} = \text{NTT}(f)\) and \(f' = \text{NTT}^{-1}(\hat{f})\). Show that \(f' = f\) for 100 random polynomials.

Then measure: what is the maximum intermediate value that appears during NTT computation? How does this relate to the choice of \(q\)?

Exercise 41.2: CBD Variance

Prove analytically that the variance of \(\text{CBD}_\eta\) is \(\eta/2\). Then verify experimentally by sampling 100,000 coefficients for \(\eta \in \{1, 2, 3, 4, 5\}\) and comparing the empirical variance to \(\eta/2\).

Exercise 41.3: Compression Error

For each compression parameter \(d \in \{1, 2, \ldots, 11\}\), compute the maximum compression error \(\max_x |x - \text{Decompress}(\text{Compress}(x, d), d)|\) over all \(x \in \{0, \ldots, q-1\}\).

Plot the maximum error vs \(d\). Why do ML-KEM-512 and ML-KEM-768 use \((d_u, d_v) = (10, 4)\) while ML-KEM-1024 uses \((11, 5)\)?

Exercise 41.4: Schoolbook vs NTT Scaling

Measure the time for polynomial multiplication using both schoolbook and NTT methods for ring dimensions \(n \in \{32, 64, 128, 256\}\).

Plot the timing ratio (schoolbook/NTT) as a function of \(n\). At what \(n\) does the NTT become faster than schoolbook?

Exercise 41.5: Lattice Attack Simulation

For very small parameters (\(n = 4\), \(q = 97\), \(k = 1\)), implement the MLWE problem directly. Generate random \(a\), secret \(s\) with small coefficients, and error \(e\). Compute \(b = a \cdot s + e\).

Then try to recover \(s\) by brute-force search. Estimate how long a brute-force attack would take for real ML-KEM-512 parameters.

Hide code cell source
import numpy as np
import matplotlib.pyplot as plt

# --- Starter code for Exercise 41.3: Compression Error ---

q = 3329
max_errors = []
d_values = list(range(1, 12))

for d in d_values:
    max_err = 0
    for x in range(q):
        y = compress(x, d, q)
        x_rec = decompress(y, d, q)
        err = abs(x - x_rec)
        if err > q // 2:
            err = q - err
        if err > max_err:
            max_err = err
    max_errors.append(max_err)

# Theoretical bound
theoretical = [int(np.ceil(q / (2 ** (d + 1)))) for d in d_values]

fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(d_values, max_errors, 'bo-', linewidth=2, markersize=8,
        label='Empirical max error')
ax.plot(d_values, theoretical, 'r^--', linewidth=1.5, markersize=7,
        label='Theoretical bound $\\lceil q/2^{d+1} \\rceil$')

# Mark ML-KEM parameters
ax.axvline(x=4, color='green', linestyle=':', alpha=0.7, label='$d_v$ = 4 (512/768)')
ax.axvline(x=5, color='orange', linestyle=':', alpha=0.7, label='$d_v$ = 5 (1024)')
ax.axvline(x=10, color='purple', linestyle=':', alpha=0.7, label='$d_u$ = 10 (512/768)')
ax.axvline(x=11, color='brown', linestyle=':', alpha=0.7, label='$d_u$ = 11 (1024)')

ax.set_xlabel('Compression parameter $d$', fontsize=12)
ax.set_ylabel('Maximum error', fontsize=12)
ax.set_title('Compression Error in ML-KEM', fontsize=13)
ax.legend(fontsize=9)
ax.set_xticks(d_values)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('mlkem_compression_error.png', dpi=150, bbox_inches='tight')
plt.show()

print('Max compression errors:')
for d, me, th in zip(d_values, max_errors, theoretical):
    print(f'  d={int(d):2d}: max_error={int(me):4d}, bound={int(th):4d}')
../_images/a307d8ad7f7efc3169759f7cc5afb088097f4f724aea13d42cb3e859922e3703.png
Max compression errors:
  d= 1: max_error= 832, bound= 833
  d= 2: max_error= 416, bound= 417
  d= 3: max_error= 208, bound= 209
  d= 4: max_error= 104, bound= 105
  d= 5: max_error=  52, bound=  53
  d= 6: max_error=  26, bound=  27
  d= 7: max_error=  13, bound=  14
  d= 8: max_error=   7, bound=   7
  d= 9: max_error=   3, bound=   4
  d=10: max_error=   2, bound=   2
  d=11: max_error=   1, bound=   1

41.18 Summary#

In this chapter we built a complete educational implementation of ML-KEM (Kyber), the first NIST-standardised post-quantum key encapsulation mechanism:

  1. Mathematical foundations: the polynomial ring \(R_q = \mathbb{Z}_{3329}[x]/(x^{256}+1)\), the Module-LWE problem, and the centered binomial distribution.

  2. Number Theoretic Transform (NTT): forward and inverse transforms using Cooley-Tukey and Gentleman-Sande butterflies, achieving \(O(n \log n)\) polynomial multiplication with a measured speedup over schoolbook.

  3. Full ML-KEM protocol: KeyGen, Encaps, and Decaps operations, demonstrating successful shared secret agreement across all three parameter sets.

  4. Experimental analysis:

    • Decryption noise stays well within the correctness threshold.

    • All parameter sets achieve high success rate in our tests (theoretical failure probability is \(< 2^{-139}\) for ML-KEM-512).

    • Key and ciphertext sizes range from 800+768 bytes (ML-KEM-512) to 1568+1568 bytes (ML-KEM-1024).

  5. Comparison with classical schemes: ML-KEM keys are larger than ECDH but still practical, and the scheme is already being deployed in TLS 1.3.

Key Takeaways#

  • ML-KEM’s security rests on the hardness of Module-LWE, which is believed to resist both classical and quantum attacks.

  • The NTT is the key algorithmic enabler, making polynomial ring operations fast.

  • The centered binomial distribution provides efficient, constant-time noise sampling.

  • Parameter selection balances security level, key sizes, and failure probability.

  • ML-KEM is already being adopted as a hybrid KEM (alongside X25519) in major internet protocols.

References#

  1. NIST (2024). FIPS 203: Module-Lattice-Based Key-Encapsulation Mechanism Standard. https://csrc.nist.gov/pubs/fips/203/final

  2. Avanzi, R., Bos, J., Ducas, L., Kiltz, E., Lepoint, T., Lyubashevsky, V., Schanck, J.M., Schwabe, P., Seiler, G., Stehle, D. (2021). CRYSTALS-Kyber: Algorithm Specifications and Supporting Documentation (version 3.02).

  3. Regev, O. (2005). On Lattices, Learning with Errors, Random Linear Codes, and Cryptography. STOC 2005.

  4. Lyubashevsky, V., Peikert, C., Regev, O. (2010). On Ideal Lattices and Learning with Errors Over Rings. EUROCRYPT 2010.

  5. Langlois, A., Stehle, D. (2015). Worst-Case to Average-Case Reductions for Module Lattices. Designs, Codes and Cryptography.

  6. Peikert, C. (2016). A Decade of Lattice Cryptography. Foundations and Trends in Theoretical Computer Science.

  7. Bernstein, D.J., Lange, T. (2017). Post-quantum cryptography. Nature 549, 188–194.