Networks with internal state for sequence data — LSTMs, GRUs, and the long-standing default for language before transformers.
Key idea
The same little network applied at every step, carrying a memory of what it saw before. Read one token, update memory, move on. Read another token, update memory again. After the whole sequence, the memory summarizes everything you've seen.
Watch the hidden state evolve as each token is read · click any step to inspect it
The same RNN cell runs at every timestep — note the "same weights" label on the first cell. The horizontal arrows are the hidden state being passed forward. Try the "long sequence" with Vanilla RNN: by the time we reach "was", the early hidden state has decayed nearly to zero — the classic vanishing gradient problem. Switch to LSTM-like and the memory survives much further.
RNNs were the default for sequences (language, speech, time series) for a decade until transformers came along. They have one key trick: the network's output at step t depends on a "hidden state" that was computed at step t−1. The hidden state is the model's memory.
Vanilla RNNs are simple but forget quickly. LSTMs and GRUs add gating mechanisms — small networks that decide what to remember and what to forget — and can carry information across hundreds of timesteps.
Reach for it when
Short to medium-length sequences with limited compute
On-device inference where a transformer is too heavy
Streaming data where you process one step at a time
Time-series forecasting with strong temporal structure
Skip it when
Long sequences with long-range dependencies — transformers win
You have enough data + compute for a transformer
You need bidirectional context simultaneously
The task is in language / vision — pretrained transformers dominate
import torch.nn as nn
# LSTM with 2 layers, hidden size 128
model = nn.LSTM(
input_size=10,
hidden_size=128,
num_layers=2,
batch_first=True,
)
# Inputs: (B, T, input_size); outputs: (B, T, hidden_size) + final (h, c)
outputs, (h_final, c_final) = model(inputs)
Wxh, Whhinput-to-hidden and hidden-to-hidden weights, shared across timesteps
The vanishing gradient problem. Vanilla RNNs struggle with long sequences. Gradients propagated backward through time get multiplied by the same matrix over and over — they either shrink to zero (vanishing) or blow up (exploding). Long-range dependencies are lost.
LSTM (Hochreiter & Schmidhuber, 1997). Adds three gates — forget, input, output — and a separate "cell state" that travels along its own additive path. The additive path is the key trick: gradients flow back without the multiplicative collapse.
GRU (Cho et al., 2014). Simpler than LSTM — two gates, one state. Similar performance on most tasks, faster to train.
Bidirectional RNNs. Run one RNN left-to-right and another right-to-left; concatenate hidden states. Captures both past and future context — useful when you can see the whole sequence at once (e.g. translation, NER).
Sequence-to-sequence. Use one RNN as an encoder (consume the input, produce a final hidden state) and another as a decoder (start from that hidden state, generate the output). The bottleneck of the final hidden state motivated attention — and attention motivated transformers.
Reach for it when
Time-series forecasting with limited compute
Online / streaming classification (process one step at a time)
Small-data sequence tasks where a transformer overfits
Edge inference — LSTMs are much smaller than transformers
Skip it when
You can throw compute at it — a transformer will dominate
Very long sequences (>1000 tokens) with non-local structure
You want to parallelize across timesteps (RNNs are sequential)
Pretrained transformer encoders exist for your domain
import torch
import torch.nn as nn
class SeqClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, n_classes=2):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=2,
dropout=0.3, bidirectional=True, batch_first=True)
self.head = nn.Linear(2 * hidden_dim, n_classes)
def forward(self, x):
x = self.embed(x)
_, (h, _) = self.lstm(x)
# Concat the last hidden states from both directions of the top layer
last = torch.cat([h[-2], h[-1]], dim=1)
return self.head(last)
Want BPTT, gradient clipping, and attention's relationship to RNNs?
f, i, oforget, input, output gates (each in [0, 1])
ctcell state — the long-term memory traveling along an additive path
hthidden state — the short-term output
BPTT (Backpropagation Through Time). Unroll the RNN over T timesteps into a feedforward network of depth T; apply standard backprop. Memory cost grows linearly with T. For long sequences, use truncated BPTT: backprop over a fixed window and discard older gradients.
Exploding gradients. Even with LSTMs, gradients can blow up. The standard fix is gradient clipping: rescale the gradient if its norm exceeds a threshold. Effectively bounds the optimizer's step size.
Attention as a fix. Bahdanau et al. (2014) added attention to seq2seq RNNs: instead of compressing the entire input into a single hidden state, let the decoder peek back at all encoder states with learned weights. This was so much better that it eventually replaced the RNN entirely (Vaswani et al., 2017).
Modern RNN-like models. State Space Models (S4, Mamba) revisit RNN-style sequential processing but with carefully chosen state dynamics that enable parallel training and long-context handling. Competitive with transformers on some tasks, much cheaper per token.
Why not just train a deep transformer? Three reasons RNNs still appear: (1) streaming inference is natural; (2) small models for edge devices; (3) when you can't afford the O(N²) attention. For most other settings, transformers won.
Reach for it when
Strict memory budget for very long sequences (state is bounded, attention isn't)
Causal real-time inference where parallel decoding doesn't help
Theoretical / interpretability work on dynamical systems
You're prototyping with modern SSMs (Mamba) and want a baseline
Skip it when
You can parallelize — transformers train an order of magnitude faster
You have pretrained transformer weights for your domain
Need to attend to specific positions — explicit attention is better than implicit memory
Long-context with global dependencies
import torch, torch.nn as nn
import torch.nn.utils as utils
# Training loop with gradient clipping and truncated BPTT
model = nn.LSTM(input_size=128, hidden_size=512, num_layers=2, batch_first=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
clip_norm = 1.0
truncate = 200 # max BPTT length
for sequence in long_sequences:
# Truncate into chunks of `truncate` steps, detach state between chunks
state = None
for chunk in sequence.split(truncate, dim=1):
out, state = model(chunk, state)
# Detach to cut the BPTT graph here:
state = tuple(s.detach() for s in state)
loss = loss_fn(out, targets[chunk])
optimizer.zero_grad()
loss.backward()
utils.clip_grad_norm_(model.parameters(), clip_norm)
optimizer.step()