Stage 06 — Transformers: Solutions

Worked solutions for Stage 6.

Dependencies: torch.

Implement attention from scratch on a 4-token toy

In <30 lines of PyTorch, on a 4-token input. Print the attention matrix.

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)

T, d_model, d_k = 4, 8, 8
x = torch.randn(T, d_model)

W_Q = nn.Linear(d_model, d_k, bias=False)
W_K = nn.Linear(d_model, d_k, bias=False)
W_V = nn.Linear(d_model, d_k, bias=False)

Q = W_Q(x)                                  # (4, 8)
K = W_K(x)                                  # (4, 8)
V = W_V(x)                                  # (4, 8)

scores = Q @ K.transpose(0, 1) / d_k**0.5   # (4, 4)
attn = F.softmax(scores, dim=-1)
out = attn @ V                              # (4, 8)

print("attention matrix (rows = queries, cols = keys):")
print(attn.detach().round(decimals=3))
print("each row sums to:", attn.sum(dim=-1).round(decimals=3))

Each row is a probability distribution (sums to 1). With random weights and no training, the matrix is mostly noise — but the structure holds.

Causal mask check

Verify that a decoder-only attention output for token 3 doesn’t depend on tokens 4+.

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(0)
T, d = 6, 16

class CausalAttn(nn.Module):
    def __init__(self):
        super().__init__()
        self.W_Q = nn.Linear(d, d, bias=False)
        self.W_K = nn.Linear(d, d, bias=False)
        self.W_V = nn.Linear(d, d, bias=False)

    def forward(self, x):
        Q, K, V = self.W_Q(x), self.W_K(x), self.W_V(x)
        scores = Q @ K.transpose(0, 1) / d**0.5
        mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
        scores = scores.masked_fill(mask, float("-inf"))
        return F.softmax(scores, dim=-1) @ V

attn = CausalAttn()

x = torch.randn(T, d)
out_orig = attn(x)

# Modify only token 4 onwards (positions 4 and 5)
x_mod = x.clone()
x_mod[4:] = torch.randn(2, d) * 100

out_mod = attn(x_mod)

# Check: positions 0..3 should be unchanged
diff = (out_orig[:4] - out_mod[:4]).abs().max().item()
print(f"max diff in positions 0-3: {diff:.6f}")    # 0.000000

# Positions 4-5 should change
diff_late = (out_orig[4:] - out_mod[4:]).abs().max().item()
print(f"max diff in positions 4-5: {diff_late:.6f}")  # large

Causal attention is verified: future tokens don’t leak into past representations. This is what enables autoregressive generation — train on the whole sequence in parallel, but each position only “sees” the past.

What happens without √d_k

Train a small transformer with and without scaling. Watch the loss curve.

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

class Block(nn.Module):
    def __init__(self, d_model=64, d_k=64, scale=True):
        super().__init__()
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_k, bias=False)
        self.W_O = nn.Linear(d_k, d_model, bias=False)
        self.scale = scale
        self.d_k = d_k
        self.norm = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.GELU(), nn.Linear(4*d_model, d_model))

    def forward(self, x):
        h = self.norm(x)
        Q, K, V = self.W_Q(h), self.W_K(h), self.W_V(h)
        scores = Q @ K.transpose(-2, -1)
        if self.scale: scores = scores / self.d_k**0.5
        attn = F.softmax(scores, dim=-1)
        x = x + self.W_O(attn @ V)
        x = x + self.mlp(self.norm(x))
        return x

# Toy: predict next token from 32-token sequences over a 64-vocab
def train(scale, seed=0):
    torch.manual_seed(seed)
    vocab, T, d = 64, 32, 64
    embed = nn.Embedding(vocab, d)
    blocks = nn.ModuleList([Block(d, d, scale=scale) for _ in range(2)])
    head = nn.Linear(d, vocab)
    params = list(embed.parameters()) + list(blocks.parameters()) + list(head.parameters())
    opt = torch.optim.AdamW(params, lr=3e-4)

    losses = []
    for step in range(200):
        x = torch.randint(0, vocab, (4, T))
        h = embed(x)
        for b in blocks: h = b(h)
        logits = head(h[:, :-1])
        target = x[:, 1:]
        loss = F.cross_entropy(logits.reshape(-1, vocab), target.reshape(-1))
        opt.zero_grad(); loss.backward(); opt.step()
        losses.append(loss.item())
    return losses

scaled = train(scale=True)
unscaled = train(scale=False)

plt.plot(scaled, label="with √d_k")
plt.plot(unscaled, label="without √d_k")
plt.legend(); plt.title("attention scaling matters")
plt.savefig("attn_scale.png")

The scaled version trains smoothly. The unscaled version often diverges or hits NaN — softmax saturates on huge logits, gradients vanish or explode.

See also