Networks that process graph-structured data — molecules, social networks, knowledge graphs, code.
Key idea
Each node updates itself based on its neighbours. Like a social network of nodes passing notes — every round, each node reads its neighbours' messages, combines them with its own state, and updates. Do this a few rounds and even far-away nodes' information has rippled through.
Click any node to make it the source · then "Step" or "Auto" to propagate
Round 0
After K rounds of message-passing, each node's representation reflects information from nodes up to K hops away. That's why GNNs typically use 2–4 layers — past a few hops, the graph starts to over-smooth and every node ends up looking the same. Click a node on the edge of the graph and watch how many rounds it takes for the signal to reach the opposite side.
A graph is just nodes and edges. Many real datasets are graphs: atoms connected by bonds (molecules), people connected by friendships (social networks), papers connected by citations (academic graphs), even code where functions call each other.
GNNs let each node have a feature vector, then repeatedly update each node by mixing in the features of its neighbours. After L rounds, every node's representation reflects information from nodes up to L hops away.
The recipe. Every node sends a message to its neighbours (the message is a function of sender, receiver, and edge features). Each node aggregates incoming messages via a permutation-invariant operation (because graphs have no node ordering). The aggregated message is combined with the node's own state to produce its new representation.
Variants.GCN (Kipf & Welling) uses normalized sum aggregation and a single linear transform. GraphSAGE samples a fixed number of neighbours for scalability. GAT uses attention to weight neighbour contributions. GIN uses sum aggregation + an MLP — provably the most expressive among standard variants.
Tasks. Node classification (predict a label per node), link prediction (predict whether an edge should exist), graph classification (one prediction for the whole graph, usually via pooling), graph generation.
Readout / pooling. For graph-level prediction, aggregate all node embeddings into one vector (mean, sum, max, or attention-based). Hierarchical pooling (DiffPool) clusters nodes into super-nodes and recurses.
Over-smoothing. After many message-passing rounds, node representations converge toward the same value — they all average each other. Standard fix: use only 2–4 layers. Modern: skip connections, jumping knowledge networks, or graph transformers that avoid the issue.
Reach for it when
Molecule / protein property prediction with chemical structure
Recommendation systems with user-item bipartite graphs
Citation / co-authorship analysis
Traffic / road-network forecasting
Skip it when
The graph adds no signal — try a baseline that ignores it first
Very long-range dependencies — vanilla GNNs over-smooth, use a graph transformer
The graph changes faster than you can train — consider streaming methods
You need probabilistic uncertainty on outputs (use a Bayesian GNN)
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, global_mean_pool
class MolPropertyGNN(torch.nn.Module):
def __init__(self, in_dim, hidden, out_dim, heads=4):
super().__init__()
self.conv1 = GATv2Conv(in_dim, hidden, heads=heads)
self.conv2 = GATv2Conv(hidden * heads, hidden, heads=1)
self.head = torch.nn.Linear(hidden, out_dim)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.elu(self.conv1(x, edge_index))
x = F.elu(self.conv2(x, edge_index))
# Pool per graph (batch holds graph index for each node)
x = global_mean_pool(x, batch)
return self.head(x)
Want the Weisfeiler-Leman story and graph transformers?
Standard message-passing GNNs can only distinguish graphs that the 1-dimensional Weisfeiler-Leman colour-refinement test can distinguish
Some non-isomorphic graphs share WL-1 colourings — no standard GNN can tell them apart
Expressiveness limits. Xu et al. (2019) showed that standard message-passing GNNs are at most as powerful as the 1-WL graph isomorphism test. There are simple non-isomorphic graphs (regular graphs with the same degree sequence) that no GNN of this form can distinguish. GIN (Graph Isomorphism Network) achieves the WL-1 upper bound via sum aggregation + an MLP.
Beyond WL-1. Higher-order GNNs use sets of k nodes as their basic unit, gaining expressiveness at O(Nk) cost. Equivariant GNNs (E(n)-GNN, EGNN) for molecules keep predictions invariant under rotations / translations of 3D coordinates. Subgraph-aware methods (SUN, GNN-AK) recover lost expressiveness by considering subgraphs around each node.
Graph transformers. Treat the graph as a fully-connected attention graph with structural inductive biases — distance encoding, edge encoding, or positional encoding from the Laplacian eigenvectors. Beats vanilla GNNs on larger / longer-range graphs. Examples: GraphGPS, Graphormer.
Over-smoothing & over-squashing. Over-smoothing: node features become indistinguishable after many layers. Over-squashing: information from a long-range node has to flow through narrow bottlenecks and gets crushed. Both motivate sparse attention or graph rewiring (adding extra long-range edges).
Scaling. For massive graphs (billions of edges), full-batch training is infeasible. Sampling-based methods (GraphSAGE, Cluster-GCN, GraphSAINT) train on subgraphs. Each samples differently — Cluster-GCN partitions; GraphSAINT samples nodes/edges; SAGE samples local neighbourhoods.
Reach for it when
EGNN: 3D structure-aware tasks (molecules, proteins, physics)
Graph transformers: long-range dependencies on small/medium graphs
Cluster-GCN / SAGE: web-scale graphs that don't fit in memory
Subgraph methods when WL-1 isn't enough
Skip it when
You're chasing the WL hierarchy but the dataset doesn't actually need it
Tabular / time-series — don't dress it up as a graph
You need bounded inference latency on a growing graph
You can solve the task with structural features + an MLP
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import scatter
# A bare-bones expressive GNN layer (GIN-style) — implemented from scratch
class GINLayer(nn.Module):
def __init__(self, d_in, d_out, eps=0.0, learn_eps=True):
super().__init__()
self.eps = nn.Parameter(torch.tensor(eps)) if learn_eps else eps
self.mlp = nn.Sequential(
nn.Linear(d_in, d_out), nn.ReLU(),
nn.Linear(d_out, d_out),
)
def forward(self, x, edge_index):
# Sum-aggregate neighbours, then MLP — provably WL-1 expressive
src, dst = edge_index
agg = scatter(x[src], dst, dim=0, dim_size=x.size(0), reduce="sum")
return self.mlp((1 + self.eps) * x + agg)