Transformers have become a fundamental component for many state-of-the-art natural language processing (NLP) systems. In this post, we will walk through how to implement a Transformer model from scratch using PyTorch.
Introduction
The Transformer architecture was first introduced in the paper Attention is All You Need by Vaswani et al. in 2017. It has since become incredibly popular and is now the model of choice for many NLP tasks such as machine translation, text summarization, question answering and more. The key innovations of the Transformer are:
Reliance entirely on attention mechanisms, eliminating recurrence and convolutions entirely
Multi-head self-attention allows the model to jointly attend to information from different representation subspaces
Positional encodings provide the model with information about the relative positioning of tokens in the sequence
In this tutorial we will use PyTorch to implement the Transformer from scratch, learning about the components that make up this powerful model.
Imports and Settings
We'll start by importing PyTorch and defining some model hyperparameters:
import torch
import torch.nn as nn
from torch.nn import functional as F
# Model hyperparameters
d_model = 512 # Embedding size
nhead = 8 # Number of attention heads
num_encoder_layers = 6 # Number of encoder layers
num_decoder_layers = 6 # Number of decoder layers
dim_feedforward = 2048 # Inner layer dimensionality in feedforward network
dropout = 0.1
Positional Encoding
Since the Transformer has no recurrence or convolution, we must inject some information about the relative position of tokens in the sequence. This is done using positional encodings by summing timing signals based on sine and cosine functions of different frequencies to the input embeddings.
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
This injects positional information to the input embeddings before passing them to the model.
Multi-Head Attention
A core component of the Transformer is multi-head attention, which allows the model to jointly attend to information from different representation subspaces at different positions. The multi-head attention consists of splitting the query, key and value vectors into multiple heads, and then computing scaled dot-product attention for each head. The attention outputs of each head are then concatenated and projected.
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0# We assume d_v always equals d_k
self.d_k = d_model // h
self.h = h
# Layers to project input features to q, k, v vectors
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(p=dropout)
self.out = nn.Linear(d_model, d_model)
def attention(self, q, k, v, d_k, mask=None, dropout=None):
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
if dropout is not None:
scores = dropout(scores)
output = torch.matmul(scores, v)
return output
def forward(self, q, k, v, mask=None):
bs = q.size(0)
# Perform linear projection
q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
# Transpose to get dimensions bs * h * sl * d_model
q = q.transpose(1,2)
k = k.transpose(1,2)
v = v.transpose(1,2)
# Calculate attention using function we defined above
scores = self.attention(q, k, v, self.d_k, mask, self.dropout)
# Concatenate heads and project back to original dimension
concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)
output = self.out(concat)
return output
This allows the model to jointly attend to information at different positions, an essential component for processing language.
Feed Forward Network
We also add a two-layer feedforward network after the self-attention and layer normalization. This consists of two linear transformations with a ReLU activation in between:
class PositionwiseFeedforwardLayer(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedforwardLayer, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ff, d_model)
def forward(self, x):
x = self.dropout(F.relu(self.linear1(x)))
x = self.linear2(x)
return x
This FFN can process the attention output features further before passing them to the next layer.
Encoder Layer
With the attention and feedforward blocks defined, we can now build the full encoder layer. This consists of multi-head self-attention followed by the feedforward network, with residual connections and layer normalization added for each block:
class EncoderLayer(nn.Module):
def __init__(self, d_model, heads, dropout=0.1):
super(EncoderLayer, self).__init__()
self.norm_1 = Norm(d_model)
self.norm_2 = Norm(d_model)
self.attn = MultiHeadedAttention(heads, d_model)
self.ff = PositionwiseFeedforwardLayer(d_model, d_ff, dropout)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
def forward(self, x, mask):
x2 = self.norm_1(x)
x = x + self.dropout_1(self.attn(x2, x2, x2, mask))
x2 = self.norm_2(x)
x = x + self.dropout_2(self.ff(x2))
return x
We can then stack N of these encoder layers to form the full encoder. The decoder layers are similar, except with an extra multi-head attention block attending to the encoder outputs.
Full Transformer With the components defined, we can now implement the full Transformer model. The encoder consists of an embedding layer followed by positional encodings and N encoder layers. The decoder is similar but includes an extra multi-head attention block attending to the encoder outputs:
class Transformer(nn.Module):
def __init__(self, n_src_vocab, n_tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
super(Transformer, self).__init__()
self.encoder = Encoder(n_src_vocab, N, n_src_vocab, d_model, d_ff, h, dropout)
self.decoder = Decoder(n_tgt_vocab, N, n_src_vocab, d_model, d_ff, h, dropout)
self.out = nn.Linear(d_model, n_tgt_vocab)
def forward(self, src, tgt, src_mask, tgt_mask):
e_outputs = self.encoder(src, src_mask)
d_output = self.decoder(tgt, e_outputs, src_mask, tgt_mask)
output = self.out(d_output)
return output
This gives us the full Transformer model powered entirely by attention.
Training the Transformer Model
To train the model, we simply need to define an optimizer and criterion then write a typical training loop. For example:
# Define optimizer and criterion
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
optimizer.zero_grad()
outputs = model(src, tgt)
loss = criterion(outputs, gold)
loss.backward()
optimizer.step()
This will optimize the model parameters to minimize the cross entropy loss using backpropagation and stochastic gradient descent. The same approach can be used for evaluating the model on a validation set.
Conclusion
And that's it! We've built a Transformer model from scratch using the building blocks of multi-head attention, feedforward layers, and residual connections. Transformers have led to huge advances in NLP and this tutorial provided insight into how they actually work under the hood. To leverage pretrained models like BERT and GPT-2, we can use the 🤗Transformers library by HuggingFace.
Dreaming of an AI-driven transformation? Engage with Codersarts AI today and let's co-create the future of tech, one prototype at a time.
Comments