Key idea

The same little network applied at every step, carrying a memory of what it saw before. Read one token, update memory, move on. Read another token, update memory again. After the whole sequence, the memory summarizes everything you've seen.

Watch the hidden state evolve as each token is read · click any step to inspect it

The same RNN cell runs at every timestep — note the "same weights" label on the first cell. The horizontal arrows are the hidden state being passed forward. Try the "long sequence" with Vanilla RNN: by the time we reach "was", the early hidden state has decayed nearly to zero — the classic vanishing gradient problem. Switch to LSTM-like and the memory survives much further.

RNNs were the default for sequences (language, speech, time series) for a decade until transformers came along. They have one key trick: the network's output at step t depends on a "hidden state" that was computed at step t−1. The hidden state is the model's memory.

Vanilla RNNs are simple but forget quickly. LSTMs and GRUs add gating mechanisms — small networks that decide what to remember and what to forget — and can carry information across hundreds of timesteps.

Reach for it when

  • Short to medium-length sequences with limited compute
  • On-device inference where a transformer is too heavy
  • Streaming data where you process one step at a time
  • Time-series forecasting with strong temporal structure

Skip it when

  • Long sequences with long-range dependencies — transformers win
  • You have enough data + compute for a transformer
  • You need bidirectional context simultaneously
  • The task is in language / vision — pretrained transformers dominate
import torch.nn as nn

# LSTM with 2 layers, hidden size 128
model = nn.LSTM(
    input_size=10,
    hidden_size=128,
    num_layers=2,
    batch_first=True,
)

# Inputs: (B, T, input_size); outputs: (B, T, hidden_size) + final (h, c)
outputs, (h_final, c_final) = model(inputs)
Want the math behind LSTM gates?
Vanilla RNN recursion $$ \mathbf{h}_t \;=\; \tanh\!\left(W_{xh}\, \mathbf{x}_t + W_{hh}\, \mathbf{h}_{t-1} + \mathbf{b}_h\right) $$
  • hthidden state at time t
  • xtinput at time t
  • Wxh, Whhinput-to-hidden and hidden-to-hidden weights, shared across timesteps

The vanishing gradient problem. Vanilla RNNs struggle with long sequences. Gradients propagated backward through time get multiplied by the same matrix over and over — they either shrink to zero (vanishing) or blow up (exploding). Long-range dependencies are lost.

LSTM (Hochreiter & Schmidhuber, 1997). Adds three gates — forget, input, output — and a separate "cell state" that travels along its own additive path. The additive path is the key trick: gradients flow back without the multiplicative collapse.

GRU (Cho et al., 2014). Simpler than LSTM — two gates, one state. Similar performance on most tasks, faster to train.

Bidirectional RNNs. Run one RNN left-to-right and another right-to-left; concatenate hidden states. Captures both past and future context — useful when you can see the whole sequence at once (e.g. translation, NER).

Sequence-to-sequence. Use one RNN as an encoder (consume the input, produce a final hidden state) and another as a decoder (start from that hidden state, generate the output). The bottleneck of the final hidden state motivated attention — and attention motivated transformers.

Reach for it when

  • Time-series forecasting with limited compute
  • Online / streaming classification (process one step at a time)
  • Small-data sequence tasks where a transformer overfits
  • Edge inference — LSTMs are much smaller than transformers

Skip it when

  • You can throw compute at it — a transformer will dominate
  • Very long sequences (>1000 tokens) with non-local structure
  • You want to parallelize across timesteps (RNNs are sequential)
  • Pretrained transformer encoders exist for your domain
import torch
import torch.nn as nn

class SeqClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, n_classes=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm  = nn.LSTM(embed_dim, hidden_dim, num_layers=2,
                             dropout=0.3, bidirectional=True, batch_first=True)
        self.head  = nn.Linear(2 * hidden_dim, n_classes)

    def forward(self, x):
        x = self.embed(x)
        _, (h, _) = self.lstm(x)
        # Concat the last hidden states from both directions of the top layer
        last = torch.cat([h[-2], h[-1]], dim=1)
        return self.head(last)
Want BPTT, gradient clipping, and attention's relationship to RNNs?
LSTM cell $$ \begin{aligned} \mathbf{f}_t &= \sigma(W_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f) \\ \mathbf{i}_t &= \sigma(W_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i) \\ \mathbf{o}_t &= \sigma(W_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o) \\ \tilde{\mathbf{c}}_t &= \tanh(W_c [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c) \\ \mathbf{c}_t &= \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t \\ \mathbf{h}_t &= \mathbf{o}_t \odot \tanh(\mathbf{c}_t) \end{aligned} $$
  • f, i, oforget, input, output gates (each in [0, 1])
  • ctcell state — the long-term memory traveling along an additive path
  • hthidden state — the short-term output

BPTT (Backpropagation Through Time). Unroll the RNN over T timesteps into a feedforward network of depth T; apply standard backprop. Memory cost grows linearly with T. For long sequences, use truncated BPTT: backprop over a fixed window and discard older gradients.

Exploding gradients. Even with LSTMs, gradients can blow up. The standard fix is gradient clipping: rescale the gradient if its norm exceeds a threshold. Effectively bounds the optimizer's step size.

Attention as a fix. Bahdanau et al. (2014) added attention to seq2seq RNNs: instead of compressing the entire input into a single hidden state, let the decoder peek back at all encoder states with learned weights. This was so much better that it eventually replaced the RNN entirely (Vaswani et al., 2017).

Modern RNN-like models. State Space Models (S4, Mamba) revisit RNN-style sequential processing but with carefully chosen state dynamics that enable parallel training and long-context handling. Competitive with transformers on some tasks, much cheaper per token.

Why not just train a deep transformer? Three reasons RNNs still appear: (1) streaming inference is natural; (2) small models for edge devices; (3) when you can't afford the O(N²) attention. For most other settings, transformers won.

Reach for it when

  • Strict memory budget for very long sequences (state is bounded, attention isn't)
  • Causal real-time inference where parallel decoding doesn't help
  • Theoretical / interpretability work on dynamical systems
  • You're prototyping with modern SSMs (Mamba) and want a baseline

Skip it when

  • You can parallelize — transformers train an order of magnitude faster
  • You have pretrained transformer weights for your domain
  • Need to attend to specific positions — explicit attention is better than implicit memory
  • Long-context with global dependencies
import torch, torch.nn as nn
import torch.nn.utils as utils

# Training loop with gradient clipping and truncated BPTT
model     = nn.LSTM(input_size=128, hidden_size=512, num_layers=2, batch_first=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
clip_norm = 1.0
truncate  = 200    # max BPTT length

for sequence in long_sequences:
    # Truncate into chunks of `truncate` steps, detach state between chunks
    state = None
    for chunk in sequence.split(truncate, dim=1):
        out, state = model(chunk, state)
        # Detach to cut the BPTT graph here:
        state = tuple(s.detach() for s in state)
        loss  = loss_fn(out, targets[chunk])
        optimizer.zero_grad()
        loss.backward()
        utils.clip_grad_norm_(model.parameters(), clip_norm)
        optimizer.step()
Too dense?