Multi-Head Attention

A single attention head can only “look one way at a time.” Multiple heads in parallel let the model attend to different things simultaneously — syntactic, semantic, positional, content-based.

The idea

Instead of a single attention with d_k = d_model, split into h heads each with d_k = d_model / h:

head_i = Attention(X W_Q^i, X W_K^i, X W_V^i)    for i in 1..h
output = Concat(head_1, ..., head_h) · W_O

Total computation is roughly the same (smaller matrices, more of them), but the model now has h independent “views.”

Why it works

Each head can specialize:

  • One head attends to the previous token (positional).
  • Another head attends to syntactically related words (subject ↔ verb).
  • Another to all instances of the same noun (coreference).
  • Another to special tokens (BOS, separators).

Empirically, multi-head significantly outperforms a single big head with the same total parameters.

Implementing it

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        B, T, _ = x.shape
        Q = self.W_Q(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
        # Q,K,V shapes now: (B, H, T, d_k)

        scores = (Q @ K.transpose(-2, -1)) / (self.d_k ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask, float("-inf"))
        attn = F.softmax(scores, dim=-1)
        out = attn @ V                     # (B, H, T, d_k)
        out = out.transpose(1, 2).contiguous().view(B, T, -1)
        return self.W_O(out)

In modern PyTorch, prefer F.scaled_dot_product_attention — it’s heavily optimized (FlashAttention internally) and handles masking efficiently.

Multi-Query Attention (MQA)

A 2019 trick (Shazeer): share the same K and V across all heads, keeping only Q distinct.

head_i = Attention(X W_Q^i, X W_K, X W_V)

Why? At inference time, the KV cache dominates memory. With MQA, KV cache size is divided by the number of heads. Throughput goes up dramatically.

Quality cost: usually small with adequate training. Used in PaLM, Falcon.

Grouped-Query Attention (GQA)

A 2023 compromise (Ainslie et al.): K and V are shared across groups of heads, not all heads.

Q heads: 32      KV groups: 8     → each KV group serves 4 query heads

Quality nearly matches full multi-head; memory and speed nearly match MQA.

GQA is the modern default — used in LLaMA-2/3, Mistral, Qwen, most recent models.

Sparse and local attention

For long contexts, full attention is expensive. Variants:

  • Local attention: each token only attends to a window around it.
  • Strided attention: skip patterns to cover longer ranges with fewer ops.
  • BigBird, Longformer: combinations of local + global + random.
  • Blockwise sparse: structured sparsity for hardware efficiency.

In practice, modern frontier models use dense attention with FlashAttention + GQA for context windows up to 128k–1M tokens, falling back to specialized architectures (Mamba, hybrids) for longer.

Latent attention (DeepSeek-V2/V3)

DeepSeek introduced Multi-head Latent Attention (MLA) — compresses K/V to a low-rank latent representation that can be expanded on demand. Lower KV cache memory than GQA, similar quality.

This is an active area; expect more variants.

How many heads?

Depends on d_model. Common ratios:

d_modelnum_headshead_dim
768 (GPT-2 base)1264
1024 (BERT-large)1664
4096 (LLaMA-2 7B)32128
8192 (larger LLaMAs)64128

Most modern designs keep head_dim = 64 or 128. Wider models add heads rather than making each one wider.

Visualization and interpretability

You can plot the attention matrix of any head on any input. In trained transformers:

  • Layer 1: heads attend to nearby tokens, syntactic patterns.
  • Mid layers: more content-driven — heads attend to relevant entities.
  • Late layers: heads attend to specific tokens needed for the next prediction.

Tools like BertViz make this easy. Worth exploring once.

Pitfalls

  • Head dim too small. Below ~32 dim, individual heads can’t carry enough info.
  • Forgetting to scale by √d_k. Catastrophic.
  • Mask shape mismatches. A common bug. Always verify shapes when debugging.

Watch it interactively

  • Head Gallery — all 12 heads of one layer at once on real GPT-2 small. Predict before clicking: at layer 0, all 12 heads look diagonal-ish; by layer 11 you’ll find at least 3 distinct shapes — a broadcaster (column 0 dominates), a previous-token head (off-by-one diagonal), and a long-range semantic head.
  • Attention Inspector — drill into one head at a time; click a token to see its top-5 keys.

Build it in code

See also