step 10 · build
Sampling and decoding
Greedy, temperature, top-k, top-p (nucleus). Same model, four very different writers.
Step 09’s model.generate() is greedy: at every step, pick the token with the highest logit. That’s a deterministic, reproducible policy — and a boring one. Greedy decoding produces bland, repetitive text because high-probability tokens lead to high-probability tokens, which leads to clichés.
This step adds three knobs that turn a deterministic argmax into a sampler with character: temperature, top-k, and top-p (nucleus). Same model, very different output styles. By the end you’ll have a sample() method on GPT and a clear sense of when to use which knob.
If you’ve used the Sampling Knobs demo, this is the algorithm behind it — the demo runs these knobs on real GPT-2 logits.
Why greedy is bad
A trained model has a probability distribution over the vocab at each step. Greedy commits to the mode every time. Two failure modes:
-
Repetition loops. Greedy is prone to “the cat sat on the mat. the cat sat on the mat. the cat sat on the mat…” The argmax of the distribution after “mat.” nudges back into the same starting state, and the model has no memory of having said it before. Sampling breaks the loop because we occasionally pick the second- or third-likeliest token.
-
Bland output. The mode of a distribution isn’t always the most interesting choice. If the model thinks “happy” has 0.30, “joyful” has 0.28, “delighted” has 0.20, greedy picks “happy” 100% of the time. Sampling preserves variety.
What we want: a way to pick from the high-probability region of the distribution, but not always the single peak. Three knobs for that.
Knob 1: Temperature
Temperature divides the logits before softmax:
probs = softmax(logits / T)
T = 1.0(default): unchanged distributionT < 1.0(e.g. 0.7): sharpens — the peak gets sharper, low-probability tokens get even lowerT > 1.0(e.g. 1.5): flattens — distribution gets closer to uniformT → 0: reduces to argmax (greedy)T → ∞: reduces to uniform random
In practice you want T ∈ [0.6, 1.0] for most generation. Higher than that produces incoherent text; lower flattens to greedy.
Conceptually, temperature is a boldness knob — how willing is the model to deviate from its top guess?
Knob 2: top-k
After softmax, keep only the top k tokens and renormalize. Sample from those.
k = 1: greedyk = 50: only consider the 50 most likely tokens, sample from them weighted by probabilityk = vocab_size: identical to plain sampling (no truncation)
The intuition: “the model thinks 47 tokens are plausible; ignore everything outside that group.” Bounds how strange the output can get.
Drawback: k is a fixed cap regardless of how confident the model is. If the model is very confident (one token has 0.95 probability), you’d want k = 1 for that step. If it’s uncertain (top tokens are 0.05 each), you’d want k = 100. Top-k can’t adapt — same k everywhere.
Knob 3: top-p (nucleus)
Top-p fixes the adaptiveness problem. Instead of “keep the top k”, keep the smallest set of tokens whose cumulative probability exceeds p, then renormalize and sample.
- Concrete:
p = 0.9means “keep the smallest set that contains 90% of the total probability mass.” - If the distribution is sharp (one token has 0.95), the kept set is just that one token. Acts like greedy.
- If it’s flat, the kept set might be 50+ tokens. Acts like wider sampling.
Top-p adapts to the model’s confidence at every step. It’s the production default for chat-style generation — what ChatGPT, Claude, and most APIs use under the hood.
All three together
Real samplers combine them. The standard pipeline:
logits → divide by T → top-k filter → top-p filter → softmax → multinomial sample
You generally use temperature with one of top-k or top-p (rarely both, though it’s safe — they compose). The “right” defaults vary by use case:
| Use case | T | top-k | top-p |
|---|---|---|---|
| Creative writing | 1.0 | — | 0.95 |
| Chat / Q&A | 0.7 | 50 | 0.9 |
| Code generation | 0.2–0.4 | — | — |
| Deterministic eval | 0 (greedy) | 1 | — |
Code generation prefers low temperature because there’s usually one correct answer; creative writing prefers high temperature because there are many good outputs.
Setup
We’re going to add a sample() method directly to the GPT class in tiny_llm/gpt.py. Open that file and add this method alongside generate:
# tiny_llm/gpt.py (new method on GPT)
@torch.no_grad()
def sample(
self,
prompt_ids: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: int | None = None,
top_p: float | None = None,
) -> torch.Tensor:
"""Generate `max_new_tokens` tokens from `prompt_ids` using
temperature + (optional) top-k + (optional) top-p sampling.
prompt_ids: (B, T_prompt) — starting token IDs
returns: (B, T_prompt + max_new_tokens)
For deterministic argmax, set temperature=0 (or use .generate()).
"""
self.eval()
ids = prompt_ids
for _ in range(max_new_tokens):
# Truncate to context window.
ids_cond = ids if ids.size(1) <= self.config.max_seq_len else ids[:, -self.config.max_seq_len:]
logits = self(ids_cond) # (B, T, vocab_size)
next_logits = logits[:, -1, :] # (B, vocab_size)
if temperature == 0:
# Greedy.
next_id = next_logits.argmax(dim=-1, keepdim=True)
else:
# Apply temperature.
next_logits = next_logits / temperature
# Top-k filter: zero out everything outside the top-k.
if top_k is not None and top_k > 0:
top_vals, _ = torch.topk(next_logits, top_k, dim=-1)
kth_value = top_vals[:, -1:].expand_as(next_logits)
next_logits = torch.where(
next_logits < kth_value,
torch.full_like(next_logits, float("-inf")),
next_logits,
)
# Top-p filter: zero out tail past cumulative probability p.
if top_p is not None and 0.0 < top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(next_logits, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cum_probs = sorted_probs.cumsum(dim=-1)
# Tokens whose cumulative probability exceeds p (after themselves)
# are the ones to remove. We keep the first one that crosses p.
remove = cum_probs > top_p
# Shift right by one so we always keep at least one token.
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
# Scatter the removal mask back to the original token order.
mask = torch.zeros_like(next_logits, dtype=torch.bool)
mask.scatter_(-1, sorted_idx, remove)
next_logits = next_logits.masked_fill(mask, float("-inf"))
# Softmax to probabilities, sample.
probs = torch.softmax(next_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
ids = torch.cat([ids, next_id], dim=1)
return ids
The two filter blocks are the crux. Let me walk through them.
Top-k. torch.topk returns the largest k values and their indices. The k-th-largest value defines the threshold; anything strictly below it gets -inf (zero after softmax). The -inf trick is the standard way to “delete” entries before softmax — softmax of -inf is exactly 0.
Top-p. This is the trickiest piece in the file. Three steps:
- Sort logits descending. Now the most-probable token is at index 0.
- Compute the cumulative softmax.
cum_probs[i]= probability of being in the top-(i+1) tokens. - Mark every position past the first one that crosses
pas removed. The “shift right by one” subtlety guarantees the first token to cross the threshold is kept (otherwise top-p would sometimes return the empty set when one token alone has > p probability).
We then scatter the per-rank removal mask back to original-vocab indices and apply it.
torch.multinomial(probs, num_samples=1) samples one token from each row’s probability distribution. That’s the actual stochastic step.
Sanity check
Add a small test that calls each sampling mode on a fresh untrained model — to confirm the shapes and that nothing crashes. Real quality is only verifiable on a trained checkpoint, which you’ll have after running step 09.
# tiny_llm/gpt.py (extend the __main__ block)
if __name__ == "__main__":
torch.manual_seed(0)
config = GPTConfig()
model = GPT(config)
prompt = torch.randint(0, config.vocab_size, (1, 8))
print("untrained sampling smoke test:")
print(f" greedy: {model.generate(prompt, 5)[0, -5:].tolist()}")
print(f" T=0: {model.sample(prompt, 5, temperature=0)[0, -5:].tolist()}")
print(f" T=0.7, top_k=50: {model.sample(prompt, 5, temperature=0.7, top_k=50)[0, -5:].tolist()}")
print(f" T=0.9, top_p=0.95: {model.sample(prompt, 5, temperature=0.9, top_p=0.95)[0, -5:].tolist()}")
uv run python -m tiny_llm.gpt
Expected:
untrained sampling smoke test:
greedy: [1234, 1234, 1234, 1234, 1234]
T=0: [1234, 1234, 1234, 1234, 1234]
T=0.7, top_k=50: [3287, 1819, 2740, 419, 3892]
T=0.9, top_p=0.95: [3554, 1276, 2988, 3119, 802]
Untrained model: greedy and T=0 give the same thing (some constant token); the two stochastic modes give different sequences. After training (step 09’s train.py sample), try each on the real prompt:
# In a quick scratch script after training
prompt = "Once upon a time, in a forest far away, there was"
ids = torch.tensor([tok.encode(prompt)])
for label, kwargs in [
("greedy", {"temperature": 0}),
("T=0.7", {"temperature": 0.7}),
("top-k=40", {"temperature": 0.9, "top_k": 40}),
("top-p=0.9", {"temperature": 0.9, "top_p": 0.9}),
]:
out = model.sample(ids, max_new_tokens=80, **kwargs)
print(f"\n--- {label} ---\n{tok.decode(out[0].tolist())}")
You’ll see greedy is the most repetitive (“a small bear who lived in a small house in a small forest”), top-p is the most natural (“a small bear who loved to find berries and tell stories to the smaller animals”).
What we did and didn’t do
What we did:
- Temperature, top-k, top-p sampling all in one method
- Numerically stable via
-infmasking before softmax - Greedy as a special case (
temperature=0) @torch.no_grad()to skip gradient bookkeeping during generation
What we didn’t:
- Beam search. Maintains
Kparallel hypotheses and picks the highest-scoring complete sequence. Used in machine translation and summarization where there’s a “right answer.” For chat / creative generation, beam search produces blander output than top-p sampling (and is much slower). The Beam Search demo shows the search tree expanding if you want to see it. - Repetition penalties. Some generators add a penalty proportional to how many times a token has already appeared. Useful as a band-aid; better to fix the underlying issue (untrained model, or low temperature, or short context).
- Logit biases. Force certain tokens to be more or less likely (e.g., to always begin with “Sure,” or to never produce profanity). Add later if you need it.
- Speculative decoding. Use a small “draft” model to propose multiple tokens, verify with a big model in parallel. ~2–3× faster inference. The Cost & Latency Calculator demo compares throughput with and without it.
Cross-references
- The Sampling Knobs demo lets you slide temperature, top-k, and top-p on real GPT-2 logits and see the distribution reshape live. Open it and play with it for two minutes — it’s the fastest way to build intuition for what these knobs do to the probability mass.
- The Beam Search demo for the deterministic-search alternative.
Next
Step 11 is the scaling article. Same architecture you’ve built, but with bigger configs: 1M → 10M → 100M parameters. What changes? Learning rate, batch size, training tokens, training time. You’ll see the Chinchilla scaling laws made concrete on the model you wrote.