Chapter 47 — Practical Exercises#

Companion sheet for Chapter 47 — Hash-Based Signatures. Six exercises building from a one-line verifier to a state-reuse forgery.

Click “Click to show” on any solution cell to reveal a fully commented reference implementation.

Exercise 47.E1 — ★ Lamport verification#

Goal. Fill in lamport_verify(pk, msg, sig) so it returns True iff the signature is valid.

Theory. §47.2 — Lamport one-time signatures.

import hashlib

def H(x): return hashlib.sha256(x).digest()
def bits(x):
    out = []
    for byte in x:
        for j in range(8):
            out.append((byte >> (7 - j)) & 1)
    return out


def lamport_verify(pk, msg, sig):
    '''pk[i] is a 2-element list [H(s_{i,0}), H(s_{i,1})]; sig[i] is the revealed s.'''
    # TODO: hash each sig[i] and check that it matches pk[i][h_i] where
    # h = bits(H(msg)).  Return True if EVERY position checks out.
    raise NotImplementedError('your turn')


# Self-test.
import os
sk = [[os.urandom(32), os.urandom(32)] for _ in range(256)]
pk = [[H(a), H(b)] for a, b in sk]
msg = b'hello world'
sig = [sk[i][b] for i, b in enumerate(bits(H(msg)))]
# assert lamport_verify(pk, msg, sig) is True
# assert lamport_verify(pk, b'hello WORLD', sig) is False
# print('E1 OK')

Hide code cell content

# Solution.
#
# Verification is symmetric to signing: we hash each revealed secret string and
# compare against the precomputed public commitment for the bit selected by
# the message hash.  All 256 positions must agree -- a single mismatch means
# the signature is invalid.

def lamport_verify(pk, msg, sig):
    h = bits(H(msg))
    return all(H(sig[i]) == pk[i][b] for i, b in enumerate(h))

assert lamport_verify(pk, msg, sig) is True
assert lamport_verify(pk, b'hello WORLD', sig) is False
print('E1 OK')
E1 OK

Exercise 47.E2 — ★★ Build a Merkle authentication path#

Goal. Given the level-by-level hashes of a binary Merkle tree levels[0] = leaves, levels[1] = parents, ..., levels[-1] = [root], write auth_path(levels, j) returning the \(h\) sibling hashes from leaf \(j\) up to the root.

Theory. §47.4 — Merkle trees: one public key, many signatures.

def build_merkle(leaves):
    levels = [leaves[:]]
    while len(levels[-1]) > 1:
        cur = levels[-1]
        nxt = [H(cur[i] + cur[i + 1]) for i in range(0, len(cur), 2)]
        levels.append(nxt)
    return levels


def auth_path(levels, j):
    '''Return the list of sibling hashes from leaf j up to the root.'''
    # TODO: at each level (except the root), append the SIBLING hash to the
    # path and update j -> j // 2.  Use j ^ 1 to flip the last bit (= sibling
    # index).
    raise NotImplementedError('your turn')


def verify_merkle_path(root, leaf, j, path):
    cur = leaf
    for sib in path:
        cur = H(cur + sib) if j % 2 == 0 else H(sib + cur)
        j //= 2
    return cur == root


# Test.
import os
leaves = [H(os.urandom(8)) for _ in range(8)]
levels = build_merkle(leaves)
root   = levels[-1][0]
# for j in range(8):
#     p = auth_path(levels, j)
#     assert verify_merkle_path(root, leaves[j], j, p), f'leaf {j} failed'
# print('E2 OK')

Hide code cell content

# Solution.
#
# The auth path is just the sibling hash at each level on the way to the root.
# The bit-trick j ^ 1 flips the low bit (i.e. yields the sibling index of j),
# and j //= 2 walks us up one level in the tree.
#
# The verify routine recombines: at each step, hash (current, sibling) in the
# correct order depending on whether j was the LEFT or RIGHT child.  This is
# §47.4 of the chapter.

def auth_path(levels, j):
    path = []
    for lv in range(len(levels) - 1):
        path.append(levels[lv][j ^ 1])
        j //= 2
    return path

for j in range(8):
    p = auth_path(levels, j)
    assert verify_merkle_path(root, leaves[j], j, p), f'leaf {j} failed'
print('E2 OK')
E2 OK

Exercise 47.E3 — ★★★ The WOTS+ checksum#

Goal. Compute the Winternitz checksum digits given the message-hash digits.

Theory. §47.3 — Winternitz one-time signatures (WOTS+).

Why this matters. Without the checksum, an attacker can forge by incrementing a digit (which costs only further hashing). The checksum forces any digit increase to be matched by a decrease somewhere else, which is hard because the chain is preimage-resistant in that direction.

import math

W = 16                   # Winternitz digit alphabet size (4 bits per digit)
WBITS = 4
N = 32                   # SHA-256 output bytes
LEN1 = (8 * N + WBITS - 1) // WBITS                  # = 64 message-hash digits
LEN2 = (math.floor(math.log2(LEN1 * (W - 1))) // WBITS) + 1   # = 3 checksum digits


def msg_to_digits(msg_hash):
    '''Split a 32-byte hash into 64 base-16 digits.'''
    out, buf, bits_in_buf = [], 0, 0
    for byte in msg_hash:
        buf = (buf << 8) | byte; bits_in_buf += 8
        while bits_in_buf >= WBITS:
            bits_in_buf -= WBITS
            out.append((buf >> bits_in_buf) & (W - 1))
    return out


def wots_checksum(msg_digits):
    '''Return the LEN2 checksum digits given the LEN1 message digits.'''
    # TODO step 1: compute csum = sum_i (W - 1 - d_i) over message digits.
    # TODO step 2: pack csum into LEN2 base-W digits.  Use big-endian byte
    #              packing then call msg_to_digits on the bytes.
    raise NotImplementedError('your turn')


# Sanity test: each message digit is W-1 -> checksum is zero.
# m = [W - 1] * LEN1
# c = wots_checksum(m)
# assert c[:LEN2] == [0] * LEN2, c
# # And: each digit zero -> checksum is its maximum.
# m0 = [0] * LEN1
# c0 = wots_checksum(m0)
# assert sum(c0[:LEN2]) > 0
# print('E3 OK')

Hide code cell content

# Solution.
#
# The checksum is sum_i (W - 1 - d_i).  Encoded in big-endian base-W and
# truncated to LEN2 digits.  Hülsing 2013 gives the standard packing rule;
# see chapter §47.3.

def wots_checksum(msg_digits):
    csum = sum((W - 1) - d for d in msg_digits)
    csum_bytes = csum.to_bytes((LEN2 * WBITS + 7) // 8, 'big')
    return msg_to_digits(csum_bytes)[:LEN2]

m = [W - 1] * LEN1
assert wots_checksum(m) == [0] * LEN2

m0 = [0] * LEN1
c0 = wots_checksum(m0)
assert sum(c0) > 0
print('E3 OK')
E3 OK

Exercise 47.E4 — ★★★ Build a height-3 Merkle MSS#

Goal. Combine your Lamport from E1 + Merkle from E2 into a height-3 many-time signature scheme. Sign the \(j\)-th of 8 messages and produce a single signature blob (ots_sig, ots_pk, auth_path) that the verifier can check using only the published root.

Theory. §47.4 — Merkle trees: one public key, many signatures.

def lamport_keygen():
    sk = [[os.urandom(32), os.urandom(32)] for _ in range(256)]
    pk = [[H(a), H(b)] for a, b in sk]
    return sk, pk


def lamport_sign(sk, msg):
    return [sk[i][b] for i, b in enumerate(bits(H(msg)))]


def serialize_pk(pk):
    return b''.join(b''.join(pair) for pair in pk)


def mss_keygen(h):
    n = 1 << h
    leaf_keys = [lamport_keygen() for _ in range(n)]
    leaves = [H(serialize_pk(pk)) for _, pk in leaf_keys]
    levels = build_merkle(leaves)
    return leaf_keys, levels[-1][0], levels    # secret leaf-keys + root + levels


def mss_sign(leaf_keys, levels, j, msg):
    # TODO: produce the triple (ots_sig, ots_pk, path).
    raise NotImplementedError('your turn')


def mss_verify(root, j, msg, blob):
    ots_sig, ots_pk, path = blob
    if not lamport_verify(ots_pk, msg, ots_sig): return False
    leaf_hash = H(serialize_pk(ots_pk))
    return verify_merkle_path(root, leaf_hash, j, path)


# Test on a height-3 tree, 8 leaves, sign leaf 5.
# leaf_keys, root, levels = mss_keygen(3)
# blob = mss_sign(leaf_keys, levels, 5, b'pay 100 PLN')
# assert mss_verify(root, 5, b'pay 100 PLN', blob)
# assert not mss_verify(root, 5, b'pay 999 PLN', blob)
# print('E4 OK')

Hide code cell content

# Solution.
#
# Putting it together: at signing time we (a) extract the OTS sig with
# lamport_sign on leaf j's secret key, (b) include leaf j's OTS public key
# (the verifier needs to recompute the leaf hash), and (c) include the auth
# path computed in E2.

def mss_sign(leaf_keys, levels, j, msg):
    sk_j, pk_j = leaf_keys[j]
    ots_sig    = lamport_sign(sk_j, msg)
    path       = auth_path(levels, j)
    return ots_sig, pk_j, path

leaf_keys, root, levels = mss_keygen(3)
blob = mss_sign(leaf_keys, levels, 5, b'pay 100 PLN')
assert mss_verify(root, 5, b'pay 100 PLN', blob)
assert not mss_verify(root, 5, b'pay 999 PLN', blob)
print('E4 OK')
E4 OK

Exercise 47.E5 — ★★★★ State-reuse forgery (24-bit hash)#

Goal. Demonstrate the state-reuse catastrophe with a 24-bit truncated hash. Sign two distinct messages with the same Lamport key, then forge a third message such that its signature is verified by the public key.

Theory. §47.7 — The state-reuse catastrophe.

DEMO_BITS = 24

def H_demo(x): return hashlib.sha256(x).digest()[: DEMO_BITS // 8]


def demo_keygen():
    sk = [[os.urandom(32), os.urandom(32)] for _ in range(DEMO_BITS)]
    pk = [[H(a), H(b)] for a, b in sk]
    return sk, pk


def demo_sign(sk, msg):
    h_bits = []
    for byte in H_demo(msg):
        for j in range(8): h_bits.append((byte >> (7 - j)) & 1)
    return [sk[i][b] for i, b in enumerate(h_bits)]


def demo_verify(pk, msg, sig):
    h_bits = []
    for byte in H_demo(msg):
        for j in range(8): h_bits.append((byte >> (7 - j)) & 1)
    return all(H(sig[i]) == pk[i][b] for i, b in enumerate(h_bits))


def attack_after_two_sigs(pk, msg1, msg2, sig1, sig2):
    '''Find a third message with a forged valid signature.'''
    # TODO step 1: from (sig1, msg1) and (sig2, msg2) build a dictionary
    #              `known[(i, b)] = secret` for every revealed secret string.
    # TODO step 2: brute-force a salted message until every position of its
    #              hash lands on a key that exists in `known`.  Then assemble
    #              the forged signature.
    raise NotImplementedError('your turn')


# sk, pk = demo_keygen()
# m1 = b'Pay 100 PLN to Alice.'
# m2 = b'Pay 999 PLN to Bob.'
# s1 = demo_sign(sk, m1); s2 = demo_sign(sk, m2)
# m3, s3 = attack_after_two_sigs(pk, m1, m2, s1, s2)
# print('forged for:', m3)
# assert demo_verify(pk, m3, s3)
# print('E5 OK')

Hide code cell content

# Solution.
#
# The attacker uses the two observed signatures to populate a partial secret
# key.  Then they brute-force salted candidate messages until the message
# hash happens to fall entirely on positions whose secrets are known.
#
# At 24 bits with two signatures, roughly half of the 24 hash positions
# differ between m1 and m2 (so both secrets are revealed there); at the rest,
# only one secret is known.  The success probability per random trial is
# roughly (3/4)^24 ~= 1/1000.  We expect a forgery in a few thousand tries.

def attack_after_two_sigs(pk, msg1, msg2, sig1, sig2):
    bits1, bits2 = [], []
    for byte in H_demo(msg1):
        for j in range(8): bits1.append((byte >> (7 - j)) & 1)
    for byte in H_demo(msg2):
        for j in range(8): bits2.append((byte >> (7 - j)) & 1)
    known = {}
    for i, (a, b) in enumerate(zip(bits1, bits2)):
        known[(i, a)] = sig1[i]
        known[(i, b)] = sig2[i]
    # Brute-force.
    for trial in range(2_000_000):
        cand = f'Pay Mallory salt={trial}'.encode()
        h_bits = []
        for byte in H_demo(cand):
            for j in range(8): h_bits.append((byte >> (7 - j)) & 1)
        if all((i, b) in known for i, b in enumerate(h_bits)):
            return cand, [known[(i, b)] for i, b in enumerate(h_bits)]
    raise RuntimeError('no forgery found')

sk, pk = demo_keygen()
m1 = b'Pay 100 PLN to Alice.'
m2 = b'Pay 999 PLN to Bob.'
s1 = demo_sign(sk, m1); s2 = demo_sign(sk, m2)
m3, s3 = attack_after_two_sigs(pk, m1, m2, s1, s2)
print('forged for:', m3)
assert demo_verify(pk, m3, s3)
print('E5 OK')
forged for: b'Pay Mallory salt=212'
E5 OK

Exercise 47.E6 — ★★★★★ Research: explore SLH-DSA-128s parameters#

Goal. Read FIPS 205 (NIST 2024) Algorithm 17 (slh_sign) and Algorithm 18 (slh_verify). Write a one-page memo explaining how the SLH-DSA-SHA2-128s parameter set fits its 7 856-byte signature size: how many WOTS+ signatures, FORS commitments, hyper-tree levels.

Then implement the outermost hyper-tree layer: a single Merkle tree of height \(h_{prime} = h / d\) that signs the root of the layer below using WOTS+. Verify your byte counts against the spec.

Theory. §47.6 — SLH-DSA / SPHINCS+.

Parameter set SLH-DSA-SHA2-128s (FIPS 205 §10)

\(n = 16\), \(h = 63\), \(d = 7\), \(h_{prime} = 9\), \(a = 12\), \(k = 14\), \(w = 16\).

# (Open-ended.) Sketch your reasoning here, then build the outer hyper-tree
# layer.  The exercise is intentionally underspecified -- design choices
# (recursive vs. iterative tree construction, in-memory layout, hash chain
# parametrization) are part of the project.

Hide code cell content

# Reference solution -- partial.
#
# A full SLH-DSA implementation is several hundred lines (see the SPHINCS+
# reference at https://github.com/sphincs/sphincsplus).  Here we sketch the
# size accounting for SLH-DSA-SHA2-128s and verify it against FIPS 205.
#
# - n = 16 bytes (output of internal hash)
# - h = 63 (total hyper-tree height) split into d = 7 layers of h' = 9 each
# - WOTS+ at each subtree node: w = 16 -> len = 35 chains of n bytes each
# - FORS leaves at the bottom: k = 14 trees, each of height a = 12 with n-byte leaves
#
# Signature shape (per FIPS 205 Algorithm 17):
#   1.  R           (n bytes)                              = 16
#   2.  FORS sig    : k * (1 + a) * n  = 14 * 13 * 16     = 2912
#   3.  HT sig      : d * (len + h') * n = 7 * (35+9) * 16= 4928
#  Total                                                   = 7856  ✓
n, h, d, hp, a, k, w = 16, 63, 7, 9, 12, 14, 16
length = w - 1
import math
def wots_len_from_w(w, n):
    len1 = math.ceil(8 * n / int(math.log2(w)))
    len2 = math.floor(math.log2(len1 * (w - 1)) / math.log2(w)) + 1
    return len1, len2, len1 + len2
len1, len2, length = wots_len_from_w(w, n)
print(f'WOTS+ chain count = {length}  (={len1} + {len2})')
total = n + k * (1 + a) * n + d * (length + hp) * n
print(f'Predicted SLH-DSA-128s signature size = {total} bytes')
assert total == 7856, total
print('E6 OK -- size matches FIPS 205 spec.')
WOTS+ chain count = 35  (=32 + 3)
Predicted SLH-DSA-128s signature size = 7856 bytes
E6 OK -- size matches FIPS 205 spec.