Skip to content

[SM120] Add FlashInfer sparse MLA decode for DSv4-Flash#27455

Open
AliceChenyy wants to merge 18 commits into
sgl-project:mainfrom
AliceChenyy:sm120-flashinfer-mla
Open

[SM120] Add FlashInfer sparse MLA decode for DSv4-Flash#27455
AliceChenyy wants to merge 18 commits into
sgl-project:mainfrom
AliceChenyy:sm120-flashinfer-mla

Conversation

@AliceChenyy

@AliceChenyy AliceChenyy commented Jun 6, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Integrate FlashInfer SM120 sparse_mla_sm120_decode_dsv4 decode kernel for DeepSeek-V4-Flash on RTX PRO 6000 (SM120)
  • Decode TPOT 2.2-3.7x faster vs Triton kernel across batch sizes (ISL=8K, TP=4)
  • Default FlashInfer when available, falls back to Triton seamlessly
  • No changes to KV cache pool, compressor, or any non-SM120 code paths

Motivation

The existing Triton FlashMLA decode kernel (merged in #24692) works but leaves performance on the table. FlashInfer's native SM120 decode_dsv4 kernel uses CUTLASS with block-scaled MXFP8 MMA, achieving 2.2-3.7x decode speedup.

Approach

Challenge: SGLang SWA KV cache uses page_size=256 with footer layout (scales stored at page end), but FlashInfer's decode_dsv4 hardcodes page_block_size=64 as a CUDA template parameter.

Solution: Fused Triton page-split kernel at call site — before each FlashInfer call, converts 256-token pages into 4 virtual 64-token pages with correct footer layout in a single kernel launch (replaces 8 separate copy kernels = 344 launches/step saved). Index remapping is not needed since the page-split preserves linear token ordering. No changes to the KV cache pool or compressor.

Reference: vLLM PR vllm-project/vllm#43477 (same FlashInfer kernel, different integration approach — vLLM uses block_size=64 natively).

Changes (3 files)

flash_mla_sm120.py:

  • Default FlashInfer backend, override via SGLANG_SM120_FLASHMLA_BACKEND=triton|torch
  • _page_split_kernel: Fused Triton kernel converting pbs=256 footer to pbs=64 footer in 1 launch
  • _split_kv_pages_to_64(): Page-split driver with lazy per-device buffer allocation
  • _flash_mla_flashinfer(): Direct call to sparse_mla_sm120_decode_dsv4() with page-split for SWA cache, extra cache (C4/C128) passed as-is
  • Pre-allocated mid_out/mid_lse/output/out_lse scratch buffers passed explicitly

deepseek_v4_backend.py (3 lines):

  • Read swa_page_size from pool instead of hardcoded 128
  • Relax assertion to swa_page_size % SWA_WINDOW == 0

environ.py (2 lines):

  • Add SGLANG_SM120_FLASHMLA_BACKEND env var (default: "flashinfer")

Performance

FlashInfer vs Triton MLA Decode (TP=4, 4x RTX PRO 6000 96GB, triton MoE, CUDA graph, ISL=8K OSL=32)

BS Backend TTFT (s) TPOT (ms) Throughput (tok/s)
1 FlashInfer 0.92 22.2 17.71
1 Triton 6.17 82.7 3.21
4 FlashInfer 0.56 64.0 45.77
4 Triton 2.58 148.8 14.87
8 FlashInfer 0.91 114.4 52.61
8 Triton 3.39 253.8 19.83

Speedup (FlashInfer over Triton):

BS TTFT TPOT Throughput
1 6.7x 3.7x 5.5x
4 4.6x 2.3x 3.1x
8 3.7x 2.2x 2.7x

Correctness

Test Result
GSM8K 10q (0-shot, TP=4, FlashInfer) 10/10 = 100%
Unit test: FlashInfer vs Triton cosine similarity pass (atol=5e-2)
CUDA graph capture BS=1,2,4,8,12,16 All pass

Dependencies

Requires FlashInfer 0.6.13 from GitHub main with SM120 sparse MLA support:

pip install flashinfer-python @ git+https://github.com/flashinfer-ai/flashinfer.git@main --no-deps

Backend selection

# Default: FlashInfer (requires flashinfer with sparse_mla_sm120)
# Override with env var:
export SGLANG_SM120_FLASHMLA_BACKEND=triton   # force Triton
export SGLANG_SM120_FLASHMLA_BACKEND=torch    # force PyTorch fallback

Test plan

  • Unit test: FlashInfer vs Triton cosine similarity pass
  • E2E: GSM8K 10q 0-shot = 100% (TP=4, FlashInfer)
  • CUDA graph capture BS=1,2,4,8,12,16
  • FlashInfer vs Triton A/B comparison (2.2-3.7x TPOT speedup)
  • CI (SM120 not available in CI — SM120-only code path, guarded by is_sm120_supported() + try-import)

🤖 Generated with Claude Code


CI States

Latest PR Test (Base): ❌ Run #27826308026
Latest PR Test (Extra): ❌ Run #27826307650

Integrate FlashInfer's SM120 sparse_mla_sm120 decode kernel for
DeepSeek-V4-Flash on RTX PRO 6000 (SM120). Decode ITL improves
1.84x vs Triton (59ms vs 109ms, ISL=8K OSL=1K BS=1).

Key challenge: SGLang SWA KV cache uses page_size=256 with footer
layout (scales at page end), but FlashInfer decode_dsv4 requires
page_block_size=64. Solution: page-split at call site — split
256-token pages into 4 virtual 64-token pages with correct footer
layout and remap indices before calling FlashInfer.

- Auto-detect FlashInfer SM120 via try-import, fallback to Triton
- Page-split + index remap in _flash_mla_flashinfer()
- Uses BatchSparseMLAPagedAttentionWrapper.run() for decode dispatch
- Extra cache (C4/C128) passed as-is (native dual-cache support)
- GSM8K 10/10 = 100%, CUDA graph BS=1,2,4,8 all pass
- Env var SGLANG_SM120_FLASHMLA_BACKEND={auto,flashinfer,triton}

Requires: flashinfer >= 0.6.12 with sparse-mla-sm120 support
(lucifer1004/flashinfer@sparse-mla-sm120, PR flashinfer-ai/flashinfer#3395)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the FlashInfer SM120 sparse MLA backend, enabling page-split utilities to convert 256-token pages into 64-token pages for FlashInfer's decode fast path. It also updates the DeepSeek-V4 backend to dynamically retrieve the SWA page size and relaxes the window size assertion. Feedback highlights several optimization and robustness opportunities: reusing global scratch buffers to avoid allocating temporary workspace tensors on every decode step, eliminating the redundant _remap_indices_to_64 function which mathematically simplifies to an identity function, and adding a device check to the lazy allocation of _split_buf to ensure robustness during multi-GPU execution.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +257 to +258
# Lazily initialized FlashInfer wrapper (reused across calls).
_flashinfer_wrapper = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Define global scratch buffers to avoid allocating temporary workspace tensors (out_lse, mid_out, mid_lse) on every single decode step.

Suggested change
# Lazily initialized FlashInfer wrapper (reused across calls).
_flashinfer_wrapper = None
# Lazily initialized FlashInfer wrapper (reused across calls).
_flashinfer_wrapper = None
_out_lse_buf = None
_mid_out_buf = None
_mid_lse_buf = None

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tensors are allocated during CUDA graph capture and reused on replay — PyTorch's caching allocator returns the same memory block each time. Pre-allocating global buffers would add device/shape management complexity for no measurable gain in the CUDA graph hot path. Will revisit if profiling shows allocation overhead in non-graph paths.

Comment on lines +333 to +347
def _remap_indices_to_64(idx: torch.Tensor, src_pbs: int) -> torch.Tensor:
"""Remap token indices from pbs=N addressing to pbs=64 addressing.

Invalid indices (-1) are preserved.
"""
if src_pbs == _PBS_DST:
return idx
ratio = src_pbs // _PBS_DST
valid_mask = idx >= 0
page_src = torch.where(valid_mask, idx // src_pbs, torch.zeros_like(idx))
offset_src = torch.where(valid_mask, idx % src_pbs, torch.zeros_like(idx))
sub_page = offset_src // _PBS_DST
offset64 = offset_src % _PBS_DST
fi_idx = (page_src * ratio + sub_page) * _PBS_DST + offset64
return torch.where(valid_mask, fi_idx, idx)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The index remapping function _remap_indices_to_64 is mathematically an identity function for all valid indices (idx >= 0). Specifically, the formula:

$$\text{fi_idx} = \left( \lfloor \frac{\text{idx}}{\text{src_pbs}} \rfloor \times \text{ratio} + \lfloor \frac{\text{idx} \pmod{\text{src_pbs}}}{\text{PBS_DST}} \rfloor \right) \times \text{PBS_DST} + (\text{idx} \pmod{\text{src_pbs}}) \pmod{\text{PBS_DST}}$$

simplifies exactly to idx because ratio * PBS_DST = src_pbs. Therefore, we can completely eliminate this function and the associated GPU overhead.

# Note: _remap_indices_to_64 is mathematically an identity function because
# (idx // src_pbs * ratio + (idx % src_pbs) // 64) * 64 + (idx % src_pbs) % 64
# simplifies exactly to idx. Therefore, we can omit this remapping entirely.

@AliceChenyy AliceChenyy Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed _remap_indices_to_64 entirely in 7bdc051.

Comment on lines +387 to +389
# --- Index remapping: pbs=N token index -> pbs=64 token index ---
idx = indices.squeeze(1) if indices.dim() == 3 else indices
idx_64 = _remap_indices_to_64(idx, src_pbs) if src_pbs != _PBS_DST else idx

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Since _remap_indices_to_64 is an identity function, we can directly use idx without any remapping, avoiding redundant GPU operations in the hot path of decode.

Suggested change
# --- Index remapping: pbs=N token index -> pbs=64 token index ---
idx = indices.squeeze(1) if indices.dim() == 3 else indices
idx_64 = _remap_indices_to_64(idx, src_pbs) if src_pbs != _PBS_DST else idx
# --- Index remapping: pbs=N token index -> pbs=64 token index ---
idx_64 = indices.squeeze(1) if indices.dim() == 3 else indices

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 7bdc051. Using idx directly, no remapping.

Comment thread python/sglang/srt/layers/attention/flash_mla_sm120.py Outdated
Comment on lines +298 to +302
if _split_buf is None or _split_buf.shape[0] < num_dst_pages:
_split_buf = torch.empty(
num_dst_pages, _BYTES_PER_DST_PAGE_PADDED,
dtype=torch.uint8, device=kv_u8.device,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To ensure robustness across different GPU devices (e.g., during multi-GPU execution or unit testing), add a device check to the lazy allocation of _split_buf so it is reallocated if the device of kv_u8 changes.

Suggested change
if _split_buf is None or _split_buf.shape[0] < num_dst_pages:
_split_buf = torch.empty(
num_dst_pages, _BYTES_PER_DST_PAGE_PADDED,
dtype=torch.uint8, device=kv_u8.device,
)
if _split_buf is None or _split_buf.shape[0] < num_dst_pages or _split_buf.device != kv_u8.device:
_split_buf = torch.empty(
num_dst_pages, _BYTES_PER_DST_PAGE_PADDED,
dtype=torch.uint8, device=kv_u8.device,
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already addressed in commit 042ffb6_split_buf is a per-device dict ({} # device -> tensor), so each GPU gets its own buffer.

- Fix incorrect comment (swa_page_size=64 → describe page-split conversion)
- Add backward compat for legacy SGLANG_SM120_TRITON_FLASHMLA env var
- Restore one-arg-per-line formatting for triton/torch call paths
- Update docstring to reflect three-backend dispatch
- Use per-device dict for _flashinfer_wrapper and _split_buf (TP>1 safety)
- Replace 8 separate copy kernels in _split_kv_pages_to_64 with a fused
  Triton kernel (_page_split_kernel), reducing 344 kernel launches per
  decode step to 1 per layer. Verified 9.3% ITL improvement on RTX 6000D
  TP=4 (45.00ms → 40.83ms).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Comment thread python/sglang/srt/layers/attention/flash_mla_sm120.py Outdated
# Controlled by SGLANG_SM120_FLASHMLA_BACKEND env var (auto/flashinfer/triton/torch).
# Legacy: SGLANG_SM120_TRITON_FLASHMLA=0 maps to "torch" for backward compat.
def _detect_sm120_backend():
env = os.environ.get("SGLANG_SM120_FLASHMLA_BACKEND", "auto")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the envs classes that are pre-existing

@AliceChenyy AliceChenyy Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified in 7bdc051. Removed _detect_sm120_backend() entirely — now just one line: os.environ.get("SGLANG_SM120_FLASHMLA_BACKEND", "flashinfer").

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is, envs.get_env, or something like that.

Comment thread python/sglang/srt/layers/attention/flash_mla_sm120.py Outdated
return "triton"


_sm120_default_backend = _detect_sm120_backend()


def flash_mla_with_kvcache_sm120(**kwargs):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we still call this "flashmla"

@AliceChenyy AliceChenyy Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about change to flashinfer_sparse_mla?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that sounds better to me

…pGEMM@sm120)

Enable DeepGEMM grouped FP8×FP4 GEMM for MoE on SM120 (RTX 6000D/PRO 6000).
Requires leavelet/DeepGEMM@sm120 branch (deepseek-ai/DeepGEMM#324, not yet merged).

Changes:
- configurer.py: Allow SM120 only when SM120-compatible DeepGEMM is installed
  (checks for m_grouped_fp8_fp4_gemm_nt_contiguous availability)
- server_args.py: Auto-select deep_gemm MoE backend on SM120
- kernels.py: Add UE8M0 (power-of-2) FP8 quantization Triton kernel required
  by DeepGEMM's block-scaled dequantization on SM120
- deep_gemm.py: SM120 adaptations for DeepGEMM MoE runner:
  - TMA-aligned scale factors for grouped GEMM (hidden_states + down_input)
  - JIT EP activation fallback when hidden_dim/8 < num_experts (TP>=2)
  - In-place swiglu clamp replacing torch.chunk+cat (-7.4ms/step decode)
- fp8.py: Add .contiguous() before transform_sf_into_required_layout
- fp8_utils.py: Skip DeepGEMM dense FP8 linear on SM120 (bf16_gemm_nt unsupported)

Performance (TP=4, BS=1, RTX 6000D 85GB, ISL=8K):
  TTFT: 130ms (vs 400ms marlin, 3x faster)
  Decode ITL: 47ms (vs 41ms marlin, 15% slower)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
AliceChenyy and others added 3 commits June 7, 2026 17:31
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@b8zhong b8zhong self-assigned this Jun 8, 2026
AliceChenyy and others added 3 commits June 8, 2026 19:26
…emap

- Default to FlashInfer directly instead of auto-detect via try-import
  (reviewer: FlashInfer is pinned, should import directly)
- Remove _remap_indices_to_64: mathematically identity function
  (page256*256+off == (page256*4+off//64)*64+off%64 for all valid idx)
- Simplify env var override to single os.environ.get with "flashinfer" default

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ff format

- Default to FlashInfer directly instead of auto-detect via try-import
  (reviewer: FlashInfer is pinned, should import directly)
- Remove _remap_indices_to_64: mathematically identity function
  (page256*256+off == (page256*4+off//64)*64+off%64 for all valid idx)
- Simplify env var override to single os.environ.get with "flashinfer" default
- Apply ruff formatter

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@AliceChenyy

Copy link
Copy Markdown
Contributor Author

Thank you @b8zhong ! I will address these comments asap.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
AliceChenyy and others added 2 commits June 9, 2026 02:10
… overhead

Replace per-call torch.empty for out_lse/mid_out/mid_lse with lazily
allocated scratch buffers that grow-only and persist across decode steps.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
_PBS_DST * _NOPE_ROPE_STRIDE + _PBS_DST * _SCALE_STRIDE
) # 64*576 + 64*8 = 37376 + 512 = 37888
# Padded to 576 alignment
import math as _math

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to make it a private import like this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


# Copy data region: DATA_PER_SUB bytes from src offset sub*DATA_PER_SUB
data_src_off = sub * DATA_PER_SUB
for start in range(0, DATA_PER_SUB, BLOCK_SIZE):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use tl.range? In both loops

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@b8zhong b8zhong left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also add some tests for this feature. (UT is OK), and register it on the appropriate CI device (I think stage-b small are 5090s/SM120). Also, do we need to update any cookbook recipes

@poryfly

poryfly commented Jun 16, 2026

Copy link
Copy Markdown

@AliceChenyy does it work on RTX 5090, also SM120?does it must need nvlink?now i launch on RTX 5090 with RTX 6090 param config.it does not work.issue28299

AliceChenyy and others added 3 commits June 18, 2026 15:32
…kernel

Address b8zhong review comments:
1. Remove private import alias: `import math as _math` → `import math`
2. Use `tl.range` instead of `range` in Triton kernel loops for
   better compiler optimization hints.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add test_flashinfer_backend_matches_triton to TestEntryPointDispatch.
Validates FlashInfer SM120 sparse MLA decode output matches Triton
reference (skips if flashinfer.sparse_mla_sm120 not available).

Registered on base-b, 1-gpu-large (existing test file CI registration).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Move `import math` to file top (E402 fix)
- Replace try/import with importlib.util.find_spec in test (F401 fix)
- Apply ruff format

Note: pre-existing E402 for `import triton` (line 288-289) is not
addressed here as it predates this PR.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
)
# SM120 FlashMLA: default FlashInfer (CUTLASS SM120 sparse MLA decode).
# Override with SGLANG_SM120_FLASHMLA_BACKEND=triton|torch to force fallback.
_sm120_default_backend = os.environ.get("SGLANG_SM120_FLASHMLA_BACKEND", "flashinfer")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the envs module instead of os.environ

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in e9b4eef. Using envs.SGLANG_SM120_FLASHMLA_BACKEND.get().

@b8zhong b8zhong left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM after resolving the one comment, but please fix the lint https://github.com/sgl-project/sglang/actions/runs/27744694747/job/82080017758?pr=27455

AliceChenyy and others added 3 commits June 18, 2026 23:34
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
FlashInfer merged PR sgl-project#3395, changing the SM120 sparse MLA API from
BatchSparseMLAPagedAttentionWrapper class to a stateless function.
Move triton imports to file top to fix E402 lint.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Resolve environ.py conflict: keep SGLANG_SM120_FLASHMLA_BACKEND,
drop removed SGLANG_FLASHINFER_FORCE_POSIX_FD_TRANSPORT.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants