Self-Attention (KQV)

Self-attention is the operation that lets each token in a sequence selectively pull information from every other token. It’s the engine inside every transformer.

The core formula

Attention(Q, K, V) = softmax(QKᵀ / √d_k) · V

That’s it. Five symbols. We’ll unpack what each is and why this works.

Setup: the three projections

For each token position i, we compute three vectors:

  • q_i — the query: “what am I looking for?”
  • k_i — the key: “what do I offer?”
  • v_i — the value: “if you pick me, here’s the info I bring”

These are produced by three learned linear projections of the token’s embedding x_i:

q_i = W_Q · x_i
k_i = W_K · x_i
v_i = W_V · x_i

Stacked across positions:

Q = X · W_Qᵀ      (T × d_k)
K = X · W_Kᵀ      (T × d_k)
V = X · W_Vᵀ      (T × d_v)

Where X is (T, d_model) — T tokens, each d_model-dim. Typical sizes: d_model = 4096, d_k = d_v = d_model / num_heads.

The similarity matrix

QKᵀ ∈ ℝ^(T × T)

Entry (i, j) is q_i · k_j — how much token i’s query matches token j’s key. Higher = more relevant.

Scaling by √d_k

Without scaling, dot products grow with dimension. For d_k = 128, dot products of two random N(0,1) vectors have variance ~128. After softmax, large values produce extreme distributions (one near 1, others near 0) — bad for gradients.

Dividing by √d_k keeps dot products at O(1) regardless of dimension. Stable softmax, stable gradients.

Softmax

For each row i, normalize across columns to a probability distribution:

α_{i,j} = softmax_j(q_i · k_j / √d_k)

This row tells you how much attention token i pays to each other token j. The row sums to 1.

The output

For each position i, take the weighted sum of value vectors:

output_i = Σ_j α_{i,j} · v_j

In matrix form:

output = α · V    where α = softmax(QKᵀ/√d_k)

That’s the new representation of token i — a context-sensitive blend of value vectors from across the sequence.

Walking through with shapes

For a single head with d_k = 64, batch size 1, sequence length 8, d_model = 512:

StepShape
Input X(8, 512)
Q = X W_Q(8, 64)
K = X W_K(8, 64)
V = X W_V(8, 64)
QKᵀ(8, 8)
/ √d_k(8, 8)
softmax (row-wise)(8, 8)
× V(8, 64)
Final(8, 64)

You can apply many such heads in parallel — that’s multi-head attention (next article).

Causal masking (decoder-only)

For autoregressive generation, token t cannot attend to tokens after it (would leak the future).

Before the softmax, add a mask:

mask[i][j] = 0    if i ≥ j (allowed)
mask[i][j] = -∞   if i < j  (blocked)

After softmax, the masked positions have weight 0. The model only looks at past tokens.

mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
scores = scores.masked_fill(mask, float("-inf"))

GPT-style decoder-only models always use this. BERT-style encoders don’t.

Self-attention vs cross-attention

  • Self-attention: Q, K, V all come from the same sequence. Each token attends to others in the same sequence.
  • Cross-attention: Q comes from one sequence (e.g. decoder), K and V from another (e.g. encoder). Used in encoder-decoder models like T5 and seq2seq.

Vision-language models like BLIP and LLaVA use cross-attention to mix image features into text generation.

Implementing it

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, d_model, d_k):
        super().__init__()
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_k, bias=False)
        self.scale = d_k ** -0.5

    def forward(self, x, mask=None):
        Q, K, V = self.W_Q(x), self.W_K(x), self.W_V(x)
        scores = (Q @ K.transpose(-2, -1)) * self.scale
        if mask is not None:
            scores = scores.masked_fill(mask, float("-inf"))
        attn = F.softmax(scores, dim=-1)
        return attn @ V

That’s a single attention head. ~15 lines.

What attention learns

In trained transformers, different heads specialize:

  • Some attend to syntactic neighbors (the previous word, the head verb).
  • Some link entities to their pronouns (“Alice … she”).
  • Some attend to delimiters (sentence boundaries).
  • Some focus on rare tokens or specific tokens of interest.

This emergent specialization is what makes attention so flexible. It’s also what mechanistic interpretability research is unpacking.

Quadratic complexity

The QKᵀ matrix is (T × T). For T = 100k, that’s 10^10 entries — already a lot. For T = 1M, infeasible to materialize.

Mitigations:

  • FlashAttention (Dao et al. 2022, v2/v3 since): reorders the computation to never materialize the full T × T matrix; computes attention in tiles using GPU shared memory. Standard everywhere now.
  • Sparse attention: each token only attends to a subset (window, strided, learned).
  • Linear attention: replaces softmax with a kernelized approximation, reducing to O(T).
  • Mamba / SSMs: avoid attention entirely for long sequences.
  • MQA/GQA: reduce KV memory cost by sharing keys/values across heads (next article).

For most practical models in 2026, FlashAttention is the default and the quadratic cost is manageable up to 128k+ tokens. Beyond that, hybrid architectures take over.

KV caching

At inference time, when generating token t+1, the K and V for tokens 1..t were already computed. Cache them. Each new token only computes its own Q (against cached K) and produces one new K/V to add to the cache.

Without KV cache: each token = O(T²). With: each token = O(T). For long sequences, this is the difference between feasible and not.

The KV cache is also the dominant memory consumer at inference for large models — see Stage 13.

Common confusions

  • Q, K, V are not three different things stored per token. They’re three projections of the same embedding.
  • Attention is not “softmax weighted average of words.” It’s a softmax weighted average of value projections of words — which can be very different from the words themselves.
  • The d_k in softmax is the head’s per-head dimension, not the model dim. Don’t divide by √d_model when the projections are smaller.

Exercises

  1. Implement attention from scratch in <30 lines of PyTorch on a 4-token toy example. Print the attention matrix.
  2. Causal mask check. Verify that a decoder-only attention output for token 3 doesn’t depend on tokens 4+.
  3. What happens without √d_k. Train a small transformer with and without scaling. Watch the loss curve.
  4. Visualize learned attention. Train a tiny model, then plot the attention weights for one head on one input. Notice the patterns.

Watch it interactively

Three demos that exercise the equation above on real models:

  • Attention Inspector — real GPT-2 small attention tensors. Pick a sentence, slide layer (0–11) and head (0–11), click any token to see its top-5 keys. Predict before clicking: at layer 0 attention is mostly diagonal (positional); by layer 11 you’ll see semantic patterns — e.g., “it” attending to “trophy” via coreference.
  • Head Gallery — all 12 heads of one layer side-by-side. Watch how heads specialize across the layer.
  • Linear Algebra Lab — the dot product q·k made geometric. Drag two vectors, see how alignment becomes attention score.

Build it in code

See also