The Transformer Block
The transformer block is the unit you stack. Modern LLMs have 30–100+ of them. Once you understand one, you understand the whole architecture.
The pre-norm block (modern default)
x = x + Attention(LayerNorm(x))
x = x + MLP(LayerNorm(x))
Two sub-layers, each wrapped in:
- LayerNorm (or RMSNorm)
- A residual connection (
+ x)
Order matters. Pre-norm (LN before sub-layer) is the standard since around 2020 — it’s more stable for deep transformers. The original 2017 paper used post-norm, which works for shallow nets but is harder to train at depth.
In code
import torch.nn as nn
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, mlp_dim):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, num_heads)
self.norm2 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, d_model),
)
def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask=mask)
x = x + self.mlp(self.norm2(x))
return x
About 20 lines. That’s the whole block.
The MLP / FFN
The feed-forward network inside each block. Typical:
hidden = Linear(d_model, 4 · d_model)
→ activation (GELU or SiLU)
→ Linear(4 · d_model, d_model)
The 4× expansion is conventional. The hidden dimension is where most of the parameter count lives — typically more than the attention layers. The MLP is ~⅔ of a transformer’s parameters.
Modern variants:
- GeGLU / SwiGLU: gated MLPs with two parallel linear layers, multiplied. LLaMA-style models use SwiGLU.
- MoE FFN: sparse mixture of experts replaces the FFN — see Stage 07.
# SwiGLU
class SwiGLU(nn.Module):
def __init__(self, d_model, hidden):
super().__init__()
self.w1 = nn.Linear(d_model, hidden, bias=False)
self.w2 = nn.Linear(d_model, hidden, bias=False)
self.w3 = nn.Linear(hidden, d_model, bias=False)
def forward(self, x):
return self.w3(F.silu(self.w1(x)) * self.w2(x))
Residual stream
The series of x = x + sublayer(...) connections forms a residual stream that runs from input to output. Every sublayer reads from and writes to this stream.
This perspective (Anthropic’s transformer circuits work) is illuminating:
- Each layer adds something to the residual stream.
- The stream maintains the “current understanding” of each token.
- Specific heads/MLPs can be analyzed as readers and writers on this stream.
Layer normalization
Stabilizes the residual stream by normalizing each token’s representation independently:
LN(x) = γ · (x − μ) / σ + β where μ, σ are over feature dim
γ and β are learned per-feature.
RMSNorm (used in LLaMA, Qwen, etc.) drops the mean centering:
RMSNorm(x) = γ · x / RMS(x) where RMS(x) = √(mean(x²) + ε)
Slightly faster, comparable quality. Modern open-source default.
Putting it together: a full transformer
class GPT(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, mlp_dim, num_layers, max_seq):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, d_model)
self.pos_embed = nn.Embedding(max_seq, d_model) # or RoPE
self.blocks = nn.ModuleList(
[TransformerBlock(d_model, num_heads, mlp_dim) for _ in range(num_layers)]
)
self.norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# Often: self.lm_head.weight = self.token_embed.weight (weight tying)
def forward(self, tokens):
B, T = tokens.shape
positions = torch.arange(T, device=tokens.device)
x = self.token_embed(tokens) + self.pos_embed(positions)
mask = torch.triu(torch.ones(T, T, device=tokens.device), diagonal=1).bool()
for block in self.blocks:
x = block(x, mask=mask)
x = self.norm(x)
return self.lm_head(x)
A complete decoder-only language model in <50 lines.
Encoder-only vs decoder-only vs encoder-decoder
The transformer block is the same; the differences are in masking and overall topology.
| Topology | Examples | Use |
|---|---|---|
| Encoder-only | BERT, RoBERTa, ELECTRA | Classification, embedding |
| Decoder-only | GPT-2/3/4, LLaMA, Claude, Gemini | Generation, chat |
| Encoder-decoder | T5, BART | Seq2seq (translation, summarization) |
Encoder-only: bidirectional attention (no causal mask). Trained with masked LM.
Decoder-only: causal attention. Trained with next-token prediction.
Encoder-decoder: encoder is bidirectional; decoder is causal and has cross-attention to encoder output.
In 2026, decoder-only dominates for general-purpose models — cleaner, scales better, generalizes better with prompting.
What grows with model size
When you “scale” a transformer, you tune:
d_model(hidden dim) — widernum_layers— deepernum_heads— more attention pathsmlp_dim— wider FFN (usually4 × d_model)vocab_size— bigger tokenizercontext_length— longer sequences
Total parameters ≈ 12 · n_layers · d_model² for a transformer with the standard ratios. (Approximately — there’s vocabulary, positional, and bias overhead.)
Performance optimizations
For training and inference:
- FlashAttention: tile-based attention computation. ~2–4× speedup, much less memory.
- Mixed precision (bf16): 2× faster on modern GPUs.
- Activation checkpointing: trade compute for memory.
- Gradient accumulation: simulate larger batches.
- ZeRO / FSDP / TP: distributed training across GPUs.
- KV caching: at inference, cache past K/V to avoid recomputing.
- Speculative decoding: a small draft model proposes tokens; the big model verifies in parallel. Faster inference.
Variants worth knowing
- Parallel attention/MLP: compute both branches from the same normalized input simultaneously. PaLM uses this.
- No bias terms: some modern implementations drop biases everywhere — slight speed/memory win, no quality cost.
- RMSNorm + SwiGLU + RoPE + GQA: this combination is the LLaMA-class default and has spread to most open models.
Watch it interactively
- Pipeline — real GPT-2 small forward pass. Watch one input flow through 12 transformer blocks: per-layer residual norms, attention, then logits. Predict before clicking: the residual stream’s norm grows roughly linearly across layers (each block adds, doesn’t replace).
- Layer Norm Lab — center, rescale, then γ/β. Drag γ to 0 and watch the layer collapse.
- KV Cache — toggle caching on/off and watch per-step compute go from O(N²) to O(N).
Build it in code
/build/07— assemble the transformer block — pre-norm + attention + MLP + residual. ~80 lines./build/08— wire up GPT — stack N blocks, train on TinyShakespeare, watch perplexity drop.