Skip to content

feat: add sm120 support for DeepGEMM #324

Open
leavelet wants to merge 55 commits into
deepseek-ai:nv_devfrom
leavelet:sm120
Open

feat: add sm120 support for DeepGEMM #324
leavelet wants to merge 55 commits into
deepseek-ai:nv_devfrom
leavelet:sm120

Conversation

@leavelet

@leavelet leavelet commented May 1, 2026

Copy link
Copy Markdown

This PR adds first-class sm_120 support to DeepGEMM, maintained on the nv_dev branch by the NVIDIA DevTech APAC team. It brings the full DeepGEMM kernel surface — dense, grouped/MoE, einsum, hyper-connection, and MQA-logits — to the sm120 and sm121 devices like RTX Pro 6000 and DGX Spark.

co-authored with @lucifer1004.

Feature surface

Dense GEMM (sm120_fp8_fp4_gemm_1d1d, sm120_bf16_gemm)

  • FP8 (e4m3, UE8M0 block-scaled, 1D1D), all layouts NT/NN/TN/TT
  • FP4 (e2m1, UE8M0 mxf4nvf4, gran_k=32)
  • BF16 and TF32
  • Mixed precision: FP8_A × FP4_B (mxf8f6f4 .e4m3.e2m1) and FP4_A × FP8_B
    (.e2m1.e4m3) — both directions
  • AB-swap small-M path (M = 1..16) via TMA + per-element strided epilogue
  • Wave-packing latency heuristic (BLOCK_M 64/128) and split-K for small dims

Grouped GEMM / MoE

  • M-grouped contiguous, M-grouped masked, K-grouped contiguous
  • FP8 / FP4 / BF16

Batched GEMM / Einsum (sm120_bmk_bnk_mn)

  • K-major B (bhr,hdr->bhd) and MN-major B (bhd,hdr->bhr)
  • Split-S reduction (bmk,bnk->mn); FP8 / BF16; small-M AB-swap

HC Prenorm (sm120_tf32_hc_prenorm_gemm)

  • Fused TF32 GEMM + square-sum with split-K

**MQA Logits ** (sm120_fp8_mqa_logits, sm120_fp4_mqa_logits + paged)

  • FP8 / FP4; dense (ragged, warp-specialized, L2-cached KV, split-KV) and
    paged (next_n / kPadOddN support)

Performance (RTX PRO 6000 Blackwell, single GPU, tests/test_*.py)

Roofline used: FP8 ≈ 814 TFLOPS, FP4 ≈ 1628 TFLOPS (2× FP8), BF16 ≈ 445 TFLOPS. GB/s is effective bandwidth (cache-inclusive for small/cached shapes).

Dense GEMM

dtype peak TFLOPS % roofline peak GB/s
FP8 778 96% 1235
FP4 1561 96% 1337
BF16 444 ~98% 1497

Grouped GEMM (MoE)

variant FP8 TFLOPS FP4 TFLOPS peak GB/s
M-grouped contiguous (prefill) 814 1571
M-grouped masked 670 1363 836
K-grouped contiguous (EP) 543 (gk32) 666 (gk128) 1011

Einsum

Peak 742 TFLOPS (b=8192,h=8,r=4096,d=1024); up to 2893 GB/s (small-batch, cache-bound); Split-S path scales down for small problems.

HC Prenorm

m=8192, n=24, k=28672, splits=16: 29 TFLOPS / 1248 GB/s (bandwidth-bound).

MQA Logits

variant peak TFLOPS
Ragged FP4 1049
Paged FP4 685
Ragged FP8 636
Paged FP8 547

@leavelet leavelet mentioned this pull request May 1, 2026
@leavelet leavelet changed the title [WIP] Feat: Add sm120 support for DeepGEMM [WIP] feat: Add sm120 support for DeepGEMM May 1, 2026
@leavelet leavelet marked this pull request as ready for review May 9, 2026 04:37
@linjiapro

Copy link
Copy Markdown

@leavelet this is nice, after it is merged into the nv-dev branch, should vllm-project/vllm#41834 merge too in order for vllm to be able to work with the branch.

@jasl

jasl commented May 10, 2026

Copy link
Copy Markdown
Contributor

I add my benchmark at vllm-project/vllm#41834 as references

jasl added a commit to jasl/tokenspeed that referenced this pull request May 12, 2026
Replace the hand-written CUDA FP8 GEMV kernel (previously gated to
tokens==1) with a port of the SM120 FP8 einsum kernel from upstream
DeepGEMM's WIP SM120 support (deepseek-ai/DeepGEMM#324, file
`deep_gemm/include/deep_gemm/impls/sm120_fp8_einsum.cuh`). The DeepGEMM
kernel implements exactly the `bhr,hdr->bhd` einsum DeepSeek V4 needs,
with per-thread-per-output-cell GEMV using fp8x4 vectorized loads and
the same block-128 fp32 scale recipe.

Removing the tokens==1 gate: the kernel handles all token counts that
the SM12x dispatch predicate accepts (tokens <= 16 today; larger token
batches will arrive once T1-α expands graph capture).

Microbench (DSv4-Flash decode shape, groups=8, hidden=2048, out=1024,
GPU idle):

  tokens=1  cuda 0.026ms  triton 0.020ms  speedup 0.72x (was 0.40x)
  tokens=2  cuda 0.027ms  triton 0.020ms  speedup 0.72x (was 0.73x*)
  tokens=8  cuda 0.075ms  triton 0.020ms  speedup 0.27x (was 0.21x*)

  * Triton-as-default after the previous tokens==1 hotfix.

The kernel's grid is `tokens * groups * (out/128)`, one block per
`(token, group, out_tile=128)` triple. Because each block reads its
weight tile independently, total weight reads scale linearly with
`num_tokens`. At graph bs=2 (today) this dominates: tokens<=2 is the
production shape and the 0.72x is a real net win against the previous
Triton-fallback default. At tokens=8 (future, post-T1-α) the kernel
loses ~2x to Triton's m=16 cooperative tile; we will revisit with a
multi-token tile design before T1-α exposes that shape to production.

Earlier hand-written attempts (one-cell-per-block, per-thread B=16
accumulator tile, 1-warp m16n8 MMA, 4-warp m16n32 cooperative MMA,
4-warp m16n128 MMA) are documented in
`docs/notes/2026-05-09-ds4-sm12x-rejected-experiments.md`. The MMA
designs hit either occupancy collapse (80 regs/thread) or insufficient
parallelism (64 blocks at decode shape vs Blackwell's 140 SMs), capping
out at ~0.51x. The DeepGEMM design wins at the production shape by
avoiding tensor cores entirely -- a per-thread GEMV with fp8x4
vectorization and L1/L2-friendly weight access fits the small-M decode
profile better than the m=16 MMA tile.

Attribution: kernel source ported under MIT license from upstream
DeepGEMM (Copyright (c) 2025 DeepSeek). Tokenspeed adaptations are the
tvm-ffi binding, stride/scale validation, and the SM12x dispatch
integration; the dot-product math is unchanged.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/tokenspeed that referenced this pull request May 12, 2026
Upstream PR lightseekorg#93 added a pre-flight DeepGEMM ``fp8_gemm_nt`` call to
``DeepseekV4Attention._compute_qr_kv``: on success it replaces the
reference FP8 linear path, on failure it logs a WARNING per layer and
falls back. DeepGEMM does not support SM120/SM121 yet (see PR
``deepseek-ai/DeepGEMM#324`` + ``reference_deepgemm_sm120`` memory),
so on the RTX Pro 6000 workstation every layer fires:

    DeepSeek V4 DeepGEMM FP8 linear failed; falling back to reference
    FP8 linear. reason=RuntimeError: Assertion error
    (csrc/apis/layout.hpp:59): Unknown SF transformation

The existing per-layer ``_deepseek_v4_deep_gemm_linear_disabled`` flag
already catches this for steady-state replay, but it costs one failed
call + one WARNING per layer at boot. Mirror the pattern used by
``_deepseek_v4_deepgemm_fp4_indexer_enabled_for_platform``: short-
circuit ``_deepseek_v4_get_fp8_linear_deep_gemm`` to ``None`` on SM12x
so the platform never tries the DeepGEMM path. Non-SM12x platforms
keep the new fast path.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 18, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 18, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 19, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 19, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 20, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
@Rachmanino

Copy link
Copy Markdown

nice work! may I ask the hardware for testing here is either 5090 or RTX6000pro?

jasl added a commit to jasl/vllm that referenced this pull request May 22, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 22, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
DoradusResearch pushed a commit to DoradusResearch/vllm that referenced this pull request May 23, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
leavelet and others added 12 commits May 26, 2026 05:43
Phase 1a: infrastructure + dense FP8 GEMM kernel for SM120a (CC 12.0).

Architecture: warp-level mma.sync with block-scaled UE8M0 scale factors,
B128 XOR swizzle, persistent scheduling, register-based epilogue.

New files:
- SM120 heuristics, JIT codegen, MMA PTX wrappers, ldmatrix/swizzle utils
- CUDA kernel with warp-specialized TMA/math pipeline (3-9 stages)

Modified files:
- Arch detection, compiler flags (-gencode for SM120a)
- API dispatch (arch_major == 12), SF layout transform
- Default recipe for SM120

Correctness: 8/8 shapes pass (diff < 0.001 cosine distance)
Performance: ~73 TFLOPS (baseline, optimization pending)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… only

Drop the non-warp-specialized kernel path for SM120a (matching SM90/SM100
architecture), merging the warp-specialized implementation into the main
sm120_fp8_fp4_gemm_1d1d kernel. Add FP4 GEMM support using packed SMEM
with the mxf4nvf4 m16n8k64 MMA instruction.

Key changes:
- Consolidate: remove non-spec path, always BM=128/BK=128/384 threads
- FP4: packed 4-bit SMEM (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B), standard
  ldmatrix, uint16_t scale factors (scale_vec::2X), kKSteps=2 vs FP8's 4
- Heuristic: simplified to warp-spec only, correct SMEM sizing for FP4
- API: enable FP4 on SM120a (arch_major==12), add fp8_fp4_gemm_nt binding
- Fix SF hoist bug: hoist SFA/SFB independently for mixed gran_k configs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Kernel: Add TMA descriptor runtime update (tensormap.replace) in producer
  loop for K-grouped group transitions, fix SF_K_ALIGNMENT to kGranKA*4,
  fix SMEM layout (pipeline data at offset 0 for B128 swizzle alignment,
  tensor map descriptors at end), fix epilogue bounds for multi-group output.

- MMA wrappers: Replace CUTLASS mma_sm120.hpp dependency with custom inline
  asm using "+f" read-write constraints for accumulator registers. Eliminates
  CUTLASS header dependency and gives explicit control over MMA operand
  encoding for both FP8 (m16n8k32) and FP4 (m16n8k64) block-scaled MMA.

- JIT launcher: Add sm120_k_grouped_fp8_fp4_gemm_1d1d() with proper TMA
  descriptor creation (first_k base, FP4-aware stride), SF TMA covering
  concatenated groups, CD TMA with num_groups outer dimension.

- API dispatch: Add arch_major==12 path in k_grouped_fp8_gemm_nt_contiguous,
  relax recipe assertion to support gran_k=32/128, add SM120 SF layout
  transform with auto-detection of transposed K-major scale factors.

- Tests: Add dedicated SM120 K-grouped test (7 configs including zero-K
  edge case), fix K-major selection for SM120 in generators, fix test
  dispatch for SM120 in test_fp8_fp4.py, update FP4 test with perf comparison.

Tested: Dense FP8 8/8, Dense FP4 10/10, K-grouped FP8 7/7 — all PASS.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Einsum support:
- Add GemmType::Batched to FP8/FP4 and BF16 kernels with 3D TMA load/store
- Add IndexType::SF_K for batched SF coordinate computation
- Add MN-major B support to BF16 kernel (scalar SMEM loads, single-atom constraint)
- BF16 bhr,hdr->bhd: 384 TFLOPS, FP8: 681 TFLOPS (batch=8, b=8192)

M-grouped BF16:
- Add contiguous and masked M-grouped BF16 GEMM launchers

HC prenorm TF32:
- New fused GEMM + sqr_sum kernel using mma.sync.m16n8k8 TF32 (226T peak)
- BF16 A -> FP32 cast with fused sqr_sum accumulation
- Atom-aware FP32 B fragment loading from K-major SMEM
- Split-K support for large K / small M shapes
- 24/24 test shapes PASS, ~1.1 TB/s bandwidth on large shapes

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Dense (ragged): 651 TFLOPS peak (80% of FP8 MMA peak 814T), 40/40 tests pass.
Paged (KV cache): 320 TFLOPS peak, 1.36 TB/s DRAM (91% of HBM BW), 8/8 tests pass.

Kernel design: warp-specialized mma.sync m16n8k32 FP8 (no block_scale).
8 math warps × 16 KV rows each = 128 BLOCK_KV. In-warp 2-shfl reduction
across 4 threads (lane%4) — only ~10 cycles, negligible vs MMA time.
Global stores are fire-and-forget on SM120a, so no epilogue warps needed.

Key parameters: block_qh=128, num_heads=64, head_dim=128, 2 Q stages,
3 KV stages, 84KB SMEM (83% of 101KB capacity).

Paged variant: 2 groups of 4 warps, SPLIT_KV=128, per-group KV pipeline.
Fixed metadata split_kv mismatch and register budget overflow (TMA regs
64→40 to stay within 65536 register limit).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- skip_head_mid: Add SM120 dispatch in attention.hpp for EpilogueHeadSplits.
  Fix three issues: TMA CD descriptor uses d.size(-1)/d.stride(-2) instead
  of n; kernel uses stride_d parameter for D row stride and bounds checks;
  TMA store coordinates apply epilogue N-index remapping.

- MN-major B: Fix kernel TMA coordinate for M-grouped BF16 with MN-major B.
  Group offset moves to outer=K coordinate (not inner=N) when kBKMajor=false.

- FP8 kernel stride_d: Add stride_d parameter to decouple D tensor stride
  from computation dimension n, enabling epilogue transforms that expand N.
Replace per-element scalar SMEM loads in the MN-major B path with
ldmatrix.sync.aligned.x2.m8n8.trans.shared.b16 which natively loads
column-major 8x8 BF16 matrices — directly producing MMA B fragments.

Performance: MN-major B improves from ~220T to ~290T (dense) and ~330T
(M-grouped), a 30-50% gain. Remaining gap vs K-major (400T) is due to
heuristic selecting BLOCK_N=32/64 vs 128 (single swizzle atom constraint).

Verified by micro benchmark b_bf16_4: fragment layout 32x32 lanes PASS,
MMA pipeline 4 K-steps accumulation PASS.
Remove single-atom BLOCK_N constraint for MN-major B. ldmatrix.trans
correctly handles multi-atom SMEM (verified by micro benchmark b_bf16_5).

MN-major B now achieves 99-102% of K-major performance (was 53% with
scalar loads, 80% with single-atom ldmatrix.trans).
…bhr->hdr

New BF16 bmk,bnk->mn reduction kernel with split-S atomicAdd to FP32 output
(188T peak, HBM BW limited). FP8 einsum dispatch for bhd,hdr->bhr and
bhd,bhr->hdr via .contiguous() to K-major. Fix batched epilogue stride
formula: replace single stride_d with stride_cd_m + stride_cd_batch to
support arbitrary D layouts ([batch,M,N] vs [M,batch,N]). Add kBKMajor
template parameter to FP8 kernel with verified scalar-load MN-major B path
(correct but 3x slower than K-major ldmatrix, kept for future optimization).
K-grouped TN: single .t().contiguous() transpose with constant-stride TMA
(kKGroupedConstantStride) — per-group only replaces addr+dim, not stride.
TN achieves 99-101% of NT performance. New PTX tensor_map_replace_global_dim_in_smem.

Paged MQA varlen: fix 4 kernel bugs in sm120_fp8_paged_mqa_logits.cuh:
- TMA Q coordinate: use atom_to_token_idx() instead of hardcoded *kNextNAtom
- Prefetch advance: use get_atom_advance() instead of hardcoded +1
- Math loop: conditional iteration count via is_paired_atom for unpaired atoms
- KV block idx: reset kv_block_idx_ptr=32 on q_atom change
New kernels using mma.sync m16n8k64 block-scaled FP4 (mxf4nvf4, scale_vec::2X).
Architecture: 8 math warps + 4 TMA warps, B64 swizzle, kKSteps=2.
Block-scaled MMA folds UE8M0 SF into computation — no post-MMA scale.

Dense FP4 MQA: 1022 TFLOPS peak (63% FP4 peak), 1.6x vs FP8.
Paged FP4 MQA: 707 TFLOPS peak (varlen), 566 TFLOPS (non-varlen next_n=2).
jasl added a commit to jasl/vllm that referenced this pull request Jun 5, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 5, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 5, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 5, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
…r changes

The SM120 MQA-logits work modified the shared launchers in two ways that broke
SM100 (and were latent on SM90), surfacing as test_attention failures on B300
(baseline=origin/nv_dev passes, sm120 HEAD fails). No SM100/SM90 device kernel
(.cuh) changed, so this is purely a host launch-config regression.

1. Dense MQA SMEM under-allocation (smxx_fp8_fp4_mqa_logits.hpp). The SMEM size
   dropped the `(num_math_threads / 128) * 2` mbarrier pairs that the SM90/SM100
   kernels still allocate (the SM120 kernel does not). On SM100 this under-sized
   the dynamic SMEM; compute-sanitizer reports an "Invalid __shared__ write" in
   sm100_fp8_mqa_logits, faulting (CUDA_ERROR_ILLEGAL_ADDRESS) at large configs
   such as seq_len_kv=130560. Restore the term for arch != 12.

2. Paged-MQA metadata capacity assert (smxx_fp8_fp4_paged_mqa_logits.hpp). An
   unconditional `smem_size <= SM120ArchSpec::smem_capacity` (99 KB) was added to
   the metadata launcher, which runs on every arch. On SM90/SM100 (228 KB) a
   large (varlen) batch's metadata SMEM legitimately exceeds 99 KB and falsely
   tripped the assert. Gate the capacity check to the running arch.

Verified on B300 (sm100): compute-sanitizer memcheck clean on the previously
out-of-bounds config (seq_len=510, seq_len_kv=130560); full test_attention
(dense + paged, NextN 1-6, FP8/FP4) passes. Fix (1) also corrects a latent
SMEM under-allocation on SM90.
jasl added a commit to jasl/vllm that referenced this pull request Jun 5, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 6, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
Strip changes that should not ship upstream and align comment style with the
existing codebase (terse, no decorative banners). No change to production
GEMM/attention behavior — only removes a dev-only bench path and trims comments.

- .gitignore: revert to upstream; personal ignores (docs_internal/, benchmarks/,
  ncu_reports/, internal/) moved to local .git/info/exclude.
- Remove dev-only sm120_fp8_gemm_bench API and its override_layout tile-override
  param (gemm.hpp, __init__.py, sm120_fp8_fp4_gemm_1d1d.hpp); config selection
  collapses to the standard get_best_config path.
- Trim changelog-/rationale-style block comments to single lines (shared rationale
  lives in commit history): MQA-logits SMEM, paged-MQA capacity assert,
  cd_n_contiguous, split-K SF alignment, latency model, einsum MN-major notes.
- De-duplicate the AB-swap "transposed output" rationale (kept once in config.hpp)
  and the UE8M0 SF-tile note.
- Remove decorative // ==== banner comments from sm120 .cuh (no other kernel uses them).
- Remove an unused num_waves local in the SM120 latency model (clears a
  -Wunused-variable warning).
- Remove tests/test_split_k_swap.py (standalone dev harness, not pytest-style).
@leavelet

leavelet commented Jun 7, 2026

Copy link
Copy Markdown
Author

@RayWang96 I have cleaned up the code and verified them on sm90, sm100 and sm120, the PR is ready for merge. Thanks!

jasl added a commit to jasl/vllm that referenced this pull request Jun 7, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
AliceChenyy added a commit to AliceChenyy/sglang that referenced this pull request Jun 7, 2026
…pGEMM@sm120)

Enable DeepGEMM grouped FP8×FP4 GEMM for MoE on SM120 (RTX 6000D/PRO 6000).
Requires leavelet/DeepGEMM@sm120 branch (deepseek-ai/DeepGEMM#324, not yet merged).

Changes:
- configurer.py: Allow SM120 only when SM120-compatible DeepGEMM is installed
  (checks for m_grouped_fp8_fp4_gemm_nt_contiguous availability)
- server_args.py: Auto-select deep_gemm MoE backend on SM120
- kernels.py: Add UE8M0 (power-of-2) FP8 quantization Triton kernel required
  by DeepGEMM's block-scaled dequantization on SM120
- deep_gemm.py: SM120 adaptations for DeepGEMM MoE runner:
  - TMA-aligned scale factors for grouped GEMM (hidden_states + down_input)
  - JIT EP activation fallback when hidden_dim/8 < num_experts (TP>=2)
  - In-place swiglu clamp replacing torch.chunk+cat (-7.4ms/step decode)
- fp8.py: Add .contiguous() before transform_sf_into_required_layout
- fp8_utils.py: Skip DeepGEMM dense FP8 linear on SM120 (bf16_gemm_nt unsupported)

Performance (TP=4, BS=1, RTX 6000D 85GB, ISL=8K):
  TTFT: 130ms (vs 400ms marlin, 3x faster)
  Decode ITL: 47ms (vs 41ms marlin, 15% slower)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
guqiong96 pushed a commit to guqiong96/Lvllmds4 that referenced this pull request Jun 8, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 9, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 11, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 11, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 11, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 11, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
xutizhou pushed a commit to xutizhou/sglang that referenced this pull request Jun 12, 2026
…pGEMM@sm120)

Enable DeepGEMM grouped FP8×FP4 GEMM for MoE on SM120 (RTX 6000D/PRO 6000).
Requires leavelet/DeepGEMM@sm120 branch (deepseek-ai/DeepGEMM#324, not yet merged).

Changes:
- configurer.py: Allow SM120 only when SM120-compatible DeepGEMM is installed
  (checks for m_grouped_fp8_fp4_gemm_nt_contiguous availability)
- server_args.py: Auto-select deep_gemm MoE backend on SM120
- kernels.py: Add UE8M0 (power-of-2) FP8 quantization Triton kernel required
  by DeepGEMM's block-scaled dequantization on SM120
- deep_gemm.py: SM120 adaptations for DeepGEMM MoE runner:
  - TMA-aligned scale factors for grouped GEMM (hidden_states + down_input)
  - JIT EP activation fallback when hidden_dim/8 < num_experts (TP>=2)
  - In-place swiglu clamp replacing torch.chunk+cat (-7.4ms/step decode)
- fp8.py: Add .contiguous() before transform_sf_into_required_layout
- fp8_utils.py: Skip DeepGEMM dense FP8 linear on SM120 (bf16_gemm_nt unsupported)

Performance (TP=4, BS=1, RTX 6000D 85GB, ISL=8K):
  TTFT: 130ms (vs 400ms marlin, 3x faster)
  Decode ITL: 47ms (vs 41ms marlin, 15% slower)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
leavelet added 2 commits June 12, 2026 09:53
The SM120 fp8/fp4 1D1D kernel iterated every ceil(M_sum/BLOCK_M) m-block and
computed invalid (m_indices == -1) padding tiles against expert 0. SM90 skips
these via Scheduler::is_computation_valid, but that helper is marked SM90-only
and the SM120 producer issued the A/B/SF TMA copies unconditionally.

At MoE decode the contiguous worst-case M_sum reserves a block per local expert
(min(M*topk, local_experts)) while only a few are routed, so most m-blocks are
empty padding. Computing them wastes a full-width GEMM tile — the dominant cost
of EP decode. Measured on RTX PRO 6000, EP gate_up (N=4096, K=4096, 64 local
experts, ~20 routed): worst-case 20-valid + 45-empty of 65 blocks = 304us vs the
20 real blocks alone = 170us; the empty tiles cost 134us (44%).

Add the m_indices < 0 skip-continue to both the producer and consumer
while-loops. The check is identical in both, so no barrier ops are issued for
skipped blocks and the warp-specialized pipeline stays in sync. Empty experts
now cost nothing: EP decode MoE GEMM drops ~1.8x (304 -> 169us), matching the
TP-sharded layout. The all-valid path (prefill/dense) is unchanged.

Validated: m-grouped contiguous correctness for fp8xfp8 / fp4xfp4 / fp8xfp4
(diff < 2%), worst-case-padded result == tight, all-65-valid timing unchanged.
Adds the reverse of the existing fp8-A x fp4-B mixed path: A is fp4 (e2m1), B is
fp8 (e4m3), via mxf8f6f4 .e2m1.e4m3 (m16n8k32). Lets the expert weight sit on the
well-filled M axis and the (few) decode tokens on a small N axis.

- mma/sm120.cuh: fp4_fp8_mixed_mma_block_scaled (.e2m1.e4m3).
- common/sm120_utils.cuh: ldmatrix_m8n16_x4_b4x16_p64 + load_a_fragment_b4x16
  (4-reg b4x16 unpack for the fp4 A operand; same addressing as the fp8 A load).
- impls/sm120_fp8_fp4_gemm_1d1d.cuh: kAIsFP4 template flag; A loaded as
  fp4-unpacked (b4x16 + <<2), TMA_A_BYTES = SMEM_A/2, MMA dispatch, perNTileX4
  disabled for kAIsFP4.
- jit_kernels/impls/sm120_fp8_fp4_gemm_1d1d.hpp: detect A=fp4 & B=fp8 (a_is_fp4);
  is_fp4 redefined to symmetric fp4xfp4 (value-identical for existing cases);
  emit kAIsFP4.
- heuristics/sm120.hpp: a_padded_fp4 -> the fp4 A operand uses unpacked b4x16
  SMEM (row = block_k, swizzle 128), mirroring the fp4 B handling.
- tests/generators.py: add the FP4_A x FP8_B QuantConfig (32,128,True,False) to
  the SM120 sweep; the FP8 operand of a mixed config is 1D-scaled (per-token,
  recipe (1, gran)), not per-block.

Validated on RTX PRO 6000: full fp8_fp4 sweep (normal / m-grouped contiguous /
m-grouped masked, all dtype configs incl. the new FP4_A x FP8_B) passes
(diff < max_diff, ~0.007 for fp4xfp8). Existing paths unchanged (a_is_fp4
defaults false). SM100 untouched.

Note: for MoE decode this orientation is bandwidth-saturated (~92% DRAM) and only
~3-6% faster than the standard layout; the empty-tile skip is the real EP-decode
lever. The K-grouped launcher A-stride is not yet wired for fp4-A.
@leavelet

Copy link
Copy Markdown
Author

@lucifer1004 @AliceChenyy cc @samuellees

Please switch to commit 33a715e for better MoE EP performance, especially for decode. The commit bump alone gives a 2x kernel speedup, and swapAB is an optional add-on worth roughly 6%

Comment thread csrc/apis/gemm.hpp Outdated
Comment thread tests/test_einsum.py Outdated
z = torch.empty((b, h, d), device='cuda', dtype=torch.bfloat16)
deep_gemm.einsum('bhr,hdr->bhd', x, y, z)
assert calc_diff(z, ref_z) < 1e-10
assert calc_diff(z, ref_z) < 1e-7

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the justification for the relaxation? Is it driven by SM120 not passing at 1e-10? If so, what makes SM120's einsum less accurate here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test compares DeepGEMM's bf16 output against cuBLAS's; both use identical HMMA.16816.F32.BF16 + a single bf16 cast, so they differ only in FP32 K-reduction order. At small M, cuBLAS's nvjet kernel sums 3 partial accumulators (vs DeepGEMM's single linear accumulator), which is marginally more accurate and makes ~0.2% of low-magnitude outputs differ, exceeding the near-bit-exact 1e-10 (lands ~8e-9). Against an FP64 reference DeepGEMM and cuBLAS are equally accurate (~1.4e-6, ratio 1.00). I will propose a fix, keeps 1e-10for SM90/SM100, loosens to1e-7 only on SM120, and adds an FP32-reference correctness check there (< 1e-5). Does this sound correct?

Comment thread csrc/jit_kernels/heuristics/runtime.hpp Outdated
leavelet added 2 commits June 15, 2026 23:59
Behavior-preserving cleanups from upstream review.

gemm.hpp: pull the SM120 arm of fp8_fp4_gemm_nt into a dedicated
fp8_fp4_gemm_nt_sm120() helper and dispatch in natural arch order
(9/10 share the SF-transform path, then 12). Add a shared
sm120_to_k_major() helper for the "repack packed-FP4 / contiguous to
force K-major" idiom, reused by the FP8/FP4 NT GEMM, the m-grouped
contiguous arm, and the BF16 GEMM.

runtime.hpp: replace the coupled arch==12 ternaries in
get_theoretical_mk_alignment_for_contiguous_layout() with a per-arch
ContiguousMKAlignment policy plus a single shrink loop.

SM90/SM100/SM120 results are unchanged; verified on SM120 with
test_fp8_fp4.py and test_bf16.py.
The BF16 einsum tests compared DeepGEMM's bf16 output against cuBLAS's
at 1e-10 (near bit-exact). On SM120 the two differ only in FP32
accumulation order (cuBLAS uses a multi-accumulator K-reduction at
small M), so ~0.2% of low-magnitude outputs differ by 1 ULP and the
comparison lands ~8e-9 -- not a DeepGEMM accuracy issue (both match the
FP64 truth equally, ~1.4e-6).

Keep the strict 1e-10-vs-cuBLAS bound for SM90/SM100, loosen to 1e-7
only on SM120, and add an architecture-independent FP32-reference
correctness check on SM120 (calc_diff < 1e-5; DeepGEMM measures ~1.4e-6).
WenbinHou added a commit to memorylake-ai/vllm that referenced this pull request Jun 17, 2026
## Summary

This PR backports the DeepSeek V4 FlashInfer sparse MLA / SM120 work
from the community PR vllm-project#43477 onto
`memorylake-ai/v0.22.1-dev`.

Community reference: vllm-project#43477

This is not a pure cherry-pick. The majority of the changed files are
overlaid from PR vllm-project#43477, but this branch also includes compatibility
fixes needed to make that code build, import, serve, and run
long-context DeepSeek V4 Flash on the current MemoryLake vLLM branch and
the locally built dependency stack.

## What changed

- Adds FlashInfer sparse MLA backend registration for SM120 / DeepSeek
V4:
  - `FLASHINFER_MLA_SPARSE_SM120`
  - `FLASHINFER_MLA_SPARSE_DSV4`
- Adds the DeepSeek V4 FlashInfer sparse attention implementation and
shared sparse MLA helpers.
- Adds FlashInfer sparse MLA warmup and autotune cache helpers.
- Adds DeepSeek V4 sparse/compressor integration needed by the
FlashInfer sparse path.
- Ports the CUTE sparse compressor implementation used by the DSv4 path.
- Updates CUDA platform/backend wiring so the new attention backend can
be selected through the usual vLLM attention backend enum path.
- Updates CMake external project handling for DeepGEMM and QuTLASS
source overrides used by this local build flow.

## Compatibility fixes on top of PR vllm-project#43477

Several local adjustments were required because this branch is older
than the community PR base and because the tested FlashInfer / DeepGEMM
APIs differ from the PR branch assumptions:

- FlashInfer sparse MLA symbols are imported from `flashinfer.mla` in
the locally built FlashInfer source, rather than `flashinfer.decode`.
- `MoERunner` import/API usage is adapted to the current
`memorylake-ai/v0.22.1-dev` fused-MoE layout.
- DeepGEMM MoE expert classes restore `supports_expert_map()` so they
satisfy the current modular expert interface.
- Triton FP8 block-scaled matmul decodes E8M0 / raw `uint8` scale
tensors before kernel launch, avoiding `torch.float8_e8m0fnu` binding
failures on this stack.
- The FlashInfer sparse O-projection path prepares `wo_a` weights and
scale tensors into the grouped layout expected by DeepGEMM `fp8_einsum`.
- TileLang MHC wrappers are adapted to the current `_tilelang_ops`
import path and custom-op registration shape.
- FlashMLA import in the DeepSeek V4 NVIDIA model path is made lazy so
the selected FlashInfer backend can import without requiring the
unrelated FlashMLA overlay class.

## Tested stack

Tested locally on `zp-nc522` with 8x RTX 6000D / SM120 using:

```bash
vllm serve /models/deepseek-ai/DeepSeek-V4-Flash \
  --served-model-name DeepSeek-V4-Flash \
  --host 0.0.0.0 \
  --port 8000 \
  --tensor-parallel-size 8 \
  --max-model-len 262144 \
  --max-num-seqs 1 \
  --max-num-batched-tokens 32768 \
  --enable-chunked-prefill \
  --gpu-memory-utilization 0.90 \
  --attention-backend FLASHINFER_MLA_SPARSE_DSV4 \
  --kv-cache-dtype fp8_ds_mla \
  --linear-backend triton
```

The local dependency stack used the MemoryLake vLLM source build, local
FlashInfer source build, local DeepGEMM source at the currently
validated SM120 branch point, local FlashMLA, and local vLLM
FlashAttention source.

## Validation

- Ran Python syntax validation on all staged Python files:

```bash
python -m py_compile $(git diff --cached --name-only -- '*.py')
```

- Built and imported the local editable vLLM package successfully,
including compiled extension probes for:
  - `vllm._C`
  - `vllm._moe_C`
  - `vllm._C_stable_libtorch`
  - `vllm.vllm_flash_attn._vllm_fa2_C`
  - `vllm.vllm_flash_attn._vllm_fa3_C`
- Verified runtime sparse MLA availability:
  - `AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4`
  - `AttentionBackendEnum.FLASHINFER_MLA_SPARSE_SM120`
  - `has_flashinfer_sparse_mla_sm120() == True`
  - `flashinfer.mla.trtllm_batch_decode_sparse_mla_dsv4` exists
- Started `vllm serve` successfully with `FLASHINFER_MLA_SPARSE_DSV4`,
`fp8_ds_mla`, TP=8, and `max_model_len=262144`.
- Verified OpenAI-compatible API responses:
- `/v1/models` returned `DeepSeek-V4-Flash` with `max_model_len=262144`.
  - `/v1/completions` returned HTTP 200 for a short prompt.
- Verified long-context admission and streaming decode:
- ~200K prompt tokens + `max_tokens=50000` admitted under the 262144
context limit.
  - First streaming SSE token arrived after 28.314s.
- A repeated curl streaming request also returned HTTP 200 and streamed
SSE chunks.
- Ran 1-concurrency streaming benchmark with `stream=true`,
`ignore_eos=true`, `temperature=0`, `max_tokens=20000`, and distinct
prompt prefixes per case:

| Case | TTFT | Total time | Output tokens | End-to-end output
throughput | Decode throughput after first token |
| --- | ---: | ---: | ---: | ---: | ---: |
| 10K input + 20K output | 0.798s | 274.425s | 20,000 | 72.88 tok/s |
73.09 tok/s |
| 100K input + 20K output | 15.549s | 295.159s | 20,000 | 67.76 tok/s |
71.52 tok/s |
| 200K input + 20K output | 22.417s | 304.302s | 20,000 | 65.72 tok/s |
70.95 tok/s |

All three benchmark requests returned HTTP 200 with
`finish_reason=length` and final usage matching the requested
prompt/completion token counts.

## Notes

- The PR deliberately keeps the validated DeepGEMM branch point used by
the successful local serve/benchmark run. PR deepseek-ai/DeepGEMM#324
has since moved forward, so a later performance pass may want to
evaluate its newer head separately.
- The live local API server also required an environment-level
compatibility patch in `prometheus_fastapi_instrumentator` for the
installed FastAPI/Starlette route shape. That patch is outside this vLLM
repository and is not included here.

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **High Risk**
> Touches the core DSv4 sparse attention and KV-cache paths, DeepGEMM
grouped-MoE alignment (prior IMA under cudagraph), and CMake external
deps—high blast radius for SM120 long-context serving.
> 
> **Overview**
> Backports **FlashInfer sparse MLA** for DeepSeek V4
(`FLASHINFER_MLA_SPARSE_DSV4` / SM120 variants) onto the MemoryLake
branch, with local compatibility fixes for CMake, DeepGEMM, MoE, and
warmup.
> 
> **Attention / DSv4:** Refactors V4 MLA into a shared
`DeepseekV4Attention` base with platform subclasses. Adds FlashInfer
TRTLLM-gen sparse decode (`flashinfer_sparse.py`), a Triton helper to
build mixed decode/prefill sparse index matrices, and KV paths for
**fp8_ds_mla** vs plain **bf16 / per-tensor fp8** rows (fused insert ops
+ compressor flags). Compressor/cache alignment only applies 576B
padding for `fp8_ds_mla`.
> 
> **Warmup:** Hooks **DSv4 mHC TileLang** warmup, **FlashInfer sparse
MLA** mixed-batch warmup/autotune (shared cache helpers), and expands
DeepGEMM grouped-GEMM warmup to cover multiple `(M_sum, align_used)`
cases under `mk_alignment_scope`.
> 
> **DeepGEMM MoE:** Introduces `compute_aligned_M_and_alignment`
(tighter padding, SM100/SM120 per-call BLOCK_M), threads `align_m`
through EP scatter, wraps grouped GEMMs in `mk_alignment_scope` to avoid
cudagraph IMAs, and extends FP4 DeepGEMM experts to **SM120**.
> 
> **MXFP4 / quant:** DeepGEMM MXFP4 uses
`deepgemm_post_process_weight_scale_block`; ROCm DSv4 prefers AITER
FlyDSL with native shuffle; E8M0/`uint8` scales decode before Triton FP8
matmul.
> 
> **Build:** `deepgemm.cmake` / `qutlass.cmake` use explicit
`FetchContent_Populate`, validate local `*_SRC_DIR`, skip re-fetch when
cached, and add **SM12x** to DeepGEMM arch lists.
> 
> **MHC:** Adds `hc_head_fused_kernel_tilelang` custom op wrapper.
> 
> <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit
2b9ea1c. Bugbot is set up for automated
code reviews on this repo. Configure
[here](https://www.cursor.com/dashboard/bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
jasl added a commit to jasl/vllm that referenced this pull request Jun 17, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request Jun 18, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants