Stage 06 — Transformers: Solutions
Worked solutions for Stage 6.
Dependencies: torch.
Implement attention from scratch on a 4-token toy
In <30 lines of PyTorch, on a 4-token input. Print the attention matrix.
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
T, d_model, d_k = 4, 8, 8
x = torch.randn(T, d_model)
W_Q = nn.Linear(d_model, d_k, bias=False)
W_K = nn.Linear(d_model, d_k, bias=False)
W_V = nn.Linear(d_model, d_k, bias=False)
Q = W_Q(x) # (4, 8)
K = W_K(x) # (4, 8)
V = W_V(x) # (4, 8)
scores = Q @ K.transpose(0, 1) / d_k**0.5 # (4, 4)
attn = F.softmax(scores, dim=-1)
out = attn @ V # (4, 8)
print("attention matrix (rows = queries, cols = keys):")
print(attn.detach().round(decimals=3))
print("each row sums to:", attn.sum(dim=-1).round(decimals=3))
Each row is a probability distribution (sums to 1). With random weights and no training, the matrix is mostly noise — but the structure holds.
Causal mask check
Verify that a decoder-only attention output for token 3 doesn’t depend on tokens 4+.
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
T, d = 6, 16
class CausalAttn(nn.Module):
def __init__(self):
super().__init__()
self.W_Q = nn.Linear(d, d, bias=False)
self.W_K = nn.Linear(d, d, bias=False)
self.W_V = nn.Linear(d, d, bias=False)
def forward(self, x):
Q, K, V = self.W_Q(x), self.W_K(x), self.W_V(x)
scores = Q @ K.transpose(0, 1) / d**0.5
mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
scores = scores.masked_fill(mask, float("-inf"))
return F.softmax(scores, dim=-1) @ V
attn = CausalAttn()
x = torch.randn(T, d)
out_orig = attn(x)
# Modify only token 4 onwards (positions 4 and 5)
x_mod = x.clone()
x_mod[4:] = torch.randn(2, d) * 100
out_mod = attn(x_mod)
# Check: positions 0..3 should be unchanged
diff = (out_orig[:4] - out_mod[:4]).abs().max().item()
print(f"max diff in positions 0-3: {diff:.6f}") # 0.000000
# Positions 4-5 should change
diff_late = (out_orig[4:] - out_mod[4:]).abs().max().item()
print(f"max diff in positions 4-5: {diff_late:.6f}") # large
Causal attention is verified: future tokens don’t leak into past representations. This is what enables autoregressive generation — train on the whole sequence in parallel, but each position only “sees” the past.
What happens without √d_k
Train a small transformer with and without scaling. Watch the loss curve.
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
class Block(nn.Module):
def __init__(self, d_model=64, d_k=64, scale=True):
super().__init__()
self.W_Q = nn.Linear(d_model, d_k, bias=False)
self.W_K = nn.Linear(d_model, d_k, bias=False)
self.W_V = nn.Linear(d_model, d_k, bias=False)
self.W_O = nn.Linear(d_k, d_model, bias=False)
self.scale = scale
self.d_k = d_k
self.norm = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.GELU(), nn.Linear(4*d_model, d_model))
def forward(self, x):
h = self.norm(x)
Q, K, V = self.W_Q(h), self.W_K(h), self.W_V(h)
scores = Q @ K.transpose(-2, -1)
if self.scale: scores = scores / self.d_k**0.5
attn = F.softmax(scores, dim=-1)
x = x + self.W_O(attn @ V)
x = x + self.mlp(self.norm(x))
return x
# Toy: predict next token from 32-token sequences over a 64-vocab
def train(scale, seed=0):
torch.manual_seed(seed)
vocab, T, d = 64, 32, 64
embed = nn.Embedding(vocab, d)
blocks = nn.ModuleList([Block(d, d, scale=scale) for _ in range(2)])
head = nn.Linear(d, vocab)
params = list(embed.parameters()) + list(blocks.parameters()) + list(head.parameters())
opt = torch.optim.AdamW(params, lr=3e-4)
losses = []
for step in range(200):
x = torch.randint(0, vocab, (4, T))
h = embed(x)
for b in blocks: h = b(h)
logits = head(h[:, :-1])
target = x[:, 1:]
loss = F.cross_entropy(logits.reshape(-1, vocab), target.reshape(-1))
opt.zero_grad(); loss.backward(); opt.step()
losses.append(loss.item())
return losses
scaled = train(scale=True)
unscaled = train(scale=False)
plt.plot(scaled, label="with √d_k")
plt.plot(unscaled, label="without √d_k")
plt.legend(); plt.title("attention scaling matters")
plt.savefig("attn_scale.png")
The scaled version trains smoothly. The unscaled version often diverges or hits NaN — softmax saturates on huge logits, gradients vanish or explode.
See also
- Stage 6 — README
- Stage 6 — GPT from scratch — full GPT implementation already in the article
- Stage 3 solutions