The Transformer Block

The transformer block is the unit you stack. Modern LLMs have 30–100+ of them. Once you understand one, you understand the whole architecture.

The pre-norm block (modern default)

x = x + Attention(LayerNorm(x))
x = x + MLP(LayerNorm(x))

Two sub-layers, each wrapped in:

  1. LayerNorm (or RMSNorm)
  2. A residual connection (+ x)

Order matters. Pre-norm (LN before sub-layer) is the standard since around 2020 — it’s more stable for deep transformers. The original 2017 paper used post-norm, which works for shallow nets but is harder to train at depth.

In code

import torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, mlp_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, d_model),
        )

    def forward(self, x, mask=None):
        x = x + self.attn(self.norm1(x), mask=mask)
        x = x + self.mlp(self.norm2(x))
        return x

About 20 lines. That’s the whole block.

The MLP / FFN

The feed-forward network inside each block. Typical:

hidden = Linear(d_model, 4 · d_model)
        → activation (GELU or SiLU)
        → Linear(4 · d_model, d_model)

The 4× expansion is conventional. The hidden dimension is where most of the parameter count lives — typically more than the attention layers. The MLP is ~⅔ of a transformer’s parameters.

Modern variants:

  • GeGLU / SwiGLU: gated MLPs with two parallel linear layers, multiplied. LLaMA-style models use SwiGLU.
  • MoE FFN: sparse mixture of experts replaces the FFN — see Stage 07.
# SwiGLU
class SwiGLU(nn.Module):
    def __init__(self, d_model, hidden):
        super().__init__()
        self.w1 = nn.Linear(d_model, hidden, bias=False)
        self.w2 = nn.Linear(d_model, hidden, bias=False)
        self.w3 = nn.Linear(hidden, d_model, bias=False)

    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

Residual stream

The series of x = x + sublayer(...) connections forms a residual stream that runs from input to output. Every sublayer reads from and writes to this stream.

This perspective (Anthropic’s transformer circuits work) is illuminating:

  • Each layer adds something to the residual stream.
  • The stream maintains the “current understanding” of each token.
  • Specific heads/MLPs can be analyzed as readers and writers on this stream.

Layer normalization

Stabilizes the residual stream by normalizing each token’s representation independently:

LN(x) = γ · (x − μ) / σ + β    where μ, σ are over feature dim

γ and β are learned per-feature.

RMSNorm (used in LLaMA, Qwen, etc.) drops the mean centering:

RMSNorm(x) = γ · x / RMS(x)    where RMS(x) = √(mean(x²) + ε)

Slightly faster, comparable quality. Modern open-source default.

Putting it together: a full transformer

class GPT(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, mlp_dim, num_layers, max_seq):
        super().__init__()
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_seq, d_model)   # or RoPE
        self.blocks = nn.ModuleList(
            [TransformerBlock(d_model, num_heads, mlp_dim) for _ in range(num_layers)]
        )
        self.norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        # Often: self.lm_head.weight = self.token_embed.weight  (weight tying)

    def forward(self, tokens):
        B, T = tokens.shape
        positions = torch.arange(T, device=tokens.device)
        x = self.token_embed(tokens) + self.pos_embed(positions)
        mask = torch.triu(torch.ones(T, T, device=tokens.device), diagonal=1).bool()
        for block in self.blocks:
            x = block(x, mask=mask)
        x = self.norm(x)
        return self.lm_head(x)

A complete decoder-only language model in <50 lines.

Encoder-only vs decoder-only vs encoder-decoder

The transformer block is the same; the differences are in masking and overall topology.

TopologyExamplesUse
Encoder-onlyBERT, RoBERTa, ELECTRAClassification, embedding
Decoder-onlyGPT-2/3/4, LLaMA, Claude, GeminiGeneration, chat
Encoder-decoderT5, BARTSeq2seq (translation, summarization)

Encoder-only: bidirectional attention (no causal mask). Trained with masked LM.

Decoder-only: causal attention. Trained with next-token prediction.

Encoder-decoder: encoder is bidirectional; decoder is causal and has cross-attention to encoder output.

In 2026, decoder-only dominates for general-purpose models — cleaner, scales better, generalizes better with prompting.

What grows with model size

When you “scale” a transformer, you tune:

  • d_model (hidden dim) — wider
  • num_layers — deeper
  • num_heads — more attention paths
  • mlp_dim — wider FFN (usually 4 × d_model)
  • vocab_size — bigger tokenizer
  • context_length — longer sequences

Total parameters ≈ 12 · n_layers · d_model² for a transformer with the standard ratios. (Approximately — there’s vocabulary, positional, and bias overhead.)

Performance optimizations

For training and inference:

  • FlashAttention: tile-based attention computation. ~2–4× speedup, much less memory.
  • Mixed precision (bf16): 2× faster on modern GPUs.
  • Activation checkpointing: trade compute for memory.
  • Gradient accumulation: simulate larger batches.
  • ZeRO / FSDP / TP: distributed training across GPUs.
  • KV caching: at inference, cache past K/V to avoid recomputing.
  • Speculative decoding: a small draft model proposes tokens; the big model verifies in parallel. Faster inference.

Variants worth knowing

  • Parallel attention/MLP: compute both branches from the same normalized input simultaneously. PaLM uses this.
  • No bias terms: some modern implementations drop biases everywhere — slight speed/memory win, no quality cost.
  • RMSNorm + SwiGLU + RoPE + GQA: this combination is the LLaMA-class default and has spread to most open models.

Watch it interactively

  • Pipeline — real GPT-2 small forward pass. Watch one input flow through 12 transformer blocks: per-layer residual norms, attention, then logits. Predict before clicking: the residual stream’s norm grows roughly linearly across layers (each block adds, doesn’t replace).
  • Layer Norm Lab — center, rescale, then γ/β. Drag γ to 0 and watch the layer collapse.
  • KV Cache — toggle caching on/off and watch per-step compute go from O(N²) to O(N).

Build it in code

See also