Graph Neural Networks (GNNs) are a class of deep learning models designed to process data represented as graphs. They excel at learning from structured data with complex relationships, making them powerful tools for social network analysis, molecular structure prediction, and recommendation systems. GNNs can capture both local and global graph structure through message passing between nodes.
GNNs are built on several key concepts that enable them to effectively process graph data.
The structure of a GNN consists of:
The main operations in GNNs include:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
class GNN(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GNN, self).__init__()
self.conv1 = GCNConv(num_node_features, 64)
self.conv2 = GCNConv(64, 64)
self.conv3 = GCNConv(64, 64)
self.classifier = torch.nn.Linear(64, num_classes)
def forward(self, x, edge_index, batch):
# Node feature transformation
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.2, training=self.training)
x = self.conv3(x, edge_index)
x = F.relu(x)
# Graph-level pooling
x = global_mean_pool(x, batch)
# Classification
x = self.classifier(x)
return F.log_softmax(x, dim=1)
# Example usage
num_node_features = 16 # Number of features per node
num_classes = 2 # Binary classification
model = GNN(num_node_features, num_classes)
print(model)