tiny-llm 09 / 16 28 min read · 1h hands-on

step 09 · build

The training loop

AdamW with weight-decay groups, warmup + cosine LR schedule, gradient clipping, periodic eval, checkpoints.

training

This is the step where things stop being a model and start being a trained model. We’ve built every architectural piece. Now we wire AdamW to nudge the parameters until the loss goes down.

The training loop is shorter than you’d think — about 120 lines for the full thing including eval, checkpointing, and logging. The interesting bits aren’t the loop itself; they’re the configuration choices around it that make training stable instead of divergent. Specifically:

  1. AdamW with proper weight-decay grouping — apply decay only to the right parameters
  2. Learning rate schedule — linear warmup then cosine decay
  3. Gradient clipping — keep updates from exploding
  4. Periodic eval — measure validation loss to detect overfitting
  5. Checkpointing — save the best model so far

Get these five right and a 5M-param model on TinyStories trains to coherent output in 10–30 minutes on a CPU, or 2–5 minutes on a single GPU. Get any of them wrong and you’ll spend hours debugging why the loss is NaN.

What you’ll have at the end

A train.py you can run from the command line:

uv run python -m tiny_llm.train

It loads the data prepared in step 03, instantiates the model from step 08, runs the training loop, and saves a checkpoint.pt every time validation loss improves. After training, the Animated Gradient Descent demo is no longer abstract — your model is the ball, and the loss curve you’ve watched is your own.

Setup

Add tqdm for the progress bar:

uv add tqdm

Open a new file:

# tiny_llm/train.py
from __future__ import annotations
import math
import time
from dataclasses import dataclass, field
from pathlib import Path
import torch
import torch.nn.functional as F
from tqdm import tqdm

from tiny_llm.gpt import GPT, GPTConfig
from tiny_llm.data import load_token_array, get_batch, DATA_DIR

Training-time configuration

A second dataclass for hyperparameters that don’t affect the architecture but affect training:

# tiny_llm/train.py
@dataclass
class TrainConfig:
    """Hyperparameters for one training run.

    Architecture lives in GPTConfig; this is everything else.
    """
    # Optimizer
    lr: float = 3e-4              # peak learning rate (after warmup)
    weight_decay: float = 0.1     # AdamW decay (applied selectively)
    betas: tuple = (0.9, 0.95)    # AdamW betas — (0.9, 0.95) is the GPT-3 / LLaMA convention
    grad_clip: float = 1.0        # max grad norm; clip if larger

    # Schedule
    max_steps: int = 5000
    warmup_steps: int = 200
    min_lr: float = 3e-5          # peak / 10 — typical convention

    # Batches
    batch_size: int = 32
    grad_accum_steps: int = 1     # effective batch = batch_size * this

    # Eval
    eval_interval: int = 250
    eval_iters: int = 50          # number of batches to average eval loss over

    # Checkpoints
    out_dir: Path = field(default_factory=lambda: Path("checkpoints"))

    # Hardware
    device: str = "auto"          # "cuda", "mps", "cpu", or "auto"
    seed: int = 42

Defaults targeted at the 5M model and TinyStories. We’ll touch a few of these in step 11 when we scale up.

Building the optimizer

This is the part that gets wrong most often in tutorials. AdamW applies weight decay to all parameters by default — but you don’t actually want weight decay on biases, layer-norm γ/β, or 1D parameters in general. Decaying them slowly pulls them toward zero, which is wrong (they’re calibration, not features).

Standard practice (used in nanoGPT, GPT-NeoX, and roughly every modern recipe):

# tiny_llm/train.py
def make_optimizer(model: GPT, train_cfg: TrainConfig) -> torch.optim.AdamW:
    """Build AdamW with separate weight-decay groups.

    Convention: 2D parameters (Linear weights, Embedding weights) get weight
    decay; 1D parameters (biases, LayerNorm γ/β) do not.
    """
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if p.dim() >= 2:
            decay.append(p)
        else:
            no_decay.append(p)

    print(f"  decay params:    {len(decay):4d} tensors, "
          f"{sum(p.numel() for p in decay):,} elements")
    print(f"  no-decay params: {len(no_decay):4d} tensors, "
          f"{sum(p.numel() for p in no_decay):,} elements")

    optimizer_groups = [
        {"params": decay,    "weight_decay": train_cfg.weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ]
    return torch.optim.AdamW(
        optimizer_groups,
        lr=train_cfg.lr,
        betas=train_cfg.betas,
    )

weight_decay = 0.1 is generous compared to typical PyTorch defaults (often 0.01); modern LLM training uses larger values because the stable schedule (warmup + cosine) lets the model tolerate it.

The betas = (0.9, 0.95) choice is also a small departure from PyTorch’s default (0.9, 0.999). Both work; (0.9, 0.95) is what GPT-3 used and what most modern decoder-only training scripts adopt. It makes the second-moment estimate adapt slightly faster, which empirically trains better.

Learning rate schedule

Two phases: linear warmup over warmup_steps, then cosine decay over the rest.

# tiny_llm/train.py
def lr_at_step(step: int, train_cfg: TrainConfig) -> float:
    """Compute the learning rate at a given training step.

    Schedule:
      - 0..warmup_steps:  linear ramp from 0 to peak (lr)
      - warmup..max:      cosine decay from peak to min_lr
      - past max_steps:   stays at min_lr
    """
    if step < train_cfg.warmup_steps:
        return train_cfg.lr * (step + 1) / train_cfg.warmup_steps
    if step >= train_cfg.max_steps:
        return train_cfg.min_lr

    progress = (step - train_cfg.warmup_steps) / (train_cfg.max_steps - train_cfg.warmup_steps)
    coeff = 0.5 * (1 + math.cos(math.pi * progress))   # 1.0 → 0.0 over the run
    return train_cfg.min_lr + coeff * (train_cfg.lr - train_cfg.min_lr)

Why warmup: at step 0 the model parameters are random; large gradient updates would push them somewhere arbitrary and slow convergence. We start at LR 0 and ramp up so the first few hundred steps gently move toward a useful region.

Why cosine decay: empirically beats linear or step decay across model sizes. The smooth taper means the last 10% of training takes small steps, refining without disturbing what’s been learned.

The min_lr floor is lr / 10 by convention. Going to zero LR at the end means the model can’t recover from any noise in the final batches; keeping a small floor is a cheap safety net.

The training step

The single update happens in a few lines. We wrap it in a function so we can call it repeatedly.

# tiny_llm/train.py
def train_step(
    model: GPT,
    optimizer: torch.optim.Optimizer,
    train_data: Path,
    train_cfg: TrainConfig,
    step: int,
) -> float:
    """One optimizer step; returns the loss value."""
    # Set learning rate for this step.
    lr = lr_at_step(step, train_cfg)
    for pg in optimizer.param_groups:
        pg["lr"] = lr

    model.train()

    # Gradient accumulation: average loss over `grad_accum_steps`
    # micro-batches before stepping. Lets us simulate a larger batch
    # without running out of memory.
    optimizer.zero_grad(set_to_none=True)
    total_loss = 0.0
    for _ in range(train_cfg.grad_accum_steps):
        x, y = get_batch(
            train_data,
            batch_size=train_cfg.batch_size,
            seq_len=model.config.max_seq_len,
            device=str(next(model.parameters()).device),
        )
        logits = model(x)
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            y.view(-1),
        ) / train_cfg.grad_accum_steps
        loss.backward()
        total_loss += loss.item()

    # Gradient clipping — bound the global gradient norm.
    torch.nn.utils.clip_grad_norm_(model.parameters(), train_cfg.grad_clip)

    optimizer.step()
    return total_loss

zero_grad(set_to_none=True) is the modern PyTorch idiom — instead of zeroing the .grad tensors in place, it sets them to None, which means the next .backward() allocates fresh tensors. Slightly faster.

clip_grad_norm_ rescales the entire gradient vector so its L2 norm is at most grad_clip. Without this, occasional huge gradients (which do happen, especially in the first few hundred steps) would push the parameters far from anywhere reasonable. The default 1.0 is conservative and safe.

Gradient accumulation: with grad_accum_steps = 4, each “step” calls .backward() four times before .step(). Effective batch size is batch_size × grad_accum_steps. We default to 1 (no accumulation) but having the knob lets you train with effective batch 128 on a machine that only fits batch 32 in memory.

Periodic evaluation

Every eval_interval steps, measure validation loss. Average over eval_iters random batches to get a stable estimate.

# tiny_llm/train.py
@torch.no_grad()
def estimate_loss(
    model: GPT,
    data: torch.Tensor,
    train_cfg: TrainConfig,
) -> float:
    """Average cross-entropy over `eval_iters` random batches."""
    model.eval()
    losses = torch.zeros(train_cfg.eval_iters)
    for k in range(train_cfg.eval_iters):
        x, y = get_batch(
            data,
            batch_size=train_cfg.batch_size,
            seq_len=model.config.max_seq_len,
            device=str(next(model.parameters()).device),
        )
        logits = model(x)
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            y.view(-1),
        )
        losses[k] = loss.item()
    model.train()
    return losses.mean().item()

We use the same get_batch from the validation array. Random sampling is fine — at eval_iters = 50 and batch_size = 32, we’re averaging over ~50,000 tokens, which is enough for a tight estimate.

Putting the loop together

# tiny_llm/train.py
def train(
    gpt_cfg: GPTConfig | None = None,
    train_cfg: TrainConfig | None = None,
) -> GPT:
    gpt_cfg = gpt_cfg or GPTConfig()
    train_cfg = train_cfg or TrainConfig()
    train_cfg.out_dir.mkdir(parents=True, exist_ok=True)

    torch.manual_seed(train_cfg.seed)

    # Pick device.
    if train_cfg.device == "auto":
        device = "cuda" if torch.cuda.is_available() else (
            "mps" if torch.backends.mps.is_available() else "cpu"
        )
    else:
        device = train_cfg.device
    print(f"device: {device}")

    # Load data once. memmap'd arrays — cheap.
    train_data = load_token_array(DATA_DIR / "train.bin")
    valid_data = load_token_array(DATA_DIR / "valid.bin")
    print(f"train tokens: {len(train_data):,}")
    print(f"valid tokens: {len(valid_data):,}")

    # Build model, move to device.
    model = GPT(gpt_cfg).to(device)
    print(f"model params: {model.n_params:,}")

    # Optimizer
    print("\nbuilding optimizer:")
    optimizer = make_optimizer(model, train_cfg)

    # Training loop
    print(f"\ntraining for {train_cfg.max_steps:,} steps:")
    best_val = float("inf")
    t0 = time.time()
    pbar = tqdm(range(train_cfg.max_steps))
    for step in pbar:
        loss = train_step(model, optimizer, train_data, train_cfg, step)

        # Evaluate periodically.
        if step % train_cfg.eval_interval == 0 or step == train_cfg.max_steps - 1:
            val_loss = estimate_loss(model, valid_data, train_cfg)
            elapsed = time.time() - t0
            pbar.set_description(
                f"step {step:5d}  train {loss:.3f}  val {val_loss:.3f}  "
                f"lr {lr_at_step(step, train_cfg):.4f}  {elapsed:.0f}s"
            )
            if val_loss < best_val:
                best_val = val_loss
                ckpt = {
                    "model": model.state_dict(),
                    "gpt_config": gpt_cfg,
                    "train_config": train_cfg,
                    "step": step,
                    "val_loss": val_loss,
                }
                torch.save(ckpt, train_cfg.out_dir / "best.pt")

    # Save final checkpoint regardless.
    torch.save({
        "model": model.state_dict(),
        "gpt_config": gpt_cfg,
        "train_config": train_cfg,
        "step": train_cfg.max_steps,
    }, train_cfg.out_dir / "final.pt")

    print(f"\ndone in {(time.time() - t0):.0f}s. best val loss: {best_val:.3f}")
    return model

A small main

# tiny_llm/train.py
if __name__ == "__main__":
    train()

Run it

uv run python -m tiny_llm.train

Expected output (approximate; numbers depend on seed and hardware):

device: mps
train tokens: 472,113,920
valid tokens: 4,872,406
model params: 5,266,944

building optimizer:
  decay params:      27 tensors, 5,265,792 elements
  no-decay params:   25 tensors, 1,152 elements

training for 5,000 steps:
step 4999  train 1.687  val 1.713  lr 0.0000  624s
done in 624s. best val loss: 1.704

What to expect during training:

  • Step 0: train loss ≈ 8.3 (≈ ln(4096), uniform over vocab). The model starts knowing nothing.
  • Step 100: train loss ≈ 5.0. Tokenizer biases learned (frequent tokens get more probability mass).
  • Step 500: train loss ≈ 3.0. Common bigrams learned. Output is real words but nonsense sentences.
  • Step 2000: train loss ≈ 2.0. Output is grammatical TinyStories-style sentences.
  • Step 5000: train loss ≈ 1.7. Coherent short stories with characters and plots.

Validation loss tracks training loss closely until ~step 4000 where we start overfitting (val begins flattening while train keeps falling). 5000 steps is roughly the sweet spot for this configuration.

Sanity check: actually generate something

After training, load the best checkpoint and let it complete a prompt:

# Add to the bottom of tiny_llm/train.py, replacing the simple __main__:
if __name__ == "__main__":
    import sys
    if len(sys.argv) > 1 and sys.argv[1] == "sample":
        # Load best checkpoint and generate.
        from tiny_llm.tokenize import BPETokenizer
        from tiny_llm.data import prepare

        # Re-train tokenizer (it's deterministic with our seed).
        tok = prepare()

        ckpt = torch.load(Path("checkpoints/best.pt"), weights_only=False)
        model = GPT(ckpt["gpt_config"])
        model.load_state_dict(ckpt["model"])
        model.eval()

        prompt = "Once upon a time"
        ids = torch.tensor([tok.encode(prompt)])
        out = model.generate(ids, max_new_tokens=80)
        print(tok.decode(out[0].tolist()))
    else:
        train()

After training:

uv run python -m tiny_llm.train sample

Expected output (approximate — your seed and run will differ):

Once upon a time there was a little girl named Lily. She loved to play
with her dog Max. One day, Lily and Max went to the park. They saw a
big red ball. Lily threw the ball and Max ran after it.

That’s a 5M-param model you wrote from scratch, generating coherent (if simple) English. Compare to the [0, 1, 2, 1234, 1234, 1234...] of step 08’s untrained sanity check. Same architecture; everything between is what training did.

A few things that go wrong

If the loss is NaN by step 50, in roughly decreasing likelihood:

  1. Forgot gradient clipping. Add clip_grad_norm_.
  2. LR is too high. Drop lr to 1e-4.
  3. Forgot weight init. Step 04’s nn.init.normal_(..., std=0.02) matters; PyTorch’s default nn.Embedding init is too wide.
  4. Forgot to call model.to(device) after building the optimizer. Move the model first, then construct the optimizer pointing at its (now-on-device) parameters.

If the loss plateaus around 8 (the uniform-distribution baseline):

  1. Targets aren’t shifted by 1. Step 03’s (x, y) shift is the entire training signal.
  2. vocab_size mismatch between tokenizer and model config. The LM head’s output dim has to match the tokenizer’s vocab size exactly.

If train loss falls but val loss doesn’t (or rises):

  1. Overfitting. Train for fewer steps, or use a smaller model, or add more data.

What we did and didn’t do

What we did:

  • AdamW with proper weight-decay grouping (no decay on biases / 1D params)
  • Linear warmup + cosine decay learning rate schedule
  • Global gradient norm clipping
  • Periodic eval on validation set with averaging
  • Checkpoint the best validation loss; also save the final
  • Gradient accumulation knob for effective-batch flexibility
  • Auto-device detection (CUDA / MPS / CPU)

What we didn’t:

  • Mixed-precision training (bf16/fp16). Significant speedup on GPUs; unnecessary at our scale. Add torch.amp.autocast if you scale up in step 11.
  • Distributed training (DDP/FSDP). Single-GPU assumption; multi-GPU is a different article.
  • Learning rate finder. Some recipes do a sweep first to find the optimal LR; we picked a reasonable default and trust it.
  • Gradient noise scale or other batch-size tuning. Empirically batch_size=32, accum=1 is fine for the 5M model; you can increase for larger.
  • Fancy schedulers (1cycle, OneCycleLR, plateau-detection). Cosine + warmup is the production standard for LLM training.

Cross-references

The Optimizer Race demo shows AdamW racing other optimizers on a toy loss landscape. The Gradient Descent demo and animated walkthrough are the visual versions of what optimizer.step() is doing under the hood.

Next

Step 10 covers sampling and decoding. Right now model.generate is greedy — it always picks the most-probable token. That produces bland, repetitive output. We’ll add temperature, top-k, and top-p (nucleus) sampling to make generations interesting and diverse. Then step 11 looks at how the same architecture scales when we pump the dimensions and depth.