step 12 · build
Fine-tune with LoRA
Adapt your trained base model to a new behavior — with 800× fewer trainable parameters than full fine-tuning.
You have a trained base model. It produces TinyStories-style fiction. You want it to do something specific — follow instructions, answer questions, write in a particular style. The naive approach is full fine-tuning: train every weight on a new dataset. That works, but it requires storing a separate copy of every parameter for every fine-tune.
LoRA (Low-Rank Adaptation, Hu et al. 2021) is the trick that makes fine-tuning practical at scale. Instead of updating every weight, LoRA learns a tiny low-rank “delta” that gets added to specific weights at inference time. The base model stays frozen.
By the end of this step you’ll have:
- A
LoRALinearwrapper that turns anynn.Linearinto a LoRA-adapted layer - Code that swaps every attention/MLP linear in the trained model with the wrapper
- A short fine-tuning loop that updates only the LoRA parameters (~0.1% of the model)
- A demo that takes the TinyStories base model and teaches it to answer “what’s the moral of this story?” — using ~50 example pairs
The LoRA Lab demo on this site visualizes the rank-r approximation; this step is its implementation.
What LoRA actually does
For any weight matrix W ∈ R^(d × d), LoRA replaces
y = W · x
with
y = (W + ΔW) · x where ΔW = B · A
with A ∈ R^(r × d) and B ∈ R^(d × r). The crucial detail: r (the rank) is tiny — typically 4, 8, or 16 — while d is the model dimension (could be 768 or 4096).
So instead of d² parameters in ΔW directly, we have 2 · r · d parameters in A and B. For d=768, r=8: 12,288 LoRA params replace what would be 589,824 — a 48× reduction per matrix.
The matrix A is initialized with random small values (same std=0.02 we’ve used everywhere). The matrix B is initialized to zero. So at training step 0, ΔW = B·A = 0, and the model behaves identically to the base. Training nudges B away from zero to learn the adaptation.
This is the entire trick. It works because the rank of the update is much smaller than the rank of the weight matrix — for fine-tuning, you typically don’t need to change all directions of weight space, just a few.
Setup
Add a new file:
# tiny_llm/lora.py
import torch
import torch.nn as nn
class LoRALinear(nn.Module):
"""Wraps an existing nn.Linear with a LoRA adapter.
forward(x) = base(x) + (B @ A) @ x · (alpha / r)
The base layer is frozen (requires_grad=False). Only A and B train.
"""
def __init__(
self,
base: nn.Linear,
r: int = 8,
alpha: int = 16,
) -> None:
super().__init__()
self.base = base
self.r = r
self.alpha = alpha
# Standard LoRA scaling factor — applied to ΔW, not to params themselves.
self.scaling = alpha / r
# Freeze the base layer. Only LoRA matrices train.
for p in self.base.parameters():
p.requires_grad = False
in_features = base.in_features
out_features = base.out_features
# A: (r, in_features). Random small init.
self.A = nn.Parameter(torch.zeros(r, in_features))
nn.init.normal_(self.A, mean=0.0, std=0.02)
# B: (out_features, r). Init to ZERO so ΔW = B @ A = 0 at start.
# This means the model behaves identically to the base on step 0.
self.B = nn.Parameter(torch.zeros(out_features, r))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# base(x) is just the original linear's projection.
base_out = self.base(x)
# LoRA delta: x @ A.T → (..., r) → @ B.T → (..., out_features)
# Then scale.
lora_out = (x @ self.A.T @ self.B.T) * self.scaling
return base_out + lora_out
Three things worth understanding in detail.
alpha / r scaling. The convention from the paper. The LoRA delta gets multiplied by alpha/r so that increasing r doesn’t change the effective magnitude of the update. With alpha=16, r=8, scaling is 2; with r=16, scaling drops to 1. This decouples the rank from the update strength — useful when sweeping r.
B initialized to zero. This is non-negotiable. The whole reason LoRA works for fine-tuning is that the model starts as the base model. If you initialized B randomly, the first forward pass would produce garbage and you’d have to recover from there. With B = 0, the adapter is a no-op at step 0 and gradient flow gently steers it.
Frozen base. requires_grad=False on the base layer means its parameters don’t appear in .parameters() returning gradients. The optimizer’s parameter list will only contain A and B — that’s why memory and step time both drop dramatically vs. full fine-tuning.
Wiring LoRA into the GPT
# tiny_llm/lora.py
def add_lora_to_model(model: nn.Module, r: int = 8, alpha: int = 16) -> None:
"""Walk the model and replace every linear in attention sub-layers
with a LoRA-wrapped version. Mutates `model` in place.
Convention: only the attention QKV and output projections get LoRA.
The MLP often gets LoRA too in production; we skip it here for
simplicity. Skipping the LM head is also conventional.
"""
for name, module in model.named_modules():
# Look for the multi-head attention's qkv and W_o projections.
# The recursive descent found by named_modules walks blocks.
if name.endswith(".attn.qkv") or name.endswith(".attn.W_o"):
parent_name, _, child_name = name.rpartition(".")
parent = model.get_submodule(parent_name)
base = getattr(parent, child_name)
assert isinstance(base, nn.Linear)
setattr(parent, child_name, LoRALinear(base, r=r, alpha=alpha))
def lora_parameters(model: nn.Module):
"""Iterator over only the LoRA parameters (A, B). Use this to build
the optimizer so it only updates the adapters."""
for name, param in model.named_parameters():
if param.requires_grad:
yield name, param
def lora_param_count(model: nn.Module) -> int:
"""Total trainable params after LoRA wrapping."""
return sum(p.numel() for _, p in lora_parameters(model))
The add_lora_to_model walks the model’s parameter tree and swaps nn.Linear instances with LoRALinear wrappers in place. After this call, every QKV/W_o projection in every transformer block has been wrapped.
We restrict LoRA to attention-only (skipping MLP) because (a) it’s the conventional starting point in the LoRA paper, (b) the MLP is the parameter majority of the model and including it doubles the LoRA-trainable count. Production setups tune which layers get LoRA based on the task.
Fine-tuning data
For our toy fine-tune, we’ll teach the model to follow a specific instruction format. About 200 examples are enough to see the effect:
[INST] Write a story about a dog and a ball.[/INST] One day, a brown dog named Rex saw a red ball...
[INST] Tell a story with a moral about sharing.[/INST] Maya had a basket of berries...
In real instruction-tuning datasets (Alpaca, Dolly, OASST), this format is consistent across thousands of examples. We’ll fake it for our purposes by templating TinyStories prefixes:
# tiny_llm/finetune.py
import torch
import torch.nn.functional as F
from pathlib import Path
from tiny_llm.gpt import GPT
from tiny_llm.tokenize import BPETokenizer
from tiny_llm.lora import add_lora_to_model, lora_parameters, lora_param_count
from tiny_llm.data import DATA_DIR
INSTRUCTION_PROMPTS = [
"Tell a short story about a friendship.",
"Write a story with a happy ending.",
"Tell a story where a child learns something new.",
"Write a short story about a brave animal.",
"Tell a story about a lost toy that comes home.",
# ...etc; ~50 of these
]
def make_instruction_dataset(stories: list[str]) -> list[str]:
"""Pair instructions with stories. We loop through the small set of
prompts; in real instruction tuning every example would have its
own prompt."""
dataset = []
for i, story in enumerate(stories):
prompt = INSTRUCTION_PROMPTS[i % len(INSTRUCTION_PROMPTS)]
dataset.append(f"[INST] {prompt}[/INST] {story}")
return dataset
Real instruction tuning loads a curated dataset (Alpaca, ShareGPT, etc.); for our purposes the format above is enough to teach the model to recognize the [INST] … [/INST] template and continue with a story.
The fine-tune loop
It’s nearly identical to step 09’s training loop, with two changes:
- Optimizer points at LoRA params only.
- Lower learning rate — fine-tuning needs gentler updates than pretraining.
# tiny_llm/finetune.py
def finetune(
base_ckpt: Path,
train_examples: list[str],
valid_examples: list[str],
out_dir: Path,
r: int = 8,
alpha: int = 16,
lr: float = 5e-5, # 6× smaller than pretraining
max_steps: int = 1000,
batch_size: int = 8,
seq_len: int = 256,
):
"""LoRA fine-tune the base model."""
out_dir.mkdir(parents=True, exist_ok=True)
# Load base model + tokenizer
ckpt = torch.load(base_ckpt, weights_only=False)
model = GPT(ckpt["gpt_config"])
model.load_state_dict(ckpt["model"])
tok = BPETokenizer()
# In practice you'd save/load the tokenizer; for the demo we re-train it
# via the standard data pipeline. (See step 03.)
from tiny_llm.data import prepare
tok = prepare()
# Wrap with LoRA
print(f"\nbefore LoRA: {sum(p.numel() for p in model.parameters()):,} total params")
add_lora_to_model(model, r=r, alpha=alpha)
print(f"after LoRA:")
print(f" total params: {sum(p.numel() for p in model.parameters()):,}")
print(f" trainable params: {lora_param_count(model):,}")
ratio = sum(p.numel() for p in model.parameters()) / lora_param_count(model)
print(f" trainable ratio: 1 / {ratio:.0f}")
# Tokenize training examples once.
print(f"\ntokenizing {len(train_examples)} examples...")
train_ids = [
torch.tensor(tok.encode(ex)[:seq_len], dtype=torch.long)
for ex in train_examples
]
# Optimizer over LoRA params only.
optimizer = torch.optim.AdamW(
[p for _, p in lora_parameters(model)],
lr=lr,
weight_decay=0.0, # no decay on LoRA — convention
)
model.train()
for step in range(max_steps):
# Sample a random batch of examples and pad to a common length.
idx = torch.randint(0, len(train_ids), (batch_size,))
batch = [train_ids[i] for i in idx]
max_len = max(len(b) for b in batch)
pad_id = tok.vocab["<|pad|>"]
x = torch.stack([
torch.cat([b, torch.full((max_len - len(b),), pad_id, dtype=torch.long)])
for b in batch
])
y = x.clone()
y[x == pad_id] = -100 # ignore padding in the loss
logits = model(x[:, :-1])
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
y[:, 1:].reshape(-1),
ignore_index=-100,
)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(
[p for _, p in lora_parameters(model)], 1.0,
)
optimizer.step()
if step % 50 == 0:
print(f" step {step:4d} loss {loss.item():.3f}")
# Save only the LoRA parameters — base stays in `base_ckpt`.
lora_state = {name: param.detach() for name, param in lora_parameters(model)}
torch.save({"lora": lora_state, "r": r, "alpha": alpha}, out_dir / "lora.pt")
print(f"\nsaved LoRA adapter to {out_dir / 'lora.pt'}")
Note ignore_index=-100 in F.cross_entropy. We mark padded positions with -100, and PyTorch’s loss skips them — otherwise we’d be training the model to predict padding from padding, which is meaningless.
Sanity check
A __main__ that runs the whole flow:
# tiny_llm/finetune.py
if __name__ == "__main__":
# Quick smoke test with random "instructions" pulled from the
# validation set. In real use you'd load a curated set.
import random
base_ckpt = Path("checkpoints/best.pt")
if not base_ckpt.exists():
print("Need a base checkpoint. Run `tiny_llm.train` first.")
raise SystemExit(1)
# Pull 200 short stories from the validation file as our toy
# training set. This is a stand-in for a real instruction dataset.
valid = (DATA_DIR / "tinystories_valid.txt").read_text(encoding="utf-8")
stories = [s.strip() for s in valid.split("<|endoftext|>") if 50 < len(s) < 800][:200]
random.shuffle(stories)
train_examples = make_instruction_dataset(stories[:180])
valid_examples = make_instruction_dataset(stories[180:])
finetune(
base_ckpt=base_ckpt,
train_examples=train_examples,
valid_examples=valid_examples,
out_dir=Path("checkpoints/lora"),
max_steps=500,
)
Run it:
uv run python -m tiny_llm.finetune
Expected output (numbers depend on your base):
before LoRA: 5,266,944 total params
after LoRA:
total params: 5,316,096
trainable params: 49,152
trainable ratio: 1 / 108
tokenizing 180 examples...
step 0 loss 6.741
step 50 loss 3.812
step 100 loss 3.143
step 150 loss 2.654
step 200 loss 2.301
step 250 loss 2.087
step 300 loss 1.949
step 350 loss 1.836
step 400 loss 1.752
step 450 loss 1.701
saved LoRA adapter to checkpoints/lora/lora.pt
What to notice:
- 49k trainable params on a 5.3M model — exactly the 1/108 we expected (
r=8, alpha=16on the QKV + W_o linears in 6 blocks). On a real LLaMA-7B, the same setup gives ~4M trainable on 7B total — about 1/1700. - Initial loss is around 6.7 — the base model has never seen the
[INST]…[/INST]format and is confused. After 100 steps it’s adapting; after 500 it’s halfway there. - Loss curve looks like training, just smaller and faster. Same shape, ~10× faster wall-clock per step (because the optimizer only updates 0.9% of the weights).
Loading and using the adapter
To use the fine-tuned model, you’d load the base + adapter:
def load_with_lora(base_ckpt: Path, adapter_ckpt: Path) -> GPT:
base = torch.load(base_ckpt, weights_only=False)
model = GPT(base["gpt_config"])
model.load_state_dict(base["model"])
adapter = torch.load(adapter_ckpt, weights_only=False)
add_lora_to_model(model, r=adapter["r"], alpha=adapter["alpha"])
# Load the adapter weights into the wrapped layers.
state = {n: p for n, p in adapter["lora"].items()}
model.load_state_dict(state, strict=False)
return model
The base checkpoint stays the same on disk — multiple LoRA adapters can sit beside it (lora_chat.pt, lora_summarize.pt, etc.) and you swap them at load time. In production, this is how systems like Stable Diffusion XL serve hundreds of style adapters from one base model.
What we did and didn’t do
What we did:
- LoRA wrapper that freezes the base linear and adds a rank-
rdelta add_lora_to_modelto swap attention linears in place- A fine-tuning loop that trains only LoRA params (1/108 of the model)
- Saved adapter checkpoint that’s tiny (~200 KB vs ~21 MB for the full model)
What we didn’t:
- DPO or RLHF. Different fine-tuning paradigms — instead of next-token prediction on labeled data, you optimize directly on preferences. The RLHF Preference demo shows the underlying preference model. Modern post-training stacks combine SFT (what we did) with DPO/RLHF on top.
- QLoRA. Quantize the base to 4-bit, then train LoRA on top. Cuts base memory by 4×; lets you fine-tune Llama-7B on a 16 GB GPU. The Quantization Lab shows what 4-bit weights look like.
- Apply LoRA to MLP layers too. Doubles trainable count, sometimes wins for harder tasks. Production rule of thumb: try attention-only first, add MLP LoRA if you’re underfitting.
- Merge LoRA back into the base. At inference time, you can compute
W' = W + B·Aonce and store/serve a merged checkpoint, eliminating the runtime cost of the adapter. Worth doing in production.
Cross-references
- LoRA Lab demo — visualize the rank-r approximation. Slide the rank, watch a small
B·Afactorization approximate a full weight update. - Quantization Lab demo — 4-bit base weights, the foundation of QLoRA.
- RLHF Preference demo — for the preference-learning side of post-training.
Next
Step 13 is evaluation. We’ve trained, we’ve fine-tuned, but how do we know the model is actually any good? Perplexity, generation quality, lightweight LLM-as-judge — three lenses, none of which is sufficient alone, all of which together let you make decisions instead of vibes.