examples/causal_llm: speculative decoding (n-gram + draft-model)#2370
Open
czoli1976 wants to merge 1 commit into
Open
examples/causal_llm: speculative decoding (n-gram + draft-model)#2370czoli1976 wants to merge 1 commit into
czoli1976 wants to merge 1 commit into
Conversation
…nsform Add greedy/sampling speculative decoding to the causal_llm example: an expose_all_logits ModelTransform (with a raw-load API hook) so the target emits per-position logits, generate_speculative with KV-cache rollback, n-gram and draft-model drafters, plus E2E and micro benchmarks. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds speculative decoding to the
causal_llmexample: a cheap drafter proposes the next few tokens, the target model verifies them in a single forward pass and accepts the longest greedy-matching prefix plus its own next token. Output is identical to plain greedy decoding.What's here
expose_all_logitstransform (transformers) +Nnef::load_without_decluttering(api/rs): exports emit last-token-only logits, but verifying K drafts needs a next-token distribution at every draft position. The transform recomputes the final projection over the full hidden states (the last-position values are unchanged, so it stays a drop-in at decode time). It must run before declutter rearranges the last-token slice, hence the raw-load hook.generate_speculativewith KV-cache rollback: one forward over[tail · drafts], greedy-verify, accept the matching prefix + correction, truncate the cache to drop rejected drafts. A distribution-preserving rejection sampler (Leviathan/Chen) is included fortemperature > 0.NgramDrafter(prompt-lookup, no second model) andModelDrafter(a smaller causal LM).bench_spec(E2E throughput + acceptance + lossless check),bench_micro(forward latency vs tokens-per-pass), and a criterion host-side bench.Correctness
Greedy speculative output is bit-identical to
generate_next_token, verified on real models (Qwen3-1.7B and Llama-3.2-1B). The real-model tests are gated on a local.cachedmodel path and skip when absent, matching the existingtruncate_prefix_cache_hittest; the verification/drafter logic also has model-free unit tests. Lossless in exact arithmetic; rare divergences on high-entropy text come from floating-point differences between batched verification and single-token decode (standard for speculative decoding).Performance
n-gram speculation on Qwen3-1.7B (Metal): ~1.2–1.4× decode speedup at k=2 on repetitive / code-like text; neutral-to-slower on prose (low acceptance) and at larger k. The useful draft length is currently capped by the q4 GEMV→GEMM kernel selection; #2369 widens that and lifts k=4 from a slowdown to ~1.19×.
🤖 Generated with Claude Code