tiny-llm 14 / 16 24 min read · 35 min hands-on

step 14 · build

Inference: KV cache + ONNX export

Two operations that make generation production-fast: cache K and V across steps, then export the whole model to a portable ONNX file.

inference

The model from step 09 generates tokens by running the full forward pass on the entire sequence at every step. For a 1000-token output, you do 1000 forward passes on sequences of length 1, 2, 3, …, 1000 — that’s a sum of T(T+1)/2 = 500,500 token-equivalents of compute. Quadratic in output length.

But notice: when you generate token 500, the keys and values for tokens 0–499 don’t change. Their projections were already computed during step 499. Recomputing them is wasted work.

KV caching is the trick. Save the K and V tensors per layer, per generation step, and only compute new ones for the freshly-emitted token. Compute drops from quadratic to linear per token, which is the difference between “ChatGPT runs in real time” and “ChatGPT takes minutes per response.”

Then we’ll export the whole model to ONNX — a portable format that lets us run the same network in JavaScript, mobile runtimes, or any C++ inference engine, without dragging PyTorch along. Step 15 (the browser capstone) consumes the ONNX export.

Why KV cache works

Look at what happens during attention, step by step, as we generate:

Step T=1: input is [t0]              compute Q1, K1, V1.  attend → output token t1.
Step T=2: input is [t0, t1]          compute Q1..2, K1..2, V1..2.  attend → output token t2.
Step T=3: input is [t0, t1, t2]      compute Q1..3, K1..3, V1..3.  attend → output token t3.

At step T, we only need the Q for the new last position (the only query that’s about to read). But we need K and V for all positions. The K and V for position 0 don’t change between step 1 and step 1000 — they’re a function of token t0 and the (frozen) projection weights only.

KV cache: store K and V from previous steps, append new K, V each step, run attention with Q shape (B, H, 1, d_head) against K, V shape (B, H, T, d_head).

The change to compute per token:

Without cacheWith cache
Forward FLOPs at step T∝ T²∝ T (only last token’s Q × cached K)
Memoryinput onlyinput + K, V cache (T × d_model × 2)
Per-token wallclockgrows quadraticallyroughly constant

The KV Cache demo toggles caching on and off and shows the per-token compute curve — that’s the picture in numbers.

What we change

We need attention to optionally accept previous K, V and return updated K, V. The minimal change is a method forward_with_cache(x, kv_cache) that returns (out, new_kv_cache).

For simplicity we’ll keep the original forward for training (where we always feed the full sequence) and add a separate cached path for inference.

# tiny_llm/mha.py — add to MultiHeadAttention class
@torch.no_grad()
def forward_cached(
    self,
    x: torch.Tensor,
    kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
    """Forward pass for inference with KV caching.

    Args:
        x: (B, T_new, d_model) — only the *new* tokens this step.
                                  Usually T_new=1 during sampling.
        kv_cache: tuple of (K_prev, V_prev) each shape (B, H, T_prev, d_head),
                  or None for the first step.

    Returns:
        out: (B, T_new, d_model)
        new_cache: (K, V) including the new K, V appended.
    """
    B, T_new, D = x.shape
    H = self.n_heads
    d_head = self.d_head

    # Project new Q, K, V — same as forward, but only on the new tokens.
    qkv_new = self.qkv(x)
    q_new, k_new, v_new = qkv_new.chunk(3, dim=-1)
    q_new = q_new.view(B, T_new, H, d_head).transpose(1, 2)  # (B, H, T_new, d_head)
    k_new = k_new.view(B, T_new, H, d_head).transpose(1, 2)
    v_new = v_new.view(B, T_new, H, d_head).transpose(1, 2)

    # Append to cache.
    if kv_cache is not None:
        k_prev, v_prev = kv_cache
        k = torch.cat([k_prev, k_new], dim=2)   # along the time axis
        v = torch.cat([v_prev, v_new], dim=2)
    else:
        k, v = k_new, v_new

    T_total = k.size(2)

    # Scaled dot-product. Q only has T_new positions; K/V have T_total.
    scores = q_new @ k.transpose(-2, -1) / math.sqrt(d_head)   # (B, H, T_new, T_total)

    # Causal mask: each query at position (T_total - T_new + i) can attend
    # to keys at positions 0..(T_total - T_new + i).
    # In the common single-token case (T_new=1), the mask is all-ones.
    if T_new > 1:
        mask = self.causal_mask[T_total - T_new : T_total, :T_total]
        scores = scores.masked_fill(~mask, float("-inf"))

    weights = F.softmax(scores, dim=-1)
    out = weights @ v                                          # (B, H, T_new, d_head)
    out = out.transpose(1, 2).contiguous().view(B, T_new, D)
    return self.W_o(out), (k, v)

The body is nearly identical to the training-time forward. The differences:

  • q_new is computed only for the new tokens — that’s the win.
  • K, V are appended to the cache before scoring.
  • Causal mask sliced differently — but in the single-new-token case (T_new=1), the new query attends to all cached positions, so the mask is trivially all-True.

We add the corresponding cached forward to the Block class:

# tiny_llm/block.py — add to Block class
@torch.no_grad()
def forward_cached(
    self,
    x: torch.Tensor,
    kv_cache: tuple[torch.Tensor, torch.Tensor] | None,
):
    attn_out, new_cache = self.attn.forward_cached(self.ln_1(x), kv_cache)
    x = x + attn_out
    x = x + self.mlp(self.ln_2(x))
    return x, new_cache

And then to GPT:

# tiny_llm/gpt.py — add to GPT class
@torch.no_grad()
def forward_cached(
    self,
    token_ids: torch.Tensor,
    caches: list[tuple[torch.Tensor, torch.Tensor]] | None = None,
):
    """Forward pass that maintains a per-block KV cache for fast generation.

    token_ids: (B, T_new). Pass the full prompt on first call (caches=None);
               on subsequent calls pass T_new=1 (just the last sampled token).
    caches: list of len n_layers, one (K, V) tuple per block, or None.

    Returns:
        logits: (B, T_new, vocab_size)
        new_caches: list of (K, V) — to pass on the next call.
    """
    x = self.embed(token_ids)                # (B, T_new, d_model)
    new_caches = []
    for i, block in enumerate(self.blocks):
        cache_i = caches[i] if caches is not None else None
        x, new_cache_i = block.forward_cached(x, cache_i)
        new_caches.append(new_cache_i)
    x = self.ln_f(x)
    logits = self.lm_head(x)
    return logits, new_caches

And a faster sample_cached() that uses it:

# tiny_llm/gpt.py — add to GPT class
@torch.no_grad()
def sample_cached(
    self,
    prompt_ids: torch.Tensor,
    max_new_tokens: int,
    temperature: float = 0.8,
    top_p: float = 0.95,
) -> torch.Tensor:
    """Same as sample() but with KV caching — much faster for long outputs."""
    self.eval()
    ids = prompt_ids

    # First pass: process the entire prompt, build caches.
    logits, caches = self.forward_cached(ids)

    for _ in range(max_new_tokens):
        # Sample from the logits at the last position.
        next_logits = logits[:, -1, :] / temperature
        # (top-p filter elided for brevity — same as in sample())
        probs = F.softmax(next_logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        ids = torch.cat([ids, next_id], dim=1)

        # Forward only the new token, reusing caches.
        logits, caches = self.forward_cached(next_id, caches)

    return ids

Speed sanity check

# tiny_llm/gpt.py
if __name__ == "__main__":
    import time

    config = GPTConfig()
    model = GPT(config)
    model.eval()

    prompt = torch.randint(0, config.vocab_size, (1, 8))

    # Without cache
    t0 = time.time()
    out_uncached = model.sample(prompt, max_new_tokens=200, temperature=0.8, top_p=0.95)
    t_uncached = time.time() - t0

    # With cache
    t0 = time.time()
    out_cached = model.sample_cached(prompt, max_new_tokens=200, temperature=0.8, top_p=0.95)
    t_cached = time.time() - t0

    print(f"sample() uncached:    {t_uncached:.2f}s ({200/t_uncached:.1f} tok/s)")
    print(f"sample_cached():       {t_cached:.2f}s ({200/t_cached:.1f} tok/s)")
    print(f"speedup: {t_uncached / t_cached:.1f}×")

Expected output (CPU, the 5M-param model):

sample() uncached:    18.4s (10.9 tok/s)
sample_cached():       2.1s (95.2 tok/s)
speedup: 8.8×

8–10× isn’t unusual for output-length-200 with our config. The longer the output, the bigger the speedup; for 1000-token outputs it’s typically 30–50×.

Memory cost of caching

The KV cache occupies real memory. Per token, per layer:

2 (K + V) × n_heads × d_head = 2 × d_model floats

For our MEDIUM config (d_model=768, n_layers=12, max_seq_len=1024, batch_size=1):

2 × 768 × 12 × 1024 = ~19M floats = ~75 MB at fp32

Tractable for a single sequence. Becomes the dominant memory cost in production: a 70B-param model serving 32 concurrent 4096-token contexts is paying gigabytes in KV cache. This is why paged attention (vLLM) and multi-query attention (where K and V are shared across heads to shrink cache size) exist.

ONNX export

Once the model is fast and trained, you might want to run it without PyTorch. ONNX (Open Neural Network Exchange) is a portable format consumed by many runtimes:

  • ONNX Runtime (Microsoft) — fast C++ inference on CPU/GPU, used by VS Code, Office, browsers
  • onnxruntime-web — runs in the browser via WebAssembly + WebGPU (step 15 will use this)
  • CoreML (Apple) — convert ONNX → CoreML for iOS deployment
  • TensorRT (NVIDIA) — high-performance inference on NVIDIA GPUs

The export call:

# tiny_llm/export.py
import torch
from pathlib import Path
from tiny_llm.gpt import GPT


def export_to_onnx(
    model: GPT,
    output_path: Path,
    max_seq_len: int = 256,
    opset_version: int = 17,
):
    """Export model to ONNX format for portable inference.

    The exported model takes a (B, T) integer tensor and returns
    (B, T, vocab_size) logits. KV caching support in ONNX needs
    a more elaborate export that we skip here; this version is fine
    for one-shot inference (which is what step 15 needs).
    """
    model.eval()
    dummy_input = torch.randint(0, model.config.vocab_size, (1, max_seq_len))

    torch.onnx.export(
        model,
        (dummy_input,),
        output_path,
        input_names=["token_ids"],
        output_names=["logits"],
        # Dynamic axes let the exported model handle variable batch + seq len.
        dynamic_axes={
            "token_ids": {0: "batch", 1: "seq_len"},
            "logits":    {0: "batch", 1: "seq_len"},
        },
        opset_version=opset_version,
        do_constant_folding=True,
    )
    print(f"exported to {output_path} ({output_path.stat().st_size / 1e6:.1f} MB)")


if __name__ == "__main__":
    ckpt = torch.load("checkpoints/best.pt", weights_only=False)
    model = GPT(ckpt["gpt_config"])
    model.load_state_dict(ckpt["model"])
    export_to_onnx(model, Path("checkpoints/tiny_llm.onnx"))
uv run python -m tiny_llm.export

Expected output:

exported to checkpoints/tiny_llm.onnx (21.5 MB)

The 21.5 MB number is the model’s parameters at fp32. You can quantize the ONNX file to int8 with onnxruntime.quantization.quantize_dynamic and shrink it ~4× — useful for browsers.

Loading and running the ONNX

Just to confirm the export works:

import numpy as np
import onnxruntime as ort

session = ort.InferenceSession("checkpoints/tiny_llm.onnx")
input_ids = np.random.randint(0, 4096, size=(1, 16)).astype(np.int64)
logits = session.run(["logits"], {"token_ids": input_ids})[0]

print(f"input shape:  {input_ids.shape}")
print(f"output shape: {logits.shape}")

Same outputs as PyTorch (modulo numerical noise at ~1e-6). No import torch required.

What we did and didn’t do

What we did:

  • KV cache support: forward_cached on each module, sample_cached on GPT
  • Speed sanity check showing ~10× speedup on the small config
  • ONNX export with dynamic batch + sequence axes
  • A working ONNX runtime smoke test

What we didn’t:

  • KV cache in the ONNX export. The simple export above always recomputes K and V; for production browser inference you’d want a stateful ONNX graph. The proper way involves pass-the-state-through-as-input/output graph rewrites; tools like optimum automate this for HuggingFace models, but it’s manual for our custom architecture. For step 15 (browser capstone) we use small enough sequences that the no-cache export is fine.
  • Quantization. ONNX supports int8/int4 quantization for ~4× size reduction. onnxruntime.quantization.quantize_dynamic is the one-liner; the Quantization Lab demo shows what it does to the weight distributions.
  • Multi-query / grouped-query attention. The bigger your model, the more KV cache memory dominates inference. MQA shares K/V across heads (8× memory cut for our 8-head config); GQA does it for groups. LLaMA-2 70B uses GQA. Architectural change; we’d revisit mha.py.
  • Speculative decoding. Use a small “draft” model to propose multiple tokens, then verify with the big model in parallel. ~2–3× speedup for free at inference time. Different mechanism; doesn’t change our code, just adds a wrapper.

Cross-references

Next

Step 15 is the capstone. We export your trained model to ONNX, load it in the browser via onnxruntime-web, and let the reader generate text from your tiny LLM inside their browser tab. No Python, no Colab, no server — just a model file and a web page. That’s the closing image of the curriculum: a thing you wrote running, anywhere.