Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
f33271c
[SM120] Add FlashInfer sparse MLA decode for DSv4-Flash
AliceChenyy Jun 6, 2026
042ffb6
[SM120] Fix review issues and fuse page-split kernel for FlashInfer MLA
AliceChenyy Jun 7, 2026
c30adcd
[SM120] DeepGEMM MoE integration (experimental, requires leavelet/Dee…
AliceChenyy Jun 7, 2026
148972b
[SM120] Fix lint: ruff formatter compliance
AliceChenyy Jun 8, 2026
ebb01e3
Revert "[SM120] Fix lint: ruff formatter compliance"
AliceChenyy Jun 8, 2026
2be8f89
Revert "[SM120] DeepGEMM MoE integration (experimental, requires leav…
AliceChenyy Jun 8, 2026
4fd93c4
[SM120] Address review: simplify backend selection, remove identity r…
AliceChenyy Jun 9, 2026
d097ede
Revert "[SM120] Address review: simplify backend selection, remove id…
AliceChenyy Jun 9, 2026
7bdc051
[SM120] Address review: default FlashInfer, remove identity remap, ru…
AliceChenyy Jun 9, 2026
7c98036
[SM120] Fix black-jupyter format
AliceChenyy Jun 9, 2026
cfb7733
[SM120] Reuse scratch buffers in FlashInfer decode to avoid allocator…
AliceChenyy Jun 9, 2026
a6c4f05
Revert "[SM120] Reuse scratch buffers in FlashInfer decode to avoid a…
AliceChenyy Jun 9, 2026
a57f897
[SM120] Address review: use math (not _math), tl.range in page-split …
AliceChenyy Jun 18, 2026
c85a8fe
[SM120] Add FlashInfer backend unit test for sparse MLA decode
AliceChenyy Jun 18, 2026
17d11e8
[SM120] Fix lint: move math import to top, use importlib in test
AliceChenyy Jun 18, 2026
e9b4eef
[SM120] Use envs module for backend selection, fix lint formatting
AliceChenyy Jun 19, 2026
95fecc2
[SM120] Migrate FlashInfer API to sparse_mla_sm120_decode_dsv4
AliceChenyy Jun 19, 2026
4f063a2
Merge origin/main into sm120-flashinfer-mla
AliceChenyy Jun 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/attention/deepseek_v4_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,14 @@ def __init__(
self.softmax_scale: float = head_dim**-0.5
self.head_dim_v: int = model_runner.model_config.v_head_dim
self.cuda_int32_kwargs = {"device": self.device, "dtype": torch.int32}
self.swa_page_size = 128
assert model_runner.page_size is not None
assert model_runner.req_to_token_pool is not None
self.page_size = model_runner.page_size
assert self.page_size == 256, "the system hardcodes page_size=256"

self.req_to_token_pool = model_runner.req_to_token_pool
self.token_to_kv_pool: DeepSeekV4TokenToKVPool = model_runner.token_to_kv_pool
self.swa_page_size = self.token_to_kv_pool.swa_page_size
self.hisparse_coordinator = model_runner.hisparse_coordinator
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.MAX_SEQ_LEN_FOR_CAPTURE = self.req_to_token.shape[1]
Expand Down Expand Up @@ -1151,7 +1151,7 @@ def make_core_attn_metadata(
need_compress: bool = True,
is_prefill: bool = False,
) -> DSV4AttnMetadata:
assert self.swa_page_size == SWA_WINDOW
assert self.swa_page_size % SWA_WINDOW == 0

swa_page_indices = self.get_swa_page_indices(
seq_lens_casual=seq_lens_casual,
Expand Down
238 changes: 232 additions & 6 deletions python/sglang/srt/layers/attention/flash_mla_sm120.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,15 @@ def _sm120_sparse_decode_fwd(
return out.to(torch.bfloat16), lse.permute(0, 2, 1)


# Default SM120 FlashMLA backend: "triton" (optimized) or "torch" (pure-PyTorch fallback).
# Controlled by SGLANG_SM120_TRITON_FLASHMLA env var (1=triton, 0=torch).
_sm120_default_backend = (
"triton" if os.environ.get("SGLANG_SM120_TRITON_FLASHMLA", "1") == "1" else "torch"
)
# SM120 FlashMLA: default FlashInfer (CUTLASS SM120 sparse MLA decode).
# Override with SGLANG_SM120_FLASHMLA_BACKEND=triton|torch to force fallback.
_sm120_default_backend = os.environ.get("SGLANG_SM120_FLASHMLA_BACKEND", "flashinfer")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the envs module instead of os.environ

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in e9b4eef. Using envs.SGLANG_SM120_FLASHMLA_BACKEND.get().



def flash_mla_with_kvcache_sm120(**kwargs):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we still call this "flashmla"

@AliceChenyy AliceChenyy Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about change to flashinfer_sparse_mla?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that sounds better to me

"""SM120 FlashMLA sparse decode entry point.

Dispatches to the Triton kernel (default) or PyTorch fallback.
Dispatches to FlashInfer (default if available), Triton, or PyTorch fallback.
"""
q = kwargs["q"]
k_cache = kwargs["k_cache"]
Expand All @@ -218,6 +216,20 @@ def flash_mla_with_kvcache_sm120(**kwargs):
extra_indices = kwargs.get("extra_indices_in_kvcache")
extra_topk_length = kwargs.get("extra_topk_length")

if _sm120_default_backend == "flashinfer":
return _flash_mla_flashinfer(
q,
k_cache,
indices,
topk_length,
attn_sink,
head_dim_v,
softmax_scale,
extra_k_cache,
extra_indices,
extra_topk_length,
)

if _sm120_default_backend == "triton":
from sglang.srt.layers.attention.flash_mla_sm120_triton import (
flash_mla_sparse_decode_triton,
Expand Down Expand Up @@ -250,3 +262,217 @@ def flash_mla_with_kvcache_sm120(**kwargs):
extra_topk_length,
)
return (out, lse)


# Lazily initialized FlashInfer wrapper per device (reused across calls).
_flashinfer_wrapper = {} # device -> wrapper

# --- Page-split utilities: pbs=256 → pbs=64 ---
# SGLang SWA KV cache footer layout per 256-token page:
# [data: 256 * 576 bytes] [scale: 256 * 8 bytes] [padding]
# FlashInfer decode_dsv4 expects per 64-token page:
# [data: 64 * 576 bytes] [scale: 64 * 8 bytes] [padding to 37440]
_PBS_SRC = 256 # SGLang physical page size
_PBS_DST = 64 # FlashInfer page_block_size
_NOPE_ROPE_STRIDE = 576 # bytes per token for nope+rope
_SCALE_STRIDE = 8 # bytes per token for scale (7 + 1 pad)
_BYTES_PER_DST_PAGE = (
_PBS_DST * _NOPE_ROPE_STRIDE + _PBS_DST * _SCALE_STRIDE
) # 64*576 + 64*8 = 37376 + 512 = 37888
# Padded to 576 alignment
import math as _math

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to make it a private import like this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


_BYTES_PER_DST_PAGE_PADDED = _math.ceil(_BYTES_PER_DST_PAGE / 576) * 576 # 37440

# Pre-allocated buffer for page-split output per device (lazily sized).
_split_buf = {} # device -> tensor

import triton
import triton.language as tl


@triton.jit
def _page_split_kernel(
src_ptr,
dst_ptr,
N_pages,
src_stride0: tl.constexpr,
dst_stride0: tl.constexpr,
DATA_PER_SUB: tl.constexpr, # 64 * 576 = 36864
SCALE_PER_SUB: tl.constexpr, # 64 * 8 = 512
SRC_SCALE_OFF: tl.constexpr, # 256 * 576 = 147456
DST_SCALE_OFF: tl.constexpr, # 64 * 576 = 36864
RATIO: tl.constexpr, # 4
BLOCK_SIZE: tl.constexpr,
):
"""Fused page-split: copy data+scale for all sub-pages in one kernel."""
pid = tl.program_id(0)
page_idx = pid // RATIO
sub = pid % RATIO

if page_idx >= N_pages:
return

src_base = src_ptr + page_idx * src_stride0
dst_base = dst_ptr + (page_idx * RATIO + sub) * dst_stride0

# Copy data region: DATA_PER_SUB bytes from src offset sub*DATA_PER_SUB
data_src_off = sub * DATA_PER_SUB
for start in range(0, DATA_PER_SUB, BLOCK_SIZE):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use tl.range? In both loops

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

offs = start + tl.arange(0, BLOCK_SIZE)
mask = offs < DATA_PER_SUB
vals = tl.load(src_base + data_src_off + offs, mask=mask)
tl.store(dst_base + offs, vals, mask=mask)

# Copy scale region: SCALE_PER_SUB bytes
scale_src_off = SRC_SCALE_OFF + sub * SCALE_PER_SUB
for start in range(0, SCALE_PER_SUB, BLOCK_SIZE):
offs = start + tl.arange(0, BLOCK_SIZE)
mask = offs < SCALE_PER_SUB
vals = tl.load(src_base + scale_src_off + offs, mask=mask)
tl.store(dst_base + DST_SCALE_OFF + offs, vals, mask=mask)


def _split_kv_pages_to_64(kv_u8: torch.Tensor, src_pbs: int) -> torch.Tensor:
"""Split pbs=N footer-format pages into pbs=64 footer-format pages.

Uses a fused Triton kernel to do all sub-page copies in a single launch
instead of 8 separate copy kernels (4 sub-pages × 2 regions).
"""
assert src_pbs % _PBS_DST == 0 and src_pbs >= _PBS_DST
if src_pbs == _PBS_DST:
return kv_u8

N = kv_u8.shape[0]
ratio = src_pbs // _PBS_DST
num_dst_pages = N * ratio

dev = kv_u8.device
buf = _split_buf.get(dev)
if buf is None or buf.shape[0] < num_dst_pages:
buf = torch.empty(
num_dst_pages,
_BYTES_PER_DST_PAGE_PADDED,
dtype=torch.uint8,
device=dev,
)
_split_buf[dev] = buf
out = buf[:num_dst_pages]

# Get raw 2D view of source
src_2d = kv_u8
if src_2d.ndim == 4:
src_stride0 = src_2d.stride(0)
src_2d = torch.as_strided(src_2d, (N, src_stride0), (src_stride0, 1))
else:
src_stride0 = src_2d.stride(0)

grid = (N * ratio,)
_page_split_kernel[grid](
src_2d,
out,
N,
src_stride0,
_BYTES_PER_DST_PAGE_PADDED,
_PBS_DST * _NOPE_ROPE_STRIDE, # DATA_PER_SUB = 36864
_PBS_DST * _SCALE_STRIDE, # SCALE_PER_SUB = 512
src_pbs * _NOPE_ROPE_STRIDE, # SRC_SCALE_OFF = 147456
_PBS_DST * _NOPE_ROPE_STRIDE, # DST_SCALE_OFF = 36864
ratio, # RATIO = 4
1024, # BLOCK_SIZE
)

bpt = _NOPE_ROPE_STRIDE + _SCALE_STRIDE # 584
return out.as_strided(
(num_dst_pages, _PBS_DST, 1, bpt),
(_BYTES_PER_DST_PAGE_PADDED, bpt, bpt, 1),
)


def _flash_mla_flashinfer(
q,
k_cache,
indices,
topk_length,
attn_sink,
head_dim_v,
softmax_scale,
extra_k_cache,
extra_indices,
extra_topk_length,
):
"""FlashInfer SM120 sparse MLA via BatchSparseMLAPagedAttentionWrapper.

SGLang SWA pool uses page_size=256 (footer format: 256*576 bytes data + 256*8 bytes scale).
FlashInfer decode_dsv4 fast path requires page_block_size=64 (footer: 64*576 + 64*8).
We split 256-token pages into 4 virtual 64-token pages.
Token indices are invariant under page-split (identity mapping).
"""
from flashinfer.sparse_mla_sm120 import BatchSparseMLAPagedAttentionWrapper

B, _, H, D = q.shape # (batch, 1, num_heads, head_dim)

dev = q.device
wrapper = _flashinfer_wrapper.get(dev)
if wrapper is None:
wrapper = BatchSparseMLAPagedAttentionWrapper(
d_v=head_dim_v,
device=dev,
)
_flashinfer_wrapper[dev] = wrapper

# --- Page-split: convert pbs=N kv_cache to pbs=64 view ---
kv_u8 = k_cache.view(torch.uint8) if k_cache.dtype != torch.uint8 else k_cache
src_pbs = k_cache.shape[1] if k_cache.ndim >= 3 else _PBS_SRC
kv_64 = _split_kv_pages_to_64(kv_u8, src_pbs) if src_pbs != _PBS_DST else kv_u8

extra_kv_u8 = (
extra_k_cache.view(torch.uint8)
if extra_k_cache is not None and extra_k_cache.dtype != torch.uint8
else extra_k_cache
)
# Extra cache (C4/C128) has its own page_block_size (e.g. 64 for C4, 2 for C128).
# FlashInfer decode_dsv4 supports variable extra_page_block_size natively,
# so pass extra cache as-is without splitting.
extra_kv_64 = extra_kv_u8

# Indices: no remapping needed (page-split preserves token addressing).
idx = indices.squeeze(1) if indices.dim() == 3 else indices
extra_idx = (
extra_indices.squeeze(1)
if extra_indices is not None and extra_indices.dim() == 3
else extra_indices
)

output = torch.empty(B, H, head_dim_v, dtype=torch.bfloat16, device=q.device)
out_lse = torch.empty(B, H, dtype=torch.float32, device=q.device)

# Pre-allocate split-K scratch for decode-dsv4 fast path.
topk = idx.shape[-1]
extra_topk = extra_idx.shape[-1] if extra_idx is not None else 0
_BI = 64
num_splits = (topk + _BI - 1) // _BI + (
(extra_topk + _BI - 1) // _BI if extra_topk > 0 else 0
)
mid_out = torch.empty(
B, H, num_splits, head_dim_v, dtype=torch.bfloat16, device=q.device
)
mid_lse = torch.empty(B, H, num_splits, dtype=torch.float32, device=q.device)

wrapper.run(
q=q, # (B, 1, H, D) — wrapper handles squeeze
kv_cache=kv_64,
indices=idx,
output=output,
sm_scale=softmax_scale,
topk_length=topk_length,
attn_sink=attn_sink,
extra_kv_cache=extra_kv_64,
extra_indices=extra_idx,
extra_topk_length=extra_topk_length,
out_lse=out_lse,
mid_out=mid_out,
mid_lse=mid_lse,
)

return (output.unsqueeze(1), None)
Loading