Add SM90 FP8 paged MQA logits support for next_n=3Fix sm90 nextn3 paged mqa logits#340
Open
yangsiqt wants to merge 4 commits into
Open
Add SM90 FP8 paged MQA logits support for next_n=3Fix sm90 nextn3 paged mqa logits#340yangsiqt wants to merge 4 commits into
yangsiqt wants to merge 4 commits into
Conversation
Implements Hopper NextN=3 via 2-token scheduler atoms with fused KV tile reuse and single-token tail WGMMA, avoiding the register pressure of a native 3-token shape. - Introduce kPadOddN / kNextNAtom / kNumNextNAtoms compile-time constants for odd-N support - Fuse KV tile load: one load per row task serves 2 Q stages - Use TailWGMMA (1-head shape) for the 3rd token tail pass - Increase Q pipeline to 4 stages for NextN=3 - Update JIT templates and scheduler with num_next_n_atoms Benchmark on H800 PCIe (B=256, L=8192, H=64, D=128): next_n=2: 0.169 ms/call next_n=3: 0.192 ms/call (previously 0.324 ms/call with C++ fallback)
cd4bdbc to
57b90e0
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds SM90/Hopper support for FP8 paged MQA logits with
next_n=3.The immediate motivation is vLLM MTP speculative decoding:
Before this change, the SM90 paged MQA logits path rejected this case with:
This PR keeps the existing
next_n=1andnext_n=2behavior unchanged, and only extends the SM90 non-varlen path to handlenext_n=3.Related vLLM issue:
vllm-project/vllm#43457
Implementation Notes
The SM90 path already handles
next_n=1andnext_n=2.next_n=3is an odd case, so this PR does not simply widen the assertion.Instead, the implementation maps
next_n=3as:Concretely:
next_n=3case as one logical next-n atom.next_n.This avoids treating
next_n=3as a full 3-token atom, which would increase register/shared-memory pressure and was not supported by the original SM90 kernel structure.Scope
This PR intentionally keeps the scope narrow:
next_n=3support for paged MQA logits.next_n > 3support.next_n=1/2behavior.If broader
next_n=4or scheduler-generalization work lands separately, this PR should still be understood as the minimalnext_n=3path needed by vLLMnum_speculative_tokens=2.Validation
Negative / Positive Probe
Unpatched DeepGEMM direct
next_n=3probe fails with the expected assertion:With this patch, the same direct
next_n=3probe passes.DeepGEMM Attention Test
Ran from
/root/DeepGEMM/tests:This covers SM90 FP8 paged MQA logits with:
Observed
NextN=3lines on H20:Standalone DeepGEMM Microbenchmark
Direct standalone DeepGEMM benchmark on H20:
next_n=1(256, 8192)0.1306 ms/callnext_n=2(512, 8192)0.2519 ms/callnext_n=3(768, 8192)0.3807 ms/callStructured microbenchmark at
B=256, L=4096:next_n=10.0683 msnext_n=20.1301 msnext_n=30.1981 msThese numbers are intended to show that the new path is functional and in the expected performance range. I would not describe this as performance-final; large-B/long-L
next_n=3may still have optimization headroom.vLLM End-to-End Check
I also validated the issue path in vLLM on 2x H20:
Server configuration:
Result:
I used
max_model_len=4096on this H20 setup due to KV-cache capacity. The important path exercised here is still:End-to-End Performance Context
For context, I ran two sequential streaming vLLM benchmarks on the same 2x H20 setup.
The first workload uses the original short realistic prompt set:
The second workload uses HumanEval-style code-completion prompts and a longer decode length:
End-to-end results:
0.123s11.118 ms89.941.00x0.130s6.693 ms149.4181.3%1.66x0.133s6.414 ms155.9163.0%1.73x0.117s9.9369 ms100.641.00x0.119s5.9092 ms169.2395.3%1.68x0.130s5.3010 ms188.6483.4%1.87xFor
mtp_k2, per-position acceptance was:Measured relative speedups:
1.66x1.73x1.04x1.68x1.87x1.115xThe HumanEval/code-completion workload shows clearer
mtp_k2gains because the second draft token is accepted much more often and the longer decode length better amortizes fixed serving overhead:The main purpose of this PR is still to fix the SM90 functional blocker for
next_n=3. The vLLM numbers are included to show the issue path runs end to end; speculative decoding speedup remains workload-dependent and further performance tuning can be handled separately.