step 05 · build
Scaled dot-product attention, from scratch
The keystone of every transformer, in 30 lines of PyTorch.
Self-attention is the operation every transformer is built around. After this step you’ll have a working attention.py that produces the same output GPT-2 produces — exactly the matrix the Attention Inspector renders as a heatmap, but written by you.
Open the inspector in another tab right now if you can. We’re going to build the operation that fills those cells.
What we’re doing
For each token in a sequence, self-attention asks: given everything I’ve seen so far, which earlier tokens should I pay attention to? The output for that token is a weighted average of the values of all tokens it attends to.
Three projections do the work. Each is a learned linear layer:
- Query
Q = X · W_Q— “what am I looking for?” - Key
K = X · W_K— “what do I offer?” - Value
V = X · W_V— “if you pick me, here’s the info I bring.”
The famous formula:
Attention(Q, K, V) = softmax(QKᵀ / √d_k) · V
If you want the full derivation, it’s in the Self-Attention (KQV) article. The compressed version: take the dot product of every query with every key, normalize by √d_k to keep gradients stable, softmax to turn scores into probabilities, then use those as weights on the values.
That’s the whole operation. The next 30 lines are us writing it.
Setup
We’ll work in a new file. Create tiny_llm/attention.py:
# tiny_llm/attention.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
That’s the entire import block. PyTorch’s nn.Linear will give us the projection layers; everything else we build with raw tensor ops.
The class skeleton
We’re writing single-head attention first. Multi-head comes in step 06; building the single-head version makes the multi-head extension obvious.
# tiny_llm/attention.py
class CausalSelfAttention(nn.Module):
"""Single-head causal self-attention.
For every token at position t, computes a weighted sum of the values
of all tokens at positions 0..t (the "causal" part — no peeking
ahead). Weights come from the softmax of (Q · K^T / sqrt(d_k)).
"""
def __init__(self, d_model: int, max_seq_len: int):
super().__init__()
self.d_model = d_model
# Three independent linear projections. Each maps (..., d_model)
# to (..., d_model). They share input but are *not* tied weights.
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
# Final output projection — folds attention back to the residual stream.
self.W_o = nn.Linear(d_model, d_model, bias=False)
# Pre-compute the causal mask once, register as a buffer so it
# moves with the module across devices but isn't a parameter.
mask = torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool))
self.register_buffer("causal_mask", mask)
Three things to call out before we write forward:
bias=Falseon the projections. The standard configuration. Biases here add parameters but rarely change the result; LLaMA, GPT-NeoX, Qwen, and most modern decoder-only models drop them. We follow.W_oexists but doesn’t appear in the formula. It’s the output projection — applied after the attention computation to mix the per-head outputs back into the residual stream. With one head it’s nearly redundant; with multi-head (step 06) it’s essential. We add it now so we don’t restructure later.- The causal mask. A lower-triangular boolean matrix where
mask[i, j] = Trueiffj ≤ i. Positionimay attend to positions0..iand nowhere else. We pre-compute it at init because it never changes.
The forward pass
Now the operation itself. I’ll write it all at once so you can read it as a unit, then walk through what each line does:
# tiny_llm/attention.py (continuing the class)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (batch, seq_len, d_model) — residual stream
returns (batch, seq_len, d_model)
"""
B, T, D = x.shape
# 1. Project to queries, keys, values.
q = self.W_q(x) # (B, T, D)
k = self.W_k(x) # (B, T, D)
v = self.W_v(x) # (B, T, D)
# 2. Score: how much does each query attend to each key?
# (B, T, D) @ (B, D, T) -> (B, T, T)
scores = q @ k.transpose(-2, -1)
# 3. Scale by sqrt(d_k). This is the *scaled* in scaled
# dot-product attention. Without it, scores grow with d_k and
# softmax saturates, killing gradients.
scores = scores / math.sqrt(D)
# 4. Apply the causal mask — set future positions to -inf so
# softmax assigns them zero probability.
mask = self.causal_mask[:T, :T] # (T, T) -> top-left slice
scores = scores.masked_fill(~mask, float("-inf"))
# 5. Softmax along the key axis: each row (one query) sums to 1.
weights = F.softmax(scores, dim=-1) # (B, T, T)
# 6. Combine — weighted sum of values.
out = weights @ v # (B, T, D)
# 7. Output projection.
return self.W_o(out)
Seven steps, one per line of math. The whole class is now ~40 lines. If you’ve never seen q @ k.transpose(-2, -1) before: the @ operator is matrix multiplication, and .transpose(-2, -1) swaps the last two dimensions. With shapes (B, T, D) and (B, D, T), the result is (B, T, T) — a square matrix per batch element where entry [i, j] is the dot product of query i with key j.
That’s it. You’ve written self-attention.
Sanity check
Add a small __main__ block at the bottom of the file:
# tiny_llm/attention.py (bottom of file)
if __name__ == "__main__":
torch.manual_seed(0)
attn = CausalSelfAttention(d_model=8, max_seq_len=4)
# Two sequences, four tokens each, 8-d hidden.
x = torch.randn(2, 4, 8)
out = attn(x)
print(f"input shape: {x.shape}")
print(f"output shape: {out.shape}")
# Sanity: the attention weights for one example should be a row-
# stochastic lower-triangular matrix.
with torch.no_grad():
q = attn.W_q(x); k = attn.W_k(x)
scores = q @ k.transpose(-2, -1) / math.sqrt(8)
scores = scores.masked_fill(~attn.causal_mask[:4, :4], float("-inf"))
weights = F.softmax(scores, dim=-1)
print("\nrow sums of attention weights (should all be 1.0):")
print(weights[0].sum(dim=-1))
print("\nattention weights for first sequence (lower-triangular):")
print(weights[0])
Run it:
uv run python -m tiny_llm.attention
Expected output:
input shape: torch.Size([2, 4, 8])
output shape: torch.Size([2, 4, 8])
row sums of attention weights (should all be 1.0):
tensor([1., 1., 1., 1.])
attention weights for first sequence (lower-triangular):
tensor([[1.0000, 0.0000, 0.0000, 0.0000],
[0.4123, 0.5877, 0.0000, 0.0000],
[0.4321, 0.2187, 0.3492, 0.0000],
[0.2845, 0.1822, 0.3133, 0.2200]])
Two things to notice:
- The upper triangle is exactly zero. That’s the causal mask working: token at position 0 sees only itself, position 1 sees positions 0–1, etc. Future positions are hard-masked before softmax.
- Row 0 has weight
1.0on column 0. A position that can only attend to itself softmaxes to a one-hot. As we go down the rows, the model has more positions to distribute weight across, and the values reflect the model’s (random, untrained) preferences.
The exact decimal values depend on the seed and the random init. What matters is the shape: lower-triangular with row sums of 1.
Tying it back to the demo
Now open the Attention Inspector in another tab. The heatmap you see is exactly what we just computed — the weights tensor — but for a real GPT-2 model on a real sentence, instead of random init on random tokens.
Three things you can verify by clicking around:
- The dark upper triangle. Same lower-triangular structure as our sanity-check output. GPT-2 is causal; ours is causal.
- Row sums to 1. The visualization normalizes per row; intensity within a row is comparable. Same as our
softmax(dim=-1). - Layer-0 attention is diffuse, layer-11 is sharp. That’s not an artifact of our code — it’s what trained attention learns. By step 09 we’ll have trained our own version, and ours will start diffuse and sharpen the same way.
What we did and didn’t do
What we did:
- Implemented scaled dot-product attention end-to-end
- Added a causal mask (decoder-only, no peeking)
- Validated row-stochastic weights and lower-triangular structure
- Confirmed our shapes match what comes out of the demo
What we didn’t:
- Multiple heads. We have one head. Real transformers use 8, 12, 16. That’s step 06.
- Dropout. Production attention has dropout on the weights. Ours doesn’t. We’ll add it during training.
- KV caching for inference. During generation, we don’t need to recompute K and V for tokens we’ve already seen. That’s step 14. The KV Cache demo shows what we’ll be optimizing.
Next
Step 06 takes the single-head class we just wrote and runs n_heads copies of it in parallel, then concatenates. The clever trick: instead of n_heads separate nn.Linear layers, we use one big linear and reshape. Same math, much faster. You’ll see why when we look at the Head Specialization Gallery — twelve heads, all running on the same input, each learning a different pattern.