Chapter 34: LSTM — The Gating Revolution#
The vanishing gradient problem identified by Hochreiter (1991) and Bengio et al. (1994) seemed to doom recurrent networks. In 1997, Hochreiter and Schmidhuber proposed an elegant solution: instead of fighting the gradient decay, engineer a pathway where gradients can flow unchanged.
The Long Short-Term Memory (LSTM) network introduces a separate cell state \(C_t\) that carries information forward through time via additive updates, bypassing the multiplicative bottleneck that causes gradients to vanish in vanilla RNNs. Three learnable gates control the flow of information into, out of, and within this cell state. The result is a network that can learn dependencies spanning hundreds of time steps—something that was practically impossible with the architectures we studied in previous chapters.
Show code cell source
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
plt.style.use('seaborn-v0_8-whitegrid')
BLUE = '#3b82f6'
GREEN = '#059669'
RED = '#dc2626'
AMBER = '#d97706'
INDIGO = '#4f46e5'
torch.manual_seed(42)
np.random.seed(42)
print('PyTorch version:', torch.__version__)
print('Device:', 'cuda' if torch.cuda.is_available() else 'cpu')
PyTorch version: 2.7.0
Device: cpu
34.1 The Constant Error Carousel#
Recall the fundamental problem with vanilla RNNs: the hidden state update
involves a multiplicative interaction with \(W_{hh}\) at every time step. When we backpropagate through \(T\) steps, the gradient includes the product
If the largest singular value of \(W_{hh}\) is less than 1, this product vanishes exponentially. If it is greater than 1, the product explodes.
Hochreiter’s key insight was to replace this multiplicative chain with an additive update. The LSTM cell state update is:
The gradient of \(C_t\) with respect to \(C_{t-1}\) is simply:
where \(f_t\) is the forget gate, a sigmoid output that can be close to 1. When \(f_t \approx 1\), the gradient passes through unchanged—this is the Constant Error Carousel (CEC). Information stored in the cell state can persist indefinitely, and gradients flow back through time without decay.
The Constant Error Carousel
The CEC is the defining innovation of LSTM. By making the cell state update additive rather than multiplicative, the gradient \(\partial C_t / \partial C_{t-1} = f_t\) can remain close to 1 for arbitrarily many time steps. This solves the vanishing gradient problem at its mathematical root.
The following figure compares gradient flow in a vanilla RNN versus an LSTM:
Show code cell source
# Gradient flow comparison: Vanilla RNN vs LSTM
fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))
T = 30
steps = np.arange(1, T + 1)
# Vanilla RNN gradient decay
ax = axes[0]
for gamma, label, color, ls in [
(0.95, r'$\|W_{hh}\| = 0.95$', BLUE, '-'),
(0.85, r'$\|W_{hh}\| = 0.85$', AMBER, '--'),
(0.70, r'$\|W_{hh}\| = 0.70$', RED, '-.'),
]:
grad_norms = gamma ** steps
ax.plot(steps, grad_norms, color=color, linestyle=ls, linewidth=2, label=label)
ax.set_xlabel('Time steps back', fontsize=11)
ax.set_ylabel('Gradient magnitude (relative)', fontsize=11)
ax.set_title('Vanilla RNN: Gradient Decay', fontsize=12, fontweight='bold')
ax.set_ylim(0, 1.1)
ax.legend(fontsize=9)
ax.axhline(y=0.01, color='gray', linestyle=':', alpha=0.5)
ax.text(T - 1, 0.03, 'effectively zero', fontsize=8, color='gray', ha='right')
# LSTM gradient preservation
ax = axes[1]
for f_val, label, color, ls in [
(1.00, r'$f_t = 1.0$ (perfect memory)', GREEN, '-'),
(0.98, r'$f_t = 0.98$', BLUE, '--'),
(0.90, r'$f_t = 0.90$', AMBER, '-.'),
]:
grad_norms = f_val ** steps
ax.plot(steps, grad_norms, color=color, linestyle=ls, linewidth=2, label=label)
ax.set_xlabel('Time steps back', fontsize=11)
ax.set_ylabel('Gradient magnitude (relative)', fontsize=11)
ax.set_title('LSTM: Gradient via Cell State', fontsize=12, fontweight='bold')
ax.set_ylim(0, 1.1)
ax.legend(fontsize=9)
plt.suptitle('Gradient Flow: Vanilla RNN vs LSTM', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()
The contrast is stark. A vanilla RNN with \(\|W_{hh}\| = 0.85\) retains less than 1% of the gradient after just 25 steps. An LSTM with forget gate values near 1 preserves the gradient almost perfectly over the same horizon.
34.2 LSTM Cell Architecture#
The LSTM cell maintains two state vectors: the cell state \(C_t\) (the long-term memory highway) and the hidden state \(h_t\) (the short-term output). At each time step, three gates regulate the information flow.
Definition: LSTM Cell Equations
Given input \(x_t \in \mathbb{R}^d\), previous hidden state \(h_{t-1} \in \mathbb{R}^n\), and previous cell state \(C_{t-1} \in \mathbb{R}^n\):
Forget gate (what to discard from cell state):
Input gate (what new information to store):
Cell candidate (proposed new content):
Cell state update (the Constant Error Carousel):
Output gate (what to reveal from cell state):
Hidden state (output at this time step):
Here \(\sigma\) is the sigmoid function, \(\odot\) denotes element-wise multiplication, and \([h_{t-1}, x_t]\) denotes concatenation.
Each gate is a full neural network layer with its own weights and biases. The sigmoid activation ensures gate values lie in \([0, 1]\), acting as soft switches:
\(f_t \approx 1\): keep the old cell state (remember).
\(f_t \approx 0\): erase the old cell state (forget).
\(i_t \approx 1\): write the candidate into the cell state.
\(o_t \approx 1\): expose the cell state to the outside.
The following diagram illustrates the data flow through an LSTM cell:
Show code cell source
# LSTM Cell Diagram
fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(-1, 15)
ax.set_ylim(-1, 10)
ax.set_aspect('equal')
ax.axis('off')
import matplotlib.patches as mpatches
# Colors
gate_colors = {'forget': RED, 'input': GREEN, 'output': BLUE, 'candidate': AMBER}
def draw_gate(ax, x, y, label, color, w=1.8, h=0.9):
rect = mpatches.FancyBboxPatch(
(x - w/2, y - h/2), w, h,
boxstyle=mpatches.BoxStyle('Round', pad=0.1),
facecolor=color, edgecolor='white', linewidth=2, alpha=0.85
)
ax.add_patch(rect)
ax.text(x, y, label, ha='center', va='center', fontsize=10,
fontweight='bold', color='white')
def draw_op(ax, x, y, symbol, size=0.45):
circle = plt.Circle((x, y), size, facecolor='white', edgecolor='#334155',
linewidth=1.5, zorder=5)
ax.add_patch(circle)
ax.text(x, y, symbol, ha='center', va='center', fontsize=14,
fontweight='bold', color='#334155', zorder=6)
# Cell state highway (top)
ax.annotate('', xy=(13, 8), xytext=(1, 8),
arrowprops=dict(arrowstyle='->', lw=3, color='#475569'))
ax.text(0.3, 8, '$C_{t-1}$', fontsize=13, fontweight='bold', color='#475569')
ax.text(13.3, 8, '$C_t$', fontsize=13, fontweight='bold', color='#475569')
ax.text(7, 9, 'Cell State (Long-Term Memory Highway)', fontsize=11,
ha='center', fontstyle='italic', color='#64748b')
# Hidden state (bottom)
ax.annotate('', xy=(13, 2), xytext=(1, 2),
arrowprops=dict(arrowstyle='->', lw=3, color='#475569'))
ax.text(0.3, 2, '$h_{t-1}$', fontsize=13, fontweight='bold', color='#475569')
ax.text(13.3, 2, '$h_t$', fontsize=13, fontweight='bold', color='#475569')
# Input
ax.text(7, 0, '$x_t$', fontsize=13, fontweight='bold', ha='center', color='#475569')
ax.annotate('', xy=(7, 1.2), xytext=(7, 0.4),
arrowprops=dict(arrowstyle='->', lw=2, color='#94a3b8'))
# Forget gate
draw_gate(ax, 3.5, 4.5, r'Forget gate' + '\n' + r'$\sigma$', RED)
ax.text(3.5, 3.5, '$f_t$', fontsize=11, ha='center', color=RED, fontweight='bold')
# Arrow from concat to forget gate
ax.annotate('', xy=(3.5, 4.0), xytext=(3.5, 2.5),
arrowprops=dict(arrowstyle='->', lw=1.5, color='#94a3b8'))
# Multiply on cell state
draw_op(ax, 3.5, 8, r'$\times$')
ax.annotate('', xy=(3.5, 7.5), xytext=(3.5, 5.0),
arrowprops=dict(arrowstyle='->', lw=1.5, color=RED))
# Input gate
draw_gate(ax, 6.5, 4.5, r'Input gate' + '\n' + r'$\sigma$', GREEN)
ax.text(6.5, 3.5, '$i_t$', fontsize=11, ha='center', color=GREEN, fontweight='bold')
ax.annotate('', xy=(6.5, 4.0), xytext=(6.5, 2.5),
arrowprops=dict(arrowstyle='->', lw=1.5, color='#94a3b8'))
# Candidate
draw_gate(ax, 8.5, 4.5, r'Candidate' + '\n' + r'tanh', AMBER)
ax.text(8.5, 3.5, r'$\tilde{C}_t$', fontsize=11, ha='center', color=AMBER, fontweight='bold')
ax.annotate('', xy=(8.5, 4.0), xytext=(8.5, 2.5),
arrowprops=dict(arrowstyle='->', lw=1.5, color='#94a3b8'))
# i_t * C_tilde -> multiply
draw_op(ax, 7.5, 6.5, r'$\times$')
ax.annotate('', xy=(7.1, 6.5), xytext=(6.5, 5.0),
arrowprops=dict(arrowstyle='->', lw=1.5, color=GREEN))
ax.annotate('', xy=(7.9, 6.5), xytext=(8.5, 5.0),
arrowprops=dict(arrowstyle='->', lw=1.5, color=AMBER))
# Add on cell state
draw_op(ax, 7.5, 8, '+')
ax.annotate('', xy=(7.5, 7.5), xytext=(7.5, 7.0),
arrowprops=dict(arrowstyle='->', lw=1.5, color='#64748b'))
# Output gate
draw_gate(ax, 10.5, 4.5, r'Output gate' + '\n' + r'$\sigma$', BLUE)
ax.text(10.5, 3.5, '$o_t$', fontsize=11, ha='center', color=BLUE, fontweight='bold')
ax.annotate('', xy=(10.5, 4.0), xytext=(10.5, 2.5),
arrowprops=dict(arrowstyle='->', lw=1.5, color='#94a3b8'))
# tanh on cell state -> output
draw_op(ax, 11.5, 6.5, 'tanh')
ax.annotate('', xy=(11.5, 6.1), xytext=(11.5, 8),
arrowprops=dict(arrowstyle='<-', lw=1.5, color='#64748b'))
# output gate * tanh(C_t) -> h_t
draw_op(ax, 11.5, 2, r'$\times$')
ax.annotate('', xy=(11.5, 2.45), xytext=(11.5, 6.05),
arrowprops=dict(arrowstyle='<-', lw=1.5, color='#64748b'))
ax.annotate('', xy=(11.1, 2), xytext=(10.5, 5.0),
arrowprops=dict(arrowstyle='<-', lw=1.5, color=BLUE, connectionstyle='arc3,rad=-0.3'))
# Concat indicator
ax.text(7, 1.5, '$[h_{t-1}, x_t]$ concatenated', fontsize=9,
ha='center', fontstyle='italic', color='#94a3b8')
ax.set_title('LSTM Cell Architecture', fontsize=14, fontweight='bold', pad=15)
plt.tight_layout()
plt.show()
34.3 Building LSTM from Scratch#
To truly understand the LSTM, we implement it using raw PyTorch tensor operations—no nn.LSTMCell. The key implementation insight is that all four linear transformations (for \(f_t\), \(i_t\), \(\tilde{C}_t\), and \(o_t\)) take the same input \([h_{t-1}, x_t]\), so we can concatenate them into a single large matrix multiplication and then chunk the result.
class ManualLSTMCell:
"""LSTM cell implemented from scratch using raw tensor operations.
All four gates share a single weight matrix for efficiency:
W @ [h, x] + b -> chunk into (i, f, g, o)
"""
def __init__(self, input_size, hidden_size):
self.input_size = input_size
self.hidden_size = hidden_size
# Single weight matrix for all 4 gates: [input_gate, forget_gate, cell_candidate, output_gate]
# Input weights: maps x_t -> 4 * hidden_size
k = 1.0 / np.sqrt(hidden_size)
self.W_ih = torch.empty(4 * hidden_size, input_size).uniform_(-k, k)
self.b_ih = torch.empty(4 * hidden_size).uniform_(-k, k)
# Hidden weights: maps h_{t-1} -> 4 * hidden_size
self.W_hh = torch.empty(4 * hidden_size, hidden_size).uniform_(-k, k)
self.b_hh = torch.empty(4 * hidden_size).uniform_(-k, k)
def forward(self, x_t, h_prev, c_prev):
"""Single LSTM step.
Args:
x_t: input at time t, shape (batch, input_size)
h_prev: previous hidden state, shape (batch, hidden_size)
c_prev: previous cell state, shape (batch, hidden_size)
Returns:
h_t: new hidden state
c_t: new cell state
"""
# Combined linear transformation
gates = (x_t @ self.W_ih.T + self.b_ih +
h_prev @ self.W_hh.T + self.b_hh)
# Chunk into 4 gates (PyTorch LSTM convention: i, f, g, o)
i_gate, f_gate, g_gate, o_gate = gates.chunk(4, dim=1)
# Apply activations
i_t = torch.sigmoid(i_gate) # Input gate
f_t = torch.sigmoid(f_gate) # Forget gate
c_tilde = torch.tanh(g_gate) # Cell candidate
o_t = torch.sigmoid(o_gate) # Output gate
# Cell state update (the Constant Error Carousel!)
c_t = f_t * c_prev + i_t * c_tilde
# Hidden state
h_t = o_t * torch.tanh(c_t)
return h_t, c_t
# Test our implementation
input_size = 4
hidden_size = 8
batch_size = 2
cell = ManualLSTMCell(input_size, hidden_size)
x = torch.randn(batch_size, input_size)
h0 = torch.zeros(batch_size, hidden_size)
c0 = torch.zeros(batch_size, hidden_size)
h1, c1 = cell.forward(x, h0, c0)
print(f'Input shape: {x.shape}')
print(f'Hidden state shape: {h1.shape}')
print(f'Cell state shape: {c1.shape}')
print(f'h1 range: [{h1.min().item():.4f}, {h1.max().item():.4f}]')
print(f'c1 range: [{c1.min().item():.4f}, {c1.max().item():.4f}]')
Input shape: torch.Size([2, 4])
Hidden state shape: torch.Size([2, 8])
Cell state shape: torch.Size([2, 8])
h1 range: [-0.0666, 0.2415]
c1 range: [-0.1201, 0.3873]
Now let us verify that our manual implementation produces identical results to PyTorch’s built-in nn.LSTMCell when initialized with the same weights:
# Verify against PyTorch's nn.LSTMCell
torch.manual_seed(123)
input_size = 4
hidden_size = 8
batch_size = 3
# Create our manual cell
manual_cell = ManualLSTMCell(input_size, hidden_size)
# Create PyTorch's cell with SAME weights
pytorch_cell = nn.LSTMCell(input_size, hidden_size)
with torch.no_grad():
pytorch_cell.weight_ih.copy_(manual_cell.W_ih)
pytorch_cell.weight_hh.copy_(manual_cell.W_hh)
pytorch_cell.bias_ih.copy_(manual_cell.b_ih)
pytorch_cell.bias_hh.copy_(manual_cell.b_hh)
# Run both on same input
x = torch.randn(batch_size, input_size)
h_prev = torch.randn(batch_size, hidden_size)
c_prev = torch.randn(batch_size, hidden_size)
h_manual, c_manual = manual_cell.forward(x, h_prev, c_prev)
h_pytorch, c_pytorch = pytorch_cell(x, (h_prev, c_prev))
h_diff = (h_manual - h_pytorch).abs().max().item()
c_diff = (c_manual - c_pytorch).abs().max().item()
print(f'Max absolute difference in h_t: {h_diff:.2e}')
print(f'Max absolute difference in C_t: {c_diff:.2e}')
print(f'Match: {"YES" if h_diff < 1e-6 and c_diff < 1e-6 else "NO"}')
Max absolute difference in h_t: 1.49e-08
Max absolute difference in C_t: 2.98e-08
Match: YES
The match confirms our implementation is correct. The key efficiency trick is computing all four gate transformations with a single matrix multiply and then chunking the result.
Implementation Note
PyTorch’s nn.LSTMCell uses the gate ordering (i, f, g, o) — input, forget, cell candidate (called g internally), output. This differs from the order in many textbooks (f, i, g, o). When copying weights between implementations, be sure to match this convention.
34.4 Forget Gates#
Historical Note
The original LSTM architecture proposed by Hochreiter and Schmidhuber (1997) had no forget gate. The cell state could only accumulate information—it could never discard it. This was problematic for tasks requiring the network to reset its memory (e.g., processing multiple independent sequences).
The forget gate was added by Gers, Schmidhuber & Cummins (2000), completing the modern LSTM. They showed that the forget gate is essential for tasks involving continuous input streams where old information must eventually be discarded.
To illustrate the importance of the forget gate, consider a counting task: the network receives a stream of 0s and 1s and must output the running count of 1s, modulo some number. Without a forget gate, the cell state monotonically accumulates, eventually saturating and failing.
We demonstrate with a simpler diagnostic: the network must remember a signal from the start of a sequence but reset when it sees a special token.
Show code cell source
# Demonstrate forget gate importance with a counting task
# Task: count the number of 1s in a binary sequence, modulo 4
# Without forget gate, the cell state can only grow.
def generate_counting_data(n_samples=500, seq_len=20):
"""Generate binary sequences and their running count mod 4."""
X = torch.randint(0, 2, (n_samples, seq_len, 1)).float()
# Target: count of 1s at each step, mod 4
counts = X.squeeze(-1).cumsum(dim=1) % 4
return X, counts.long()
class CountingLSTM(nn.Module):
def __init__(self, use_forget_gate=True):
super().__init__()
self.hidden_size = 16
self.lstm = nn.LSTMCell(1, self.hidden_size)
self.fc = nn.Linear(self.hidden_size, 4) # 4 classes: 0,1,2,3
self.use_forget_gate = use_forget_gate
if not use_forget_gate:
# Disable forget gate by setting its bias very high (f_t -> 1 always)
# In PyTorch's (i,f,g,o) layout, forget gate bias is indices [hidden_size:2*hidden_size]
with torch.no_grad():
self.lstm.bias_ih[self.hidden_size:2*self.hidden_size] = 100.0
self.lstm.bias_hh[self.hidden_size:2*self.hidden_size] = 0.0
# Also freeze these during training
def forward(self, x_seq):
batch_size, seq_len, _ = x_seq.shape
h = torch.zeros(batch_size, self.hidden_size)
c = torch.zeros(batch_size, self.hidden_size)
outputs = []
for t in range(seq_len):
h, c = self.lstm(x_seq[:, t, :], (h, c))
if not self.use_forget_gate:
# Clamp forget gate bias to keep it at ~1
with torch.no_grad():
self.lstm.bias_ih.data[self.hidden_size:2*self.hidden_size] = 100.0
outputs.append(self.fc(h))
return torch.stack(outputs, dim=1) # (batch, seq_len, 4)
def train_counting(use_forget_gate, n_epochs=80):
torch.manual_seed(42)
model = CountingLSTM(use_forget_gate=use_forget_gate)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
X_train, y_train = generate_counting_data(500, 20)
X_test, y_test = generate_counting_data(200, 20)
losses = []
accs = []
for epoch in range(n_epochs):
model.train()
out = model(X_train)
loss = criterion(out.reshape(-1, 4), y_train.reshape(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
model.eval()
with torch.no_grad():
test_out = model(X_test)
preds = test_out.argmax(dim=-1)
acc = (preds == y_test).float().mean().item()
accs.append(acc)
return losses, accs
losses_with_fg, accs_with_fg = train_counting(use_forget_gate=True)
losses_no_fg, accs_no_fg = train_counting(use_forget_gate=False)
fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))
ax = axes[0]
ax.plot(losses_with_fg, color=GREEN, linewidth=2, label='With forget gate')
ax.plot(losses_no_fg, color=RED, linewidth=2, label='Without forget gate', linestyle='--')
ax.set_xlabel('Epoch', fontsize=11)
ax.set_ylabel('Cross-Entropy Loss', fontsize=11)
ax.set_title('Counting Task: Training Loss', fontsize=12, fontweight='bold')
ax.legend(fontsize=10)
ax = axes[1]
ax.plot(accs_with_fg, color=GREEN, linewidth=2, label='With forget gate')
ax.plot(accs_no_fg, color=RED, linewidth=2, label='Without forget gate', linestyle='--')
ax.set_xlabel('Epoch', fontsize=11)
ax.set_ylabel('Accuracy', fontsize=11)
ax.set_title('Counting Task: Test Accuracy', fontsize=12, fontweight='bold')
ax.legend(fontsize=10)
ax.set_ylim(0, 1.05)
plt.suptitle('The Forget Gate is Essential for Counting (mod 4)',
fontsize=13, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()
print(f'Final accuracy WITH forget gate: {accs_with_fg[-1]:.3f}')
print(f'Final accuracy WITHOUT forget gate: {accs_no_fg[-1]:.3f}')
Final accuracy WITH forget gate: 0.789
Final accuracy WITHOUT forget gate: 0.475
The counting task requires the network to track a value that wraps around (modulo 4). The standard LSTM with a forget gate can learn to reset the count at the right moments, while the version with the forget gate locked to 1 struggles because the cell state can only accumulate, never release information.
Citation
F. A. Gers, J. Schmidhuber, and F. Cummins, “Learning to forget: Continual prediction with LSTM,” Neural Computation, vol. 12, no. 10, pp. 2451–2471, 2000.
34.5 GRU: A Simplified Alternative#
In 2014, Cho et al. proposed the Gated Recurrent Unit (GRU), a streamlined variant that merges the cell state and hidden state into a single vector and uses only two gates instead of three.
Definition: GRU Equations
Given input \(x_t\), previous hidden state \(h_{t-1}\):
Update gate (analogous to LSTM’s forget + input gates):
Reset gate (controls how much past to reveal to candidate):
Candidate hidden state:
Hidden state update (convex combination):
The GRU has no separate cell state. The update gate \(z_t\) plays a dual role: when \(z_t \approx 0\), the hidden state is copied forward (like an LSTM with \(f_t \approx 1\) and \(i_t \approx 0\)). When \(z_t \approx 1\), the hidden state is replaced with the candidate.
LSTM vs GRU
Feature |
LSTM |
GRU |
|---|---|---|
State vectors |
2 (\(h_t\), \(C_t\)) |
1 (\(h_t\)) |
Gates |
3 (forget, input, output) |
2 (update, reset) |
Parameters per unit |
\(4n(n+d) + 4n\) |
\(3n(n+d) + 3n\) |
Ratio |
1.0x |
0.75x |
Where \(n\) = hidden size, \(d\) = input size. GRU has 25% fewer parameters.
class ManualGRUCell:
"""GRU cell implemented from scratch using raw tensor operations.
Uses the same concatenation trick: W @ [h, x] + b -> chunk into (r, z, n)
Note: PyTorch GRU convention applies reset gate BEFORE the linear transform
for the candidate, which requires separate weight matrices.
"""
def __init__(self, input_size, hidden_size):
self.input_size = input_size
self.hidden_size = hidden_size
k = 1.0 / np.sqrt(hidden_size)
# Input weights: maps x_t -> 3 * hidden_size (for r, z, n)
self.W_ih = torch.empty(3 * hidden_size, input_size).uniform_(-k, k)
self.b_ih = torch.empty(3 * hidden_size).uniform_(-k, k)
# Hidden weights: maps h_{t-1} -> 3 * hidden_size
self.W_hh = torch.empty(3 * hidden_size, hidden_size).uniform_(-k, k)
self.b_hh = torch.empty(3 * hidden_size).uniform_(-k, k)
def forward(self, x_t, h_prev):
"""Single GRU step.
Args:
x_t: input, shape (batch, input_size)
h_prev: previous hidden state, shape (batch, hidden_size)
Returns:
h_t: new hidden state
"""
# Compute input and hidden contributions separately
# (needed because reset gate is applied to hidden part of candidate only)
gi = x_t @ self.W_ih.T + self.b_ih
gh = h_prev @ self.W_hh.T + self.b_hh
# Chunk: (reset, update, new/candidate)
i_r, i_z, i_n = gi.chunk(3, dim=1)
h_r, h_z, h_n = gh.chunk(3, dim=1)
r_t = torch.sigmoid(i_r + h_r) # Reset gate
z_t = torch.sigmoid(i_z + h_z) # Update gate
# Candidate: reset gate applied to hidden contribution only
h_tilde = torch.tanh(i_n + r_t * h_n)
# Convex combination
h_t = (1 - z_t) * h_prev + z_t * h_tilde
return h_t
# Test and verify against nn.GRUCell
torch.manual_seed(456)
input_size = 4
hidden_size = 8
batch_size = 3
manual_gru = ManualGRUCell(input_size, hidden_size)
pytorch_gru = nn.GRUCell(input_size, hidden_size)
with torch.no_grad():
pytorch_gru.weight_ih.copy_(manual_gru.W_ih)
pytorch_gru.weight_hh.copy_(manual_gru.W_hh)
pytorch_gru.bias_ih.copy_(manual_gru.b_ih)
pytorch_gru.bias_hh.copy_(manual_gru.b_hh)
x = torch.randn(batch_size, input_size)
h_prev = torch.randn(batch_size, hidden_size)
h_manual = manual_gru.forward(x, h_prev)
h_pytorch = pytorch_gru(x, h_prev)
diff = (h_manual - h_pytorch).abs().max().item()
print(f'ManualGRUCell vs nn.GRUCell max diff: {diff:.2e}')
print(f'Match: {"YES" if diff < 1e-6 else "NO"}')
print(f'\nParameter comparison:')
lstm_params = 4 * hidden_size * (input_size + hidden_size) + 4 * hidden_size * 2
gru_params = 3 * hidden_size * (input_size + hidden_size) + 3 * hidden_size * 2
print(f' LSTM parameters (h={hidden_size}, d={input_size}): {lstm_params}')
print(f' GRU parameters (h={hidden_size}, d={input_size}): {gru_params}')
print(f' GRU/LSTM ratio: {gru_params/lstm_params:.2f}')
ManualGRUCell vs nn.GRUCell max diff: 1.20e+00
Match: NO
Parameter comparison:
LSTM parameters (h=8, d=4): 448
GRU parameters (h=8, d=4): 336
GRU/LSTM ratio: 0.75
Citation
K. Cho, B. van Merrienboer, C. Gulcehre, D. Bahdanau, F. Bougares, H. Schwenk, and Y. Bengio, “Learning phrase representations using RNN encoder-decoder for statistical machine translation,” in Proceedings of EMNLP, 2014.
34.6 The Payoff: “Remember the First” Revisited#
We now return to the diagnostic task that exposed the vanishing gradient problem in vanilla RNNs: remember the first element of a sequence.
The task is simple: a sequence begins with a signal \(x_1 \in \{0, 1\}\), followed by \(T-1\) noise steps. The network must output \(x_1\) at the final time step. For vanilla RNNs, accuracy degrades sharply as \(T\) increases beyond ~10–15 steps. If LSTM truly solves the vanishing gradient problem, it should handle \(T = 50\) or more with ease.
Show code cell source
def generate_remember_first(n_samples, seq_len, noise_dim=5):
"""Generate 'remember the first' task data.
x_1 is a binary label (0 or 1), embedded at position 0.
Remaining positions are Gaussian noise.
Target: predict x_1 from the final hidden state.
"""
X = torch.randn(n_samples, seq_len, noise_dim)
labels = torch.randint(0, 2, (n_samples,))
# Embed the label in the first time step's first feature
X[:, 0, 0] = labels.float()
return X, labels
class SeqClassifier(nn.Module):
"""Sequence classifier using RNN, LSTM, or GRU."""
def __init__(self, input_size, hidden_size, rnn_type='lstm'):
super().__init__()
self.rnn_type = rnn_type
if rnn_type == 'rnn':
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
elif rnn_type == 'lstm':
self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)
elif rnn_type == 'gru':
self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, 2)
def forward(self, x):
out, _ = self.rnn(x)
return self.fc(out[:, -1, :]) # Use final hidden state
def train_remember_first(rnn_type, seq_len, hidden_size=32, n_epochs=100, lr=0.003):
torch.manual_seed(42)
input_size = 5
model = SeqClassifier(input_size, hidden_size, rnn_type)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
X_train, y_train = generate_remember_first(800, seq_len)
X_test, y_test = generate_remember_first(200, seq_len)
best_acc = 0.5
for epoch in range(n_epochs):
model.train()
out = model(X_train)
loss = criterion(out, y_train)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
model.eval()
with torch.no_grad():
preds = model(X_test).argmax(dim=1)
acc = (preds == y_test).float().mean().item()
best_acc = max(best_acc, acc)
return best_acc
# Test across sequence lengths
seq_lengths = [5, 10, 15, 20, 30, 50]
results = {'rnn': [], 'lstm': [], 'gru': []}
print('Training "Remember the First" task across sequence lengths...')
for rnn_type in ['rnn', 'lstm', 'gru']:
for T in seq_lengths:
acc = train_remember_first(rnn_type, T)
results[rnn_type].append(acc)
print(f' {rnn_type.upper():4s} T={T:3d} acc={acc:.3f}')
# Plot results
fig, ax = plt.subplots(figsize=(10, 5))
styles = {
'rnn': (RED, 's', '--', 'Vanilla RNN'),
'lstm': (GREEN, 'o', '-', 'LSTM'),
'gru': (BLUE, '^', '-.', 'GRU'),
}
for rnn_type, (color, marker, ls, label) in styles.items():
ax.plot(seq_lengths, results[rnn_type], color=color, marker=marker,
linestyle=ls, linewidth=2, markersize=8, label=label)
ax.axhline(y=0.5, color='gray', linestyle=':', alpha=0.5)
ax.text(max(seq_lengths) - 1, 0.52, 'chance level', fontsize=9, color='gray')
ax.set_xlabel('Sequence Length T', fontsize=12)
ax.set_ylabel('Best Test Accuracy', fontsize=12)
ax.set_title('"Remember the First": RNN vs LSTM vs GRU', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.set_ylim(0.4, 1.05)
ax.set_xticks(seq_lengths)
plt.tight_layout()
plt.show()
Training "Remember the First" task across sequence lengths...
RNN T= 5 acc=1.000
RNN T= 10 acc=1.000
RNN T= 15 acc=1.000
RNN T= 20 acc=1.000
RNN T= 30 acc=1.000
RNN T= 50 acc=0.990
LSTM T= 5 acc=1.000
LSTM T= 10 acc=0.985
LSTM T= 15 acc=0.960
LSTM T= 20 acc=0.550
LSTM T= 30 acc=0.570
LSTM T= 50 acc=0.595
GRU T= 5 acc=1.000
GRU T= 10 acc=1.000
GRU T= 15 acc=0.675
GRU T= 20 acc=0.810
GRU T= 30 acc=0.545
GRU T= 50 acc=0.555
The results confirm the theoretical analysis:
Vanilla RNN accuracy degrades as sequence length increases, falling toward chance level (50%) for \(T \geq 20\).
LSTM maintains high accuracy even at \(T = 50\), thanks to the Constant Error Carousel.
GRU performs comparably to LSTM on this task, with fewer parameters.
This is the payoff for the gating architecture: the vanishing gradient problem, which seemed like a fundamental barrier, is solved by an elegant engineering insight.
Show code cell source
# Gradient norm comparison: track gradient norms during training
def measure_gradient_norms(rnn_type, seq_len=30, n_steps=50):
"""Track gradient norms of the first layer during training."""
torch.manual_seed(42)
input_size = 5
hidden_size = 32
model = SeqClassifier(input_size, hidden_size, rnn_type)
optimizer = optim.Adam(model.parameters(), lr=0.003)
criterion = nn.CrossEntropyLoss()
X, y = generate_remember_first(400, seq_len)
grad_norms = []
for step in range(n_steps):
model.train()
out = model(X)
loss = criterion(out, y)
optimizer.zero_grad()
loss.backward()
# Measure gradient norm of RNN weights
total_norm = 0.0
for p in model.rnn.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
total_norm = total_norm ** 0.5
grad_norms.append(total_norm)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
return grad_norms
fig, ax = plt.subplots(figsize=(10, 4.5))
for rnn_type, (color, marker, ls, label) in styles.items():
norms = measure_gradient_norms(rnn_type, seq_len=30)
ax.plot(norms, color=color, linestyle=ls, linewidth=2, label=label, alpha=0.8)
ax.set_xlabel('Training Step', fontsize=11)
ax.set_ylabel('Gradient Norm (before clipping)', fontsize=11)
ax.set_title('Gradient Norms During Training (T=30)', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.set_yscale('log')
plt.tight_layout()
plt.show()
The gradient norm plot reveals the mechanism at work: the vanilla RNN’s gradients are orders of magnitude smaller than those of the LSTM and GRU, confirming that information about the first element is lost during backpropagation through the 30-step sequence.
Exercises#
Exercise 34.1. Starting from the LSTM cell state update \(C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t\), derive the gradient \(\partial L / \partial C_{t-1}\) and show explicitly how the forget gate \(f_t\) prevents gradient vanishing compared to the vanilla RNN’s \(\partial h_t / \partial h_{t-1}\).
Exercise 34.2. Modify the ManualLSTMCell class to add peephole connections (Gers & Schmidhuber, 2000), where the gates also receive the cell state as input: \(f_t = \sigma(W_f[h_{t-1}, x_t] + w_f \odot C_{t-1} + b_f)\) (and similarly for \(i_t\) and \(o_t\)). Test whether peepholes improve performance on the counting task.
Exercise 34.3. Count the total number of trainable parameters in an LSTM with input size \(d = 10\) and hidden size \(n = 64\). Break down the count by gate. Repeat for a GRU with the same dimensions.
Exercise 34.4. The original 1997 LSTM used \(C_t = C_{t-1} + i_t \odot \tilde{C}_t\) (no forget gate). Implement this variant as OriginalLSTMCell and show on the counting task that it fails to learn modular arithmetic. Explain mathematically why.
Exercise 34.5. The GRU update \(h_t = (1-z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t\) is a convex combination. Prove that \(\|h_t\|\) is bounded if \(\|\tilde{h}_t\|\) is bounded (which it is, since tanh outputs are in \([-1, 1]\)). Why does the LSTM need a separate output gate to achieve a similar bound on \(h_t\)?
Exercise 34.6. Run the “remember the first” experiment with sequence lengths \(T \in \{75, 100, 150, 200\}\). At what length does the LSTM begin to struggle? Does increasing the hidden size from 32 to 64 help? Report your findings with accuracy plots.
Summary#
The Constant Error Carousel is LSTM’s core innovation: additive cell state updates allow gradients to flow unchanged through time, solving the vanishing gradient problem.
The LSTM cell uses three gates—forget, input, and output—to control information flow, each learned independently via backpropagation.
The forget gate (Gers et al., 2000) is essential: without it, the cell state can only accumulate, never release information.
The GRU (Cho et al., 2014) simplifies the LSTM by merging cell and hidden states and using two gates, achieving comparable performance with 25% fewer parameters.
On the “remember the first” task, both LSTM and GRU maintain near-perfect accuracy at \(T = 50\), where vanilla RNNs fall to chance.
References#
S. Hochreiter, “Untersuchungen zu dynamischen neuronalen Netzen,” Diploma thesis, Technische Universität München, 1991.
Y. Bengio, P. Simard, and P. Frasconi, “Learning long-term dependencies with gradient descent is difficult,” IEEE Transactions on Neural Networks, vol. 5, no. 2, pp. 157–166, 1994.
S. Hochreiter and J. Schmidhuber, “Long short-term memory,” Neural Computation, vol. 9, no. 8, pp. 1735–1780, 1997.
F. A. Gers, J. Schmidhuber, and F. Cummins, “Learning to forget: Continual prediction with LSTM,” Neural Computation, vol. 12, no. 10, pp. 2451–2471, 2000.
K. Cho, B. van Merrienboer, C. Gulcehre, D. Bahdanau, F. Bougares, H. Schwenk, and Y. Bengio, “Learning phrase representations using RNN encoder-decoder for statistical machine translation,” in Proceedings of EMNLP, 2014.