Week 3: Transformer Architecture Internals

A deep dive into every component of the Transformer — the architecture behind GPT, Claude, Llama, and every major LLM. Understand it inside and out.

Difficulty: Intermediate

1. The Transformer Architecture

1.1 "Attention Is All You Need"

Published in June 2017 by Vaswani et al. at Google, this paper introduced the Transformer architecture. It is arguably the most impactful machine learning paper ever written. Every major language model since — GPT, BERT, T5, Llama, Claude, Gemini — is based on the Transformer.

1.2 Full Architecture Overview

The original Transformer has an encoder-decoder structure, designed for sequence-to-sequence tasks like machine translation.


┌─────────────────────────────────────────────────────────────────────┐
│                     THE TRANSFORMER ARCHITECTURE                     │
│                                                                       │
│   ENCODER (left side)              DECODER (right side)              │
│   ┌─────────────────────┐          ┌─────────────────────┐          │
│   │                     │          │                     │          │
│   │  ┌───────────────┐  │    x N   │  ┌───────────────┐  │    x N  │
│   │  │ Feed-Forward   │  │          │  │ Feed-Forward   │  │          │
│   │  │ Network (FFN)  │  │          │  │ Network (FFN)  │  │          │
│   │  └───────┬───────┘  │          │  └───────┬───────┘  │          │
│   │     Add & Norm      │          │     Add & Norm      │          │
│   │  ┌───────────────┐  │          │  ┌───────────────┐  │          │
│   │  │  Multi-Head    │  │          │  │  Cross-        │  │          │
│   │  │  Self-Attention│  │   ──────►│  │  Attention     │  │          │
│   │  └───────┬───────┘  │  (K, V)  │  │  (Enc → Dec)  │  │          │
│   │     Add & Norm      │          │  └───────┬───────┘  │          │
│   │                     │          │     Add & Norm      │          │
│   └─────────────────────┘          │  ┌───────────────┐  │          │
│                                     │  │  Masked Multi- │  │          │
│                                     │  │  Head Self-    │  │          │
│                                     │  │  Attention     │  │          │
│                                     │  └───────┬───────┘  │          │
│                                     │     Add & Norm      │          │
│                                     └─────────────────────┘          │
│                                                                       │
│   ┌─────────────────────┐          ┌─────────────────────┐          │
│   │ Input Embedding     │          │ Output Embedding    │          │
│   │ + Positional Enc.   │          │ + Positional Enc.   │          │
│   └─────────────────────┘          └─────────────────────┘          │
│                                                                       │
│        INPUT TOKENS                      OUTPUT TOKENS                │
│   "The cat sat on the mat"          "Le chat ..." (shifted right)    │
└─────────────────────────────────────────────────────────────────────┘
Architecture Variant Structure Examples Best For
Encoder-Only Only the left side; bidirectional attention BERT, RoBERTa, DeBERTa Classification, NER, sentence embeddings
Decoder-Only Only the right side; causal (left-to-right) attention GPT-1/2/3/4, Llama, Claude, DeepSeek Text generation, chatbots, code generation
Encoder-Decoder Both sides; cross-attention between them T5, BART, mBART, Whisper Translation, summarization, speech-to-text
Transformer Architecture Overview
graph TD subgraph Encoder EI["Input Embedding
+ Positional Encoding"] --> ESA["Multi-Head
Self-Attention"] ESA --> EAN1["Add and Norm"] EAN1 --> EFFN["Feed-Forward
Network"] EFFN --> EAN2["Add and Norm"] end subgraph Decoder DI["Output Embedding
+ Positional Encoding"] --> DSA["Masked Multi-Head
Self-Attention"] DSA --> DAN1["Add and Norm"] DAN1 --> DCA["Cross-Attention
Q from Decoder, K V from Encoder"] DCA --> DAN2["Add and Norm"] DAN2 --> DFFN["Feed-Forward
Network"] DFFN --> DAN3["Add and Norm"] end EAN2 --> DCA DAN3 --> LIN["Linear + Softmax"] LIN --> OUT["Output Probabilities"] style EI fill:#d5e8d4,stroke:#333 style DI fill:#dae8fc,stroke:#333 style ESA fill:#fff2cc,stroke:#333 style DSA fill:#fff2cc,stroke:#333 style DCA fill:#e1d5e7,stroke:#333 style OUT fill:#f8cecc,stroke:#333
Why Decoder-Only Won: As of 2025-2026, virtually all frontier LLMs use the decoder-only architecture. Why?
  • Simpler architecture (one stack instead of two)
  • Scales better with more parameters and data
  • Autoregressive generation is natural and flexible
  • Can handle any task framed as text generation
  • More efficient to train (single forward pass per token)

1.3 Why Transformers Replaced RNNs/LSTMs

Aspect RNNs/LSTMs Transformers
Parallelization Sequential processing (token by token). Cannot parallelize within a sequence. Fully parallel (all tokens processed simultaneously). Massive GPU utilization.
Long-range dependencies Information degrades over long distances (vanishing gradients). LSTM helps but doesn't solve. Direct attention between ANY two tokens, regardless of distance. O(1) path length.
Training speed Slow (sequential bottleneck) Fast (parallel computation on GPUs)
Memory O(1) per step (hidden state) O(n2) for attention (quadratic in sequence length)
Scalability Difficult to scale beyond ~1B parameters Scales to hundreds of billions (proven by GPT-3/4, Llama, etc.)

2. Multi-Head Attention

2.1 Why Multiple Heads?

A single attention head can only focus on one type of relationship at a time. But language has many simultaneous relationships: syntactic (subject-verb agreement), semantic (word meaning), coreference ("it" refers to "cat"), positional patterns, etc.

Multi-head attention runs several attention operations in parallel, each learning to focus on different aspects of the input. It's like having multiple "experts" each looking at the data from a different angle.

MultiHead(Q, K, V) = Concat(head1, ..., headh) WO

where headi = Attention(Q WiQ, K WiK, V WiV)

Multi-Head Attention Mechanism
graph TD X["Input X
shape: seq_len x d_model"] --> S1["Split into h Heads"] S1 --> H1["Head 1
Attention on d_k dims"] S1 --> H2["Head 2
Attention on d_k dims"] S1 --> H3["..."] S1 --> Hh["Head h
Attention on d_k dims"] H1 --> CAT["Concatenate All Heads
Back to d_model dims"] H2 --> CAT H3 --> CAT Hh --> CAT CAT --> WO["Output Projection W_O
Linear layer"] WO --> OUT["Multi-Head Output"] style X fill:#e8f4f8,stroke:#333 style H1 fill:#dae8fc,stroke:#333 style H2 fill:#d5e8d4,stroke:#333 style H3 fill:#fff2cc,stroke:#333 style Hh fill:#e1d5e7,stroke:#333 style CAT fill:#f8cecc,stroke:#333 style OUT fill:#d5e8d4,stroke:#333

2.2 How Heads Are Split and Concatenated

"""
Multi-Head Attention - How it works internally
================================================

Key insight: Instead of doing attention with d_model dimensions,
we split into h heads, each operating on d_model/h dimensions.
Total computation is the same, but we get h different attention patterns.
"""

import numpy as np

# Example configuration (like BERT-base)
d_model = 768    # Model dimension
n_heads = 12     # Number of attention heads
d_k = d_model // n_heads  # = 64 per head
d_v = d_model // n_heads  # = 64 per head

print(f"d_model = {d_model}")
print(f"n_heads = {n_heads}")
print(f"d_k per head = {d_k}")
print(f"d_v per head = {d_v}")

# The process:
# 1. Project input X (seq_len, 768) to Q, K, V using large matrices
# 2. Split Q, K, V into 12 heads of dimension 64 each
# 3. Run attention independently on each head
# 4. Concatenate the 12 outputs back to dimension 768
# 5. Apply final linear projection W_O

# Step-by-step shapes:
seq_len = 10
batch_size = 1

print(f"\nShape trace:")
print(f"Input X:          ({batch_size}, {seq_len}, {d_model})")
print(f"After W_Q:        ({batch_size}, {seq_len}, {d_model})")
print(f"Reshape to heads: ({batch_size}, {n_heads}, {seq_len}, {d_k})")
print(f"Attention per head:({batch_size}, {n_heads}, {seq_len}, {d_v})")
print(f"Concat heads:     ({batch_size}, {seq_len}, {d_model})")
print(f"After W_O:        ({batch_size}, {seq_len}, {d_model})")

2.3 PRACTICAL: Multi-Head Attention in PyTorch

"""
Complete Multi-Head Attention Implementation in PyTorch
========================================================
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    """
    Multi-Head Self-Attention mechanism.

    This is the core component of every Transformer layer.
    """

    def __init__(self, d_model, n_heads, dropout=0.1):
        """
        Args:
            d_model: Total model dimension (e.g., 768)
            n_heads: Number of attention heads (e.g., 12)
            dropout: Dropout rate for attention weights
        """
        super().__init__()

        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # Dimension per head

        # Linear projections for Q, K, V
        # Each is (d_model, d_model) - we project all heads at once
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)

        # Output projection
        self.W_O = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x):
        """
        Split the last dimension into (n_heads, d_k).

        Input:  (batch, seq_len, d_model)
        Output: (batch, n_heads, seq_len, d_k)
        """
        batch_size, seq_len, _ = x.shape
        # Reshape: (batch, seq_len, n_heads, d_k)
        x = x.view(batch_size, seq_len, self.n_heads, self.d_k)
        # Transpose: (batch, n_heads, seq_len, d_k)
        return x.transpose(1, 2)

    def combine_heads(self, x):
        """
        Reverse of split_heads.

        Input:  (batch, n_heads, seq_len, d_k)
        Output: (batch, seq_len, d_model)
        """
        batch_size, _, seq_len, _ = x.shape
        # Transpose: (batch, seq_len, n_heads, d_k)
        x = x.transpose(1, 2)
        # Reshape: (batch, seq_len, d_model)
        return x.contiguous().view(batch_size, seq_len, self.d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        """
        Compute scaled dot-product attention.

        Args:
            Q: (batch, n_heads, seq_len, d_k)
            K: (batch, n_heads, seq_len, d_k)
            V: (batch, n_heads, seq_len, d_k)
            mask: (1, 1, seq_len, seq_len) or None

        Returns:
            context: (batch, n_heads, seq_len, d_k)
            attention_weights: (batch, n_heads, seq_len, seq_len)
        """
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, heads, seq, seq)
        scores = scores / math.sqrt(self.d_k)

        # Apply mask (for causal attention or padding)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Softmax to get attention weights
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Weighted sum of values
        context = torch.matmul(attention_weights, V)

        return context, attention_weights

    def forward(self, x, mask=None):
        """
        Forward pass for multi-head self-attention.

        Args:
            x: Input tensor (batch, seq_len, d_model)
            mask: Optional attention mask

        Returns:
            output: (batch, seq_len, d_model)
            attention_weights: (batch, n_heads, seq_len, seq_len)
        """
        # Step 1: Linear projections
        Q = self.W_Q(x)  # (batch, seq_len, d_model)
        K = self.W_K(x)
        V = self.W_V(x)

        # Step 2: Split into multiple heads
        Q = self.split_heads(Q)  # (batch, n_heads, seq_len, d_k)
        K = self.split_heads(K)
        V = self.split_heads(V)

        # Step 3: Scaled dot-product attention (parallel for all heads)
        context, attention_weights = self.scaled_dot_product_attention(
            Q, K, V, mask
        )

        # Step 4: Combine heads
        context = self.combine_heads(context)  # (batch, seq_len, d_model)

        # Step 5: Final linear projection
        output = self.W_O(context)  # (batch, seq_len, d_model)

        return output, attention_weights


# ------- Test the implementation -------
torch.manual_seed(42)

batch_size = 2
seq_len = 10
d_model = 64
n_heads = 8

x = torch.randn(batch_size, seq_len, d_model)
mha = MultiHeadAttention(d_model=d_model, n_heads=n_heads, dropout=0.0)

output, weights = mha(x)

print(f"Input shape:       {x.shape}")         # (2, 10, 64)
print(f"Output shape:      {output.shape}")     # (2, 10, 64)
print(f"Weights shape:     {weights.shape}")    # (2, 8, 10, 10)

# Verify each head has different attention patterns
print(f"\nAttention weights for head 0 (first example, first 5 tokens):")
print(weights[0, 0, :5, :5].detach().numpy().round(3))

print(f"\nAttention weights for head 3 (same positions):")
print(weights[0, 3, :5, :5].detach().numpy().round(3))

# These should be DIFFERENT - each head learns different patterns!

# Parameter count
total_params = sum(p.numel() for p in mha.parameters())
print(f"\nTotal parameters: {total_params:,}")
print(f"  W_Q: {d_model * d_model:,}")
print(f"  W_K: {d_model * d_model:,}")
print(f"  W_V: {d_model * d_model:,}")
print(f"  W_O: {d_model * d_model:,}")
print(f"  Total: 4 * d_model^2 = {4 * d_model * d_model:,}")

3. Q, K, V Matrices Deep Dive

3.1 How Q, K, V Are Computed

"""
Q, K, V Computation - Detailed Breakdown
==========================================
"""

import torch
import torch.nn as nn

# Configuration
d_model = 512   # Model dimension
n_heads = 8     # Number of heads
d_k = d_model // n_heads  # = 64 per head

# Input: sequence of token embeddings
# Shape: (batch_size, seq_len, d_model)
batch_size = 1
seq_len = 4
x = torch.randn(batch_size, seq_len, d_model)

# Weight matrices
# Shape: (d_model, d_model) = (512, 512) for each of Q, K, V
W_Q = nn.Linear(d_model, d_model, bias=False)
W_K = nn.Linear(d_model, d_model, bias=False)
W_V = nn.Linear(d_model, d_model, bias=False)

print("COMPUTING Q, K, V:")
print(f"  Input x shape: {x.shape}")
print(f"  W_Q weight shape: {W_Q.weight.shape}")

# Matrix multiplication: (1, 4, 512) @ (512, 512) = (1, 4, 512)
Q = W_Q(x)  # Each token's embedding is projected to a Query vector
K = W_K(x)  # Each token's embedding is projected to a Key vector
V = W_V(x)  # Each token's embedding is projected to a Value vector

print(f"\n  Q shape: {Q.shape}")  # (1, 4, 512)
print(f"  K shape: {K.shape}")
print(f"  V shape: {V.shape}")

# Now reshape for multi-head attention
# (1, 4, 512) -> (1, 4, 8, 64) -> (1, 8, 4, 64)
Q_heads = Q.view(batch_size, seq_len, n_heads, d_k).transpose(1, 2)
K_heads = K.view(batch_size, seq_len, n_heads, d_k).transpose(1, 2)
V_heads = V.view(batch_size, seq_len, n_heads, d_k).transpose(1, 2)

print(f"\n  Q_heads shape: {Q_heads.shape}")  # (1, 8, 4, 64)
print(f"  K_heads shape: {K_heads.shape}")
print(f"  V_heads shape: {V_heads.shape}")

# Each head operates on 64-dimensional slices independently
# Head 0 uses Q_heads[:, 0, :, :] which is (1, 4, 64)
# Head 1 uses Q_heads[:, 1, :, :] which is (1, 4, 64)
# ... etc.

print(f"\n  Head 0 Q slice shape: {Q_heads[:, 0, :, :].shape}")  # (1, 4, 64)
print(f"  Head 0 K slice shape: {K_heads[:, 0, :, :].shape}")
print(f"  Head 0 V slice shape: {V_heads[:, 0, :, :].shape}")
Q, K, V Matrix Computation Flow
graph LR X["Input X
seq_len x d_model"] --> WQ["W_Q Weight Matrix
d_model x d_k"] X --> WK["W_K Weight Matrix
d_model x d_k"] X --> WV["W_V Weight Matrix
d_model x d_v"] WQ --> Q["Q = X times W_Q
What am I looking for"] WK --> K["K = X times W_K
What do I contain"] WV --> V["V = X times W_V
What info do I carry"] Q --> ATT["Scaled Dot-Product
Attention"] K --> ATT V --> ATT ATT --> O["Attention Output"] style X fill:#e8f4f8,stroke:#333 style Q fill:#dae8fc,stroke:#333 style K fill:#d5e8d4,stroke:#333 style V fill:#fff2cc,stroke:#333 style ATT fill:#e1d5e7,stroke:#333 style O fill:#f8cecc,stroke:#333

3.2 Intuition for What Q, K, V Represent

Think of it as a database query analogy:

  • Query (Q): "I am looking for context about [this concept]." Each token generates a query vector representing what information it needs.
  • Key (K): "I contain information about [this concept]." Each token generates a key vector advertising what information it can provide.
  • Value (V): "Here is the actual information I carry." The content that gets passed along when attention is paid.

The dot product Q · K measures relevance (how well the query matches the key). The result weights how much of each V to include in the output.

Crucially: Q, K, and V are all derived from the same input (in self-attention) but through DIFFERENT learned projections. This means the model can learn that the same word plays different roles as a query vs. a key vs. a value.

3.3 Worked Example with Tensor Shapes

"""
Complete Shape Trace Through Multi-Head Attention
===================================================
Model: d_model=768, n_heads=12, d_k=64
Sequence: 5 tokens
"""

import torch

# Configuration (BERT-base / GPT-2 scale)
batch = 2
seq = 5
d_model = 768
n_heads = 12
d_k = d_model // n_heads  # 64

print("=" * 60)
print("COMPLETE SHAPE TRACE")
print("=" * 60)

# Input embeddings (after token embedding + positional encoding)
x = torch.randn(batch, seq, d_model)
print(f"\n1. Input X: {x.shape}")
print(f"   Meaning: {batch} examples, {seq} tokens each, {d_model}-dim embeddings")

# Linear projections
W_Q = torch.randn(d_model, d_model)
Q = x @ W_Q  # (2, 5, 768) @ (768, 768) = (2, 5, 768)
print(f"\n2. Q = X @ W_Q: {Q.shape}")
print(f"   Similarly K, V: ({batch}, {seq}, {d_model})")

# Split into heads
Q_heads = Q.view(batch, seq, n_heads, d_k).transpose(1, 2)
print(f"\n3. Split into {n_heads} heads:")
print(f"   Q_heads: {Q_heads.shape}")
print(f"   Meaning: {batch} examples, {n_heads} heads, {seq} tokens, {d_k}-dim per head")

# Attention scores for each head
K_heads = torch.randn(batch, n_heads, seq, d_k)  # placeholder
scores = Q_heads @ K_heads.transpose(-2, -1)
print(f"\n4. Attention scores = Q_heads @ K_heads^T:")
print(f"   scores: {scores.shape}")
print(f"   Meaning: for each head, a {seq}x{seq} attention matrix")

# After softmax
weights = torch.softmax(scores / (d_k ** 0.5), dim=-1)
print(f"\n5. Attention weights (after softmax): {weights.shape}")

# Weighted sum of values
V_heads = torch.randn(batch, n_heads, seq, d_k)
context = weights @ V_heads
print(f"\n6. Context = weights @ V_heads: {context.shape}")
print(f"   Meaning: weighted combination of values for each head")

# Combine heads
combined = context.transpose(1, 2).contiguous().view(batch, seq, d_model)
print(f"\n7. Combine heads: {combined.shape}")
print(f"   Back to original dimensions!")

# Output projection
W_O = torch.randn(d_model, d_model)
output = combined @ W_O
print(f"\n8. Output = combined @ W_O: {output.shape}")
print(f"   Same shape as input! This allows residual connections.")

4. Cross-Attention

4.1 How the Decoder Attends to Encoder Output

In encoder-decoder models (T5, BART, translation models), the decoder needs information from the encoder. Cross-attention is how this happens:

  • Queries (Q) come from the decoder (what the decoder is looking for)
  • Keys (K) and Values (V) come from the encoder output (what the source has)
"""
Cross-Attention Implementation
================================
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class CrossAttention(nn.Module):
    """
    Cross-attention: decoder attends to encoder output.

    Difference from self-attention:
    - In self-attention: Q, K, V all come from the same source
    - In cross-attention: Q comes from decoder, K and V from encoder
    """

    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_Q = nn.Linear(d_model, d_model, bias=False)  # Projects decoder states
        self.W_K = nn.Linear(d_model, d_model, bias=False)  # Projects encoder output
        self.W_V = nn.Linear(d_model, d_model, bias=False)  # Projects encoder output
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, decoder_state, encoder_output):
        """
        Args:
            decoder_state:  (batch, decoder_seq_len, d_model) - from previous decoder layer
            encoder_output: (batch, encoder_seq_len, d_model) - final encoder output

        Returns:
            output: (batch, decoder_seq_len, d_model)
        """
        batch_size = decoder_state.size(0)
        dec_len = decoder_state.size(1)
        enc_len = encoder_output.size(1)

        # Q from decoder, K and V from encoder
        Q = self.W_Q(decoder_state)     # (batch, dec_len, d_model)
        K = self.W_K(encoder_output)    # (batch, enc_len, d_model)
        V = self.W_V(encoder_output)    # (batch, enc_len, d_model)

        # Split heads
        Q = Q.view(batch_size, dec_len, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, enc_len, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, enc_len, self.n_heads, self.d_k).transpose(1, 2)

        # Attention: (batch, heads, dec_len, d_k) @ (batch, heads, d_k, enc_len)
        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
        # Shape: (batch, heads, dec_len, enc_len)
        # This tells us: for each decoder position, how much to attend to
        # each encoder position

        weights = F.softmax(scores, dim=-1)
        context = weights @ V  # (batch, heads, dec_len, d_k)

        # Combine heads
        context = context.transpose(1, 2).contiguous()
        context = context.view(batch_size, dec_len, -1)

        return self.W_O(context)


# Example: French to English translation
d_model = 256
n_heads = 8

cross_attn = CrossAttention(d_model, n_heads)

# Encoder processes: "Le chat est sur le tapis" (6 tokens)
encoder_output = torch.randn(1, 6, d_model)

# Decoder is generating: "The cat is" (3 tokens so far)
decoder_state = torch.randn(1, 3, d_model)

output = cross_attn(decoder_state, encoder_output)
print(f"Encoder output shape: {encoder_output.shape}")  # (1, 6, 256)
print(f"Decoder state shape:  {decoder_state.shape}")   # (1, 3, 256)
print(f"Cross-attention output: {output.shape}")         # (1, 3, 256)

# The decoder token "is" can now attend to all French source tokens
# to figure out what English word comes next

5. Feed-Forward Neural Networks in Transformers

5.1 Position-wise FFN

Original Transformer FFN:

FFN(x) = max(0, xW1 + b1)W2 + b2

Or equivalently: FFN(x) = ReLU(xW1 + b1)W2 + b2

The FFN is applied to each position independently (same weights for all positions).

Why FFNs are needed: Attention is great at mixing information between positions, but it's fundamentally a weighted average — a linear operation. The FFN provides non-linear transformation that allows the model to compute complex functions of the attention output.

5.2 Hidden Dimension Expansion

"""
Feed-Forward Network in Transformers
======================================
The FFN typically EXPANDS the dimension by 4x, then projects back.
"""

import torch
import torch.nn as nn

class FeedForwardNetwork(nn.Module):
    """
    Position-wise Feed-Forward Network.

    FFN(x) = activation(x @ W1 + b1) @ W2 + b2

    The hidden dimension is typically 4x the model dimension.
    This expansion-contraction creates a "bottleneck" that forces
    the model to learn compressed representations.
    """

    def __init__(self, d_model, d_ff=None, dropout=0.1, activation='relu'):
        """
        Args:
            d_model: Input/output dimension (e.g., 768)
            d_ff: Hidden dimension (default: 4 * d_model = 3072)
            dropout: Dropout rate
            activation: 'relu' or 'gelu' or 'swiglu'
        """
        super().__init__()

        self.d_ff = d_ff or 4 * d_model

        self.W1 = nn.Linear(d_model, self.d_ff)
        self.W2 = nn.Linear(self.d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        else:
            self.activation = nn.ReLU()

    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model)
        Returns:
            output: (batch, seq_len, d_model)
        """
        # Expand: (batch, seq, 768) -> (batch, seq, 3072)
        hidden = self.W1(x)
        hidden = self.activation(hidden)
        hidden = self.dropout(hidden)

        # Contract: (batch, seq, 3072) -> (batch, seq, 768)
        output = self.W2(hidden)
        return output


# Example
d_model = 768
ffn = FeedForwardNetwork(d_model, activation='gelu')

x = torch.randn(2, 10, d_model)
output = ffn(x)

print(f"Input:  {x.shape}")       # (2, 10, 768)
print(f"Output: {output.shape}")   # (2, 10, 768)

# Parameter count
params = sum(p.numel() for p in ffn.parameters())
print(f"\nFFN parameters: {params:,}")
print(f"  W1: {d_model} x {4*d_model} + bias = {d_model * 4*d_model + 4*d_model:,}")
print(f"  W2: {4*d_model} x {d_model} + bias = {4*d_model * d_model + d_model:,}")
# FFN has MORE parameters than the attention layer!
# W1: 768 * 3072 = 2,359,296
# W2: 3072 * 768 = 2,359,296
# Total: ~4.7M (vs ~2.4M for attention)

5.3 SwiGLU Activation (Modern Models)

"""
SwiGLU Activation - Used in Llama, Mistral, and most modern LLMs
=================================================================
SwiGLU combines the Swish activation with a Gated Linear Unit.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLUFeedForward(nn.Module):
    """
    FFN with SwiGLU activation (used in Llama 2/3, Mistral, etc.)

    SwiGLU(x, W, V, b, c) = (Swish(xW + b)) * (xV + c)

    Key differences from standard FFN:
    1. Uses Swish (SiLU) instead of ReLU/GELU
    2. Has a "gate" mechanism (element-wise multiplication)
    3. Uses 3 weight matrices instead of 2
    4. Hidden dim is typically (2/3) * 4 * d_model to keep param count similar
    """

    def __init__(self, d_model, d_ff=None, dropout=0.1):
        super().__init__()

        # Llama convention: hidden_dim = (2/3) * 4 * d_model
        # rounded to nearest multiple of 256 for efficiency
        self.d_ff = d_ff or int(2 * (4 * d_model) / 3)
        # Round to multiple of 256
        self.d_ff = 256 * ((self.d_ff + 255) // 256)

        # Three projections instead of two
        self.w1 = nn.Linear(d_model, self.d_ff, bias=False)  # Gate projection
        self.w2 = nn.Linear(self.d_ff, d_model, bias=False)  # Down projection
        self.w3 = nn.Linear(d_model, self.d_ff, bias=False)  # Up projection

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        SwiGLU(x) = (SiLU(xW1)) * (xW3) then down-project with W2

        SiLU(x) = x * sigmoid(x)  (also called Swish)
        """
        # Gate: apply SiLU to first projection
        gate = F.silu(self.w1(x))  # (batch, seq, d_ff)

        # Up projection (no activation)
        up = self.w3(x)            # (batch, seq, d_ff)

        # Element-wise multiply (gating mechanism)
        hidden = gate * up         # (batch, seq, d_ff)
        hidden = self.dropout(hidden)

        # Down projection
        output = self.w2(hidden)   # (batch, seq, d_model)

        return output


# Compare parameter counts
d_model = 4096  # Llama 3.1 8B

standard_ffn = FeedForwardNetwork(d_model, d_ff=4*d_model)
swiglu_ffn = SwiGLUFeedForward(d_model)

standard_params = sum(p.numel() for p in standard_ffn.parameters())
swiglu_params = sum(p.numel() for p in swiglu_ffn.parameters())

print(f"Standard FFN (d_ff=4*d_model={4*d_model}):")
print(f"  Parameters: {standard_params:,}")

print(f"\nSwiGLU FFN (d_ff={swiglu_ffn.d_ff}):")
print(f"  Parameters: {swiglu_params:,}")

print(f"\nRatio: {swiglu_params/standard_params:.2f}x")
# SwiGLU has ~1.5x more params per FFN due to 3 matrices,
# but the hidden dim is smaller (2/3 * 4 * d_model) to compensate

6. Layer Normalization

6.1 Why Normalization?

During training, the distribution of layer inputs shifts as parameters are updated (called "internal covariate shift"). Normalization stabilizes training by ensuring that each layer's inputs have consistent statistics (mean and variance).

6.2 LayerNorm vs BatchNorm

"""
LayerNorm vs BatchNorm
========================
"""

import torch
import torch.nn as nn
import numpy as np

# Example tensor: (batch=3, seq_len=4, d_model=6)
x = torch.tensor([
    [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
     [2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
     [0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
     [3.0, 4.0, 5.0, 6.0, 7.0, 8.0]],

    [[1.5, 2.5, 3.5, 4.5, 5.5, 6.5],
     [0.5, 1.5, 2.5, 3.5, 4.5, 5.5],
     [2.5, 3.5, 4.5, 5.5, 6.5, 7.5],
     [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]],

    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
     [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
     [2.0, 2.0, 2.0, 2.0, 2.0, 2.0],
     [3.0, 3.0, 3.0, 3.0, 3.0, 3.0]],
])

print(f"Input shape: {x.shape}")  # (3, 4, 6)

# ---- BatchNorm ----
# Normalizes across the BATCH dimension (for each feature independently)
# Computes mean and variance across batch dimension
# Used in CNNs, NOT suitable for transformers (variable sequence lengths)

# ---- LayerNorm ----
# Normalizes across the FEATURE dimension (for each sample independently)
# Computes mean and variance across the last dimension (d_model)
# Used in transformers

layer_norm = nn.LayerNorm(6)  # Normalize across last dimension (d_model=6)

# For each position in each example:
# 1. Compute mean across 6 dimensions
# 2. Compute variance across 6 dimensions
# 3. Normalize: (x - mean) / sqrt(variance + epsilon)
# 4. Scale and shift: gamma * normalized + beta (learned parameters)

output = layer_norm(x)
print(f"\nLayerNorm output shape: {output.shape}")

# Check: each position should have mean≈0 and std≈1
print(f"\nFirst position, first example:")
print(f"  Before: mean={x[0,0].mean():.2f}, std={x[0,0].std():.2f}")
print(f"  After:  mean={output[0,0].mean():.4f}, std={output[0,0].std():.4f}")

6.3 Pre-Norm vs Post-Norm

"""
Pre-Norm vs Post-Norm
======================
The placement of LayerNorm relative to attention/FFN matters a lot!
"""

import torch
import torch.nn as nn

class PostNormTransformerBlock(nn.Module):
    """
    ORIGINAL Transformer (2017): Post-Norm
    x -> Attention -> Add(x) -> LayerNorm -> FFN -> Add -> LayerNorm

    Issues:
    - Harder to train at large scale
    - Requires careful learning rate warmup
    - Can be unstable with deep models
    """
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.ffn = FeedForwardNetwork(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        # Attention with residual, THEN normalize
        attn_out, _ = self.attn(x, mask)
        x = self.norm1(x + attn_out)

        # FFN with residual, THEN normalize
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)

        return x


class PreNormTransformerBlock(nn.Module):
    """
    MODERN approach (GPT-2+, Llama, etc.): Pre-Norm
    x -> LayerNorm -> Attention -> Add(x) -> LayerNorm -> FFN -> Add(x)

    Benefits:
    - More stable training (gradients flow through residual path unchanged)
    - Works with larger learning rates
    - Easier to scale to very deep models
    - Used by GPT-2, GPT-3, Llama, most modern LLMs
    """
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.ffn = FeedForwardNetwork(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        # Normalize FIRST, then attention, then residual
        attn_out, _ = self.attn(self.norm1(x), mask)
        x = x + attn_out  # Residual connection (gradient highway)

        # Normalize FIRST, then FFN, then residual
        ffn_out = self.ffn(self.norm2(x))
        x = x + ffn_out

        return x

# Why Pre-Norm is better for deep models:
# In Post-Norm, gradients must flow THROUGH the LayerNorm
# In Pre-Norm, there's a "clean" residual path where gradients
# flow directly from output to input without any transformation.
# This is critical for training 100+ layer models.

6.4 RMSNorm (Used in Llama)

"""
RMSNorm (Root Mean Square Layer Normalization)
================================================
Used in Llama 2, Llama 3, Mistral, and most modern open LLMs.

Simpler and faster than standard LayerNorm:
- No mean subtraction (skip the centering step)
- Only divides by RMS (root mean square)
- Fewer operations, similar or better performance
"""

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization.

    Standard LayerNorm: (x - mean) / sqrt(var + eps) * gamma + beta
    RMSNorm:           x / sqrt(mean(x^2) + eps) * gamma

    Key differences:
    1. No mean subtraction (no centering)
    2. No beta (bias) parameter
    3. Divides by RMS instead of standard deviation
    4. ~10-15% faster than LayerNorm
    """

    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        # Compute RMS
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)

        # Normalize and scale
        return (x / rms) * self.gamma

# Compare
d_model = 768
x = torch.randn(2, 10, d_model)

layer_norm = nn.LayerNorm(d_model)
rms_norm = RMSNorm(d_model)

ln_out = layer_norm(x)
rms_out = rms_norm(x)

print(f"Input shape: {x.shape}")
print(f"LayerNorm output: mean={ln_out.mean():.6f}, std={ln_out.std():.4f}")
print(f"RMSNorm output:   mean={rms_out.mean():.6f}, std={rms_out.std():.4f}")

# RMSNorm parameters (no bias):
ln_params = sum(p.numel() for p in layer_norm.parameters())
rms_params = sum(p.numel() for p in rms_norm.parameters())
print(f"\nLayerNorm parameters: {ln_params:,} (gamma + beta)")
print(f"RMSNorm parameters:  {rms_params:,} (gamma only)")

7. Residual Connections

7.1 Skip Connections Explained

A residual (skip) connection adds the input of a layer directly to its output:

output = x + F(x)

Where F(x) is the transformation (attention or FFN) and x is the original input.

Instead of learning the full transformation, the layer only needs to learn the residual (the difference from the identity). This is much easier!

7.2 Why They Help with Gradient Flow

"""
Residual Connections and Gradient Flow
========================================

Without residual connections:
  output = F_N(F_{N-1}(...F_2(F_1(x))...))

  Gradient through N layers:
  dL/dx = dL/dF_N * dF_N/dF_{N-1} * ... * dF_1/dx

  If each dF_i/dF_{i-1} < 1, gradient VANISHES exponentially!
  If each dF_i/dF_{i-1} > 1, gradient EXPLODES exponentially!


With residual connections:
  output = x + F_N(x + F_{N-1}(x + ...))

  Gradient includes a DIRECT path through the addition:
  dL/dx = dL/d_out * (1 + dF/dx)

  The "1" ensures gradient ALWAYS flows, even if dF/dx is tiny.
  This is like a highway for gradients through the entire network.
"""

import torch
import torch.nn as nn

# Demonstrate gradient flow
class LayerWithoutResidual(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)

    def forward(self, x):
        return torch.tanh(self.linear(x))  # No residual

class LayerWithResidual(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear = nn.Linear(dim, dim)

    def forward(self, x):
        return x + torch.tanh(self.linear(x))  # With residual

# Stack 50 layers and check gradient magnitude
dim = 64

# Without residuals
x = torch.randn(1, dim, requires_grad=True)
out = x
layers_no_res = [LayerWithoutResidual(dim) for _ in range(50)]
for layer in layers_no_res:
    out = layer(out)
loss = out.sum()
loss.backward()
print(f"Without residual connections (50 layers):")
print(f"  Gradient magnitude: {x.grad.norm():.6e}")
# Will be extremely small (vanishing gradient)

# With residuals
x2 = torch.randn(1, dim, requires_grad=True)
out2 = x2
layers_res = [LayerWithResidual(dim) for _ in range(50)]
for layer in layers_res:
    out2 = layer(out2)
loss2 = out2.sum()
loss2.backward()
print(f"\nWith residual connections (50 layers):")
print(f"  Gradient magnitude: {x2.grad.norm():.6e}")
# Will be much healthier

8. Logits and Output

8.1 What Are Logits?

Logits are the raw, unnormalized scores output by the final layer of the model, before applying softmax. For a language model with vocabulary size V, the logits are a vector of V numbers, one per possible next token.

# For a model with vocab_size = 50,257 (GPT-2)
# Input: "The cat sat on the"
# Output logits: [2.1, -0.3, 0.8, ..., 1.5]  (50,257 numbers)
#                ^ "the"  ^ "a"   ^ "and"      ^ "mat"

# Higher logit = model thinks this token is more likely next
# But logits aren't probabilities! They can be any real number.
# We need softmax to convert them to a probability distribution.

8.2 Temperature Parameter

P(tokeni) = softmax(logiti / T)

Where T is the temperature parameter (default: 1.0)

"""
Temperature Sampling
=====================
Temperature controls the "randomness" of generation.
"""

import numpy as np

def softmax_with_temperature(logits, temperature=1.0):
    """Apply temperature scaling before softmax."""
    scaled = logits / temperature
    exp_scaled = np.exp(scaled - np.max(scaled))
    return exp_scaled / exp_scaled.sum()

# Example logits for next token
logits = np.array([2.0, 1.5, 1.0, 0.5, 0.0, -0.5, -1.0])
tokens = ["the", "a", "my", "his", "some", "that", "one"]

print("Temperature Effects on Token Probabilities:")
print(f"{'Token':>8}", end="")
for temp in [0.1, 0.5, 1.0, 1.5, 2.0]:
    print(f"  T={temp}", end="")
print()

for temps in [0.1, 0.5, 1.0, 1.5, 2.0]:
    probs = softmax_with_temperature(logits, temps)
    if temps == 0.1:
        for i, token in enumerate(tokens):
            print(f"{token:>8}", end="")
            for t in [0.1, 0.5, 1.0, 1.5, 2.0]:
                p = softmax_with_temperature(logits, t)
                print(f" {p[i]:.3f}", end="")
            print()
        break

# T=0.1: Almost deterministic (highest logit gets ~100%)
# T=0.5: More focused (top tokens dominate)
# T=1.0: Standard (balanced distribution)
# T=1.5: More random (flatter distribution)
# T=2.0: Very random (nearly uniform)

# T -> 0: Greedy decoding (always pick the most likely token)
# T -> inf: Random sampling (uniform distribution)

8.3 Top-k and Top-p (Nucleus) Sampling

"""
Top-k and Top-p Sampling
==========================
These are FILTERING strategies applied before sampling.
"""

import numpy as np

def top_k_sampling(logits, k=5, temperature=1.0):
    """
    Top-k sampling: only consider the k most likely tokens.

    1. Sort tokens by probability
    2. Keep only the top k
    3. Set all others to -infinity (zero probability)
    4. Re-normalize and sample

    Example: k=5 means only the 5 most likely tokens are candidates.
    """
    # Apply temperature
    scaled_logits = logits / temperature

    # Find top-k indices
    top_k_indices = np.argsort(scaled_logits)[-k:]

    # Create filtered logits (everything else = -inf)
    filtered = np.full_like(scaled_logits, -np.inf)
    filtered[top_k_indices] = scaled_logits[top_k_indices]

    # Softmax to get probabilities
    probs = np.exp(filtered - np.max(filtered))
    probs = probs / probs.sum()

    # Sample
    chosen = np.random.choice(len(logits), p=probs)
    return chosen, probs

def top_p_sampling(logits, p=0.9, temperature=1.0):
    """
    Top-p (nucleus) sampling: keep the smallest set of tokens whose
    cumulative probability exceeds p.

    1. Sort tokens by probability (descending)
    2. Compute cumulative sum
    3. Find cutoff where cumsum >= p
    4. Keep all tokens before the cutoff
    5. Re-normalize and sample

    Example: p=0.9 means keep tokens until their cumulative probability
    reaches 90%. This ADAPTS the number of candidates:
    - If model is confident: might only keep 2-3 tokens
    - If model is uncertain: might keep 50+ tokens
    """
    # Apply temperature and softmax
    scaled_logits = logits / temperature
    probs = np.exp(scaled_logits - np.max(scaled_logits))
    probs = probs / probs.sum()

    # Sort in descending order
    sorted_indices = np.argsort(probs)[::-1]
    sorted_probs = probs[sorted_indices]

    # Compute cumulative probabilities
    cumulative_probs = np.cumsum(sorted_probs)

    # Find cutoff index (first index where cumsum >= p)
    cutoff_idx = np.searchsorted(cumulative_probs, p) + 1

    # Keep only tokens within the nucleus
    nucleus_indices = sorted_indices[:cutoff_idx]
    nucleus_probs = probs[nucleus_indices]
    nucleus_probs = nucleus_probs / nucleus_probs.sum()  # Re-normalize

    # Sample
    chosen_idx = np.random.choice(len(nucleus_indices), p=nucleus_probs)
    chosen = nucleus_indices[chosen_idx]

    return chosen, cutoff_idx, nucleus_probs

# Demonstrate
np.random.seed(42)

# Scenario 1: Model is confident
confident_logits = np.array([5.0, 2.0, 1.0, 0.5, 0.1, -1.0, -2.0, -3.0])
tokens = ["Paris", "London", "Berlin", "Tokyo", "Rome", "Madrid", "Oslo", "Lisbon"]

print("SCENARIO 1: Model is confident")
print("Prompt: 'The capital of France is'")
probs = np.exp(confident_logits - np.max(confident_logits))
probs = probs / probs.sum()
for i, (token, prob) in enumerate(zip(tokens, probs)):
    print(f"  {token:10s}: {prob:.4f} {'<--- dominant' if i == 0 else ''}")

_, n_nucleus, _ = top_p_sampling(confident_logits, p=0.9)
print(f"Top-p=0.9 keeps {n_nucleus} tokens (very focused)")

# Scenario 2: Model is uncertain
uncertain_logits = np.array([1.5, 1.4, 1.3, 1.2, 1.1, 1.0, 0.9, 0.8])

print("\nSCENARIO 2: Model is uncertain")
print("Prompt: 'I want to eat'")
probs = np.exp(uncertain_logits - np.max(uncertain_logits))
probs = probs / probs.sum()
food_tokens = ["pizza", "sushi", "pasta", "tacos", "salad", "steak", "curry", "soup"]
for token, prob in zip(food_tokens, probs):
    print(f"  {token:10s}: {prob:.4f}")

_, n_nucleus, _ = top_p_sampling(uncertain_logits, p=0.9)
print(f"Top-p=0.9 keeps {n_nucleus} tokens (more diverse)")

# Top-p ADAPTS: fewer candidates when confident, more when uncertain
# This is why top-p is generally preferred over top-k in practice
"""
Complete Sampling Function
============================
Putting it all together: temperature + top-k + top-p
"""

import numpy as np

def sample_next_token(logits, temperature=1.0, top_k=0, top_p=1.0):
    """
    Sample the next token from logits with various strategies.

    This is what happens at each step of LLM text generation.

    Args:
        logits: Raw model output (vocab_size,)
        temperature: Scaling factor (higher = more random)
        top_k: If > 0, only keep top-k tokens (0 = disabled)
        top_p: If < 1.0, nucleus sampling threshold (1.0 = disabled)

    Returns:
        selected_token: Index of the chosen token
    """
    # Step 1: Apply temperature
    if temperature == 0:
        # Greedy decoding: always pick the most likely
        return np.argmax(logits)

    scaled_logits = logits / temperature

    # Step 2: Apply top-k filtering
    if top_k > 0:
        top_k_indices = np.argsort(scaled_logits)[-top_k:]
        mask = np.full_like(scaled_logits, -np.inf)
        mask[top_k_indices] = scaled_logits[top_k_indices]
        scaled_logits = mask

    # Step 3: Convert to probabilities
    probs = np.exp(scaled_logits - np.max(scaled_logits))
    probs = probs / probs.sum()

    # Step 4: Apply top-p (nucleus) filtering
    if top_p < 1.0:
        sorted_indices = np.argsort(probs)[::-1]
        sorted_probs = probs[sorted_indices]
        cumulative = np.cumsum(sorted_probs)

        # Find cutoff
        cutoff = np.searchsorted(cumulative, top_p) + 1

        # Zero out tokens beyond the nucleus
        removed_indices = sorted_indices[cutoff:]
        probs[removed_indices] = 0

        # Re-normalize
        probs = probs / probs.sum()

    # Step 5: Sample from the distribution
    selected = np.random.choice(len(probs), p=probs)

    return selected


# Demonstrate different sampling strategies
np.random.seed(42)
vocab_size = 10
logits = np.random.randn(vocab_size) * 2  # Random logits
token_names = [f"tok_{i}" for i in range(vocab_size)]

print("Sampling Strategy Comparison (100 samples each):")
print("=" * 60)

strategies = [
    {"name": "Greedy (T=0)",           "temperature": 0,   "top_k": 0, "top_p": 1.0},
    {"name": "Low temp (T=0.3)",       "temperature": 0.3, "top_k": 0, "top_p": 1.0},
    {"name": "Standard (T=1.0)",       "temperature": 1.0, "top_k": 0, "top_p": 1.0},
    {"name": "Top-k=3, T=1.0",        "temperature": 1.0, "top_k": 3, "top_p": 1.0},
    {"name": "Top-p=0.9, T=1.0",      "temperature": 1.0, "top_k": 0, "top_p": 0.9},
    {"name": "Top-k=5, Top-p=0.9",    "temperature": 1.0, "top_k": 5, "top_p": 0.9},
]

for strategy in strategies:
    counts = np.zeros(vocab_size)
    for _ in range(1000):
        token = sample_next_token(
            logits,
            temperature=strategy["temperature"],
            top_k=strategy["top_k"],
            top_p=strategy["top_p"]
        )
        counts[token] += 1

    # Show distribution
    probs = counts / counts.sum()
    n_unique = (counts > 0).sum()
    top_token = token_names[np.argmax(counts)]

    print(f"\n{strategy['name']:30s} | Unique tokens: {n_unique} | "
          f"Top: {top_token} ({probs[np.argmax(counts)]:.1%})")

Week 3 Summary

Key Takeaways

  1. The Transformer has encoder-decoder structure, but decoder-only variants (GPT-style) dominate modern LLMs.
  2. Multi-head attention splits attention into parallel heads, each learning different relationship types. Output = Concat(heads) W_O.
  3. Q, K, V are learned linear projections of the input. Q asks "what am I looking for?", K says "what do I have?", V provides the content.
  4. Cross-attention lets the decoder attend to encoder output: Q from decoder, K/V from encoder.
  5. FFNs provide non-linear transformation after attention. Modern models use SwiGLU activation.
  6. Layer normalization stabilizes training. Pre-norm (norm before attention/FFN) is preferred. RMSNorm is the modern standard.
  7. Residual connections enable gradient flow through deep networks by providing shortcut paths.
  8. Temperature, top-k, top-p control the diversity of generated text.

Exercises

  1. Implement a complete Transformer encoder block (multi-head attention + FFN + LayerNorm + residual connections) in PyTorch.
  2. Experiment with different numbers of attention heads. How does model performance change?
  3. Implement RMSNorm and compare its speed to standard LayerNorm using PyTorch benchmarking.
  4. Write a function that generates text using temperature + top-p sampling. Compare outputs at different temperatures.
  5. Calculate the total parameters for a Transformer with d_model=512, n_heads=8, d_ff=2048, n_layers=6.

Further Reading