step 14 · build
Inference: KV cache + ONNX export
Two operations that make generation production-fast: cache K and V across steps, then export the whole model to a portable ONNX file.
The model from step 09 generates tokens by running the full forward pass on the entire sequence at every step. For a 1000-token output, you do 1000 forward passes on sequences of length 1, 2, 3, …, 1000 — that’s a sum of T(T+1)/2 = 500,500 token-equivalents of compute. Quadratic in output length.
But notice: when you generate token 500, the keys and values for tokens 0–499 don’t change. Their projections were already computed during step 499. Recomputing them is wasted work.
KV caching is the trick. Save the K and V tensors per layer, per generation step, and only compute new ones for the freshly-emitted token. Compute drops from quadratic to linear per token, which is the difference between “ChatGPT runs in real time” and “ChatGPT takes minutes per response.”
Then we’ll export the whole model to ONNX — a portable format that lets us run the same network in JavaScript, mobile runtimes, or any C++ inference engine, without dragging PyTorch along. Step 15 (the browser capstone) consumes the ONNX export.
Why KV cache works
Look at what happens during attention, step by step, as we generate:
Step T=1: input is [t0] compute Q1, K1, V1. attend → output token t1.
Step T=2: input is [t0, t1] compute Q1..2, K1..2, V1..2. attend → output token t2.
Step T=3: input is [t0, t1, t2] compute Q1..3, K1..3, V1..3. attend → output token t3.
At step T, we only need the Q for the new last position (the only query that’s about to read). But we need K and V for all positions. The K and V for position 0 don’t change between step 1 and step 1000 — they’re a function of token t0 and the (frozen) projection weights only.
KV cache: store K and V from previous steps, append new K, V each step, run attention with Q shape (B, H, 1, d_head) against K, V shape (B, H, T, d_head).
The change to compute per token:
| Without cache | With cache | |
|---|---|---|
| Forward FLOPs at step T | ∝ T² | ∝ T (only last token’s Q × cached K) |
| Memory | input only | input + K, V cache (T × d_model × 2) |
| Per-token wallclock | grows quadratically | roughly constant |
The KV Cache demo toggles caching on and off and shows the per-token compute curve — that’s the picture in numbers.
What we change
We need attention to optionally accept previous K, V and return updated K, V. The minimal change is a method forward_with_cache(x, kv_cache) that returns (out, new_kv_cache).
For simplicity we’ll keep the original forward for training (where we always feed the full sequence) and add a separate cached path for inference.
# tiny_llm/mha.py — add to MultiHeadAttention class
@torch.no_grad()
def forward_cached(
self,
x: torch.Tensor,
kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""Forward pass for inference with KV caching.
Args:
x: (B, T_new, d_model) — only the *new* tokens this step.
Usually T_new=1 during sampling.
kv_cache: tuple of (K_prev, V_prev) each shape (B, H, T_prev, d_head),
or None for the first step.
Returns:
out: (B, T_new, d_model)
new_cache: (K, V) including the new K, V appended.
"""
B, T_new, D = x.shape
H = self.n_heads
d_head = self.d_head
# Project new Q, K, V — same as forward, but only on the new tokens.
qkv_new = self.qkv(x)
q_new, k_new, v_new = qkv_new.chunk(3, dim=-1)
q_new = q_new.view(B, T_new, H, d_head).transpose(1, 2) # (B, H, T_new, d_head)
k_new = k_new.view(B, T_new, H, d_head).transpose(1, 2)
v_new = v_new.view(B, T_new, H, d_head).transpose(1, 2)
# Append to cache.
if kv_cache is not None:
k_prev, v_prev = kv_cache
k = torch.cat([k_prev, k_new], dim=2) # along the time axis
v = torch.cat([v_prev, v_new], dim=2)
else:
k, v = k_new, v_new
T_total = k.size(2)
# Scaled dot-product. Q only has T_new positions; K/V have T_total.
scores = q_new @ k.transpose(-2, -1) / math.sqrt(d_head) # (B, H, T_new, T_total)
# Causal mask: each query at position (T_total - T_new + i) can attend
# to keys at positions 0..(T_total - T_new + i).
# In the common single-token case (T_new=1), the mask is all-ones.
if T_new > 1:
mask = self.causal_mask[T_total - T_new : T_total, :T_total]
scores = scores.masked_fill(~mask, float("-inf"))
weights = F.softmax(scores, dim=-1)
out = weights @ v # (B, H, T_new, d_head)
out = out.transpose(1, 2).contiguous().view(B, T_new, D)
return self.W_o(out), (k, v)
The body is nearly identical to the training-time forward. The differences:
q_newis computed only for the new tokens — that’s the win.- K, V are appended to the cache before scoring.
- Causal mask sliced differently — but in the single-new-token case (T_new=1), the new query attends to all cached positions, so the mask is trivially all-True.
We add the corresponding cached forward to the Block class:
# tiny_llm/block.py — add to Block class
@torch.no_grad()
def forward_cached(
self,
x: torch.Tensor,
kv_cache: tuple[torch.Tensor, torch.Tensor] | None,
):
attn_out, new_cache = self.attn.forward_cached(self.ln_1(x), kv_cache)
x = x + attn_out
x = x + self.mlp(self.ln_2(x))
return x, new_cache
And then to GPT:
# tiny_llm/gpt.py — add to GPT class
@torch.no_grad()
def forward_cached(
self,
token_ids: torch.Tensor,
caches: list[tuple[torch.Tensor, torch.Tensor]] | None = None,
):
"""Forward pass that maintains a per-block KV cache for fast generation.
token_ids: (B, T_new). Pass the full prompt on first call (caches=None);
on subsequent calls pass T_new=1 (just the last sampled token).
caches: list of len n_layers, one (K, V) tuple per block, or None.
Returns:
logits: (B, T_new, vocab_size)
new_caches: list of (K, V) — to pass on the next call.
"""
x = self.embed(token_ids) # (B, T_new, d_model)
new_caches = []
for i, block in enumerate(self.blocks):
cache_i = caches[i] if caches is not None else None
x, new_cache_i = block.forward_cached(x, cache_i)
new_caches.append(new_cache_i)
x = self.ln_f(x)
logits = self.lm_head(x)
return logits, new_caches
And a faster sample_cached() that uses it:
# tiny_llm/gpt.py — add to GPT class
@torch.no_grad()
def sample_cached(
self,
prompt_ids: torch.Tensor,
max_new_tokens: int,
temperature: float = 0.8,
top_p: float = 0.95,
) -> torch.Tensor:
"""Same as sample() but with KV caching — much faster for long outputs."""
self.eval()
ids = prompt_ids
# First pass: process the entire prompt, build caches.
logits, caches = self.forward_cached(ids)
for _ in range(max_new_tokens):
# Sample from the logits at the last position.
next_logits = logits[:, -1, :] / temperature
# (top-p filter elided for brevity — same as in sample())
probs = F.softmax(next_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
ids = torch.cat([ids, next_id], dim=1)
# Forward only the new token, reusing caches.
logits, caches = self.forward_cached(next_id, caches)
return ids
Speed sanity check
# tiny_llm/gpt.py
if __name__ == "__main__":
import time
config = GPTConfig()
model = GPT(config)
model.eval()
prompt = torch.randint(0, config.vocab_size, (1, 8))
# Without cache
t0 = time.time()
out_uncached = model.sample(prompt, max_new_tokens=200, temperature=0.8, top_p=0.95)
t_uncached = time.time() - t0
# With cache
t0 = time.time()
out_cached = model.sample_cached(prompt, max_new_tokens=200, temperature=0.8, top_p=0.95)
t_cached = time.time() - t0
print(f"sample() uncached: {t_uncached:.2f}s ({200/t_uncached:.1f} tok/s)")
print(f"sample_cached(): {t_cached:.2f}s ({200/t_cached:.1f} tok/s)")
print(f"speedup: {t_uncached / t_cached:.1f}×")
Expected output (CPU, the 5M-param model):
sample() uncached: 18.4s (10.9 tok/s)
sample_cached(): 2.1s (95.2 tok/s)
speedup: 8.8×
8–10× isn’t unusual for output-length-200 with our config. The longer the output, the bigger the speedup; for 1000-token outputs it’s typically 30–50×.
Memory cost of caching
The KV cache occupies real memory. Per token, per layer:
2 (K + V) × n_heads × d_head = 2 × d_model floats
For our MEDIUM config (d_model=768, n_layers=12, max_seq_len=1024, batch_size=1):
2 × 768 × 12 × 1024 = ~19M floats = ~75 MB at fp32
Tractable for a single sequence. Becomes the dominant memory cost in production: a 70B-param model serving 32 concurrent 4096-token contexts is paying gigabytes in KV cache. This is why paged attention (vLLM) and multi-query attention (where K and V are shared across heads to shrink cache size) exist.
ONNX export
Once the model is fast and trained, you might want to run it without PyTorch. ONNX (Open Neural Network Exchange) is a portable format consumed by many runtimes:
- ONNX Runtime (Microsoft) — fast C++ inference on CPU/GPU, used by VS Code, Office, browsers
- onnxruntime-web — runs in the browser via WebAssembly + WebGPU (step 15 will use this)
- CoreML (Apple) — convert ONNX → CoreML for iOS deployment
- TensorRT (NVIDIA) — high-performance inference on NVIDIA GPUs
The export call:
# tiny_llm/export.py
import torch
from pathlib import Path
from tiny_llm.gpt import GPT
def export_to_onnx(
model: GPT,
output_path: Path,
max_seq_len: int = 256,
opset_version: int = 17,
):
"""Export model to ONNX format for portable inference.
The exported model takes a (B, T) integer tensor and returns
(B, T, vocab_size) logits. KV caching support in ONNX needs
a more elaborate export that we skip here; this version is fine
for one-shot inference (which is what step 15 needs).
"""
model.eval()
dummy_input = torch.randint(0, model.config.vocab_size, (1, max_seq_len))
torch.onnx.export(
model,
(dummy_input,),
output_path,
input_names=["token_ids"],
output_names=["logits"],
# Dynamic axes let the exported model handle variable batch + seq len.
dynamic_axes={
"token_ids": {0: "batch", 1: "seq_len"},
"logits": {0: "batch", 1: "seq_len"},
},
opset_version=opset_version,
do_constant_folding=True,
)
print(f"exported to {output_path} ({output_path.stat().st_size / 1e6:.1f} MB)")
if __name__ == "__main__":
ckpt = torch.load("checkpoints/best.pt", weights_only=False)
model = GPT(ckpt["gpt_config"])
model.load_state_dict(ckpt["model"])
export_to_onnx(model, Path("checkpoints/tiny_llm.onnx"))
uv run python -m tiny_llm.export
Expected output:
exported to checkpoints/tiny_llm.onnx (21.5 MB)
The 21.5 MB number is the model’s parameters at fp32. You can quantize the ONNX file to int8 with onnxruntime.quantization.quantize_dynamic and shrink it ~4× — useful for browsers.
Loading and running the ONNX
Just to confirm the export works:
import numpy as np
import onnxruntime as ort
session = ort.InferenceSession("checkpoints/tiny_llm.onnx")
input_ids = np.random.randint(0, 4096, size=(1, 16)).astype(np.int64)
logits = session.run(["logits"], {"token_ids": input_ids})[0]
print(f"input shape: {input_ids.shape}")
print(f"output shape: {logits.shape}")
Same outputs as PyTorch (modulo numerical noise at ~1e-6). No import torch required.
What we did and didn’t do
What we did:
- KV cache support:
forward_cachedon each module,sample_cachedon GPT - Speed sanity check showing ~10× speedup on the small config
- ONNX export with dynamic batch + sequence axes
- A working ONNX runtime smoke test
What we didn’t:
- KV cache in the ONNX export. The simple export above always recomputes K and V; for production browser inference you’d want a stateful ONNX graph. The proper way involves pass-the-state-through-as-input/output graph rewrites; tools like
optimumautomate this for HuggingFace models, but it’s manual for our custom architecture. For step 15 (browser capstone) we use small enough sequences that the no-cache export is fine. - Quantization. ONNX supports int8/int4 quantization for ~4× size reduction.
onnxruntime.quantization.quantize_dynamicis the one-liner; the Quantization Lab demo shows what it does to the weight distributions. - Multi-query / grouped-query attention. The bigger your model, the more KV cache memory dominates inference. MQA shares K/V across heads (8× memory cut for our 8-head config); GQA does it for groups. LLaMA-2 70B uses GQA. Architectural change; we’d revisit
mha.py. - Speculative decoding. Use a small “draft” model to propose multiple tokens, then verify with the big model in parallel. ~2–3× speedup for free at inference time. Different mechanism; doesn’t change our code, just adds a wrapper.
Cross-references
- KV Cache demo — toggle caching on and off, watch per-token compute change
- Inference Pipeline demo — the full forward pass, instrumented and animated
- Cost & Latency Calculator demo — KV cache + batching + speculative decoding side-by-side, with cost numbers
Next
Step 15 is the capstone. We export your trained model to ONNX, load it in the browser via onnxruntime-web, and let the reader generate text from your tiny LLM inside their browser tab. No Python, no Colab, no server — just a model file and a web page. That’s the closing image of the curriculum: a thing you wrote running, anywhere.