Key idea

Look at everything at once, weighted by relevance. Instead of processing tokens one at a time like an RNN, a transformer looks at every token in the sequence and asks: which other tokens matter for understanding this one? That's attention.

Pick a head · click any token to make it the query

The four heads above show the kinds of patterns real transformers actually learn: simple positional shifts, plus richer linguistic relationships like adjective→noun and verb→subject. A real model has dozens of these per layer, all running in parallel, and the next layer can compose their outputs. That composition is most of what makes transformers work. The explainers below run real models — go deeper there.

If an RNN is reading a book one word at a time and trying to remember the relevant parts, a transformer is photocopying every page and laying them out on a table — every word can directly look at every other word. This is slower per token (you can't shortcut) but much better at understanding long-range relationships.

The "Transformer" came from a 2017 paper called "Attention is All You Need". Every modern LLM (GPT, Claude, Gemini, Llama) is a transformer. So are most modern vision models (ViT, DiT), speech models, and code models.

Reach for it when

  • Language modelling, translation, summarization
  • Long-range dependencies in any sequence
  • You can afford O(N²) attention or use an efficient variant
  • Pretrained weights exist (and they do, for almost everything)

Skip it when

  • Very small data and no pretrained model
  • Strict memory budget — attention is quadratic in sequence length
  • Highly structured sequential output where an RNN suffices
  • You need true streaming with bounded state
from transformers import AutoModel, AutoTokenizer

tok = AutoTokenizer.from_pretrained("bert-base-uncased")
mdl = AutoModel.from_pretrained("bert-base-uncased")

inputs = tok("Transformers see every token at once.", return_tensors="pt")
outputs = mdl(**inputs)
# outputs.last_hidden_state: (1, seq_len, 768) — contextualized embeddings
Want self-attention math and positional encoding?
Scaled dot-product attention $$ \text{Attention}(Q, K, V) \;=\; \mathrm{softmax}\!\left(\frac{Q K^\top}{\sqrt{d_k}}\right) V $$
  • Q, K, Vqueries, keys, values — each token projected to three vectors
  • dkdimension of keys; the √dk scaling prevents softmax saturation
  • Row i of the output = weighted sum of values, with weights from how much query i matches each key

Self-attention. For each token, compute a query, key, and value vector (all three are linear projections of the token's embedding). Each token's output is a weighted average of all values, with weights given by how well its query matches each key.

Multi-head attention. Run several attention "heads" in parallel — each with its own Q/K/V projections — then concatenate. Different heads learn to attend to different patterns (syntax, coreference, position, …).

Positional encoding. Attention is permutation-invariant — without help, the model can't tell "the cat sat on the mat" from "the mat sat on the cat". Add a position vector to each token's embedding. Original sinusoidal encoding; modern models use rotary positional embeddings (RoPE) or ALiBi.

Block structure. A transformer "block" is: multi-head attention → residual + layer norm → MLP (feed-forward) → residual + layer norm. Stack N blocks. The MLP is usually 4× wider than the model dimension.

Encoder vs. decoder. Encoder-only (BERT, ViT): every token attends to every other — used for understanding tasks. Decoder-only (GPT, Llama): each token only attends to past tokens — used for generation. Encoder-decoder (T5, original Transformer): two stacks for seq-to-seq tasks.

Reach for it when

  • You can use a pretrained model (almost always the right choice)
  • Long-range dependencies are the bottleneck
  • You need to attend to specific positions (retrieval, alignment)
  • Multimodal — transformers compose across modalities easily

Skip it when

  • Memory budget is tight and sequences are long — try a sparse / linear attention variant
  • You need a small, fast model — distilled / pruned versions exist but consider RNNs / SSMs
  • Very short sequences with strong local structure — CNN may suffice
  • The task doesn't need contextualization (e.g. independent classification)
import torch, torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ln1  = nn.LayerNorm(d_model)
        self.ln2  = nn.LayerNorm(d_model)
        self.mlp  = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(),
            nn.Linear(d_ff, d_model), nn.Dropout(dropout),
        )

    def forward(self, x, mask=None):
        # Pre-norm variant: norm before sublayer, residual around it
        a, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x), attn_mask=mask)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x
Want the O(N²) tradeoffs, scaling laws, and modern variants?
Computational cost $$ \text{FLOPs}_{\text{attention}} \sim \mathcal{O}(N^2 \cdot d), \quad \text{FLOPs}_{\text{MLP}} \sim \mathcal{O}(N \cdot d^2) $$
  • Attention scales quadratically in sequence length N
  • MLPs scale linearly in N but quadratically in model width d
  • For long context, attention dominates; for small context with wide models, the MLP does

The O(N²) problem. Standard attention computes an N×N matrix — memory and compute both quadratic. For 8K tokens that's 64M attention entries per head per layer. Solutions: sparse attention (Longformer, BigBird), linear attention (Performer, Linformer), recomputation tricks (FlashAttention — same big-O but better memory), or alternative architectures (SSMs, Mamba).

Scaling laws. Kaplan et al. (2020) and Hoffmann et al. (2022, "Chinchilla") empirically found that loss falls as a power law in compute, parameters, and tokens. The "Chinchilla scaling" rule: optimal token count ≈ 20× parameter count. Most pre-Chinchilla models were severely under-trained on data.

RoPE (Rotary Positional Embedding). Encodes position by rotating query and key vectors in 2D subspaces. Naturally handles relative position, generalizes better to longer contexts than sinusoidal, and is used in Llama, GPT-NeoX, most modern models.

Mixture of Experts (MoE). Replace the MLP with multiple "expert" MLPs and a router that picks the top-k for each token. Increases total parameters without proportionally increasing compute. Used in Mixtral, GShard, Switch Transformer.

Inference tricks. KV cache (store keys/values from past tokens to skip recomputing). Quantization (int8 / int4 weights). Speculative decoding (run a small model to draft, large to verify). Continuous batching (group requests dynamically). These are how production LLM serving works.

What attention really does. It's a differentiable lookup: queries retrieve from a key-value store. This unifies a surprising amount of ML — retrieval, memory networks, set transformers, even some classical algorithms. The Anthropic "Mathematical Framework for Transformer Circuits" is the best modern source for understanding attention mechanistically.

Reach for it when

  • Anywhere you have sequential data and enough compute
  • Multi-modal fusion — attention generalizes across modalities
  • You need to embed and search (encoder transformer for embeddings)
  • You want a single architecture template for many tasks

Skip it when

  • Strict memory budget for sequences > a few thousand tokens — use efficient attention or SSMs
  • Online / true streaming — recurrent or SSM variants are cheaper
  • Adversarial robustness with strong guarantees — easier in smaller architectures
  • Fully interpretable per-prediction explanations needed
import torch
import torch.nn as nn

# Causal self-attention with KV cache (the inference-time pattern in LLMs)
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_head  = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x, cache=None):
        B, T, D = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        # Reshape to (B, n_heads, T, d_head)
        reshape = lambda t: t.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        q, k, v = map(reshape, (q, k, v))

        if cache is not None:
            k = torch.cat([cache[0], k], dim=2)
            v = torch.cat([cache[1], v], dim=2)
        new_cache = (k, v)

        # F.scaled_dot_product_attention is the fast FlashAttention-backed kernel
        attn_out = nn.functional.scaled_dot_product_attention(q, k, v, is_causal=cache is None)
        attn_out = attn_out.transpose(1, 2).contiguous().view(B, T, D)
        return self.out(attn_out), new_cache
Too dense?