feat: add sm120 support for DeepGEMM #324
Conversation
|
@leavelet this is nice, after it is merged into the nv-dev branch, should vllm-project/vllm#41834 merge too in order for vllm to be able to work with the branch. |
|
I add my benchmark at vllm-project/vllm#41834 as references |
Replace the hand-written CUDA FP8 GEMV kernel (previously gated to tokens==1) with a port of the SM120 FP8 einsum kernel from upstream DeepGEMM's WIP SM120 support (deepseek-ai/DeepGEMM#324, file `deep_gemm/include/deep_gemm/impls/sm120_fp8_einsum.cuh`). The DeepGEMM kernel implements exactly the `bhr,hdr->bhd` einsum DeepSeek V4 needs, with per-thread-per-output-cell GEMV using fp8x4 vectorized loads and the same block-128 fp32 scale recipe. Removing the tokens==1 gate: the kernel handles all token counts that the SM12x dispatch predicate accepts (tokens <= 16 today; larger token batches will arrive once T1-α expands graph capture). Microbench (DSv4-Flash decode shape, groups=8, hidden=2048, out=1024, GPU idle): tokens=1 cuda 0.026ms triton 0.020ms speedup 0.72x (was 0.40x) tokens=2 cuda 0.027ms triton 0.020ms speedup 0.72x (was 0.73x*) tokens=8 cuda 0.075ms triton 0.020ms speedup 0.27x (was 0.21x*) * Triton-as-default after the previous tokens==1 hotfix. The kernel's grid is `tokens * groups * (out/128)`, one block per `(token, group, out_tile=128)` triple. Because each block reads its weight tile independently, total weight reads scale linearly with `num_tokens`. At graph bs=2 (today) this dominates: tokens<=2 is the production shape and the 0.72x is a real net win against the previous Triton-fallback default. At tokens=8 (future, post-T1-α) the kernel loses ~2x to Triton's m=16 cooperative tile; we will revisit with a multi-token tile design before T1-α exposes that shape to production. Earlier hand-written attempts (one-cell-per-block, per-thread B=16 accumulator tile, 1-warp m16n8 MMA, 4-warp m16n32 cooperative MMA, 4-warp m16n128 MMA) are documented in `docs/notes/2026-05-09-ds4-sm12x-rejected-experiments.md`. The MMA designs hit either occupancy collapse (80 regs/thread) or insufficient parallelism (64 blocks at decode shape vs Blackwell's 140 SMs), capping out at ~0.51x. The DeepGEMM design wins at the production shape by avoiding tensor cores entirely -- a per-thread GEMV with fp8x4 vectorization and L1/L2-friendly weight access fits the small-M decode profile better than the m=16 MMA tile. Attribution: kernel source ported under MIT license from upstream DeepGEMM (Copyright (c) 2025 DeepSeek). Tokenspeed adaptations are the tvm-ffi binding, stride/scale validation, and the SM12x dispatch integration; the dot-product math is unchanged. Signed-off-by: jasl <jasl9187@hotmail.com>
Upstream PR lightseekorg#93 added a pre-flight DeepGEMM ``fp8_gemm_nt`` call to ``DeepseekV4Attention._compute_qr_kv``: on success it replaces the reference FP8 linear path, on failure it logs a WARNING per layer and falls back. DeepGEMM does not support SM120/SM121 yet (see PR ``deepseek-ai/DeepGEMM#324`` + ``reference_deepgemm_sm120`` memory), so on the RTX Pro 6000 workstation every layer fires: DeepSeek V4 DeepGEMM FP8 linear failed; falling back to reference FP8 linear. reason=RuntimeError: Assertion error (csrc/apis/layout.hpp:59): Unknown SF transformation The existing per-layer ``_deepseek_v4_deep_gemm_linear_disabled`` flag already catches this for steady-state replay, but it costs one failed call + one WARNING per layer at boot. Mirror the pattern used by ``_deepseek_v4_deepgemm_fp4_indexer_enabled_for_platform``: short- circuit ``_deepseek_v4_get_fp8_linear_deep_gemm`` to ``None`` on SM12x so the platform never tries the DeepGEMM path. Non-SM12x platforms keep the new fast path. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
|
nice work! may I ask the hardware for testing here is either 5090 or RTX6000pro? |
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Phase 1a: infrastructure + dense FP8 GEMM kernel for SM120a (CC 12.0). Architecture: warp-level mma.sync with block-scaled UE8M0 scale factors, B128 XOR swizzle, persistent scheduling, register-based epilogue. New files: - SM120 heuristics, JIT codegen, MMA PTX wrappers, ldmatrix/swizzle utils - CUDA kernel with warp-specialized TMA/math pipeline (3-9 stages) Modified files: - Arch detection, compiler flags (-gencode for SM120a) - API dispatch (arch_major == 12), SF layout transform - Default recipe for SM120 Correctness: 8/8 shapes pass (diff < 0.001 cosine distance) Performance: ~73 TFLOPS (baseline, optimization pending) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… only Drop the non-warp-specialized kernel path for SM120a (matching SM90/SM100 architecture), merging the warp-specialized implementation into the main sm120_fp8_fp4_gemm_1d1d kernel. Add FP4 GEMM support using packed SMEM with the mxf4nvf4 m16n8k64 MMA instruction. Key changes: - Consolidate: remove non-spec path, always BM=128/BK=128/384 threads - FP4: packed 4-bit SMEM (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B), standard ldmatrix, uint16_t scale factors (scale_vec::2X), kKSteps=2 vs FP8's 4 - Heuristic: simplified to warp-spec only, correct SMEM sizing for FP4 - API: enable FP4 on SM120a (arch_major==12), add fp8_fp4_gemm_nt binding - Fix SF hoist bug: hoist SFA/SFB independently for mixed gran_k configs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Kernel: Add TMA descriptor runtime update (tensormap.replace) in producer loop for K-grouped group transitions, fix SF_K_ALIGNMENT to kGranKA*4, fix SMEM layout (pipeline data at offset 0 for B128 swizzle alignment, tensor map descriptors at end), fix epilogue bounds for multi-group output. - MMA wrappers: Replace CUTLASS mma_sm120.hpp dependency with custom inline asm using "+f" read-write constraints for accumulator registers. Eliminates CUTLASS header dependency and gives explicit control over MMA operand encoding for both FP8 (m16n8k32) and FP4 (m16n8k64) block-scaled MMA. - JIT launcher: Add sm120_k_grouped_fp8_fp4_gemm_1d1d() with proper TMA descriptor creation (first_k base, FP4-aware stride), SF TMA covering concatenated groups, CD TMA with num_groups outer dimension. - API dispatch: Add arch_major==12 path in k_grouped_fp8_gemm_nt_contiguous, relax recipe assertion to support gran_k=32/128, add SM120 SF layout transform with auto-detection of transposed K-major scale factors. - Tests: Add dedicated SM120 K-grouped test (7 configs including zero-K edge case), fix K-major selection for SM120 in generators, fix test dispatch for SM120 in test_fp8_fp4.py, update FP4 test with perf comparison. Tested: Dense FP8 8/8, Dense FP4 10/10, K-grouped FP8 7/7 — all PASS. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Einsum support: - Add GemmType::Batched to FP8/FP4 and BF16 kernels with 3D TMA load/store - Add IndexType::SF_K for batched SF coordinate computation - Add MN-major B support to BF16 kernel (scalar SMEM loads, single-atom constraint) - BF16 bhr,hdr->bhd: 384 TFLOPS, FP8: 681 TFLOPS (batch=8, b=8192) M-grouped BF16: - Add contiguous and masked M-grouped BF16 GEMM launchers HC prenorm TF32: - New fused GEMM + sqr_sum kernel using mma.sync.m16n8k8 TF32 (226T peak) - BF16 A -> FP32 cast with fused sqr_sum accumulation - Atom-aware FP32 B fragment loading from K-major SMEM - Split-K support for large K / small M shapes - 24/24 test shapes PASS, ~1.1 TB/s bandwidth on large shapes Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Dense (ragged): 651 TFLOPS peak (80% of FP8 MMA peak 814T), 40/40 tests pass. Paged (KV cache): 320 TFLOPS peak, 1.36 TB/s DRAM (91% of HBM BW), 8/8 tests pass. Kernel design: warp-specialized mma.sync m16n8k32 FP8 (no block_scale). 8 math warps × 16 KV rows each = 128 BLOCK_KV. In-warp 2-shfl reduction across 4 threads (lane%4) — only ~10 cycles, negligible vs MMA time. Global stores are fire-and-forget on SM120a, so no epilogue warps needed. Key parameters: block_qh=128, num_heads=64, head_dim=128, 2 Q stages, 3 KV stages, 84KB SMEM (83% of 101KB capacity). Paged variant: 2 groups of 4 warps, SPLIT_KV=128, per-group KV pipeline. Fixed metadata split_kv mismatch and register budget overflow (TMA regs 64→40 to stay within 65536 register limit). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- skip_head_mid: Add SM120 dispatch in attention.hpp for EpilogueHeadSplits. Fix three issues: TMA CD descriptor uses d.size(-1)/d.stride(-2) instead of n; kernel uses stride_d parameter for D row stride and bounds checks; TMA store coordinates apply epilogue N-index remapping. - MN-major B: Fix kernel TMA coordinate for M-grouped BF16 with MN-major B. Group offset moves to outer=K coordinate (not inner=N) when kBKMajor=false. - FP8 kernel stride_d: Add stride_d parameter to decouple D tensor stride from computation dimension n, enabling epilogue transforms that expand N.
Replace per-element scalar SMEM loads in the MN-major B path with ldmatrix.sync.aligned.x2.m8n8.trans.shared.b16 which natively loads column-major 8x8 BF16 matrices — directly producing MMA B fragments. Performance: MN-major B improves from ~220T to ~290T (dense) and ~330T (M-grouped), a 30-50% gain. Remaining gap vs K-major (400T) is due to heuristic selecting BLOCK_N=32/64 vs 128 (single swizzle atom constraint). Verified by micro benchmark b_bf16_4: fragment layout 32x32 lanes PASS, MMA pipeline 4 K-steps accumulation PASS.
Remove single-atom BLOCK_N constraint for MN-major B. ldmatrix.trans correctly handles multi-atom SMEM (verified by micro benchmark b_bf16_5). MN-major B now achieves 99-102% of K-major performance (was 53% with scalar loads, 80% with single-atom ldmatrix.trans).
…bhr->hdr New BF16 bmk,bnk->mn reduction kernel with split-S atomicAdd to FP32 output (188T peak, HBM BW limited). FP8 einsum dispatch for bhd,hdr->bhr and bhd,bhr->hdr via .contiguous() to K-major. Fix batched epilogue stride formula: replace single stride_d with stride_cd_m + stride_cd_batch to support arbitrary D layouts ([batch,M,N] vs [M,batch,N]). Add kBKMajor template parameter to FP8 kernel with verified scalar-load MN-major B path (correct but 3x slower than K-major ldmatrix, kept for future optimization).
K-grouped TN: single .t().contiguous() transpose with constant-stride TMA (kKGroupedConstantStride) — per-group only replaces addr+dim, not stride. TN achieves 99-101% of NT performance. New PTX tensor_map_replace_global_dim_in_smem. Paged MQA varlen: fix 4 kernel bugs in sm120_fp8_paged_mqa_logits.cuh: - TMA Q coordinate: use atom_to_token_idx() instead of hardcoded *kNextNAtom - Prefetch advance: use get_atom_advance() instead of hardcoded +1 - Math loop: conditional iteration count via is_paired_atom for unpaired atoms - KV block idx: reset kv_block_idx_ptr=32 on q_atom change
New kernels using mma.sync m16n8k64 block-scaled FP4 (mxf4nvf4, scale_vec::2X). Architecture: 8 math warps + 4 TMA warps, B64 swizzle, kKSteps=2. Block-scaled MMA folds UE8M0 SF into computation — no post-MMA scale. Dense FP4 MQA: 1022 TFLOPS peak (63% FP4 peak), 1.6x vs FP8. Paged FP4 MQA: 707 TFLOPS peak (varlen), 566 TFLOPS (non-varlen next_n=2).
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
…r changes The SM120 MQA-logits work modified the shared launchers in two ways that broke SM100 (and were latent on SM90), surfacing as test_attention failures on B300 (baseline=origin/nv_dev passes, sm120 HEAD fails). No SM100/SM90 device kernel (.cuh) changed, so this is purely a host launch-config regression. 1. Dense MQA SMEM under-allocation (smxx_fp8_fp4_mqa_logits.hpp). The SMEM size dropped the `(num_math_threads / 128) * 2` mbarrier pairs that the SM90/SM100 kernels still allocate (the SM120 kernel does not). On SM100 this under-sized the dynamic SMEM; compute-sanitizer reports an "Invalid __shared__ write" in sm100_fp8_mqa_logits, faulting (CUDA_ERROR_ILLEGAL_ADDRESS) at large configs such as seq_len_kv=130560. Restore the term for arch != 12. 2. Paged-MQA metadata capacity assert (smxx_fp8_fp4_paged_mqa_logits.hpp). An unconditional `smem_size <= SM120ArchSpec::smem_capacity` (99 KB) was added to the metadata launcher, which runs on every arch. On SM90/SM100 (228 KB) a large (varlen) batch's metadata SMEM legitimately exceeds 99 KB and falsely tripped the assert. Gate the capacity check to the running arch. Verified on B300 (sm100): compute-sanitizer memcheck clean on the previously out-of-bounds config (seq_len=510, seq_len_kv=130560); full test_attention (dense + paged, NextN 1-6, FP8/FP4) passes. Fix (1) also corrects a latent SMEM under-allocation on SM90.
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Strip changes that should not ship upstream and align comment style with the existing codebase (terse, no decorative banners). No change to production GEMM/attention behavior — only removes a dev-only bench path and trims comments. - .gitignore: revert to upstream; personal ignores (docs_internal/, benchmarks/, ncu_reports/, internal/) moved to local .git/info/exclude. - Remove dev-only sm120_fp8_gemm_bench API and its override_layout tile-override param (gemm.hpp, __init__.py, sm120_fp8_fp4_gemm_1d1d.hpp); config selection collapses to the standard get_best_config path. - Trim changelog-/rationale-style block comments to single lines (shared rationale lives in commit history): MQA-logits SMEM, paged-MQA capacity assert, cd_n_contiguous, split-K SF alignment, latency model, einsum MN-major notes. - De-duplicate the AB-swap "transposed output" rationale (kept once in config.hpp) and the UE8M0 SF-tile note. - Remove decorative // ==== banner comments from sm120 .cuh (no other kernel uses them). - Remove an unused num_waves local in the SM120 latency model (clears a -Wunused-variable warning). - Remove tests/test_split_k_swap.py (standalone dev harness, not pytest-style).
|
@RayWang96 I have cleaned up the code and verified them on sm90, sm100 and sm120, the PR is ready for merge. Thanks! |
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
…pGEMM@sm120) Enable DeepGEMM grouped FP8×FP4 GEMM for MoE on SM120 (RTX 6000D/PRO 6000). Requires leavelet/DeepGEMM@sm120 branch (deepseek-ai/DeepGEMM#324, not yet merged). Changes: - configurer.py: Allow SM120 only when SM120-compatible DeepGEMM is installed (checks for m_grouped_fp8_fp4_gemm_nt_contiguous availability) - server_args.py: Auto-select deep_gemm MoE backend on SM120 - kernels.py: Add UE8M0 (power-of-2) FP8 quantization Triton kernel required by DeepGEMM's block-scaled dequantization on SM120 - deep_gemm.py: SM120 adaptations for DeepGEMM MoE runner: - TMA-aligned scale factors for grouped GEMM (hidden_states + down_input) - JIT EP activation fallback when hidden_dim/8 < num_experts (TP>=2) - In-place swiglu clamp replacing torch.chunk+cat (-7.4ms/step decode) - fp8.py: Add .contiguous() before transform_sf_into_required_layout - fp8_utils.py: Skip DeepGEMM dense FP8 linear on SM120 (bf16_gemm_nt unsupported) Performance (TP=4, BS=1, RTX 6000D 85GB, ISL=8K): TTFT: 130ms (vs 400ms marlin, 3x faster) Decode ITL: 47ms (vs 41ms marlin, 15% slower) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
…pGEMM@sm120) Enable DeepGEMM grouped FP8×FP4 GEMM for MoE on SM120 (RTX 6000D/PRO 6000). Requires leavelet/DeepGEMM@sm120 branch (deepseek-ai/DeepGEMM#324, not yet merged). Changes: - configurer.py: Allow SM120 only when SM120-compatible DeepGEMM is installed (checks for m_grouped_fp8_fp4_gemm_nt_contiguous availability) - server_args.py: Auto-select deep_gemm MoE backend on SM120 - kernels.py: Add UE8M0 (power-of-2) FP8 quantization Triton kernel required by DeepGEMM's block-scaled dequantization on SM120 - deep_gemm.py: SM120 adaptations for DeepGEMM MoE runner: - TMA-aligned scale factors for grouped GEMM (hidden_states + down_input) - JIT EP activation fallback when hidden_dim/8 < num_experts (TP>=2) - In-place swiglu clamp replacing torch.chunk+cat (-7.4ms/step decode) - fp8.py: Add .contiguous() before transform_sf_into_required_layout - fp8_utils.py: Skip DeepGEMM dense FP8 linear on SM120 (bf16_gemm_nt unsupported) Performance (TP=4, BS=1, RTX 6000D 85GB, ISL=8K): TTFT: 130ms (vs 400ms marlin, 3x faster) Decode ITL: 47ms (vs 41ms marlin, 15% slower) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The SM120 fp8/fp4 1D1D kernel iterated every ceil(M_sum/BLOCK_M) m-block and computed invalid (m_indices == -1) padding tiles against expert 0. SM90 skips these via Scheduler::is_computation_valid, but that helper is marked SM90-only and the SM120 producer issued the A/B/SF TMA copies unconditionally. At MoE decode the contiguous worst-case M_sum reserves a block per local expert (min(M*topk, local_experts)) while only a few are routed, so most m-blocks are empty padding. Computing them wastes a full-width GEMM tile — the dominant cost of EP decode. Measured on RTX PRO 6000, EP gate_up (N=4096, K=4096, 64 local experts, ~20 routed): worst-case 20-valid + 45-empty of 65 blocks = 304us vs the 20 real blocks alone = 170us; the empty tiles cost 134us (44%). Add the m_indices < 0 skip-continue to both the producer and consumer while-loops. The check is identical in both, so no barrier ops are issued for skipped blocks and the warp-specialized pipeline stays in sync. Empty experts now cost nothing: EP decode MoE GEMM drops ~1.8x (304 -> 169us), matching the TP-sharded layout. The all-valid path (prefill/dense) is unchanged. Validated: m-grouped contiguous correctness for fp8xfp8 / fp4xfp4 / fp8xfp4 (diff < 2%), worst-case-padded result == tight, all-65-valid timing unchanged.
Adds the reverse of the existing fp8-A x fp4-B mixed path: A is fp4 (e2m1), B is fp8 (e4m3), via mxf8f6f4 .e2m1.e4m3 (m16n8k32). Lets the expert weight sit on the well-filled M axis and the (few) decode tokens on a small N axis. - mma/sm120.cuh: fp4_fp8_mixed_mma_block_scaled (.e2m1.e4m3). - common/sm120_utils.cuh: ldmatrix_m8n16_x4_b4x16_p64 + load_a_fragment_b4x16 (4-reg b4x16 unpack for the fp4 A operand; same addressing as the fp8 A load). - impls/sm120_fp8_fp4_gemm_1d1d.cuh: kAIsFP4 template flag; A loaded as fp4-unpacked (b4x16 + <<2), TMA_A_BYTES = SMEM_A/2, MMA dispatch, perNTileX4 disabled for kAIsFP4. - jit_kernels/impls/sm120_fp8_fp4_gemm_1d1d.hpp: detect A=fp4 & B=fp8 (a_is_fp4); is_fp4 redefined to symmetric fp4xfp4 (value-identical for existing cases); emit kAIsFP4. - heuristics/sm120.hpp: a_padded_fp4 -> the fp4 A operand uses unpacked b4x16 SMEM (row = block_k, swizzle 128), mirroring the fp4 B handling. - tests/generators.py: add the FP4_A x FP8_B QuantConfig (32,128,True,False) to the SM120 sweep; the FP8 operand of a mixed config is 1D-scaled (per-token, recipe (1, gran)), not per-block. Validated on RTX PRO 6000: full fp8_fp4 sweep (normal / m-grouped contiguous / m-grouped masked, all dtype configs incl. the new FP4_A x FP8_B) passes (diff < max_diff, ~0.007 for fp4xfp8). Existing paths unchanged (a_is_fp4 defaults false). SM100 untouched. Note: for MoE decode this orientation is bandwidth-saturated (~92% DRAM) and only ~3-6% faster than the standard layout; the empty-tile skip is the real EP-decode lever. The K-grouped launcher A-stride is not yet wired for fp4-A.
|
@lucifer1004 @AliceChenyy cc @samuellees Please switch to commit 33a715e for better MoE EP performance, especially for decode. The commit bump alone gives a 2x kernel speedup, and swapAB is an optional add-on worth roughly 6% |
| z = torch.empty((b, h, d), device='cuda', dtype=torch.bfloat16) | ||
| deep_gemm.einsum('bhr,hdr->bhd', x, y, z) | ||
| assert calc_diff(z, ref_z) < 1e-10 | ||
| assert calc_diff(z, ref_z) < 1e-7 |
There was a problem hiding this comment.
What's the justification for the relaxation? Is it driven by SM120 not passing at 1e-10? If so, what makes SM120's einsum less accurate here.
There was a problem hiding this comment.
The test compares DeepGEMM's bf16 output against cuBLAS's; both use identical HMMA.16816.F32.BF16 + a single bf16 cast, so they differ only in FP32 K-reduction order. At small M, cuBLAS's nvjet kernel sums 3 partial accumulators (vs DeepGEMM's single linear accumulator), which is marginally more accurate and makes ~0.2% of low-magnitude outputs differ, exceeding the near-bit-exact 1e-10 (lands ~8e-9). Against an FP64 reference DeepGEMM and cuBLAS are equally accurate (~1.4e-6, ratio 1.00). I will propose a fix, keeps 1e-10for SM90/SM100, loosens to1e-7 only on SM120, and adds an FP32-reference correctness check there (< 1e-5). Does this sound correct?
Behavior-preserving cleanups from upstream review. gemm.hpp: pull the SM120 arm of fp8_fp4_gemm_nt into a dedicated fp8_fp4_gemm_nt_sm120() helper and dispatch in natural arch order (9/10 share the SF-transform path, then 12). Add a shared sm120_to_k_major() helper for the "repack packed-FP4 / contiguous to force K-major" idiom, reused by the FP8/FP4 NT GEMM, the m-grouped contiguous arm, and the BF16 GEMM. runtime.hpp: replace the coupled arch==12 ternaries in get_theoretical_mk_alignment_for_contiguous_layout() with a per-arch ContiguousMKAlignment policy plus a single shrink loop. SM90/SM100/SM120 results are unchanged; verified on SM120 with test_fp8_fp4.py and test_bf16.py.
The BF16 einsum tests compared DeepGEMM's bf16 output against cuBLAS's at 1e-10 (near bit-exact). On SM120 the two differ only in FP32 accumulation order (cuBLAS uses a multi-accumulator K-reduction at small M), so ~0.2% of low-magnitude outputs differ by 1 ULP and the comparison lands ~8e-9 -- not a DeepGEMM accuracy issue (both match the FP64 truth equally, ~1.4e-6). Keep the strict 1e-10-vs-cuBLAS bound for SM90/SM100, loosen to 1e-7 only on SM120, and add an architecture-independent FP32-reference correctness check on SM120 (calc_diff < 1e-5; DeepGEMM measures ~1.4e-6).
## Summary This PR backports the DeepSeek V4 FlashInfer sparse MLA / SM120 work from the community PR vllm-project#43477 onto `memorylake-ai/v0.22.1-dev`. Community reference: vllm-project#43477 This is not a pure cherry-pick. The majority of the changed files are overlaid from PR vllm-project#43477, but this branch also includes compatibility fixes needed to make that code build, import, serve, and run long-context DeepSeek V4 Flash on the current MemoryLake vLLM branch and the locally built dependency stack. ## What changed - Adds FlashInfer sparse MLA backend registration for SM120 / DeepSeek V4: - `FLASHINFER_MLA_SPARSE_SM120` - `FLASHINFER_MLA_SPARSE_DSV4` - Adds the DeepSeek V4 FlashInfer sparse attention implementation and shared sparse MLA helpers. - Adds FlashInfer sparse MLA warmup and autotune cache helpers. - Adds DeepSeek V4 sparse/compressor integration needed by the FlashInfer sparse path. - Ports the CUTE sparse compressor implementation used by the DSv4 path. - Updates CUDA platform/backend wiring so the new attention backend can be selected through the usual vLLM attention backend enum path. - Updates CMake external project handling for DeepGEMM and QuTLASS source overrides used by this local build flow. ## Compatibility fixes on top of PR vllm-project#43477 Several local adjustments were required because this branch is older than the community PR base and because the tested FlashInfer / DeepGEMM APIs differ from the PR branch assumptions: - FlashInfer sparse MLA symbols are imported from `flashinfer.mla` in the locally built FlashInfer source, rather than `flashinfer.decode`. - `MoERunner` import/API usage is adapted to the current `memorylake-ai/v0.22.1-dev` fused-MoE layout. - DeepGEMM MoE expert classes restore `supports_expert_map()` so they satisfy the current modular expert interface. - Triton FP8 block-scaled matmul decodes E8M0 / raw `uint8` scale tensors before kernel launch, avoiding `torch.float8_e8m0fnu` binding failures on this stack. - The FlashInfer sparse O-projection path prepares `wo_a` weights and scale tensors into the grouped layout expected by DeepGEMM `fp8_einsum`. - TileLang MHC wrappers are adapted to the current `_tilelang_ops` import path and custom-op registration shape. - FlashMLA import in the DeepSeek V4 NVIDIA model path is made lazy so the selected FlashInfer backend can import without requiring the unrelated FlashMLA overlay class. ## Tested stack Tested locally on `zp-nc522` with 8x RTX 6000D / SM120 using: ```bash vllm serve /models/deepseek-ai/DeepSeek-V4-Flash \ --served-model-name DeepSeek-V4-Flash \ --host 0.0.0.0 \ --port 8000 \ --tensor-parallel-size 8 \ --max-model-len 262144 \ --max-num-seqs 1 \ --max-num-batched-tokens 32768 \ --enable-chunked-prefill \ --gpu-memory-utilization 0.90 \ --attention-backend FLASHINFER_MLA_SPARSE_DSV4 \ --kv-cache-dtype fp8_ds_mla \ --linear-backend triton ``` The local dependency stack used the MemoryLake vLLM source build, local FlashInfer source build, local DeepGEMM source at the currently validated SM120 branch point, local FlashMLA, and local vLLM FlashAttention source. ## Validation - Ran Python syntax validation on all staged Python files: ```bash python -m py_compile $(git diff --cached --name-only -- '*.py') ``` - Built and imported the local editable vLLM package successfully, including compiled extension probes for: - `vllm._C` - `vllm._moe_C` - `vllm._C_stable_libtorch` - `vllm.vllm_flash_attn._vllm_fa2_C` - `vllm.vllm_flash_attn._vllm_fa3_C` - Verified runtime sparse MLA availability: - `AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4` - `AttentionBackendEnum.FLASHINFER_MLA_SPARSE_SM120` - `has_flashinfer_sparse_mla_sm120() == True` - `flashinfer.mla.trtllm_batch_decode_sparse_mla_dsv4` exists - Started `vllm serve` successfully with `FLASHINFER_MLA_SPARSE_DSV4`, `fp8_ds_mla`, TP=8, and `max_model_len=262144`. - Verified OpenAI-compatible API responses: - `/v1/models` returned `DeepSeek-V4-Flash` with `max_model_len=262144`. - `/v1/completions` returned HTTP 200 for a short prompt. - Verified long-context admission and streaming decode: - ~200K prompt tokens + `max_tokens=50000` admitted under the 262144 context limit. - First streaming SSE token arrived after 28.314s. - A repeated curl streaming request also returned HTTP 200 and streamed SSE chunks. - Ran 1-concurrency streaming benchmark with `stream=true`, `ignore_eos=true`, `temperature=0`, `max_tokens=20000`, and distinct prompt prefixes per case: | Case | TTFT | Total time | Output tokens | End-to-end output throughput | Decode throughput after first token | | --- | ---: | ---: | ---: | ---: | ---: | | 10K input + 20K output | 0.798s | 274.425s | 20,000 | 72.88 tok/s | 73.09 tok/s | | 100K input + 20K output | 15.549s | 295.159s | 20,000 | 67.76 tok/s | 71.52 tok/s | | 200K input + 20K output | 22.417s | 304.302s | 20,000 | 65.72 tok/s | 70.95 tok/s | All three benchmark requests returned HTTP 200 with `finish_reason=length` and final usage matching the requested prompt/completion token counts. ## Notes - The PR deliberately keeps the validated DeepGEMM branch point used by the successful local serve/benchmark run. PR deepseek-ai/DeepGEMM#324 has since moved forward, so a later performance pass may want to evaluate its newer head separately. - The live local API server also required an environment-level compatibility patch in `prometheus_fastapi_instrumentator` for the installed FastAPI/Starlette route shape. That patch is outside this vLLM repository and is not included here. <!-- CURSOR_SUMMARY --> --- > [!NOTE] > **High Risk** > Touches the core DSv4 sparse attention and KV-cache paths, DeepGEMM grouped-MoE alignment (prior IMA under cudagraph), and CMake external deps—high blast radius for SM120 long-context serving. > > **Overview** > Backports **FlashInfer sparse MLA** for DeepSeek V4 (`FLASHINFER_MLA_SPARSE_DSV4` / SM120 variants) onto the MemoryLake branch, with local compatibility fixes for CMake, DeepGEMM, MoE, and warmup. > > **Attention / DSv4:** Refactors V4 MLA into a shared `DeepseekV4Attention` base with platform subclasses. Adds FlashInfer TRTLLM-gen sparse decode (`flashinfer_sparse.py`), a Triton helper to build mixed decode/prefill sparse index matrices, and KV paths for **fp8_ds_mla** vs plain **bf16 / per-tensor fp8** rows (fused insert ops + compressor flags). Compressor/cache alignment only applies 576B padding for `fp8_ds_mla`. > > **Warmup:** Hooks **DSv4 mHC TileLang** warmup, **FlashInfer sparse MLA** mixed-batch warmup/autotune (shared cache helpers), and expands DeepGEMM grouped-GEMM warmup to cover multiple `(M_sum, align_used)` cases under `mk_alignment_scope`. > > **DeepGEMM MoE:** Introduces `compute_aligned_M_and_alignment` (tighter padding, SM100/SM120 per-call BLOCK_M), threads `align_m` through EP scatter, wraps grouped GEMMs in `mk_alignment_scope` to avoid cudagraph IMAs, and extends FP4 DeepGEMM experts to **SM120**. > > **MXFP4 / quant:** DeepGEMM MXFP4 uses `deepgemm_post_process_weight_scale_block`; ROCm DSv4 prefers AITER FlyDSL with native shuffle; E8M0/`uint8` scales decode before Triton FP8 matmul. > > **Build:** `deepgemm.cmake` / `qutlass.cmake` use explicit `FetchContent_Populate`, validate local `*_SRC_DIR`, skip re-fetch when cached, and add **SM12x** to DeepGEMM arch lists. > > **MHC:** Adds `hc_head_fused_kernel_tilelang` custom op wrapper. > > <sup>Reviewed by [Cursor Bugbot](https://cursor.com/bugbot) for commit 2b9ea1c. Bugbot is set up for automated code reviews on this repo. Configure [here](https://www.cursor.com/dashboard/bugbot).</sup> <!-- /CURSOR_SUMMARY -->
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
This PR adds first-class
sm_120support to DeepGEMM, maintained on thenv_devbranch by the NVIDIA DevTech APAC team. It brings the full DeepGEMM kernel surface — dense, grouped/MoE, einsum, hyper-connection, and MQA-logits — to the sm120 and sm121 devices like RTX Pro 6000 and DGX Spark.co-authored with @lucifer1004.
Feature surface
Dense GEMM (
sm120_fp8_fp4_gemm_1d1d,sm120_bf16_gemm)mxf4nvf4, gran_k=32)mxf8f6f4 .e4m3.e2m1) and FP4_A × FP8_B(
.e2m1.e4m3) — both directionsGrouped GEMM / MoE
Batched GEMM / Einsum (
sm120_bmk_bnk_mn)bhr,hdr->bhd) and MN-major B (bhd,hdr->bhr)bmk,bnk->mn); FP8 / BF16; small-M AB-swapHC Prenorm (
sm120_tf32_hc_prenorm_gemm)**MQA Logits ** (
sm120_fp8_mqa_logits,sm120_fp4_mqa_logits+ paged)paged (next_n / kPadOddN support)
Performance (RTX PRO 6000 Blackwell, single GPU,
tests/test_*.py)Roofline used: FP8 ≈ 814 TFLOPS, FP4 ≈ 1628 TFLOPS (2× FP8), BF16 ≈ 445 TFLOPS. GB/s is effective bandwidth (cache-inclusive for small/cached shapes).
Dense GEMM
Grouped GEMM (MoE)
Einsum
Peak 742 TFLOPS (
b=8192,h=8,r=4096,d=1024); up to 2893 GB/s (small-batch, cache-bound); Split-S path scales down for small problems.HC Prenorm
m=8192, n=24, k=28672, splits=16: 29 TFLOPS / 1248 GB/s (bandwidth-bound).MQA Logits