-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[SM120] Add FlashInfer sparse MLA decode for DSv4-Flash #27455
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 12 commits
f33271c
042ffb6
c30adcd
148972b
ebb01e3
2be8f89
4fd93c4
d097ede
7bdc051
7c98036
cfb7733
a6c4f05
a57f897
c85a8fe
17d11e8
e9b4eef
95fecc2
4f063a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -193,17 +193,15 @@ def _sm120_sparse_decode_fwd( | |
| return out.to(torch.bfloat16), lse.permute(0, 2, 1) | ||
|
|
||
|
|
||
| # Default SM120 FlashMLA backend: "triton" (optimized) or "torch" (pure-PyTorch fallback). | ||
| # Controlled by SGLANG_SM120_TRITON_FLASHMLA env var (1=triton, 0=torch). | ||
| _sm120_default_backend = ( | ||
| "triton" if os.environ.get("SGLANG_SM120_TRITON_FLASHMLA", "1") == "1" else "torch" | ||
| ) | ||
| # SM120 FlashMLA: default FlashInfer (CUTLASS SM120 sparse MLA decode). | ||
| # Override with SGLANG_SM120_FLASHMLA_BACKEND=triton|torch to force fallback. | ||
| _sm120_default_backend = os.environ.get("SGLANG_SM120_FLASHMLA_BACKEND", "flashinfer") | ||
|
|
||
|
|
||
| def flash_mla_with_kvcache_sm120(**kwargs): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we still call this "flashmla"
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about change to flashinfer_sparse_mla?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, that sounds better to me |
||
| """SM120 FlashMLA sparse decode entry point. | ||
|
|
||
| Dispatches to the Triton kernel (default) or PyTorch fallback. | ||
| Dispatches to FlashInfer (default if available), Triton, or PyTorch fallback. | ||
| """ | ||
| q = kwargs["q"] | ||
| k_cache = kwargs["k_cache"] | ||
|
|
@@ -218,6 +216,20 @@ def flash_mla_with_kvcache_sm120(**kwargs): | |
| extra_indices = kwargs.get("extra_indices_in_kvcache") | ||
| extra_topk_length = kwargs.get("extra_topk_length") | ||
|
|
||
| if _sm120_default_backend == "flashinfer": | ||
| return _flash_mla_flashinfer( | ||
| q, | ||
| k_cache, | ||
| indices, | ||
| topk_length, | ||
| attn_sink, | ||
| head_dim_v, | ||
| softmax_scale, | ||
| extra_k_cache, | ||
| extra_indices, | ||
| extra_topk_length, | ||
| ) | ||
|
|
||
| if _sm120_default_backend == "triton": | ||
| from sglang.srt.layers.attention.flash_mla_sm120_triton import ( | ||
| flash_mla_sparse_decode_triton, | ||
|
|
@@ -250,3 +262,217 @@ def flash_mla_with_kvcache_sm120(**kwargs): | |
| extra_topk_length, | ||
| ) | ||
| return (out, lse) | ||
|
|
||
|
|
||
| # Lazily initialized FlashInfer wrapper per device (reused across calls). | ||
| _flashinfer_wrapper = {} # device -> wrapper | ||
|
|
||
| # --- Page-split utilities: pbs=256 → pbs=64 --- | ||
| # SGLang SWA KV cache footer layout per 256-token page: | ||
| # [data: 256 * 576 bytes] [scale: 256 * 8 bytes] [padding] | ||
| # FlashInfer decode_dsv4 expects per 64-token page: | ||
| # [data: 64 * 576 bytes] [scale: 64 * 8 bytes] [padding to 37440] | ||
| _PBS_SRC = 256 # SGLang physical page size | ||
| _PBS_DST = 64 # FlashInfer page_block_size | ||
| _NOPE_ROPE_STRIDE = 576 # bytes per token for nope+rope | ||
| _SCALE_STRIDE = 8 # bytes per token for scale (7 + 1 pad) | ||
| _BYTES_PER_DST_PAGE = ( | ||
| _PBS_DST * _NOPE_ROPE_STRIDE + _PBS_DST * _SCALE_STRIDE | ||
| ) # 64*576 + 64*8 = 37376 + 512 = 37888 | ||
| # Padded to 576 alignment | ||
| import math as _math | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to make it a private import like this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
|
|
||
| _BYTES_PER_DST_PAGE_PADDED = _math.ceil(_BYTES_PER_DST_PAGE / 576) * 576 # 37440 | ||
|
|
||
| # Pre-allocated buffer for page-split output per device (lazily sized). | ||
| _split_buf = {} # device -> tensor | ||
|
|
||
| import triton | ||
| import triton.language as tl | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _page_split_kernel( | ||
| src_ptr, | ||
| dst_ptr, | ||
| N_pages, | ||
| src_stride0: tl.constexpr, | ||
| dst_stride0: tl.constexpr, | ||
| DATA_PER_SUB: tl.constexpr, # 64 * 576 = 36864 | ||
| SCALE_PER_SUB: tl.constexpr, # 64 * 8 = 512 | ||
| SRC_SCALE_OFF: tl.constexpr, # 256 * 576 = 147456 | ||
| DST_SCALE_OFF: tl.constexpr, # 64 * 576 = 36864 | ||
| RATIO: tl.constexpr, # 4 | ||
| BLOCK_SIZE: tl.constexpr, | ||
| ): | ||
| """Fused page-split: copy data+scale for all sub-pages in one kernel.""" | ||
| pid = tl.program_id(0) | ||
| page_idx = pid // RATIO | ||
| sub = pid % RATIO | ||
|
|
||
| if page_idx >= N_pages: | ||
| return | ||
|
|
||
| src_base = src_ptr + page_idx * src_stride0 | ||
| dst_base = dst_ptr + (page_idx * RATIO + sub) * dst_stride0 | ||
|
|
||
| # Copy data region: DATA_PER_SUB bytes from src offset sub*DATA_PER_SUB | ||
| data_src_off = sub * DATA_PER_SUB | ||
| for start in range(0, DATA_PER_SUB, BLOCK_SIZE): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use tl.range? In both loops
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| offs = start + tl.arange(0, BLOCK_SIZE) | ||
| mask = offs < DATA_PER_SUB | ||
| vals = tl.load(src_base + data_src_off + offs, mask=mask) | ||
| tl.store(dst_base + offs, vals, mask=mask) | ||
|
|
||
| # Copy scale region: SCALE_PER_SUB bytes | ||
| scale_src_off = SRC_SCALE_OFF + sub * SCALE_PER_SUB | ||
| for start in range(0, SCALE_PER_SUB, BLOCK_SIZE): | ||
| offs = start + tl.arange(0, BLOCK_SIZE) | ||
| mask = offs < SCALE_PER_SUB | ||
| vals = tl.load(src_base + scale_src_off + offs, mask=mask) | ||
| tl.store(dst_base + DST_SCALE_OFF + offs, vals, mask=mask) | ||
|
|
||
|
|
||
| def _split_kv_pages_to_64(kv_u8: torch.Tensor, src_pbs: int) -> torch.Tensor: | ||
| """Split pbs=N footer-format pages into pbs=64 footer-format pages. | ||
|
|
||
| Uses a fused Triton kernel to do all sub-page copies in a single launch | ||
| instead of 8 separate copy kernels (4 sub-pages × 2 regions). | ||
| """ | ||
| assert src_pbs % _PBS_DST == 0 and src_pbs >= _PBS_DST | ||
| if src_pbs == _PBS_DST: | ||
| return kv_u8 | ||
|
|
||
| N = kv_u8.shape[0] | ||
| ratio = src_pbs // _PBS_DST | ||
| num_dst_pages = N * ratio | ||
|
|
||
| dev = kv_u8.device | ||
| buf = _split_buf.get(dev) | ||
| if buf is None or buf.shape[0] < num_dst_pages: | ||
| buf = torch.empty( | ||
| num_dst_pages, | ||
| _BYTES_PER_DST_PAGE_PADDED, | ||
| dtype=torch.uint8, | ||
| device=dev, | ||
| ) | ||
| _split_buf[dev] = buf | ||
| out = buf[:num_dst_pages] | ||
|
|
||
| # Get raw 2D view of source | ||
| src_2d = kv_u8 | ||
| if src_2d.ndim == 4: | ||
| src_stride0 = src_2d.stride(0) | ||
| src_2d = torch.as_strided(src_2d, (N, src_stride0), (src_stride0, 1)) | ||
| else: | ||
| src_stride0 = src_2d.stride(0) | ||
|
|
||
| grid = (N * ratio,) | ||
| _page_split_kernel[grid]( | ||
| src_2d, | ||
| out, | ||
| N, | ||
| src_stride0, | ||
| _BYTES_PER_DST_PAGE_PADDED, | ||
| _PBS_DST * _NOPE_ROPE_STRIDE, # DATA_PER_SUB = 36864 | ||
| _PBS_DST * _SCALE_STRIDE, # SCALE_PER_SUB = 512 | ||
| src_pbs * _NOPE_ROPE_STRIDE, # SRC_SCALE_OFF = 147456 | ||
| _PBS_DST * _NOPE_ROPE_STRIDE, # DST_SCALE_OFF = 36864 | ||
| ratio, # RATIO = 4 | ||
| 1024, # BLOCK_SIZE | ||
| ) | ||
|
|
||
| bpt = _NOPE_ROPE_STRIDE + _SCALE_STRIDE # 584 | ||
| return out.as_strided( | ||
| (num_dst_pages, _PBS_DST, 1, bpt), | ||
| (_BYTES_PER_DST_PAGE_PADDED, bpt, bpt, 1), | ||
| ) | ||
|
|
||
|
|
||
| def _flash_mla_flashinfer( | ||
| q, | ||
| k_cache, | ||
| indices, | ||
| topk_length, | ||
| attn_sink, | ||
| head_dim_v, | ||
| softmax_scale, | ||
| extra_k_cache, | ||
| extra_indices, | ||
| extra_topk_length, | ||
| ): | ||
| """FlashInfer SM120 sparse MLA via BatchSparseMLAPagedAttentionWrapper. | ||
|
|
||
| SGLang SWA pool uses page_size=256 (footer format: 256*576 bytes data + 256*8 bytes scale). | ||
| FlashInfer decode_dsv4 fast path requires page_block_size=64 (footer: 64*576 + 64*8). | ||
| We split 256-token pages into 4 virtual 64-token pages. | ||
| Token indices are invariant under page-split (identity mapping). | ||
| """ | ||
| from flashinfer.sparse_mla_sm120 import BatchSparseMLAPagedAttentionWrapper | ||
|
|
||
| B, _, H, D = q.shape # (batch, 1, num_heads, head_dim) | ||
|
|
||
| dev = q.device | ||
| wrapper = _flashinfer_wrapper.get(dev) | ||
| if wrapper is None: | ||
| wrapper = BatchSparseMLAPagedAttentionWrapper( | ||
| d_v=head_dim_v, | ||
| device=dev, | ||
| ) | ||
| _flashinfer_wrapper[dev] = wrapper | ||
|
|
||
| # --- Page-split: convert pbs=N kv_cache to pbs=64 view --- | ||
| kv_u8 = k_cache.view(torch.uint8) if k_cache.dtype != torch.uint8 else k_cache | ||
| src_pbs = k_cache.shape[1] if k_cache.ndim >= 3 else _PBS_SRC | ||
| kv_64 = _split_kv_pages_to_64(kv_u8, src_pbs) if src_pbs != _PBS_DST else kv_u8 | ||
|
|
||
| extra_kv_u8 = ( | ||
| extra_k_cache.view(torch.uint8) | ||
| if extra_k_cache is not None and extra_k_cache.dtype != torch.uint8 | ||
| else extra_k_cache | ||
| ) | ||
| # Extra cache (C4/C128) has its own page_block_size (e.g. 64 for C4, 2 for C128). | ||
| # FlashInfer decode_dsv4 supports variable extra_page_block_size natively, | ||
| # so pass extra cache as-is without splitting. | ||
| extra_kv_64 = extra_kv_u8 | ||
|
|
||
| # Indices: no remapping needed (page-split preserves token addressing). | ||
| idx = indices.squeeze(1) if indices.dim() == 3 else indices | ||
| extra_idx = ( | ||
| extra_indices.squeeze(1) | ||
| if extra_indices is not None and extra_indices.dim() == 3 | ||
| else extra_indices | ||
| ) | ||
|
|
||
| output = torch.empty(B, H, head_dim_v, dtype=torch.bfloat16, device=q.device) | ||
| out_lse = torch.empty(B, H, dtype=torch.float32, device=q.device) | ||
|
|
||
| # Pre-allocate split-K scratch for decode-dsv4 fast path. | ||
| topk = idx.shape[-1] | ||
| extra_topk = extra_idx.shape[-1] if extra_idx is not None else 0 | ||
| _BI = 64 | ||
| num_splits = (topk + _BI - 1) // _BI + ( | ||
| (extra_topk + _BI - 1) // _BI if extra_topk > 0 else 0 | ||
| ) | ||
| mid_out = torch.empty( | ||
| B, H, num_splits, head_dim_v, dtype=torch.bfloat16, device=q.device | ||
| ) | ||
| mid_lse = torch.empty(B, H, num_splits, dtype=torch.float32, device=q.device) | ||
|
|
||
| wrapper.run( | ||
| q=q, # (B, 1, H, D) — wrapper handles squeeze | ||
| kv_cache=kv_64, | ||
| indices=idx, | ||
| output=output, | ||
| sm_scale=softmax_scale, | ||
| topk_length=topk_length, | ||
| attn_sink=attn_sink, | ||
| extra_kv_cache=extra_kv_64, | ||
| extra_indices=extra_idx, | ||
| extra_topk_length=extra_topk_length, | ||
| out_lse=out_lse, | ||
| mid_out=mid_out, | ||
| mid_lse=mid_lse, | ||
| ) | ||
|
|
||
| return (output.unsqueeze(1), None) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the envs module instead of os.environ
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in e9b4eef. Using
envs.SGLANG_SM120_FLASHMLA_BACKEND.get().