Skip to content

Add SM90 FP8 paged MQA logits support for next_n=3Fix sm90 nextn3 paged mqa logits#340

Open
yangsiqt wants to merge 4 commits into
deepseek-ai:mainfrom
yangsiqt:fix-sm90-nextn3-paged-mqa-logits
Open

Add SM90 FP8 paged MQA logits support for next_n=3Fix sm90 nextn3 paged mqa logits#340
yangsiqt wants to merge 4 commits into
deepseek-ai:mainfrom
yangsiqt:fix-sm90-nextn3-paged-mqa-logits

Conversation

@yangsiqt

@yangsiqt yangsiqt commented May 28, 2026

Copy link
Copy Markdown

Summary

This PR adds SM90/Hopper support for FP8 paged MQA logits with next_n=3.

The immediate motivation is vLLM MTP speculative decoding:

num_speculative_tokens=2 -> DeepGEMM next_n=3

Before this change, the SM90 paged MQA logits path rejected this case with:

next_n == 1 or next_n == 2

This PR keeps the existing next_n=1 and next_n=2 behavior unchanged, and only extends the SM90 non-varlen path to handle next_n=3.

Related vLLM issue:

vllm-project/vllm#43457

Implementation Notes

The SM90 path already handles next_n=1 and next_n=2. next_n=3 is an odd case, so this PR does not simply widen the assertion.

Instead, the implementation maps next_n=3 as:

tokens 0 and 1 -> existing paired-token atom path
token 2        -> single-token tail path

Concretely:

  • The paged MQA logits metadata/scheduler can group the SM90 next_n=3 case as one logical next-n atom.
  • The main WGMMA path handles the first two tokens as the existing paired-token case.
  • A tail WGMMA path handles the third token.
  • The block-table row mapping remains tied to the original batch row, while the token offset selects the token inside next_n.

This avoids treating next_n=3 as a full 3-token atom, which would increase register/shared-memory pressure and was not supported by the original SM90 kernel structure.

Scope

This PR intentionally keeps the scope narrow:

  • Adds SM90 non-varlen next_n=3 support for paged MQA logits.
  • Does not add SM90 varlen support.
  • Does not add next_n > 3 support.
  • Does not change the existing SM100 behavior.
  • Does not change the existing SM90 next_n=1/2 behavior.

If broader next_n=4 or scheduler-generalization work lands separately, this PR should still be understood as the minimal next_n=3 path needed by vLLM num_speculative_tokens=2.

Validation

Negative / Positive Probe

Unpatched DeepGEMM direct next_n=3 probe fails with the expected assertion:

RuntimeError: Assertion error (.../smxx_fp8_fp4_paged_mqa_logits.hpp:233): next_n == 1 or next_n == 2

With this patch, the same direct next_n=3 probe passes.

DeepGEMM Attention Test

Ran from /root/DeepGEMM/tests:

OMP_NUM_THREADS=1 \
PYTHONPATH=/root/DeepGEMM \
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
python - <<'PY'
from test_attention import test_paged_mqa_logits

test_paged_mqa_logits()
PY

This covers SM90 FP8 paged MQA logits with:

NextN = 1, 2, 3
logits dtype = FP32, BF16
BLOCK_KV = 64
BSZ = 256
H = 64
D = 128
L = 8192, 32768

Observed NextN=3 lines on H20:

FP4=False, BF16=False, BLOCK_KV=64, BSZ=256, NextN=3, H=64, D=128, L=8192:   267 TFLOPS, 382 us
FP4=False, BF16=False, BLOCK_KV=64, BSZ=256, NextN=3, H=64, D=128, L=32768:  278 TFLOPS, 1494 us
FP4=False, BF16=True,  BLOCK_KV=64, BSZ=256, NextN=3, H=64, D=128, L=8192:   267 TFLOPS, 392 us
FP4=False, BF16=True,  BLOCK_KV=64, BSZ=256, NextN=3, H=64, D=128, L=32768:  278 TFLOPS, 1504 us

Standalone DeepGEMM Microbenchmark

Direct standalone DeepGEMM benchmark on H20:

B = 256
H = 64
D = 128
max_model_len / L = 8192
block_kv = 64
warmup = 8
iters = 80
repeats = 5
metric = CUDA event median ms/call
Case Status Output shape Median time
next_n=1 existing path (256, 8192) 0.1306 ms/call
next_n=2 existing path (512, 8192) 0.2519 ms/call
next_n=3 this PR (768, 8192) 0.3807 ms/call

Structured microbenchmark at B=256, L=4096:

Case Median time
next_n=1 0.0683 ms
next_n=2 0.1301 ms
next_n=3 0.1981 ms

These numbers are intended to show that the new path is functional and in the expected performance range. I would not describe this as performance-final; large-B/long-L next_n=3 may still have optimization headroom.

vLLM End-to-End Check

I also validated the issue path in vLLM on 2x H20:

vLLM: 0.21.1rc1.dev186+g1e48d8139.cu129
vLLM commit: 1e48d81395d48ff73ea98a5509a7e29d60d788f1
DeepGEMM commit: dccfe0922dbc06d09f84905cf652714502de7413
Model: DeepSeek-V4-Flash-W4A16-FP8-MTP
TP: 2
max_model_len: 4096
kv_cache_dtype: fp8

Server configuration:

vllm serve "$MODEL" \
  --host 127.0.0.1 \
  --port 8000 \
  --tensor-parallel-size 2 \
  --kv-cache-dtype fp8 \
  --dtype auto \
  --gpu-memory-utilization 0.92 \
  --max-model-len 4096 \
  --max-num-seqs 4 \
  --max-num-batched-tokens 4096 \
  --safetensors-load-strategy prefetch \
  --speculative-config '{"method":"mtp","num_speculative_tokens":2}'

Result:

mtp_k2 server starts successfully
/v1/completions returns 200 OK
no DeepGEMM next_n assertion
spec_decode metrics are emitted

I used max_model_len=4096 on this H20 setup due to KV-cache capacity. The important path exercised here is still:

num_speculative_tokens=2 -> next_n=3 -> DeepGEMM paged MQA logits

End-to-End Performance Context

For context, I ran two sequential streaming vLLM benchmarks on the same 2x H20 setup.

The first workload uses the original short realistic prompt set:

workload: short realistic prompts
num_prompts: 64
max_tokens: 256
max_concurrency: 1
stream: true
temperature: 0
ignore_eos: true
mean prompt words: 29.8
median prompt words: 29.5
prompt word range: 18 to 56

The second workload uses HumanEval-style code-completion prompts and a longer decode length:

workload: HumanEval code-completion prompts
num_prompts: 64
prompt tokens: min 221, max 320, mean 235.2, median 229.0
max_tokens: 512
max_concurrency: 1
stream: true
temperature: 0
ignore_eos: true

End-to-end results:

Workload Case Completed Failed Median TTFT Approx TPOT Output tok/s Acceptance Speedup vs baseline
short realistic, output=256 baseline 64 0 0.123s 11.118 ms 89.94 N/A 1.00x
short realistic, output=256 mtp_k1 64 0 0.130s 6.693 ms 149.41 81.3% 1.66x
short realistic, output=256 mtp_k2 64 0 0.133s 6.414 ms 155.91 63.0% 1.73x
HumanEval code, output=512 baseline 64 0 0.117s 9.9369 ms 100.64 N/A 1.00x
HumanEval code, output=512 mtp_k1 64 0 0.119s 5.9092 ms 169.23 95.3% 1.68x
HumanEval code, output=512 mtp_k2 64 0 0.130s 5.3010 ms 188.64 83.4% 1.87x

For mtp_k2, per-position acceptance was:

Workload Draft steps Position 0 accepted Position 0 acceptance Position 1 accepted Position 1 acceptance Position 1 conditional
short realistic, output=256 7,258 5,799 79.9% 3,346 46.1% 57.7%
HumanEval code, output=512 12,290 11,686 95.1% 8,818 71.7% 75.5%

Measured relative speedups:

Workload k1 vs baseline k2 vs baseline k2 relative to k1
short realistic, output=256 1.66x 1.73x 1.04x
HumanEval code, output=512 1.68x 1.87x 1.115x

The HumanEval/code-completion workload shows clearer mtp_k2 gains because the second draft token is accepted much more often and the longer decode length better amortizes fixed serving overhead:

short realistic position-1 acceptance: 46.1%
HumanEval code position-1 acceptance: 71.7%

The main purpose of this PR is still to fix the SM90 functional blocker for next_n=3. The vLLM numbers are included to show the issue path runs end to end; speculative decoding speedup remains workload-dependent and further performance tuning can be handled separately.

yangsiqt added 3 commits May 29, 2026 00:01
Implements Hopper NextN=3 via 2-token scheduler atoms with fused
KV tile reuse and single-token tail WGMMA, avoiding the register
pressure of a native 3-token shape.

- Introduce kPadOddN / kNextNAtom / kNumNextNAtoms compile-time
  constants for odd-N support
- Fuse KV tile load: one load per row task serves 2 Q stages
- Use TailWGMMA (1-head shape) for the 3rd token tail pass
- Increase Q pipeline to 4 stages for NextN=3
- Update JIT templates and scheduler with num_next_n_atoms

Benchmark on H800 PCIe (B=256, L=8192, H=64, D=128):
  next_n=2: 0.169 ms/call
  next_n=3: 0.192 ms/call (previously 0.324 ms/call with C++ fallback)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant