← All Applets

Sketch-RNN Mini

Draw a shape, train, watch the network sketch on its own.

Draw a shape on the left canvas (or pick a preset). The drawing becomes a sequence of pen-stroke deltas — exactly the format Google's Sketch-RNN uses. An LSTM trains on this sequence and generates new sketches in your style.

Your Drawing

0 stroke points

Train the LSTM

Loss Curve

Draw or pick a preset, then Train.

Generated Sketch

0.4
150

What the LSTM Sees

Stroke deltas (Δx, Δy) as a 2D scatter. Each point = one step in the sequence.

GMM mixture weights π (last step)

Architecture in use — chapter map

stroke (5-vec) = (Δx, Δy, pdown, pup, pend)  →  LSTM(hidden)  →  Dense(64, ReLU)  →  MDN head  →  [ M Gaussians on (Δx, Δy) | 3 pen-state logits ]
Component Role Course chapter / paper
Stroke-3 encoding(Δx, Δy, pen-state) per stepCh 32 — Sequences and memory; Ha & Eck (2017)
LSTM cellMaintains state across the stroke sequenceCh 34 — LSTM and gating
BPTTTrains the recurrence end-to-endCh 33 — Backpropagation through time
Dense + ReLU headNon-linear projection before outputCh 17 — Activations
Mixture Density Network5 bivariate Gaussians for (Δx, Δy) — captures multimodal next-strokeBishop (1994); loss math = Ch 26 log-likelihood
3-way softmaxPen state {down, up, end-of-drawing}Ch 26 — categorical CE
NLL loss−log p(Δx, Δy | π, μ, Σ) per stepCh 26 — maximum likelihood
Temperature samplingSharpen π and shrink σ at low TCh 35 — char-rnn temperature
Adam optimiserAdaptive lr; gradient clippingCh 27 — Adam
Augmentation (rot+scale+shift)Synthetic data multiplier on a single drawingstandard ML practice

Why MDN and not plain regression? A circle's next stroke at any point can go two ways (clockwise or counter-clockwise) — the conditional distribution over (Δx, Δy) is multimodal. Plain MSE regression would average them and produce a degenerate "stay still" prediction. The Gaussian mixture lets the network represent multiple plausible directions and sample one.

How does this work? (long version)

Each pen stroke is encoded as a 5-vector (Δx, Δy, p_down, p_up, p_end) where the last three are a one-hot pen-state. The LSTM (Ch 34) reads the sequence and maintains a hidden state. A small dense head produces 33 numbers per step:

Loss is the negative log-likelihood of the true next stroke under the mixture (a real likelihood, not MSE) plus categorical cross-entropy for the pen state. This is the loss from the original Sketch-RNN paper (Ha & Eck 2017), in turn going back to Bishop's Mixture Density Networks (1994).

Generation: at each step we sample a mixture component i with probability πi, then sample (Δx, Δy) from that bivariate Gaussian (with correlation), and sample a pen state from the softmax. We seed the first 8 steps with the actual input drawing so the LSTM starts with real context (rather than zeros).

Why this is closer to the real thing. The earlier MSE-regression version of this applet always produced blurry, drift-y reproductions because MSE collapses multimodal distributions to their mean. The MDN version represents "could go left or right" honestly and the samples reflect that.

Limitations. Real Sketch-RNN (Ha & Eck 2017) wraps an encoder-decoder VAE around this; we keep just the autoregressive decoder. Training data is also tiny — augmentation helps but a real sketch dataset (Quick, Draw!) has ~50M drawings; a single browser-trained shape has ~36 points.

← Back to course