Week 4: Causal Attention + Coding a Transformer from Scratch

Build a complete GPT-style language model from the ground up in PyTorch. Understand autoregressive generation, causal masking, and every detail of training.

Difficulty: Intermediate

1. Causal (Masked) Attention

1.1 Why Causal Masking?

In autoregressive language models (GPT, Llama, Claude), the model generates text one token at a time, left to right. When predicting the next token, it should only see tokens that came before it — never tokens from the future. This is enforced through causal masking.

Without causal masking, the model would "cheat" during training by looking at the answer it's supposed to predict. Imagine trying to predict the next word in a sentence but being allowed to peek at it first — the model would never learn anything useful.

1.2 The Causal Mask Matrix

"""
Causal Mask - The Lower Triangular Matrix
============================================
"""

import torch
import numpy as np

seq_len = 6
tokens = ["The", "cat", "sat", "on", "the", "mat"]

# Create causal mask (lower triangular matrix)
causal_mask = torch.tril(torch.ones(seq_len, seq_len))

print("Causal Mask (1 = can attend, 0 = cannot attend):")
print(f"\n{'':>6}", end="")
for t in tokens:
    print(f" {t:>5}", end="")
print()

for i, t in enumerate(tokens):
    print(f"{t:>6}", end="")
    for j in range(seq_len):
        val = int(causal_mask[i, j].item())
        marker = f"  {val}  " if val == 1 else "  .  "
        print(marker, end="")
    print()

# Output:
#           The   cat   sat    on   the   mat
#    The    1     .     .     .     .     .
#    cat    1     1     .     .     .     .
#    sat    1     1     1     .     .     .
#     on    1     1     1     1     .     .
#    the    1     1     1     1     1     .
#    mat    1     1     1     1     1     1

# Reading row by row:
# "The" can only attend to itself
# "cat" can attend to "The" and itself
# "sat" can attend to "The", "cat", and itself
# ...
# "mat" can attend to ALL previous tokens (full context)

# This is a LOWER TRIANGULAR matrix - the key structural element
# of all autoregressive (decoder-only) transformers.

1.3 How Masking is Applied

"""
How Causal Masking Works in Practice
======================================
Masking is applied BEFORE softmax by setting future positions to -infinity.
"""

import torch
import torch.nn.functional as F
import math

def causal_self_attention(Q, K, V):
    """
    Self-attention with causal mask.

    Args:
        Q, K, V: (batch, n_heads, seq_len, d_k)

    Returns:
        output: (batch, n_heads, seq_len, d_k)
        weights: (batch, n_heads, seq_len, seq_len)
    """
    d_k = Q.size(-1)
    seq_len = Q.size(-2)

    # Step 1: Compute raw attention scores
    scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
    # Shape: (batch, n_heads, seq_len, seq_len)

    # Step 2: Create causal mask
    # Lower triangular matrix of True/False
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    # mask[i][j] = True if j > i (future positions)

    # Step 3: Apply mask - set future positions to -infinity
    # This ensures softmax gives them ZERO probability
    scores = scores.masked_fill(mask, float('-inf'))

    # Step 4: Softmax (future positions become 0, past positions share 100%)
    weights = F.softmax(scores, dim=-1)

    # Step 5: Weighted sum of values
    output = weights @ V

    return output, weights

# Demonstrate
torch.manual_seed(42)
batch, n_heads, seq_len, d_k = 1, 1, 5, 4

Q = torch.randn(batch, n_heads, seq_len, d_k)
K = torch.randn(batch, n_heads, seq_len, d_k)
V = torch.randn(batch, n_heads, seq_len, d_k)

output, weights = causal_self_attention(Q, K, V)

print("Causal Attention Weights (batch 0, head 0):")
print(weights[0, 0].numpy().round(3))
print()

# Expected pattern:
# [[1.000, 0.000, 0.000, 0.000, 0.000],  <- token 0 only attends to self
#  [0.XXX, 0.XXX, 0.000, 0.000, 0.000],  <- token 1 attends to 0,1
#  [0.XXX, 0.XXX, 0.XXX, 0.000, 0.000],  <- token 2 attends to 0,1,2
#  [0.XXX, 0.XXX, 0.XXX, 0.XXX, 0.000],  <- token 3 attends to 0,1,2,3
#  [0.XXX, 0.XXX, 0.XXX, 0.XXX, 0.XXX]]  <- token 4 attends to all

# Notice: upper triangle is all zeros!
# Each row sums to exactly 1.0 (valid probability distribution)
print("Row sums:", weights[0, 0].sum(dim=-1).numpy().round(3))

1.4 Encoder (Bidirectional) vs Decoder (Causal)

Aspect Encoder (Bidirectional) Decoder (Causal/Unidirectional)
Attention mask None (or only padding mask). Every token attends to all tokens. Lower triangular. Token i can only attend to tokens 0..i.
Examples BERT, RoBERTa, DeBERTa GPT-2/3/4, Llama, Claude, DeepSeek
Training task Masked Language Modeling (predict [MASK] tokens) Next Token Prediction (predict each next token)
Generation Cannot generate text naturally Natural autoregressive generation
Best for Understanding, classification, embeddings Text generation, chatbots, code generation
Causal Masked Attention Pattern
graph TD subgraph Causal Mask - Lower Triangular T1["Token 1: The"] --> A1["Attends to: The"] T2["Token 2: cat"] --> A2["Attends to: The, cat"] T3["Token 3: sat"] --> A3["Attends to: The, cat, sat"] T4["Token 4: on"] --> A4["Attends to: The, cat, sat, on"] end M["Mask fills future positions
with negative infinity"] --> S["Softmax converts to
zero probability"] S --> R["Each token only sees
past and present tokens"] style T1 fill:#e8f4f8,stroke:#333 style T2 fill:#dae8fc,stroke:#333 style T3 fill:#d5e8d4,stroke:#333 style T4 fill:#fff2cc,stroke:#333 style M fill:#f8cecc,stroke:#333 style R fill:#e1d5e7,stroke:#333

2. Autoregressive Models

2.1 Next Token Prediction Explained

"""
Next Token Prediction - The Core of GPT/Llama/Claude
======================================================

The model is trained to predict the next token at EVERY position
in the sequence simultaneously. This is incredibly efficient!
"""

# Training example:
# Input text: "The cat sat on the mat"
# After tokenization: [464, 3797, 3332, 319, 262, 2603]

# The model processes ALL tokens at once and produces predictions
# for each position:

# Position 0: Input "The"           -> Predict "cat"     (ID 3797)
# Position 1: Input "The cat"       -> Predict "sat"     (ID 3332)
# Position 2: Input "The cat sat"   -> Predict "on"      (ID 319)
# Position 3: Input "The cat sat on"-> Predict "the"     (ID 262)
# Position 4: Input "The cat sat on the" -> Predict "mat" (ID 2603)

# The causal mask ensures:
# - When predicting position 1, the model can only see position 0
# - When predicting position 2, the model can see positions 0-1
# - etc.

# In one forward pass, we get 5 training signals from a 6-token sequence!
# This is why language model training is so efficient.

2.2 Teacher Forcing

"""
Teacher Forcing
================
During TRAINING, we feed the model the CORRECT previous tokens
(not its own predictions). This is called "teacher forcing."

Why? If the model generates a wrong token and we feed that back,
errors compound rapidly ("exposure bias"). Teacher forcing uses
the ground truth to keep training stable.
"""

# Teacher forcing example:
# Target: "The cat sat on the mat"

# With teacher forcing (TRAINING):
# Step 1: Input = [BOS]          -> Model predicts "The"  (correct: "The")
# Step 2: Input = [BOS, "The"]   -> Model predicts "dog"  (correct: "cat")
# Step 3: Input = [BOS, "The", "cat"]  -> Use "cat" (NOT "dog"!)
#          The model gets the CORRECT token regardless of what it predicted.

# Without teacher forcing (INFERENCE / GENERATION):
# Step 1: Input = [BOS]          -> Model predicts "The"
# Step 2: Input = [BOS, "The"]   -> Model predicts "cat"
# Step 3: Input = [BOS, "The", "cat"] -> Model predicts "sat"
#          The model uses its OWN predictions as input.

# Key insight: thanks to the causal mask, teacher forcing works in
# ONE forward pass. The mask prevents the model from seeing future
# tokens, so all positions can be trained simultaneously.
Autoregressive Generation Flow
graph LR P["Prompt:
The cat"] --> T["Transformer
Forward Pass"] T --> L["Logits over
Vocabulary"] L --> S["Sampling
greedy/top-k/top-p"] S --> N["Next Token:
sat"] N --> APP["Append to
Sequence"] APP --> T APP --> O["Output so far:
The cat sat ..."] style P fill:#e8f4f8,stroke:#333 style T fill:#dae8fc,stroke:#333 style L fill:#d5e8d4,stroke:#333 style S fill:#fff2cc,stroke:#333 style N fill:#e1d5e7,stroke:#333 style O fill:#f8cecc,stroke:#333

2.3 Decoding Strategies

"""
Decoding Strategies for Text Generation
==========================================
"""

import numpy as np

def greedy_decoding(model, prompt_tokens, max_tokens=50):
    """
    Greedy Decoding: always pick the most likely next token.

    Pros: Fast, deterministic, consistent
    Cons: Repetitive, boring, misses good alternatives
    """
    tokens = list(prompt_tokens)
    for _ in range(max_tokens):
        logits = model.forward(tokens)  # Get logits for all vocab
        next_token = np.argmax(logits[-1])  # Pick highest probability
        tokens.append(next_token)
        if next_token == EOS_TOKEN:
            break
    return tokens

def beam_search(model, prompt_tokens, beam_width=5, max_tokens=50):
    """
    Beam Search: maintain top-k candidates at each step.

    Keep beam_width "beams" (partial sequences) and expand each
    by considering all possible next tokens. Keep the top beam_width
    sequences by total log-probability.

    Pros: Finds higher-probability sequences than greedy
    Cons: Still tends to be generic/repetitive, expensive
    """
    # Initialize beams: [(sequence, cumulative_log_prob)]
    beams = [(list(prompt_tokens), 0.0)]

    for _ in range(max_tokens):
        all_candidates = []

        for seq, score in beams:
            logits = model.forward(seq)
            log_probs = log_softmax(logits[-1])

            # Expand each beam with top-k next tokens
            top_k_indices = np.argsort(log_probs)[-beam_width:]

            for idx in top_k_indices:
                new_seq = seq + [idx]
                new_score = score + log_probs[idx]
                all_candidates.append((new_seq, new_score))

        # Keep top beam_width candidates
        all_candidates.sort(key=lambda x: x[1], reverse=True)
        beams = all_candidates[:beam_width]

    # Return the best sequence
    return beams[0][0]

def sample_with_top_p(model, prompt_tokens, temperature=0.8, top_p=0.95,
                       max_tokens=50):
    """
    Sampling with Temperature + Top-p: the most common approach for chatbots.

    This is what ChatGPT, Claude, etc. use for generation.

    Pros: Creative, diverse, natural-sounding
    Cons: Non-deterministic, can occasionally go off-track
    """
    tokens = list(prompt_tokens)
    for _ in range(max_tokens):
        logits = model.forward(tokens)
        next_logits = logits[-1] / temperature

        # Top-p filtering
        probs = softmax(next_logits)
        sorted_indices = np.argsort(probs)[::-1]
        sorted_probs = probs[sorted_indices]
        cumsum = np.cumsum(sorted_probs)
        cutoff = np.searchsorted(cumsum, top_p) + 1

        # Zero out everything outside the nucleus
        allowed = sorted_indices[:cutoff]
        filtered_probs = np.zeros_like(probs)
        filtered_probs[allowed] = probs[allowed]
        filtered_probs /= filtered_probs.sum()

        # Sample
        next_token = np.random.choice(len(filtered_probs), p=filtered_probs)
        tokens.append(next_token)
        if next_token == EOS_TOKEN:
            break

    return tokens

# In practice, the most common settings (as of 2025-2026):
# - temperature: 0.6-1.0 (lower for factual tasks, higher for creative)
# - top_p: 0.9-0.95
# - top_k: often not used when top_p is set
# - repetition_penalty: 1.0-1.2 (penalizes repeating tokens)

3. Building a Complete GPT-Style Transformer from Scratch

Transformer Training Loop
graph TD D["Training Data
Text Corpus"] --> B["Create Batches
input_ids, target_ids"] B --> FW["Forward Pass
Embeddings to Logits"] FW --> CE["Cross-Entropy Loss
Compare predictions vs targets"] CE --> BW["Backward Pass
Compute Gradients"] BW --> CG["Gradient Clipping
Prevent exploding gradients"] CG --> OPT["Optimizer Step
AdamW updates weights"] OPT --> LR["Learning Rate Scheduler
Warmup then decay"] LR --> CHK{"Checkpoint?"} CHK -->|"Save"| SV["Save Model Weights"] CHK -->|"Continue"| B style D fill:#e8f4f8,stroke:#333 style FW fill:#dae8fc,stroke:#333 style CE fill:#f8cecc,stroke:#333 style BW fill:#fff2cc,stroke:#333 style OPT fill:#d5e8d4,stroke:#333 style LR fill:#e1d5e7,stroke:#333

This is the main practical section. We will build a complete, working, trainable GPT-style language model in PyTorch. Every component is implemented from scratch with detailed explanations.

What we're building:
  • Token + positional embedding layer
  • Multi-head causal self-attention
  • Feed-forward network with GELU activation
  • Transformer blocks (attention + FFN + LayerNorm + residual)
  • N stacked transformer blocks
  • Output projection to vocabulary
  • Training loop with AdamW optimizer
  • Text generation function
We will train it on Shakespeare text and generate new text!
"""
=============================================================================
MINI-GPT: A Complete GPT-Style Transformer Language Model from Scratch
=============================================================================

This is a fully working implementation of a GPT-style decoder-only
transformer. You can train it on any text and generate new text.

Architecture:
  - Token Embeddings + Learned Positional Embeddings
  - N Transformer Blocks, each containing:
    - Pre-Norm (LayerNorm)
    - Multi-Head Causal Self-Attention
    - Residual Connection
    - Pre-Norm (LayerNorm)
    - Feed-Forward Network (with GELU)
    - Residual Connection
  - Final LayerNorm
  - Output Projection (linear layer to vocab_size)

This follows the GPT-2 architecture closely.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os

# ============================================================================
# CONFIGURATION
# ============================================================================

class GPTConfig:
    """
    Configuration for our mini-GPT model.

    Default values are small enough to train on a single GPU (or even CPU)
    but large enough to learn interesting patterns.
    """
    # Model architecture
    vocab_size: int = 256          # Character-level (we'll use byte-level encoding)
    max_seq_len: int = 256         # Maximum sequence length (context window)
    d_model: int = 384             # Embedding dimension
    n_heads: int = 6               # Number of attention heads
    n_layers: int = 6              # Number of transformer blocks
    d_ff: int = 384 * 4            # FFN hidden dimension (4x d_model)
    dropout: float = 0.1           # Dropout rate

    # Training
    batch_size: int = 64           # Batch size
    learning_rate: float = 3e-4    # Peak learning rate
    max_iters: int = 5000          # Total training iterations
    warmup_iters: int = 100        # Learning rate warmup steps
    weight_decay: float = 0.1      # AdamW weight decay
    grad_clip: float = 1.0         # Gradient clipping norm

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)
        # Ensure d_model is divisible by n_heads
        assert self.d_model % self.n_heads == 0

    def __repr__(self):
        params = ', '.join(f'{k}={v}' for k, v in self.__dict__.items())
        return f"GPTConfig({params})"


config = GPTConfig()
print(config)

# ============================================================================
# COMPONENT 1: MULTI-HEAD CAUSAL SELF-ATTENTION
# ============================================================================

class CausalSelfAttention(nn.Module):
    """
    Multi-head causal (masked) self-attention.

    This is the core attention mechanism of GPT. Each token can only
    attend to itself and previous tokens (never future tokens).
    """

    def __init__(self, config):
        super().__init__()

        self.n_heads = config.n_heads
        self.d_model = config.d_model
        self.d_k = config.d_model // config.n_heads  # Dimension per head

        # Combined Q, K, V projection (more efficient than 3 separate)
        # Projects from d_model to 3 * d_model (Q, K, V concatenated)
        self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model, bias=False)

        # Output projection
        self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        # Dropout
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        # Causal mask: register as a buffer (not a parameter)
        # This is a lower triangular matrix that ensures causal attention
        mask = torch.tril(torch.ones(config.max_seq_len, config.max_seq_len))
        self.register_buffer('mask', mask.view(
            1, 1, config.max_seq_len, config.max_seq_len
        ))

    def forward(self, x):
        """
        Args:
            x: (batch_size, seq_len, d_model)

        Returns:
            output: (batch_size, seq_len, d_model)
        """
        B, T, C = x.size()  # Batch, Time (seq_len), Channels (d_model)

        # Step 1: Compute Q, K, V in one matrix multiplication
        qkv = self.qkv_proj(x)  # (B, T, 3*d_model)

        # Split into Q, K, V
        Q, K, V = qkv.split(self.d_model, dim=2)  # Each: (B, T, d_model)

        # Step 2: Reshape for multi-head attention
        # (B, T, d_model) -> (B, T, n_heads, d_k) -> (B, n_heads, T, d_k)
        Q = Q.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        # Step 3: Compute attention scores
        # (B, n_heads, T, d_k) @ (B, n_heads, d_k, T) = (B, n_heads, T, T)
        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)

        # Step 4: Apply causal mask
        # Set future positions to -infinity so softmax gives them 0 probability
        scores = scores.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))

        # Step 5: Softmax to get attention weights
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)

        # Step 6: Weighted sum of values
        # (B, n_heads, T, T) @ (B, n_heads, T, d_k) = (B, n_heads, T, d_k)
        context = attn_weights @ V

        # Step 7: Combine heads
        # (B, n_heads, T, d_k) -> (B, T, n_heads, d_k) -> (B, T, d_model)
        context = context.transpose(1, 2).contiguous().view(B, T, C)

        # Step 8: Output projection
        output = self.resid_dropout(self.out_proj(context))

        return output


# ============================================================================
# COMPONENT 2: FEED-FORWARD NETWORK
# ============================================================================

class FeedForward(nn.Module):
    """
    Position-wise feed-forward network with GELU activation.

    FFN(x) = GELU(x @ W1 + b1) @ W2 + b2

    GELU (Gaussian Error Linear Unit) is used in GPT-2 and most modern models.
    It's a smoother version of ReLU:
    GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
    """

    def __init__(self, config):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model, config.d_ff)
        self.fc2 = nn.Linear(config.d_ff, config.d_model)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        """
        Args:
            x: (batch_size, seq_len, d_model)
        Returns:
            output: (batch_size, seq_len, d_model)
        """
        x = self.fc1(x)          # (B, T, d_model) -> (B, T, d_ff)
        x = self.activation(x)   # GELU activation
        x = self.fc2(x)          # (B, T, d_ff) -> (B, T, d_model)
        x = self.dropout(x)
        return x


# ============================================================================
# COMPONENT 3: TRANSFORMER BLOCK
# ============================================================================

class TransformerBlock(nn.Module):
    """
    A single transformer block: LayerNorm -> Attention -> Residual ->
                                 LayerNorm -> FFN -> Residual

    This uses Pre-Norm (LayerNorm before attention/FFN), which is the
    modern standard used in GPT-2, GPT-3, Llama, etc.
    """

    def __init__(self, config):
        super().__init__()

        # Pre-norm LayerNorm for attention
        self.ln1 = nn.LayerNorm(config.d_model)

        # Multi-head causal self-attention
        self.attn = CausalSelfAttention(config)

        # Pre-norm LayerNorm for FFN
        self.ln2 = nn.LayerNorm(config.d_model)

        # Feed-forward network
        self.ffn = FeedForward(config)

    def forward(self, x):
        """
        Pre-Norm Transformer Block:
        x = x + Attention(LayerNorm(x))
        x = x + FFN(LayerNorm(x))

        The residual connections (x + ...) create "gradient highways"
        that allow gradients to flow directly through the network,
        enabling training of very deep models (100+ layers).
        """
        # Attention sub-layer with residual connection
        x = x + self.attn(self.ln1(x))

        # FFN sub-layer with residual connection
        x = x + self.ffn(self.ln2(x))

        return x


# ============================================================================
# COMPONENT 4: THE COMPLETE GPT MODEL
# ============================================================================

class MiniGPT(nn.Module):
    """
    Complete GPT-style language model.

    Architecture:
    1. Token embedding + Positional embedding
    2. Stack of N Transformer blocks
    3. Final LayerNorm
    4. Output projection to vocabulary (shared with token embedding weights)
    """

    def __init__(self, config):
        super().__init__()
        self.config = config

        # Token embedding: maps token IDs to vectors
        # vocab_size -> d_model
        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)

        # Positional embedding: learned (not sinusoidal)
        # max_seq_len -> d_model
        self.position_embedding = nn.Embedding(config.max_seq_len, config.d_model)

        # Dropout after embeddings
        self.embed_dropout = nn.Dropout(config.dropout)

        # Stack of transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layers)
        ])

        # Final layer normalization
        self.ln_final = nn.LayerNorm(config.d_model)

        # Output projection: d_model -> vocab_size
        # This converts the final hidden states to logits over the vocabulary
        self.output_proj = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # Weight tying: share weights between token embedding and output projection
        # This is a common technique that:
        # 1. Reduces parameters
        # 2. Improves performance (tokens used similarly in input and output)
        # GPT-2, GPT-3, and most modern models do this
        self.output_proj.weight = self.token_embedding.weight

        # Initialize weights
        self.apply(self._init_weights)

        # Count parameters
        n_params = sum(p.numel() for p in self.parameters())
        print(f"MiniGPT initialized with {n_params:,} parameters")

    def _init_weights(self, module):
        """
        Initialize weights following GPT-2 conventions.

        - Linear layers: normal distribution with std=0.02
        - Embeddings: normal distribution with std=0.02
        - Biases: zero
        - Output projection of residual layers: scaled by 1/sqrt(2*n_layers)
        """
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        """
        Forward pass of the model.

        Args:
            idx: Token indices, shape (batch_size, seq_len)
            targets: Target token indices for computing loss, same shape as idx

        Returns:
            logits: (batch_size, seq_len, vocab_size)
            loss: Scalar loss value (if targets provided)
        """
        B, T = idx.shape
        device = idx.device

        assert T <= self.config.max_seq_len, \
            f"Sequence length {T} exceeds maximum {self.config.max_seq_len}"

        # Step 1: Token embeddings
        # (B, T) -> (B, T, d_model)
        tok_emb = self.token_embedding(idx)

        # Step 2: Positional embeddings
        # Create position indices: [0, 1, 2, ..., T-1]
        positions = torch.arange(0, T, dtype=torch.long, device=device)
        pos_emb = self.position_embedding(positions)  # (T, d_model)

        # Step 3: Add token and positional embeddings
        x = self.embed_dropout(tok_emb + pos_emb)  # (B, T, d_model)

        # Step 4: Pass through all transformer blocks
        for block in self.blocks:
            x = block(x)  # (B, T, d_model)

        # Step 5: Final layer normalization
        x = self.ln_final(x)  # (B, T, d_model)

        # Step 6: Project to vocabulary size
        logits = self.output_proj(x)  # (B, T, vocab_size)

        # Step 7: Compute loss (if targets provided)
        loss = None
        if targets is not None:
            # Cross-entropy loss for next token prediction
            # Reshape: (B*T, vocab_size) vs (B*T,)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1)
            )

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, top_p=None):
        """
        Generate text autoregressively.

        Args:
            idx: Starting token indices (B, T)
            max_new_tokens: How many tokens to generate
            temperature: Sampling temperature
            top_k: Top-k filtering (optional)
            top_p: Top-p nucleus filtering (optional)

        Returns:
            idx: Extended sequence (B, T + max_new_tokens)
        """
        for _ in range(max_new_tokens):
            # Crop to max_seq_len if necessary
            idx_cond = idx if idx.size(1) <= self.config.max_seq_len \
                       else idx[:, -self.config.max_seq_len:]

            # Forward pass (no targets = no loss computation)
            logits, _ = self(idx_cond)

            # Get logits for the last position only
            logits = logits[:, -1, :] / temperature  # (B, vocab_size)

            # Optional: top-k filtering
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            # Optional: top-p (nucleus) filtering
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                # Remove tokens with cumulative probability above top_p
                sorted_indices_to_remove = cumulative_probs > top_p
                # Shift so first token above threshold is kept
                sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
                sorted_indices_to_remove[:, 0] = 0
                # Scatter back to original indexing
                indices_to_remove = sorted_indices_to_remove.scatter(
                    1, sorted_indices, sorted_indices_to_remove
                )
                logits[indices_to_remove] = float('-inf')

            # Convert to probabilities
            probs = F.softmax(logits, dim=-1)

            # Sample
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)

            # Append to sequence
            idx = torch.cat([idx, idx_next], dim=1)  # (B, T+1)

        return idx


# ============================================================================
# CREATE THE MODEL
# ============================================================================

config = GPTConfig(
    vocab_size=256,        # Byte-level (ASCII)
    max_seq_len=256,       # 256-character context window
    d_model=384,           # Embedding dimension
    n_heads=6,             # 6 attention heads (384/6 = 64 per head)
    n_layers=6,            # 6 transformer blocks
    d_ff=384 * 4,          # 1536 FFN hidden dim
    dropout=0.1,
)

model = MiniGPT(config)
# Expected output: "MiniGPT initialized with X,XXX,XXX parameters"
# ============================================================================
# DATA PREPARATION
# ============================================================================

"""
We'll use character-level (byte-level) encoding for simplicity.
Each character maps to its ASCII/UTF-8 byte value (0-255).

In production LLMs, you'd use BPE tokenization (covered in Week 2).
"""

# Download Shakespeare text (or use any text file)
import urllib.request

def download_data():
    """Download a small Shakespeare dataset."""
    url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    data_path = "shakespeare.txt"
    if not os.path.exists(data_path):
        print("Downloading Shakespeare dataset...")
        urllib.request.urlretrieve(url, data_path)
    with open(data_path, 'r') as f:
        text = f.read()
    return text

text = download_data()
print(f"Dataset size: {len(text):,} characters")
print(f"First 200 characters:\n{text[:200]}")

# Encode text as bytes
def encode(text):
    """Convert text to list of byte values."""
    return [b for b in text.encode('utf-8')]

def decode(tokens):
    """Convert list of byte values back to text."""
    return bytes(tokens).decode('utf-8', errors='replace')

data = torch.tensor(encode(text), dtype=torch.long)
print(f"\nEncoded tensor shape: {data.shape}")
print(f"Vocabulary size: {config.vocab_size} (byte-level)")

# Train/validation split (90% / 10%)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
print(f"Train data: {len(train_data):,} tokens")
print(f"Val data: {len(val_data):,} tokens")


# ============================================================================
# DATA LOADER
# ============================================================================

def get_batch(split, config, train_data=train_data, val_data=val_data):
    """
    Generate a random batch of training examples.

    Each example is a sequence of max_seq_len tokens.
    The target is the same sequence shifted by 1 position.

    Input:  [T, h, e, _, c, a, t, _, s, a, ...]
    Target: [h, e, _, c, a, t, _, s, a, t, ...]

    For each position i, we predict target[i] from input[0:i+1].
    """
    data_source = train_data if split == 'train' else val_data

    # Random starting positions
    ix = torch.randint(len(data_source) - config.max_seq_len, (config.batch_size,))

    # Extract sequences
    x = torch.stack([data_source[i:i+config.max_seq_len] for i in ix])
    y = torch.stack([data_source[i+1:i+config.max_seq_len+1] for i in ix])

    return x, y  # Both shape: (batch_size, max_seq_len)


# Test the data loader
x_batch, y_batch = get_batch('train', config)
print(f"\nBatch shapes:")
print(f"  Input:  {x_batch.shape}")   # (64, 256)
print(f"  Target: {y_batch.shape}")   # (64, 256)
print(f"\nExample (first 50 chars):")
print(f"  Input:  '{decode(x_batch[0, :50].tolist())}'")
print(f"  Target: '{decode(y_batch[0, :50].tolist())}'")
print(f"  (Target is input shifted by 1 position)")
# ============================================================================
# TRAINING LOOP
# ============================================================================

"""
Training our mini-GPT model.

Key components:
1. AdamW optimizer (Adam with decoupled weight decay)
2. Learning rate scheduling (warmup + cosine decay)
3. Gradient clipping
4. Periodic evaluation on validation set
"""

def get_lr(step, config):
    """
    Learning rate schedule: linear warmup followed by cosine decay.

    This is the standard schedule used by GPT-2, GPT-3, Llama, etc.

    1. Warmup phase: linearly increase LR from 0 to peak
    2. Decay phase: cosine decay from peak to min_lr (10% of peak)
    """
    min_lr = config.learning_rate * 0.1

    # Warmup phase
    if step < config.warmup_iters:
        return config.learning_rate * (step / config.warmup_iters)

    # Cosine decay phase
    decay_ratio = (step - config.warmup_iters) / (config.max_iters - config.warmup_iters)
    decay_ratio = min(decay_ratio, 1.0)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))

    return min_lr + coeff * (config.learning_rate - min_lr)


@torch.no_grad()
def estimate_loss(model, config, eval_iters=50):
    """Estimate loss on train and validation sets."""
    model.eval()
    losses = {}
    for split in ['train', 'val']:
        total_loss = 0.0
        for _ in range(eval_iters):
            x, y = get_batch(split, config)
            _, loss = model(x, y)
            total_loss += loss.item()
        losses[split] = total_loss / eval_iters
    model.train()
    return losses


def train(model, config):
    """
    Main training loop.
    """
    # Set up optimizer
    # AdamW: Adam with decoupled weight decay regularization
    # We apply weight decay ONLY to weight matrices, not biases or LayerNorm
    param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad}

    # Separate parameters into those that need weight decay and those that don't
    decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
    no_decay_params = [p for n, p in param_dict.items() if p.dim() < 2]

    optimizer_groups = [
        {'params': decay_params, 'weight_decay': config.weight_decay},
        {'params': no_decay_params, 'weight_decay': 0.0},
    ]

    print(f"Decay params: {sum(p.numel() for p in decay_params):,}")
    print(f"No-decay params: {sum(p.numel() for p in no_decay_params):,}")

    optimizer = torch.optim.AdamW(
        optimizer_groups,
        lr=config.learning_rate,
        betas=(0.9, 0.95),  # Standard for LLM training
        eps=1e-8
    )

    # Training loop
    print(f"\nStarting training for {config.max_iters} iterations...")
    print("=" * 70)

    best_val_loss = float('inf')

    for step in range(config.max_iters):
        # Update learning rate
        lr = get_lr(step, config)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Get batch
        x, y = get_batch('train', config)

        # Forward pass
        logits, loss = model(x, y)

        # Backward pass
        optimizer.zero_grad(set_to_none=True)  # More efficient than zero_grad()
        loss.backward()

        # Gradient clipping (prevents exploding gradients)
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)

        # Update weights
        optimizer.step()

        # Logging
        if step % 100 == 0 or step == config.max_iters - 1:
            losses = estimate_loss(model, config)
            print(f"Step {step:5d}/{config.max_iters} | "
                  f"Train loss: {losses['train']:.4f} | "
                  f"Val loss: {losses['val']:.4f} | "
                  f"LR: {lr:.6f} | "
                  f"Grad norm: {grad_norm:.4f}")

            if losses['val'] < best_val_loss:
                best_val_loss = losses['val']

        # Generate sample text periodically
        if step % 500 == 0 and step > 0:
            model.eval()
            prompt = encode("\n")
            prompt_tensor = torch.tensor([prompt], dtype=torch.long)
            generated = model.generate(
                prompt_tensor,
                max_new_tokens=200,
                temperature=0.8,
                top_k=40
            )
            print(f"\n--- Generated text (step {step}) ---")
            print(decode(generated[0].tolist()))
            print("--- End generated text ---\n")
            model.train()

    print(f"\nTraining complete! Best validation loss: {best_val_loss:.4f}")
    return model


# Run training
# model = train(model, config)

# Expected output (approximate):
# Step     0/5000 | Train loss: 5.5432 | Val loss: 5.5398 | LR: 0.000000
# Step   100/5000 | Train loss: 3.2145 | Val loss: 3.2287 | LR: 0.000300
# Step   500/5000 | Train loss: 2.1234 | Val loss: 2.1456 | LR: 0.000298
# Step  1000/5000 | Train loss: 1.7823 | Val loss: 1.8234 | LR: 0.000285
# Step  2000/5000 | Train loss: 1.5234 | Val loss: 1.5987 | LR: 0.000230
# Step  3000/5000 | Train loss: 1.3456 | Val loss: 1.4567 | LR: 0.000160
# Step  4000/5000 | Train loss: 1.2345 | Val loss: 1.3789 | LR: 0.000085
# Step  5000/5000 | Train loss: 1.1789 | Val loss: 1.3456 | LR: 0.000030

# The model starts by generating gibberish and gradually learns:
# Step 0:    "k3j#@nf9$kLm..."
# Step 500:  "the the the and and..."
# Step 1000: "KING: What is the..."
# Step 3000: "ROMEO: O, speak again, bright angel..."
# Step 5000: Reasonably coherent Shakespeare-like text!
# ============================================================================
# TEXT GENERATION
# ============================================================================

"""
Generate text using our trained model.
"""

def generate_text(model, prompt="", max_tokens=500, temperature=0.8,
                  top_k=40, top_p=0.95):
    """
    Generate text from a prompt using our trained model.

    Args:
        model: Trained MiniGPT model
        prompt: Starting text (can be empty)
        max_tokens: Maximum number of tokens to generate
        temperature: Higher = more random, lower = more deterministic
        top_k: Only consider top-k tokens at each step
        top_p: Nucleus sampling threshold

    Returns:
        Generated text string
    """
    model.eval()

    # Encode prompt
    if prompt:
        tokens = encode(prompt)
    else:
        tokens = [ord('\n')]  # Start with newline

    idx = torch.tensor([tokens], dtype=torch.long)

    # Generate
    with torch.no_grad():
        generated = model.generate(
            idx,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p
        )

    return decode(generated[0].tolist())


# After training, generate text:
# print(generate_text(model, prompt="ROMEO: ", max_tokens=300))

# Expected output (after adequate training):
# ROMEO: O, she doth teach the torches to burn bright!
# It seems she hangs upon the cheek of night
# Like a rich jewel in an Ethiope's ear;
# Beauty too rich for use, for earth too dear!
# ...

# The model learns Shakespeare's style:
# - Character names followed by colons
# - Iambic pentameter (roughly)
# - Archaic English vocabulary
# - Dramatic themes

4. Model Training Concepts

4.1 Loss Function for Language Models

"""
Cross-Entropy Loss for Next Token Prediction
===============================================
This is the loss function used to train ALL autoregressive language models.
"""

import torch
import torch.nn.functional as F

# The model outputs logits of shape (batch, seq_len, vocab_size)
# For each position, we compute cross-entropy against the target token

# Example: vocab_size = 10 (tiny vocabulary for illustration)
vocab_size = 10
batch_size = 2
seq_len = 5

# Model outputs (raw logits, before softmax)
logits = torch.randn(batch_size, seq_len, vocab_size)

# Target: the correct next token at each position
targets = torch.randint(0, vocab_size, (batch_size, seq_len))

print(f"Logits shape: {logits.shape}")      # (2, 5, 10)
print(f"Targets shape: {targets.shape}")     # (2, 5)

# Cross-entropy loss
# We reshape to (batch*seq, vocab_size) and (batch*seq,) for F.cross_entropy
loss = F.cross_entropy(
    logits.view(-1, vocab_size),  # (10, 10) - flattened
    targets.view(-1)              # (10,)    - flattened
)

print(f"\nCross-entropy loss: {loss.item():.4f}")

# What does this loss mean?
# - Random prediction: loss = ln(vocab_size) = ln(10) ≈ 2.30
# - Perfect prediction: loss = 0
# - For a real model with 50K vocab: random loss ≈ 10.82

# Perplexity (common evaluation metric):
perplexity = torch.exp(loss)
print(f"Perplexity: {perplexity.item():.2f}")
# Perplexity = e^loss
# Interpretation: "the model is as confused as if it were choosing
# uniformly among X options" where X = perplexity
# Lower perplexity = better model
# GPT-4 achieves perplexity ~5-10 on common benchmarks

4.2 AdamW Optimizer

"""
AdamW: The Optimizer of Choice for Transformers
=================================================
"""

import torch

# Standard Adam update:
# m = β1 * m + (1 - β1) * grad           (momentum / first moment)
# v = β2 * v + (1 - β2) * grad^2         (adaptive learning rate / second moment)
# param = param - lr * m / (sqrt(v) + eps)

# AdamW adds DECOUPLED weight decay:
# param = param - lr * weight_decay * param  (applied SEPARATELY from Adam step)

# Why decoupled? In standard Adam+L2, weight decay interacts with the adaptive
# learning rate in unintended ways. AdamW fixes this.

# Typical hyperparameters for LLM training:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,                  # Peak learning rate
    betas=(0.9, 0.95),        # β1=0.9 (momentum), β2=0.95 (for LLMs; 0.999 in original)
    eps=1e-8,                 # Numerical stability
    weight_decay=0.1,         # Weight decay (L2 regularization)
)

# Important: weight decay should NOT be applied to:
# 1. Bias terms (they're already small)
# 2. LayerNorm parameters (they have a different role)
# Only apply to weight matrices (2D+ tensors)

# This is why we use parameter groups:
decay_params = [p for n, p in model.named_parameters()
                if p.dim() >= 2]  # Weight matrices
no_decay_params = [p for n, p in model.named_parameters()
                   if p.dim() < 2]  # Biases, LayerNorm

optimizer = torch.optim.AdamW([
    {'params': decay_params, 'weight_decay': 0.1},
    {'params': no_decay_params, 'weight_decay': 0.0},
], lr=3e-4, betas=(0.9, 0.95))

4.3 Learning Rate Scheduling

"""
Learning Rate Scheduling: Warmup + Cosine Decay
==================================================
The standard schedule for training LLMs.
"""

import math
import numpy as np

def cosine_schedule_with_warmup(step, total_steps, warmup_steps,
                                 peak_lr, min_lr_ratio=0.1):
    """
    Compute learning rate at a given step.

    Phase 1 (warmup): Linear increase from 0 to peak_lr
    Phase 2 (decay):  Cosine decay from peak_lr to min_lr

    Why warmup?
    - At the start of training, the model's parameters are random
    - Large learning rates with random weights cause unstable updates
    - Warmup allows the model to "warm up" before taking large steps

    Why cosine decay?
    - Gradual reduction allows fine-tuning in later stages
    - Smoother than step-wise decay
    - Proven effective in practice
    """
    min_lr = peak_lr * min_lr_ratio

    if step < warmup_steps:
        # Linear warmup
        return peak_lr * (step / warmup_steps)
    elif step >= total_steps:
        return min_lr
    else:
        # Cosine decay
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return min_lr + 0.5 * (peak_lr - min_lr) * (1 + math.cos(math.pi * progress))

# Demonstrate the schedule
total_steps = 10000
warmup_steps = 200
peak_lr = 3e-4

print("Learning Rate Schedule:")
print(f"{'Step':>8} {'LR':>12}")
print("-" * 22)
for step in [0, 50, 100, 150, 200, 500, 1000, 3000, 5000, 8000, 10000]:
    lr = cosine_schedule_with_warmup(step, total_steps, warmup_steps, peak_lr)
    bar = "=" * int(lr / peak_lr * 40)
    print(f"{step:>8} {lr:.6f} |{bar}")

4.4 Gradient Clipping

"""
Gradient Clipping
==================
Prevents exploding gradients by capping the gradient norm.
"""

import torch
import torch.nn as nn

# Without gradient clipping, a single bad batch can produce enormous
# gradients that destroy the model's learned weights.

# Two types of gradient clipping:

# 1. Clip by norm (most common for transformers)
# If ||grad|| > max_norm, scale all gradients so ||grad|| = max_norm
def clip_by_norm_example(model, max_norm=1.0):
    """
    Gradient norm clipping.

    Computes the total norm across ALL parameters:
    total_norm = sqrt(sum(||grad_i||^2))

    If total_norm > max_norm:
        scale all gradients by max_norm / total_norm
    """
    total_norm = torch.nn.utils.clip_grad_norm_(
        model.parameters(),
        max_norm=max_norm
    )
    return total_norm

# 2. Clip by value (less common)
# Clamp each gradient element to [-clip_value, clip_value]
def clip_by_value_example(model, clip_value=1.0):
    torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)

# In practice, GPT models use clip_grad_norm with max_norm=1.0
# This is applied AFTER loss.backward() and BEFORE optimizer.step():

# loss.backward()
# grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# optimizer.step()

print("Gradient clipping is essential for stable transformer training.")
print("Typical max_norm values: 0.5 - 1.0")
print("If you see training loss spike suddenly, gradient clipping may need adjustment.")

4.5 Mixed Precision Training

"""
Mixed Precision Training (FP16 / BF16)
=========================================
Use lower precision for faster computation while maintaining accuracy.
"""

import torch

# Standard training uses FP32 (32-bit floating point):
# - 1 sign bit + 8 exponent bits + 23 mantissa bits
# - Range: ±3.4 × 10^38
# - Memory: 4 bytes per parameter

# FP16 (16-bit floating point / Half precision):
# - 1 sign bit + 5 exponent bits + 10 mantissa bits
# - Range: ±65,504
# - Memory: 2 bytes per parameter (50% savings!)
# - Risk: limited range can cause overflow/underflow

# BF16 (Brain Float 16 / used in Google TPUs and modern GPUs):
# - 1 sign bit + 8 exponent bits + 7 mantissa bits
# - Same range as FP32 (±3.4 × 10^38) but less precision
# - Memory: 2 bytes per parameter
# - PREFERRED for LLM training (2024-2026 standard)

# Mixed precision: use FP16/BF16 for most operations, FP32 for critical ones
# - Forward/backward pass: FP16/BF16 (fast)
# - Weight updates: FP32 (accurate)
# - Loss scaling: prevents FP16 underflow

# PyTorch Automatic Mixed Precision (AMP):
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()  # Handles loss scaling for FP16

# Training step with mixed precision:
def mixed_precision_training_step(model, optimizer, x, y):
    """
    One training step with automatic mixed precision.

    This is ~2x faster than FP32 training on modern GPUs
    with negligible quality loss.
    """
    optimizer.zero_grad()

    # Forward pass in FP16/BF16
    with autocast(dtype=torch.float16):  # or torch.bfloat16
        logits, loss = model(x, y)

    # Backward pass with scaled loss (prevents FP16 underflow)
    scaler.scale(loss).backward()

    # Unscale gradients, then clip
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    # Optimizer step (in FP32)
    scaler.step(optimizer)
    scaler.update()

    return loss.item()

# Memory comparison:
print("Memory per parameter:")
print(f"  FP32: 4 bytes  | 1B params = 4 GB")
print(f"  FP16: 2 bytes  | 1B params = 2 GB")
print(f"  BF16: 2 bytes  | 1B params = 2 GB")
print(f"  FP8:  1 byte   | 1B params = 1 GB  (experimental, 2025+)")
print()
print("Training memory (with optimizer states):")
print(f"  FP32 training: ~16 bytes/param  | 7B model ≈ 112 GB")
print(f"  Mixed precision: ~10 bytes/param | 7B model ≈ 70 GB")
print(f"  With ZeRO-3: distributed across GPUs")

5. Understanding Model Parameters

5.1 How to Calculate Total Parameters

"""
Calculating Transformer Parameters
=====================================
Understand exactly where all the parameters live.
"""

def calculate_transformer_params(vocab_size, d_model, n_heads, n_layers,
                                  d_ff, max_seq_len, tie_weights=True):
    """
    Calculate the total number of parameters in a GPT-style transformer.

    Args:
        vocab_size: Size of vocabulary
        d_model: Model dimension
        n_heads: Number of attention heads
        n_layers: Number of transformer layers
        d_ff: FFN hidden dimension
        max_seq_len: Maximum sequence length
        tie_weights: Whether to tie input/output embeddings

    Returns:
        Dictionary with parameter breakdown
    """
    params = {}

    # ---- 1. Token Embedding ----
    # Maps each token to a d_model-dimensional vector
    # Shape: (vocab_size, d_model)
    params['token_embedding'] = vocab_size * d_model

    # ---- 2. Positional Embedding ----
    # Maps each position to a d_model-dimensional vector
    # Shape: (max_seq_len, d_model)
    params['position_embedding'] = max_seq_len * d_model

    # ---- Per Layer Parameters ----
    layer_params = {}

    # ---- 3. Multi-Head Attention ----
    # Q, K, V projections: each is (d_model, d_model)
    # We have 3 of them: W_Q, W_K, W_V
    layer_params['attention_qkv'] = 3 * d_model * d_model

    # Output projection: (d_model, d_model)
    layer_params['attention_output'] = d_model * d_model

    # Total attention per layer (no biases in modern models)
    layer_params['attention_total'] = 4 * d_model * d_model

    # ---- 4. Feed-Forward Network ----
    # W1: (d_model, d_ff) + bias (d_ff)
    layer_params['ffn_w1'] = d_model * d_ff + d_ff

    # W2: (d_ff, d_model) + bias (d_model)
    layer_params['ffn_w2'] = d_ff * d_model + d_model

    layer_params['ffn_total'] = layer_params['ffn_w1'] + layer_params['ffn_w2']

    # ---- 5. Layer Normalization ----
    # Two LayerNorms per layer: each has gamma (d_model) + beta (d_model)
    layer_params['layer_norm'] = 2 * 2 * d_model  # 2 norms * (gamma + beta)

    # ---- Total per layer ----
    layer_params['total_per_layer'] = (layer_params['attention_total'] +
                                        layer_params['ffn_total'] +
                                        layer_params['layer_norm'])

    # ---- 6. All layers ----
    params['all_layers'] = n_layers * layer_params['total_per_layer']

    # ---- 7. Final LayerNorm ----
    params['final_layer_norm'] = 2 * d_model

    # ---- 8. Output Projection ----
    if tie_weights:
        params['output_projection'] = 0  # Shared with token embedding
    else:
        params['output_projection'] = vocab_size * d_model

    # ---- TOTAL ----
    total = sum(params.values())

    return {
        'breakdown': params,
        'per_layer': layer_params,
        'total': total,
    }


# ========================================
# Example 1: Our Mini-GPT
# ========================================
print("=" * 70)
print("PARAMETER CALCULATION: Mini-GPT")
print("=" * 70)
result = calculate_transformer_params(
    vocab_size=256, d_model=384, n_heads=6,
    n_layers=6, d_ff=1536, max_seq_len=256
)
print(f"\nTotal parameters: {result['total']:,}")
print(f"\nBreakdown:")
for name, count in result['breakdown'].items():
    pct = count / result['total'] * 100
    print(f"  {name:25s}: {count:>12,} ({pct:5.1f}%)")

print(f"\nPer-layer breakdown:")
for name, count in result['per_layer'].items():
    print(f"  {name:25s}: {count:>12,}")


# ========================================
# Example 2: GPT-2 (124M)
# ========================================
print("\n" + "=" * 70)
print("PARAMETER CALCULATION: GPT-2 (124M)")
print("=" * 70)
result = calculate_transformer_params(
    vocab_size=50257, d_model=768, n_heads=12,
    n_layers=12, d_ff=3072, max_seq_len=1024
)
print(f"\nTotal parameters: {result['total']:,}")
print(f"Expected: ~124M")


# ========================================
# Example 3: Llama 3.1 8B (approximate)
# ========================================
print("\n" + "=" * 70)
print("PARAMETER CALCULATION: Llama 3.1 8B (approximate)")
print("=" * 70)

# Llama uses SwiGLU (3 FFN matrices instead of 2) and RMSNorm
# so the exact calculation differs slightly

def calculate_llama_params(vocab_size, d_model, n_heads, n_kv_heads,
                            n_layers, d_ff, max_seq_len):
    """
    Calculate parameters for Llama-style model.

    Llama differences from standard GPT:
    1. Grouped Query Attention (GQA): fewer K,V heads than Q heads
    2. SwiGLU FFN: 3 weight matrices instead of 2
    3. RMSNorm: no bias parameter
    4. No bias in attention projections
    """
    params = {}

    # Token embedding
    params['token_embedding'] = vocab_size * d_model

    # Per-layer attention with GQA
    d_k = d_model // n_heads
    # Q: full n_heads
    attn_q = d_model * (n_heads * d_k)
    # K, V: reduced to n_kv_heads (GQA)
    attn_k = d_model * (n_kv_heads * d_k)
    attn_v = d_model * (n_kv_heads * d_k)
    attn_o = (n_heads * d_k) * d_model

    per_layer_attn = attn_q + attn_k + attn_v + attn_o

    # SwiGLU FFN: gate_proj, up_proj, down_proj (3 matrices)
    per_layer_ffn = d_model * d_ff + d_model * d_ff + d_ff * d_model

    # RMSNorm (gamma only, no beta)
    per_layer_norm = 2 * d_model

    per_layer = per_layer_attn + per_layer_ffn + per_layer_norm
    params['all_layers'] = n_layers * per_layer

    # Final RMSNorm
    params['final_norm'] = d_model

    # Output projection (not tied in Llama)
    params['output_proj'] = vocab_size * d_model

    total = sum(params.values())

    return total, per_layer

# Llama 3.1 8B configuration
total, per_layer = calculate_llama_params(
    vocab_size=128256,
    d_model=4096,
    n_heads=32,
    n_kv_heads=8,    # GQA: 8 KV heads (4 Q heads per KV head)
    n_layers=32,
    d_ff=14336,       # SwiGLU intermediate size
    max_seq_len=131072  # 128K context
)

print(f"\nTotal parameters: {total:,}")
print(f"Total parameters: {total/1e9:.1f}B")
print(f"Parameters per layer: {per_layer:,}")
print(f"Expected: ~8.0B")


# ========================================
# General formula
# ========================================
print("\n" + "=" * 70)
print("GENERAL PARAMETER FORMULA")
print("=" * 70)
print("""
For a standard GPT model (no GQA, standard FFN):

Total ≈ vocab_size × d_model           (embeddings)
      + n_layers × [
          4 × d_model²                  (attention: Q, K, V, O)
          + 2 × d_model × d_ff          (FFN: up + down)
          + d_ff + d_model               (FFN biases)
          + 4 × d_model                  (2 LayerNorms × (gamma + beta))
        ]
      + 2 × d_model                     (final LayerNorm)

Simplified rule of thumb:
  Total ≈ 12 × n_layers × d_model²

For GPT-2 (12 layers, d_model=768):
  12 × 12 × 768² ≈ 85M (close to actual 124M including embeddings)
""")

Putting It All Together

"""
Complete Summary: Data Flow Through a GPT Model
==================================================

Let's trace a single input through the entire model.

Input: "The cat" (as byte tokens: [84, 104, 101, 32, 99, 97, 116])
Task: Predict the next character after "The cat"
"""

print("""
DATA FLOW THROUGH GPT
=======================

1. INPUT TOKENS
   "The cat" -> [84, 104, 101, 32, 99, 97, 116]
   Shape: (1, 7)  [batch=1, seq_len=7]

2. TOKEN EMBEDDING
   Each token ID -> 384-dimensional vector
   token_embedding[84]  -> [0.02, -0.01, 0.03, ...]  (for 'T')
   token_embedding[104] -> [-0.01, 0.02, 0.01, ...]  (for 'h')
   Shape: (1, 7, 384)

3. POSITIONAL EMBEDDING
   position_embedding[0] -> [0.01, 0.00, -0.02, ...]
   position_embedding[1] -> [0.00, 0.01, 0.01, ...]
   Shape: (1, 7, 384)

4. ADD: token_emb + pos_emb
   Shape: (1, 7, 384)

5. TRANSFORMER BLOCK 1 (of 6):

   5a. LayerNorm (normalize across d_model=384 dimension)
       Shape: (1, 7, 384)

   5b. Multi-Head Causal Self-Attention (6 heads, 64 dim each)
       - Project to Q, K, V: (1, 7, 384) -> Q,K,V each (1, 7, 384)
       - Split heads: (1, 7, 384) -> (1, 6, 7, 64)
       - Attention scores: (1, 6, 7, 64) @ (1, 6, 64, 7) = (1, 6, 7, 7)
       - Apply causal mask (upper triangle -> -inf)
       - Softmax -> attention weights (1, 6, 7, 7)
       - Weighted values: (1, 6, 7, 7) @ (1, 6, 7, 64) = (1, 6, 7, 64)
       - Combine heads: (1, 6, 7, 64) -> (1, 7, 384)
       - Output projection: (1, 7, 384) -> (1, 7, 384)

   5c. Residual: x = x + attention_output
       Shape: (1, 7, 384)

   5d. LayerNorm
       Shape: (1, 7, 384)

   5e. FFN: Linear(384->1536) -> GELU -> Linear(1536->384)
       Shape: (1, 7, 384) -> (1, 7, 1536) -> (1, 7, 384)

   5f. Residual: x = x + ffn_output
       Shape: (1, 7, 384)

6. TRANSFORMER BLOCKS 2-6: Same as above
   Shape remains: (1, 7, 384)

7. FINAL LAYER NORM
   Shape: (1, 7, 384)

8. OUTPUT PROJECTION (Linear: 384 -> 256)
   Shape: (1, 7, 256)  [256 = vocab_size]

9. TAKE LAST POSITION: logits[:, -1, :] = (1, 256)
   These are the logits for the next character after "The cat"

10. SOFTMAX + SAMPLING
    logits / temperature -> top-k/top-p filter -> softmax -> sample
    Result: token ID 32 (which is ' ' = space)

11. APPEND AND REPEAT
    New input: "The cat " -> predict next character -> 's'
    "The cat s" -> 'a'
    "The cat sa" -> 't'
    "The cat sat" -> ' '
    "The cat sat " -> 'o'
    "The cat sat o" -> 'n'
    ...
""")

# ============================================================================
# FINAL VERIFICATION: Run the complete model
# ============================================================================

# Verify the model works end-to-end
config = GPTConfig(
    vocab_size=256,
    max_seq_len=256,
    d_model=384,
    n_heads=6,
    n_layers=6,
    d_ff=1536,
    dropout=0.1,
)

model = MiniGPT(config)

# Test forward pass
x = torch.randint(0, 256, (2, 128))  # Random input
y = torch.randint(0, 256, (2, 128))  # Random targets

logits, loss = model(x, y)
print(f"Forward pass successful!")
print(f"  Input shape:  {x.shape}")
print(f"  Logits shape: {logits.shape}")
print(f"  Loss: {loss.item():.4f}")
print(f"  Expected random loss: ln(256) = {math.log(256):.4f}")

# Test generation
prompt = torch.randint(0, 256, (1, 10))
generated = model.generate(prompt, max_new_tokens=50, temperature=1.0)
print(f"\n  Generation input:  {prompt.shape}")
print(f"  Generation output: {generated.shape}")
print(f"  Generated text (random model): '{decode(generated[0].tolist())}'")
print(f"  (Gibberish expected from untrained model)")

# After training on Shakespeare, this would generate coherent text!

Week 4 Summary

Key Takeaways

  1. Causal masking uses a lower triangular matrix to prevent tokens from attending to future positions. Applied before softmax by setting masked positions to -infinity.
  2. Autoregressive generation produces text one token at a time. During training, teacher forcing provides ground truth tokens. During inference, the model uses its own predictions.
  3. A complete GPT model consists of: token + positional embeddings, N transformer blocks (each with causal attention + FFN + LayerNorm + residual connections), and an output projection to vocabulary logits.
  4. Cross-entropy loss on next-token prediction is the training objective. Perplexity = e^loss is a common evaluation metric.
  5. AdamW with warmup + cosine decay is the standard optimizer setup. Gradient clipping (max_norm=1.0) prevents training instability.
  6. Mixed precision (BF16) halves memory usage and doubles training speed with minimal quality impact.
  7. Parameters are dominated by attention (4 * d_model^2 per layer) and FFN (2 * d_model * d_ff per layer). Rule of thumb: total params approximately equal to 12 * n_layers * d_model^2.

Exercises

  1. Train the Mini-GPT model on Shakespeare (or any text of your choice). Experiment with different hyperparameters (d_model, n_layers, learning rate).
  2. Modify the model to use RMSNorm instead of LayerNorm. Does training stability change?
  3. Implement SwiGLU activation in the FFN. Compare training curves with GELU.
  4. Add Grouped Query Attention (GQA) where n_kv_heads < n_heads. How does this affect memory usage?
  5. Calculate the total parameters for GPT-3 (d_model=12288, n_heads=96, n_layers=96, d_ff=49152, vocab_size=50257). Does your calculation match the reported 175B?
  6. Implement KV-cache for faster inference (avoid recomputing K, V for previous tokens during generation).
  7. Train two models: one with 2 layers of d_model=512 and one with 8 layers of d_model=256 (similar parameter count). Which performs better?

Further Reading