Transformer Networks were published in 2017 by Vaswani et al. in a now very famous paper called Attention is all you need. This paper used it for language translation, but it has since been used for a wide range of tasks, including image generation, speech recognition, and more. The huge and sudden influx of Large Language Models (LLMs) was made possible by the transformer architecture. They managed to capture long term dependencies and could be trained on vast amounts of data in a way that their predecessor, mostly RNNs, could not. There's a great visualisation here.
Transformers are built on several key concepts that enable them to effectively process sequential data.
The structure of a Transformer consists of:
The main operations in Transformers include:
import tensorflow as tf
from tensorflow.keras import layers, models
def create_transformer_model(input_shape, num_heads, dff, num_layers, num_classes):
# Input layers
inputs = layers.Input(shape=input_shape)
# Positional encoding
x = layers.Embedding(input_shape[0], dff)(inputs)
x = x + positional_encoding(input_shape[0], dff)
# Transformer blocks
for _ in range(num_layers):
# Multi-head attention
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=dff//num_heads
)(x, x)
x = layers.LayerNormalization(epsilon=1e-6)(x + attention_output)
# Feed-forward network
ffn_output = layers.Dense(dff, activation='relu')(x)
ffn_output = layers.Dense(dff)(ffn_output)
x = layers.LayerNormalization(epsilon=1e-6)(x + ffn_output)
# Output layer
outputs = layers.Dense(num_classes, activation='softmax')(x)
return models.Model(inputs=inputs, outputs=outputs)
def positional_encoding(position, d_model):
angle_rads = tf.range(position, dtype=tf.float32)[:, tf.newaxis] / \
tf.pow(10000, tf.range(0, d_model, 2, dtype=tf.float32) / d_model)
sines = tf.math.sin(angle_rads)
cosines = tf.math.cos(angle_rads)
pos_encoding = tf.concat([sines, cosines], axis=-1)
pos_encoding = pos_encoding[tf.newaxis, ...]
return tf.cast(pos_encoding, tf.float32)
# Example usage
input_shape = (100,) # Sequence length of 100
num_heads = 8
dff = 512
num_layers = 4
num_classes = 10
model = create_transformer_model(
input_shape, num_heads, dff, num_layers, num_classes
)
model.summary()
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class TransformerModel(nn.Module):
def __init__(self, input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, num_classes, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(input_dim, d_model)
self.pos_encoder = PositionalEncoding(d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
self.decoder = nn.Linear(d_model, num_classes)
def forward(self, src, src_mask=None):
# src shape: (batch_size, seq_len)
x = self.embedding(src)
x = self.pos_encoder(x)
if src_mask is None:
src_mask = self.generate_square_subsequent_mask(src.size(1)).to(src.device)
output = self.transformer_encoder(x, src_mask)
output = self.decoder(output)
return output
def generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
# Example usage
input_dim = 10000 # Vocabulary size
d_model = 512 # Embedding dimension
nhead = 8 # Number of attention heads
num_encoder_layers = 6
dim_feedforward = 2048
num_classes = 10
dropout = 0.1
# Create model
model = TransformerModel(
input_dim=input_dim,
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
dim_feedforward=dim_feedforward,
num_classes=num_classes,
dropout=dropout
)
# Example input
batch_size = 32
seq_length = 100
x = torch.randint(0, input_dim, (batch_size, seq_length))
# Forward pass
output = model(x)
print(f"Output shape: {output.shape}") # Should be (batch_size, seq_length, num_classes)