Week 13 of 16

Image and Video Models

From Convolutional Neural Networks to Vision Transformers, CLIP, Multimodal Models, and the Frontier of Video Understanding

Advanced Estimated: 15-20 hours

Learning Objectives

Understand the Evolution

Trace the journey from CNNs to Vision Transformers and understand why the field made this shift.

Master GANs

Implement and understand Generative Adversarial Networks, including their training dynamics and limitations.

Build with ViT and CLIP

Implement Vision Transformers and use CLIP for zero-shot classification and multimodal search.

Work with Multimodal Models

Build applications using GPT-4o, Claude, and other vision-language models for real-world tasks.

1. Computer Vision Fundamentals

Computer vision has undergone one of the most dramatic architectural shifts in all of deep learning. In just a few years, the dominant paradigm moved from Convolutional Neural Networks (CNNs), which had reigned supreme since AlexNet in 2012, to Vision Transformers (ViTs), which adapt the Transformer architecture from NLP to images. Understanding this evolution is critical for any AI engineer working with visual data.

1.1 Convolutional Neural Networks Recap

CNNs exploit three key ideas that make them well-suited for image processing:

Key CNN Concepts

  • Local connectivity: Each neuron connects only to a small region of the input (the receptive field), rather than every pixel. This dramatically reduces parameters.
  • Weight sharing: The same filter (kernel) is applied across the entire image, meaning the network learns translation-invariant features.
  • Hierarchical feature extraction: Early layers learn edges and textures, middle layers learn parts and patterns, and deep layers learn high-level semantic concepts.

The Convolution Operation

A convolution slides a small filter (e.g., 3x3) across the input image, computing the dot product at each position. For an input image I and a filter K:

2D Convolution:

(I * K)(i, j) = sum_m sum_n I(i+m, j+n) * K(m, n)

The output at position (i,j) is the sum of element-wise products between the filter and the corresponding image patch.

import torch
import torch.nn as nn

# A simple CNN for image classification
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()

        # Convolutional layers: extract spatial features
        self.features = nn.Sequential(
            # Conv layer 1: 3 input channels (RGB), 32 output feature maps
            # Kernel size 3x3, padding 1 to preserve spatial dimensions
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # [B, 3, 32, 32] -> [B, 32, 32, 32]
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [B, 32, 32, 32] -> [B, 32, 16, 16]

            # Conv layer 2: deeper features
            nn.Conv2d(32, 64, kernel_size=3, padding=1),  # [B, 32, 16, 16] -> [B, 64, 16, 16]
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [B, 64, 16, 16] -> [B, 64, 8, 8]

            # Conv layer 3: even deeper features
            nn.Conv2d(64, 128, kernel_size=3, padding=1),  # [B, 64, 8, 8] -> [B, 128, 8, 8]
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # [B, 128, 8, 8] -> [B, 128, 4, 4]
        )

        # Classification head
        self.classifier = nn.Sequential(
            nn.Flatten(),           # [B, 128, 4, 4] -> [B, 2048]
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Demonstrate the model
model = SimpleCNN(num_classes=10)
dummy_input = torch.randn(4, 3, 32, 32)  # Batch of 4 RGB 32x32 images
output = model(dummy_input)
print(f"Input shape:  {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

Pooling Layers

Pooling reduces spatial dimensions, providing translation invariance and reducing computation:

  • Max Pooling: Takes the maximum value in each window. Most common. Preserves the strongest activations.
  • Average Pooling: Takes the mean. Used in some architectures (e.g., global average pooling before classification).
  • Global Average Pooling: Averages entire feature map to a single value. Replaces fully connected layers in modern architectures like ResNet.

Key CNN Architectures (Historical)

ArchitectureYearKey InnovationDepth
AlexNet2012GPU training, ReLU, dropout8 layers
VGGNet2014Small 3x3 filters stacked deep16-19 layers
GoogLeNet/Inception2014Inception modules (parallel filters)22 layers
ResNet2015Skip connections (residual learning)50-152 layers
EfficientNet2019Compound scaling (width, depth, resolution)Variable
ConvNeXt2022Modernized CNN with Transformer ideasVariable

1.2 Why Vision Went from CNNs to Transformers

The shift from CNNs to Transformers in vision was driven by several factors:

Limitations of CNNs

  1. Limited receptive field: CNNs see the image through small local windows. Even deep CNNs struggle to capture truly global relationships without many stacked layers.
  2. Fixed geometric priors: Convolution assumes a grid structure and translation invariance. This inductive bias helps with small data but limits flexibility.
  3. Scaling challenges: While CNNs benefit from depth, the gains plateau. Transformers scale more smoothly with data and compute.
  4. Difficulty with variable resolution: CNNs typically require fixed input sizes (or workarounds like adaptive pooling).

Advantages of Vision Transformers

  1. Global attention from the start: Self-attention lets every patch attend to every other patch, capturing long-range dependencies from layer 1.
  2. Fewer inductive biases: Less built-in assumption about the data structure means the model can learn more flexible representations with enough data.
  3. Unified architecture: The same Transformer architecture works for text, images, audio, and video, enabling true multimodal models.
  4. Better scaling: ViTs show near-linear improvements with more data and compute, following scaling laws similar to LLMs.
  5. Transfer learning: Pre-trained ViTs transfer remarkably well across tasks and domains.
Important caveat: CNNs are NOT dead. ConvNeXt (2022) showed that with modern training recipes, CNNs can match ViTs. In practice, the choice depends on your data size, compute budget, and task. For small datasets, CNNs often still win due to their built-in inductive biases.

2. Generative Adversarial Networks (GANs)

GANs, introduced by Ian Goodfellow in 2014, were the first generative model architecture to produce strikingly realistic images. While they have been largely superseded by diffusion models for image generation (as we will see in Week 14), understanding GANs is essential because they introduced foundational concepts in generative modeling that continue to influence the field.

2.1 GAN Architecture: Generator vs Discriminator

A GAN consists of two neural networks trained simultaneously in opposition:

The Two Players


  Random Noise (z)                    Real Images (x)
       |                                    |
       v                                    v
  +-----------+                      +-----------------+
  | Generator |  -- fake images -->  | Discriminator   |
  |   G(z)    |                      | D(x) = real?    |
  +-----------+                      | D(G(z)) = fake? |
       ^                             +-----------------+
       |                                    |
       +--- update G to fool D <--- loss ---+
                                            |
                            update D to detect fakes
                        
  • Generator (G): Takes random noise z ~ N(0, I) and produces a synthetic image G(z). Its goal is to produce images so realistic that the discriminator cannot tell them from real data.
  • Discriminator (D): Takes an image (real or generated) and outputs a probability that the image is real. Its goal is to correctly classify real vs fake images.
GAN Architecture: Generator vs Discriminator
graph LR Z["Random Noise z"] --> G["Generator G"] G --> FakeImg["Generated Image G(z)"] FakeImg --> D["Discriminator D"] RealImg["Real Images x"] --> D D --> RealOrFake{"Real or Fake?"} RealOrFake -->|Loss Signal| G RealOrFake -->|Loss Signal| D style G fill:#4CAF50,stroke:#333,color:#fff style D fill:#2196F3,stroke:#333,color:#fff style RealOrFake fill:#FF9800,stroke:#333,color:#fff

2.2 The Adversarial Training Game (Minimax)

GAN training is formulated as a two-player minimax game. The Generator tries to minimize the objective while the Discriminator tries to maximize it:

GAN Minimax Objective:

min_G max_D V(D, G) = E_{x~p_data}[log D(x)] + E_{z~p_z}[log(1 - D(G(z)))]

Term 1: E_{x~p_data}[log D(x)] -- For real data x, D wants D(x) close to 1, making log D(x) close to 0 (maximizing).
Term 2: E_{z~p_z}[log(1 - D(G(z)))] -- For fake data G(z), D wants D(G(z)) close to 0, making log(1 - D(G(z))) close to 0 (maximizing). G wants D(G(z)) close to 1, making this term very negative (minimizing).

In practice, early in training when G is poor, log(1 - D(G(z))) saturates (the gradient is tiny). So we instead train G to maximize log D(G(z)) rather than minimize log(1 - D(G(z))). This provides stronger gradients early in training.

Training Algorithm

# Pseudocode for GAN training
for epoch in range(num_epochs):
    for real_batch in dataloader:

        # ---- Step 1: Train Discriminator ----
        # Goal: maximize log D(real) + log(1 - D(fake))

        z = sample_noise(batch_size, latent_dim)
        fake_batch = generator(z)

        d_real = discriminator(real_batch)       # Should output ~1
        d_fake = discriminator(fake_batch.detach())  # Should output ~0

        d_loss = -torch.mean(torch.log(d_real) + torch.log(1 - d_fake))
        d_loss.backward()
        d_optimizer.step()

        # ---- Step 2: Train Generator ----
        # Goal: maximize log D(G(z)) [non-saturating variant]

        z = sample_noise(batch_size, latent_dim)
        fake_batch = generator(z)
        d_fake = discriminator(fake_batch)       # Want this to be ~1

        g_loss = -torch.mean(torch.log(d_fake))
        g_loss.backward()
        g_optimizer.step()

2.3 Types of GANs

DCGAN (Deep Convolutional GAN)

Published by Radford et al. (2015), DCGAN established the standard architectural guidelines for convolutional GANs:

  • Replace pooling layers with strided convolutions (discriminator) and transposed convolutions (generator)
  • Use batch normalization in both networks (except output layer of G and input layer of D)
  • Remove fully connected hidden layers for deeper architectures
  • Use ReLU in the generator (except output uses Tanh), LeakyReLU in the discriminator

StyleGAN (Karras et al., 2019-2024)

StyleGAN revolutionized face generation with several innovations:

  • Mapping network: An 8-layer MLP maps z to an intermediate latent space W, which is more disentangled
  • Adaptive Instance Normalization (AdaIN): Style is injected at each layer via learned affine transforms of the W vector
  • Progressive growing: Training starts at low resolution and progressively adds higher-resolution layers
  • Style mixing: Different W vectors can control different levels (coarse pose, fine details)

StyleGAN2 fixed water droplet artifacts, and StyleGAN3 achieved alias-free generation with continuous equivariance.

CycleGAN (Unpaired Image-to-Image Translation)

CycleGAN translates images between two domains without paired training data (e.g., horses to zebras, summer to winter):

  • Two generators: G_AB (A to B) and G_BA (B to A)
  • Two discriminators: D_A and D_B
  • Cycle consistency loss: G_BA(G_AB(x_A)) should equal x_A. This constraint ensures content is preserved during translation.

Pix2Pix (Paired Image-to-Image Translation)

Pix2Pix uses paired training data for image-to-image translation tasks:

  • Conditional GAN: Both G and D receive the input image as a condition
  • U-Net generator with skip connections
  • PatchGAN discriminator (classifies NxN patches as real/fake rather than the whole image)
  • Applications: edges to photos, segmentation maps to images, day to night

2.4 Mode Collapse and Training Instability

GANs are notoriously difficult to train. The two main failure modes are:

Mode Collapse

The generator discovers a small set of outputs that consistently fool the discriminator and stops exploring the full data distribution. For example, if trained on digits 0-9, the generator might only produce perfect 1s and 7s.

Why it happens: The generator finds a local optimum where a few "safe" outputs always get high discriminator scores. The discriminator then learns to reject those specific outputs, but the generator just switches to a different small set, creating oscillation.

Mitigation strategies:

  • Minibatch discrimination (let D see multiple samples at once)
  • Unrolled GANs (optimize G against future D, not current D)
  • Wasserstein loss (WGAN) with gradient penalty
  • Spectral normalization

Training Instability

The adversarial training dynamic can easily become unstable:

  • Vanishing gradients: If D becomes too strong, G gets no useful gradient signal
  • Oscillation: D and G chase each other without converging
  • Non-convergence: The minimax game has no guarantee of reaching equilibrium

2.5 Why GANs Are Being Replaced by Diffusion Models

While GANs dominated image generation from 2014-2021, diffusion models have taken over for several reasons:

  • Training stability: Diffusion models use a simple MSE loss on noise prediction. No adversarial dynamics, no mode collapse.
  • Mode coverage: Diffusion models naturally cover the full data distribution. No mode collapse by design.
  • Quality: Diffusion models now match or exceed GAN quality, especially for diverse scenes (not just faces).
  • Controllability: Text conditioning via cross-attention is natural in diffusion models. GANs require additional conditioning mechanisms.
  • Compositionality: Diffusion models handle complex scenes with multiple objects better than GANs.

However, GANs still have advantages in speed (one forward pass vs many denoising steps) and are used in real-time applications. GAN-based super-resolution (Real-ESRGAN) and face restoration (GFPGAN, CodeFormer) remain popular.

2.6 PRACTICAL: Simple GAN for Generating Digits

Let us implement a complete GAN that generates handwritten digits using the MNIST dataset.

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

# ====================
# Hyperparameters
# ====================
LATENT_DIM = 100      # Dimension of the noise vector z
IMG_SIZE = 28          # MNIST image size
IMG_CHANNELS = 1       # Grayscale
BATCH_SIZE = 128
NUM_EPOCHS = 50
LR = 0.0002
BETA1 = 0.5           # Adam beta1 (recommended for GANs)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====================
# Data Loading
# ====================
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1] to match Tanh output
])

dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# ====================
# Generator Network
# ====================
class Generator(nn.Module):
    """
    Takes a latent vector z of dimension LATENT_DIM
    and produces a 1x28x28 image.

    Architecture: Linear -> Reshape -> ConvTranspose2d layers
    """
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            # Project and reshape: z -> 256 x 7 x 7
            nn.Linear(LATENT_DIM, 256 * 7 * 7),
            nn.BatchNorm1d(256 * 7 * 7),
            nn.ReLU(True),

            # Reshape to 256 x 7 x 7 happens in forward()

            # Upsample to 128 x 14 x 14
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # Upsample to 1 x 28 x 28
            nn.ConvTranspose2d(128, IMG_CHANNELS, kernel_size=4, stride=2, padding=1),
            nn.Tanh()  # Output in [-1, 1]
        )

    def forward(self, z):
        # First pass through linear layers
        x = self.model[0](z)   # Linear
        x = self.model[1](x)   # BatchNorm1d
        x = self.model[2](x)   # ReLU

        # Reshape to image format
        x = x.view(-1, 256, 7, 7)

        # Pass through conv layers
        x = self.model[3](x)   # ConvTranspose2d
        x = self.model[4](x)   # BatchNorm2d
        x = self.model[5](x)   # ReLU
        x = self.model[6](x)   # ConvTranspose2d
        x = self.model[7](x)   # Tanh

        return x


# ====================
# Discriminator Network
# ====================
class Discriminator(nn.Module):
    """
    Takes a 1x28x28 image and outputs a probability
    that the image is real (vs generated).

    Architecture: Conv2d layers -> Flatten -> Linear -> Sigmoid
    """
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            # 1 x 28 x 28 -> 64 x 14 x 14
            nn.Conv2d(IMG_CHANNELS, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            # 64 x 14 x 14 -> 128 x 7 x 7
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            # Flatten and classify
            nn.Flatten(),
            nn.Linear(128 * 7 * 7, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


# ====================
# Initialize Models
# ====================
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)

# Weight initialization (important for GAN stability)
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

generator.apply(weights_init)
discriminator.apply(weights_init)

# ====================
# Optimizers and Loss
# ====================
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=LR, betas=(BETA1, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=LR, betas=(BETA1, 0.999))

# Fixed noise for visualization (to track progress)
fixed_noise = torch.randn(64, LATENT_DIM, device=DEVICE)

# ====================
# Training Loop
# ====================
os.makedirs("gan_outputs", exist_ok=True)

g_losses = []
d_losses = []

print("Starting GAN Training...")
for epoch in range(NUM_EPOCHS):
    for batch_idx, (real_images, _) in enumerate(dataloader):
        real_images = real_images.to(DEVICE)
        batch_size = real_images.size(0)

        # Labels for real and fake
        real_labels = torch.ones(batch_size, 1, device=DEVICE)
        fake_labels = torch.zeros(batch_size, 1, device=DEVICE)

        # ---------------------
        # Train Discriminator
        # ---------------------
        d_optimizer.zero_grad()

        # Loss on real images
        d_real_output = discriminator(real_images)
        d_real_loss = criterion(d_real_output, real_labels)

        # Loss on fake images
        z = torch.randn(batch_size, LATENT_DIM, device=DEVICE)
        fake_images = generator(z)
        d_fake_output = discriminator(fake_images.detach())  # .detach() to not update G
        d_fake_loss = criterion(d_fake_output, fake_labels)

        # Combined discriminator loss
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()

        # ---------------------
        # Train Generator
        # ---------------------
        g_optimizer.zero_grad()

        # Generate fake images and try to fool discriminator
        z = torch.randn(batch_size, LATENT_DIM, device=DEVICE)
        fake_images = generator(z)
        g_output = discriminator(fake_images)

        # Generator wants discriminator to output 1 (real) for fakes
        g_loss = criterion(g_output, real_labels)
        g_loss.backward()
        g_optimizer.step()

        # Logging
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())

    # Save generated images each epoch
    with torch.no_grad():
        fake_samples = generator(fixed_noise)
        save_image(fake_samples, f"gan_outputs/epoch_{epoch:03d}.png",
                   nrow=8, normalize=True)

    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}]  "
          f"D Loss: {d_loss.item():.4f}  G Loss: {g_loss.item():.4f}")

# ====================
# Plot Training Curves
# ====================
plt.figure(figsize=(10, 5))
plt.plot(g_losses, label="Generator Loss", alpha=0.7)
plt.plot(d_losses, label="Discriminator Loss", alpha=0.7)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("GAN Training Losses")
plt.legend()
plt.savefig("gan_outputs/training_curves.png")
plt.show()

print("Training complete! Check gan_outputs/ for generated images.")
Tips for stable GAN training:
  • Use Adam with a low learning rate (0.0002) and beta1=0.5
  • Apply spectral normalization to the discriminator for better stability
  • Use label smoothing (real labels = 0.9 instead of 1.0)
  • Train D more steps than G if D is too weak, or vice versa
  • Monitor both losses: if D loss goes to 0, it has won and G cannot learn
  • Use Wasserstein loss (WGAN-GP) for more stable training

3. Vision Transformers (ViT)

The Vision Transformer, introduced by Dosovitskiy et al. in the 2020 paper "An Image is Worth 16x16 Words," demonstrated that a pure Transformer architecture, with minimal image-specific inductive biases, can achieve excellent performance on image classification when pre-trained on sufficient data.

3.1 The Core Idea: Images as Sequences of Patches

The key insight of ViT is simple but powerful: treat an image as a sequence of patches, just like a sentence is a sequence of tokens.

ViT Pipeline


  Input Image (224 x 224 x 3)
       |
       v
  Split into patches (14 x 14 grid of 16x16 patches = 196 patches)
       |
       v
  Flatten each patch (16 x 16 x 3 = 768 values per patch)
       |
       v
  Linear projection (768 -> D, where D is the embedding dimension)
       |
       v
  Add position embeddings (learnable, one per patch)
       |
       v
  Prepend [CLS] token (learnable embedding for classification)
       |
       v
  Transformer Encoder (L layers of Multi-Head Self-Attention + FFN)
       |
       v
  Take [CLS] token output -> MLP Head -> Class prediction
                        

Patch Embedding in Detail

For an image of size H x W x C with patch size P:

Number of patches: N = (H * W) / P^2
Patch dimension: P^2 * C
For 224x224 RGB with P=16: N = 196, patch_dim = 768

Each patch is flattened into a vector and linearly projected into the model's embedding dimension D. This is equivalent to applying a convolution with kernel size P and stride P:

import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    """
    Split image into patches and project to embedding dimension.

    This is equivalent to a Conv2d with kernel_size=patch_size and stride=patch_size.
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2  # 196 for 224/16

        # Linear projection of flattened patches
        # Using Conv2d is equivalent but more efficient than manual flatten + linear
        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )
        # Input:  [B, 3, 224, 224]
        # Output: [B, 768, 14, 14] -> reshape to [B, 196, 768]

    def forward(self, x):
        # x: [B, C, H, W]
        x = self.projection(x)  # [B, embed_dim, H/P, W/P]
        x = x.flatten(2)        # [B, embed_dim, num_patches]
        x = x.transpose(1, 2)   # [B, num_patches, embed_dim]
        return x

# Test
patch_embed = PatchEmbedding()
dummy = torch.randn(2, 3, 224, 224)
patches = patch_embed(dummy)
print(f"Patches shape: {patches.shape}")  # [2, 196, 768]

Position Embeddings

Unlike CNNs, Transformers have no inherent notion of spatial position. Position embeddings provide this information:

  • Learned 1D position embeddings (original ViT): Simple and effective. A learnable vector for each position 0, 1, ..., N.
  • 2D position embeddings: Separate embeddings for row and column. Slightly better for some tasks.
  • Sinusoidal embeddings: Fixed, not learned. Can generalize to different resolutions.
  • Rotary Position Embedding (RoPE): Used in newer vision models, encodes relative positions.

The CLS Token

A special learnable embedding [CLS] is prepended to the sequence of patch embeddings. After passing through the Transformer, the output corresponding to this token serves as the aggregate representation of the entire image, used for classification. This is analogous to the [CLS] token in BERT.

3.2 PRACTICAL: Implement a Simple Vision Transformer

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# ====================
# ViT Components
# ====================

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=256):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.projection = nn.Conv2d(in_channels, embed_dim,
                                     kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)    # [B, embed_dim, H/P, W/P]
        x = x.flatten(2)          # [B, embed_dim, num_patches]
        x = x.transpose(1, 2)     # [B, num_patches, embed_dim]
        return x


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim=256, num_heads=8, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_dropout = nn.Dropout(dropout)
        self.proj_dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, D = x.shape

        # Compute Q, K, V in one linear layer
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, B, heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * self.scale  # [B, heads, N, N]
        attn = attn.softmax(dim=-1)
        attn = self.attn_dropout(attn)

        # Apply attention to values
        out = (attn @ v).transpose(1, 2).reshape(B, N, D)
        out = self.proj(out)
        out = self.proj_dropout(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=256, num_heads=8, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # Pre-norm architecture (better than post-norm for ViT)
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class VisionTransformer(nn.Module):
    """
    A simple Vision Transformer for image classification.
    Designed for CIFAR-10 (32x32 images).
    """
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_channels=3,
        num_classes=10,
        embed_dim=256,
        depth=6,
        num_heads=8,
        mlp_ratio=4.0,
        dropout=0.1,
    ):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        # CLS token: learnable classification token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Position embeddings: one for each patch + CLS token
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        self.pos_dropout = nn.Dropout(dropout)

        # Transformer encoder blocks
        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)

        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialize position embeddings with truncated normal
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.shape[0]

        # Create patch embeddings
        x = self.patch_embed(x)  # [B, num_patches, embed_dim]

        # Prepend CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B, 1, embed_dim]
        x = torch.cat([cls_tokens, x], dim=1)  # [B, num_patches + 1, embed_dim]

        # Add position embeddings
        x = x + self.pos_embed
        x = self.pos_dropout(x)

        # Pass through transformer blocks
        x = self.blocks(x)
        x = self.norm(x)

        # Extract CLS token output for classification
        cls_output = x[:, 0]  # [B, embed_dim]

        # Classification
        logits = self.head(cls_output)  # [B, num_classes]
        return logits


# ====================
# Training on CIFAR-10
# ====================

def train_vit():
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    EPOCHS = 20
    BATCH_SIZE = 128
    LR = 3e-4

    # Data augmentation for training
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])

    train_set = datasets.CIFAR10(root="./data", train=True, download=True,
                                  transform=train_transform)
    test_set = datasets.CIFAR10(root="./data", train=False, download=True,
                                 transform=test_transform)

    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=2)
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False,
                             num_workers=2)

    # Create model
    model = VisionTransformer(
        img_size=32,
        patch_size=4,       # 32/4 = 8x8 = 64 patches
        embed_dim=256,
        depth=6,
        num_heads=8,
        mlp_ratio=4.0,
        dropout=0.1,
        num_classes=10,
    ).to(DEVICE)

    param_count = sum(p.numel() for p in model.parameters())
    print(f"Model parameters: {param_count:,}")

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.05)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        model.train()
        total_loss, correct, total = 0, 0, 0

        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            logits = model(images)
            loss = criterion(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * images.size(0)
            correct += (logits.argmax(1) == labels).sum().item()
            total += images.size(0)

        scheduler.step()

        train_acc = 100 * correct / total
        avg_loss = total_loss / total

        # Evaluate
        model.eval()
        test_correct, test_total = 0, 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                logits = model(images)
                test_correct += (logits.argmax(1) == labels).sum().item()
                test_total += images.size(0)

        test_acc = 100 * test_correct / test_total
        print(f"Epoch {epoch+1}/{EPOCHS}  Loss: {avg_loss:.4f}  "
              f"Train Acc: {train_acc:.1f}%  Test Acc: {test_acc:.1f}%")

    return model

# Run training
# model = train_vit()

3.3 ViT vs CNN Performance Comparison

AspectCNN (ResNet/ConvNeXt)ViT
Small data (<10K)Better (inductive bias helps)Worse (needs lots of data)
Large data (>1M)GoodBetter (scales well)
Computational costEfficient (O(n) with kernel size)Expensive (O(n^2) attention)
Global contextLimited (local receptive field)Full (from layer 1)
Transfer learningGoodExcellent
InterpretabilityFeature mapsAttention maps
Multimodal integrationRequires adaptersNatural (same architecture)

4. CLIP (Contrastive Language-Image Pre-training)

CLIP, introduced by OpenAI in January 2021, is one of the most influential models in modern AI. It learns a joint embedding space for images and text using contrastive learning on 400 million image-text pairs from the internet. CLIP is the foundation that made models like DALL-E 2 and Stable Diffusion possible.

4.1 How CLIP Works

CLIP Architecture


  Image                             Text
    |                                 |
    v                                 v
  +---------------+           +----------------+
  | Image Encoder |           | Text Encoder   |
  | (ViT or       |           | (Transformer)  |
  | ResNet)       |           |                |
  +---------------+           +----------------+
    |                                 |
    v                                 v
  Image                          Text
  Embedding                      Embedding
  (d-dimensional)                (d-dimensional)
    |                                 |
    +--- cosine similarity matrix ---+
    |                                 |
    v                                 v
  Contrastive Loss (InfoNCE): matching pairs should be close,
                               non-matching pairs should be far
                        
CLIP Architecture: Contrastive Learning
graph TB Img["Input Image"] --> ImgEnc["Image Encoder
(ViT / ResNet)"] Txt["Input Text"] --> TxtEnc["Text Encoder
(Transformer)"] ImgEnc --> ImgEmb["Image Embedding"] TxtEnc --> TxtEmb["Text Embedding"] ImgEmb --> Sim["Cosine Similarity
Matrix"] TxtEmb --> Sim Sim --> CL["Contrastive Loss
(InfoNCE)"] CL -->|Matching pairs closer| ImgEnc CL -->|Non-matching pairs farther| TxtEnc style ImgEnc fill:#4CAF50,stroke:#333,color:#fff style TxtEnc fill:#2196F3,stroke:#333,color:#fff style Sim fill:#FF9800,stroke:#333,color:#fff style CL fill:#9C27B0,stroke:#333,color:#fff

Contrastive Learning Objective (InfoNCE Loss)

Given a batch of N image-text pairs, CLIP computes the cosine similarity between all NxN possible (image, text) combinations. The loss encourages the N correct pairs to have high similarity and the N^2 - N incorrect pairs to have low similarity.

InfoNCE Loss (image-to-text direction):

L_i2t = -log( exp(sim(I_i, T_i) / tau) / sum_j exp(sim(I_i, T_j) / tau) )

For each image I_i, we want its matching text T_i to have the highest similarity among all texts in the batch. tau is a learned temperature parameter that scales the logits. The total loss is the average over both directions (image-to-text and text-to-image).

import torch
import torch.nn.functional as F

def clip_loss(image_embeddings, text_embeddings, temperature=0.07):
    """
    Compute symmetric CLIP contrastive loss.

    Args:
        image_embeddings: [B, D] normalized image embeddings
        text_embeddings: [B, D] normalized text embeddings
        temperature: learnable temperature parameter

    Returns:
        Symmetric contrastive loss
    """
    # Compute cosine similarity matrix [B, B]
    # Each entry (i, j) = similarity between image_i and text_j
    logits = (image_embeddings @ text_embeddings.T) / temperature

    # Labels: the diagonal entries are the correct pairs
    labels = torch.arange(len(logits), device=logits.device)

    # Symmetric loss: image-to-text + text-to-image
    loss_i2t = F.cross_entropy(logits, labels)         # rows
    loss_t2i = F.cross_entropy(logits.T, labels)       # columns

    loss = (loss_i2t + loss_t2i) / 2
    return loss

# Example
B, D = 32, 512
image_emb = F.normalize(torch.randn(B, D), dim=-1)
text_emb = F.normalize(torch.randn(B, D), dim=-1)
loss = clip_loss(image_emb, text_emb)
print(f"CLIP loss: {loss.item():.4f}")

4.2 Zero-Shot Classification with CLIP

One of CLIP's most remarkable capabilities is zero-shot classification. Without any task-specific fine-tuning, CLIP can classify images into arbitrary categories by comparing image embeddings to text embeddings of class descriptions.

How Zero-Shot Classification Works

  1. Encode the image using CLIP's image encoder
  2. Create text prompts for each class: "a photo of a {class_name}"
  3. Encode all text prompts using CLIP's text encoder
  4. Compute cosine similarity between the image and each text embedding
  5. The class with highest similarity is the prediction

4.3 PRACTICAL: Use CLIP for Zero-Shot Image Classification

import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

# Load CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()

def zero_shot_classify(image_path: str, candidate_labels: list[str]) -> dict:
    """
    Classify an image into one of the candidate labels using CLIP.

    Args:
        image_path: Path to the image file
        candidate_labels: List of class names to choose from

    Returns:
        Dictionary mapping labels to probabilities
    """
    image = Image.open(image_path).convert("RGB")

    # Create text prompts (prompt engineering matters!)
    text_prompts = [f"a photo of a {label}" for label in candidate_labels]

    # Process inputs
    inputs = processor(
        text=text_prompts,
        images=image,
        return_tensors="pt",
        padding=True,
    )

    # Get embeddings
    with torch.no_grad():
        outputs = model(**inputs)

    # Cosine similarity between image and each text, scaled by temperature
    logits_per_image = outputs.logits_per_image  # [1, num_labels]
    probs = logits_per_image.softmax(dim=-1)     # Convert to probabilities

    results = {label: prob.item() for label, prob in zip(candidate_labels, probs[0])}

    # Sort by probability
    results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
    return results


# Example usage
labels = ["cat", "dog", "bird", "car", "airplane", "ship", "horse", "deer"]
# results = zero_shot_classify("test_image.jpg", labels)
# for label, prob in results.items():
#     print(f"  {label}: {prob:.1%}")


# ====================
# Prompt Engineering for CLIP
# ====================
# The text prompt significantly affects CLIP's accuracy.
# Ensembling multiple prompts improves results:

def zero_shot_classify_ensemble(image_path: str, candidate_labels: list[str]) -> dict:
    """
    Improved zero-shot classification using prompt ensembling.
    Uses multiple prompt templates and averages the text embeddings.
    """
    prompt_templates = [
        "a photo of a {}",
        "a photograph of a {}",
        "an image of a {}",
        "a picture of a {}",
        "a close-up photo of a {}",
        "a good photo of a {}",
        "a photo of the {}",
        "art of a {}",
        "a rendition of a {}",
    ]

    image = Image.open(image_path).convert("RGB")

    # Process image
    image_inputs = processor(images=image, return_tensors="pt")

    with torch.no_grad():
        image_features = model.get_image_features(**image_inputs)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        # For each label, average embeddings across all templates
        label_features = []
        for label in candidate_labels:
            texts = [template.format(label) for template in prompt_templates]
            text_inputs = processor(text=texts, return_tensors="pt", padding=True)
            text_features = model.get_text_features(**text_inputs)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            # Average across templates
            mean_text_feature = text_features.mean(dim=0, keepdim=True)
            mean_text_feature = mean_text_feature / mean_text_feature.norm(dim=-1, keepdim=True)
            label_features.append(mean_text_feature)

        label_features = torch.cat(label_features, dim=0)  # [num_labels, D]

        # Compute similarity
        similarity = (image_features @ label_features.T) * 100  # Scale
        probs = similarity.softmax(dim=-1)

    results = {label: prob.item() for label, prob in zip(candidate_labels, probs[0])}
    return dict(sorted(results.items(), key=lambda x: x[1], reverse=True))

CLIP's joint embedding space enables powerful semantic image search. You can search for images using natural language queries, or find similar images.

import torch
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from pathlib import Path
import json

class CLIPImageSearchEngine:
    """
    A semantic image search engine powered by CLIP.

    Supports:
    - Text-to-image search (find images matching a description)
    - Image-to-image search (find visually similar images)
    - Building and saving an index for fast retrieval
    """

    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = CLIPModel.from_pretrained(model_name).to(self.device)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        self.model.eval()

        # Index storage
        self.image_paths = []
        self.image_embeddings = None

    def build_index(self, image_directory: str, extensions=(".jpg", ".png", ".jpeg")):
        """
        Build a search index from all images in a directory.
        Computes and stores CLIP embeddings for each image.
        """
        image_dir = Path(image_directory)
        image_files = []
        for ext in extensions:
            image_files.extend(image_dir.glob(f"**/*{ext}"))

        image_files = sorted(image_files)
        print(f"Found {len(image_files)} images. Building index...")

        all_embeddings = []
        batch_size = 32

        for i in range(0, len(image_files), batch_size):
            batch_paths = image_files[i:i + batch_size]
            images = []

            for path in batch_paths:
                try:
                    img = Image.open(path).convert("RGB")
                    images.append(img)
                    self.image_paths.append(str(path))
                except Exception as e:
                    print(f"  Skipping {path}: {e}")
                    continue

            if images:
                inputs = self.processor(images=images, return_tensors="pt", padding=True)
                inputs = {k: v.to(self.device) for k, v in inputs.items()}

                with torch.no_grad():
                    image_features = self.model.get_image_features(**inputs)
                    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                    all_embeddings.append(image_features.cpu())

            if (i // batch_size) % 10 == 0:
                print(f"  Processed {min(i + batch_size, len(image_files))}/{len(image_files)}")

        self.image_embeddings = torch.cat(all_embeddings, dim=0)
        print(f"Index built: {len(self.image_paths)} images, "
              f"embedding shape: {self.image_embeddings.shape}")

    def search_by_text(self, query: str, top_k: int = 5) -> list[dict]:
        """
        Search for images matching a text description.

        Args:
            query: Natural language description (e.g., "a sunset over the ocean")
            top_k: Number of results to return

        Returns:
            List of dicts with 'path', 'score', and 'rank'
        """
        if self.image_embeddings is None:
            raise ValueError("Index not built. Call build_index() first.")

        # Encode the text query
        text_inputs = self.processor(text=[query], return_tensors="pt", padding=True)
        text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}

        with torch.no_grad():
            text_features = self.model.get_text_features(**text_inputs)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # Compute cosine similarity with all images
        similarities = (text_features.cpu() @ self.image_embeddings.T).squeeze(0)

        # Get top-k results
        top_indices = similarities.argsort(descending=True)[:top_k]

        results = []
        for rank, idx in enumerate(top_indices):
            results.append({
                "rank": rank + 1,
                "path": self.image_paths[idx],
                "score": similarities[idx].item(),
            })

        return results

    def search_by_image(self, query_image_path: str, top_k: int = 5) -> list[dict]:
        """
        Find images similar to a query image.

        Args:
            query_image_path: Path to the query image
            top_k: Number of results to return

        Returns:
            List of dicts with 'path', 'score', and 'rank'
        """
        if self.image_embeddings is None:
            raise ValueError("Index not built. Call build_index() first.")

        image = Image.open(query_image_path).convert("RGB")
        inputs = self.processor(images=image, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            query_features = self.model.get_image_features(**inputs)
            query_features = query_features / query_features.norm(dim=-1, keepdim=True)

        similarities = (query_features.cpu() @ self.image_embeddings.T).squeeze(0)
        top_indices = similarities.argsort(descending=True)[:top_k]

        results = []
        for rank, idx in enumerate(top_indices):
            results.append({
                "rank": rank + 1,
                "path": self.image_paths[idx],
                "score": similarities[idx].item(),
            })

        return results

    def save_index(self, save_path: str):
        """Save the built index to disk for later use."""
        torch.save({
            "image_paths": self.image_paths,
            "image_embeddings": self.image_embeddings,
        }, save_path)
        print(f"Index saved to {save_path}")

    def load_index(self, load_path: str):
        """Load a previously saved index."""
        data = torch.load(load_path, weights_only=True)
        self.image_paths = data["image_paths"]
        self.image_embeddings = data["image_embeddings"]
        print(f"Index loaded: {len(self.image_paths)} images")


# ====================
# Usage Example
# ====================

# Initialize search engine
# engine = CLIPImageSearchEngine()

# Build index from a directory of images
# engine.build_index("/path/to/your/images")

# Search by text
# results = engine.search_by_text("a dog playing in the snow")
# for r in results:
#     print(f"  #{r['rank']} (score: {r['score']:.3f}): {r['path']}")

# Search by image
# results = engine.search_by_image("/path/to/query_image.jpg")
# for r in results:
#     print(f"  #{r['rank']} (score: {r['score']:.3f}): {r['path']}")

4.5 CLIP's Impact on AI

CLIP has been foundational to many subsequent breakthroughs:

  • DALL-E 2: Uses CLIP embeddings as the bridge between text and images. The text is encoded by CLIP, then a diffusion model generates an image that matches the CLIP embedding.
  • Stable Diffusion: Uses CLIP's text encoder to condition the diffusion process via cross-attention.
  • Open-vocabulary detection: Models like OWL-ViT use CLIP to detect objects described in free-form text.
  • Image generation evaluation: CLIP score is widely used to measure text-image alignment.
  • Multimodal retrieval: CLIP powers image search across many applications.

5. Multimodal Models

Multimodal models, particularly Vision-Language Models (VLMs), represent the convergence of computer vision and natural language processing into unified systems that can see and reason about visual content.

Multimodal Model Pipeline
graph LR Img["Image Input"] --> VE["Vision Encoder
(ViT / SigLIP)"] VE --> Proj["Projection Layer"] Proj --> Merge["Merge Visual +
Text Tokens"] Txt["Text Input"] --> Tok["Tokenizer"] --> Merge Merge --> LLM["Large Language
Model"] LLM --> Out["Text Output
(Answer / Description)"] style VE fill:#4CAF50,stroke:#333,color:#fff style LLM fill:#2196F3,stroke:#333,color:#fff style Proj fill:#FF9800,stroke:#333,color:#fff style Merge fill:#9C27B0,stroke:#333,color:#fff

5.1 Vision-Language Models (VLMs)

The major VLMs as of early 2026:

ModelProviderKey Features
GPT-4oOpenAINative multimodal, fast, excellent reasoning
Claude 3.5 Sonnet / Claude 3 OpusAnthropicStrong document understanding, safety-focused
Gemini 2.0GoogleNative multimodal, long context, video understanding
LLaVA-NeXT / LLaVA-OneVisionOpen-sourceStrong open-source VLM, multiple resolution support
Qwen2.5-VLAlibabaCompetitive open-source, strong multilingual
InternVL 2.5Shanghai AI LabExcellent benchmark performance
PixtralMistralEfficient multimodal, strong for its size

5.2 How Multimodal Models Process Images

Most VLMs follow a three-component architecture:

VLM Architecture Pattern


  Input Image                    Input Text
       |                              |
       v                              |
  +------------------+                |
  | Vision Encoder   |                |
  | (ViT, SigLIP,    |                |
  |  CLIP ViT, etc.) |                |
  +------------------+                |
       |                              |
       v                              |
  Image Features                      |
  [N visual tokens]                   |
       |                              |
       v                              |
  +------------------+                |
  | Projection Layer |                |
  | (MLP or          |                |
  |  Cross-attention)|                |
  +------------------+                |
       |                              |
       v                              v
  Visual Tokens     +    Text Tokens
       |                     |
       +--------+------------+
                |
                v
       +------------------+
       | Language Model   |
       | (LLM backbone)   |
       +------------------+
                |
                v
          Text Output
                        
  1. Vision Encoder: Processes the image into a sequence of feature vectors (visual tokens). Common choices: ViT, SigLIP, CLIP ViT. A 224x224 image with 14x14 patches produces 256 visual tokens.
  2. Projection Layer: Maps visual tokens into the LLM's embedding space. Can be a simple MLP (LLaVA) or cross-attention (Flamingo). Some models use a perceiver/resampler to reduce the number of visual tokens.
  3. Language Model: The LLM backbone processes both visual and text tokens together. The visual tokens are interleaved with text tokens in the sequence.

5.3 PRACTICAL: Use GPT-4o / Claude for Vision Tasks via API

import base64
import httpx
from openai import OpenAI
from anthropic import Anthropic

# ====================
# Helper: Encode image to base64
# ====================
def encode_image_to_base64(image_path: str) -> str:
    """Read an image file and return its base64 encoding."""
    with open(image_path, "rb") as f:
        return base64.standard_b64encode(f.read()).decode("utf-8")


def get_image_media_type(image_path: str) -> str:
    """Determine the MIME type from file extension."""
    ext = image_path.lower().split(".")[-1]
    return {
        "jpg": "image/jpeg",
        "jpeg": "image/jpeg",
        "png": "image/png",
        "gif": "image/gif",
        "webp": "image/webp",
    }.get(ext, "image/jpeg")


# ====================
# GPT-4o Vision
# ====================
def analyze_image_gpt4o(image_path: str, prompt: str) -> str:
    """
    Analyze an image using GPT-4o's vision capabilities.

    Args:
        image_path: Path to the image file
        prompt: Question or instruction about the image

    Returns:
        The model's text response
    """
    client = OpenAI()  # Uses OPENAI_API_KEY env variable

    base64_image = encode_image_to_base64(image_path)
    media_type = get_image_media_type(image_path)

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": prompt,
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:{media_type};base64,{base64_image}",
                            "detail": "high",  # "low", "high", or "auto"
                        },
                    },
                ],
            }
        ],
        max_tokens=1024,
    )

    return response.choices[0].message.content


# ====================
# Claude Vision
# ====================
def analyze_image_claude(image_path: str, prompt: str) -> str:
    """
    Analyze an image using Claude's vision capabilities.

    Args:
        image_path: Path to the image file
        prompt: Question or instruction about the image

    Returns:
        The model's text response
    """
    client = Anthropic()  # Uses ANTHROPIC_API_KEY env variable

    base64_image = encode_image_to_base64(image_path)
    media_type = get_image_media_type(image_path)

    response = client.messages.create(
        model="claude-sonnet-4-20250514",
        max_tokens=1024,
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": media_type,
                            "data": base64_image,
                        },
                    },
                    {
                        "type": "text",
                        "text": prompt,
                    },
                ],
            }
        ],
    )

    return response.content[0].text


# ====================
# Usage Examples
# ====================

# Describe an image
# description = analyze_image_gpt4o("photo.jpg", "Describe this image in detail.")
# print(description)

# Extract text from a document
# text = analyze_image_claude("invoice.png",
#     "Extract all text, amounts, and dates from this invoice. "
#     "Return the result as structured JSON."
# )
# print(text)

# Visual question answering
# answer = analyze_image_gpt4o("diagram.png",
#     "This is an architecture diagram. Explain the data flow step by step."
# )
# print(answer)

5.4 PRACTICAL: Build a Multimodal Application

Let us build a complete multimodal application that combines image captioning, visual Q&A, and document analysis.

import base64
import json
from pathlib import Path
from openai import OpenAI
from anthropic import Anthropic

class MultimodalAssistant:
    """
    A multimodal AI assistant that can:
    1. Caption images with detailed descriptions
    2. Answer questions about images
    3. Extract structured data from documents
    4. Compare multiple images
    5. Analyze charts and diagrams
    """

    def __init__(self, provider="openai"):
        self.provider = provider
        if provider == "openai":
            self.client = OpenAI()
            self.model = "gpt-4o"
        elif provider == "anthropic":
            self.client = Anthropic()
            self.model = "claude-sonnet-4-20250514"
        else:
            raise ValueError(f"Unknown provider: {provider}")

    def _encode_image(self, image_path: str) -> tuple[str, str]:
        """Encode image to base64 and determine media type."""
        with open(image_path, "rb") as f:
            data = base64.standard_b64encode(f.read()).decode("utf-8")
        ext = image_path.lower().split(".")[-1]
        media_types = {"jpg": "image/jpeg", "jpeg": "image/jpeg",
                       "png": "image/png", "webp": "image/webp"}
        return data, media_types.get(ext, "image/jpeg")

    def _call_model(self, messages: list, max_tokens: int = 2048) -> str:
        """Unified model calling for both providers."""
        if self.provider == "openai":
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                max_tokens=max_tokens,
            )
            return response.choices[0].message.content
        else:
            # Extract system message if present
            system = None
            user_messages = []
            for msg in messages:
                if msg["role"] == "system":
                    system = msg["content"]
                else:
                    user_messages.append(msg)

            kwargs = {"model": self.model, "max_tokens": max_tokens,
                      "messages": user_messages}
            if system:
                kwargs["system"] = system

            response = self.client.messages.create(**kwargs)
            return response.content[0].text

    def _make_image_content(self, image_path: str) -> dict:
        """Create image content block for the appropriate provider."""
        data, media_type = self._encode_image(image_path)
        if self.provider == "openai":
            return {
                "type": "image_url",
                "image_url": {"url": f"data:{media_type};base64,{data}", "detail": "high"},
            }
        else:
            return {
                "type": "image",
                "source": {"type": "base64", "media_type": media_type, "data": data},
            }

    # ---------- Feature 1: Detailed Image Captioning ----------
    def caption_image(self, image_path: str, style="detailed") -> str:
        """Generate a caption for an image."""
        prompts = {
            "brief": "Describe this image in one sentence.",
            "detailed": (
                "Provide a detailed description of this image. Include:\n"
                "- Main subject and setting\n"
                "- Colors, lighting, and mood\n"
                "- Notable details and composition\n"
                "- Any text visible in the image"
            ),
            "alt_text": (
                "Write concise alt text for this image suitable for screen readers. "
                "Be descriptive but brief (under 125 characters)."
            ),
        }

        messages = [
            {"role": "system", "content": "You are an expert image analyst."},
            {
                "role": "user",
                "content": [
                    self._make_image_content(image_path),
                    {"type": "text", "text": prompts.get(style, prompts["detailed"])},
                ],
            },
        ]
        return self._call_model(messages)

    # ---------- Feature 2: Visual Q&A ----------
    def ask_about_image(self, image_path: str, question: str) -> str:
        """Ask any question about an image."""
        messages = [
            {"role": "system", "content": "Answer questions about images accurately and concisely."},
            {
                "role": "user",
                "content": [
                    self._make_image_content(image_path),
                    {"type": "text", "text": question},
                ],
            },
        ]
        return self._call_model(messages)

    # ---------- Feature 3: Document Data Extraction ----------
    def extract_document_data(self, image_path: str, schema: dict = None) -> dict:
        """
        Extract structured data from a document image (invoice, receipt, form).

        Args:
            image_path: Path to document image
            schema: Optional JSON schema describing expected fields

        Returns:
            Extracted data as a dictionary
        """
        if schema:
            schema_str = json.dumps(schema, indent=2)
            prompt = (
                f"Extract data from this document according to this schema:\n"
                f"```json\n{schema_str}\n```\n\n"
                f"Return ONLY valid JSON matching this schema. No explanation."
            )
        else:
            prompt = (
                "Extract all structured data from this document. "
                "Include dates, amounts, names, addresses, line items, totals, etc. "
                "Return the result as valid JSON. No explanation."
            )

        messages = [
            {"role": "system", "content": "You are a document data extraction specialist. Always return valid JSON."},
            {
                "role": "user",
                "content": [
                    self._make_image_content(image_path),
                    {"type": "text", "text": prompt},
                ],
            },
        ]

        response = self._call_model(messages)

        # Try to parse JSON from response
        try:
            # Handle markdown code blocks
            if "```json" in response:
                response = response.split("```json")[1].split("```")[0]
            elif "```" in response:
                response = response.split("```")[1].split("```")[0]
            return json.loads(response.strip())
        except json.JSONDecodeError:
            return {"raw_text": response}

    # ---------- Feature 4: Compare Images ----------
    def compare_images(self, image_paths: list[str], comparison_prompt: str = None) -> str:
        """Compare two or more images."""
        content = []
        for i, path in enumerate(image_paths):
            content.append({"type": "text", "text": f"Image {i + 1}:"})
            content.append(self._make_image_content(path))

        default_prompt = (
            "Compare these images in detail. What are the similarities and differences? "
            "Consider content, style, composition, colors, and mood."
        )
        content.append({"type": "text", "text": comparison_prompt or default_prompt})

        messages = [
            {"role": "system", "content": "You are an expert at comparing and analyzing images."},
            {"role": "user", "content": content},
        ]
        return self._call_model(messages)

    # ---------- Feature 5: Chart Analysis ----------
    def analyze_chart(self, image_path: str) -> dict:
        """Analyze a chart or graph and extract insights."""
        prompt = (
            "Analyze this chart/graph and provide:\n"
            "1. Chart type (bar, line, pie, scatter, etc.)\n"
            "2. Title and axis labels\n"
            "3. Key data points and values\n"
            "4. Main trends or patterns\n"
            "5. Key insights and takeaways\n\n"
            "Return as JSON with keys: chart_type, title, axes, data_points, "
            "trends, insights"
        )

        messages = [
            {"role": "system", "content": "You are a data visualization expert. Return valid JSON."},
            {
                "role": "user",
                "content": [
                    self._make_image_content(image_path),
                    {"type": "text", "text": prompt},
                ],
            },
        ]

        response = self._call_model(messages)
        try:
            if "```json" in response:
                response = response.split("```json")[1].split("```")[0]
            return json.loads(response.strip())
        except json.JSONDecodeError:
            return {"raw_analysis": response}


# ====================
# Usage
# ====================

# assistant = MultimodalAssistant(provider="openai")

# Caption an image
# caption = assistant.caption_image("sunset.jpg", style="detailed")
# print("Caption:", caption)

# Ask about an image
# answer = assistant.ask_about_image("diagram.png", "What architecture pattern is shown here?")
# print("Answer:", answer)

# Extract invoice data
# invoice_schema = {
#     "vendor": "string",
#     "invoice_number": "string",
#     "date": "string",
#     "line_items": [{"description": "string", "quantity": "number", "amount": "number"}],
#     "subtotal": "number",
#     "tax": "number",
#     "total": "number",
# }
# data = assistant.extract_document_data("invoice.png", schema=invoice_schema)
# print("Extracted:", json.dumps(data, indent=2))

6. Video Models

Video understanding and generation represent the frontier of visual AI. Video adds a temporal dimension to images, dramatically increasing the computational and conceptual challenges.

6.1 Challenges of Video

Why Video Is Hard

  • Massive data volume: A 1-minute 1080p video at 30 fps contains 1,800 frames, each 1920x1080x3 pixels. That is roughly 10 billion pixel values per minute, uncompressed.
  • Temporal dimension: The model must understand not just what is in each frame but how things change over time (motion, causality, physics).
  • Long-range dependencies: Understanding a video often requires connecting events separated by hundreds of frames.
  • Computational cost: Processing video with Transformers requires attention over both spatial and temporal dimensions. Naive approaches are O(N^2) where N = frames x patches_per_frame.
  • Annotation difficulty: Labeling video data is orders of magnitude more expensive than images.

6.2 Video Understanding Architectures

ViViT: Video Vision Transformer

ViViT (Arnab et al., 2021) extends ViT to video by creating "tubelet" embeddings across time:

  • Tubelet embedding: Extract 3D patches (spatial + temporal) from the video, e.g., 2 frames x 16x16 pixels.
  • Factorized attention: First apply spatial attention within each frame, then temporal attention across frames. Much more efficient than joint spacetime attention.
  • Factorized self-attention: Spatial Transformer followed by Temporal Transformer.

Frame Sampling Strategies

Processing every frame is usually impractical. Common sampling strategies:

  • Uniform sampling: Select N frames evenly spaced across the video. Simple but may miss important events.
  • Random sampling: Randomly select N frames during training (data augmentation effect).
  • Keyframe extraction: Use scene change detection to identify informative frames.
  • Dense sampling with stride: Sample every k-th frame from a temporal window.
  • Adaptive sampling: Use a lightweight model to identify which frames are important, then process only those.

6.3 Video Generation Models (2026)

ModelCompanyKey Features
SoraOpenAIDiffusion Transformer (DiT), spacetime patches, up to 1 min, strong physics understanding
Veo 2Google DeepMindHigh fidelity, 4K output, cinematic quality
Runway Gen-3 AlphaRunwayFast iteration, motion brush control
Kling 1.6KuaishouLong-form video, strong motion quality
Wan 2.1AlibabaOpen-source, competitive quality
Hunyuan VideoTencentOpen-source, flow matching based
CogVideoXZhipu AIOpen-source, 3D VAE + DiT

6.4 PRACTICAL: Video Analysis with Multimodal APIs

import base64
import cv2
import tempfile
from pathlib import Path
from openai import OpenAI

class VideoAnalyzer:
    """
    Analyze videos using multimodal LLM APIs.

    Since most APIs accept images (not video directly), we extract
    key frames and send them as a sequence of images.
    """

    def __init__(self):
        self.client = OpenAI()

    def extract_frames(
        self,
        video_path: str,
        num_frames: int = 8,
        strategy: str = "uniform",
    ) -> list[str]:
        """
        Extract frames from a video file.

        Args:
            video_path: Path to video file
            num_frames: Number of frames to extract
            strategy: "uniform" or "keyframe"

        Returns:
            List of base64-encoded frame images
        """
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        duration = total_frames / fps if fps > 0 else 0

        print(f"Video: {total_frames} frames, {fps:.1f} fps, {duration:.1f}s")

        if strategy == "uniform":
            # Evenly spaced frames
            indices = [int(i * total_frames / num_frames) for i in range(num_frames)]
        elif strategy == "keyframe":
            # Simple keyframe detection based on frame difference
            indices = self._detect_keyframes(cap, total_frames, num_frames)
        else:
            indices = list(range(0, min(total_frames, num_frames)))

        frames_base64 = []
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                # Resize for efficiency (max 512px on longest side)
                h, w = frame.shape[:2]
                scale = min(512 / max(h, w), 1.0)
                if scale < 1.0:
                    frame = cv2.resize(frame, (int(w * scale), int(h * scale)))

                # Encode to JPEG base64
                _, buffer = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
                b64 = base64.standard_b64encode(buffer).decode("utf-8")
                frames_base64.append(b64)

        cap.release()
        print(f"Extracted {len(frames_base64)} frames")
        return frames_base64

    def _detect_keyframes(self, cap, total_frames, num_keyframes):
        """Simple keyframe detection using frame differencing."""
        diffs = []
        prev_frame = None

        # Sample every 10th frame for efficiency
        sample_indices = range(0, total_frames, max(1, total_frames // 200))

        for idx in sample_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if not ret:
                continue

            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            small = cv2.resize(gray, (64, 64))

            if prev_frame is not None:
                diff = cv2.absdiff(small, prev_frame).mean()
                diffs.append((idx, diff))

            prev_frame = small

        # Select frames with highest difference (scene changes)
        diffs.sort(key=lambda x: x[1], reverse=True)
        keyframe_indices = sorted([d[0] for d in diffs[:num_keyframes]])

        return keyframe_indices

    def analyze_video(self, video_path: str, prompt: str, num_frames: int = 8) -> str:
        """
        Analyze a video by extracting frames and sending to GPT-4o.

        Args:
            video_path: Path to video file
            prompt: Analysis question or instruction
            num_frames: Number of frames to extract

        Returns:
            Analysis text
        """
        frames = self.extract_frames(video_path, num_frames)

        # Build content with interleaved frame images
        content = [
            {"type": "text", "text": f"I'm showing you {len(frames)} frames extracted from a video, in chronological order.\n"},
        ]

        for i, frame_b64 in enumerate(frames):
            content.append({
                "type": "text",
                "text": f"Frame {i + 1}/{len(frames)}:",
            })
            content.append({
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/jpeg;base64,{frame_b64}",
                    "detail": "low",  # Use "low" for frames to manage tokens
                },
            })

        content.append({"type": "text", "text": f"\n{prompt}"})

        response = self.client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {
                    "role": "system",
                    "content": (
                        "You are a video analysis expert. You are shown sequential frames "
                        "from a video. Analyze the visual content, motion, and temporal "
                        "progression to answer questions about the video."
                    ),
                },
                {"role": "user", "content": content},
            ],
            max_tokens=2048,
        )

        return response.choices[0].message.content

    def summarize_video(self, video_path: str, num_frames: int = 12) -> dict:
        """
        Generate a comprehensive video summary.

        Returns a structured summary with scene descriptions,
        key events, and overall narrative.
        """
        prompt = """Analyze this video and provide a structured summary:

1. **Overall Description**: What is this video about? (2-3 sentences)
2. **Scene Breakdown**: Describe each distinct scene or segment
3. **Key Events**: List the main events in chronological order
4. **Visual Style**: Describe the cinematography, colors, and mood
5. **Objects and People**: List notable objects, people, and their actions
6. **Audio Context** (inferred): What sounds or dialogue might accompany this video?

Return your analysis in a clear, structured format."""

        analysis = self.analyze_video(video_path, prompt, num_frames)
        return {"summary": analysis, "num_frames_analyzed": num_frames}


# ====================
# Usage
# ====================

# analyzer = VideoAnalyzer()

# Basic video analysis
# result = analyzer.analyze_video(
#     "cooking_video.mp4",
#     "What dish is being prepared? List the ingredients and steps shown."
# )
# print(result)

# Video summarization
# summary = analyzer.summarize_video("presentation.mp4", num_frames=16)
# print(summary["summary"])
Note on Gemini for Video: Google's Gemini models natively accept video input via their API, which avoids the frame extraction approach. If you are working with video frequently, Gemini's native video understanding can be more accurate and simpler to use:
# Gemini can process video directly
# import google.generativeai as genai
# model = genai.GenerativeModel("gemini-2.0-flash")
# video_file = genai.upload_file("video.mp4")
# response = model.generate_content([video_file, "Describe this video"])

Summary and Key Takeaways

What We Covered This Week

  1. CNN to ViT Evolution: Computer vision moved from CNNs (local, translation-invariant) to Transformers (global attention, flexible) as data and compute scaled up.
  2. GANs: Generator-Discriminator adversarial training produces realistic images but suffers from mode collapse and training instability. Being replaced by diffusion models.
  3. Vision Transformers: Split images into patches, treat them as tokens, apply standard Transformer. Simple but powerful with enough data.
  4. CLIP: Joint image-text embedding space via contrastive learning. Foundation for zero-shot classification, image search, and generative models.
  5. Multimodal Models: Vision encoder + projection + LLM backbone. GPT-4o, Claude, Gemini can understand images, documents, and more.
  6. Video Models: Extending vision to temporal dimension. Frame sampling, ViViT, and video generation (Sora, Veo) represent the frontier.

Preparation for Next Week

In Week 14: Diffusion Based Models, we will dive deep into the math and implementation of diffusion models -- the technology behind Stable Diffusion, DALL-E 3, Midjourney, and Sora. Make sure you are comfortable with:

  • Basic probability (Gaussian distributions, Bayes' rule)
  • Neural network training (loss functions, backpropagation)
  • The U-Net architecture concept (we will implement one)
  • The CLIP model (used as the text encoder in Stable Diffusion)

Exercises

Exercise 1: GAN Exploration

Modify the GAN implementation to use Wasserstein loss with gradient penalty (WGAN-GP). Compare the training stability and quality of generated digits with the original BCE loss.

Exercise 2: ViT Analysis

Visualize the attention maps from the ViT model. For a given input image, plot which patches each head attends to. Do different heads learn different patterns?

Exercise 3: CLIP Search Engine

Build a complete image search application using the CLIPImageSearchEngine class. Create a simple Gradio or Streamlit UI that lets users search a folder of images by typing natural language queries.

Exercise 4: Multimodal Pipeline

Build an automated document processing pipeline that: (1) accepts document images, (2) classifies the document type, (3) extracts structured data based on the type, and (4) stores results in a database.

Exercise 5: Video Understanding

Extend the VideoAnalyzer to support action recognition. Given a video, classify the primary action being performed from a predefined list (e.g., cooking, sports, presentation, conversation).