Attention-based architectures — the backbone of modern LLMs, vision models, and almost everything else.
Key idea
Look at everything at once, weighted by relevance. Instead of processing tokens one at a time like an RNN, a transformer looks at every token in the sequence and asks: which other tokens matter for understanding this one? That's attention.
Pick a head · click any token to make it the query
The four heads above show the kinds of patterns real transformers actually learn: simple positional shifts, plus richer linguistic relationships like adjective→noun and verb→subject. A real model has dozens of these per layer, all running in parallel, and the next layer can compose their outputs. That composition is most of what makes transformers work. The explainers below run real models — go deeper there.
If an RNN is reading a book one word at a time and trying to remember the relevant parts, a transformer is photocopying every page and laying them out on a table — every word can directly look at every other word. This is slower per token (you can't shortcut) but much better at understanding long-range relationships.
The "Transformer" came from a 2017 paper called "Attention is All You Need". Every modern LLM (GPT, Claude, Gemini, Llama) is a transformer. So are most modern vision models (ViT, DiT), speech models, and code models.
Reach for it when
Language modelling, translation, summarization
Long-range dependencies in any sequence
You can afford O(N²) attention or use an efficient variant
Pretrained weights exist (and they do, for almost everything)
Skip it when
Very small data and no pretrained model
Strict memory budget — attention is quadratic in sequence length
Highly structured sequential output where an RNN suffices
You need true streaming with bounded state
from transformers import AutoModel, AutoTokenizer
tok = AutoTokenizer.from_pretrained("bert-base-uncased")
mdl = AutoModel.from_pretrained("bert-base-uncased")
inputs = tok("Transformers see every token at once.", return_tensors="pt")
outputs = mdl(**inputs)
# outputs.last_hidden_state: (1, seq_len, 768) — contextualized embeddings
Q, K, Vqueries, keys, values — each token projected to three vectors
dkdimension of keys; the √dk scaling prevents softmax saturation
Row i of the output = weighted sum of values, with weights from how much query i matches each key
Self-attention. For each token, compute a query, key, and value vector (all three are linear projections of the token's embedding). Each token's output is a weighted average of all values, with weights given by how well its query matches each key.
Multi-head attention. Run several attention "heads" in parallel — each with its own Q/K/V projections — then concatenate. Different heads learn to attend to different patterns (syntax, coreference, position, …).
Positional encoding. Attention is permutation-invariant — without help, the model can't tell "the cat sat on the mat" from "the mat sat on the cat". Add a position vector to each token's embedding. Original sinusoidal encoding; modern models use rotary positional embeddings (RoPE) or ALiBi.
Block structure. A transformer "block" is: multi-head attention → residual + layer norm → MLP (feed-forward) → residual + layer norm. Stack N blocks. The MLP is usually 4× wider than the model dimension.
Encoder vs. decoder. Encoder-only (BERT, ViT): every token attends to every other — used for understanding tasks. Decoder-only (GPT, Llama): each token only attends to past tokens — used for generation. Encoder-decoder (T5, original Transformer): two stacks for seq-to-seq tasks.
Reach for it when
You can use a pretrained model (almost always the right choice)
Long-range dependencies are the bottleneck
You need to attend to specific positions (retrieval, alignment)
Multimodal — transformers compose across modalities easily
Skip it when
Memory budget is tight and sequences are long — try a sparse / linear attention variant
You need a small, fast model — distilled / pruned versions exist but consider RNNs / SSMs
Very short sequences with strong local structure — CNN may suffice
The task doesn't need contextualization (e.g. independent classification)
import torch, torch.nn as nn
class TransformerBlock(nn.Module):
def __init__(self, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(),
nn.Linear(d_ff, d_model), nn.Dropout(dropout),
)
def forward(self, x, mask=None):
# Pre-norm variant: norm before sublayer, residual around it
a, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x), attn_mask=mask)
x = x + a
x = x + self.mlp(self.ln2(x))
return x
Want the O(N²) tradeoffs, scaling laws, and modern variants?
Attention scales quadratically in sequence length N
MLPs scale linearly in N but quadratically in model width d
For long context, attention dominates; for small context with wide models, the MLP does
The O(N²) problem. Standard attention computes an N×N matrix — memory and compute both quadratic. For 8K tokens that's 64M attention entries per head per layer. Solutions: sparse attention (Longformer, BigBird), linear attention (Performer, Linformer), recomputation tricks (FlashAttention — same big-O but better memory), or alternative architectures (SSMs, Mamba).
Scaling laws. Kaplan et al. (2020) and Hoffmann et al. (2022, "Chinchilla") empirically found that loss falls as a power law in compute, parameters, and tokens. The "Chinchilla scaling" rule: optimal token count ≈ 20× parameter count. Most pre-Chinchilla models were severely under-trained on data.
RoPE (Rotary Positional Embedding). Encodes position by rotating query and key vectors in 2D subspaces. Naturally handles relative position, generalizes better to longer contexts than sinusoidal, and is used in Llama, GPT-NeoX, most modern models.
Mixture of Experts (MoE). Replace the MLP with multiple "expert" MLPs and a router that picks the top-k for each token. Increases total parameters without proportionally increasing compute. Used in Mixtral, GShard, Switch Transformer.
Inference tricks. KV cache (store keys/values from past tokens to skip recomputing). Quantization (int8 / int4 weights). Speculative decoding (run a small model to draft, large to verify). Continuous batching (group requests dynamically). These are how production LLM serving works.
What attention really does. It's a differentiable lookup: queries retrieve from a key-value store. This unifies a surprising amount of ML — retrieval, memory networks, set transformers, even some classical algorithms. The Anthropic "Mathematical Framework for Transformer Circuits" is the best modern source for understanding attention mechanistically.
Reach for it when
Anywhere you have sequential data and enough compute
Multi-modal fusion — attention generalizes across modalities
You need to embed and search (encoder transformer for embeddings)
You want a single architecture template for many tasks
Skip it when
Strict memory budget for sequences > a few thousand tokens — use efficient attention or SSMs
Online / true streaming — recurrent or SSM variants are cheaper
Adversarial robustness with strong guarantees — easier in smaller architectures