[SM120] Add FlashInfer sparse MLA decode for DSv4-Flash#27455
[SM120] Add FlashInfer sparse MLA decode for DSv4-Flash#27455AliceChenyy wants to merge 18 commits into
Conversation
Integrate FlashInfer's SM120 sparse_mla_sm120 decode kernel for
DeepSeek-V4-Flash on RTX PRO 6000 (SM120). Decode ITL improves
1.84x vs Triton (59ms vs 109ms, ISL=8K OSL=1K BS=1).
Key challenge: SGLang SWA KV cache uses page_size=256 with footer
layout (scales at page end), but FlashInfer decode_dsv4 requires
page_block_size=64. Solution: page-split at call site — split
256-token pages into 4 virtual 64-token pages with correct footer
layout and remap indices before calling FlashInfer.
- Auto-detect FlashInfer SM120 via try-import, fallback to Triton
- Page-split + index remap in _flash_mla_flashinfer()
- Uses BatchSparseMLAPagedAttentionWrapper.run() for decode dispatch
- Extra cache (C4/C128) passed as-is (native dual-cache support)
- GSM8K 10/10 = 100%, CUDA graph BS=1,2,4,8 all pass
- Env var SGLANG_SM120_FLASHMLA_BACKEND={auto,flashinfer,triton}
Requires: flashinfer >= 0.6.12 with sparse-mla-sm120 support
(lucifer1004/flashinfer@sparse-mla-sm120, PR flashinfer-ai/flashinfer#3395)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces support for the FlashInfer SM120 sparse MLA backend, enabling page-split utilities to convert 256-token pages into 64-token pages for FlashInfer's decode fast path. It also updates the DeepSeek-V4 backend to dynamically retrieve the SWA page size and relaxes the window size assertion. Feedback highlights several optimization and robustness opportunities: reusing global scratch buffers to avoid allocating temporary workspace tensors on every decode step, eliminating the redundant _remap_indices_to_64 function which mathematically simplifies to an identity function, and adding a device check to the lazy allocation of _split_buf to ensure robustness during multi-GPU execution.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| # Lazily initialized FlashInfer wrapper (reused across calls). | ||
| _flashinfer_wrapper = None |
There was a problem hiding this comment.
Define global scratch buffers to avoid allocating temporary workspace tensors (out_lse, mid_out, mid_lse) on every single decode step.
| # Lazily initialized FlashInfer wrapper (reused across calls). | |
| _flashinfer_wrapper = None | |
| # Lazily initialized FlashInfer wrapper (reused across calls). | |
| _flashinfer_wrapper = None | |
| _out_lse_buf = None | |
| _mid_out_buf = None | |
| _mid_lse_buf = None |
There was a problem hiding this comment.
These tensors are allocated during CUDA graph capture and reused on replay — PyTorch's caching allocator returns the same memory block each time. Pre-allocating global buffers would add device/shape management complexity for no measurable gain in the CUDA graph hot path. Will revisit if profiling shows allocation overhead in non-graph paths.
| def _remap_indices_to_64(idx: torch.Tensor, src_pbs: int) -> torch.Tensor: | ||
| """Remap token indices from pbs=N addressing to pbs=64 addressing. | ||
|
|
||
| Invalid indices (-1) are preserved. | ||
| """ | ||
| if src_pbs == _PBS_DST: | ||
| return idx | ||
| ratio = src_pbs // _PBS_DST | ||
| valid_mask = idx >= 0 | ||
| page_src = torch.where(valid_mask, idx // src_pbs, torch.zeros_like(idx)) | ||
| offset_src = torch.where(valid_mask, idx % src_pbs, torch.zeros_like(idx)) | ||
| sub_page = offset_src // _PBS_DST | ||
| offset64 = offset_src % _PBS_DST | ||
| fi_idx = (page_src * ratio + sub_page) * _PBS_DST + offset64 | ||
| return torch.where(valid_mask, fi_idx, idx) |
There was a problem hiding this comment.
The index remapping function _remap_indices_to_64 is mathematically an identity function for all valid indices (idx >= 0). Specifically, the formula:
simplifies exactly to idx because ratio * PBS_DST = src_pbs. Therefore, we can completely eliminate this function and the associated GPU overhead.
# Note: _remap_indices_to_64 is mathematically an identity function because
# (idx // src_pbs * ratio + (idx % src_pbs) // 64) * 64 + (idx % src_pbs) % 64
# simplifies exactly to idx. Therefore, we can omit this remapping entirely.There was a problem hiding this comment.
Removed _remap_indices_to_64 entirely in 7bdc051.
| # --- Index remapping: pbs=N token index -> pbs=64 token index --- | ||
| idx = indices.squeeze(1) if indices.dim() == 3 else indices | ||
| idx_64 = _remap_indices_to_64(idx, src_pbs) if src_pbs != _PBS_DST else idx |
There was a problem hiding this comment.
Since _remap_indices_to_64 is an identity function, we can directly use idx without any remapping, avoiding redundant GPU operations in the hot path of decode.
| # --- Index remapping: pbs=N token index -> pbs=64 token index --- | |
| idx = indices.squeeze(1) if indices.dim() == 3 else indices | |
| idx_64 = _remap_indices_to_64(idx, src_pbs) if src_pbs != _PBS_DST else idx | |
| # --- Index remapping: pbs=N token index -> pbs=64 token index --- | |
| idx_64 = indices.squeeze(1) if indices.dim() == 3 else indices |
There was a problem hiding this comment.
Done in 7bdc051. Using idx directly, no remapping.
| if _split_buf is None or _split_buf.shape[0] < num_dst_pages: | ||
| _split_buf = torch.empty( | ||
| num_dst_pages, _BYTES_PER_DST_PAGE_PADDED, | ||
| dtype=torch.uint8, device=kv_u8.device, | ||
| ) |
There was a problem hiding this comment.
To ensure robustness across different GPU devices (e.g., during multi-GPU execution or unit testing), add a device check to the lazy allocation of _split_buf so it is reallocated if the device of kv_u8 changes.
| if _split_buf is None or _split_buf.shape[0] < num_dst_pages: | |
| _split_buf = torch.empty( | |
| num_dst_pages, _BYTES_PER_DST_PAGE_PADDED, | |
| dtype=torch.uint8, device=kv_u8.device, | |
| ) | |
| if _split_buf is None or _split_buf.shape[0] < num_dst_pages or _split_buf.device != kv_u8.device: | |
| _split_buf = torch.empty( | |
| num_dst_pages, _BYTES_PER_DST_PAGE_PADDED, | |
| dtype=torch.uint8, device=kv_u8.device, | |
| ) |
There was a problem hiding this comment.
Already addressed in commit 042ffb6 — _split_buf is a per-device dict ({} # device -> tensor), so each GPU gets its own buffer.
- Fix incorrect comment (swa_page_size=64 → describe page-split conversion) - Add backward compat for legacy SGLANG_SM120_TRITON_FLASHMLA env var - Restore one-arg-per-line formatting for triton/torch call paths - Update docstring to reflect three-backend dispatch - Use per-device dict for _flashinfer_wrapper and _split_buf (TP>1 safety) - Replace 8 separate copy kernels in _split_kv_pages_to_64 with a fused Triton kernel (_page_split_kernel), reducing 344 kernel launches per decode step to 1 per layer. Verified 9.3% ITL improvement on RTX 6000D TP=4 (45.00ms → 40.83ms). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| # Controlled by SGLANG_SM120_FLASHMLA_BACKEND env var (auto/flashinfer/triton/torch). | ||
| # Legacy: SGLANG_SM120_TRITON_FLASHMLA=0 maps to "torch" for backward compat. | ||
| def _detect_sm120_backend(): | ||
| env = os.environ.get("SGLANG_SM120_FLASHMLA_BACKEND", "auto") |
There was a problem hiding this comment.
Use the envs classes that are pre-existing
There was a problem hiding this comment.
Simplified in 7bdc051. Removed _detect_sm120_backend() entirely — now just one line: os.environ.get("SGLANG_SM120_FLASHMLA_BACKEND", "flashinfer").
There was a problem hiding this comment.
What I mean is, envs.get_env, or something like that.
| return "triton" | ||
|
|
||
|
|
||
| _sm120_default_backend = _detect_sm120_backend() | ||
|
|
||
|
|
||
| def flash_mla_with_kvcache_sm120(**kwargs): |
There was a problem hiding this comment.
Should we still call this "flashmla"
There was a problem hiding this comment.
How about change to flashinfer_sparse_mla?
There was a problem hiding this comment.
Sure, that sounds better to me
…pGEMM@sm120) Enable DeepGEMM grouped FP8×FP4 GEMM for MoE on SM120 (RTX 6000D/PRO 6000). Requires leavelet/DeepGEMM@sm120 branch (deepseek-ai/DeepGEMM#324, not yet merged). Changes: - configurer.py: Allow SM120 only when SM120-compatible DeepGEMM is installed (checks for m_grouped_fp8_fp4_gemm_nt_contiguous availability) - server_args.py: Auto-select deep_gemm MoE backend on SM120 - kernels.py: Add UE8M0 (power-of-2) FP8 quantization Triton kernel required by DeepGEMM's block-scaled dequantization on SM120 - deep_gemm.py: SM120 adaptations for DeepGEMM MoE runner: - TMA-aligned scale factors for grouped GEMM (hidden_states + down_input) - JIT EP activation fallback when hidden_dim/8 < num_experts (TP>=2) - In-place swiglu clamp replacing torch.chunk+cat (-7.4ms/step decode) - fp8.py: Add .contiguous() before transform_sf_into_required_layout - fp8_utils.py: Skip DeepGEMM dense FP8 linear on SM120 (bf16_gemm_nt unsupported) Performance (TP=4, BS=1, RTX 6000D 85GB, ISL=8K): TTFT: 130ms (vs 400ms marlin, 3x faster) Decode ITL: 47ms (vs 41ms marlin, 15% slower) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…emap - Default to FlashInfer directly instead of auto-detect via try-import (reviewer: FlashInfer is pinned, should import directly) - Remove _remap_indices_to_64: mathematically identity function (page256*256+off == (page256*4+off//64)*64+off%64 for all valid idx) - Simplify env var override to single os.environ.get with "flashinfer" default Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…entity remap" This reverts commit 4fd93c4.
…ff format - Default to FlashInfer directly instead of auto-detect via try-import (reviewer: FlashInfer is pinned, should import directly) - Remove _remap_indices_to_64: mathematically identity function (page256*256+off == (page256*4+off//64)*64+off%64 for all valid idx) - Simplify env var override to single os.environ.get with "flashinfer" default - Apply ruff formatter Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Thank you @b8zhong ! I will address these comments asap. |
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… overhead Replace per-call torch.empty for out_lse/mid_out/mid_lse with lazily allocated scratch buffers that grow-only and persist across decode steps. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…llocator overhead" This reverts commit cfb7733.
| _PBS_DST * _NOPE_ROPE_STRIDE + _PBS_DST * _SCALE_STRIDE | ||
| ) # 64*576 + 64*8 = 37376 + 512 = 37888 | ||
| # Padded to 576 alignment | ||
| import math as _math |
There was a problem hiding this comment.
No need to make it a private import like this
|
|
||
| # Copy data region: DATA_PER_SUB bytes from src offset sub*DATA_PER_SUB | ||
| data_src_off = sub * DATA_PER_SUB | ||
| for start in range(0, DATA_PER_SUB, BLOCK_SIZE): |
There was a problem hiding this comment.
Can we use tl.range? In both loops
|
@AliceChenyy does it work on RTX 5090, also SM120?does it must need nvlink?now i launch on RTX 5090 with RTX 6090 param config.it does not work.issue28299 |
…kernel Address b8zhong review comments: 1. Remove private import alias: `import math as _math` → `import math` 2. Use `tl.range` instead of `range` in Triton kernel loops for better compiler optimization hints. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add test_flashinfer_backend_matches_triton to TestEntryPointDispatch. Validates FlashInfer SM120 sparse MLA decode output matches Triton reference (skips if flashinfer.sparse_mla_sm120 not available). Registered on base-b, 1-gpu-large (existing test file CI registration). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Move `import math` to file top (E402 fix) - Replace try/import with importlib.util.find_spec in test (F401 fix) - Apply ruff format Note: pre-existing E402 for `import triton` (line 288-289) is not addressed here as it predates this PR. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| ) | ||
| # SM120 FlashMLA: default FlashInfer (CUTLASS SM120 sparse MLA decode). | ||
| # Override with SGLANG_SM120_FLASHMLA_BACKEND=triton|torch to force fallback. | ||
| _sm120_default_backend = os.environ.get("SGLANG_SM120_FLASHMLA_BACKEND", "flashinfer") |
There was a problem hiding this comment.
Please use the envs module instead of os.environ
There was a problem hiding this comment.
Done in e9b4eef. Using envs.SGLANG_SM120_FLASHMLA_BACKEND.get().
b8zhong
left a comment
There was a problem hiding this comment.
LGTM after resolving the one comment, but please fix the lint https://github.com/sgl-project/sglang/actions/runs/27744694747/job/82080017758?pr=27455
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
FlashInfer merged PR sgl-project#3395, changing the SM120 sparse MLA API from BatchSparseMLAPagedAttentionWrapper class to a stateless function. Move triton imports to file top to fix E402 lint. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Resolve environ.py conflict: keep SGLANG_SM120_FLASHMLA_BACKEND, drop removed SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
sparse_mla_sm120_decode_dsv4decode kernel for DeepSeek-V4-Flash on RTX PRO 6000 (SM120)Motivation
The existing Triton FlashMLA decode kernel (merged in #24692) works but leaves performance on the table. FlashInfer's native SM120
decode_dsv4kernel uses CUTLASS with block-scaled MXFP8 MMA, achieving 2.2-3.7x decode speedup.Approach
Challenge: SGLang SWA KV cache uses
page_size=256with footer layout (scales stored at page end), but FlashInfer'sdecode_dsv4hardcodespage_block_size=64as a CUDA template parameter.Solution: Fused Triton page-split kernel at call site — before each FlashInfer call, converts 256-token pages into 4 virtual 64-token pages with correct footer layout in a single kernel launch (replaces 8 separate copy kernels = 344 launches/step saved). Index remapping is not needed since the page-split preserves linear token ordering. No changes to the KV cache pool or compressor.
Reference: vLLM PR vllm-project/vllm#43477 (same FlashInfer kernel, different integration approach — vLLM uses block_size=64 natively).
Changes (3 files)
flash_mla_sm120.py:
SGLANG_SM120_FLASHMLA_BACKEND=triton|torch_page_split_kernel: Fused Triton kernel converting pbs=256 footer to pbs=64 footer in 1 launch_split_kv_pages_to_64(): Page-split driver with lazy per-device buffer allocation_flash_mla_flashinfer(): Direct call tosparse_mla_sm120_decode_dsv4()with page-split for SWA cache, extra cache (C4/C128) passed as-ismid_out/mid_lse/output/out_lsescratch buffers passed explicitlydeepseek_v4_backend.py (3 lines):
swa_page_sizefrom pool instead of hardcoded 128swa_page_size % SWA_WINDOW == 0environ.py (2 lines):
SGLANG_SM120_FLASHMLA_BACKENDenv var (default:"flashinfer")Performance
FlashInfer vs Triton MLA Decode (TP=4, 4x RTX PRO 6000 96GB, triton MoE, CUDA graph, ISL=8K OSL=32)
Speedup (FlashInfer over Triton):
Correctness
Dependencies
Requires FlashInfer 0.6.13 from GitHub main with SM120 sparse MLA support:
_sparse_mla_sm120.pyyetBackend selection
Test plan
is_sm120_supported()+ try-import)🤖 Generated with Claude Code
CI States
Latest PR Test (Base): ❌ Run #27826308026
Latest PR Test (Extra): ❌ Run #27826307650