Skip to content

examples/causal_llm: speculative decoding (n-gram + draft-model)#2370

Open
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feat/speculative-decoding
Open

examples/causal_llm: speculative decoding (n-gram + draft-model)#2370
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feat/speculative-decoding

Conversation

@czoli1976

Copy link
Copy Markdown
Contributor

Adds speculative decoding to the causal_llm example: a cheap drafter proposes the next few tokens, the target model verifies them in a single forward pass and accepts the longest greedy-matching prefix plus its own next token. Output is identical to plain greedy decoding.

What's here

  • expose_all_logits transform (transformers) + Nnef::load_without_decluttering (api/rs): exports emit last-token-only logits, but verifying K drafts needs a next-token distribution at every draft position. The transform recomputes the final projection over the full hidden states (the last-position values are unchanged, so it stays a drop-in at decode time). It must run before declutter rearranges the last-token slice, hence the raw-load hook.
  • generate_speculative with KV-cache rollback: one forward over [tail · drafts], greedy-verify, accept the matching prefix + correction, truncate the cache to drop rejected drafts. A distribution-preserving rejection sampler (Leviathan/Chen) is included for temperature > 0.
  • Two drafters: NgramDrafter (prompt-lookup, no second model) and ModelDrafter (a smaller causal LM).
  • Benchmarks: bench_spec (E2E throughput + acceptance + lossless check), bench_micro (forward latency vs tokens-per-pass), and a criterion host-side bench.

Correctness

Greedy speculative output is bit-identical to generate_next_token, verified on real models (Qwen3-1.7B and Llama-3.2-1B). The real-model tests are gated on a local .cached model path and skip when absent, matching the existing truncate_prefix_cache_hit test; the verification/drafter logic also has model-free unit tests. Lossless in exact arithmetic; rare divergences on high-entropy text come from floating-point differences between batched verification and single-token decode (standard for speculative decoding).

Performance

n-gram speculation on Qwen3-1.7B (Metal): ~1.2–1.4× decode speedup at k=2 on repetitive / code-like text; neutral-to-slower on prose (low acceptance) and at larger k. The useful draft length is currently capped by the q4 GEMV→GEMM kernel selection; #2369 widens that and lifts k=4 from a slowdown to ~1.19×.

🤖 Generated with Claude Code

…nsform

Add greedy/sampling speculative decoding to the causal_llm example: an
expose_all_logits ModelTransform (with a raw-load API hook) so the target emits
per-position logits, generate_speculative with KV-cache rollback, n-gram and
draft-model drafters, plus E2E and micro benchmarks.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant