Draw a shape on the left canvas (or pick a preset). The drawing becomes a sequence of pen-stroke deltas — exactly the format Google's Sketch-RNN uses. An LSTM trains on this sequence and generates new sketches in your style.
Your Drawing
Train the LSTM
Loss Curve
Generated Sketch
What the LSTM Sees
Stroke deltas (Δx, Δy) as a 2D scatter. Each point = one step in the sequence.
GMM mixture weights π (last step)
Architecture in use — chapter map
| Component | Role | Course chapter / paper |
|---|---|---|
| Stroke-3 encoding | (Δx, Δy, pen-state) per step | Ch 32 — Sequences and memory; Ha & Eck (2017) |
| LSTM cell | Maintains state across the stroke sequence | Ch 34 — LSTM and gating |
| BPTT | Trains the recurrence end-to-end | Ch 33 — Backpropagation through time |
| Dense + ReLU head | Non-linear projection before output | Ch 17 — Activations |
| Mixture Density Network | 5 bivariate Gaussians for (Δx, Δy) — captures multimodal next-stroke | Bishop (1994); loss math = Ch 26 log-likelihood |
| 3-way softmax | Pen state {down, up, end-of-drawing} | Ch 26 — categorical CE |
| NLL loss | −log p(Δx, Δy | π, μ, Σ) per step | Ch 26 — maximum likelihood |
| Temperature sampling | Sharpen π and shrink σ at low T | Ch 35 — char-rnn temperature |
| Adam optimiser | Adaptive lr; gradient clipping | Ch 27 — Adam |
| Augmentation (rot+scale+shift) | Synthetic data multiplier on a single drawing | standard ML practice |
Why MDN and not plain regression? A circle's next stroke at any point can go two ways (clockwise or counter-clockwise) — the conditional distribution over (Δx, Δy) is multimodal. Plain MSE regression would average them and produce a degenerate "stay still" prediction. The Gaussian mixture lets the network represent multiple plausible directions and sample one.
How does this work? (long version)
Each pen stroke is encoded as a 5-vector (Δx, Δy, p_down, p_up, p_end) where the last three are a one-hot pen-state. The LSTM (Ch 34) reads the sequence and maintains a hidden state. A small dense head produces 33 numbers per step:
- 30 numbers (= 5 mixture components × 6 parameters) describe a Gaussian Mixture over next-(Δx, Δy): mixture weight π, means μx μy, scales σx σy, correlation ρ.
- 3 logits for the next pen-state (down / up / end).
Loss is the negative log-likelihood of the true next stroke under the mixture (a real likelihood, not MSE) plus categorical cross-entropy for the pen state. This is the loss from the original Sketch-RNN paper (Ha & Eck 2017), in turn going back to Bishop's Mixture Density Networks (1994).
Generation: at each step we sample a mixture component i with probability πi, then sample (Δx, Δy) from that bivariate Gaussian (with correlation), and sample a pen state from the softmax. We seed the first 8 steps with the actual input drawing so the LSTM starts with real context (rather than zeros).
Why this is closer to the real thing. The earlier MSE-regression version of this applet always produced blurry, drift-y reproductions because MSE collapses multimodal distributions to their mean. The MDN version represents "could go left or right" honestly and the samples reflect that.
Limitations. Real Sketch-RNN (Ha & Eck 2017) wraps an encoder-decoder VAE around this; we keep just the autoregressive decoder. Training data is also tiny — augmentation helps but a real sketch dataset (Quick, Draw!) has ~50M drawings; a single browser-trained shape has ~36 points.