Transformer Logo Transformer Networks are a revolutionary architecture in deep learning that uses self-attention mechanisms to process sequential data. They have become the foundation for state-of-the-art models in natural language processing, enabling breakthroughs in machine translation, text generation, and other language tasks. Transformers excel at capturing long-range dependencies and parallel processing of sequences.

Transformer Networks

Transformer Networks were published in 2017 by Vaswani et al. in a now very famous paper called Attention is all you need. This paper used it for language translation, but it has since been used for a wide range of tasks, including image generation, speech recognition, and more. The huge and sudden influx of Large Language Models (LLMs) was made possible by the transformer architecture. They managed to capture long term dependencies and could be trained on vast amounts of data in a way that their predecessor, mostly RNNs, could not. There's a great visualisation here.

Core Concepts

Transformers are built on several key concepts that enable them to effectively process sequential data.

  • Network Architecture

    The structure of a Transformer consists of:

    • Encoder and decoder stacks
    • Multi-head attention layers
    • Position-wise feed-forward networks
    • Positional encoding

  • Key Operations

    The main operations in Transformers include:

    • Self-attention computation
    • Multi-head attention
    • Position-wise feed-forward
    • Layer normalization

Key Components

  • Attention layers
  • Feed-forward networks
  • Layer normalization
  • Positional encoding
  • Residual connections

  • Loss functions
  • Optimizers
  • Learning rate scheduling
  • Masking strategies
  • Gradient clipping

  • BERT and variants
  • GPT models
  • Cross-attention
  • Efficient attention
  • Long-range transformers

Implementation Examples

Transformer with TensorFlow/Keras

import tensorflow as tf
from tensorflow.keras import layers, models

def create_transformer_model(input_shape, num_heads, dff, num_layers, num_classes):
    # Input layers
    inputs = layers.Input(shape=input_shape)
    
    # Positional encoding
    x = layers.Embedding(input_shape[0], dff)(inputs)
    x = x + positional_encoding(input_shape[0], dff)
    
    # Transformer blocks
    for _ in range(num_layers):
        # Multi-head attention
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=dff//num_heads
        )(x, x)
        x = layers.LayerNormalization(epsilon=1e-6)(x + attention_output)
        
        # Feed-forward network
        ffn_output = layers.Dense(dff, activation='relu')(x)
        ffn_output = layers.Dense(dff)(ffn_output)
        x = layers.LayerNormalization(epsilon=1e-6)(x + ffn_output)
    
    # Output layer
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    return models.Model(inputs=inputs, outputs=outputs)

def positional_encoding(position, d_model):
    angle_rads = tf.range(position, dtype=tf.float32)[:, tf.newaxis] / \
                 tf.pow(10000, tf.range(0, d_model, 2, dtype=tf.float32) / d_model)
    
    sines = tf.math.sin(angle_rads)
    cosines = tf.math.cos(angle_rads)
    
    pos_encoding = tf.concat([sines, cosines], axis=-1)
    pos_encoding = pos_encoding[tf.newaxis, ...]
    
    return tf.cast(pos_encoding, tf.float32)

# Example usage
input_shape = (100,)  # Sequence length of 100
num_heads = 8
dff = 512
num_layers = 4
num_classes = 10

model = create_transformer_model(
    input_shape, num_heads, dff, num_layers, num_classes
)
model.summary()

Transformer with PyTorch

import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class TransformerModel(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, num_classes, dropout=0.1):
        super().__init__()
        
        self.embedding = nn.Embedding(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        self.decoder = nn.Linear(d_model, num_classes)
        
    def forward(self, src, src_mask=None):
        # src shape: (batch_size, seq_len)
        x = self.embedding(src)
        x = self.pos_encoder(x)
        
        if src_mask is None:
            src_mask = self.generate_square_subsequent_mask(src.size(1)).to(src.device)
        
        output = self.transformer_encoder(x, src_mask)
        output = self.decoder(output)
        return output
    
    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

# Example usage
input_dim = 10000  # Vocabulary size
d_model = 512      # Embedding dimension
nhead = 8          # Number of attention heads
num_encoder_layers = 6
dim_feedforward = 2048
num_classes = 10
dropout = 0.1

# Create model
model = TransformerModel(
    input_dim=input_dim,
    d_model=d_model,
    nhead=nhead,
    num_encoder_layers=num_encoder_layers,
    dim_feedforward=dim_feedforward,
    num_classes=num_classes,
    dropout=dropout
)

# Example input
batch_size = 32
seq_length = 100
x = torch.randint(0, input_dim, (batch_size, seq_length))

# Forward pass
output = model(x)
print(f"Output shape: {output.shape}")  # Should be (batch_size, seq_length, num_classes)