tiny-llm 06 / 16 20 min read · 25 min hands-on

step 06 · build

Multi-head attention

Run n_heads attentions in parallel — efficiently, with one big projection matrix and a reshape.

modelattention

Step 05 gave us a single attention head. One head can only learn one kind of relationship — say, “the verb attends to its subject.” Real text has many simultaneous relationships: subject ↔ verb, pronoun ↔ antecedent, modifier ↔ noun, punctuation ↔ context. Multi-head attention runs n_heads heads in parallel, each on a smaller slice of the embedding dimension, then concatenates their outputs.

The implementation has one trick worth knowing: instead of creating n_heads separate nn.Linear layers and looping over them, we use one big linear that produces all the queries, keys, and values at once, then reshape. Same math, ~3× faster on GPUs, and the code is slightly cleaner.

By the end of this step you’ll have a MultiHeadAttention module that’s a near-drop-in replacement for step 05’s CausalSelfAttention — same interface, same shape contract, but with n_heads heads working in parallel. You’ll also see why the Head Specialization Gallery shows different heads learning different patterns.

What we’re doing

Single-head attention takes a (B, T, D) tensor and returns a (B, T, D) tensor. Multi-head attention does the same — same input, same output — but internally:

  1. Split the D dimension into H heads of size d_head = D / H each.
  2. Run scaled dot-product attention independently in each head. Each head has its own q_h, k_h, v_h projections to a d_head-dimensional subspace.
  3. Concatenate the per-head outputs back into a (B, T, D) tensor.
  4. Project the result through a final W_o to mix the heads.

If you set H = 1, you recover step 05 exactly. The interesting cases are H = 8 or 12 — what GPT-2 small and the original “Attention Is All You Need” use.

The math per head is unchanged from step 05:

Q_h = X · W_Q_h         # shape (B, T, d_head)
K_h = X · W_K_h         # shape (B, T, d_head)
V_h = X · W_V_h         # shape (B, T, d_head)
head_h = softmax(Q_h · K_h^T / √d_head) · V_h
out = concat([head_0, ..., head_{H-1}]) · W_O

What changes is the implementation: we’d hate to write H separate nn.Linear layers and loop. Instead, we reshape.

The reshape trick

Stash this in your head: instead of H separate nn.Linear(D, d_head) layers per Q/K/V (so 3·H Linears), use one nn.Linear(D, 3·D) and reshape its output. The math is identical; the GPU loves it because all the matmuls happen in one fused kernel.

The shapes flow like this:

input        x:    (B, T, D)
↓ Linear(D → 3·D)
qkv:                (B, T, 3·D)
↓ chunk into 3
q, k, v:            three (B, T, D) tensors
↓ reshape into heads
q, k, v:            three (B, H, T, d_head) tensors    ← per-head, "batch" on H
↓ scaled dot-product attention (per head, batched)
out:                (B, H, T, d_head)
↓ recombine heads
out:                (B, T, D)
↓ Linear(D → D) — the "output projection" W_o
final:              (B, T, D)

The two reshapes are pure tensor-shape gymnastics (no FLOPs added), and the per-head attention happens in one batched matmul thanks to PyTorch broadcasting H like another batch dim.

Setup

Open tiny_llm/mha.py:

# tiny_llm/mha.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

Same imports as step 05. We’re going to build on the patterns you already know.

The class

# tiny_llm/mha.py
class MultiHeadAttention(nn.Module):
    """Multi-head causal self-attention.

    Identical interface to step 05's CausalSelfAttention; internally
    runs n_heads parallel heads using a fused QKV projection and a
    reshape-based per-head split.

    Input:  (B, T, d_model)
    Output: (B, T, d_model)
    """

    def __init__(self, d_model: int, n_heads: int, max_seq_len: int) -> None:
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError(f"d_model ({d_model}) must be divisible by n_heads ({n_heads})")
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads

        # ONE big projection that produces Q, K, V all at once.
        # Output dim 3*d_model; we'll chunk into three (B, T, d_model)
        # tensors after the call.
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)

        # Output projection. Same as step 05's W_o.
        self.W_o = nn.Linear(d_model, d_model, bias=False)

        # Same standard init as embeddings (step 04).
        nn.init.normal_(self.qkv.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.W_o.weight, mean=0.0, std=0.02)

        # Causal mask, pre-computed and registered as a non-parameter buffer.
        mask = torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool))
        self.register_buffer("causal_mask", mask)

A note on the divisibility check: every multi-head config requires d_model % n_heads == 0. GPT-2 small uses d_model=768, n_heads=12, d_head=64. LLaMA-3 8B uses d_model=4096, n_heads=32, d_head=128. Powers of 2 for d_head keep CUDA tile shapes happy.

The forward pass

# tiny_llm/mha.py (continuing the class)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, D = x.shape
        H = self.n_heads
        d_head = self.d_head

        # 1. Fused QKV projection: (B, T, D) → (B, T, 3*D).
        qkv = self.qkv(x)

        # 2. Split into Q, K, V along the last axis. Each is (B, T, D).
        q, k, v = qkv.chunk(3, dim=-1)

        # 3. Reshape into heads. (B, T, D) → (B, T, H, d_head) → (B, H, T, d_head).
        # The transpose puts the head axis next to the batch dim so the
        # attention matmul treats each (batch, head) pair as independent.
        q = q.view(B, T, H, d_head).transpose(1, 2)
        k = k.view(B, T, H, d_head).transpose(1, 2)
        v = v.view(B, T, H, d_head).transpose(1, 2)

        # 4. Scaled dot-product attention, batched over (B, H).
        # (B, H, T, d_head) @ (B, H, d_head, T) → (B, H, T, T)
        scores = q @ k.transpose(-2, -1) / math.sqrt(d_head)

        # 5. Apply causal mask (broadcast across batch and heads).
        mask = self.causal_mask[:T, :T]                    # (T, T)
        scores = scores.masked_fill(~mask, float("-inf"))

        # 6. Softmax along the key axis: each row sums to 1.
        weights = F.softmax(scores, dim=-1)                # (B, H, T, T)

        # 7. Weighted sum of values.
        out = weights @ v                                  # (B, H, T, d_head)

        # 8. Recombine heads. (B, H, T, d_head) → (B, T, H, d_head) → (B, T, D).
        # .contiguous() because .view() needs a contiguous-memory tensor.
        out = out.transpose(1, 2).contiguous().view(B, T, D)

        # 9. Output projection — mixes heads into the residual stream.
        return self.W_o(out)

Eight numbered steps, none of them surprising once you’ve read step 05. The differences from single-head:

  • Step 1 & 2 are the fused projection + chunk. Saves one matmul vs three.
  • Step 3 & 8 are the reshape gymnastics. Free in FLOPs.
  • Step 4–7 is exactly step 05’s attention, just with a H axis sitting between batch and time. PyTorch’s matmul broadcasts over leading dims, so the per-head computation is automatic — we never write a Python loop over heads.

The output projection (W_o) is functionally what mixes the per-head outputs back into the residual stream. With one head it’s nearly redundant; with H > 1 it’s essential — different heads write into different “slots” of the d_model vector and W_o learns how to combine them.

Sanity check

Add at the bottom:

# tiny_llm/mha.py (bottom of file)
if __name__ == "__main__":
    torch.manual_seed(0)

    mha = MultiHeadAttention(d_model=64, n_heads=8, max_seq_len=16)
    print(f"params: {sum(p.numel() for p in mha.parameters()):,}")
    print(f"d_head: {mha.d_head}")

    x = torch.randn(2, 12, 64)
    out = mha(x)
    print(f"\ninput shape:  {tuple(x.shape)}")
    print(f"output shape: {tuple(out.shape)}")

    # Verify the n_heads=1 special case matches step 05's expected shape.
    sha = MultiHeadAttention(d_model=64, n_heads=1, max_seq_len=16)
    sha_out = sha(x)
    print(f"\nsingle-head output shape: {tuple(sha_out.shape)} (should match {tuple(out.shape)})")

    # Compute attention weights for one example to verify per-head independence.
    with torch.no_grad():
        qkv = mha.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
        q = q.view(2, 12, 8, 8).transpose(1, 2)
        k = k.view(2, 12, 8, 8).transpose(1, 2)
        scores = q @ k.transpose(-2, -1) / math.sqrt(8)
        scores = scores.masked_fill(~mha.causal_mask[:12, :12], float("-inf"))
        weights = F.softmax(scores, dim=-1)        # (B=2, H=8, T=12, T=12)

    print(f"\nattention weights shape: {tuple(weights.shape)}")
    print(f"row sums (should all be ~1.0):")
    print(f"  head 0, row 11: {weights[0, 0, 11].sum().item():.6f}")
    print(f"  head 7, row 11: {weights[0, 7, 11].sum().item():.6f}")

    # Heads should *differ* — that's the point of multi-head.
    head0_row11 = weights[0, 0, 11]
    head7_row11 = weights[0, 7, 11]
    print(f"\nhead 0 vs head 7 at row 11 differ: {not torch.allclose(head0_row11, head7_row11)}")

Run it:

uv run python -m tiny_llm.mha

Expected output:

params: 16,384
d_head: 8

input shape:  (2, 12, 64)
output shape: (2, 12, 64)

single-head output shape: (2, 12, 64) (should match (2, 12, 64))

attention weights shape: (2, 8, 12, 12)
row sums (should all be ~1.0):
  head 0, row 11: 1.000000
  head 7, row 11: 1.000000

head 0 vs head 7 at row 11 differ: True

What to notice:

  • params: 16,384. Two nn.Linear(64, 64) layers (qkv-fused has 64×3·64 = 12288 params, W_o has 64×64 = 4096) gives 16384 total. Step 05’s single-head version would have 4 × 4096 = 16384. Same parameter count. Multi-head is a free upgrade in parameter terms — we just split them across heads.
  • Output shape unchanged from single-head. The interface contract holds, and the rest of the model doesn’t care how many heads we use.
  • Per-head weights differ. Different heads see the same input but, with random init, attend differently. After training they’ll specialize — which is the entire pedagogical point of the Head Specialization Gallery.

A small efficiency callout

What we wrote works correctly. PyTorch ≥ 2.0 also has F.scaled_dot_product_attention(q, k, v, is_causal=True) which fuses scoring + masking + softmax + value-weighting into a single CUDA kernel — typically 2–4× faster on GPUs and lower-memory because it never materializes the (T, T) weight matrix.

We’re staying with the manual implementation because it makes every step legible and matches the Attention Inspector demo’s output exactly. If you train a real-scale model later, swap in F.scaled_dot_product_attention and wrap your causal mask logic accordingly.

What we did and didn’t do

What we did:

  • Multi-head causal self-attention with a fused QKV projection
  • Reshape-based per-head split (no Python loops, all batched ops)
  • Output-projection weight W_o that mixes heads into the residual stream
  • ~50 lines of PyTorch, drop-in compatible with step 05’s interface
  • Sanity-checked: same shape contract, per-head independence, expected param count

What we didn’t:

  • Use F.scaled_dot_product_attention. Faster, but opaque. Pedagogy first.
  • Multi-query attention (MQA) or grouped-query attention (GQA). LLaMA-2 70B uses GQA — fewer K/V heads than Q heads to save inference memory. We’ll mention this in step 14 (inference) where it actually matters.
  • Rotary position embeddings. RoPE rotates Q and K inside this function instead of adding a separate position embedding. Modern models (LLaMA, Qwen) use RoPE; we use learned absolute positions (step 04) which is simpler. The Positional Encoding Lab compares all four schemes.
  • Dropout on attention weights. A regularization trick from the original transformer paper. Most modern decoder-only models drop it. So do we.

Next

Step 07 wraps multi-head attention in a transformer block — adding layer norm, a feed-forward MLP, and residual connections. The block is what gets repeated N times to make a deep transformer. We’ll discover why “pre-norm” beats “post-norm” for training stability, why the MLP expands to 4× width, and why GELU replaces ReLU in the gap between projections.