Quantization and Fine-Tuning
Master the techniques that make large language models practical: KV caching for fast inference, quantization for memory reduction, attention optimizations for throughput, and parameter-efficient fine-tuning with LoRA and QLoRA.
Learning Objectives
KV Cache Mechanics
Understand how the KV cache accelerates autoregressive generation and how to calculate its memory requirements.
Quantization Mastery
Learn every major quantization format from FP32 to 1.58-bit, including GPTQ, AWQ, and GGUF.
Attention Optimizations
Understand FlashAttention, GQA, MQA, and PagedAttention for efficient inference.
LoRA and QLoRA
Master parameter-efficient fine-tuning to adapt LLMs to custom tasks on consumer hardware.
1. KV Cache
The KV (Key-Value) cache is one of the most important optimizations in LLM inference. Without it, autoregressive generation would be prohibitively slow. Understanding the KV cache is essential for anyone working with LLMs in production.
Why KV Cache Matters
During autoregressive generation, the model generates one token at a time. Each new token requires attending to all previous tokens. Without caching, this means recomputing the keys and values for all previous tokens at every generation step.
# Without KV Cache: O(n^2) total computation for n tokens
# Step 1: compute attention for token 1 -> 1 computation
# Step 2: compute attention for tokens 1,2 -> 2 computations
# Step 3: compute attention for tokens 1,2,3 -> 3 computations
# ...
# Step n: compute attention for all n tokens -> n computations
# Total: 1 + 2 + 3 + ... + n = n(n+1)/2 = O(n^2)
# With KV Cache: O(n) total computation
# Step 1: compute K,V for token 1, cache them -> 1 computation
# Step 2: compute K,V for token 2 only, attend to cached K,V -> 1 computation
# Step 3: compute K,V for token 3 only, attend to cached K,V -> 1 computation
# ...
# Step n: compute K,V for token n only, attend to cached K,V -> 1 computation
# Total: n computations = O(n)
How the KV Cache Works
In the transformer attention mechanism, each token produces three vectors: Query (Q), Key (K), and Value (V). The attention output is computed as:
Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
The key insight: for already-generated tokens, the K and V vectors never change. Only the newly generated token produces a new Q, K, V. So we can cache all previously computed K and V vectors and only compute the new ones.
Step-by-Step KV Cache Operation
- Prefill phase: Process the entire prompt at once. Compute K, V for all prompt tokens. Store them in the cache.
- Decode phase (step 1): Take the last token, compute its Q, K, V. Append new K, V to cache. Compute attention using new Q against all cached K, V.
- Decode phase (step 2): Take the newly generated token, compute its Q, K, V. Append to cache. Attend to all cached K, V.
- Repeat until end-of-sequence token or max length.
KV Cache Memory Calculations
The KV cache can consume significant memory, especially for long sequences and large batch sizes.
# KV Cache Memory Formula:
# memory = 2 * n_layers * d_model * seq_len * batch_size * bytes_per_param
#
# The "2" accounts for both Keys and Values
#
# Example: Llama 3.1 70B
# n_layers = 80
# d_model = 8192
# seq_len = 8192 (for 8K context)
# batch_size = 1
# bytes_per_param = 2 (FP16)
#
# KV cache = 2 * 80 * 8192 * 8192 * 1 * 2 bytes
# = 2 * 80 * 8192 * 8192 * 2
# = 21,474,836,480 bytes
# = ~21.5 GB for a SINGLE sequence!
#
# For Llama 3.1 70B with 128K context:
# KV cache = 2 * 80 * 8192 * 131072 * 1 * 2
# = ~343 GB -- more than the model weights!
# With Grouped Query Attention (GQA), the cache is smaller:
# Llama 3 70B uses 8 KV heads (vs 64 attention heads)
# Reduction factor: 8/64 = 1/8
# Effective KV cache for 8K context = 21.5 GB / 8 = ~2.7 GB
KV Cache is Often the Bottleneck
For long-context models (128K+ tokens), the KV cache often exceeds the size of the model weights. This is why KV cache optimization techniques (GQA, MQA, quantized KV cache, sliding window attention) are so critical for production deployment.
KV Cache Optimization Techniques
- Grouped Query Attention (GQA): Share K,V heads across multiple Q heads, reducing KV cache by the group factor. Used by Llama 2/3.
- Multi-Query Attention (MQA): All Q heads share a single K,V head. Maximum cache reduction but can impact quality. Used by Falcon.
- Quantized KV Cache: Store cached K,V in INT8 or INT4 instead of FP16, reducing memory by 2-4x with minimal quality loss.
- Sliding Window Attention: Only cache the last W tokens, discarding older ones. Used by Mistral (W=4096).
- PagedAttention (vLLM): Manage KV cache memory like virtual memory pages, avoiding fragmentation and enabling efficient batching.
- Multi-head Latent Attention (MLA): DeepSeek V3's approach -- compress K,V into a low-rank latent space before caching, dramatically reducing cache size.
Practical: Implement a Simple KV Cache
"""
KV Cache Implementation
=========================
Demonstrate how the KV cache works in transformer attention
by implementing it from scratch.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
from typing import Optional, Tuple
class CausalSelfAttention(nn.Module):
"""
Self-attention with KV cache support.
Demonstrates the difference between cached and uncached inference.
"""
def __init__(self, d_model: int, n_heads: int, max_seq_len: int = 2048):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.max_seq_len = max_seq_len
# Q, K, V projections
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(
self,
x: torch.Tensor,
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
"""
Forward pass with optional KV cache.
Args:
x: Input tensor of shape (batch_size, seq_len, d_model)
kv_cache: Tuple of (cached_keys, cached_values) or None
use_cache: Whether to return updated cache
Returns:
output: Attention output of shape (batch_size, seq_len, d_model)
new_cache: Updated KV cache (if use_cache=True)
"""
batch_size, seq_len, _ = x.shape
# Compute Q, K, V for current input
Q = self.W_q(x) # (B, seq_len, d_model)
K = self.W_k(x) # (B, seq_len, d_model)
V = self.W_v(x) # (B, seq_len, d_model)
# Reshape to (B, n_heads, seq_len, head_dim)
Q = Q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
# If we have a KV cache, concatenate with current K, V
if kv_cache is not None:
cached_K, cached_V = kv_cache
K = torch.cat([cached_K, K], dim=2) # Append new K to cache
V = torch.cat([cached_V, V], dim=2) # Append new V to cache
# Store updated cache if requested
new_cache = (K, V) if use_cache else None
# Compute attention scores
# Q: (B, n_heads, seq_len_q, head_dim)
# K: (B, n_heads, seq_len_kv, head_dim)
scale = math.sqrt(self.head_dim)
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
# attn_scores: (B, n_heads, seq_len_q, seq_len_kv)
# Apply causal mask (only needed during prefill, not during cached decode)
seq_len_kv = K.size(2)
if seq_len > 1:
# During prefill: create causal mask
causal_mask = torch.triu(
torch.ones(seq_len, seq_len_kv, device=x.device),
diagonal=seq_len_kv - seq_len + 1,
).bool()
attn_scores = attn_scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
# Softmax and weighted sum
attn_weights = F.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, V)
# output: (B, n_heads, seq_len, head_dim)
# Reshape back to (B, seq_len, d_model)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
output = self.W_o(output)
return output, new_cache
class SimpleTransformerBlock(nn.Module):
"""A simple transformer block with KV cache support."""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.attention = CausalSelfAttention(d_model, n_heads)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
)
def forward(self, x, kv_cache=None, use_cache=False):
# Pre-norm attention with residual
normed = self.norm1(x)
attn_out, new_cache = self.attention(normed, kv_cache, use_cache)
x = x + attn_out
# FFN with residual
x = x + self.ffn(self.norm2(x))
return x, new_cache
def benchmark_kv_cache():
"""
Benchmark inference with and without KV cache
to demonstrate the speedup.
"""
# Model configuration (small for demonstration)
d_model = 512
n_heads = 8
n_layers = 6
vocab_size = 1000
seq_len_prompt = 64
gen_length = 128
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create model layers
embedding = nn.Embedding(vocab_size, d_model).to(device)
blocks = nn.ModuleList([
SimpleTransformerBlock(d_model, n_heads).to(device)
for _ in range(n_layers)
])
lm_head = nn.Linear(d_model, vocab_size, bias=False).to(device)
# Create a sample prompt
prompt_ids = torch.randint(0, vocab_size, (1, seq_len_prompt), device=device)
# ===========================
# Method 1: WITHOUT KV Cache
# ===========================
print("=== Generating WITHOUT KV Cache ===")
start_time = time.time()
generated_ids = prompt_ids.clone()
with torch.no_grad():
for step in range(gen_length):
# Recompute everything from scratch each step
x = embedding(generated_ids)
for block in blocks:
x, _ = block(x, kv_cache=None, use_cache=False)
logits = lm_head(x[:, -1:, :])
next_token = logits.argmax(dim=-1)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
no_cache_time = time.time() - start_time
print(f"Time: {no_cache_time:.3f}s")
print(f"Tokens/sec: {gen_length / no_cache_time:.1f}")
# ===========================
# Method 2: WITH KV Cache
# ===========================
print("\n=== Generating WITH KV Cache ===")
start_time = time.time()
with torch.no_grad():
# Prefill: process entire prompt, build initial cache
x = embedding(prompt_ids)
caches = []
for block in blocks:
x, cache = block(x, kv_cache=None, use_cache=True)
caches.append(cache)
# Get last token logits for first generation step
logits = lm_head(x[:, -1:, :])
next_token = logits.argmax(dim=-1)
all_tokens = [next_token]
# Decode: process one token at a time using cache
for step in range(gen_length - 1):
x = embedding(next_token) # Only embed the new token
new_caches = []
for i, block in enumerate(blocks):
x, cache = block(x, kv_cache=caches[i], use_cache=True)
new_caches.append(cache)
caches = new_caches
logits = lm_head(x) # x is (1, 1, d_model)
next_token = logits.argmax(dim=-1)
all_tokens.append(next_token)
cache_time = time.time() - start_time
print(f"Time: {cache_time:.3f}s")
print(f"Tokens/sec: {gen_length / cache_time:.1f}")
speedup = no_cache_time / cache_time
print(f"\n=== KV Cache Speedup: {speedup:.2f}x ===")
# Calculate KV cache memory
# Per layer: 2 * d_model * total_seq_len * batch_size * bytes
total_seq_len = seq_len_prompt + gen_length
cache_bytes = 2 * n_layers * d_model * total_seq_len * 1 * 4 # FP32
print(f"KV Cache size: {cache_bytes / 1024:.1f} KB")
if __name__ == "__main__":
benchmark_kv_cache()
2. Quantization Deep Dive
Quantization is the process of reducing the numerical precision of a model's weights (and sometimes activations) to use fewer bits. This reduces memory usage, increases inference speed, and enables running large models on consumer hardware.
Why Quantize?
| Model | FP32 Memory | FP16 Memory | INT8 Memory | INT4 Memory |
|---|---|---|---|---|
| Llama 3.1 8B | 32 GB | 16 GB | 8 GB | 4 GB |
| Llama 3.1 70B | 280 GB | 140 GB | 70 GB | 35 GB |
| Llama 3.1 405B | 1,620 GB | 810 GB | 405 GB | ~203 GB |
INT4 quantization makes a 70B model fit on a single 48GB GPU (A6000 or dual-24GB consumer GPUs). Without quantization, you would need 2-4 high-end GPUs just to load the model.
Number Representations
Understanding the different number formats is crucial for grasping quantization tradeoffs.
FP32 (32-bit floating point)
# IEEE 754 FP32 Layout:
# [1 sign bit] [8 exponent bits] [23 mantissa bits]
# S EEEEEEEE MMMMMMMMMMMMMMMMMMMMMMM
#
# Range: ~1.2e-38 to ~3.4e+38
# Precision: ~7 decimal digits
# Memory: 4 bytes per parameter
#
# This is the "full precision" -- most neural networks are trained in FP32
# but it's wasteful for inference.
FP16 (16-bit floating point)
# IEEE 754 FP16 Layout:
# [1 sign bit] [5 exponent bits] [10 mantissa bits]
# S EEEEE MMMMMMMMMM
#
# Range: ~6.1e-5 to ~6.5e+4
# Precision: ~3.3 decimal digits
# Memory: 2 bytes per parameter (50% reduction)
#
# Problem: Limited range can cause overflow/underflow during training.
# The exponent range (5 bits = 0-31) is often insufficient for
# large gradients in training.
BF16 (Brain Floating Point 16)
# BF16 Layout (designed by Google Brain):
# [1 sign bit] [8 exponent bits] [7 mantissa bits]
# S EEEEEEEE MMMMMMM
#
# Range: Same as FP32 (~1.2e-38 to ~3.4e+38)
# Precision: ~2.4 decimal digits (lower than FP16)
# Memory: 2 bytes per parameter
#
# Key insight: BF16 keeps the same exponent range as FP32,
# sacrificing precision for range. This makes it much more
# stable for training than FP16.
#
# Used by: Llama 3, DeepSeek V3, most modern training runs
INT8 (8-bit integer)
# INT8 Layout:
# [8 bits representing signed integer]
# Range: -128 to 127
# No fractional part -- requires a scaling factor
#
# Quantization: float_value = scale * int8_value + zero_point
# Dequantization: int8_value = round((float_value - zero_point) / scale)
#
# Memory: 1 byte per parameter (75% reduction from FP32)
# Quality loss: Minimal for most models (~0.1-0.5% on benchmarks)
INT4 (4-bit integer)
# INT4 Layout:
# [4 bits representing signed integer]
# Range: -8 to 7 (or 0 to 15 unsigned)
# Only 16 possible values!
#
# Memory: 0.5 bytes per parameter (87.5% reduction from FP32)
# Quality loss: Noticeable but acceptable for many tasks (~1-3% on benchmarks)
#
# Two values packed per byte, requiring special packing/unpacking logic.
# Most popular format for running large models on consumer hardware.
1.58-bit (Ternary Quantization)
# 1.58-bit Quantization (BitNet b1.58):
# Each weight is one of {-1, 0, +1}
# log2(3) = 1.58 bits of information per weight
#
# Advantages:
# - Matrix multiplication becomes addition/subtraction (no multiply!)
# - Extreme memory efficiency
# - Very fast on specialized hardware
#
# Research paper: "The Era of 1-bit LLMs" (Microsoft, 2024)
# Still experimental but shows promising results, especially
# when models are trained natively in this format.
Post-Training Quantization (PTQ) vs Quantization-Aware Training (QAT)
Post-Training Quantization (PTQ)
Quantize a pre-trained model without any additional training. This is the most common approach because it is fast and requires no training data.
- Pros: Fast (minutes to hours), no training needed, works with any model
- Cons: Some quality degradation, especially at very low bit-widths (INT4)
- Methods: GPTQ, AWQ, Round-to-Nearest (RTN), SmoothQuant
Quantization-Aware Training (QAT)
Simulate quantization during training so the model learns to be robust to reduced precision.
- Pros: Better quality at low bit-widths, model adapts to quantization noise
- Cons: Requires full training pipeline, expensive, slower
- Methods: BitNet, AQLM, QAT fine-tuning
GPTQ Algorithm
GPTQ (2022) is one of the most popular PTQ methods for LLMs. It quantizes weights one layer at a time, using a small calibration dataset to minimize the quantization error.
# GPTQ Algorithm (simplified):
#
# For each layer:
# 1. Run a small calibration dataset through the model to get
# input activations for this layer
# 2. For each column of the weight matrix (one output feature):
# a. Find the optimal quantized value that minimizes the
# squared error: ||W*X - Q(W)*X||^2
# b. Use the Hessian (X^T * X) to account for which weights
# are most important (based on activation magnitudes)
# c. After quantizing one column, update remaining columns
# to compensate for the error introduced
# 3. Result: quantized weights that minimize the output error
#
# Key insight: Columns are processed in a specific order (by
# quantization difficulty), and errors are compensated greedily.
# This produces much better results than naive round-to-nearest.
AWQ (Activation-aware Weight Quantization)
AWQ (2023) takes a different approach: instead of trying to find the best quantized values, it identifies which weights are most important (based on activation magnitudes) and protects them.
# AWQ Algorithm (simplified):
#
# Key observation: A small fraction of weights (~1%) are much more
# important than others. These "salient" weights correspond to
# channels with large activation magnitudes.
#
# Approach:
# 1. Run calibration data to find activation magnitudes per channel
# 2. Identify salient channels (high activation magnitude)
# 3. Scale up salient channels before quantization (multiply by s > 1)
# 4. Scale down the corresponding output channels (divide by s)
# 5. This mathematically preserves the output while reducing
# quantization error for important weights
#
# Advantages over GPTQ:
# - Faster quantization (no iterative per-column optimization)
# - Better generalization (not overfit to calibration data)
# - Often slightly better quality at INT4
GGUF Format and llama.cpp
GGUF (GPT-Generated Unified Format) is the file format used by llama.cpp, the most popular framework for running LLMs on CPUs and consumer hardware.
llama.cpp Quantization Types
| Quant Type | Bits/Weight | Description | Quality | Speed |
|---|---|---|---|---|
| Q2_K | 2.5 | 2-bit with K-quant optimization | Poor | Fastest |
| Q3_K_M | 3.4 | 3-bit medium quality | Fair | Very Fast |
| Q4_K_M | 4.6 | 4-bit medium quality (most popular) | Good | Fast |
| Q5_K_M | 5.7 | 5-bit medium quality | Very Good | Moderate |
| Q6_K | 6.6 | 6-bit quantization | Excellent | Moderate |
| Q8_0 | 8.5 | 8-bit quantization | Near-perfect | Slow |
| F16 | 16 | Full half-precision | Perfect | Slowest |
The "K-quant" variants (K_S, K_M, K_L) use different quantization parameters for different layers -- giving more bits to the most important layers (first and last layers, attention layers) and fewer bits to less sensitive layers.
Practical: Quantize a Model with bitsandbytes
"""
Model Quantization with bitsandbytes
========================================
Quantize a model to INT8 and INT4, then compare
performance, memory usage, and output quality.
pip install transformers bitsandbytes accelerate torch
"""
import torch
import time
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
def get_model_size_mb(model) -> float:
"""Calculate model size in MB."""
param_bytes = sum(
p.nelement() * p.element_size() for p in model.parameters()
)
return param_bytes / (1024 * 1024)
def get_gpu_memory_mb() -> float:
"""Get current GPU memory usage in MB."""
if torch.cuda.is_available():
return torch.cuda.memory_allocated() / (1024 * 1024)
return 0
def generate_text(model, tokenizer, prompt, max_new_tokens=100):
"""Generate text and measure time."""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
start = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False, # Greedy for reproducibility
temperature=1.0,
)
elapsed = time.time() - start
generated_tokens = outputs.shape[1] - inputs["input_ids"].shape[1]
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return text, elapsed, generated_tokens
def main():
model_name = "meta-llama/Llama-3.2-1B" # Use a small model
test_prompt = "Explain the theory of relativity in simple terms:"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
results = {}
# ===========================
# 1. FP16 (baseline)
# ===========================
print("\n" + "=" * 60)
print("Loading model in FP16 (baseline)...")
print("=" * 60)
torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None
model_fp16 = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
)
gpu_mem = get_gpu_memory_mb()
text, elapsed, n_tokens = generate_text(model_fp16, tokenizer, test_prompt)
tokens_per_sec = n_tokens / elapsed
print(f"GPU Memory: {gpu_mem:.0f} MB")
print(f"Speed: {tokens_per_sec:.1f} tokens/sec")
print(f"Output: {text[:200]}...")
results["FP16"] = {
"memory_mb": gpu_mem,
"tokens_per_sec": tokens_per_sec,
"output": text,
}
del model_fp16
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# ===========================
# 2. INT8 Quantization
# ===========================
print("\n" + "=" * 60)
print("Loading model in INT8...")
print("=" * 60)
quantization_config_8bit = BitsAndBytesConfig(
load_in_8bit=True,
)
model_int8 = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config_8bit,
device_map="auto",
)
gpu_mem = get_gpu_memory_mb()
text, elapsed, n_tokens = generate_text(model_int8, tokenizer, test_prompt)
tokens_per_sec = n_tokens / elapsed
print(f"GPU Memory: {gpu_mem:.0f} MB")
print(f"Speed: {tokens_per_sec:.1f} tokens/sec")
print(f"Output: {text[:200]}...")
results["INT8"] = {
"memory_mb": gpu_mem,
"tokens_per_sec": tokens_per_sec,
"output": text,
}
del model_int8
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# ===========================
# 3. INT4 Quantization (NF4)
# ===========================
print("\n" + "=" * 60)
print("Loading model in INT4 (NF4 -- used by QLoRA)...")
print("=" * 60)
quantization_config_4bit = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat4
bnb_4bit_use_double_quant=True, # Double quantization
bnb_4bit_compute_dtype=torch.bfloat16, # Compute in BF16
)
model_int4 = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config_4bit,
device_map="auto",
)
gpu_mem = get_gpu_memory_mb()
text, elapsed, n_tokens = generate_text(model_int4, tokenizer, test_prompt)
tokens_per_sec = n_tokens / elapsed
print(f"GPU Memory: {gpu_mem:.0f} MB")
print(f"Speed: {tokens_per_sec:.1f} tokens/sec")
print(f"Output: {text[:200]}...")
results["INT4"] = {
"memory_mb": gpu_mem,
"tokens_per_sec": tokens_per_sec,
"output": text,
}
# ===========================
# Summary
# ===========================
print("\n" + "=" * 60)
print("QUANTIZATION COMPARISON SUMMARY")
print("=" * 60)
print(f"{'Format':<10} {'Memory (MB)':<15} {'Speed (tok/s)':<15} {'Mem Reduction':<15}")
print("-" * 55)
baseline_mem = results["FP16"]["memory_mb"]
for fmt, data in results.items():
reduction = (1 - data["memory_mb"] / baseline_mem) * 100 if baseline_mem > 0 else 0
print(
f"{fmt:<10} {data['memory_mb']:<15.0f} "
f"{data['tokens_per_sec']:<15.1f} "
f"{reduction:<15.1f}%"
)
if __name__ == "__main__":
main()
Practical: Run a Quantized Model with llama.cpp
# Install llama.cpp
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
make -j$(nproc)
# Download a GGUF model (e.g., from HuggingFace)
# Many models are available pre-quantized on HuggingFace:
# https://huggingface.co/TheBloke (for older models)
# https://huggingface.co/bartowski (for newer models)
# Example: Download Llama 3.2 1B in Q4_K_M format
huggingface-cli download bartowski/Llama-3.2-1B-Instruct-GGUF \
Llama-3.2-1B-Instruct-Q4_K_M.gguf \
--local-dir ./models
# Run inference
./llama-cli \
-m ./models/Llama-3.2-1B-Instruct-Q4_K_M.gguf \
-p "Explain quantum computing:" \
-n 256 \
--temp 0.7 \
-ngl 99 # Offload all layers to GPU (if available)
# Run as a server (OpenAI-compatible API)
./llama-server \
-m ./models/Llama-3.2-1B-Instruct-Q4_K_M.gguf \
--host 0.0.0.0 \
--port 8080 \
-ngl 99
# To quantize your own model from HuggingFace format:
# 1. Convert to GGUF
python convert_hf_to_gguf.py /path/to/hf/model --outfile model-f16.gguf
# 2. Quantize to desired format
./llama-quantize model-f16.gguf model-q4_k_m.gguf Q4_K_M
# Using llama.cpp from Python with llama-cpp-python
# pip install llama-cpp-python
from llama_cpp import Llama
# Load quantized model
llm = Llama(
model_path="./models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
n_ctx=4096, # Context window
n_gpu_layers=-1, # Use all GPU layers
verbose=False,
)
# Generate text
response = llm.create_chat_completion(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is machine learning?"},
],
max_tokens=256,
temperature=0.7,
)
print(response["choices"][0]["message"]["content"])
3. Attention Optimizations
Standard self-attention is O(n^2) in sequence length and memory-hungry. Several optimizations have been developed to make attention faster and more memory-efficient.
FlashAttention
FlashAttention (Dao et al., 2022) is an IO-aware exact attention algorithm that is both faster and more memory-efficient than standard attention. It achieves this by restructuring the computation to minimize memory reads/writes between GPU HBM (high-bandwidth memory) and SRAM (on-chip memory).
The Core Idea: Tiling
# Standard Attention Memory Access Pattern:
# 1. Load Q, K from HBM to SRAM
# 2. Compute S = Q @ K^T -> write S to HBM (n x n matrix!)
# 3. Load S from HBM, compute P = softmax(S) -> write P to HBM
# 4. Load P, V from HBM, compute O = P @ V -> write O to HBM
#
# Total HBM reads/writes: O(n^2) -- dominated by the n x n attention matrix
# The attention matrix S (n x n) must be materialized in HBM
# FlashAttention Memory Access Pattern:
# Process Q, K, V in TILES that fit in SRAM:
#
# For each tile of Q (size B_r):
# For each tile of K, V (size B_c):
# 1. Load Q tile, K tile, V tile from HBM to SRAM
# 2. Compute partial attention IN SRAM (no HBM write for S!)
# 3. Update running softmax statistics
# 4. Accumulate partial output
# Write final output tile to HBM
#
# Total HBM reads/writes: O(n^2 * d / SRAM_size)
# The n x n attention matrix is NEVER materialized in HBM!
#
# Memory: O(n) instead of O(n^2) -- massive savings for long sequences
The key mathematical insight is that softmax can be computed incrementally (the "online softmax" trick), allowing FlashAttention to process tiles without ever storing the full attention matrix.
FlashAttention Versions
| Version | Year | Key Improvement |
|---|---|---|
| FlashAttention | 2022 | IO-aware tiled attention, 2-4x speedup |
| FlashAttention-2 | 2023 | Better work partitioning, ~2x faster than FA-1 |
| FlashAttention-3 | 2024 | Optimized for H100 GPUs, FP8 support, asynchronous execution |
Grouped Query Attention (GQA) and Multi-Query Attention (MQA)
Standard multi-head attention uses separate K, V projections for each head. GQA and MQA reduce the number of K, V heads to save memory and increase throughput.
# Standard Multi-Head Attention (MHA):
# n_heads = 32, n_kv_heads = 32
# Each attention head has its own Q, K, V
# KV cache per layer: 2 * 32 * head_dim * seq_len
# Multi-Query Attention (MQA):
# n_heads = 32, n_kv_heads = 1
# All Q heads share 1 K and 1 V head
# KV cache per layer: 2 * 1 * head_dim * seq_len (32x smaller!)
# Used by: Falcon, PaLM
# Downside: Can reduce quality, especially on complex tasks
# Grouped Query Attention (GQA):
# n_heads = 32, n_kv_heads = 8
# Every 4 Q heads share 1 K and 1 V head
# KV cache per layer: 2 * 8 * head_dim * seq_len (4x smaller)
# Used by: Llama 2/3, Mistral
# Best of both worlds: significant cache savings with minimal quality loss
# Visual representation:
# MHA: Q1-K1-V1 Q2-K2-V2 Q3-K3-V3 Q4-K4-V4 (4 KV heads)
# GQA: Q1-K1-V1 Q2-K1-V1 Q3-K2-V2 Q4-K2-V2 (2 KV heads)
# MQA: Q1-K1-V1 Q2-K1-V1 Q3-K1-V1 Q4-K1-V1 (1 KV head)
Sliding Window Attention (Mistral)
Instead of attending to all previous tokens, each token only attends to the last W tokens. This bounds the KV cache to size W regardless of sequence length.
# Sliding Window Attention (Mistral, W=4096):
#
# Token at position i attends to tokens [i-W, i]
#
# Advantages:
# - Fixed KV cache size: O(W) instead of O(n)
# - Inference memory is constant regardless of sequence length
# - Information can still propagate through layers:
# After L layers, effective receptive field = L * W
# For Mistral (32 layers, W=4096): 32 * 4096 = 131,072 tokens
#
# Disadvantage:
# - Direct attention to distant tokens is lost
# - Not suitable for tasks requiring precise long-range retrieval
PagedAttention (vLLM)
PagedAttention (Kwon et al., 2023) manages KV cache memory like an operating system manages virtual memory. This is crucial for serving multiple requests simultaneously.
# The Problem:
# When serving multiple requests with different sequence lengths,
# pre-allocating KV cache for max_seq_len wastes huge amounts of memory.
#
# Request 1: 500 tokens (allocated 2048) -> 75% waste
# Request 2: 1800 tokens (allocated 2048) -> 12% waste
# Request 3: 100 tokens (allocated 2048) -> 95% waste
#
# PagedAttention Solution:
# - Divide KV cache into fixed-size "pages" (blocks)
# - Each block holds KV vectors for a fixed number of tokens
# - Pages are allocated on-demand (like virtual memory)
# - A page table maps logical token positions to physical memory
#
# Benefits:
# - Near-zero memory waste (only last block may be partially filled)
# - Memory sharing: multiple sequences from same prompt share pages
# - Efficient batch scheduling: more requests fit in memory
# - Foundation of vLLM's high-throughput serving
Practical: Efficient Inference with vLLM
"""
High-Performance LLM Serving with vLLM
==========================================
vLLM uses PagedAttention for efficient KV cache management,
enabling high-throughput serving of LLMs.
pip install vllm
"""
from vllm import LLM, SamplingParams
# ===========================
# Basic vLLM Usage
# ===========================
# Load model with vLLM (automatically applies PagedAttention)
llm = LLM(
model="meta-llama/Llama-3.2-1B-Instruct",
dtype="auto", # Auto-detect best dtype
gpu_memory_utilization=0.9, # Use 90% of GPU memory
max_model_len=4096, # Maximum sequence length
# quantization="awq", # Use AWQ quantized model
# tensor_parallel_size=2, # Use 2 GPUs with tensor parallelism
)
# Define sampling parameters
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=256,
stop=["<|eot_id|>"],
)
# Batch inference -- vLLM handles batching efficiently
prompts = [
"What is quantum computing?",
"Explain the theory of relativity.",
"How do neural networks learn?",
"What causes the northern lights?",
"Describe the water cycle.",
]
# Generate responses for all prompts
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated = output.outputs[0].text
print(f"Prompt: {prompt[:50]}...")
print(f"Response: {generated[:200]}...")
print(f"Tokens generated: {len(output.outputs[0].token_ids)}")
print("-" * 60)
# ===========================
# vLLM as OpenAI-compatible server
# ===========================
# Run from command line:
# python -m vllm.entrypoints.openai.api_server \
# --model meta-llama/Llama-3.2-1B-Instruct \
# --port 8000
#
# Then use the OpenAI client:
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="not-needed", # vLLM doesn't require an API key
)
response = client.chat.completions.create(
model="meta-llama/Llama-3.2-1B-Instruct",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is machine learning?"},
],
max_tokens=256,
temperature=0.7,
)
print(response.choices[0].message.content)
4. Fine-Tuning Methods
Fine-tuning adapts a pre-trained model to a specific task or domain. The challenge is that full fine-tuning of large models requires enormous GPU memory. Parameter-Efficient Fine-Tuning (PEFT) methods solve this problem.
Full Fine-Tuning vs Parameter-Efficient Methods
| Method | Trainable Params | Memory Required (7B model) | Quality |
|---|---|---|---|
| Full Fine-Tuning | 100% | ~112 GB (FP16 + Adam) | Best |
| LoRA (r=16) | ~0.1-1% | ~16 GB | Very Good |
| QLoRA (4-bit + LoRA) | ~0.1-1% | ~6-8 GB | Good |
LoRA (Low-Rank Adaptation)
LoRA is the most important PEFT technique. Introduced by Hu et al. (2021), it is based on the insight that weight updates during fine-tuning have low intrinsic rank.
The Math Behind LoRA
# Standard fine-tuning updates the full weight matrix:
# W' = W_0 + delta_W
# where W_0 is the pre-trained weight and delta_W is the update
# delta_W has shape (d, k) -- same as the original weight
# LoRA insight: delta_W is approximately low-rank
# So we decompose it: delta_W = B @ A
# where B has shape (d, r) and A has shape (r, k)
# r << min(d, k) is the rank
# Example for a 4096 x 4096 weight matrix with rank r=16:
# Full fine-tuning trainable params: 4096 * 4096 = 16,777,216
# LoRA trainable params: (4096 * 16) + (16 * 4096) = 131,072
# Reduction: 128x fewer trainable parameters!
# During forward pass:
# h = W_0 @ x + B @ A @ x
# = W_0 @ x + delta_W @ x
#
# Key: W_0 is FROZEN (no gradients computed)
# Only B and A are trained
# Scaling factor alpha/r is applied: h = W_0 @ x + (alpha/r) * B @ A @ x
# Initialization:
# A is initialized with random Gaussian values
# B is initialized to zero
# This means delta_W = B @ A = 0 at the start (training begins from pre-trained weights)
Why LoRA Works
Research has shown that when fine-tuning large models on downstream tasks, the weight updates occupy a very low-dimensional subspace. The intrinsic dimension of common NLP tasks is surprisingly small (often rank 1-64 is sufficient). This means we can capture most of the fine-tuning benefit with a tiny fraction of the parameters.
Choosing LoRA Hyperparameters
| Parameter | Description | Recommended Values |
|---|---|---|
| r (rank) | Rank of the decomposition | 8-64 (16 is a good default) |
| alpha | Scaling factor (alpha/r is the effective scale) | Usually 2x rank (e.g., alpha=32 for r=16) |
| target_modules | Which weight matrices to apply LoRA to | q_proj, v_proj minimum; all linear layers for best results |
| dropout | Dropout on LoRA layers | 0.05-0.1 |
QLoRA: Quantized LoRA
QLoRA (Dettmers et al., 2023) combines 4-bit quantization with LoRA for extreme memory efficiency. The base model is loaded in 4-bit NormalFloat (NF4) precision, while LoRA adapters are trained in FP16/BF16.
# QLoRA Architecture:
#
# Base Model (FROZEN, 4-bit NF4):
# [W_0 in INT4] ----> dequantize to BF16 for computation
#
# LoRA Adapters (TRAINABLE, BF16):
# [B (d x r) in BF16] @ [A (r x k) in BF16]
#
# Forward pass:
# x_input (BF16)
# |
# |--> dequantize(W_0) @ x --> base output (BF16)
# |--> B @ A @ x --> lora output (BF16)
# |
# sum --> output (BF16)
#
# Key innovations:
# 1. NF4: Data type optimized for normally distributed weights
# 2. Double quantization: Quantize the quantization constants too
# 3. Paged optimizers: Use CPU RAM for optimizer states when GPU is full
# Memory comparison for Llama 3.1 70B:
# Full FT: 70B * 18 bytes (param + grad + Adam states) = ~1,260 GB
# LoRA: 70B * 2 bytes (frozen FP16) + LoRA params = ~140 GB
# QLoRA: 70B * 0.5 bytes (INT4) + LoRA params = ~35 GB + LoRA
# Fits on a single 48GB GPU!
Practical: Fine-tune Llama 3.2 1B with QLoRA
Hands-On Project
This is a complete, end-to-end QLoRA fine-tuning project. We will fine-tune Llama 3.2 1B on a custom instruction dataset, covering dataset preparation, training configuration, the training loop, evaluation, and adapter merging.
"""
Complete QLoRA Fine-Tuning Pipeline
======================================
Fine-tune Llama 3.2 1B on a custom instruction dataset
using QLoRA (4-bit quantization + LoRA adapters).
Requirements:
pip install transformers datasets peft trl bitsandbytes accelerate torch
Hardware: Works on a single GPU with 8+ GB VRAM (e.g., RTX 3070)
"""
import torch
from datasets import Dataset, load_dataset
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
TrainingArguments,
)
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
PeftModel,
TaskType,
)
from trl import SFTTrainer, SFTConfig
import json
# ==================================================
# STEP 1: Dataset Preparation
# ==================================================
def create_custom_dataset():
"""
Create a custom instruction-following dataset.
In practice, you would load from a file or HuggingFace datasets.
"""
examples = [
{
"instruction": "Classify the sentiment of this review.",
"input": "This product exceeded my expectations. The quality is outstanding and shipping was fast.",
"output": "Positive. The review expresses satisfaction with the product quality and shipping speed."
},
{
"instruction": "Translate the following English text to French.",
"input": "The weather is beautiful today.",
"output": "Le temps est magnifique aujourd'hui."
},
{
"instruction": "Write a SQL query to find all customers who made purchases over $100 in the last month.",
"input": "",
"output": """```sql
SELECT DISTINCT c.customer_id, c.name, c.email
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
WHERE o.total_amount > 100
AND o.order_date >= DATE_SUB(CURRENT_DATE, INTERVAL 1 MONTH)
ORDER BY c.name;
```
This query joins the customers and orders tables, filters for orders over $100 in the last 30 days, and returns unique customer records sorted by name."""
},
{
"instruction": "Explain the concept of recursion to a beginner programmer.",
"input": "",
"output": """Recursion is a programming technique where a function calls itself to solve a problem. Think of it like Russian nesting dolls -- each doll contains a smaller version of itself.
Here is how it works:
1. **Base case**: A simple condition where the function returns a result without calling itself (the smallest doll).
2. **Recursive case**: The function calls itself with a simpler version of the problem.
Example -- calculating factorial:
```python
def factorial(n):
if n <= 1: # Base case
return 1
return n * factorial(n - 1) # Recursive case
```
`factorial(4)` works like this:
- factorial(4) = 4 * factorial(3)
- factorial(3) = 3 * factorial(2)
- factorial(2) = 2 * factorial(1)
- factorial(1) = 1 (base case!)
- Then it "unwinds": 2 * 1 = 2, 3 * 2 = 6, 4 * 6 = 24
Key rules: Always have a base case, and ensure each recursive call moves toward the base case to avoid infinite loops."""
},
{
"instruction": "Summarize the key differences between TCP and UDP.",
"input": "",
"output": """TCP (Transmission Control Protocol) and UDP (User Datagram Protocol) are both transport layer protocols, but they differ in several key ways:
| Feature | TCP | UDP |
|---------|-----|-----|
| Connection | Connection-oriented (handshake required) | Connectionless |
| Reliability | Guaranteed delivery with acknowledgments | No delivery guarantee |
| Ordering | Maintains packet order | No ordering guarantee |
| Speed | Slower (overhead from reliability features) | Faster (minimal overhead) |
| Use cases | Web browsing, email, file transfer | Video streaming, gaming, DNS |
Choose TCP when data integrity matters (web, file transfer). Choose UDP when speed matters more than reliability (real-time video, gaming)."""
},
] * 100 # Repeat for more training data
return Dataset.from_list(examples)
def format_for_llama(example, tokenizer):
"""
Format an instruction example into Llama 3 chat format.
"""
messages = [
{"role": "system", "content": "You are a helpful AI assistant. Provide clear, accurate, and detailed responses."},
]
if example["input"]:
user_msg = f"{example['instruction']}\n\nInput: {example['input']}"
else:
user_msg = example["instruction"]
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": example["output"]})
# Apply the model's chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
return text
# ==================================================
# STEP 2: Model and Tokenizer Setup
# ==================================================
def setup_model_and_tokenizer(model_name: str = "meta-llama/Llama-3.2-1B-Instruct"):
"""Load the model with 4-bit quantization and prepare for QLoRA."""
# Quantization configuration for QLoRA
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat4 -- optimal for normally distributed weights
bnb_4bit_use_double_quant=True, # Quantize the quantization constants (saves ~0.4 bits/param)
bnb_4bit_compute_dtype=torch.bfloat16, # Use BF16 for computation
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Required for training
# Load model with quantization
print(f"Loading {model_name} with 4-bit quantization...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
attn_implementation="sdpa", # Use scaled dot-product attention
)
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=True, # Trade compute for memory
)
return model, tokenizer
# ==================================================
# STEP 3: LoRA Configuration
# ==================================================
def setup_lora(model):
"""Configure and apply LoRA adapters."""
lora_config = LoraConfig(
r=16, # Rank -- lower = fewer params, higher = more capacity
lora_alpha=32, # Scaling factor (alpha/r = effective scale)
target_modules=[ # Which modules to apply LoRA to
"q_proj", # Query projection
"k_proj", # Key projection
"v_proj", # Value projection
"o_proj", # Output projection
"gate_proj", # MLP gate projection
"up_proj", # MLP up projection
"down_proj", # MLP down projection
],
lora_dropout=0.05, # Dropout for regularization
bias="none", # Don't train bias terms
task_type=TaskType.CAUSAL_LM, # Task type
)
# Apply LoRA to the model
model = get_peft_model(model, lora_config)
# Print trainable parameters
trainable, total = model.get_nb_trainable_parameters()
print(f"\nTrainable parameters: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)")
return model
# ==================================================
# STEP 4: Training
# ==================================================
def train_model(model, tokenizer, dataset, output_dir="./qlora_output"):
"""Train the model with QLoRA."""
# Format the dataset
def formatting_func(example):
return format_for_llama(example, tokenizer)
# Training configuration
training_args = SFTConfig(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # Effective batch size = 4 * 4 = 16
learning_rate=2e-4, # Higher LR for LoRA than full FT
weight_decay=0.01,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
logging_steps=10,
save_strategy="epoch",
save_total_limit=2, # Keep only 2 best checkpoints
bf16=True, # Use BF16 for training
optim="paged_adamw_8bit", # Memory-efficient optimizer
max_grad_norm=0.3, # Gradient clipping
max_seq_length=1024, # Maximum sequence length
packing=True, # Pack short sequences together
gradient_checkpointing=True, # Trade compute for memory
gradient_checkpointing_kwargs={
"use_reentrant": False,
},
dataset_text_field=None,
report_to="none", # Set to "wandb" for logging
)
# Create trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
formatting_func=formatting_func,
)
# Train!
print("\nStarting QLoRA training...")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
result = trainer.train()
# Print training results
print(f"\nTraining complete!")
print(f"Total steps: {result.global_step}")
print(f"Training loss: {result.training_loss:.4f}")
# Save the LoRA adapters (not the full model)
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"LoRA adapters saved to {output_dir}")
return trainer
# ==================================================
# STEP 5: Evaluation
# ==================================================
def evaluate_model(model, tokenizer, test_prompts=None):
"""Evaluate the fine-tuned model on test prompts."""
if test_prompts is None:
test_prompts = [
"Classify the sentiment: 'I absolutely love this new restaurant, the food was amazing!'",
"Write a Python function to reverse a linked list.",
"Explain the difference between SQL JOIN types.",
"Summarize what containerization is in one paragraph.",
]
model.eval()
print("\n" + "=" * 60)
print("MODEL EVALUATION")
print("=" * 60)
for prompt in test_prompts:
messages = [
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": prompt},
]
input_text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
do_sample=True,
top_p=0.9,
)
response = tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)
print(f"\nPrompt: {prompt}")
print(f"Response: {response[:300]}...")
print("-" * 60)
# ==================================================
# STEP 6: Merge Adapters and Export
# ==================================================
def merge_and_export(
base_model_name: str,
adapter_path: str,
output_path: str = "./merged_model",
):
"""
Merge LoRA adapters back into the base model for deployment.
The merged model can be used without PEFT and even quantized with GGUF.
"""
print(f"Loading base model: {base_model_name}")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
print(f"Loading LoRA adapters from: {adapter_path}")
model = PeftModel.from_pretrained(base_model, adapter_path)
print("Merging adapters into base model...")
merged_model = model.merge_and_unload()
print(f"Saving merged model to: {output_path}")
merged_model.save_pretrained(output_path)
tokenizer = AutoTokenizer.from_pretrained(adapter_path)
tokenizer.save_pretrained(output_path)
print("Merge complete! The model can now be used without PEFT.")
print(f"To convert to GGUF for llama.cpp:")
print(f" python convert_hf_to_gguf.py {output_path} --outfile model.gguf")
print(f" ./llama-quantize model.gguf model-q4_k_m.gguf Q4_K_M")
# ==================================================
# MAIN
# ==================================================
def main():
# Configuration
model_name = "meta-llama/Llama-3.2-1B-Instruct"
output_dir = "./qlora_output"
# Step 1: Prepare dataset
print("Step 1: Creating dataset...")
dataset = create_custom_dataset()
print(f"Dataset size: {len(dataset)} examples")
# Step 2: Load model and tokenizer
print("\nStep 2: Loading model with 4-bit quantization...")
model, tokenizer = setup_model_and_tokenizer(model_name)
# Step 3: Apply LoRA
print("\nStep 3: Applying LoRA adapters...")
model = setup_lora(model)
# Step 4: Train
print("\nStep 4: Training with QLoRA...")
trainer = train_model(model, tokenizer, dataset, output_dir)
# Step 5: Evaluate
print("\nStep 5: Evaluating fine-tuned model...")
evaluate_model(model, tokenizer)
# Step 6: Merge (optional -- for deployment)
# Uncomment to merge adapters into base model:
# print("\nStep 6: Merging adapters...")
# merge_and_export(model_name, output_dir, "./merged_model")
print("\n" + "=" * 60)
print("QLORA FINE-TUNING PIPELINE COMPLETE!")
print("=" * 60)
if __name__ == "__main__":
main()
5. When to Fine-tune vs Prompt Engineering vs RAG
One of the most common questions AI engineers face is: should I fine-tune a model, use prompt engineering, or build a RAG system? Here is a decision framework.
Decision Matrix
| Factor | Prompt Engineering | RAG | Fine-Tuning |
|---|---|---|---|
| Setup Cost | Minimal | Medium | High |
| Ongoing Cost | Higher (more tokens per request) | Medium (retrieval + generation) | Lower (shorter prompts) |
| Up-to-date Knowledge | No | Yes (update index) | No (unless retrained) |
| Custom Behavior | Limited | Limited | Strong |
| Domain Expertise | Limited (by context window) | Good (if documents available) | Excellent |
| Latency | Low | Medium (retrieval overhead) | Low |
| Hallucination Risk | High | Low (grounded in documents) | Medium |
| Data Requirements | None | Documents/knowledge base | Labeled training data |
Use Case Guide
Choose Prompt Engineering when:
- Quick prototyping and experimentation
- The task can be explained with examples in the prompt
- You need flexibility to change behavior rapidly
- The base model already knows the domain well
Choose RAG when:
- You need to answer questions about specific documents
- Knowledge changes frequently
- You need citations and verifiable answers
- You have a large corpus of domain-specific documents
Choose Fine-Tuning when:
- You need a specific output format or style consistently
- You have proprietary domain knowledge not in the base model
- You need to reduce per-request costs (shorter prompts)
- You have enough labeled training data (hundreds to thousands of examples)
- Combine with RAG for best results
Summary and Key Takeaways
Week 6 Key Takeaways
- KV cache is essential for fast autoregressive generation, turning O(n^2) into O(n). Understanding its memory footprint is critical for production deployment.
- Quantization enables practical deployment: INT4 quantization reduces memory by 8x with 1-3% quality loss. GPTQ, AWQ, and GGUF are the three main approaches.
- FlashAttention provides 2-4x speedups by being IO-aware -- restructuring computation to minimize memory reads/writes.
- GQA reduces KV cache by sharing K,V heads across query heads. Most modern models (Llama 3, Mistral) use GQA.
- LoRA achieves 99%+ of full fine-tuning quality with less than 1% of the trainable parameters, by exploiting the low-rank nature of weight updates.
- QLoRA makes fine-tuning accessible: 4-bit quantization + LoRA enables fine-tuning 70B models on a single 48GB GPU.
- Choose wisely: Prompt engineering for prototyping, RAG for knowledge-intensive tasks with changing data, fine-tuning for custom behavior and style.
Next Steps
In Week 7: Retrieval Augmented Generation, we will build complete RAG systems from scratch. You will learn about document chunking strategies, vector embeddings, vector databases, and advanced RAG techniques like hybrid search, reranking, and GraphRAG.