production-stack production 17 / 17 26 min read · 1h 30m hands-on

step 17 · ship · production

Synthetic data + distillation

Compress a frontier model into a small specialist for ~10× the cost reduction. The pipeline behind every cost-conscious production deploy.

distillationsynthetic-datacostfine-tuning

You’ve shipped step 15. The bot works. The bill is real. A frontier model on every request is the most expensive way to do what you’re doing — and for ~80% of your traffic, the bot doesn’t actually need a frontier model. It needs the policy your frontier model has learned for your domain, in a smaller faster cheaper container.

This is what distillation buys you: you take the frontier teacher’s behavior on your prompts, capture it as a dataset, train a much smaller student to imitate it, and ship the student. The math is two equations. The pipeline is what we’re building in this step.

By the end you’ll have a Llama-3.1-8B that handles 92% of your traffic at 1/10th the cost, and a frontier-model fallback for the 8% that still needs the big model.

Why pair synthetic data with distillation

These two are not separate topics. Distillation needs a dataset; the dataset comes from the teacher. The teacher is the labeler. Pair them together because:

  • One pipeline produces both. A single frontier-model run gives you both the inputs (prompts) and the outputs (labels) you train the student on.
  • The “soft” labels (full distributions) are what makes distillation work. The student learns from the teacher’s uncertainty, not just the argmax — which is information you can only get from the teacher running, not from human annotation.
  • The economics flip together. Synthetic data is ~$1–10/M generated tokens; the resulting trained student costs ~10× less than the teacher per request. The cost trade is “spend $5K once on data, save $5K/month forever after.”

Stage 1 — generate synthetic data with the teacher

The teacher is your frontier model. The student is whatever small base you’re targeting. The dataset is what comes out of the teacher running on representative prompts.

Step 1.1 — seed prompts

Don’t generate “general” data. Generate data that matches your traffic distribution. The cheapest way: sample from your production logs (with PII redacted), cluster them, pick representatives.

# stack/distill.py
from __future__ import annotations
import json
from collections import defaultdict
from pathlib import Path
from typing import Callable, Iterator

from stack.llm import LLM
from stack.embed import Embedder

def sample_seed_prompts(
    prod_logs: Path,
    n_seeds: int = 200,
    embedder: Embedder | None = None,
) -> list[str]:
    """Pick `n_seeds` representative prompts from production logs.

    Strategy: embed all prompts, k-means into n_seeds clusters, pick the
    medoid of each cluster. Captures the diversity of real traffic without
    1000× duplicates of the most common query.
    """
    embedder = embedder or Embedder()
    lines = prod_logs.read_text().splitlines()
    prompts = [json.loads(l)["prompt"] for l in lines if l.strip()]
    if len(prompts) <= n_seeds:
        return prompts

    embs = embedder.embed_batch(prompts)
    from sklearn.cluster import KMeans
    km = KMeans(n_clusters=n_seeds, random_state=42, n_init=10).fit(embs)
    medoids = []
    for c in range(n_seeds):
        members = [(i, e) for i, e in enumerate(embs) if km.labels_[i] == c]
        # Medoid = closest member to the cluster centroid
        center = km.cluster_centers_[c]
        best = min(members, key=lambda x: ((x[1] - center) ** 2).sum())
        medoids.append(prompts[best[0]])
    return medoids

200 seed prompts is enough to anchor the distribution. We’ll multiply this 50× via teacher-side augmentation in step 1.2.

Step 1.2 — let the teacher write more data

For each seed prompt, ask the teacher to paraphrase it 50 ways. Each paraphrase becomes a new training prompt. Cheap, effective, and stays inside the production distribution.

PARAPHRASE_PROMPT = """\
The user might phrase the same intent in many ways. Generate {n} different
ways a real user could ask the same question. Vary the formality, length,
and word choice. Output one paraphrase per line, no numbering.

Original: {seed}
"""

def expand_prompts(
    seeds: list[str],
    teacher: LLM,
    multiplier: int = 50,
) -> list[str]:
    """For each seed, generate `multiplier` paraphrases via the teacher."""
    out: list[str] = []
    for seed in seeds:
        resp = teacher.chat([
            {"role": "user", "content": PARAPHRASE_PROMPT.format(n=multiplier, seed=seed)}
        ], temperature=0.9)   # high T = diversity
        text = resp["choices"][0]["message"]["content"] or ""
        out.extend([line.strip() for line in text.split("\n") if line.strip()])
    return out

200 seeds × 50 paraphrases = 10K prompts. ~$2 of teacher cost at frontier prices for the paraphrasing step.

Step 1.3 — the teacher labels

For each prompt, run the teacher in full inference mode (the same pipeline that handles real traffic). Record the response and the top-K logits at each generated token — these soft labels are what you’ll distill on.

def label_with_teacher(
    prompts: list[str],
    teacher: LLM,
    *,
    capture_logprobs: bool = True,
    top_k: int = 20,
) -> Iterator[dict]:
    """For each prompt, get the teacher's response + per-token top-K logprobs.

    The top-K logprobs are the "soft labels" — the teacher's full
    distribution, not just argmax. This is what makes distillation
    transfer more knowledge than plain SFT.
    """
    for prompt in prompts:
        resp = teacher.chat([{"role": "user", "content": prompt}],
            temperature=0.0,                # deterministic teacher
            logprobs=capture_logprobs,
            top_logprobs=top_k,
        )
        choice = resp["choices"][0]
        # Soft labels are at choice.logprobs.content[i].top_logprobs
        # Each entry: [{token, logprob}, ...] of length top_k
        yield {
            "prompt": prompt,
            "response": choice["message"]["content"],
            "soft_labels": choice.get("logprobs", {}).get("content", []),
        }

Cost note: capturing top-K logprobs roughly triples API cost (the provider has to send more data per token). Worth it — without soft labels, you’re back to plain SFT, which is ~30% less effective per training example.

Step 1.4 — quality filter

Not every teacher response is good. The teacher hallucinates ~3–5% of the time even on its own labels. Filter them.

QUALITY_RUBRIC = """\
Score this response 1–5 on whether it's a high-quality answer to the prompt.
Penalize: hallucinated facts, wrong format, off-topic, refusals where the
question is reasonable. Reward: directly answering, correct format,
appropriate uncertainty.

Prompt: {prompt}
Response: {response}

Output ONLY a JSON object: {{"score": <int>, "reason": "<short>"}}.
"""

def quality_filter(
    examples: list[dict],
    judge: LLM,
    threshold: int = 4,
) -> list[dict]:
    """Drop examples where a separate judge model scores below threshold."""
    kept = []
    for ex in examples:
        resp = judge.chat([
            {"role": "user", "content": QUALITY_RUBRIC.format(
                prompt=ex["prompt"], response=ex["response"])}
        ], temperature=0.0)
        try:
            obj = json.loads(resp["choices"][0]["message"]["content"] or "{}")
            if int(obj.get("score", 0)) >= threshold:
                kept.append(ex)
        except Exception:
            pass   # malformed judge output → drop
    return kept

Typical filter rate: 5–15% of examples drop. You’re spending ~$3 per 10K examples on judging; the alternative is shipping a student trained on garbage. Always filter.

Step 1.5 — dedupe

Paraphrasing can produce near-duplicates that bias training. Dedupe via embedding cosine similarity:

def dedupe(examples: list[dict], embedder: Embedder, threshold: float = 0.95) -> list[dict]:
    """Drop examples whose prompt is ≥0.95 cosine to a kept example."""
    kept: list[dict] = []
    kept_embs: list[list[float]] = []
    for ex in examples:
        emb = embedder.embed(ex["prompt"])
        if any(_cos(emb, k) > threshold for k in kept_embs):
            continue
        kept.append(ex)
        kept_embs.append(emb)
    return kept

def _cos(a: list[float], b: list[float]) -> float:
    import numpy as np
    a, b = np.array(a), np.array(b)
    return float(a @ b / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-9))

Typical dedupe rate: 5–15% of examples drop. The training signal gets sharper.

After all four filter passes, ~10K → ~7K training examples. That’s enough.

Stage 2 — distill the student

Now the math. You have your dataset of (prompt, teacher_response, teacher_top_K_logprobs). The student is a small base (Llama-3.1-8B in our running example).

The two-loss formulation

L_total = α · L_hard + (1 − α) · L_soft

L_hard = −Σ_t  log p_student(t̂_t | prefix)        # standard cross-entropy on teacher's argmax tokens
L_soft = T² · Σ_t  KL( p_student(·/T) || p_teacher(·/T) )    # KL divergence on softened distributions

Two terms:

  • Hard loss = standard SFT. Train the student to predict the teacher’s picked token at each position. This is what plain fine-tuning does.
  • Soft loss = KL divergence between the full distributions of student and teacher (after softening with temperature T). This captures the teacher’s uncertainty — when it was 60/30/10 between three options, the student learns to be 60/30/10, not just 100/0/0.

The temperature T softens the distribution; at T=1, soft loss = standard KL; at T=4, both distributions are flattened and the student learns the relative preferences over the top tokens (which is more learnable than the sharp peaks). The T² factor compensates for the gradient scaling that softening introduces.

α (the mix) is typically 0.3–0.5 — soft labels carry more signal but hard labels stabilize training.

The training loop

import torch
import torch.nn.functional as F

def distill_step(
    student: torch.nn.Module,
    batch: dict,                  # contains prompt_ids, teacher_top_logprobs, teacher_top_token_ids
    *,
    T: float = 2.0,
    alpha: float = 0.3,
    optimizer: torch.optim.Optimizer,
):
    prompt_ids = batch["prompt_ids"]                      # [B, seq_len]
    teacher_top_ids = batch["teacher_top_token_ids"]      # [B, gen_len, K]
    teacher_top_logp = batch["teacher_top_logprobs"]      # [B, gen_len, K]

    # Forward pass on student. We need student's logits at each
    # generation position.
    logits = student(prompt_ids).logits                   # [B, full_len, vocab]
    # Slice to just the generation positions
    gen_logits = logits[:, -teacher_top_ids.shape[1]:, :] # [B, gen_len, vocab]

    # ── HARD LOSS: cross-entropy on teacher's argmax token at each pos ──
    teacher_argmax = teacher_top_ids[:, :, 0]             # [B, gen_len]
    L_hard = F.cross_entropy(
        gen_logits.reshape(-1, gen_logits.size(-1)),
        teacher_argmax.reshape(-1),
        reduction="mean",
    )

    # ── SOFT LOSS: KL divergence on softened distributions ──
    # Build student's distribution over the teacher's top-K tokens at each pos.
    # We restrict to top-K because that's what the teacher gave us; softmax
    # over the K teacher tokens (renormalized).
    student_logits_at_K = torch.gather(
        gen_logits, dim=-1, index=teacher_top_ids)        # [B, gen_len, K]
    student_logp_at_K = F.log_softmax(student_logits_at_K / T, dim=-1)
    teacher_logp_at_K = F.log_softmax(teacher_top_logp / T, dim=-1)

    L_soft = F.kl_div(
        student_logp_at_K,
        teacher_logp_at_K,
        reduction="batchmean",
        log_target=True,
    ) * (T ** 2)

    # Combined loss
    L = alpha * L_hard + (1 - alpha) * L_soft

    optimizer.zero_grad()
    L.backward()
    optimizer.step()
    return {"L_total": L.item(), "L_hard": L_hard.item(), "L_soft": L_soft.item()}

Hyperparameters that matter:

  • T (temperature) — start at 2; push to 4 if the student’s KL drops too quickly. Higher T = softer = more “ranked-list” learning.
  • α (loss mix) — start at 0.3; go to 0.5 if training is unstable. Soft loss has more signal but is also noisier.
  • Learning rate — same as fine-tuning, typically 1e-5 to 5e-5 for LoRA on a small base.
  • Batch size — bigger is better for KL stability. Aim for effective batch ≥ 16.

Distill on top of LoRA, not full

You almost never want to distill into the base weights of the student. LoRA on top is cheaper, faster, and usually works just as well. The result: a tiny adapter (~50 MB) that, loaded on top of Llama-3.1-8B, does what the teacher does in your domain.

If you’ve done /ship/16’s LoRA experiments, this is the same training loop, just with the soft-label term added.

Stage 3 — measure the gap

This is where most teams skip and ship a regression. Don’t.

# stack/distill.py (continued)
from stack.eval import grade_judge

def measure_distillation_gap(
    teacher: LLM,
    student: LLM,
    eval_cases: list[dict],
    judge: LLM,
) -> dict:
    """Run both teacher and student on the same eval set; compute the gap.

    Returns per-case scores plus aggregate stats. The gap is the
    quality you're trading for cost.
    """
    teacher_scores = []
    student_scores = []
    for case in eval_cases:
        t_resp = teacher.chat(
            [{"role": "user", "content": case["input"]}],
            temperature=0.0,
        )["choices"][0]["message"]["content"]
        s_resp = student.chat(
            [{"role": "user", "content": case["input"]}],
            temperature=0.0,
        )["choices"][0]["message"]["content"]
        t_score = grade_judge(judge, case, t_resp)[0]
        s_score = grade_judge(judge, case, s_resp)[0]
        teacher_scores.append(t_score)
        student_scores.append(s_score)

    teacher_mean = sum(teacher_scores) / len(teacher_scores)
    student_mean = sum(student_scores) / len(student_scores)
    parity = student_mean / teacher_mean
    return {
        "teacher_mean": teacher_mean,
        "student_mean": student_mean,
        "parity": parity,
        "n_cases": len(eval_cases),
        "delta_per_case": [s - t for s, t in zip(student_scores, teacher_scores)],
    }

What to look at:

  • Parity ratio = student_mean / teacher_mean. Target ≥ 0.92. Below this, you’ve over-compressed.
  • Per-case deltas — sort them. The cases where the student is much worse than the teacher are your fallback set — route those to the teacher in production.
  • Where the gap concentrates — usually on long-context, multi-step reasoning, and edge cases. The student handles routine traffic; the teacher handles the tail.

Stage 4 — production routing

Real production stacks do routing: the cheap student handles 80% of traffic; the expensive teacher handles the 20% that needs it.

# stack/router.py
def route(prompt: str) -> "LLM":
    """Decide which model handles this request.

    Cheap heuristics first; escalate to teacher only when needed.
    """
    if len(prompt) > 8000:           # long context likely needs teacher
        return TEACHER
    if needs_reasoning(prompt):       # math, code, multi-step
        return TEACHER
    if is_high_stakes(prompt):        # billing, account changes, etc.
        return TEACHER
    return STUDENT

needs_reasoning and is_high_stakes can be:

  • a small classifier model trained on your eval set
  • a regex/keyword heuristic (“calculate”, “step-by-step”, “billing”)
  • a confidence-based router that runs the student first and escalates if its self-rated confidence is low

The router IS the cost lever. A 60/40 student/teacher split halves your bill vs all-teacher; a 90/10 split with proper routing matches teacher quality at ~15% of the cost.

The cost math

For a typical “compress an 8B model from a 70B teacher” run:

# one-time costs
synthetic data:   200 seeds × 50 paraphrases = 10K prompts
                   teacher labeling: 10K × ~500 tokens × $5/Mtok = $25
                   logprob capture (3× cost): + $50
                   judge filtering: 10K × $0.30/case = $30
                   total data cost: ~$105

distillation:     LoRA on 8B base, ~3 epochs, ~1 GPU-hour on H100
                   ~$3 of compute
                   total training cost: ~$3

# ongoing cost (per 1M production tokens)
all-teacher:      $5/Mtok input + $15/Mtok output = ~$10/Mtok blended
all-student:      $0.20/Mtok input + $0.60/Mtok output = ~$0.40/Mtok blended
80/20 routed:     0.8·$0.40 + 0.2·$10 = ~$2.30/Mtok blended

# break-even
total one-time:   ~$110
monthly traffic:  100M tokens
all-teacher cost: $1000/month
routed cost:      $230/month
savings:          $770/month
break-even:       <1 week

The 80/20 router is the typical sweet spot. The math gets even better at higher volumes — at 1B tokens/month, savings hit $7.7K/month and the engineering cost is in the noise.

Pitfalls

  • Train/eval contamination. Your synthetic data is similar to your eval set if both come from prod logs. Hold out a slice of prod traffic before generating the training data.
  • Distribution shift. Six months in, your traffic looks different. Re-distill quarterly. Set a calendar reminder.
  • The student “learns” the teacher’s failures. If the teacher hallucinates 5% of the time, the student inherits that. Quality filtering helps; pairing distillation with a small RLHF pass on the student helps more.
  • Underestimating the routing layer. Distillation gets you to “cheap student model”; routing gets you to “production system.” A great student with a bad router still costs more than it should.

What we shipped and didn’t ship

What we shipped:

  • A 200-seed → 10K-prompt synthetic-data pipeline using a frontier teacher
  • A two-loss distillation training step (hard label + soft label with KL)
  • A teacher/student parity measurement that surfaces per-case deltas
  • A production router that escalates the hard 20% to the teacher

What we didn’t:

  • Process reward models. For reasoning-model distillation, you’d capture per-step rewards, not just final outputs. Different pipeline.
  • Iterative self-distillation. Generate data with the student, train on it, generate again. Bootstraps quality but risks “stuck on bias” loops.
  • Cross-tokenizer alignment. As called out above.
  • Continual distillation. Refresh weekly as the teacher gets better. The plumbing is there but you need an eval pipeline that runs nightly.

Cross-references

Next

You now have one more lever: the size of the model you’re running, not just which model. That changes how you think about scaling: when traffic 10×s, you don’t necessarily need a 10× cheaper API contract — sometimes you need a 10× smaller specialist.

The /case-studies side of this is CS-05 — the distilled docs assistant: take the bot from /case-studies/01 (Llama-3.1-8B running everything) and distill its specific behavior into a Llama-3.2-1B. Same eval set; ~7× cost reduction; ~5% parity gap. That’s where this technique earns its place.