Key idea

Each node updates itself based on its neighbours. Like a social network of nodes passing notes — every round, each node reads its neighbours' messages, combines them with its own state, and updates. Do this a few rounds and even far-away nodes' information has rippled through.

Click any node to make it the source · then "Step" or "Auto" to propagate
Round 0

After K rounds of message-passing, each node's representation reflects information from nodes up to K hops away. That's why GNNs typically use 2–4 layers — past a few hops, the graph starts to over-smooth and every node ends up looking the same. Click a node on the edge of the graph and watch how many rounds it takes for the signal to reach the opposite side.

A graph is just nodes and edges. Many real datasets are graphs: atoms connected by bonds (molecules), people connected by friendships (social networks), papers connected by citations (academic graphs), even code where functions call each other.

GNNs let each node have a feature vector, then repeatedly update each node by mixing in the features of its neighbours. After L rounds, every node's representation reflects information from nodes up to L hops away.

Reach for it when

  • Molecular property prediction (drug discovery, materials)
  • Recommendation on bipartite user-item graphs
  • Knowledge graph completion
  • Any task where the structure of relationships matters

Skip it when

  • Your data isn't naturally a graph (don't force it)
  • The graph is huge and the structure is incidental
  • You only have node features, no informative edges
  • Long-range dependencies dominate — try a graph transformer
import torch.nn as nn
from torch_geometric.nn import GCNConv

class GCN(nn.Module):
    def __init__(self, in_dim, hidden, out_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden)
        self.conv2 = GCNConv(hidden, out_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)
Want the message-passing math?
Message passing $$ \mathbf{h}_v^{(\ell+1)} \;=\; \phi\!\left(\mathbf{h}_v^{(\ell)},\; \bigoplus_{u \in \mathcal{N}(v)} \psi(\mathbf{h}_u^{(\ell)}, \mathbf{h}_v^{(\ell)}, \mathbf{e}_{uv})\right) $$
  • hv(ℓ)feature vector of node v at layer
  • 𝒩(v)neighbours of node v
  • ψ, φlearned message and update functions
  • permutation-invariant aggregation (sum, mean, max)

The recipe. Every node sends a message to its neighbours (the message is a function of sender, receiver, and edge features). Each node aggregates incoming messages via a permutation-invariant operation (because graphs have no node ordering). The aggregated message is combined with the node's own state to produce its new representation.

Variants. GCN (Kipf & Welling) uses normalized sum aggregation and a single linear transform. GraphSAGE samples a fixed number of neighbours for scalability. GAT uses attention to weight neighbour contributions. GIN uses sum aggregation + an MLP — provably the most expressive among standard variants.

Tasks. Node classification (predict a label per node), link prediction (predict whether an edge should exist), graph classification (one prediction for the whole graph, usually via pooling), graph generation.

Readout / pooling. For graph-level prediction, aggregate all node embeddings into one vector (mean, sum, max, or attention-based). Hierarchical pooling (DiffPool) clusters nodes into super-nodes and recurses.

Over-smoothing. After many message-passing rounds, node representations converge toward the same value — they all average each other. Standard fix: use only 2–4 layers. Modern: skip connections, jumping knowledge networks, or graph transformers that avoid the issue.

Reach for it when

  • Molecule / protein property prediction with chemical structure
  • Recommendation systems with user-item bipartite graphs
  • Citation / co-authorship analysis
  • Traffic / road-network forecasting

Skip it when

  • The graph adds no signal — try a baseline that ignores it first
  • Very long-range dependencies — vanilla GNNs over-smooth, use a graph transformer
  • The graph changes faster than you can train — consider streaming methods
  • You need probabilistic uncertainty on outputs (use a Bayesian GNN)
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, global_mean_pool

class MolPropertyGNN(torch.nn.Module):
    def __init__(self, in_dim, hidden, out_dim, heads=4):
        super().__init__()
        self.conv1 = GATv2Conv(in_dim, hidden, heads=heads)
        self.conv2 = GATv2Conv(hidden * heads, hidden, heads=1)
        self.head  = torch.nn.Linear(hidden, out_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.elu(self.conv1(x, edge_index))
        x = F.elu(self.conv2(x, edge_index))
        # Pool per graph (batch holds graph index for each node)
        x = global_mean_pool(x, batch)
        return self.head(x)
Want the Weisfeiler-Leman story and graph transformers?
WL-1 expressiveness ceiling $$ \text{Standard GNNs} \;\preceq\; \text{1-WL test}\;<\; \text{graph isomorphism} $$
  • Standard message-passing GNNs can only distinguish graphs that the 1-dimensional Weisfeiler-Leman colour-refinement test can distinguish
  • Some non-isomorphic graphs share WL-1 colourings — no standard GNN can tell them apart

Expressiveness limits. Xu et al. (2019) showed that standard message-passing GNNs are at most as powerful as the 1-WL graph isomorphism test. There are simple non-isomorphic graphs (regular graphs with the same degree sequence) that no GNN of this form can distinguish. GIN (Graph Isomorphism Network) achieves the WL-1 upper bound via sum aggregation + an MLP.

Beyond WL-1. Higher-order GNNs use sets of k nodes as their basic unit, gaining expressiveness at O(Nk) cost. Equivariant GNNs (E(n)-GNN, EGNN) for molecules keep predictions invariant under rotations / translations of 3D coordinates. Subgraph-aware methods (SUN, GNN-AK) recover lost expressiveness by considering subgraphs around each node.

Graph transformers. Treat the graph as a fully-connected attention graph with structural inductive biases — distance encoding, edge encoding, or positional encoding from the Laplacian eigenvectors. Beats vanilla GNNs on larger / longer-range graphs. Examples: GraphGPS, Graphormer.

Over-smoothing & over-squashing. Over-smoothing: node features become indistinguishable after many layers. Over-squashing: information from a long-range node has to flow through narrow bottlenecks and gets crushed. Both motivate sparse attention or graph rewiring (adding extra long-range edges).

Scaling. For massive graphs (billions of edges), full-batch training is infeasible. Sampling-based methods (GraphSAGE, Cluster-GCN, GraphSAINT) train on subgraphs. Each samples differently — Cluster-GCN partitions; GraphSAINT samples nodes/edges; SAGE samples local neighbourhoods.

Reach for it when

  • EGNN: 3D structure-aware tasks (molecules, proteins, physics)
  • Graph transformers: long-range dependencies on small/medium graphs
  • Cluster-GCN / SAGE: web-scale graphs that don't fit in memory
  • Subgraph methods when WL-1 isn't enough

Skip it when

  • You're chasing the WL hierarchy but the dataset doesn't actually need it
  • Tabular / time-series — don't dress it up as a graph
  • You need bounded inference latency on a growing graph
  • You can solve the task with structural features + an MLP
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import scatter

# A bare-bones expressive GNN layer (GIN-style) — implemented from scratch
class GINLayer(nn.Module):
    def __init__(self, d_in, d_out, eps=0.0, learn_eps=True):
        super().__init__()
        self.eps = nn.Parameter(torch.tensor(eps)) if learn_eps else eps
        self.mlp = nn.Sequential(
            nn.Linear(d_in, d_out), nn.ReLU(),
            nn.Linear(d_out, d_out),
        )

    def forward(self, x, edge_index):
        # Sum-aggregate neighbours, then MLP — provably WL-1 expressive
        src, dst = edge_index
        agg = scatter(x[src], dst, dim=0, dim_size=x.size(0), reduce="sum")
        return self.mlp((1 + self.eps) * x + agg)
Too dense?