Multi-Head Attention
A single attention head can only “look one way at a time.” Multiple heads in parallel let the model attend to different things simultaneously — syntactic, semantic, positional, content-based.
The idea
Instead of a single attention with d_k = d_model, split into h heads each with d_k = d_model / h:
head_i = Attention(X W_Q^i, X W_K^i, X W_V^i) for i in 1..h
output = Concat(head_1, ..., head_h) · W_O
Total computation is roughly the same (smaller matrices, more of them), but the model now has h independent “views.”
Why it works
Each head can specialize:
- One head attends to the previous token (positional).
- Another head attends to syntactically related words (subject ↔ verb).
- Another to all instances of the same noun (coreference).
- Another to special tokens (BOS, separators).
Empirically, multi-head significantly outperforms a single big head with the same total parameters.
Implementing it
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, mask=None):
B, T, _ = x.shape
Q = self.W_Q(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_K(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_V(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
# Q,K,V shapes now: (B, H, T, d_k)
scores = (Q @ K.transpose(-2, -1)) / (self.d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = attn @ V # (B, H, T, d_k)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.W_O(out)
In modern PyTorch, prefer F.scaled_dot_product_attention — it’s heavily optimized (FlashAttention internally) and handles masking efficiently.
Multi-Query Attention (MQA)
A 2019 trick (Shazeer): share the same K and V across all heads, keeping only Q distinct.
head_i = Attention(X W_Q^i, X W_K, X W_V)
Why? At inference time, the KV cache dominates memory. With MQA, KV cache size is divided by the number of heads. Throughput goes up dramatically.
Quality cost: usually small with adequate training. Used in PaLM, Falcon.
Grouped-Query Attention (GQA)
A 2023 compromise (Ainslie et al.): K and V are shared across groups of heads, not all heads.
Q heads: 32 KV groups: 8 → each KV group serves 4 query heads
Quality nearly matches full multi-head; memory and speed nearly match MQA.
GQA is the modern default — used in LLaMA-2/3, Mistral, Qwen, most recent models.
Sparse and local attention
For long contexts, full attention is expensive. Variants:
- Local attention: each token only attends to a window around it.
- Strided attention: skip patterns to cover longer ranges with fewer ops.
- BigBird, Longformer: combinations of local + global + random.
- Blockwise sparse: structured sparsity for hardware efficiency.
In practice, modern frontier models use dense attention with FlashAttention + GQA for context windows up to 128k–1M tokens, falling back to specialized architectures (Mamba, hybrids) for longer.
Latent attention (DeepSeek-V2/V3)
DeepSeek introduced Multi-head Latent Attention (MLA) — compresses K/V to a low-rank latent representation that can be expanded on demand. Lower KV cache memory than GQA, similar quality.
This is an active area; expect more variants.
How many heads?
Depends on d_model. Common ratios:
| d_model | num_heads | head_dim |
|---|---|---|
| 768 (GPT-2 base) | 12 | 64 |
| 1024 (BERT-large) | 16 | 64 |
| 4096 (LLaMA-2 7B) | 32 | 128 |
| 8192 (larger LLaMAs) | 64 | 128 |
Most modern designs keep head_dim = 64 or 128. Wider models add heads rather than making each one wider.
Visualization and interpretability
You can plot the attention matrix of any head on any input. In trained transformers:
- Layer 1: heads attend to nearby tokens, syntactic patterns.
- Mid layers: more content-driven — heads attend to relevant entities.
- Late layers: heads attend to specific tokens needed for the next prediction.
Tools like BertViz make this easy. Worth exploring once.
Pitfalls
- Head dim too small. Below ~32 dim, individual heads can’t carry enough info.
- Forgetting to scale by √d_k. Catastrophic.
- Mask shape mismatches. A common bug. Always verify shapes when debugging.
Watch it interactively
- Head Gallery — all 12 heads of one layer at once on real GPT-2 small. Predict before clicking: at layer 0, all 12 heads look diagonal-ish; by layer 11 you’ll find at least 3 distinct shapes — a broadcaster (column 0 dominates), a previous-token head (off-by-one diagonal), and a long-range semantic head.
- Attention Inspector — drill into one head at a time; click a token to see its top-5 keys.
Build it in code
/build/06— multi-head attention — reshape Q/K/V into H heads, run attention per head, concat back tod_model. ~40 lines on top of the single-head code.
See also
- Self-attention (KQV)
- The transformer block
- Stage 13 — Cost & latency — KV cache sizing