1. The Transformer Architecture
1.1 "Attention Is All You Need"
Published in June 2017 by Vaswani et al. at Google, this paper introduced the Transformer architecture. It is arguably the most impactful machine learning paper ever written. Every major language model since — GPT, BERT, T5, Llama, Claude, Gemini — is based on the Transformer.
1.2 Full Architecture Overview
The original Transformer has an encoder-decoder structure, designed for sequence-to-sequence tasks like machine translation.
┌─────────────────────────────────────────────────────────────────────┐
│ THE TRANSFORMER ARCHITECTURE │
│ │
│ ENCODER (left side) DECODER (right side) │
│ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ │ │ │ │
│ │ ┌───────────────┐ │ x N │ ┌───────────────┐ │ x N │
│ │ │ Feed-Forward │ │ │ │ Feed-Forward │ │ │
│ │ │ Network (FFN) │ │ │ │ Network (FFN) │ │ │
│ │ └───────┬───────┘ │ │ └───────┬───────┘ │ │
│ │ Add & Norm │ │ Add & Norm │ │
│ │ ┌───────────────┐ │ │ ┌───────────────┐ │ │
│ │ │ Multi-Head │ │ │ │ Cross- │ │ │
│ │ │ Self-Attention│ │ ──────►│ │ Attention │ │ │
│ │ └───────┬───────┘ │ (K, V) │ │ (Enc → Dec) │ │ │
│ │ Add & Norm │ │ └───────┬───────┘ │ │
│ │ │ │ Add & Norm │ │
│ └─────────────────────┘ │ ┌───────────────┐ │ │
│ │ │ Masked Multi- │ │ │
│ │ │ Head Self- │ │ │
│ │ │ Attention │ │ │
│ │ └───────┬───────┘ │ │
│ │ Add & Norm │ │
│ └─────────────────────┘ │
│ │
│ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ Input Embedding │ │ Output Embedding │ │
│ │ + Positional Enc. │ │ + Positional Enc. │ │
│ └─────────────────────┘ └─────────────────────┘ │
│ │
│ INPUT TOKENS OUTPUT TOKENS │
│ "The cat sat on the mat" "Le chat ..." (shifted right) │
└─────────────────────────────────────────────────────────────────────┘
| Architecture Variant | Structure | Examples | Best For |
|---|---|---|---|
| Encoder-Only | Only the left side; bidirectional attention | BERT, RoBERTa, DeBERTa | Classification, NER, sentence embeddings |
| Decoder-Only | Only the right side; causal (left-to-right) attention | GPT-1/2/3/4, Llama, Claude, DeepSeek | Text generation, chatbots, code generation |
| Encoder-Decoder | Both sides; cross-attention between them | T5, BART, mBART, Whisper | Translation, summarization, speech-to-text |
+ Positional Encoding"] --> ESA["Multi-Head
Self-Attention"] ESA --> EAN1["Add and Norm"] EAN1 --> EFFN["Feed-Forward
Network"] EFFN --> EAN2["Add and Norm"] end subgraph Decoder DI["Output Embedding
+ Positional Encoding"] --> DSA["Masked Multi-Head
Self-Attention"] DSA --> DAN1["Add and Norm"] DAN1 --> DCA["Cross-Attention
Q from Decoder, K V from Encoder"] DCA --> DAN2["Add and Norm"] DAN2 --> DFFN["Feed-Forward
Network"] DFFN --> DAN3["Add and Norm"] end EAN2 --> DCA DAN3 --> LIN["Linear + Softmax"] LIN --> OUT["Output Probabilities"] style EI fill:#d5e8d4,stroke:#333 style DI fill:#dae8fc,stroke:#333 style ESA fill:#fff2cc,stroke:#333 style DSA fill:#fff2cc,stroke:#333 style DCA fill:#e1d5e7,stroke:#333 style OUT fill:#f8cecc,stroke:#333
- Simpler architecture (one stack instead of two)
- Scales better with more parameters and data
- Autoregressive generation is natural and flexible
- Can handle any task framed as text generation
- More efficient to train (single forward pass per token)
1.3 Why Transformers Replaced RNNs/LSTMs
| Aspect | RNNs/LSTMs | Transformers |
|---|---|---|
| Parallelization | Sequential processing (token by token). Cannot parallelize within a sequence. | Fully parallel (all tokens processed simultaneously). Massive GPU utilization. |
| Long-range dependencies | Information degrades over long distances (vanishing gradients). LSTM helps but doesn't solve. | Direct attention between ANY two tokens, regardless of distance. O(1) path length. |
| Training speed | Slow (sequential bottleneck) | Fast (parallel computation on GPUs) |
| Memory | O(1) per step (hidden state) | O(n2) for attention (quadratic in sequence length) |
| Scalability | Difficult to scale beyond ~1B parameters | Scales to hundreds of billions (proven by GPT-3/4, Llama, etc.) |
2. Multi-Head Attention
2.1 Why Multiple Heads?
A single attention head can only focus on one type of relationship at a time. But language has many simultaneous relationships: syntactic (subject-verb agreement), semantic (word meaning), coreference ("it" refers to "cat"), positional patterns, etc.
Multi-head attention runs several attention operations in parallel, each learning to focus on different aspects of the input. It's like having multiple "experts" each looking at the data from a different angle.
MultiHead(Q, K, V) = Concat(head1, ..., headh) WO
where headi = Attention(Q WiQ, K WiK, V WiV)
shape: seq_len x d_model"] --> S1["Split into h Heads"] S1 --> H1["Head 1
Attention on d_k dims"] S1 --> H2["Head 2
Attention on d_k dims"] S1 --> H3["..."] S1 --> Hh["Head h
Attention on d_k dims"] H1 --> CAT["Concatenate All Heads
Back to d_model dims"] H2 --> CAT H3 --> CAT Hh --> CAT CAT --> WO["Output Projection W_O
Linear layer"] WO --> OUT["Multi-Head Output"] style X fill:#e8f4f8,stroke:#333 style H1 fill:#dae8fc,stroke:#333 style H2 fill:#d5e8d4,stroke:#333 style H3 fill:#fff2cc,stroke:#333 style Hh fill:#e1d5e7,stroke:#333 style CAT fill:#f8cecc,stroke:#333 style OUT fill:#d5e8d4,stroke:#333
2.2 How Heads Are Split and Concatenated
"""
Multi-Head Attention - How it works internally
================================================
Key insight: Instead of doing attention with d_model dimensions,
we split into h heads, each operating on d_model/h dimensions.
Total computation is the same, but we get h different attention patterns.
"""
import numpy as np
# Example configuration (like BERT-base)
d_model = 768 # Model dimension
n_heads = 12 # Number of attention heads
d_k = d_model // n_heads # = 64 per head
d_v = d_model // n_heads # = 64 per head
print(f"d_model = {d_model}")
print(f"n_heads = {n_heads}")
print(f"d_k per head = {d_k}")
print(f"d_v per head = {d_v}")
# The process:
# 1. Project input X (seq_len, 768) to Q, K, V using large matrices
# 2. Split Q, K, V into 12 heads of dimension 64 each
# 3. Run attention independently on each head
# 4. Concatenate the 12 outputs back to dimension 768
# 5. Apply final linear projection W_O
# Step-by-step shapes:
seq_len = 10
batch_size = 1
print(f"\nShape trace:")
print(f"Input X: ({batch_size}, {seq_len}, {d_model})")
print(f"After W_Q: ({batch_size}, {seq_len}, {d_model})")
print(f"Reshape to heads: ({batch_size}, {n_heads}, {seq_len}, {d_k})")
print(f"Attention per head:({batch_size}, {n_heads}, {seq_len}, {d_v})")
print(f"Concat heads: ({batch_size}, {seq_len}, {d_model})")
print(f"After W_O: ({batch_size}, {seq_len}, {d_model})")
2.3 PRACTICAL: Multi-Head Attention in PyTorch
"""
Complete Multi-Head Attention Implementation in PyTorch
========================================================
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
"""
Multi-Head Self-Attention mechanism.
This is the core component of every Transformer layer.
"""
def __init__(self, d_model, n_heads, dropout=0.1):
"""
Args:
d_model: Total model dimension (e.g., 768)
n_heads: Number of attention heads (e.g., 12)
dropout: Dropout rate for attention weights
"""
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # Dimension per head
# Linear projections for Q, K, V
# Each is (d_model, d_model) - we project all heads at once
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
# Output projection
self.W_O = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x):
"""
Split the last dimension into (n_heads, d_k).
Input: (batch, seq_len, d_model)
Output: (batch, n_heads, seq_len, d_k)
"""
batch_size, seq_len, _ = x.shape
# Reshape: (batch, seq_len, n_heads, d_k)
x = x.view(batch_size, seq_len, self.n_heads, self.d_k)
# Transpose: (batch, n_heads, seq_len, d_k)
return x.transpose(1, 2)
def combine_heads(self, x):
"""
Reverse of split_heads.
Input: (batch, n_heads, seq_len, d_k)
Output: (batch, seq_len, d_model)
"""
batch_size, _, seq_len, _ = x.shape
# Transpose: (batch, seq_len, n_heads, d_k)
x = x.transpose(1, 2)
# Reshape: (batch, seq_len, d_model)
return x.contiguous().view(batch_size, seq_len, self.d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
"""
Compute scaled dot-product attention.
Args:
Q: (batch, n_heads, seq_len, d_k)
K: (batch, n_heads, seq_len, d_k)
V: (batch, n_heads, seq_len, d_k)
mask: (1, 1, seq_len, seq_len) or None
Returns:
context: (batch, n_heads, seq_len, d_k)
attention_weights: (batch, n_heads, seq_len, seq_len)
"""
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) # (batch, heads, seq, seq)
scores = scores / math.sqrt(self.d_k)
# Apply mask (for causal attention or padding)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Weighted sum of values
context = torch.matmul(attention_weights, V)
return context, attention_weights
def forward(self, x, mask=None):
"""
Forward pass for multi-head self-attention.
Args:
x: Input tensor (batch, seq_len, d_model)
mask: Optional attention mask
Returns:
output: (batch, seq_len, d_model)
attention_weights: (batch, n_heads, seq_len, seq_len)
"""
# Step 1: Linear projections
Q = self.W_Q(x) # (batch, seq_len, d_model)
K = self.W_K(x)
V = self.W_V(x)
# Step 2: Split into multiple heads
Q = self.split_heads(Q) # (batch, n_heads, seq_len, d_k)
K = self.split_heads(K)
V = self.split_heads(V)
# Step 3: Scaled dot-product attention (parallel for all heads)
context, attention_weights = self.scaled_dot_product_attention(
Q, K, V, mask
)
# Step 4: Combine heads
context = self.combine_heads(context) # (batch, seq_len, d_model)
# Step 5: Final linear projection
output = self.W_O(context) # (batch, seq_len, d_model)
return output, attention_weights
# ------- Test the implementation -------
torch.manual_seed(42)
batch_size = 2
seq_len = 10
d_model = 64
n_heads = 8
x = torch.randn(batch_size, seq_len, d_model)
mha = MultiHeadAttention(d_model=d_model, n_heads=n_heads, dropout=0.0)
output, weights = mha(x)
print(f"Input shape: {x.shape}") # (2, 10, 64)
print(f"Output shape: {output.shape}") # (2, 10, 64)
print(f"Weights shape: {weights.shape}") # (2, 8, 10, 10)
# Verify each head has different attention patterns
print(f"\nAttention weights for head 0 (first example, first 5 tokens):")
print(weights[0, 0, :5, :5].detach().numpy().round(3))
print(f"\nAttention weights for head 3 (same positions):")
print(weights[0, 3, :5, :5].detach().numpy().round(3))
# These should be DIFFERENT - each head learns different patterns!
# Parameter count
total_params = sum(p.numel() for p in mha.parameters())
print(f"\nTotal parameters: {total_params:,}")
print(f" W_Q: {d_model * d_model:,}")
print(f" W_K: {d_model * d_model:,}")
print(f" W_V: {d_model * d_model:,}")
print(f" W_O: {d_model * d_model:,}")
print(f" Total: 4 * d_model^2 = {4 * d_model * d_model:,}")
3. Q, K, V Matrices Deep Dive
3.1 How Q, K, V Are Computed
"""
Q, K, V Computation - Detailed Breakdown
==========================================
"""
import torch
import torch.nn as nn
# Configuration
d_model = 512 # Model dimension
n_heads = 8 # Number of heads
d_k = d_model // n_heads # = 64 per head
# Input: sequence of token embeddings
# Shape: (batch_size, seq_len, d_model)
batch_size = 1
seq_len = 4
x = torch.randn(batch_size, seq_len, d_model)
# Weight matrices
# Shape: (d_model, d_model) = (512, 512) for each of Q, K, V
W_Q = nn.Linear(d_model, d_model, bias=False)
W_K = nn.Linear(d_model, d_model, bias=False)
W_V = nn.Linear(d_model, d_model, bias=False)
print("COMPUTING Q, K, V:")
print(f" Input x shape: {x.shape}")
print(f" W_Q weight shape: {W_Q.weight.shape}")
# Matrix multiplication: (1, 4, 512) @ (512, 512) = (1, 4, 512)
Q = W_Q(x) # Each token's embedding is projected to a Query vector
K = W_K(x) # Each token's embedding is projected to a Key vector
V = W_V(x) # Each token's embedding is projected to a Value vector
print(f"\n Q shape: {Q.shape}") # (1, 4, 512)
print(f" K shape: {K.shape}")
print(f" V shape: {V.shape}")
# Now reshape for multi-head attention
# (1, 4, 512) -> (1, 4, 8, 64) -> (1, 8, 4, 64)
Q_heads = Q.view(batch_size, seq_len, n_heads, d_k).transpose(1, 2)
K_heads = K.view(batch_size, seq_len, n_heads, d_k).transpose(1, 2)
V_heads = V.view(batch_size, seq_len, n_heads, d_k).transpose(1, 2)
print(f"\n Q_heads shape: {Q_heads.shape}") # (1, 8, 4, 64)
print(f" K_heads shape: {K_heads.shape}")
print(f" V_heads shape: {V_heads.shape}")
# Each head operates on 64-dimensional slices independently
# Head 0 uses Q_heads[:, 0, :, :] which is (1, 4, 64)
# Head 1 uses Q_heads[:, 1, :, :] which is (1, 4, 64)
# ... etc.
print(f"\n Head 0 Q slice shape: {Q_heads[:, 0, :, :].shape}") # (1, 4, 64)
print(f" Head 0 K slice shape: {K_heads[:, 0, :, :].shape}")
print(f" Head 0 V slice shape: {V_heads[:, 0, :, :].shape}")
seq_len x d_model"] --> WQ["W_Q Weight Matrix
d_model x d_k"] X --> WK["W_K Weight Matrix
d_model x d_k"] X --> WV["W_V Weight Matrix
d_model x d_v"] WQ --> Q["Q = X times W_Q
What am I looking for"] WK --> K["K = X times W_K
What do I contain"] WV --> V["V = X times W_V
What info do I carry"] Q --> ATT["Scaled Dot-Product
Attention"] K --> ATT V --> ATT ATT --> O["Attention Output"] style X fill:#e8f4f8,stroke:#333 style Q fill:#dae8fc,stroke:#333 style K fill:#d5e8d4,stroke:#333 style V fill:#fff2cc,stroke:#333 style ATT fill:#e1d5e7,stroke:#333 style O fill:#f8cecc,stroke:#333
3.2 Intuition for What Q, K, V Represent
Think of it as a database query analogy:
- Query (Q): "I am looking for context about [this concept]." Each token generates a query vector representing what information it needs.
- Key (K): "I contain information about [this concept]." Each token generates a key vector advertising what information it can provide.
- Value (V): "Here is the actual information I carry." The content that gets passed along when attention is paid.
The dot product Q · K measures relevance (how well the query matches the key). The result weights how much of each V to include in the output.
Crucially: Q, K, and V are all derived from the same input (in self-attention) but through DIFFERENT learned projections. This means the model can learn that the same word plays different roles as a query vs. a key vs. a value.
3.3 Worked Example with Tensor Shapes
"""
Complete Shape Trace Through Multi-Head Attention
===================================================
Model: d_model=768, n_heads=12, d_k=64
Sequence: 5 tokens
"""
import torch
# Configuration (BERT-base / GPT-2 scale)
batch = 2
seq = 5
d_model = 768
n_heads = 12
d_k = d_model // n_heads # 64
print("=" * 60)
print("COMPLETE SHAPE TRACE")
print("=" * 60)
# Input embeddings (after token embedding + positional encoding)
x = torch.randn(batch, seq, d_model)
print(f"\n1. Input X: {x.shape}")
print(f" Meaning: {batch} examples, {seq} tokens each, {d_model}-dim embeddings")
# Linear projections
W_Q = torch.randn(d_model, d_model)
Q = x @ W_Q # (2, 5, 768) @ (768, 768) = (2, 5, 768)
print(f"\n2. Q = X @ W_Q: {Q.shape}")
print(f" Similarly K, V: ({batch}, {seq}, {d_model})")
# Split into heads
Q_heads = Q.view(batch, seq, n_heads, d_k).transpose(1, 2)
print(f"\n3. Split into {n_heads} heads:")
print(f" Q_heads: {Q_heads.shape}")
print(f" Meaning: {batch} examples, {n_heads} heads, {seq} tokens, {d_k}-dim per head")
# Attention scores for each head
K_heads = torch.randn(batch, n_heads, seq, d_k) # placeholder
scores = Q_heads @ K_heads.transpose(-2, -1)
print(f"\n4. Attention scores = Q_heads @ K_heads^T:")
print(f" scores: {scores.shape}")
print(f" Meaning: for each head, a {seq}x{seq} attention matrix")
# After softmax
weights = torch.softmax(scores / (d_k ** 0.5), dim=-1)
print(f"\n5. Attention weights (after softmax): {weights.shape}")
# Weighted sum of values
V_heads = torch.randn(batch, n_heads, seq, d_k)
context = weights @ V_heads
print(f"\n6. Context = weights @ V_heads: {context.shape}")
print(f" Meaning: weighted combination of values for each head")
# Combine heads
combined = context.transpose(1, 2).contiguous().view(batch, seq, d_model)
print(f"\n7. Combine heads: {combined.shape}")
print(f" Back to original dimensions!")
# Output projection
W_O = torch.randn(d_model, d_model)
output = combined @ W_O
print(f"\n8. Output = combined @ W_O: {output.shape}")
print(f" Same shape as input! This allows residual connections.")
4. Cross-Attention
4.1 How the Decoder Attends to Encoder Output
In encoder-decoder models (T5, BART, translation models), the decoder needs information from the encoder. Cross-attention is how this happens:
- Queries (Q) come from the decoder (what the decoder is looking for)
- Keys (K) and Values (V) come from the encoder output (what the source has)
"""
Cross-Attention Implementation
================================
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class CrossAttention(nn.Module):
"""
Cross-attention: decoder attends to encoder output.
Difference from self-attention:
- In self-attention: Q, K, V all come from the same source
- In cross-attention: Q comes from decoder, K and V from encoder
"""
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_Q = nn.Linear(d_model, d_model, bias=False) # Projects decoder states
self.W_K = nn.Linear(d_model, d_model, bias=False) # Projects encoder output
self.W_V = nn.Linear(d_model, d_model, bias=False) # Projects encoder output
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, decoder_state, encoder_output):
"""
Args:
decoder_state: (batch, decoder_seq_len, d_model) - from previous decoder layer
encoder_output: (batch, encoder_seq_len, d_model) - final encoder output
Returns:
output: (batch, decoder_seq_len, d_model)
"""
batch_size = decoder_state.size(0)
dec_len = decoder_state.size(1)
enc_len = encoder_output.size(1)
# Q from decoder, K and V from encoder
Q = self.W_Q(decoder_state) # (batch, dec_len, d_model)
K = self.W_K(encoder_output) # (batch, enc_len, d_model)
V = self.W_V(encoder_output) # (batch, enc_len, d_model)
# Split heads
Q = Q.view(batch_size, dec_len, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, enc_len, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, enc_len, self.n_heads, self.d_k).transpose(1, 2)
# Attention: (batch, heads, dec_len, d_k) @ (batch, heads, d_k, enc_len)
scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
# Shape: (batch, heads, dec_len, enc_len)
# This tells us: for each decoder position, how much to attend to
# each encoder position
weights = F.softmax(scores, dim=-1)
context = weights @ V # (batch, heads, dec_len, d_k)
# Combine heads
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, dec_len, -1)
return self.W_O(context)
# Example: French to English translation
d_model = 256
n_heads = 8
cross_attn = CrossAttention(d_model, n_heads)
# Encoder processes: "Le chat est sur le tapis" (6 tokens)
encoder_output = torch.randn(1, 6, d_model)
# Decoder is generating: "The cat is" (3 tokens so far)
decoder_state = torch.randn(1, 3, d_model)
output = cross_attn(decoder_state, encoder_output)
print(f"Encoder output shape: {encoder_output.shape}") # (1, 6, 256)
print(f"Decoder state shape: {decoder_state.shape}") # (1, 3, 256)
print(f"Cross-attention output: {output.shape}") # (1, 3, 256)
# The decoder token "is" can now attend to all French source tokens
# to figure out what English word comes next
5. Feed-Forward Neural Networks in Transformers
5.1 Position-wise FFN
Original Transformer FFN:
FFN(x) = max(0, xW1 + b1)W2 + b2
Or equivalently: FFN(x) = ReLU(xW1 + b1)W2 + b2
The FFN is applied to each position independently (same weights for all positions).
Why FFNs are needed: Attention is great at mixing information between positions, but it's fundamentally a weighted average — a linear operation. The FFN provides non-linear transformation that allows the model to compute complex functions of the attention output.
5.2 Hidden Dimension Expansion
"""
Feed-Forward Network in Transformers
======================================
The FFN typically EXPANDS the dimension by 4x, then projects back.
"""
import torch
import torch.nn as nn
class FeedForwardNetwork(nn.Module):
"""
Position-wise Feed-Forward Network.
FFN(x) = activation(x @ W1 + b1) @ W2 + b2
The hidden dimension is typically 4x the model dimension.
This expansion-contraction creates a "bottleneck" that forces
the model to learn compressed representations.
"""
def __init__(self, d_model, d_ff=None, dropout=0.1, activation='relu'):
"""
Args:
d_model: Input/output dimension (e.g., 768)
d_ff: Hidden dimension (default: 4 * d_model = 3072)
dropout: Dropout rate
activation: 'relu' or 'gelu' or 'swiglu'
"""
super().__init__()
self.d_ff = d_ff or 4 * d_model
self.W1 = nn.Linear(d_model, self.d_ff)
self.W2 = nn.Linear(self.d_ff, d_model)
self.dropout = nn.Dropout(dropout)
if activation == 'relu':
self.activation = nn.ReLU()
elif activation == 'gelu':
self.activation = nn.GELU()
else:
self.activation = nn.ReLU()
def forward(self, x):
"""
Args:
x: (batch, seq_len, d_model)
Returns:
output: (batch, seq_len, d_model)
"""
# Expand: (batch, seq, 768) -> (batch, seq, 3072)
hidden = self.W1(x)
hidden = self.activation(hidden)
hidden = self.dropout(hidden)
# Contract: (batch, seq, 3072) -> (batch, seq, 768)
output = self.W2(hidden)
return output
# Example
d_model = 768
ffn = FeedForwardNetwork(d_model, activation='gelu')
x = torch.randn(2, 10, d_model)
output = ffn(x)
print(f"Input: {x.shape}") # (2, 10, 768)
print(f"Output: {output.shape}") # (2, 10, 768)
# Parameter count
params = sum(p.numel() for p in ffn.parameters())
print(f"\nFFN parameters: {params:,}")
print(f" W1: {d_model} x {4*d_model} + bias = {d_model * 4*d_model + 4*d_model:,}")
print(f" W2: {4*d_model} x {d_model} + bias = {4*d_model * d_model + d_model:,}")
# FFN has MORE parameters than the attention layer!
# W1: 768 * 3072 = 2,359,296
# W2: 3072 * 768 = 2,359,296
# Total: ~4.7M (vs ~2.4M for attention)
5.3 SwiGLU Activation (Modern Models)
"""
SwiGLU Activation - Used in Llama, Mistral, and most modern LLMs
=================================================================
SwiGLU combines the Swish activation with a Gated Linear Unit.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLUFeedForward(nn.Module):
"""
FFN with SwiGLU activation (used in Llama 2/3, Mistral, etc.)
SwiGLU(x, W, V, b, c) = (Swish(xW + b)) * (xV + c)
Key differences from standard FFN:
1. Uses Swish (SiLU) instead of ReLU/GELU
2. Has a "gate" mechanism (element-wise multiplication)
3. Uses 3 weight matrices instead of 2
4. Hidden dim is typically (2/3) * 4 * d_model to keep param count similar
"""
def __init__(self, d_model, d_ff=None, dropout=0.1):
super().__init__()
# Llama convention: hidden_dim = (2/3) * 4 * d_model
# rounded to nearest multiple of 256 for efficiency
self.d_ff = d_ff or int(2 * (4 * d_model) / 3)
# Round to multiple of 256
self.d_ff = 256 * ((self.d_ff + 255) // 256)
# Three projections instead of two
self.w1 = nn.Linear(d_model, self.d_ff, bias=False) # Gate projection
self.w2 = nn.Linear(self.d_ff, d_model, bias=False) # Down projection
self.w3 = nn.Linear(d_model, self.d_ff, bias=False) # Up projection
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
SwiGLU(x) = (SiLU(xW1)) * (xW3) then down-project with W2
SiLU(x) = x * sigmoid(x) (also called Swish)
"""
# Gate: apply SiLU to first projection
gate = F.silu(self.w1(x)) # (batch, seq, d_ff)
# Up projection (no activation)
up = self.w3(x) # (batch, seq, d_ff)
# Element-wise multiply (gating mechanism)
hidden = gate * up # (batch, seq, d_ff)
hidden = self.dropout(hidden)
# Down projection
output = self.w2(hidden) # (batch, seq, d_model)
return output
# Compare parameter counts
d_model = 4096 # Llama 3.1 8B
standard_ffn = FeedForwardNetwork(d_model, d_ff=4*d_model)
swiglu_ffn = SwiGLUFeedForward(d_model)
standard_params = sum(p.numel() for p in standard_ffn.parameters())
swiglu_params = sum(p.numel() for p in swiglu_ffn.parameters())
print(f"Standard FFN (d_ff=4*d_model={4*d_model}):")
print(f" Parameters: {standard_params:,}")
print(f"\nSwiGLU FFN (d_ff={swiglu_ffn.d_ff}):")
print(f" Parameters: {swiglu_params:,}")
print(f"\nRatio: {swiglu_params/standard_params:.2f}x")
# SwiGLU has ~1.5x more params per FFN due to 3 matrices,
# but the hidden dim is smaller (2/3 * 4 * d_model) to compensate
6. Layer Normalization
6.1 Why Normalization?
During training, the distribution of layer inputs shifts as parameters are updated (called "internal covariate shift"). Normalization stabilizes training by ensuring that each layer's inputs have consistent statistics (mean and variance).
6.2 LayerNorm vs BatchNorm
"""
LayerNorm vs BatchNorm
========================
"""
import torch
import torch.nn as nn
import numpy as np
# Example tensor: (batch=3, seq_len=4, d_model=6)
x = torch.tensor([
[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
[2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
[3.0, 4.0, 5.0, 6.0, 7.0, 8.0]],
[[1.5, 2.5, 3.5, 4.5, 5.5, 6.5],
[0.5, 1.5, 2.5, 3.5, 4.5, 5.5],
[2.5, 3.5, 4.5, 5.5, 6.5, 7.5],
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]],
[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0, 3.0, 3.0]],
])
print(f"Input shape: {x.shape}") # (3, 4, 6)
# ---- BatchNorm ----
# Normalizes across the BATCH dimension (for each feature independently)
# Computes mean and variance across batch dimension
# Used in CNNs, NOT suitable for transformers (variable sequence lengths)
# ---- LayerNorm ----
# Normalizes across the FEATURE dimension (for each sample independently)
# Computes mean and variance across the last dimension (d_model)
# Used in transformers
layer_norm = nn.LayerNorm(6) # Normalize across last dimension (d_model=6)
# For each position in each example:
# 1. Compute mean across 6 dimensions
# 2. Compute variance across 6 dimensions
# 3. Normalize: (x - mean) / sqrt(variance + epsilon)
# 4. Scale and shift: gamma * normalized + beta (learned parameters)
output = layer_norm(x)
print(f"\nLayerNorm output shape: {output.shape}")
# Check: each position should have mean≈0 and std≈1
print(f"\nFirst position, first example:")
print(f" Before: mean={x[0,0].mean():.2f}, std={x[0,0].std():.2f}")
print(f" After: mean={output[0,0].mean():.4f}, std={output[0,0].std():.4f}")
6.3 Pre-Norm vs Post-Norm
"""
Pre-Norm vs Post-Norm
======================
The placement of LayerNorm relative to attention/FFN matters a lot!
"""
import torch
import torch.nn as nn
class PostNormTransformerBlock(nn.Module):
"""
ORIGINAL Transformer (2017): Post-Norm
x -> Attention -> Add(x) -> LayerNorm -> FFN -> Add -> LayerNorm
Issues:
- Harder to train at large scale
- Requires careful learning rate warmup
- Can be unstable with deep models
"""
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_heads)
self.ffn = FeedForwardNetwork(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# Attention with residual, THEN normalize
attn_out, _ = self.attn(x, mask)
x = self.norm1(x + attn_out)
# FFN with residual, THEN normalize
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
return x
class PreNormTransformerBlock(nn.Module):
"""
MODERN approach (GPT-2+, Llama, etc.): Pre-Norm
x -> LayerNorm -> Attention -> Add(x) -> LayerNorm -> FFN -> Add(x)
Benefits:
- More stable training (gradients flow through residual path unchanged)
- Works with larger learning rates
- Easier to scale to very deep models
- Used by GPT-2, GPT-3, Llama, most modern LLMs
"""
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_heads)
self.ffn = FeedForwardNetwork(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, mask=None):
# Normalize FIRST, then attention, then residual
attn_out, _ = self.attn(self.norm1(x), mask)
x = x + attn_out # Residual connection (gradient highway)
# Normalize FIRST, then FFN, then residual
ffn_out = self.ffn(self.norm2(x))
x = x + ffn_out
return x
# Why Pre-Norm is better for deep models:
# In Post-Norm, gradients must flow THROUGH the LayerNorm
# In Pre-Norm, there's a "clean" residual path where gradients
# flow directly from output to input without any transformation.
# This is critical for training 100+ layer models.
6.4 RMSNorm (Used in Llama)
"""
RMSNorm (Root Mean Square Layer Normalization)
================================================
Used in Llama 2, Llama 3, Mistral, and most modern open LLMs.
Simpler and faster than standard LayerNorm:
- No mean subtraction (skip the centering step)
- Only divides by RMS (root mean square)
- Fewer operations, similar or better performance
"""
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
Standard LayerNorm: (x - mean) / sqrt(var + eps) * gamma + beta
RMSNorm: x / sqrt(mean(x^2) + eps) * gamma
Key differences:
1. No mean subtraction (no centering)
2. No beta (bias) parameter
3. Divides by RMS instead of standard deviation
4. ~10-15% faster than LayerNorm
"""
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(d_model))
def forward(self, x):
# Compute RMS
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
# Normalize and scale
return (x / rms) * self.gamma
# Compare
d_model = 768
x = torch.randn(2, 10, d_model)
layer_norm = nn.LayerNorm(d_model)
rms_norm = RMSNorm(d_model)
ln_out = layer_norm(x)
rms_out = rms_norm(x)
print(f"Input shape: {x.shape}")
print(f"LayerNorm output: mean={ln_out.mean():.6f}, std={ln_out.std():.4f}")
print(f"RMSNorm output: mean={rms_out.mean():.6f}, std={rms_out.std():.4f}")
# RMSNorm parameters (no bias):
ln_params = sum(p.numel() for p in layer_norm.parameters())
rms_params = sum(p.numel() for p in rms_norm.parameters())
print(f"\nLayerNorm parameters: {ln_params:,} (gamma + beta)")
print(f"RMSNorm parameters: {rms_params:,} (gamma only)")
7. Residual Connections
7.1 Skip Connections Explained
A residual (skip) connection adds the input of a layer directly to its output:
output = x + F(x)
Where F(x) is the transformation (attention or FFN) and x is the original input.
Instead of learning the full transformation, the layer only needs to learn the residual (the difference from the identity). This is much easier!
7.2 Why They Help with Gradient Flow
"""
Residual Connections and Gradient Flow
========================================
Without residual connections:
output = F_N(F_{N-1}(...F_2(F_1(x))...))
Gradient through N layers:
dL/dx = dL/dF_N * dF_N/dF_{N-1} * ... * dF_1/dx
If each dF_i/dF_{i-1} < 1, gradient VANISHES exponentially!
If each dF_i/dF_{i-1} > 1, gradient EXPLODES exponentially!
With residual connections:
output = x + F_N(x + F_{N-1}(x + ...))
Gradient includes a DIRECT path through the addition:
dL/dx = dL/d_out * (1 + dF/dx)
The "1" ensures gradient ALWAYS flows, even if dF/dx is tiny.
This is like a highway for gradients through the entire network.
"""
import torch
import torch.nn as nn
# Demonstrate gradient flow
class LayerWithoutResidual(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)
def forward(self, x):
return torch.tanh(self.linear(x)) # No residual
class LayerWithResidual(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)
def forward(self, x):
return x + torch.tanh(self.linear(x)) # With residual
# Stack 50 layers and check gradient magnitude
dim = 64
# Without residuals
x = torch.randn(1, dim, requires_grad=True)
out = x
layers_no_res = [LayerWithoutResidual(dim) for _ in range(50)]
for layer in layers_no_res:
out = layer(out)
loss = out.sum()
loss.backward()
print(f"Without residual connections (50 layers):")
print(f" Gradient magnitude: {x.grad.norm():.6e}")
# Will be extremely small (vanishing gradient)
# With residuals
x2 = torch.randn(1, dim, requires_grad=True)
out2 = x2
layers_res = [LayerWithResidual(dim) for _ in range(50)]
for layer in layers_res:
out2 = layer(out2)
loss2 = out2.sum()
loss2.backward()
print(f"\nWith residual connections (50 layers):")
print(f" Gradient magnitude: {x2.grad.norm():.6e}")
# Will be much healthier
8. Logits and Output
8.1 What Are Logits?
Logits are the raw, unnormalized scores output by the final layer of the model, before applying softmax. For a language model with vocabulary size V, the logits are a vector of V numbers, one per possible next token.
# For a model with vocab_size = 50,257 (GPT-2)
# Input: "The cat sat on the"
# Output logits: [2.1, -0.3, 0.8, ..., 1.5] (50,257 numbers)
# ^ "the" ^ "a" ^ "and" ^ "mat"
# Higher logit = model thinks this token is more likely next
# But logits aren't probabilities! They can be any real number.
# We need softmax to convert them to a probability distribution.
8.2 Temperature Parameter
P(tokeni) = softmax(logiti / T)
Where T is the temperature parameter (default: 1.0)
"""
Temperature Sampling
=====================
Temperature controls the "randomness" of generation.
"""
import numpy as np
def softmax_with_temperature(logits, temperature=1.0):
"""Apply temperature scaling before softmax."""
scaled = logits / temperature
exp_scaled = np.exp(scaled - np.max(scaled))
return exp_scaled / exp_scaled.sum()
# Example logits for next token
logits = np.array([2.0, 1.5, 1.0, 0.5, 0.0, -0.5, -1.0])
tokens = ["the", "a", "my", "his", "some", "that", "one"]
print("Temperature Effects on Token Probabilities:")
print(f"{'Token':>8}", end="")
for temp in [0.1, 0.5, 1.0, 1.5, 2.0]:
print(f" T={temp}", end="")
print()
for temps in [0.1, 0.5, 1.0, 1.5, 2.0]:
probs = softmax_with_temperature(logits, temps)
if temps == 0.1:
for i, token in enumerate(tokens):
print(f"{token:>8}", end="")
for t in [0.1, 0.5, 1.0, 1.5, 2.0]:
p = softmax_with_temperature(logits, t)
print(f" {p[i]:.3f}", end="")
print()
break
# T=0.1: Almost deterministic (highest logit gets ~100%)
# T=0.5: More focused (top tokens dominate)
# T=1.0: Standard (balanced distribution)
# T=1.5: More random (flatter distribution)
# T=2.0: Very random (nearly uniform)
# T -> 0: Greedy decoding (always pick the most likely token)
# T -> inf: Random sampling (uniform distribution)
8.3 Top-k and Top-p (Nucleus) Sampling
"""
Top-k and Top-p Sampling
==========================
These are FILTERING strategies applied before sampling.
"""
import numpy as np
def top_k_sampling(logits, k=5, temperature=1.0):
"""
Top-k sampling: only consider the k most likely tokens.
1. Sort tokens by probability
2. Keep only the top k
3. Set all others to -infinity (zero probability)
4. Re-normalize and sample
Example: k=5 means only the 5 most likely tokens are candidates.
"""
# Apply temperature
scaled_logits = logits / temperature
# Find top-k indices
top_k_indices = np.argsort(scaled_logits)[-k:]
# Create filtered logits (everything else = -inf)
filtered = np.full_like(scaled_logits, -np.inf)
filtered[top_k_indices] = scaled_logits[top_k_indices]
# Softmax to get probabilities
probs = np.exp(filtered - np.max(filtered))
probs = probs / probs.sum()
# Sample
chosen = np.random.choice(len(logits), p=probs)
return chosen, probs
def top_p_sampling(logits, p=0.9, temperature=1.0):
"""
Top-p (nucleus) sampling: keep the smallest set of tokens whose
cumulative probability exceeds p.
1. Sort tokens by probability (descending)
2. Compute cumulative sum
3. Find cutoff where cumsum >= p
4. Keep all tokens before the cutoff
5. Re-normalize and sample
Example: p=0.9 means keep tokens until their cumulative probability
reaches 90%. This ADAPTS the number of candidates:
- If model is confident: might only keep 2-3 tokens
- If model is uncertain: might keep 50+ tokens
"""
# Apply temperature and softmax
scaled_logits = logits / temperature
probs = np.exp(scaled_logits - np.max(scaled_logits))
probs = probs / probs.sum()
# Sort in descending order
sorted_indices = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_indices]
# Compute cumulative probabilities
cumulative_probs = np.cumsum(sorted_probs)
# Find cutoff index (first index where cumsum >= p)
cutoff_idx = np.searchsorted(cumulative_probs, p) + 1
# Keep only tokens within the nucleus
nucleus_indices = sorted_indices[:cutoff_idx]
nucleus_probs = probs[nucleus_indices]
nucleus_probs = nucleus_probs / nucleus_probs.sum() # Re-normalize
# Sample
chosen_idx = np.random.choice(len(nucleus_indices), p=nucleus_probs)
chosen = nucleus_indices[chosen_idx]
return chosen, cutoff_idx, nucleus_probs
# Demonstrate
np.random.seed(42)
# Scenario 1: Model is confident
confident_logits = np.array([5.0, 2.0, 1.0, 0.5, 0.1, -1.0, -2.0, -3.0])
tokens = ["Paris", "London", "Berlin", "Tokyo", "Rome", "Madrid", "Oslo", "Lisbon"]
print("SCENARIO 1: Model is confident")
print("Prompt: 'The capital of France is'")
probs = np.exp(confident_logits - np.max(confident_logits))
probs = probs / probs.sum()
for i, (token, prob) in enumerate(zip(tokens, probs)):
print(f" {token:10s}: {prob:.4f} {'<--- dominant' if i == 0 else ''}")
_, n_nucleus, _ = top_p_sampling(confident_logits, p=0.9)
print(f"Top-p=0.9 keeps {n_nucleus} tokens (very focused)")
# Scenario 2: Model is uncertain
uncertain_logits = np.array([1.5, 1.4, 1.3, 1.2, 1.1, 1.0, 0.9, 0.8])
print("\nSCENARIO 2: Model is uncertain")
print("Prompt: 'I want to eat'")
probs = np.exp(uncertain_logits - np.max(uncertain_logits))
probs = probs / probs.sum()
food_tokens = ["pizza", "sushi", "pasta", "tacos", "salad", "steak", "curry", "soup"]
for token, prob in zip(food_tokens, probs):
print(f" {token:10s}: {prob:.4f}")
_, n_nucleus, _ = top_p_sampling(uncertain_logits, p=0.9)
print(f"Top-p=0.9 keeps {n_nucleus} tokens (more diverse)")
# Top-p ADAPTS: fewer candidates when confident, more when uncertain
# This is why top-p is generally preferred over top-k in practice
"""
Complete Sampling Function
============================
Putting it all together: temperature + top-k + top-p
"""
import numpy as np
def sample_next_token(logits, temperature=1.0, top_k=0, top_p=1.0):
"""
Sample the next token from logits with various strategies.
This is what happens at each step of LLM text generation.
Args:
logits: Raw model output (vocab_size,)
temperature: Scaling factor (higher = more random)
top_k: If > 0, only keep top-k tokens (0 = disabled)
top_p: If < 1.0, nucleus sampling threshold (1.0 = disabled)
Returns:
selected_token: Index of the chosen token
"""
# Step 1: Apply temperature
if temperature == 0:
# Greedy decoding: always pick the most likely
return np.argmax(logits)
scaled_logits = logits / temperature
# Step 2: Apply top-k filtering
if top_k > 0:
top_k_indices = np.argsort(scaled_logits)[-top_k:]
mask = np.full_like(scaled_logits, -np.inf)
mask[top_k_indices] = scaled_logits[top_k_indices]
scaled_logits = mask
# Step 3: Convert to probabilities
probs = np.exp(scaled_logits - np.max(scaled_logits))
probs = probs / probs.sum()
# Step 4: Apply top-p (nucleus) filtering
if top_p < 1.0:
sorted_indices = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_indices]
cumulative = np.cumsum(sorted_probs)
# Find cutoff
cutoff = np.searchsorted(cumulative, top_p) + 1
# Zero out tokens beyond the nucleus
removed_indices = sorted_indices[cutoff:]
probs[removed_indices] = 0
# Re-normalize
probs = probs / probs.sum()
# Step 5: Sample from the distribution
selected = np.random.choice(len(probs), p=probs)
return selected
# Demonstrate different sampling strategies
np.random.seed(42)
vocab_size = 10
logits = np.random.randn(vocab_size) * 2 # Random logits
token_names = [f"tok_{i}" for i in range(vocab_size)]
print("Sampling Strategy Comparison (100 samples each):")
print("=" * 60)
strategies = [
{"name": "Greedy (T=0)", "temperature": 0, "top_k": 0, "top_p": 1.0},
{"name": "Low temp (T=0.3)", "temperature": 0.3, "top_k": 0, "top_p": 1.0},
{"name": "Standard (T=1.0)", "temperature": 1.0, "top_k": 0, "top_p": 1.0},
{"name": "Top-k=3, T=1.0", "temperature": 1.0, "top_k": 3, "top_p": 1.0},
{"name": "Top-p=0.9, T=1.0", "temperature": 1.0, "top_k": 0, "top_p": 0.9},
{"name": "Top-k=5, Top-p=0.9", "temperature": 1.0, "top_k": 5, "top_p": 0.9},
]
for strategy in strategies:
counts = np.zeros(vocab_size)
for _ in range(1000):
token = sample_next_token(
logits,
temperature=strategy["temperature"],
top_k=strategy["top_k"],
top_p=strategy["top_p"]
)
counts[token] += 1
# Show distribution
probs = counts / counts.sum()
n_unique = (counts > 0).sum()
top_token = token_names[np.argmax(counts)]
print(f"\n{strategy['name']:30s} | Unique tokens: {n_unique} | "
f"Top: {top_token} ({probs[np.argmax(counts)]:.1%})")
Week 3 Summary
Key Takeaways
- The Transformer has encoder-decoder structure, but decoder-only variants (GPT-style) dominate modern LLMs.
- Multi-head attention splits attention into parallel heads, each learning different relationship types. Output = Concat(heads) W_O.
- Q, K, V are learned linear projections of the input. Q asks "what am I looking for?", K says "what do I have?", V provides the content.
- Cross-attention lets the decoder attend to encoder output: Q from decoder, K/V from encoder.
- FFNs provide non-linear transformation after attention. Modern models use SwiGLU activation.
- Layer normalization stabilizes training. Pre-norm (norm before attention/FFN) is preferred. RMSNorm is the modern standard.
- Residual connections enable gradient flow through deep networks by providing shortcut paths.
- Temperature, top-k, top-p control the diversity of generated text.
Exercises
- Implement a complete Transformer encoder block (multi-head attention + FFN + LayerNorm + residual connections) in PyTorch.
- Experiment with different numbers of attention heads. How does model performance change?
- Implement RMSNorm and compare its speed to standard LayerNorm using PyTorch benchmarking.
- Write a function that generates text using temperature + top-p sampling. Compare outputs at different temperatures.
- Calculate the total parameters for a Transformer with d_model=512, n_heads=8, d_ff=2048, n_layers=6.