From 40b9a04dd42f0aa420d7e737ce96e724b364c3c1 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 5 May 2026 02:25:00 +0800 Subject: [PATCH 01/11] Add SM12x DeepSeek V4 fallback runtime Co-authored-by: OpenAI Codex Signed-off-by: jasl --- .../attention/test_deepgemm_attention.py | 89 +- tests/models/test_deepseek_v4_mega_moe.py | 55 +- tests/models/test_deepseek_v4_pp.py | 9 + .../quantization/test_fp8_scale_parameter.py | 33 + tests/quantization/test_mxfp4.py | 38 + .../test_deepseek_v4_sparse_mla_reference.py | 2781 +++++++++++++++++ .../test_sm120_deepgemm_fallbacks.py | 131 + .../v1/attention/test_sparse_attn_indexer.py | 40 + .../v1/attention/test_sparse_mla_backends.py | 614 +++- tests/v1/attention/test_sparse_mla_env.py | 96 + vllm/config/compilation.py | 1 + vllm/envs.py | 41 + .../kernels/linear/scaled_mm/cutlass.py | 45 + .../layers/deepseek_v4_attention.py | 654 +++- .../layers/deepseek_v4_triton_kernels.py | 1282 ++++++++ vllm/model_executor/layers/fused_moe/layer.py | 29 +- .../layers/quantization/utils/fp8_utils.py | 49 +- .../layers/sparse_attn_indexer.py | 164 +- vllm/model_executor/models/deepseek_v4.py | 135 +- vllm/utils/deep_gemm.py | 523 +++- .../attention/backends/mla/flashmla_sparse.py | 18 + vllm/v1/attention/backends/mla/indexer.py | 27 +- .../attention/backends/mla/sparse_mla_env.py | 150 + .../backends/mla/sparse_mla_kernels.py | 2694 ++++++++++++++++ .../backends/mla/sparse_mla_reference.py | 242 ++ vllm/v1/attention/backends/mla/sparse_swa.py | 47 + .../attention/ops/deepseek_v4_ops/__init__.py | 6 + .../ops/deepseek_v4_ops/cache_utils.py | 220 +- .../ops/deepseek_v4_ops/fp8_einsum.py | 177 ++ 29 files changed, 10197 insertions(+), 193 deletions(-) create mode 100644 tests/models/test_deepseek_v4_pp.py create mode 100644 tests/quantization/test_fp8_scale_parameter.py create mode 100644 tests/quantization/test_mxfp4.py create mode 100644 tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py create mode 100644 tests/v1/attention/test_sm120_deepgemm_fallbacks.py create mode 100644 tests/v1/attention/test_sparse_attn_indexer.py create mode 100644 tests/v1/attention/test_sparse_mla_env.py create mode 100644 vllm/model_executor/layers/deepseek_v4_triton_kernels.py create mode 100644 vllm/v1/attention/backends/mla/sparse_mla_env.py create mode 100644 vllm/v1/attention/backends/mla/sparse_mla_kernels.py create mode 100644 vllm/v1/attention/backends/mla/sparse_mla_reference.py create mode 100644 vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index 0cea46d6284f..01f030836527 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -10,10 +10,12 @@ _ceil_to_ue8m0, calc_diff, fp8_fp4_mqa_logits, - fp8_fp4_paged_mqa_logits, get_num_sms, get_paged_mqa_logits_metadata, ) +from vllm.utils.deep_gemm import ( + fp8_fp4_paged_mqa_logits as fp8_paged_mqa_logits, +) from vllm.utils.import_utils import has_deep_gemm from vllm.utils.math_utils import cdiv @@ -90,10 +92,64 @@ def _ref_fp8_mqa_logits( return logits +def _supports_deepgemm_optimized_mqa_logits() -> bool: + return current_platform.is_cuda() and ( + current_platform.is_device_capability(90) + or current_platform.is_device_capability_family(100) + ) + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_mqa_logits_torch_path(): + torch.manual_seed(0) + + seq_len, seq_len_kv, num_heads, head_dim = 9, 17, 32, 32 + q = torch.randn( + seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32) + cu_seqlen_ks = (torch.arange(seq_len, device="cuda", dtype=torch.int32) % 3) + cu_seqlen_ke = torch.minimum( + torch.arange(seq_len, device="cuda", dtype=torch.int32) + 4, + torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32), + ) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4) + kv_scale = (kv_amax / 448.0).squeeze(1).contiguous() + kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn) + + logits = fp8_fp4_mqa_logits( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + + kv_dequant = kv_fp8.float() * kv_scale[:, None] + score = torch.einsum("mhd,nd->hmn", q_fp8.float(), kv_dequant) + ref_logits = (score.relu() * weights.transpose(0, 1).unsqueeze(-1)).sum(dim=0) + offsets = torch.arange(seq_len_kv, device="cuda") + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + ref_logits = ref_logits.masked_fill(~valid, float("-inf")) + + assert torch.equal(torch.isneginf(logits), torch.isneginf(ref_logits)) + finite = torch.isfinite(ref_logits) + assert (logits[finite] - ref_logits[finite]).abs().max() < 1e-4 + + @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") @pytest.mark.skipif( - not current_platform.has_device_capability(90), reason="SM90 and SM100 only" + not _supports_deepgemm_optimized_mqa_logits(), reason="SM90 and SM100 only" ) @pytest.mark.parametrize("clean_logits", [True, False]) def test_deepgemm_fp8_mqa_logits(clean_logits: bool): @@ -150,7 +206,7 @@ def test_deepgemm_fp8_mqa_logits(clean_logits: bool): assert diff < 1e-3, f"{diff=}" -def _ref_fp8_fp4_paged_mqa_logits( +def _ref_fp8_paged_mqa_logits( q: torch.Tensor, kv_cache: torch.Tensor, weights: torch.Tensor, @@ -203,12 +259,10 @@ def _ref_fp8_fp4_paged_mqa_logits( @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") @pytest.mark.skipif( - not current_platform.has_device_capability(90), reason="SM90 and SM100 only" + not _supports_deepgemm_optimized_mqa_logits(), reason="SM90 and SM100 only" ) -def test_deepgemm_fp8_fp4_paged_mqa_logits(): - # NOTE: clean_logits=True is incompatible with the 2D context_lens - # required by csrc/apis/attention.hpp; only the False path is exercised. - clean_logits = False +@pytest.mark.parametrize("clean_logits", [True, False]) +def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool): torch.manual_seed(0) random.seed(0) @@ -260,29 +314,24 @@ def test_deepgemm_fp8_fp4_paged_mqa_logits(): q_fp8 = q.to(torch.float8_e4m3fn) kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) - # deep_gemm paged MQA logits requires 2D context_lens of - # shape (B, next_n) (csrc/apis/attention.hpp:332-335); - # see indexer.py:607-608. For each batch/next_n token, the - # effective context length is context_lens[b] - next_n + j + 1. - next_n_arange = torch.arange(next_n, device="cuda", dtype=torch.int32) - context_lens_2d = ( - context_lens.unsqueeze(-1) - next_n + 1 + next_n_arange - ).contiguous() + deepgemm_context_lens = ( + context_lens[:, None].expand(-1, next_n).contiguous() + ) schedule_metadata = get_paged_mqa_logits_metadata( - context_lens_2d, blocksize, get_num_sms() + deepgemm_context_lens, blocksize, get_num_sms() ) - logits = fp8_fp4_paged_mqa_logits( + logits = fp8_paged_mqa_logits( (q_fp8, None), kv_cache_fp8, weights, - context_lens_2d, + deepgemm_context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=clean_logits, ) - ref_logits = _ref_fp8_fp4_paged_mqa_logits( + ref_logits = _ref_fp8_paged_mqa_logits( q, kv_cache, weights, diff --git a/tests/models/test_deepseek_v4_mega_moe.py b/tests/models/test_deepseek_v4_mega_moe.py index 304f044868a3..7da906f58446 100644 --- a/tests/models/test_deepseek_v4_mega_moe.py +++ b/tests/models/test_deepseek_v4_mega_moe.py @@ -9,6 +9,7 @@ from vllm.model_executor.models.deepseek_v4 import ( DeepseekV4MegaMoEExperts, _stage_deepseek_v4_mega_moe_inputs, + _use_deepseek_v4_mega_moe, make_deepseek_v4_expert_params_mapping, ) from vllm.platforms import current_platform @@ -19,6 +20,52 @@ ) +def _make_mega_moe_config( + *, + enable_expert_parallel: bool = True, + moe_backend: str = "auto", +): + return SimpleNamespace( + parallel_config=SimpleNamespace( + enable_expert_parallel=enable_expert_parallel + ), + kernel_config=SimpleNamespace(moe_backend=moe_backend), + ) + + +def test_deepseek_v4_mega_moe_selection_preserves_kernel_config(monkeypatch): + from vllm import envs + + monkeypatch.delenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", raising=False) + envs.disable_envs_cache() + + assert _use_deepseek_v4_mega_moe( + _make_mega_moe_config(moe_backend="deep_gemm_mega_moe") + ) + assert not _use_deepseek_v4_mega_moe(_make_mega_moe_config()) + with pytest.raises(NotImplementedError, match="requires expert parallel"): + _use_deepseek_v4_mega_moe( + _make_mega_moe_config( + enable_expert_parallel=False, + moe_backend="deep_gemm_mega_moe", + ) + ) + + +def test_deepseek_v4_mega_moe_selection_env_override(monkeypatch): + from vllm import envs + + monkeypatch.setenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", "1") + envs.disable_envs_cache() + assert _use_deepseek_v4_mega_moe(_make_mega_moe_config()) + + monkeypatch.setenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", "0") + envs.disable_envs_cache() + assert not _use_deepseek_v4_mega_moe( + _make_mega_moe_config(moe_backend="deep_gemm_mega_moe") + ) + + def test_deepseek_v4_mega_moe_expert_mapping(): mapping = make_deepseek_v4_expert_params_mapping(2) @@ -46,7 +93,8 @@ def test_deepseek_v4_mega_moe_ue8m0_uint8_to_float(): def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership(): vllm_config = SimpleNamespace( - scheduler_config=SimpleNamespace(max_num_batched_tokens=4) + scheduler_config=SimpleNamespace(max_num_batched_tokens=4), + compilation_config=SimpleNamespace(static_forward_context={}), ) experts = DeepseekV4MegaMoEExperts( vllm_config, @@ -111,7 +159,10 @@ def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership(): reason="DeepSeek V4 MegaMoE fused input staging requires CUDA.", ) def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact(): - from vllm.third_party.deep_gemm.utils import per_token_cast_to_fp8 + per_token_cast_to_fp8 = pytest.importorskip( + "deep_gemm.utils", + reason="DeepGEMM helper package is required for FP8 staging parity.", + ).per_token_cast_to_fp8 device = torch.device("cuda") num_tokens = 7 diff --git a/tests/models/test_deepseek_v4_pp.py b/tests/models/test_deepseek_v4_pp.py new file mode 100644 index 000000000000..7c0ae5dfd725 --- /dev/null +++ b/tests/models/test_deepseek_v4_pp.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.model_executor.models.deepseek_v4 import DeepseekV4ForCausalLM +from vllm.model_executor.models.interfaces import supports_pp + + +def test_deepseek_v4_declares_pipeline_parallel_support(): + assert supports_pp(DeepseekV4ForCausalLM) diff --git a/tests/quantization/test_fp8_scale_parameter.py b/tests/quantization/test_fp8_scale_parameter.py new file mode 100644 index 000000000000..c95cdbb98be2 --- /dev/null +++ b/tests/quantization/test_fp8_scale_parameter.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm.model_executor.parameter as parameter +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + create_fp8_scale_parameter, +) +from vllm.model_executor.parameter import BlockQuantScaleParameter + + +@pytest.mark.skipif( + not hasattr(torch, "float8_e8m0fnu"), + reason="torch does not expose float8_e8m0fnu", +) +def test_create_fp8_scale_parameter_initializes_e8m0(monkeypatch): + monkeypatch.setattr(parameter, "get_tensor_model_parallel_rank", lambda: 0) + monkeypatch.setattr(parameter, "get_tensor_model_parallel_world_size", lambda: 1) + + scale = create_fp8_scale_parameter( + BlockQuantScaleParameter, + output_partition_sizes=[128], + input_size_per_partition=128, + block_size=[128, 128], + weight_loader=None, + scale_dtype=torch.float8_e8m0fnu, + ) + + assert scale.dtype == torch.float8_e8m0fnu + raw_scale = scale.data.view(torch.uint8) + assert torch.equal(raw_scale, torch.zeros_like(raw_scale)) diff --git a/tests/quantization/test_mxfp4.py b/tests/quantization/test_mxfp4.py new file mode 100644 index 000000000000..6e6792e60cae --- /dev/null +++ b/tests/quantization/test_mxfp4.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +def test_mxfp4_e8m0_scale_loading_preserves_raw_bytes(): + from types import SimpleNamespace + + import pytest + import torch + + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if e8m0_dtype is None: + pytest.skip("torch does not expose float8_e8m0fnu") + + layer = object.__new__(FusedMoE) + layer.moe_config = SimpleNamespace(is_act_and_mul=True) + + expert_data = torch.zeros((4, 2), dtype=torch.uint8) + loaded_scale = torch.tensor( + [[0.0078125, 0.015625], [0.5, 1.0]], + dtype=e8m0_dtype, + ) + + layer._load_w13( + expert_data=expert_data, + shard_dim=0, + shard_id="w1", + loaded_weight=loaded_scale, + tp_rank=0, + ) + + torch.testing.assert_close( + expert_data[:2], + loaded_scale.view(torch.uint8), + rtol=0, + atol=0, + ) diff --git a/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py new file mode 100644 index 000000000000..f2e43c59ce61 --- /dev/null +++ b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py @@ -0,0 +1,2781 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Correctness tests for the DeepSeek V4 Triton sparse MLA path and reference oracle.""" + +from types import SimpleNamespace + +import pytest +import torch + +from vllm.config.compilation import ( + CompilationConfig, + CompilationMode, + CUDAGraphMode, +) +from vllm.model_executor.layers import ( + deepseek_v4_attention as deepseek_v4_attention_module, +) +from vllm.model_executor.layers.deepseek_v4_attention import ( + _allocate_deepseek_v4_wo_a_output, + _deepseek_v4_fp8_einsum_config, + _sparse_mla_prefill_workspace_bounds, + deepseek_v4_fp8_einsum, +) +from vllm.utils.deep_gemm import fp8_einsum +from vllm.v1.attention.backend import AttentionCGSupport +from vllm.v1.attention.backends.mla.flashmla_sparse import ( + FlashMLASparseMetadataBuilder, +) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + disable_triton_sparse_mla_cudagraphs_if_enabled, + triton_sparse_mla_topk_chunk_size, +) +from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk, + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead, + accumulate_fp8ds_paged_sparse_mla_attention_chunk, + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead, + accumulate_gathered_sparse_mla_attention_chunk, + accumulate_indexed_sparse_mla_attention_chunk, + build_combined_sparse_mla_decode_valid_mask, + finish_gathered_sparse_mla_attention, + finish_materialized_sparse_mla_scores_with_sink, + finish_sparse_mla_attention_with_sink, + finish_two_sparse_mla_attention_states_with_sink, + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead, + fp8ds_paged_sparse_mla_attention_with_sink_multihead, + matmul_sparse_mla_attention_with_sink, + merge_sparse_mla_subset_with_sink, + merge_two_sparse_mla_subsets_with_sink, + sparse_mla_decode_head_block_size, +) +from vllm.v1.attention.backends.mla.sparse_mla_reference import ( + accumulate_reference_attention_chunk, + finish_reference_attention_no_sink, + merge_reference_attention_with_sink, + new_reference_attention_state, + reference_attention_no_sink, + reference_sparse_mla_prefill, + sink_aware_reference_attention, +) +from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadataBuilder +from vllm.v1.attention.ops.deepseek_v4_ops import ( + dequantize_and_gather_k_cache, + dequantize_combined_sparse_mla_decode_kv, + dequantize_global_slots_k_cache, +) +from vllm.v1.attention.ops.deepseek_v4_ops.fp8_einsum import ( + deepseek_v4_sm12_fp8_einsum, +) +from vllm.v1.kv_cache_interface import MLAAttentionSpec, SlidingWindowMLASpec + +_FP8_DIM = 448 +_ROPE_DIM = 64 +_SCALE_DIM = 8 +_TOKEN_DATA_SIZE = _FP8_DIM + _ROPE_DIM * 2 + + +class _FakeWorkspaceManager: + def get_simultaneous(self, *specs): + return tuple(torch.empty(shape, dtype=dtype) for shape, dtype in specs) + + +def _assert_fp8_einsum_close(actual: torch.Tensor, expected: torch.Tensor) -> None: + # The Triton path and DeepGEMM reference both accumulate in FP32, but + # their reduction orders are not bit-identical before the final BF16 store. + torch.testing.assert_close(actual.float(), expected.float(), rtol=5e-2, atol=3e-4) + + +def test_deepseek_v4_fp8_einsum_is_piecewise_split_op() -> None: + assert "vllm::deepseek_v4_fp8_einsum" in CompilationConfig._attention_ops + + +def test_wo_a_output_allocation_uses_empty_during_compile(monkeypatch) -> None: + class FailingWorkspaceManager: + def get_simultaneous(self, *args, **kwargs): + raise AssertionError("compiled allocation must not grow workspace") + + monkeypatch.setattr( + deepseek_v4_attention_module, + "current_workspace_manager", + lambda: FailingWorkspaceManager(), + ) + monkeypatch.setattr(torch.compiler, "is_compiling", lambda: True) + + output = _allocate_deepseek_v4_wo_a_output( + 2, + 3, + 5, + torch.bfloat16, + torch.device("cpu"), + ) + + assert output.shape == (2, 3, 5) + assert output.dtype == torch.bfloat16 + + +def test_wo_a_output_allocation_uses_workspace_outside_compile(monkeypatch) -> None: + captured = {} + + class FakeWorkspaceManager: + def get_simultaneous(self, *shapes_and_dtypes): + captured["request"] = shapes_and_dtypes + return [ + torch.empty(shape, dtype=dtype) for shape, dtype in shapes_and_dtypes + ] + + monkeypatch.setattr( + deepseek_v4_attention_module, + "current_workspace_manager", + lambda: FakeWorkspaceManager(), + ) + monkeypatch.setattr(torch.compiler, "is_compiling", lambda: False) + + output = _allocate_deepseek_v4_wo_a_output( + 2, + 3, + 5, + torch.bfloat16, + torch.device("cpu"), + ) + + assert captured["request"] == (((2, 3, 5), torch.bfloat16),) + assert output.shape == (2, 3, 5) + assert output.dtype == torch.bfloat16 + + +def test_triton_sparse_mla_default_topk_chunk_size(monkeypatch) -> None: + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", raising=False) + + assert triton_sparse_mla_topk_chunk_size() == 512 + + +def test_sparse_mla_prefill_workspace_bounds_use_active_prefill_lengths() -> None: + seq_lens_cpu = torch.tensor([15_000, 2_048], dtype=torch.int32) + gather_lens_cpu = torch.tensor([15_000, 2_048], dtype=torch.int32) + + compressed_region_size, row_stride = _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu=seq_lens_cpu, + gather_lens_cpu=gather_lens_cpu, + compress_ratio=4, + swa_only=False, + ) + + assert compressed_region_size == 3_750 + assert row_stride == 18_750 + + +def test_sparse_mla_prefill_workspace_bounds_for_swa_only() -> None: + seq_lens_cpu = torch.tensor([15_000], dtype=torch.int32) + gather_lens_cpu = torch.tensor([15_000], dtype=torch.int32) + + compressed_region_size, row_stride = _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu=seq_lens_cpu, + gather_lens_cpu=gather_lens_cpu, + compress_ratio=1, + swa_only=True, + ) + + assert compressed_region_size == 0 + assert row_stride == 15_000 + + +@pytest.mark.parametrize( + ("num_decode_tokens", "expected_head_block_size"), + [ + (0, 1), + (1, 1), + (4, 1), + (5, 2), + (8, 2), + (15, 2), + (16, 4), + (32, 4), + ], +) +def test_triton_sparse_mla_decode_head_block_size( + num_decode_tokens: int, + expected_head_block_size: int, + monkeypatch, +) -> None: + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE", raising=False) + + assert ( + sparse_mla_decode_head_block_size(num_decode_tokens) == expected_head_block_size + ) + + +@pytest.mark.parametrize("configured_head_block_size", ["1", "2", "4"]) +def test_triton_sparse_mla_decode_head_block_size_env_override( + configured_head_block_size: str, + monkeypatch, +) -> None: + monkeypatch.setenv( + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE", + configured_head_block_size, + ) + + assert sparse_mla_decode_head_block_size(1) == int(configured_head_block_size) + assert sparse_mla_decode_head_block_size(32) == int(configured_head_block_size) + + +@pytest.mark.parametrize("configured_head_block_size", ["0", "3", "invalid"]) +def test_triton_sparse_mla_decode_head_block_size_ignores_invalid_env_override( + configured_head_block_size: str, + monkeypatch, +) -> None: + monkeypatch.setenv( + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE", + configured_head_block_size, + ) + + assert sparse_mla_decode_head_block_size(8) == 2 + + +def test_swa_mtp_decode_triton_uses_global_swa_slots(monkeypatch) -> None: + captured: dict[str, torch.Tensor] = {} + + def fail_paged_attention_with_sink_multihead(**kwargs) -> None: + raise AssertionError("MTP SWA decode must use explicit SWA indices") + + def fake_accumulate_global_slots(**kwargs) -> None: + captured["slot_ids"] = kwargs["slot_ids"] + captured["lens"] = kwargs["lens"] + + def fake_finish_with_sink(*args, **kwargs) -> None: + kwargs["output"].zero_() + + monkeypatch.setattr( + deepseek_v4_attention_module, + "current_workspace_manager", + lambda: _FakeWorkspaceManager(), + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "fp8ds_paged_sparse_mla_attention_with_sink_multihead", + fail_paged_attention_with_sink_multihead, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead", + fake_accumulate_global_slots, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "finish_sparse_mla_attention_with_sink", + fake_finish_with_sink, + ) + + attention = SimpleNamespace( + num_heads=2, + scale=0.1, + attn_sink=torch.zeros(2, dtype=torch.float32), + ) + swa_indices = torch.arange(48, dtype=torch.int32).reshape(6, 1, 8) + swa_lens = torch.tensor([2, 3, 4, 2, 3, 4], dtype=torch.int32) + metadata = SimpleNamespace( + num_decodes=2, + num_decode_tokens=6, + decode_swa_lens=swa_lens, + decode_swa_indices=swa_indices, + seq_lens=torch.tensor([11, 22], dtype=torch.int32), + block_table=torch.empty((2, 4), dtype=torch.int32), + block_size=256, + token_to_req_indices=torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.int32), + ) + + deepseek_v4_attention_module.DeepseekV4MLAAttention._forward_sparse_mla_swa_decode_triton( + attention, + q=torch.empty((6, 1, 2, 512), dtype=torch.bfloat16), + swa_k_cache=torch.empty((1, 256, 584), dtype=torch.uint8), + swa_metadata=metadata, + output=torch.empty((6, 2, 512), dtype=torch.bfloat16), + ) + + torch.testing.assert_close(captured["slot_ids"], swa_indices) + torch.testing.assert_close(captured["lens"], swa_lens) + + +def test_compressed_mtp_decode_triton_uses_global_swa_slots(monkeypatch) -> None: + captured: list[torch.Tensor] = [] + + def fail_matmul_decode(**kwargs) -> None: + raise AssertionError("MTP compressed decode must not stage paged SWA") + + def fail_direct_global_paged(**kwargs) -> None: + raise AssertionError("MTP compressed decode must not use paged SWA window") + + def fake_accumulate_global_slots(**kwargs) -> None: + captured.append(kwargs["slot_ids"]) + + def fake_finish_two_states(*args, **kwargs) -> None: + kwargs["output"].zero_() + + monkeypatch.setattr( + deepseek_v4_attention_module, + "current_workspace_manager", + lambda: _FakeWorkspaceManager(), + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "dequantize_combined_sparse_mla_decode_kv", + fail_matmul_decode, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "fp8ds_global_paged_sparse_mla_attention_with_sink_multihead", + fail_direct_global_paged, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead", + fake_accumulate_global_slots, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "finish_two_sparse_mla_attention_states_with_sink", + fake_finish_two_states, + ) + + attention = SimpleNamespace( + num_heads=2, + scale=0.1, + attn_sink=torch.zeros(2, dtype=torch.float32), + compress_ratio=4, + ) + swa_indices = torch.arange(48, dtype=torch.int32).reshape(6, 1, 8) + topk_slot_ids = torch.arange(24, dtype=torch.int32).reshape(6, 1, 4) + swa_metadata = SimpleNamespace( + num_decodes=2, + num_decode_tokens=6, + decode_swa_lens=torch.full((6,), 3, dtype=torch.int32), + decode_swa_indices=swa_indices, + seq_lens=torch.tensor([11, 22], dtype=torch.int32), + block_table=torch.empty((2, 4), dtype=torch.int32), + block_size=256, + token_to_req_indices=torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.int32), + ) + + deepseek_v4_attention_module.DeepseekV4MLAAttention._forward_sparse_mla_compressed_decode_triton( + attention, + q=torch.empty((6, 1, 2, 512), dtype=torch.bfloat16), + compressed_k_cache=torch.empty((1, 64, 584), dtype=torch.uint8), + swa_k_cache=torch.empty((1, 256, 584), dtype=torch.uint8), + topk_indices=topk_slot_ids, + topk_lens=torch.full((6,), 4, dtype=torch.int32), + swa_metadata=swa_metadata, + attn_metadata=SimpleNamespace(block_size=256), + output=torch.empty((6, 2, 512), dtype=torch.bfloat16), + ) + + assert len(captured) == 2 + torch.testing.assert_close(captured[0], topk_slot_ids[:, 0]) + torch.testing.assert_close(captured[1], swa_indices) + + +@pytest.mark.parametrize( + ("capability_major", "expected_recipe", "expected_tma_aligned"), + [ + (9, (1, 128, 128), False), + (10, (1, 1, 128), True), + (12, (1, 128, 128), False), + ], +) +def test_deepseek_v4_fp8_einsum_config_for_sm12x( + capability_major: int, + expected_recipe: tuple[int, int, int], + expected_tma_aligned: bool, +) -> None: + assert _deepseek_v4_fp8_einsum_config(capability_major) == ( + expected_recipe, + expected_tma_aligned, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("use_e8m0_scale", [False, True]) +def test_deepseek_v4_sm12_triton_fp8_einsum_matches_deepgemm_reference( + use_e8m0_scale: bool, +) -> None: + if use_e8m0_scale and not hasattr(torch, "float8_e8m0fnu"): + pytest.skip("torch does not expose float8_e8m0fnu") + torch.manual_seed(0) + num_tokens = 17 + num_groups = 4 + hidden_size = 4096 + out_rank = 1024 + recipe = (1, 128, 128) + + a_backing = torch.randn( + (num_groups, num_tokens, hidden_size), + device="cuda", + dtype=torch.bfloat16, + ).to(torch.float8_e4m3fn) + a = a_backing.transpose(0, 1) + a_scale_backing = torch.empty( + (num_groups, num_tokens, hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + a_scale = a_scale_backing.transpose(0, 1) + b_flat = torch.randn( + (num_groups * out_rank, hidden_size), + device="cuda", + dtype=torch.bfloat16, + ).to(torch.float8_e4m3fn) + b = b_flat.view(num_groups, out_rank, hidden_size) + if use_e8m0_scale: + scale_choices = torch.tensor( + [0.00390625, 0.0078125, 0.015625, 0.03125], + device="cuda", + dtype=torch.float32, + ) + scale_indices = torch.randint( + 0, + len(scale_choices), + (num_groups * (out_rank // 128), hidden_size // 128), + device="cuda", + ) + b_scale_flat = scale_choices[scale_indices].to(torch.float8_e8m0fnu) + b_scale_ref_flat = b_scale_flat.to(torch.float32) + else: + b_scale_flat = torch.empty( + (num_groups * (out_rank // 128), hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + b_scale_ref_flat = b_scale_flat + b_scale_ref = b_scale_ref_flat.view(num_groups, out_rank // 128, hidden_size // 128) + expected = torch.empty( + (num_tokens, num_groups, out_rank), + device="cuda", + dtype=torch.bfloat16, + ) + actual = torch.empty_like(expected) + + fp8_einsum( + "bhr,hdr->bhd", + (a, a_scale), + (b, b_scale_ref), + expected, + recipe=recipe, + ) + deepseek_v4_fp8_einsum( + a, + a_scale, + b_flat, + b_scale_flat, + actual, + "bhr,hdr->bhd", + list(recipe), + ) + + _assert_fp8_einsum_close(actual, expected) + + +def test_deepseek_v4_fp8_einsum_slices_full_group_weight_for_tp( + monkeypatch, +) -> None: + captured: dict[str, torch.Tensor] = {} + num_tokens = 2 + local_groups = 4 + full_groups = 8 + out_rank = 512 + hidden_size = 4096 + recipe = [1, 128, 128] + + def fake_sm12_fp8_einsum( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + ) -> None: + captured["b"] = b + captured["b_scale"] = b_scale + + monkeypatch.setattr( + deepseek_v4_attention_module.current_platform, + "get_device_capability", + lambda: SimpleNamespace(major=12), + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "get_tensor_model_parallel_rank", + lambda: 1, + raising=False, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "deepseek_v4_sm12_fp8_einsum", + fake_sm12_fp8_einsum, + ) + + a = torch.empty( + (num_tokens, local_groups, hidden_size), + dtype=torch.float8_e4m3fn, + ) + a_scale = torch.empty( + (num_tokens, local_groups, hidden_size // 128), + dtype=torch.float32, + ) + b = torch.arange( + full_groups * out_rank * hidden_size, + dtype=torch.uint8, + ).view(torch.float8_e4m3fn) + b = b.view(full_groups * out_rank, hidden_size) + b_scale = torch.arange( + full_groups * (out_rank // 128) * (hidden_size // 128), + dtype=torch.float32, + ).view(full_groups * (out_rank // 128), hidden_size // 128) + out = torch.empty((num_tokens, local_groups, out_rank), dtype=torch.bfloat16) + + deepseek_v4_fp8_einsum( + a, + a_scale, + b, + b_scale, + out, + "bhr,hdr->bhd", + recipe, + ) + + expected_b = b.view(full_groups, out_rank, hidden_size)[local_groups:] + expected_b_scale = b_scale.view(full_groups, out_rank // 128, hidden_size // 128)[ + local_groups: + ] + assert torch.equal(captured["b"].view(torch.uint8), expected_b.view(torch.uint8)) + assert torch.equal(captured["b_scale"], expected_b_scale) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_deepseek_v4_sm12_triton_fp8_einsum_primitive_matches_reference() -> None: + torch.manual_seed(0) + num_tokens = 17 + num_groups = 4 + hidden_size = 4096 + out_rank = 1024 + recipe = (1, 128, 128) + + a_backing = torch.randn( + (num_groups, num_tokens, hidden_size), + device="cuda", + dtype=torch.bfloat16, + ).to(torch.float8_e4m3fn) + a = a_backing.transpose(0, 1) + a_scale_backing = torch.empty( + (num_groups, num_tokens, hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + a_scale = a_scale_backing.transpose(0, 1) + b_flat = torch.randn( + (num_groups * out_rank, hidden_size), + device="cuda", + dtype=torch.bfloat16, + ).to(torch.float8_e4m3fn) + b = b_flat.view(num_groups, out_rank, hidden_size) + b_scale_flat = torch.empty( + (num_groups * (out_rank // 128), hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + b_scale = b_scale_flat.view(num_groups, out_rank // 128, hidden_size // 128) + expected = torch.empty( + (num_tokens, num_groups, out_rank), + device="cuda", + dtype=torch.bfloat16, + ) + actual = torch.empty_like(expected) + + fp8_einsum("bhr,hdr->bhd", (a, a_scale), (b, b_scale), expected, recipe=recipe) + deepseek_v4_sm12_fp8_einsum(a, a_scale, b, b_scale, actual) + + _assert_fp8_einsum_close(actual, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("num_groups", [1, 2, 4]) +def test_deepseek_v4_sm12_triton_fp8_einsum_supports_tp_local_group_counts( + num_groups: int, +) -> None: + torch.manual_seed(18 + num_groups) + num_tokens = 5 + hidden_size = 4096 + out_rank = 1024 + recipe = (1, 128, 128) + + a_backing = torch.randn( + (num_groups, num_tokens, hidden_size), + device="cuda", + dtype=torch.bfloat16, + ).to(torch.float8_e4m3fn) + a = a_backing.transpose(0, 1) + a_scale_backing = torch.empty( + (num_groups, num_tokens, hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + a_scale = a_scale_backing.transpose(0, 1) + b_flat = torch.randn( + (num_groups * out_rank, hidden_size), + device="cuda", + dtype=torch.bfloat16, + ).to(torch.float8_e4m3fn) + b = b_flat.view(num_groups, out_rank, hidden_size) + b_scale_flat = torch.empty( + (num_groups * (out_rank // 128), hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + b_scale = b_scale_flat.view(num_groups, out_rank // 128, hidden_size // 128) + expected = torch.empty( + (num_tokens, num_groups, out_rank), + device="cuda", + dtype=torch.bfloat16, + ) + actual = torch.empty_like(expected) + + fp8_einsum("bhr,hdr->bhd", (a, a_scale), (b, b_scale), expected, recipe=recipe) + deepseek_v4_sm12_fp8_einsum(a, a_scale, b, b_scale, actual) + + _assert_fp8_einsum_close(actual, expected) + + +def _masked_scores( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, +) -> torch.Tensor: + q_bhd = q[:, 0].float() if q.dim() == 4 else q.float() + scores = torch.einsum("bhd,btd->bht", q_bhd, kv.float()) * scale + return scores.masked_fill(~valid_tokens[:, None, :], float("-inf")) + + +def _golden_no_sink_attention( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, +) -> tuple[torch.Tensor, torch.Tensor]: + scores = _masked_scores(q, kv, valid_tokens, scale) + lse = torch.logsumexp(scores, dim=-1) + weights = torch.exp(scores - lse[:, :, None]) + weights = torch.where( + valid_tokens[:, None, :], + weights, + torch.zeros((), dtype=weights.dtype, device=weights.device), + ) + weights = torch.nan_to_num(weights) + output = torch.einsum("bht,btd->bhd", weights, kv.float()) + valid = valid_tokens.any(dim=-1) + output = torch.where( + valid[:, None, None], + output, + torch.zeros((), dtype=output.dtype, device=output.device), + ) + return output, lse + + +def _golden_sink_attention( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, +) -> torch.Tensor: + scores = _masked_scores(q, kv, valid_tokens, scale) + sink = attn_sink[None, :].float() + score_max = scores.amax(dim=-1) + merge_max = torch.maximum(score_max, sink) + + weights = torch.exp(scores - merge_max[:, :, None]) + weights = torch.where( + valid_tokens[:, None, :], + weights, + torch.zeros((), dtype=weights.dtype, device=weights.device), + ) + weights = torch.nan_to_num(weights) + + sink_weight = torch.exp(sink - merge_max) + sink_weight = torch.nan_to_num(sink_weight) + denom = weights.sum(dim=-1) + sink_weight + numerator = torch.einsum("bht,btd->bhd", weights, kv.float()) + return numerator / denom[:, :, None] + + +def _chunked_no_sink_attention( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + chunk_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + q_bhd, max_score, denom, acc = new_reference_attention_state(q) + for chunk_start in range(0, kv.shape[1], chunk_size): + chunk_end = min(chunk_start + chunk_size, kv.shape[1]) + max_score, denom, acc = accumulate_reference_attention_chunk( + q_bhd=q_bhd, + kv=kv[:, chunk_start:chunk_end], + valid_tokens=valid_tokens[:, chunk_start:chunk_end], + max_score=max_score, + denom=denom, + acc=acc, + scale=scale, + ) + return finish_reference_attention_no_sink(max_score, denom, acc) + + +def _write_fp8_ds_mla_token( + k_cache: torch.Tensor, + slot: int, + block_size: int, +) -> torch.Tensor: + block_idx = slot // block_size + block_offset = slot % block_size + + values = ( + (torch.arange(_FP8_DIM, device=k_cache.device, dtype=torch.float32) % 17) - 8 + ) / 16.0 + values = values + float(slot) / 32.0 + scale_exponents = torch.tensor( + [-2, -1, 0, 1, 2, -2, 1], + device=k_cache.device, + dtype=torch.float32, + ) + scales = torch.exp2(scale_exponents) + scale_per_dim = scales.repeat_interleave(64) + + fp8_values = (values / scale_per_dim).to(torch.float8_e4m3fn) + expected_nope = fp8_values.float() * scale_per_dim + rope = ( + torch.linspace(-1.0, 1.0, _ROPE_DIM, device=k_cache.device) + float(slot) / 16.0 + ).to(torch.bfloat16) + + flat_block = k_cache[block_idx].view(-1) + token_data_start = block_offset * _TOKEN_DATA_SIZE + token_scale_start = block_size * _TOKEN_DATA_SIZE + block_offset * _SCALE_DIM + flat_block[token_data_start : token_data_start + _FP8_DIM] = fp8_values.view( + torch.uint8 + ) + flat_block[token_data_start + _FP8_DIM : token_data_start + _TOKEN_DATA_SIZE] = ( + rope.view(torch.uint8) + ) + + encoded_scales = (scale_exponents.to(torch.int32) + 127).to(torch.uint8) + flat_block[token_scale_start : token_scale_start + encoded_scales.numel()] = ( + encoded_scales + ) + flat_block[ + token_scale_start + encoded_scales.numel() : token_scale_start + _SCALE_DIM + ] = 127 + + return torch.cat([expected_nope, rope.float()]).to(torch.bfloat16) + + +def test_reference_attention_no_sink_matches_logsumexp() -> None: + torch.manual_seed(0) + scale = 0.25 + q = torch.randn(3, 4, 5) + kv = torch.randn(3, 6, 5) + valid_tokens = torch.tensor( + [ + [True, True, False, True, False, False], + [False, False, False, False, False, False], + [True, False, True, True, True, False], + ], + dtype=torch.bool, + ) + output, lse = reference_attention_no_sink(q, kv, valid_tokens, scale) + expected_output, expected_lse = _golden_no_sink_attention( + q, + kv, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=1e-6, atol=1e-6) + torch.testing.assert_close(lse, expected_lse, rtol=1e-6, atol=1e-6) + + +def test_reference_attention_ignores_nan_kv_for_invalid_tokens() -> None: + torch.manual_seed(24) + q = torch.randn(2, 1, 3, 8) + kv = torch.randn(2, 4, 8) + kv[:, 2:] = float("nan") + valid_tokens = torch.tensor( + [[True, True, False, False], [True, False, False, False]], + dtype=torch.bool, + ) + + output, lse = reference_attention_no_sink( + q=q, + kv=kv, + valid_tokens=valid_tokens, + scale=0.125, + ) + + assert torch.isfinite(output).all() + assert torch.isfinite(lse).all() + + +def test_sink_aware_reference_attention_matches_dense_golden() -> None: + torch.manual_seed(1) + scale = 0.125 + q = torch.randn(3, 1, 4, 5) + kv = torch.randn(3, 6, 5) + valid_tokens = torch.tensor( + [ + [True, True, False, True, False, False], + [False, False, False, False, False, False], + [False, True, True, False, True, True], + ], + dtype=torch.bool, + ) + sink = torch.tensor([-1.0, 0.25, 1.5, -0.5]) + output = torch.empty(3, 4, 5) + sink_aware_reference_attention(q, kv, valid_tokens, scale, sink, output) + expected = _golden_sink_attention(q, kv, valid_tokens, scale, sink) + + torch.testing.assert_close(output, expected, rtol=1e-6, atol=1e-6) + + +def test_lse_merge_with_sink_matches_concatenated_attention() -> None: + torch.manual_seed(2) + scale = 0.2 + q = torch.randn(4, 3, 7) + compressed_kv = torch.randn(4, 5, 7) + swa_kv = torch.randn(4, 3, 7) + compressed_kv[:, 1] = compressed_kv[:, 0] + swa_kv[:, 2] = compressed_kv[:, 0] + compressed_valid = torch.tensor( + [ + [True, True, False, True, False], + [False, False, False, False, False], + [True, False, True, True, False], + [False, False, False, False, False], + ], + dtype=torch.bool, + ) + swa_valid = torch.tensor( + [ + [True, False, True], + [True, True, False], + [False, False, False], + [False, False, False], + ], + dtype=torch.bool, + ) + sink = torch.tensor([-0.25, 0.75, 1.25]) + output = torch.empty(4, 3, 7) + comp_output, comp_lse = reference_attention_no_sink( + q, + compressed_kv, + compressed_valid, + scale, + ) + swa_output, swa_lse = reference_attention_no_sink(q, swa_kv, swa_valid, scale) + merge_reference_attention_with_sink( + subset_outputs=[comp_output, swa_output], + subset_lses=[comp_lse, swa_lse], + attn_sink=sink, + output=output, + ) + + expected = _golden_sink_attention( + q, + torch.cat([compressed_kv, swa_kv], dim=1), + torch.cat([compressed_valid, swa_valid], dim=1), + scale, + sink, + ) + torch.testing.assert_close(output, expected, rtol=1e-6, atol=1e-6) + assert torch.equal(output[3], torch.zeros_like(output[3])) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_lse_merge_with_sink_matches_reference() -> None: + torch.manual_seed(5) + comp_output = torch.randn(3, 4, 9, device="cuda", dtype=torch.float32) + swa_output = torch.randn(3, 4, 9, device="cuda", dtype=torch.float32) + comp_lse = torch.randn(3, 4, device="cuda", dtype=torch.float32) + swa_lse = torch.randn(3, 4, device="cuda", dtype=torch.float32) + comp_lse[1, 2] = float("-inf") + swa_lse[2, 1] = float("-inf") + sink = torch.tensor([-0.5, 0.25, 1.0, -1.5], device="cuda") + + output = torch.empty(3, 4, 9, device="cuda", dtype=torch.bfloat16) + expected = torch.empty_like(output) + merge_two_sparse_mla_subsets_with_sink( + subset0_output=comp_output, + subset0_lse=comp_lse, + subset1_output=swa_output, + subset1_lse=swa_lse, + attn_sink=sink, + output=output, + ) + merge_reference_attention_with_sink( + subset_outputs=[comp_output, swa_output], + subset_lses=[comp_lse, swa_lse], + attn_sink=sink, + output=expected, + ) + + torch.testing.assert_close(output.float(), expected.float(), rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_single_lse_merge_with_sink_matches_reference() -> None: + torch.manual_seed(14) + subset_output = torch.randn(3, 4, 9, device="cuda", dtype=torch.float32) + subset_lse = torch.randn(3, 4, device="cuda", dtype=torch.float32) + subset_lse[1, 2] = float("-inf") + sink = torch.tensor([-0.5, 0.25, 1.0, -1.5], device="cuda") + + output = torch.empty(3, 4, 9, device="cuda", dtype=torch.bfloat16) + expected = torch.empty_like(output) + merge_sparse_mla_subset_with_sink( + subset_output=subset_output, + subset_lse=subset_lse, + attn_sink=sink, + output=output, + ) + merge_reference_attention_with_sink( + subset_outputs=[subset_output], + subset_lses=[subset_lse], + attn_sink=sink, + output=expected, + ) + + torch.testing.assert_close(output.float(), expected.float(), rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_finish_with_sink_matches_finish_then_merge_reference() -> None: + torch.manual_seed(18) + max_score = torch.randn(4, 3, device="cuda", dtype=torch.float32) + denom = torch.rand(4, 3, device="cuda", dtype=torch.float32) + 0.1 + denom[1, 2] = 0.0 + max_score[1, 2] = float("-inf") + acc = torch.randn(4, 3, 17, device="cuda", dtype=torch.float32) + sink = torch.tensor( + [-0.5, 0.25, 1.0, -float("inf"), -float("inf")], + device="cuda", + dtype=torch.float32, + ) + + output = torch.full((4, 5, 17), -7.0, device="cuda", dtype=torch.bfloat16) + finish_sparse_mla_attention_with_sink(max_score, denom, acc, sink, output) + + subset_output = torch.empty_like(acc) + subset_lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=subset_output, + lse=subset_lse, + ) + expected = torch.empty(4, 3, 17, device="cuda", dtype=torch.bfloat16) + merge_reference_attention_with_sink( + subset_outputs=[subset_output], + subset_lses=[subset_lse], + attn_sink=sink[:3], + output=expected, + ) + + torch.testing.assert_close( + output[:, :3].float(), expected.float(), rtol=1e-2, atol=1e-2 + ) + torch.testing.assert_close( + output[:, 3:].float(), + torch.full_like(output[:, 3:].float(), -7.0), + rtol=0, + atol=0, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_finish_with_sink_returns_zero_when_no_tokens_or_sink() -> None: + max_score = torch.full((2, 3), float("-inf"), device="cuda") + denom = torch.zeros((2, 3), device="cuda") + acc = torch.full((2, 3, 17), float("nan"), device="cuda") + sink = torch.full((3,), float("-inf"), device="cuda") + + single_output = torch.full((2, 3, 17), 7.0, device="cuda", dtype=torch.bfloat16) + finish_sparse_mla_attention_with_sink( + max_score, + denom, + acc, + sink, + output=single_output, + ) + torch.testing.assert_close( + single_output.float(), + torch.zeros_like(single_output.float()), + rtol=0, + atol=0, + ) + + two_output = torch.full((2, 3, 17), 7.0, device="cuda", dtype=torch.bfloat16) + finish_two_sparse_mla_attention_states_with_sink( + max_score, + denom, + acc, + max_score, + denom, + acc, + sink, + output=two_output, + ) + torch.testing.assert_close( + two_output.float(), + torch.zeros_like(two_output.float()), + rtol=0, + atol=0, + ) + + +def test_triton_finish_two_states_with_sink_matches_finish_then_merge() -> None: + torch.manual_seed(22) + comp_max = torch.randn(4, 3, device="cuda", dtype=torch.float32) + comp_denom = torch.rand(4, 3, device="cuda", dtype=torch.float32) + 0.1 + comp_acc = torch.randn(4, 3, 17, device="cuda", dtype=torch.float32) + swa_max = torch.randn(4, 3, device="cuda", dtype=torch.float32) + swa_denom = torch.rand(4, 3, device="cuda", dtype=torch.float32) + 0.1 + swa_acc = torch.randn(4, 3, 17, device="cuda", dtype=torch.float32) + sink = torch.tensor( + [-0.5, 0.25, 1.0, -float("inf"), -float("inf")], + device="cuda", + dtype=torch.float32, + ) + + comp_denom[0, 1] = 0.0 + comp_max[0, 1] = float("-inf") + swa_denom[2, 0] = 0.0 + swa_max[2, 0] = float("-inf") + comp_denom[3, 2] = 0.0 + comp_max[3, 2] = float("-inf") + swa_denom[3, 2] = 0.0 + swa_max[3, 2] = float("-inf") + + output = torch.full((4, 5, 17), -7.0, device="cuda", dtype=torch.bfloat16) + finish_two_sparse_mla_attention_states_with_sink( + comp_max, + comp_denom, + comp_acc, + swa_max, + swa_denom, + swa_acc, + sink, + output, + ) + + comp_output = torch.empty_like(comp_acc) + comp_lse = torch.empty_like(comp_max) + swa_output = torch.empty_like(swa_acc) + swa_lse = torch.empty_like(swa_max) + finish_gathered_sparse_mla_attention( + comp_max, + comp_denom, + comp_acc, + comp_output, + comp_lse, + ) + finish_gathered_sparse_mla_attention( + swa_max, + swa_denom, + swa_acc, + swa_output, + swa_lse, + ) + expected = torch.empty(4, 3, 17, device="cuda", dtype=torch.bfloat16) + merge_two_sparse_mla_subsets_with_sink( + subset0_output=comp_output, + subset0_lse=comp_lse, + subset1_output=swa_output, + subset1_lse=swa_lse, + attn_sink=sink[:3], + output=expected, + ) + + torch.testing.assert_close( + output[:, :3].float(), expected.float(), rtol=1e-2, atol=1e-2 + ) + torch.testing.assert_close( + output[:, 3:].float(), + torch.full_like(output[:, 3:].float(), -7.0), + rtol=0, + atol=0, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_dim", [16, 512]) +def test_triton_gathered_attention_chunk_matches_reference(head_dim: int) -> None: + torch.manual_seed(6) + scale = 0.125 + q = torch.randn(2, 1, 5, head_dim, device="cuda", dtype=torch.bfloat16) + q_active = q[:, :, :3] + kv = torch.randn(2, 5, head_dim, device="cuda", dtype=torch.bfloat16) + slot_ids = torch.tensor( + [ + [0, 1, -1, 3, 4], + [5, -1, 7, 8, -1], + ], + dtype=torch.int32, + device="cuda", + ) + lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + max_score = torch.full((2, 3), float("-inf"), device="cuda") + denom = torch.zeros((2, 3), device="cuda") + acc = torch.zeros((2, 3, head_dim), device="cuda") + + accumulate_gathered_sparse_mla_attention_chunk( + q=q, + kv=kv[:, :2], + slot_ids=slot_ids[:, :2], + lens=lens, + candidate_offset=0, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + accumulate_gathered_sparse_mla_attention_chunk( + q=q, + kv=kv[:, 2:], + slot_ids=slot_ids[:, 2:], + lens=lens, + candidate_offset=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + offsets = torch.arange(slot_ids.shape[1], device="cuda") + valid_tokens = (offsets[None, :] < lens[:, None]) & (slot_ids >= 0) + expected_output, expected_lse = reference_attention_no_sink( + q_active, + kv, + valid_tokens, + scale, + ) + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_gathered_attention_chunk_matches_reference_without_slot_ids() -> None: + torch.manual_seed(8) + scale = 0.2 + q = torch.randn(3, 1, 2, 32, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(3, 6, 32, device="cuda", dtype=torch.bfloat16) + lens = torch.tensor([6, 3, 0], dtype=torch.int32, device="cuda") + max_score = torch.full((3, 2), float("-inf"), device="cuda") + denom = torch.zeros((3, 2), device="cuda") + acc = torch.zeros((3, 2, 32), device="cuda") + + accumulate_gathered_sparse_mla_attention_chunk( + q=q, + kv=kv, + slot_ids=None, + lens=lens, + candidate_offset=0, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + offsets = torch.arange(kv.shape[1], device="cuda") + valid_tokens = offsets[None, :] < lens[:, None] + expected_output, expected_lse = reference_attention_no_sink( + q, + kv, + valid_tokens, + scale, + ) + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_dequantize_global_slots_k_cache_fp8_ds_mla_layout() -> None: + block_size = 4 + num_blocks = 2 + k_cache = torch.zeros( + num_blocks, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + expected_by_slot = { + slot: _write_fp8_ds_mla_token(k_cache, slot, block_size) for slot in (0, 3, 4) + } + slot_ids = torch.tensor( + [ + [0, 3, -1, 4], + [4, 0, 3, -1], + ], + dtype=torch.int32, + device="cuda", + ) + + output = torch.empty(2, 4, 512, dtype=torch.bfloat16, device="cuda") + dequantize_global_slots_k_cache(output, k_cache, slot_ids, block_size) + + expected = torch.zeros_like(output) + for token_idx in range(slot_ids.shape[0]): + for topk_idx in range(slot_ids.shape[1]): + slot = int(slot_ids[token_idx, topk_idx].item()) + if slot >= 0: + expected[token_idx, topk_idx] = expected_by_slot[slot] + + torch.testing.assert_close(output.float(), expected.float(), rtol=0, atol=0) + + output_from_3d_indices = torch.empty_like(output) + dequantize_global_slots_k_cache( + output_from_3d_indices, + k_cache, + slot_ids.unsqueeze(1), + block_size, + ) + torch.testing.assert_close( + output_from_3d_indices.float(), + expected.float(), + rtol=0, + atol=0, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_dequantize_combined_sparse_mla_decode_kv_writes_direct_views() -> None: + compressed_block_size = 4 + swa_block_size = 4 + compressed_cache = torch.zeros( + 2, + compressed_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + swa_cache = torch.zeros( + 3, + swa_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + for slot in (0, 3, 4): + _write_fp8_ds_mla_token(compressed_cache, slot, compressed_block_size) + for slot in (0, 1, 2, 3, 4): + _write_fp8_ds_mla_token(swa_cache, slot, swa_block_size) + + compressed_slot_ids = torch.tensor( + [[0, 3, -1], [4, 0, 3]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([5, 7], dtype=torch.int32, device="cuda") + swa_lens = torch.tensor([2, 3], dtype=torch.int32, device="cuda") + block_table = torch.tensor( + [[0, 1, 2], [2, 0, 1]], + dtype=torch.int32, + device="cuda", + ) + + combined = torch.full( + (2, 6, 512), + -7, + dtype=torch.bfloat16, + device="cuda", + ) + dequantize_combined_sparse_mla_decode_kv( + combined, + compressed_cache, + compressed_slot_ids, + compressed_block_size, + swa_cache, + seq_lens, + swa_lens, + block_table, + swa_block_size, + ) + + expected_comp = torch.empty(2, 3, 512, dtype=torch.bfloat16, device="cuda") + expected_swa = torch.full( + (2, 3, 512), + -7, + dtype=torch.bfloat16, + device="cuda", + ) + dequantize_global_slots_k_cache( + expected_comp, + compressed_cache, + compressed_slot_ids, + compressed_block_size, + ) + dequantize_and_gather_k_cache( + expected_swa, + swa_cache, + seq_lens=seq_lens, + gather_lens=swa_lens, + block_table=block_table, + block_size=swa_block_size, + offset=0, + ) + expected = torch.full_like(combined, -7) + expected[:, :3].copy_(expected_comp) + expected[:, 3:].copy_(expected_swa) + + torch.testing.assert_close(combined.float(), expected.float(), rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_fp8ds_global_slots_attention_chunk_matches_reference() -> None: + torch.manual_seed(10) + block_size = 4 + num_blocks = 3 + k_cache = torch.zeros( + num_blocks, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + expected_by_slot = { + slot: _write_fp8_ds_mla_token(k_cache, slot, block_size) + for slot in (0, 1, 3, 4, 7, 8) + } + slot_ids = torch.tensor( + [ + [0, 3, -1, 8, 1], + [7, -1, 4, 0, 8], + ], + dtype=torch.int32, + device="cuda", + ) + lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 3, 512, device="cuda", dtype=torch.bfloat16) + scale = 0.0625 + + max_score = torch.full((2, 3), float("-inf"), device="cuda") + denom = torch.zeros((2, 3), device="cuda") + acc = torch.zeros((2, 3, 512), device="cuda") + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + slot_ids=slot_ids[:, :2], + lens=lens, + block_size=block_size, + candidate_offset=0, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + slot_ids=slot_ids[:, 2:], + lens=lens, + block_size=block_size, + candidate_offset=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + gathered = torch.zeros(2, 5, 512, device="cuda", dtype=torch.bfloat16) + for token_idx in range(slot_ids.shape[0]): + for topk_idx in range(slot_ids.shape[1]): + slot = int(slot_ids[token_idx, topk_idx].item()) + if slot >= 0: + gathered[token_idx, topk_idx] = expected_by_slot[slot] + offsets = torch.arange(slot_ids.shape[1], device="cuda") + valid_tokens = (offsets[None, :] < lens[:, None]) & (slot_ids >= 0) + expected_output, expected_lse = reference_attention_no_sink( + q, + gathered, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_block_size", [1, 2, 4]) +def test_triton_fp8ds_global_slots_multihead_attention_matches_reference( + head_block_size: int, +) -> None: + torch.manual_seed(19) + block_size = 4 + num_blocks = 3 + k_cache = torch.zeros( + num_blocks, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + expected_by_slot = { + slot: _write_fp8_ds_mla_token(k_cache, slot, block_size) + for slot in (0, 1, 3, 4, 7, 8) + } + slot_ids = torch.tensor( + [ + [0, 3, -1, 8, 1], + [7, -1, 4, 0, 8], + ], + dtype=torch.int32, + device="cuda", + ) + lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + q_active = q[:, :, :5] + scale = 0.0625 + + max_score = torch.full((2, 5), float("-inf"), device="cuda") + denom = torch.zeros((2, 5), device="cuda") + acc = torch.zeros((2, 5, 512), device="cuda") + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + slot_ids=slot_ids[:, :2], + lens=lens, + block_size=block_size, + candidate_offset=0, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + head_block_size=head_block_size, + ) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + slot_ids=slot_ids[:, 2:], + lens=lens, + block_size=block_size, + candidate_offset=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + head_block_size=head_block_size, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + gathered = torch.zeros(2, 5, 512, device="cuda", dtype=torch.bfloat16) + for token_idx in range(slot_ids.shape[0]): + for topk_idx in range(slot_ids.shape[1]): + slot = int(slot_ids[token_idx, topk_idx].item()) + if slot >= 0: + gathered[token_idx, topk_idx] = expected_by_slot[slot] + offsets = torch.arange(slot_ids.shape[1], device="cuda") + valid_tokens = (offsets[None, :] < lens[:, None]) & (slot_ids >= 0) + expected_output, expected_lse = reference_attention_no_sink( + q_active, + gathered, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_fp8ds_paged_attention_chunk_matches_reference() -> None: + torch.manual_seed(12) + block_size = 4 + k_cache = torch.zeros( + 3, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor( + [ + [1, 0, 2], + [2, 1, 0], + ], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([6, 9], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 4], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 3, 512, device="cuda", dtype=torch.bfloat16) + scale = 0.0625 + + gathered = torch.zeros(2, 4, 512, device="cuda", dtype=torch.bfloat16) + expected_by_slot: dict[int, torch.Tensor] = {} + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + block_idx = pos // block_size + block_offset = pos % block_size + physical_block = int(block_table[token_idx, block_idx].item()) + slot = physical_block * block_size + block_offset + expected_by_slot.setdefault( + slot, + _write_fp8_ds_mla_token(k_cache, slot, block_size), + ) + gathered[token_idx, gather_idx] = expected_by_slot[slot] + + max_score = torch.full((2, 3), float("-inf"), device="cuda") + denom = torch.zeros((2, 3), device="cuda") + acc = torch.zeros((2, 3, 512), device="cuda") + accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=2, + num_candidates=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + offsets = torch.arange(gathered.shape[1], device="cuda") + valid_tokens = offsets[None, :] < gather_lens[:, None] + expected_output, expected_lse = reference_attention_no_sink( + q, + gathered, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_block_size", [1, 2, 4]) +def test_triton_fp8ds_paged_multihead_attention_matches_singlehead_and_reference( + head_block_size: int, +) -> None: + torch.manual_seed(23) + block_size = 4 + k_cache = torch.zeros( + 4, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor( + [ + [1, 0, 2, 3], + [2, 3, 1, 0], + ], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + q_active = q[:, :, :5] + scale = 0.0625 + + gathered = torch.zeros(2, 5, 512, device="cuda", dtype=torch.bfloat16) + expected_by_slot: dict[int, torch.Tensor] = {} + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + block_idx = pos // block_size + block_offset = pos % block_size + physical_block = int(block_table[token_idx, block_idx].item()) + slot = physical_block * block_size + block_offset + expected_by_slot.setdefault( + slot, + _write_fp8_ds_mla_token(k_cache, slot, block_size), + ) + gathered[token_idx, gather_idx] = expected_by_slot[slot] + + single_max = torch.full((2, 5), float("-inf"), device="cuda") + single_denom = torch.zeros((2, 5), device="cuda") + single_acc = torch.zeros((2, 5, 512), device="cuda") + multi_max = torch.full_like(single_max, float("-inf")) + multi_denom = torch.zeros_like(single_denom) + multi_acc = torch.zeros_like(single_acc) + + for candidate_offset, num_candidates in ((0, 2), (2, 3)): + accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=candidate_offset, + num_candidates=num_candidates, + scale=scale, + max_score=single_max, + denom=single_denom, + acc=single_acc, + ) + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=candidate_offset, + num_candidates=num_candidates, + scale=scale, + max_score=multi_max, + denom=multi_denom, + acc=multi_acc, + head_block_size=head_block_size, + ) + + torch.testing.assert_close(multi_max, single_max, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(multi_denom, single_denom, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(multi_acc, single_acc, rtol=2e-2, atol=2e-2) + + output = torch.empty_like(multi_acc) + lse = torch.empty_like(multi_max) + finish_gathered_sparse_mla_attention( + max_score=multi_max, + denom=multi_denom, + acc=multi_acc, + output=output, + lse=lse, + ) + offsets = torch.arange(gathered.shape[1], device="cuda") + valid_tokens = offsets[None, :] < gather_lens[:, None] + expected_output, expected_lse = reference_attention_no_sink( + q_active, + gathered, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_fp8ds_paged_attention_with_sink_matches_reference() -> None: + torch.manual_seed(15) + block_size = 4 + k_cache = torch.zeros( + 3, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor([[1, 0, 2]], dtype=torch.int32, device="cuda") + seq_lens = torch.tensor([7], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([4], dtype=torch.int32, device="cuda") + q = torch.randn(1, 1, 3, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.tensor([-0.25, 0.5, 1.25], device="cuda") + scale = 0.0625 + + gathered = torch.zeros(1, 4, 512, device="cuda", dtype=torch.bfloat16) + expected_by_slot: dict[int, torch.Tensor] = {} + start_pos = int(seq_lens[0].item() - gather_lens[0].item()) + for gather_idx in range(int(gather_lens[0].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[0, pos // block_size].item()) + slot = physical_block * block_size + pos % block_size + expected_by_slot.setdefault( + slot, + _write_fp8_ds_mla_token(k_cache, slot, block_size), + ) + gathered[0, gather_idx] = expected_by_slot[slot] + + max_score = torch.full((1, 3), float("-inf"), device="cuda") + denom = torch.zeros((1, 3), device="cuda") + acc = torch.zeros((1, 3, 512), device="cuda") + accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=4, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + subset_output = torch.empty_like(acc) + subset_lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=subset_output, + lse=subset_lse, + ) + + output = torch.empty(1, 3, 512, device="cuda", dtype=torch.bfloat16) + merge_sparse_mla_subset_with_sink( + subset_output=subset_output, + subset_lse=subset_lse, + attn_sink=sink, + output=output, + ) + valid_tokens = torch.ones(1, 4, device="cuda", dtype=torch.bool) + expected = _golden_sink_attention(q, gathered, valid_tokens, scale, sink) + + torch.testing.assert_close(output.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_block_size", [1, 2, 4]) +def test_triton_fp8ds_paged_attention_with_sink_direct_matches_state_path( + head_block_size: int, +) -> None: + torch.manual_seed(29) + block_size = 4 + k_cache = torch.zeros( + 4, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor( + [[1, 0, 2, 3], [2, 3, 1, 0]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.linspace(-0.5, 0.5, 5, device="cuda") + scale = 0.0625 + + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[token_idx, pos // block_size].item()) + slot = physical_block * block_size + pos % block_size + _write_fp8_ds_mla_token(k_cache, slot, block_size) + + max_score = torch.full((2, 5), float("-inf"), device="cuda") + denom = torch.zeros((2, 5), device="cuda") + acc = torch.zeros((2, 5, 512), device="cuda") + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + head_block_size=1, + ) + expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16) + finish_sparse_mla_attention_with_sink(max_score, denom, acc, sink, expected) + + actual = torch.empty_like(expected) + fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=head_block_size, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_block_size", [1, 2, 4]) +def test_triton_fp8ds_global_paged_attention_with_sink_direct_matches_state_path( + head_block_size: int, +) -> None: + torch.manual_seed(31) + compressed_block_size = 4 + swa_block_size = 4 + compressed_cache = torch.zeros( + 4, + compressed_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + swa_cache = torch.zeros( + 4, + swa_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + slot_ids = torch.tensor( + [[0, 3, -1, 8, 1], [7, -1, 4, 0, 8]], + dtype=torch.int32, + device="cuda", + ) + topk_lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + block_table = torch.tensor( + [[1, 0, 2, 3], [2, 3, 1, 0]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.linspace(-1.0, 1.0, 5, device="cuda") + scale = 0.0625 + + for slot in (0, 1, 3, 4, 7, 8): + _write_fp8_ds_mla_token(compressed_cache, slot, compressed_block_size) + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[token_idx, pos // swa_block_size].item()) + slot = physical_block * swa_block_size + pos % swa_block_size + _write_fp8_ds_mla_token(swa_cache, slot, swa_block_size) + + comp_max = torch.full((2, 5), float("-inf"), device="cuda") + comp_denom = torch.zeros((2, 5), device="cuda") + comp_acc = torch.zeros((2, 5, 512), device="cuda") + swa_max = torch.full((2, 5), float("-inf"), device="cuda") + swa_denom = torch.zeros((2, 5), device="cuda") + swa_acc = torch.zeros((2, 5, 512), device="cuda") + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=compressed_cache, + slot_ids=slot_ids, + lens=topk_lens, + block_size=compressed_block_size, + candidate_offset=0, + scale=scale, + max_score=comp_max, + denom=comp_denom, + acc=comp_acc, + head_block_size=1, + ) + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=swa_block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + max_score=swa_max, + denom=swa_denom, + acc=swa_acc, + head_block_size=1, + ) + expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16) + finish_two_sparse_mla_attention_states_with_sink( + comp_max, + comp_denom, + comp_acc, + swa_max, + swa_denom, + swa_acc, + sink, + expected, + ) + + actual = torch.empty_like(expected) + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q=q, + compressed_k_cache=compressed_cache, + slot_ids=slot_ids, + topk_lens=topk_lens, + compressed_block_size=compressed_block_size, + swa_k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + swa_block_size=swa_block_size, + num_compressed_candidates=5, + num_swa_candidates=5, + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=head_block_size, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_matmul_sparse_mla_attention_with_sink_matches_reference() -> None: + torch.manual_seed(41) + q = torch.randn(2, 1, 5, 512, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(2, 7, 512, device="cuda", dtype=torch.bfloat16) + valid_tokens = torch.tensor( + [ + [True, True, False, True, False, True, True], + [False, True, True, False, True, False, False], + ], + dtype=torch.bool, + device="cuda", + ) + sink = torch.linspace(-0.25, 0.25, 5, device="cuda") + scale = 0.0625 + + expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16) + sink_aware_reference_attention( + q, + kv, + valid_tokens, + scale, + sink, + expected, + ) + + actual = torch.empty_like(expected) + matmul_sparse_mla_attention_with_sink( + q, + kv, + valid_tokens, + scale, + sink, + actual, + num_heads=5, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_matmul_sparse_mla_attention_accepts_bf16_score_buffer() -> None: + torch.manual_seed(67) + q = torch.randn(2, 1, 5, 512, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(2, 7, 512, device="cuda", dtype=torch.bfloat16) + valid_tokens = torch.tensor( + [ + [True, True, False, True, False, True, True], + [False, True, True, False, True, False, False], + ], + dtype=torch.bool, + device="cuda", + ) + sink = torch.linspace(-0.25, 0.25, 5, device="cuda") + scale = 0.0625 + + expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16) + sink_aware_reference_attention(q, kv, valid_tokens, scale, sink, expected) + + actual = torch.empty_like(expected) + score_buffer = torch.empty(2, 5, 7, device="cuda", dtype=torch.bfloat16) + matmul_sparse_mla_attention_with_sink( + q, + kv, + valid_tokens, + scale, + sink, + actual, + num_heads=5, + score_buffer=score_buffer, + value_block_size=512, + candidate_block_size=128, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize( + ("candidate_block_size", "value_block_size"), + [(32, 128), (64, 128), (64, 256), (128, 512)], +) +def test_finish_materialized_scores_candidate_block_matches_reference( + candidate_block_size: int, + value_block_size: int, +) -> None: + torch.manual_seed(61) + q = torch.randn(3, 1, 7, 512, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(3, 13, 512, device="cuda", dtype=torch.bfloat16) + valid_tokens = torch.tensor( + [ + [ + True, + True, + False, + True, + False, + True, + True, + False, + True, + True, + False, + True, + False, + ], + [ + False, + True, + True, + False, + True, + False, + False, + True, + False, + True, + True, + False, + True, + ], + [ + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + ], + dtype=torch.bool, + device="cuda", + ) + sink = torch.linspace(-0.25, 0.25, 7, device="cuda") + scale = 0.0625 + + expected = torch.empty(3, 7, 512, device="cuda", dtype=torch.bfloat16) + sink_aware_reference_attention(q, kv, valid_tokens, scale, sink, expected) + + scores = torch.bmm(q[:, 0].float(), kv.float().transpose(1, 2)) + scores.mul_(scale) + actual = torch.empty_like(expected) + finish_materialized_sparse_mla_scores_with_sink( + scores, + kv, + valid_tokens, + sink, + actual, + num_heads=7, + value_block_size=value_block_size, + candidate_block_size=candidate_block_size, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("value_block_size", [128, 256]) +def test_finish_materialized_scores_value_block_matches_reference( + value_block_size: int, +) -> None: + torch.manual_seed(53) + q = torch.randn(3, 1, 7, 512, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(3, 11, 512, device="cuda", dtype=torch.bfloat16) + valid_tokens = torch.tensor( + [ + [True, True, False, True, False, True, True, False, True, True, False], + [False, True, True, False, True, False, False, True, False, True, True], + [ + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + ], + dtype=torch.bool, + device="cuda", + ) + sink = torch.linspace(-0.25, 0.25, 7, device="cuda") + scale = 0.0625 + + expected = torch.empty(3, 7, 512, device="cuda", dtype=torch.bfloat16) + sink_aware_reference_attention(q, kv, valid_tokens, scale, sink, expected) + + scores = torch.bmm( + q[:, 0].float(), + kv.float().transpose(1, 2), + ) + scores.mul_(scale) + actual = torch.empty_like(expected) + finish_materialized_sparse_mla_scores_with_sink( + scores, + kv, + valid_tokens, + sink, + actual, + num_heads=7, + value_block_size=value_block_size, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_build_combined_sparse_mla_decode_valid_mask_matches_torch() -> None: + compressed_slot_ids = torch.tensor( + [ + [7, 4, -1, 9, 11], + [2, -1, 3, 8, 10], + [-1, -1, -1, -1, -1], + ], + device="cuda", + dtype=torch.int32, + ) + topk_lens = torch.tensor([4, 3, 0], device="cuda", dtype=torch.int32) + swa_lens = torch.tensor([3, 1, 0], device="cuda", dtype=torch.int32) + valid_tokens = torch.empty(3, 9, device="cuda", dtype=torch.bool) + + build_combined_sparse_mla_decode_valid_mask( + valid_tokens, + compressed_slot_ids, + topk_lens, + swa_lens, + ) + + comp_offsets = torch.arange(5, device="cuda", dtype=torch.int32) + swa_offsets = torch.arange(4, device="cuda", dtype=torch.int32) + expected = torch.empty_like(valid_tokens) + expected[:, :5] = (comp_offsets[None, :] < topk_lens[:, None]) & ( + compressed_slot_ids >= 0 + ) + expected[:, 5:] = swa_offsets[None, :] < swa_lens[:, None] + + torch.testing.assert_close(valid_tokens, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("num_heads", [8, 16, 32, 64]) +def test_triton_fp8ds_paged_attention_with_sink_supports_tp_local_heads( + num_heads: int, +) -> None: + torch.manual_seed(37 + num_heads) + block_size = 4 + k_cache = torch.zeros( + 4, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor( + [[1, 0, 2, 3], [2, 3, 1, 0]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, num_heads, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.linspace(-0.5, 0.5, num_heads, device="cuda") + scale = 0.0625 + + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[token_idx, pos // block_size].item()) + slot = physical_block * block_size + pos % block_size + _write_fp8_ds_mla_token(k_cache, slot, block_size) + + max_score = torch.full((2, num_heads), float("-inf"), device="cuda") + denom = torch.zeros((2, num_heads), device="cuda") + acc = torch.zeros((2, num_heads, 512), device="cuda") + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + head_block_size=1, + ) + expected = torch.empty(2, num_heads, 512, device="cuda", dtype=torch.bfloat16) + finish_sparse_mla_attention_with_sink(max_score, denom, acc, sink, expected) + + actual = torch.empty_like(expected) + fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=4, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("num_heads", [8, 16, 32, 64]) +def test_triton_fp8ds_global_paged_attention_with_sink_supports_tp_local_heads( + num_heads: int, +) -> None: + torch.manual_seed(41 + num_heads) + compressed_block_size = 4 + swa_block_size = 4 + compressed_cache = torch.zeros( + 4, + compressed_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + swa_cache = torch.zeros( + 4, + swa_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + slot_ids = torch.tensor( + [[0, 3, -1, 8, 1], [7, -1, 4, 0, 8]], + dtype=torch.int32, + device="cuda", + ) + topk_lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + block_table = torch.tensor( + [[1, 0, 2, 3], [2, 3, 1, 0]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, num_heads, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.linspace(-1.0, 1.0, num_heads, device="cuda") + scale = 0.0625 + + for slot in (0, 1, 3, 4, 7, 8): + _write_fp8_ds_mla_token(compressed_cache, slot, compressed_block_size) + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[token_idx, pos // swa_block_size].item()) + slot = physical_block * swa_block_size + pos % swa_block_size + _write_fp8_ds_mla_token(swa_cache, slot, swa_block_size) + + comp_max = torch.full((2, num_heads), float("-inf"), device="cuda") + comp_denom = torch.zeros((2, num_heads), device="cuda") + comp_acc = torch.zeros((2, num_heads, 512), device="cuda") + swa_max = torch.full((2, num_heads), float("-inf"), device="cuda") + swa_denom = torch.zeros((2, num_heads), device="cuda") + swa_acc = torch.zeros((2, num_heads, 512), device="cuda") + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=compressed_cache, + slot_ids=slot_ids, + lens=topk_lens, + block_size=compressed_block_size, + candidate_offset=0, + scale=scale, + max_score=comp_max, + denom=comp_denom, + acc=comp_acc, + head_block_size=1, + ) + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=swa_block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + max_score=swa_max, + denom=swa_denom, + acc=swa_acc, + head_block_size=1, + ) + expected = torch.empty(2, num_heads, 512, device="cuda", dtype=torch.bfloat16) + finish_two_sparse_mla_attention_states_with_sink( + comp_max, + comp_denom, + comp_acc, + swa_max, + swa_denom, + swa_acc, + sink, + expected, + ) + + actual = torch.empty_like(expected) + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q=q, + compressed_k_cache=compressed_cache, + slot_ids=slot_ids, + topk_lens=topk_lens, + compressed_block_size=compressed_block_size, + swa_k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + swa_block_size=swa_block_size, + num_compressed_candidates=5, + num_swa_candidates=5, + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=4, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_indexed_bf16_prefill_chunks_match_reference() -> None: + torch.manual_seed(17) + q = torch.randn(5, 5, 16, device="cuda", dtype=torch.bfloat16) + q_active = q[:, :3] + kv = torch.randn(2, 7, 16, device="cuda", dtype=torch.bfloat16) + kv_flat = kv.reshape(-1, q.shape[-1]) + combined_indices = torch.tensor( + [ + [0, 3, -1, 5, 3, 1], + [4, -1, 2, 2, 1, 8], + [-1, -1, -1, -1, -1, -1], + [8, 0, 9, -1, 7, 4], + [13, 12, 0, 12, -1, 3], + ], + dtype=torch.int64, + device="cuda", + ) + combined_lens = torch.tensor([5, 4, 0, 6, 5], dtype=torch.int32, device="cuda") + sink = torch.tensor([-0.5, 1.0, 0.25], dtype=torch.float32, device="cuda") + scale = 0.375 + output = torch.empty_like(q_active) + + for token_start in (0, 2, 4): + token_end = min(token_start + 2, q.shape[0]) + q_chunk = q[token_start:token_end] + indices_chunk = combined_indices[token_start:token_end] + lens_chunk = combined_lens[token_start:token_end] + max_score = torch.full( + (q_chunk.shape[0], q_active.shape[1]), + float("-inf"), + device="cuda", + ) + denom = torch.zeros_like(max_score) + acc = torch.zeros( + q_chunk.shape[0], + q_active.shape[1], + q_chunk.shape[-1], + device="cuda", + dtype=torch.float32, + ) + for index_start in (0, 3): + index_end = min(index_start + 3, combined_indices.shape[-1]) + accumulate_indexed_sparse_mla_attention_chunk( + q=q_chunk, + kv_flat=kv_flat, + indices=indices_chunk[:, index_start:index_end], + lens=lens_chunk, + candidate_offset=index_start, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + subset_output = torch.empty_like(acc) + subset_lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=subset_output, + lse=subset_lse, + ) + merge_sparse_mla_subset_with_sink( + subset_output=subset_output, + subset_lse=subset_lse, + attn_sink=sink, + output=output[token_start:token_end], + ) + + expected = torch.empty_like(q_active) + reference_sparse_mla_prefill( + q=q_active, + kv=kv, + combined_indices=combined_indices, + combined_lens=combined_lens, + scale=scale, + attn_sink=sink, + output=expected, + topk_chunk_size=3, + query_chunk_size=2, + ) + + torch.testing.assert_close(output.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.parametrize( + ("topk_chunk_size", "query_chunk_size"), + [(1, 1), (2, 3), (5, 2)], +) +def test_reference_sparse_mla_prefill_matches_dense_golden( + topk_chunk_size: int, + query_chunk_size: int, +) -> None: + torch.manual_seed(4) + scale = 0.375 + q = torch.randn(4, 2, 3) + kv = torch.randn(2, 5, 3) + combined_indices = torch.tensor( + [ + [0, 3, -1, 5, 3], + [4, -1, 2, 2, 1], + [-1, -1, -1, -1, -1], + [8, 0, 9, -1, 7], + ], + dtype=torch.int64, + ) + combined_lens = torch.tensor([4, 3, 0, 5], dtype=torch.int32) + sink = torch.tensor([-0.5, 1.0]) + output = torch.empty_like(q) + + reference_sparse_mla_prefill( + q=q, + kv=kv, + combined_indices=combined_indices, + combined_lens=combined_lens, + scale=scale, + attn_sink=sink, + output=output, + topk_chunk_size=topk_chunk_size, + query_chunk_size=query_chunk_size, + ) + + kv_flat = kv.reshape(-1, q.shape[-1]) + offsets = torch.arange(combined_indices.shape[-1]) + valid_tokens = (offsets[None, :] < combined_lens[:, None]) & (combined_indices >= 0) + safe_indices = torch.where( + valid_tokens, + combined_indices, + torch.zeros((), dtype=combined_indices.dtype), + ).long() + gathered_kv = kv_flat[safe_indices] + expected = _golden_sink_attention(q, gathered_kv, valid_tokens, scale, sink) + + torch.testing.assert_close(output, expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("chunk_size", [1, 2, 5]) +def test_chunked_reference_accumulation_matches_one_shot(chunk_size: int) -> None: + torch.manual_seed(3) + scale = 0.3 + q = torch.randn(3, 2, 4) + kv = torch.randn(3, 9, 4) + valid_tokens = torch.tensor( + [ + [True, False, True, True, False, False, True, False, True], + [False, False, False, False, False, False, False, False, False], + [True, True, True, False, True, False, True, True, False], + ], + dtype=torch.bool, + ) + output, lse = _chunked_no_sink_attention( + q, + kv, + valid_tokens, + scale, + chunk_size, + ) + expected_output, expected_lse = _golden_no_sink_attention( + q, + kv, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=1e-6, atol=1e-6) + torch.testing.assert_close(lse, expected_lse, rtol=1e-6, atol=1e-6) + + +def test_triton_sparse_mla_path_allows_cudagraph_support_by_default( + monkeypatch, +) -> None: + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE", "1") + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", raising=False) + + mla_spec = MLAAttentionSpec( + block_size=256, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + alignment=576, + compress_ratio=4, + model_version="deepseek_v4", + ) + swa_spec = SlidingWindowMLASpec( + block_size=64, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + sliding_window=128, + cache_dtype_str="fp8_ds_mla", + alignment=576, + model_version="deepseek_v4", + ) + + assert FlashMLASparseMetadataBuilder.get_cudagraph_support(None, mla_spec) is ( + AttentionCGSupport.UNIFORM_BATCH + ) + assert DeepseekSparseSWAMetadataBuilder.get_cudagraph_support(None, swa_spec) is ( + AttentionCGSupport.UNIFORM_BATCH + ) + + vllm_config = SimpleNamespace( + compilation_config=SimpleNamespace( + mode=CompilationMode.VLLM_COMPILE, + compile_sizes=[1, 2], + compile_ranges_endpoints=[8192], + cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE, + cudagraph_capture_sizes=[1, 2, 4], + max_cudagraph_capture_size=4, + ) + ) + disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config) + + assert vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE + assert vllm_config.compilation_config.compile_sizes == [1, 2] + assert vllm_config.compilation_config.compile_ranges_endpoints == [8192] + assert ( + vllm_config.compilation_config.cudagraph_mode + == CUDAGraphMode.FULL_AND_PIECEWISE + ) + assert vllm_config.compilation_config.cudagraph_capture_sizes == [1, 2, 4] + assert vllm_config.compilation_config.max_cudagraph_capture_size == 4 + + +def test_triton_sparse_mla_path_can_disable_cudagraphs(monkeypatch) -> None: + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE", "1") + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", "0") + + mla_spec = MLAAttentionSpec( + block_size=256, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + alignment=576, + compress_ratio=4, + model_version="deepseek_v4", + ) + swa_spec = SlidingWindowMLASpec( + block_size=64, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + sliding_window=128, + cache_dtype_str="fp8_ds_mla", + alignment=576, + model_version="deepseek_v4", + ) + + assert FlashMLASparseMetadataBuilder.get_cudagraph_support(None, mla_spec) is ( + AttentionCGSupport.NEVER + ) + assert DeepseekSparseSWAMetadataBuilder.get_cudagraph_support(None, swa_spec) is ( + AttentionCGSupport.NEVER + ) + + vllm_config = SimpleNamespace( + compilation_config=SimpleNamespace( + mode=CompilationMode.VLLM_COMPILE, + compile_sizes=[1, 2], + compile_ranges_endpoints=[8192], + cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE, + cudagraph_capture_sizes=[1, 2, 4], + max_cudagraph_capture_size=4, + ) + ) + disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config) + + assert vllm_config.compilation_config.mode == CompilationMode.NONE + assert vllm_config.compilation_config.compile_sizes == [] + assert vllm_config.compilation_config.compile_ranges_endpoints == [] + assert vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + assert vllm_config.compilation_config.cudagraph_capture_sizes == [] + assert vllm_config.compilation_config.max_cudagraph_capture_size == 0 + + +def test_triton_sparse_mla_path_disables_cudagraphs_for_mtp( + monkeypatch, +) -> None: + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE", "1") + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", raising=False) + + mla_spec = MLAAttentionSpec( + block_size=256, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + alignment=576, + compress_ratio=4, + model_version="deepseek_v4", + ) + swa_spec = SlidingWindowMLASpec( + block_size=64, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + sliding_window=128, + cache_dtype_str="fp8_ds_mla", + alignment=576, + model_version="deepseek_v4", + ) + vllm_config = SimpleNamespace( + speculative_config=SimpleNamespace( + method="mtp", + num_speculative_tokens=2, + ), + compilation_config=SimpleNamespace( + mode=CompilationMode.VLLM_COMPILE, + compile_sizes=[1, 2], + compile_ranges_endpoints=[8192], + cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE, + cudagraph_capture_sizes=[1, 2, 4], + max_cudagraph_capture_size=4, + ), + ) + + assert ( + FlashMLASparseMetadataBuilder.get_cudagraph_support( + vllm_config, + mla_spec, + ) + is AttentionCGSupport.NEVER + ) + assert ( + DeepseekSparseSWAMetadataBuilder.get_cudagraph_support( + vllm_config, + swa_spec, + ) + is AttentionCGSupport.NEVER + ) + + disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config) + + assert vllm_config.compilation_config.mode == CompilationMode.NONE + assert vllm_config.compilation_config.compile_sizes == [] + assert vllm_config.compilation_config.compile_ranges_endpoints == [] + assert vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + assert vllm_config.compilation_config.cudagraph_capture_sizes == [] + assert vllm_config.compilation_config.max_cudagraph_capture_size == 0 diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py new file mode 100644 index 000000000000..88d337ea5998 --- /dev/null +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm.utils.deep_gemm as deep_gemm_utils +from vllm.model_executor.layers.sparse_attn_indexer import ( + _decode_logits_width, + _decode_topk_logits_width, + _sparse_indexer_requires_deep_gemm, +) +from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv + + +def test_decode_logits_width_uses_active_context_bound(): + assert _decode_logits_width(262144, 1024) == 1024 + assert _decode_logits_width(4096, 8192) == 4096 + assert _decode_logits_width(4096, 0) == 4096 + assert _decode_logits_width(0, 1024) == 0 + + +def test_decode_topk_logits_width_keeps_topk_kernel_width(): + assert _decode_topk_logits_width(262144, 1024, 512) == 1024 + assert _decode_topk_logits_width(262144, 128, 512) == 512 + assert _decode_topk_logits_width(300, 128, 512) == 300 + assert _decode_topk_logits_width(0, 128, 512) == 0 + + +def test_sm120_sparse_indexer_does_not_require_deep_gemm(monkeypatch): + monkeypatch.setattr(current_platform, "is_cuda", lambda: True) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: capability == 120, + ) + + assert _sparse_indexer_requires_deep_gemm() is False + + +def test_non_sm120_cuda_sparse_indexer_still_requires_deep_gemm(monkeypatch): + monkeypatch.setattr(current_platform, "is_cuda", lambda: True) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: False, + ) + + assert _sparse_indexer_requires_deep_gemm() is True + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_paged_mqa_direct_topk_matches_truncated_decode_width( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(7) + batch_size, next_n, num_heads, head_dim = 2, 2, 8, 32 + block_size, max_model_len, num_blocks = 4, 64, 16 + active_max_len = 13 + topk_tokens = 6 + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None) + monkeypatch.setattr(deep_gemm_utils, "_SM120_PAGED_MQA_TOPK_CHUNK_SIZE", 7) + + q = torch.randn( + batch_size, + next_n, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + q_fp8 = q.to(torch.float8_e4m3fn).contiguous() + kv = torch.randn( + num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn) + fused_kv = torch.empty( + num_blocks, + block_size, + 1, + head_dim + 4, + device="cuda", + dtype=torch.uint8, + ) + fused_kv[..., :head_dim] = kv_fp8.view(torch.uint8) + fused_kv[..., head_dim:] = kv_scale.contiguous().view(torch.uint8) + + weights = torch.randn( + batch_size * next_n, num_heads, device="cuda", dtype=torch.float32 + ) + context_lens = torch.tensor( + [[5, active_max_len], [9, 12]], device="cuda", dtype=torch.int32 + ) + block_tables = ( + torch.arange( + batch_size * cdiv(max_model_len, block_size), + device="cuda", + dtype=torch.int32, + ).reshape(batch_size, -1) + % num_blocks + ) + + full_width_topk = torch.empty( + batch_size * next_n, topk_tokens, device="cuda", dtype=torch.int32 + ) + truncated_width_topk = torch.empty_like(full_width_topk) + + assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices( + (q_fp8, None), + fused_kv, + weights, + context_lens, + block_tables, + max_model_len, + full_width_topk, + ) + assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices( + (q_fp8, None), + fused_kv, + weights, + context_lens, + block_tables, + active_max_len, + truncated_width_topk, + ) + + torch.testing.assert_close(truncated_width_topk, full_width_topk, rtol=0, atol=0) diff --git a/tests/v1/attention/test_sparse_attn_indexer.py b/tests/v1/attention/test_sparse_attn_indexer.py new file mode 100644 index 000000000000..eb09cf058d8f --- /dev/null +++ b/tests/v1/attention/test_sparse_attn_indexer.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.model_executor.layers.sparse_attn_indexer import ( + SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH, + SM120_SHORT_ROW_TOPK_MAX_WIDTH, + _should_use_sm120_short_row_topk_decode, +) + + +@pytest.mark.parametrize( + ("topk_tokens", "logits_width", "num_rows", "is_cuda_sm120", "expected"), + [ + (512, SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH, 32, True, True), + (512, 8192, 16, True, True), + (512, 8192, 32, True, True), + (512, 12288, 32, True, False), + (512, SM120_SHORT_ROW_TOPK_MAX_WIDTH, 1, True, False), + (512, 4096, 1, False, False), + (2048, 4096, 1, True, False), + ], +) +def test_sm120_short_row_topk_decode_selector( + topk_tokens: int, + logits_width: int, + num_rows: int, + is_cuda_sm120: bool, + expected: bool, +) -> None: + assert ( + _should_use_sm120_short_row_topk_decode( + topk_tokens, + logits_width, + num_rows, + is_cuda_sm120, + ) + is expected + ) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 22acc748d24b..8becd568b680 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -8,6 +8,7 @@ import pytest import torch +import vllm.utils.deep_gemm as deep_gemm_utils from tests.v1.attention.test_mla_backends import ( BATCH_SPECS, BatchSpec, @@ -42,9 +43,16 @@ FlashMLASparseBackend, triton_convert_req_index_to_global_index, ) -from vllm.v1.attention.backends.mla.indexer import split_indexer_prefill_chunks +from vllm.v1.attention.backends.mla.indexer import ( + sparse_indexer_max_logits_bytes, + split_indexer_prefill_chunks, +) from vllm.v1.attention.backends.utils import split_prefill_chunks from vllm.v1.attention.ops import flashmla +from vllm.v1.attention.ops.deepseek_v4_ops import ( + combine_topk_swa_indices, + compute_global_topk_indices_and_lens, +) SPARSE_BACKEND_BATCH_SPECS = { name: BATCH_SPECS[name] @@ -67,6 +75,487 @@ DEVICE_TYPE = current_platform.device_type +def _make_packed_fp8_indexer_cache( + kv_fp8: torch.Tensor, + kv_scale: torch.Tensor, +) -> torch.Tensor: + num_blocks, block_size, num_kv_heads, head_dim = kv_fp8.shape + assert num_kv_heads == 1 + kv_scale_bytes = kv_scale.contiguous().view(torch.uint8).reshape( + num_blocks, block_size, num_kv_heads, -1 + ) + scale_bytes = kv_scale_bytes.shape[-1] + fused_kv = torch.empty( + num_blocks, + block_size, + head_dim + scale_bytes, + device=kv_fp8.device, + dtype=torch.uint8, + ) + fused_kv_blocks = fused_kv.view(num_blocks, -1) + value_end = block_size * head_dim + scale_end = value_end + block_size * scale_bytes + fused_kv_blocks[:, :value_end] = kv_fp8.view(torch.uint8).reshape( + num_blocks, -1 + ) + fused_kv_blocks[:, value_end:scale_end] = kv_scale_bytes.reshape( + num_blocks, -1 + ) + return fused_kv + + +def test_sm120_fp8_mqa_logits_chunk_sizes_cap_large_scores(): + assert deep_gemm_utils._fp8_mqa_logits_head_chunk_size(128, 128, 32) == 8 + assert deep_gemm_utils._fp8_mqa_logits_head_chunk_size(8192, 8192, 32) == 1 + assert deep_gemm_utils._fp8_mqa_logits_k_chunk_size(128, 128, 8) == 128 + assert deep_gemm_utils._fp8_mqa_logits_k_chunk_size(8192, 8192, 1) == 2048 + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_tf32_hc_prenorm_gemm_fallback_matches_split_abi( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(0) + num_tokens, out_features, hidden_size = 7, 12, 64 + x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) + fn = torch.randn(out_features, hidden_size, device="cuda", dtype=torch.float32) + + out = torch.empty(num_tokens, out_features, device="cuda", dtype=torch.float32) + sqrsum = torch.empty(num_tokens, device="cuda", dtype=torch.float32) + deep_gemm_utils._tf32_hc_prenorm_gemm_torch(x, fn, out, sqrsum, num_split=1) + + expected_out = x.float() @ fn.T + expected_sqrsum = x.float().square().sum(dim=-1) + torch.testing.assert_close(out, expected_out, rtol=0, atol=0) + torch.testing.assert_close(sqrsum, expected_sqrsum, rtol=0, atol=0) + + split_out = torch.empty(3, num_tokens, out_features, device="cuda") + split_sqrsum = torch.empty(3, num_tokens, device="cuda") + deep_gemm_utils._tf32_hc_prenorm_gemm_torch( + x, fn, split_out, split_sqrsum, num_split=3 + ) + torch.testing.assert_close(split_out.sum(dim=0), expected_out, rtol=0, atol=0) + torch.testing.assert_close(split_sqrsum.sum(dim=0), expected_sqrsum, rtol=0, atol=0) + + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None) + monkeypatch.setattr(deep_gemm_utils, "_tf32_hc_prenorm_gemm_impl", None) + wrapper_out = torch.empty_like(split_out) + wrapper_sqrsum = torch.empty_like(split_sqrsum) + deep_gemm_utils.tf32_hc_prenorm_gemm( + x, fn, wrapper_out, wrapper_sqrsum, num_split=3 + ) + torch.testing.assert_close( + wrapper_out.sum(dim=0), expected_out, rtol=2e-2, atol=2e-2 + ) + torch.testing.assert_close( + wrapper_sqrsum.sum(dim=0), expected_sqrsum, rtol=1e-4, atol=1e-4 + ) + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_paged_mqa_logits_fallback_matches_reference( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(1) + batch_size, next_n, num_heads, head_dim = 2, 2, 4, 32 + block_size, max_model_len, num_blocks = 4, 12, 4 + + q = torch.randn( + batch_size, + next_n, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + q_fp8 = q.to(torch.float8_e4m3fn) + kv = torch.randn( + num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn) + fused_kv = _make_packed_fp8_indexer_cache(kv_fp8, kv_scale) + + weights = torch.randn( + batch_size * next_n, num_heads, device="cuda", dtype=torch.float32 + ) + context_lens = torch.tensor([[3, 6], [7, 11]], device="cuda", dtype=torch.int32) + block_tables = torch.tensor( + [[0, 1, 2], [1, 2, 3]], device="cuda", dtype=torch.int32 + ) + expected = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device="cuda", + dtype=torch.float32, + ) + kv_dequant = kv_fp8.float() * kv_scale + for batch_idx in range(batch_size): + for next_idx in range(next_n): + row = batch_idx * next_n + next_idx + for token_idx in range(int(context_lens[batch_idx, next_idx].item())): + block = int(block_tables[batch_idx, token_idx // block_size].item()) + offset = token_idx % block_size + score = ( + q_fp8[batch_idx, next_idx].float() * kv_dequant[block, offset, 0] + ).sum(dim=1) + expected[row, token_idx] = (score.relu() * weights[row]).sum() + + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None) + monkeypatch.setattr(deep_gemm_utils, "_fp8_fp4_paged_mqa_logits_impl", None) + + def fail_torch_path(*args, **kwargs): + raise AssertionError("torch paged fallback should not be used") + + monkeypatch.setattr(deep_gemm_utils, "_fp8_paged_mqa_logits_torch", fail_torch_path) + actual = deep_gemm_utils.fp8_fp4_paged_mqa_logits( + (q_fp8.contiguous(), None), + fused_kv, + weights, + context_lens, + block_tables, + schedule_metadata=torch.empty(0, device="cuda", dtype=torch.int32), + max_model_len=max_model_len, + clean_logits=False, + ) + torch.testing.assert_close(actual, expected, rtol=0, atol=1e-5) + + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + fp8_paged_mqa_logits_triton, + ) + + triton_actual = fp8_paged_mqa_logits_triton( + q_fp8.contiguous(), fused_kv, weights, context_lens, block_tables, max_model_len + ) + assert torch.equal(torch.isneginf(triton_actual), torch.isneginf(expected)) + finite = torch.isfinite(expected) + assert (triton_actual[finite] - expected[finite]).abs().max() < 2e-2 + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_paged_mqa_rowwise_logits_matches_reference(): + torch.manual_seed(11) + batch_size, next_n, num_heads, head_dim = 2, 1, 8, 64 + block_size, max_model_len, num_blocks = 4, 18, 8 + + q = torch.randn( + batch_size, + next_n, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + q_fp8 = q.to(torch.float8_e4m3fn).contiguous() + kv = torch.randn( + num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn) + fused_kv = _make_packed_fp8_indexer_cache(kv_fp8, kv_scale) + + weights = torch.randn( + batch_size * next_n, num_heads, device="cuda", dtype=torch.float32 + ) + context_lens = torch.tensor([[7], [17]], device="cuda", dtype=torch.int32) + block_tables = ( + torch.arange( + batch_size * cdiv(max_model_len, block_size), + device="cuda", + dtype=torch.int32, + ).reshape(batch_size, -1) + % num_blocks + ) + + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + fp8_paged_mqa_logits_rowwise_triton, + ) + + actual = fp8_paged_mqa_logits_rowwise_triton( + q_fp8, fused_kv, weights, context_lens, block_tables, max_model_len + ) + expected = deep_gemm_utils._fp8_paged_mqa_logits_torch( + (q_fp8, None), fused_kv, weights, context_lens, block_tables, max_model_len + ) + + assert torch.equal(torch.isneginf(actual), torch.isneginf(expected)) + finite = torch.isfinite(expected) + assert (actual[finite] - expected[finite]).abs().max() < 2e-2 + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_paged_mqa_topk_indices_streams_chunks( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(3) + batch_size, next_n, num_heads, head_dim = 2, 2, 8, 32 + block_size, max_model_len, num_blocks = 4, 20, 8 + topk_tokens = 5 + monkeypatch.setattr( + deep_gemm_utils, + "_SM120_PAGED_MQA_TOPK_CHUNK_SIZE", + 7, + ) + monkeypatch.setattr( + torch, + "cat", + lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError("paged MQA top-k should reuse candidate buffers") + ), + ) + + q = torch.randn( + batch_size, + next_n, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + q_fp8 = q.to(torch.float8_e4m3fn) + kv = torch.randn( + num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn) + fused_kv = _make_packed_fp8_indexer_cache(kv_fp8, kv_scale) + + weights = torch.randn( + batch_size * next_n, num_heads, device="cuda", dtype=torch.float32 + ) + context_lens = torch.tensor([[3, 11], [17, 20]], device="cuda", dtype=torch.int32) + block_tables = ( + torch.arange( + batch_size * cdiv(max_model_len, block_size), + device="cuda", + dtype=torch.int32, + ).reshape(batch_size, -1) + % num_blocks + ) + topk_indices = torch.empty( + batch_size * next_n, topk_tokens, device="cuda", dtype=torch.int32 + ) + + assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices( + (q_fp8.contiguous(), None), + fused_kv, + weights, + context_lens, + block_tables, + max_model_len, + topk_indices, + ) + + logits = deep_gemm_utils._fp8_paged_mqa_logits_torch( + (q_fp8.contiguous(), None), + fused_kv, + weights, + context_lens, + block_tables, + max_model_len, + ) + expected = torch.full_like(topk_indices, -1) + flat_context_lens = context_lens.reshape(-1) + for row in range(batch_size * next_n): + valid_count = int(flat_context_lens[row].item()) + row_topk = min(topk_tokens, valid_count) + if row_topk > 0: + expected[row, :row_topk] = ( + logits[row].topk(row_topk).indices.to(torch.int32) + ) + + for row in range(batch_size * next_n): + row_topk = min(topk_tokens, int(flat_context_lens[row].item())) + assert set(topk_indices[row, :row_topk].tolist()) == set( + expected[row, :row_topk].tolist() + ) + assert torch.all(topk_indices[row, row_topk:] == -1) + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_mqa_logits_torch_path_streams_head_chunks( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(0) + seq_len, seq_len_kv, num_heads, head_dim = 9, 17, 32, 32 + monkeypatch.setattr( + deep_gemm_utils, + "_SM120_MQA_LOGITS_MAX_SCORE_BYTES", + seq_len * 5 * 4, + ) + + q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32) + cu_seqlen_ks = torch.arange(seq_len, device="cuda", dtype=torch.int32) % 3 + cu_seqlen_ke = torch.minimum( + torch.arange(seq_len, device="cuda", dtype=torch.int32) + 4, + torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32), + ) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4) + kv_scale = (kv_amax / 448.0).squeeze(1).contiguous() + kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn) + + logits = deep_gemm_utils._fp8_mqa_logits_torch( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + + kv_dequant = kv_fp8.float() * kv_scale[:, None] + score = torch.einsum("mhd,nd->hmn", q_fp8.float(), kv_dequant) + ref_logits = (score.relu() * weights.transpose(0, 1).unsqueeze(-1)).sum(dim=0) + offsets = torch.arange(seq_len_kv, device="cuda") + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + ref_logits = ref_logits.masked_fill(~valid, float("-inf")) + + assert torch.equal(torch.isneginf(logits), torch.isneginf(ref_logits)) + finite = torch.isfinite(ref_logits) + assert (logits[finite] - ref_logits[finite]).abs().max() < 1e-4 + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_mqa_logits_wrapper_uses_triton_when_deepgemm_missing( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(2) + seq_len, seq_len_kv, num_heads, head_dim = 5, 13, 8, 32 + + q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32) + cu_seqlen_ks = torch.arange(seq_len, device="cuda", dtype=torch.int32) % 3 + cu_seqlen_ke = torch.minimum( + cu_seqlen_ks + 6, + torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32), + ) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4) + kv_scale = (kv_amax / 448.0).squeeze(1).contiguous() + kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn) + + kv_dequant = kv_fp8.float() * kv_scale[:, None] + score = torch.einsum("mhd,nd->hmn", q_fp8.float(), kv_dequant) + expected = (score.relu() * weights.transpose(0, 1).unsqueeze(-1)).sum(dim=0) + offsets = torch.arange(seq_len_kv, device="cuda") + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + expected = expected.masked_fill(~valid, float("-inf")) + + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None) + monkeypatch.setattr(deep_gemm_utils, "_fp8_fp4_mqa_logits_impl", None) + + def fail_torch_path(*args, **kwargs): + raise AssertionError("torch fallback should not be used") + + monkeypatch.setattr(deep_gemm_utils, "_fp8_mqa_logits_torch", fail_torch_path) + actual = deep_gemm_utils.fp8_fp4_mqa_logits( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + + assert torch.equal(torch.isneginf(actual), torch.isneginf(expected)) + finite = torch.isfinite(expected) + assert (actual[finite] - expected[finite]).abs().max() < 2e-2 + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_mqa_logits_topk_streams_k_chunks( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(1) + seq_len, seq_len_kv, num_heads, head_dim = 11, 23, 16, 32 + topk_tokens = 5 + monkeypatch.setattr( + deep_gemm_utils, + "_SM120_MQA_LOGITS_MAX_SCORE_BYTES", + seq_len * 5 * 4, + ) + monkeypatch.setattr( + torch, + "cat", + lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError("MQA top-k should reuse candidate buffers") + ), + ) + + q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32) + cu_seqlen_ks = torch.arange(seq_len, device="cuda", dtype=torch.int32) % 4 + valid_lens = torch.arange(seq_len, device="cuda", dtype=torch.int32) % 7 + cu_seqlen_ke = torch.minimum( + cu_seqlen_ks + valid_lens, + torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32), + ) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4) + kv_scale = (kv_amax / 448.0).squeeze(1).contiguous() + kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn) + + topk_indices = deep_gemm_utils._fp8_mqa_logits_topk_torch( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_tokens, + ) + + logits = deep_gemm_utils._fp8_mqa_logits_torch( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + expected = torch.full_like(topk_indices, -1) + for row in range(seq_len): + valid_count = int((cu_seqlen_ke[row] - cu_seqlen_ks[row]).item()) + row_topk = min(topk_tokens, valid_count) + if row_topk > 0: + expected[row, :row_topk] = ( + logits[row].topk(row_topk).indices.to(torch.int32) + ) + + for row in range(seq_len): + valid_count = int((cu_seqlen_ke[row] - cu_seqlen_ks[row]).item()) + row_topk = min(topk_tokens, valid_count) + assert set(topk_indices[row, :row_topk].tolist()) == set( + expected[row, :row_topk].tolist() + ) + assert torch.all(topk_indices[row, row_topk:] == -1) + + def _float_to_e8m0_truncate(f: float) -> float: """Simulate SM100's float -> e8m0 -> bf16 scale conversion. e8m0 format only stores the exponent (power of 2). @@ -218,8 +707,14 @@ def test_sparse_backend_decode_correctness( if not ok: pytest.skip(reason) elif backend_cls == FlashInferMLASparseBackend: - if not current_platform.has_device_capability(100): - pytest.skip("FlashInferMLASparseBackend requires SM 10.0 or higher") + capability = current_platform.get_device_capability() + if capability is None or not backend_cls.supports_compute_capability( + capability + ): + pytest.skip( + "FlashInferMLASparseBackend does not support " + f"{capability} on this platform" + ) batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name] use_fp8_ds_mla_quantization = kv_cache_dtype == "fp8_ds_mla" @@ -781,6 +1276,119 @@ def test_split_indexer_prefill_chunks( assert out == expected +def test_sparse_indexer_max_logits_bytes_uses_sm12x_safe_default(monkeypatch): + monkeypatch.delenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", raising=False) + + assert sparse_indexer_max_logits_bytes(is_sm12x=True) == 256 * 1024 * 1024 + assert sparse_indexer_max_logits_bytes(is_sm12x=False) == 512 * 1024 * 1024 + + +def test_sparse_indexer_max_logits_bytes_honors_env_override(monkeypatch): + monkeypatch.setenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", "384") + + assert sparse_indexer_max_logits_bytes(is_sm12x=True) == 384 * 1024 * 1024 + assert sparse_indexer_max_logits_bytes(is_sm12x=False) == 384 * 1024 * 1024 + + +def test_compute_global_topk_indices_supports_in_place_output(): + device = torch.device(DEVICE_TYPE) + block_size = 4 + topk_indices = torch.tensor( + [[0, 3, 4, -1], [2, 5, -1, -1], [1, 7, -1, -1]], + dtype=torch.int32, + device=device, + ) + token_to_req = torch.tensor([0, 1, 1], dtype=torch.int32, device=device) + block_table = torch.tensor( + [[10, 11, 12], [20, 21, 22]], dtype=torch.int32, device=device + ) + is_valid = torch.tensor([True, True, False], device=device) + + expected_indices = torch.tensor( + [ + [40, 43, 44, -1], + [82, 85, -1, -1], + [-1, -1, -1, -1], + ], + dtype=torch.int32, + device=device, + ) + expected_lens = torch.tensor([3, 2, 0], dtype=torch.int32, device=device) + + out, lens = compute_global_topk_indices_and_lens( + topk_indices, + token_to_req, + block_table, + block_size, + is_valid, + ) + torch.testing.assert_close(out, expected_indices, rtol=0, atol=0) + torch.testing.assert_close(lens, expected_lens, rtol=0, atol=0) + + in_place = topk_indices.clone() + provided_lens = torch.empty(3, dtype=torch.int32, device=device) + out, lens = compute_global_topk_indices_and_lens( + in_place, + token_to_req, + block_table, + block_size, + is_valid, + global_topk_indices=in_place, + topk_lens=provided_lens, + ) + assert out is in_place + assert lens is provided_lens + torch.testing.assert_close(in_place, expected_indices, rtol=0, atol=0) + torch.testing.assert_close(provided_lens, expected_lens, rtol=0, atol=0) + + +def test_combine_topk_swa_indices_supports_workspace_outputs(): + device = torch.device(DEVICE_TYPE) + num_tokens = 6 + topk = 4 + window_size = 8 + topk_indices = ( + torch.arange(num_tokens * topk, dtype=torch.int32, device=device) + .reshape(num_tokens, topk) + .remainder(5) + ) + query_start_loc = torch.tensor([0, num_tokens], dtype=torch.int32, device=device) + seq_lens = torch.tensor([20], dtype=torch.int32, device=device) + gather_lens = torch.tensor([8], dtype=torch.int32, device=device) + + expected_indices, expected_lens = combine_topk_swa_indices( + topk_indices, + query_start_loc, + seq_lens, + gather_lens, + window_size, + 4, + topk, + 16, + 12, + ) + workspace_indices = torch.empty_like(expected_indices) + workspace_lens = torch.empty_like(expected_lens) + actual_indices, actual_lens = combine_topk_swa_indices( + topk_indices, + query_start_loc, + seq_lens, + gather_lens, + window_size, + 4, + topk, + 16, + 12, + combined_indices=workspace_indices, + combined_lens=workspace_lens, + ) + + assert actual_indices.data_ptr() == workspace_indices.data_ptr() + assert actual_lens.data_ptr() == workspace_lens.data_ptr() + torch.testing.assert_close(actual_indices, expected_indices, rtol=0, atol=0) + torch.testing.assert_close(actual_lens, expected_lens, rtol=0, atol=0) + + def test_split_indexer_prefill_chunks_single_request_overflow(): """Test that single request exceeding budget is sub-chunked on query dim.""" seq_lens = torch.tensor([1000, 50]) diff --git a/tests/v1/attention/test_sparse_mla_env.py b/tests/v1/attention/test_sparse_mla_env.py new file mode 100644 index 000000000000..9745ea169cc2 --- /dev/null +++ b/tests/v1/attention/test_sparse_mla_env.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from collections.abc import Iterator +from contextlib import contextmanager + +import torch + +from vllm.envs import environment_variables +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled, + triton_sparse_mla_cudagraphs_allowed, + triton_sparse_mla_head_block_size, + triton_sparse_mla_query_chunk_size, + triton_sparse_mla_topk_chunk_size, +) + +_SPARSE_MLA_ENV_NAMES = ( + "VLLM_TRITON_MLA_SPARSE", + "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", + "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE", + "VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE", +) + + +@contextmanager +def _patched_sparse_mla_env(**updates: str) -> Iterator[None]: + previous = {name: os.environ.get(name) for name in _SPARSE_MLA_ENV_NAMES} + try: + for name in _SPARSE_MLA_ENV_NAMES: + os.environ.pop(name, None) + os.environ.update(updates) + yield + finally: + for name, value in previous.items(): + if value is None: + os.environ.pop(name, None) + else: + os.environ[name] = value + + +def test_triton_sparse_mla_env_uses_new_name() -> None: + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE="0"): + assert not is_triton_sparse_mla_enabled(torch.device("cpu")) + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE="1"): + assert is_triton_sparse_mla_enabled(torch.device("cpu")) + + +def test_sparse_mla_cudagraph_env_defaults_to_allowed() -> None: + with _patched_sparse_mla_env(): + assert triton_sparse_mla_cudagraphs_allowed() + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH="0"): + assert not triton_sparse_mla_cudagraphs_allowed() + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH="1"): + assert triton_sparse_mla_cudagraphs_allowed() + + +def test_sparse_mla_head_block_env_accepts_supported_values() -> None: + with _patched_sparse_mla_env(): + assert triton_sparse_mla_head_block_size() is None + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="1"): + assert triton_sparse_mla_head_block_size() == 1 + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="2"): + assert triton_sparse_mla_head_block_size() == 2 + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="4"): + assert triton_sparse_mla_head_block_size() == 4 + + +def test_sparse_mla_head_block_env_ignores_invalid_values() -> None: + for value in ("0", "3", "invalid"): + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE=value): + assert triton_sparse_mla_head_block_size() is None + + +def test_sparse_mla_head_block_env_is_registered_with_vllm_envs() -> None: + assert "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE" in environment_variables + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="4"): + assert environment_variables["VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE"]() == 4 + + +def test_sparse_mla_chunk_env_defaults_invalid_values() -> None: + with _patched_sparse_mla_env( + VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE="invalid", + VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE="-7", + ): + assert triton_sparse_mla_topk_chunk_size() == 512 + assert triton_sparse_mla_query_chunk_size() == 1 diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 0f02a92681c1..91e5a8913e88 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -750,6 +750,7 @@ class CompilationConfig: "vllm::sparse_attn_indexer", "vllm::rocm_aiter_sparse_attn_indexer", "vllm::deepseek_v4_attention", + "vllm::deepseek_v4_fp8_einsum", ] def compute_hash(self) -> str: diff --git a/vllm/envs.py b/vllm/envs.py index ded474dc085a..07a3e55025de 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -166,6 +166,12 @@ VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True + VLLM_TRITON_MLA_SPARSE: bool | None = None + VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE: int = 512 + VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE: int = 256 + VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH: bool = True + VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE: int | None = None + VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE: bool | None = None VLLM_DEEP_GEMM_WARMUP: Literal[ "skip", "full", @@ -249,6 +255,7 @@ VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD: int = 1024 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False + VLLM_DEEPSEEK_V4_USE_MEGA_MOE: bool = False VLLM_LOG_MODEL_INSPECTION: bool = False VLLM_DEBUG_MFU_METRICS: bool = False VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY: bool = False @@ -1275,6 +1282,34 @@ def _get_or_set_default() -> str: "VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool( int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1")) ), + # Experimental sparse MLA fallback controls. + # ``VLLM_TRITON_MLA_SPARSE`` unset means auto-select where FlashMLA sparse + # is unavailable; set 0/1 to force-disable/force-enable the fallback. + "VLLM_TRITON_MLA_SPARSE": lambda: ( + None + if os.getenv("VLLM_TRITON_MLA_SPARSE") is None + else os.getenv("VLLM_TRITON_MLA_SPARSE", "").lower() + in ("1", "true", "yes", "on") + ), + "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE": lambda: maybe_convert_int( + os.getenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", "512") + ), + "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE": lambda: maybe_convert_int( + os.getenv("VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE", "256") + ), + "VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH": lambda: ( + os.getenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", "1").lower() + in ("1", "true", "yes", "on") + ), + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE": lambda: maybe_convert_int( + os.getenv("VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE") + ), + "VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE": lambda: ( + None + if os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE") is None + else os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE", "").lower() + in ("1", "true", "yes", "on") + ), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine @@ -1711,6 +1746,12 @@ def _get_or_set_default() -> str: "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) ), + # Optional override for the DeepGEMM MegaMoE fused expert kernel in + # DeepSeek V4. If unset, kernel_config.moe_backend decides; set to 1/0 to + # force-enable or force-disable this path during bring-up. + "VLLM_DEEPSEEK_V4_USE_MEGA_MOE": lambda: bool( + int(os.getenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", "0")) + ), # Log model inspection after loading. # If enabled, logs a transformers-style hierarchical view of the model # with quantization methods and attention backends. diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py index 618084029159..6baedd3bbcbc 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py @@ -7,6 +7,9 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) @@ -26,6 +29,20 @@ ) +def _is_sm12x_compute_capability(compute_capability) -> bool: + if compute_capability is None: + return current_platform.is_device_capability_family(120) + + if isinstance(compute_capability, tuple): + return compute_capability[0] == 12 + + to_int = getattr(compute_capability, "to_int", None) + if callable(to_int): + return to_int() // 10 == 12 + + return int(compute_capability) // 10 == 12 + + class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def is_supported( @@ -196,6 +213,9 @@ def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None: @classmethod def is_supported(cls, compute_capability=None): + if _is_sm12x_compute_capability(compute_capability): + return False, "CUTLASS block-scaled FP8 GEMM is not supported on SM12x." + if not CUTLASS_BLOCK_FP8_SUPPORTED: return ( False, @@ -219,6 +239,31 @@ def can_implement(cls, config: FP8ScaledMMLinearLayerConfig): ) return True, None + def process_weights_after_loading(self, layer: torch.nn.Module): + super().process_weights_after_loading(layer) + params = self._get_layer_params(layer) + weight_scale = ( + params.weight_scale + if params.weight_scale_inv is None + else params.weight_scale_inv + ) + scale_attr_name = ( + params.WEIGHT_SCALE + if params.weight_scale_inv is None + else params.WEIGHT_SCALE_INV + ) + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if ( + e8m0_dtype is not None + and weight_scale is not None + and weight_scale.dtype == e8m0_dtype + ): + replace_parameter( + layer, + scale_attr_name, + _upcast_e8m0_to_fp32(weight_scale), + ) + def apply_block_scaled_mm( self, A: torch.Tensor, diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 494d61338084..446935c7c739 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -18,15 +18,21 @@ ReplicatedLinear, ) from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer +from vllm.platforms import current_platform from vllm.utils.deep_gemm import fp8_einsum from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.ops.deepseek_v4_ops import ( combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, + dequantize_combined_sparse_mla_decode_kv, fused_indexer_q_rope_quant, fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, + sparse_prefill_combined_topk_size, +) +from vllm.v1.attention.ops.deepseek_v4_ops.fp8_einsum import ( + deepseek_v4_sm12_fp8_einsum, ) from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( rocm_forward_decode_fallback, @@ -44,7 +50,10 @@ VllmConfig, get_current_vllm_config, ) -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer @@ -58,7 +67,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) -from vllm.platforms import current_platform from vllm.utils.multi_stream_utils import ( execute_in_parallel, maybe_execute_in_parallel, @@ -73,6 +81,25 @@ DeepseekV4IndexerBackend, get_max_prefill_buffer_size, ) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + disable_triton_sparse_mla_cudagraphs_if_enabled, + is_triton_sparse_mla_enabled, + triton_sparse_mla_matmul_decode_enabled, + triton_sparse_mla_query_chunk_size, + triton_sparse_mla_topk_chunk_size, +) +from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead, + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead, + accumulate_indexed_sparse_mla_attention_chunk, + build_combined_sparse_mla_decode_valid_mask, + finish_sparse_mla_attention_with_sink, + finish_two_sparse_mla_attention_states_with_sink, + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead, + fp8ds_paged_sparse_mla_attention_with_sink_multihead, + matmul_sparse_mla_attention_with_sink, + sparse_mla_decode_head_block_size, +) from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache from vllm.v1.attention.ops.flashmla import ( flash_mla_sparse_fwd, @@ -83,6 +110,68 @@ logger = init_logger(__name__) + +def _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu: torch.Tensor, + gather_lens_cpu: torch.Tensor, + compress_ratio: int, + swa_only: bool, +) -> tuple[int, int]: + if seq_lens_cpu.numel() == 0: + return 0, 0 + + max_gather_len = int(gather_lens_cpu.max().item()) + if swa_only: + return 0, max_gather_len + + compressed_region_size = int((seq_lens_cpu // compress_ratio).max().item()) + return compressed_region_size, compressed_region_size + max_gather_len + + +def _deepseek_v4_fp8_einsum_config( + capability_major: int, +) -> tuple[tuple[int, int, int], bool]: + if capability_major == 10: + return (1, 1, 128), True + return (1, 128, 128), False + + +def _use_deepseek_v4_sm12_triton_fp8_einsum( + equation: str, + recipe: list[int], + b_scale: torch.Tensor, +) -> bool: + capability = current_platform.get_device_capability() + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + return ( + capability is not None + and capability.major == 12 + and equation == "bhr,hdr->bhd" + and tuple(recipe) == (1, 128, 128) + and b_scale.dtype in (torch.float32, e8m0_dtype) + ) + + +def _allocate_deepseek_v4_wo_a_output( + num_tokens: int, + num_groups: int, + output_rank: int, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + shape = (num_tokens, num_groups, output_rank) + if torch.compiler.is_compiling(): + # Workspace growth can call torch.accelerator.empty_cache(), which + # Dynamo intentionally refuses to trace. During compilation this is a + # normal graph allocation, matching the o_padded allocation above. + return torch.empty(shape, dtype=dtype, device=device) + + (output,) = current_workspace_manager().get_simultaneous( + (shape, dtype), + ) + return output + + # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather # workspace allocated at _forward_prefill (and the matching profile-time # reservation in attention_impl's dummy-run branch). @@ -172,6 +261,8 @@ def __init__( self.compress_ratio = compress_ratio if compress_ratio is not None else 1 self.prefix = prefix + disable_triton_sparse_mla_cudagraphs_if_enabled(mla_modules.vllm_config) + # Extract config from vllm_config config = mla_modules.vllm_config.model_config.hf_config tp_size = get_tensor_model_parallel_world_size() @@ -202,12 +293,13 @@ def __init__( self.wo_b = mla_modules.wo_b # Pick fp8_einsum recipe based on GPU arch: - # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128 - # SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1 + # SM90/SM120: FP32 block scales stay [g, r/128, d/128]. + # SM100: INT32 packed scales become [g, r, ...]. cap = current_platform.get_device_capability() assert cap is not None, "DeepseekV4 attention requires a CUDA device" - self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) - self._tma_aligned_scales = cap.major >= 10 + self._einsum_recipe, self._tma_aligned_scales = _deepseek_v4_fp8_einsum_config( + cap.major + ) self.rotary_emb = mla_modules.rotary_emb self.indexer_rotary_emb = mla_modules.indexer_rotary_emb @@ -336,10 +428,12 @@ def forward( wo_a_fp8 = self.wo_a.weight wo_a_scale = self.wo_a.weight_scale_inv - z = torch.empty( - (num_tokens, self.n_local_groups, self.o_lora_rank), - device=o.device, - dtype=torch.bfloat16, + z = _allocate_deepseek_v4_wo_a_output( + num_tokens, + self.n_local_groups, + self.o_lora_rank, + torch.bfloat16, + hidden_states.device, ) torch.ops.vllm.deepseek_v4_fp8_einsum( o_fp8, @@ -494,20 +588,6 @@ def wq_b_kv_insert() -> torch.Tensor: # Handle dummy run (no metadata). if not isinstance(attn_metadata, dict): - # Reserve _forward_prefill's bf16-gather workspace; the dummy - # run returns before mla_attn runs, so without this the shared - # workspace locks below the real prefill size. - sub = self.mla_attn - swa_only = sub.compress_ratio <= 1 - N = ( - 0 - if swa_only - else (sub.max_model_len + sub.compress_ratio - 1) // sub.compress_ratio - ) - M = N + sub.window_size + sub.max_num_batched_tokens - current_workspace_manager().get_simultaneous( - ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - ) out.zero_() return @@ -594,6 +674,68 @@ def deepseek_v4_fp8_einsum( equation: str, recipe: list[int], ) -> None: + if equation == "bhr,hdr->bhd" and b.dim() == 2: + num_groups = out.shape[1] + out_rank = out.shape[2] + hidden_size = a.shape[2] + if b.shape[0] % out_rank != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum weight rows must be divisible by " + f"out_rank={out_rank}, got {b.shape[0]}" + ) + b_groups = b.shape[0] // out_rank + group_start = 0 + if b_groups != num_groups: + if b_groups % num_groups != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum weight groups must match the " + "TP-local output groups or be an integer multiple of " + f"them, got weight_groups={b_groups}, " + f"output_groups={num_groups}" + ) + group_partitions = b_groups // num_groups + group_start = ( + get_tensor_model_parallel_rank() % group_partitions + ) * num_groups + b = b.view(b_groups, out_rank, hidden_size) + if group_start != 0 or b_groups != num_groups: + b = b.narrow(0, group_start, num_groups) + + if b_scale.dim() == 2: + scale_mn = recipe[1] + scale_k_pack = 4 if b_scale.dtype == torch.int32 else 1 + scale_k = recipe[2] * scale_k_pack + scale_out_blocks = (out_rank + scale_mn - 1) // scale_mn + scale_hidden_blocks = (hidden_size + scale_k - 1) // scale_k + if b_scale.shape[0] % scale_out_blocks != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum scale rows must be divisible by " + f"scale_out_blocks={scale_out_blocks}, " + f"got {b_scale.shape[0]}" + ) + scale_groups = b_scale.shape[0] // scale_out_blocks + if scale_groups not in (num_groups, b_groups): + raise RuntimeError( + "DeepSeek V4 fp8 einsum scale groups must match the " + "TP-local output groups or weight groups, got " + f"scale_groups={scale_groups}, output_groups={num_groups}, " + f"weight_groups={b_groups}" + ) + b_scale = b_scale.view( + scale_groups, + scale_out_blocks, + scale_hidden_blocks, + ) + if scale_groups == b_groups and scale_groups != num_groups: + b_scale = b_scale.narrow(0, group_start, num_groups) + elif b_scale.dim() == 3 and b_scale.shape[0] == b_groups: + if b_groups != num_groups: + b_scale = b_scale.narrow(0, group_start, num_groups) + + if _use_deepseek_v4_sm12_triton_fp8_einsum(equation, recipe, b_scale): + deepseek_v4_sm12_fp8_einsum(a, a_scale, b, b_scale, out) + return + fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) @@ -711,7 +853,11 @@ def __init__( assert cache_config is not None cache_config.cache_dtype = "fp8_ds_mla" kv_cache_dtype = "fp8_ds_mla" - logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.") + logger.info_once( + "Using DeepSeek's fp8_ds_mla KV cache format. To use standard " + "fp8 kv-cache format, please set `--attention-backend " + "FLASHINFER_MLA_SPARSE`" + ) self.kv_cache_dtype = kv_cache_dtype @@ -743,6 +889,332 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: model_version="deepseek_v4", ) + def _forward_sparse_mla_swa_decode_triton( + self, + q: torch.Tensor, + swa_k_cache: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + output: torch.Tensor, + ) -> None: + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + mtp_decode = num_decode_tokens != num_decodes + + swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens] + swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens] + max_swa_len = swa_metadata.decode_swa_indices.shape[-1] + head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens) + if not mtp_decode: + fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q=q, + k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + block_size=swa_metadata.block_size, + candidate_offset=0, + num_candidates=max_swa_len, + scale=self.scale, + attn_sink=self.attn_sink, + output=output, + head_block_size=head_block_size, + num_heads=self.num_heads, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + return + + ( + swa_max_score, + swa_denom, + swa_acc, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32), + ) + swa_max_score.fill_(float("-inf")) + swa_denom.zero_() + swa_acc.zero_() + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + slot_ids=swa_indices, + lens=swa_lens, + block_size=swa_metadata.block_size, + scale=self.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + finish_sparse_mla_attention_with_sink( + swa_max_score, + swa_denom, + swa_acc, + self.attn_sink, + output=output, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + + def _forward_sparse_mla_compressed_decode_triton( + self, + q: torch.Tensor, + compressed_k_cache: torch.Tensor, + swa_k_cache: torch.Tensor, + topk_indices: torch.Tensor, + topk_lens: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata, + output: torch.Tensor, + ) -> None: + if self.compress_ratio not in (4, 128): + raise NotImplementedError( + "Triton sparse MLA compressed decode currently supports " + f"compress_ratio=4 or 128, got {self.compress_ratio}" + ) + + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + mtp_decode = num_decode_tokens != num_decodes + + max_swa_len = swa_metadata.decode_swa_indices.shape[-1] + compressed_block_size = attn_metadata.block_size // self.compress_ratio + compressed_topk = topk_indices.shape[-1] + topk_chunk_size = min( + compressed_topk, + triton_sparse_mla_topk_chunk_size(), + ) + compressed_slot_ids = topk_indices[:, 0, :] + swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens] + swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens] + head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens) + if ( + not mtp_decode + and compressed_topk <= topk_chunk_size + and triton_sparse_mla_matmul_decode_enabled() + ): + total_candidates = compressed_topk + max_swa_len + ( + combined_kv, + valid_tokens, + score_buffer, + ) = current_workspace_manager().get_simultaneous( + ( + (num_decode_tokens, total_candidates, q.shape[-1]), + torch.bfloat16, + ), + ((num_decode_tokens, total_candidates), torch.bool), + ((num_decode_tokens, self.num_heads, total_candidates), torch.bfloat16), + ) + dequantize_combined_sparse_mla_decode_kv( + combined_kv, + compressed_k_cache, + compressed_slot_ids, + compressed_block_size, + swa_k_cache, + swa_metadata.seq_lens[:num_decodes], + swa_lens, + swa_metadata.block_table[:num_decodes], + swa_metadata.block_size, + ) + + build_combined_sparse_mla_decode_valid_mask( + valid_tokens, + compressed_slot_ids, + topk_lens, + swa_lens, + ) + use_dot_finish = num_decode_tokens <= 16 + matmul_sparse_mla_attention_with_sink( + q=q, + kv=combined_kv, + valid_tokens=valid_tokens, + scale=self.scale, + attn_sink=self.attn_sink, + output=output, + num_heads=self.num_heads, + score_buffer=score_buffer, + value_block_size=512 if use_dot_finish else 256, + candidate_block_size=128 if use_dot_finish else None, + ) + return + + if not mtp_decode and compressed_topk <= topk_chunk_size: + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q=q, + compressed_k_cache=compressed_k_cache, + slot_ids=compressed_slot_ids, + topk_lens=topk_lens, + compressed_block_size=compressed_block_size, + swa_k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + swa_block_size=swa_metadata.block_size, + num_compressed_candidates=compressed_topk, + num_swa_candidates=max_swa_len, + scale=self.scale, + attn_sink=self.attn_sink, + output=output, + head_block_size=head_block_size, + num_heads=self.num_heads, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + return + + ( + comp_max_score, + comp_denom, + comp_acc, + swa_max_score, + swa_denom, + swa_acc, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32), + ) + comp_max_score.fill_(float("-inf")) + comp_denom.zero_() + comp_acc.zero_() + swa_max_score.fill_(float("-inf")) + swa_denom.zero_() + swa_acc.zero_() + + for chunk_start in range(0, compressed_topk, topk_chunk_size): + chunk_end = min(chunk_start + topk_chunk_size, compressed_topk) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=compressed_k_cache, + slot_ids=compressed_slot_ids[:, chunk_start:chunk_end], + lens=topk_lens, + block_size=compressed_block_size, + candidate_offset=chunk_start, + scale=self.scale, + max_score=comp_max_score, + denom=comp_denom, + acc=comp_acc, + head_block_size=head_block_size, + ) + if mtp_decode: + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + slot_ids=swa_indices, + lens=swa_lens, + block_size=swa_metadata.block_size, + scale=self.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + else: + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + block_size=swa_metadata.block_size, + candidate_offset=0, + num_candidates=max_swa_len, + scale=self.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + finish_two_sparse_mla_attention_states_with_sink( + comp_max_score, + comp_denom, + comp_acc, + swa_max_score, + swa_denom, + swa_acc, + self.attn_sink, + output=output, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + + def _forward_sparse_mla_prefill_triton( + self, + q: torch.Tensor, + kv: torch.Tensor, + combined_indices: torch.Tensor, + combined_lens: torch.Tensor, + output: torch.Tensor, + state_buffers: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> None: + kv_flat = kv.reshape(-1, q.shape[-1]) + topk_chunk_size = min( + combined_indices.shape[-1], + triton_sparse_mla_topk_chunk_size(), + ) + query_chunk_size = min( + q.shape[0], + triton_sparse_mla_query_chunk_size(), + ) + if state_buffers is None: + ( + max_score_buffer, + denom_buffer, + output_buffer, + ) = current_workspace_manager().get_simultaneous( + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads, q.shape[-1]), torch.float32), + ) + else: + max_score_buffer, denom_buffer, output_buffer = state_buffers + + for token_start in range(0, q.shape[0], query_chunk_size): + token_end = min(token_start + query_chunk_size, q.shape[0]) + q_chunk = q[token_start:token_end] + indices_chunk_full = combined_indices[token_start:token_end] + lens_chunk = combined_lens[token_start:token_end] + num_tokens = token_end - token_start + max_score = max_score_buffer[:num_tokens] + denom = denom_buffer[:num_tokens] + subset_acc = output_buffer[:num_tokens] + max_score.fill_(float("-inf")) + denom.zero_() + subset_acc.zero_() + + for index_start in range(0, combined_indices.shape[-1], topk_chunk_size): + index_end = min( + index_start + topk_chunk_size, + combined_indices.shape[-1], + ) + accumulate_indexed_sparse_mla_attention_chunk( + q=q_chunk, + kv_flat=kv_flat, + indices=indices_chunk_full[:, index_start:index_end], + lens=lens_chunk, + candidate_offset=index_start, + scale=self.scale, + max_score=max_score, + denom=denom, + acc=subset_acc, + ) + + finish_sparse_mla_attention_with_sink( + max_score, + denom, + subset_acc, + self.attn_sink, + output=output[token_start:token_end], + ) + if output.shape[1] > self.num_heads: + output[token_start:token_end, self.num_heads :].zero_() + def forward( self, q: torch.Tensor, @@ -823,12 +1295,14 @@ def _forward_decode( if self.compress_ratio == 4: # C4A: local indices differ per layer (filled by Indexer). assert self.topk_indices_buffer is not None + local_topk_indices = self.topk_indices_buffer[:num_decode_tokens] global_indices, topk_lens = compute_global_topk_indices_and_lens( - self.topk_indices_buffer[:num_decode_tokens], + local_topk_indices, swa_metadata.token_to_req_indices, attn_metadata.block_table[:num_decodes], block_size, is_valid, + global_topk_indices=local_topk_indices, ) topk_indices = global_indices.view(num_decode_tokens, 1, -1) else: @@ -867,9 +1341,35 @@ def _forward_decode( # Use unsqueeze to preserve strides (handles padded blocks correctly) swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2) # Reshape KV cache to (num_blocks, block_size, 1, head_bytes) + compressed_k_cache = kv_cache if kv_cache is not None: kv_cache = kv_cache.unsqueeze(-2) + if is_triton_sparse_mla_enabled(q.device): + if swa_only: + self._forward_sparse_mla_swa_decode_triton( + q=q, + swa_k_cache=self.swa_cache_layer.kv_cache, + swa_metadata=swa_metadata, + output=output, + ) + return + if self.compress_ratio in (4, 128): + assert compressed_k_cache is not None + assert attn_metadata is not None + assert topk_indices is not None + assert topk_lens is not None + self._forward_sparse_mla_compressed_decode_triton( + q=q, + compressed_k_cache=compressed_k_cache, + swa_k_cache=self.swa_cache_layer.kv_cache, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_metadata=swa_metadata, + attn_metadata=attn_metadata, + output=output, + ) + return # One FlashMLASchedMeta per layer type, shared across all same-type # layers within this decode step. The first forward call per type # triggers the in-kernel planner (allocating tile_scheduler_metadata @@ -932,8 +1432,12 @@ def _forward_prefill( # Use pre-computed prefill metadata. seq_lens = swa_metadata.prefill_seq_lens gather_lens = swa_metadata.prefill_gather_lens + seq_lens_cpu = swa_metadata.prefill_seq_lens_cpu + gather_lens_cpu = swa_metadata.prefill_gather_lens_cpu assert seq_lens is not None assert gather_lens is not None + assert seq_lens_cpu is not None + assert gather_lens_cpu is not None # Derive prefill-local token offsets from the full query_start_loc_cpu. query_start_loc_cpu = swa_metadata.query_start_loc_cpu @@ -952,24 +1456,69 @@ def _forward_prefill( assert attn_metadata is not None topk_indices = attn_metadata.c128a_prefill_topk_indices top_k = topk_indices.shape[-1] - # Compressed region must fit the full compressed pool (seq_len // - # compress_ratio), not just top_k. top_k bounds how many indices - # the indexer selects, not the pool size it indexes into. - N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio else: # NOTE(woosuk): topk_indices will not be used for SWA-only layers. assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[num_decode_tokens:] top_k = 0 - N = 0 - M = N + self.window_size + self.max_num_batched_tokens + N, M = _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu=seq_lens_cpu, + gather_lens_cpu=gather_lens_cpu, + compress_ratio=self.compress_ratio, + swa_only=swa_only, + ) num_chunks = (num_prefills + PREFILL_CHUNK_SIZE - 1) // PREFILL_CHUNK_SIZE + max_query_chunk_tokens = 0 + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * PREFILL_CHUNK_SIZE + chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills) + query_start = ( + query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base + ) + query_end = ( + query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base + ) + max_query_chunk_tokens = max( + max_query_chunk_tokens, int(query_end - query_start) + ) + combined_topk = sparse_prefill_combined_topk_size(top_k, self.window_size) workspace_manager = current_workspace_manager() - kv = workspace_manager.get_simultaneous( - ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - )[0] + triton_sparse_mla_enabled = is_triton_sparse_mla_enabled(q.device) + if triton_sparse_mla_enabled: + query_chunk_size = min(q.shape[0], triton_sparse_mla_query_chunk_size()) + ( + kv, + combined_indices_buffer, + combined_lens_buffer, + max_score_buffer, + denom_buffer, + output_buffer, + ) = workspace_manager.get_simultaneous( + ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads, q.shape[-1]), torch.float32), + ) + prefill_state_buffers = ( + max_score_buffer, + denom_buffer, + output_buffer, + ) + else: + ( + kv, + combined_indices_buffer, + combined_lens_buffer, + ) = workspace_manager.get_simultaneous( + ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ) + prefill_state_buffers = None for chunk_idx in range(num_chunks): chunk_start = chunk_idx * PREFILL_CHUNK_SIZE chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills) @@ -1008,6 +1557,7 @@ def _forward_prefill( query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base ) + query_tokens = query_end - query_start combined_indices, combined_lens = combine_topk_swa_indices( topk_indices[query_start:query_end], query_start_loc[ @@ -1020,8 +1570,21 @@ def _forward_prefill( top_k, M, N, + combined_indices=combined_indices_buffer[:query_tokens], + combined_lens=combined_lens_buffer[:query_tokens], ) + if triton_sparse_mla_enabled: + self._forward_sparse_mla_prefill_triton( + q=q[query_start:query_end], + kv=kv[:chunk_size], + combined_indices=combined_indices, + combined_lens=combined_lens, + output=output[query_start:query_end], + state_buffers=prefill_state_buffers, + ) + continue + if current_platform.is_rocm(): rocm_sparse_attn_prefill( q=q[query_start:query_end], @@ -1033,16 +1596,17 @@ def _forward_prefill( attn_sink=self.attn_sink, output=output[query_start:query_end], ) - else: - output_chunk, _, _ = flash_mla_sparse_fwd( - q=q[query_start:query_end], - kv=kv.view(-1, 1, q.shape[-1]), - indices=combined_indices.unsqueeze(1), - sm_scale=self.scale, - attn_sink=self.attn_sink, - topk_length=combined_lens, - out=output[query_start:query_end], - ) + continue + + output_chunk, _, _ = flash_mla_sparse_fwd( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + sm_scale=self.scale, + attn_sink=self.attn_sink, + topk_length=combined_lens, + out=output[query_start:query_end], + ) class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): diff --git a/vllm/model_executor/layers/deepseek_v4_triton_kernels.py b/vllm/model_executor/layers/deepseek_v4_triton_kernels.py new file mode 100644 index 000000000000..b5048c5fd013 --- /dev/null +++ b/vllm/model_executor/layers/deepseek_v4_triton_kernels.py @@ -0,0 +1,1282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Triton fallback kernels used by the local DeepSeek V4 path.""" + +import torch + +from vllm.triton_utils import LOG2E, tl, triton + +DEEPSEEK_V4_MLA_HEAD_DIM = 512 +FP8_DS_MLA_FP8_DIM = 448 +FP8_DS_MLA_SCALE_GROUP = 64 +FP8_DS_MLA_SCALE_BYTES = 8 +FP8_DS_MLA_TOKEN_BYTES = 576 + + +def _view_packed_fp8_paged_mqa_kv_cache( + kv_cache: torch.Tensor, + head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Return FP8 values and fp32 scales from indexer cache block storage.""" + if kv_cache.dtype != torch.uint8: + raise TypeError(f"Expected uint8 kv_cache, got {kv_cache.dtype}") + if kv_cache.dim() == 3: + num_blocks, block_size, head_dim_with_scale = kv_cache.shape + num_kv_heads = 1 + elif kv_cache.dim() == 4: + num_blocks, block_size, num_kv_heads, head_dim_with_scale = kv_cache.shape + else: + raise ValueError( + f"Expected 3D or 4D kv_cache, got {kv_cache.dim()} dimensions" + ) + if num_kv_heads != 1: + raise ValueError(f"Expected one KV head, got {num_kv_heads}") + + scale_bytes = head_dim_with_scale - head_dim + if scale_bytes <= 0 or scale_bytes % torch.float32.itemsize != 0: + raise ValueError( + "Expected kv_cache last dimension to contain FP8 values followed " + f"by fp32 scale bytes; got head_dim={head_dim}, " + f"last_dim={head_dim_with_scale}" + ) + + block_stride = kv_cache.stride(0) + base_storage_offset = kv_cache.storage_offset() + scale_elems = scale_bytes // torch.float32.itemsize + kv_values = torch.as_strided( + kv_cache, + size=(num_blocks, block_size, 1, head_dim), + stride=(block_stride, head_dim, head_dim, 1), + storage_offset=base_storage_offset, + ).view(torch.float8_e4m3fn) + kv_scale = torch.as_strided( + kv_cache, + size=(num_blocks, block_size, 1, scale_bytes), + stride=(block_stride, scale_bytes, scale_bytes, 1), + storage_offset=base_storage_offset + block_size * head_dim, + ).view(torch.float32) + return kv_values, kv_scale[..., :scale_elems] + + +@triton.jit +def _sparse_attention_bf16_kernel( + q_ptr, + kv_ptr, + indices_ptr, + lengths_ptr, + sink_ptr, + out_ptr, + num_tokens: tl.constexpr, + num_heads: tl.constexpr, + seq_kv: tl.constexpr, + index_topk: tl.constexpr, + sm_scale_log2: tl.constexpr, + stride_qt: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_k: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + HAS_SINK: tl.constexpr, + LOG2E_CONST: tl.constexpr, +): + token_id = tl.program_id(0) + head_block = tl.program_id(1) + heads = head_block * BLOCK_H + tl.arange(0, BLOCK_H) + offs_d = tl.arange(0, BLOCK_D) + mask_h = heads < num_heads + + q = tl.load( + q_ptr + + token_id * stride_qt + + heads[:, None] * stride_qh + + offs_d[None, :] * stride_qd, + mask=mask_h[:, None], + other=0.0, + ) + + if HAS_SINK: + sink = tl.load(sink_ptr + heads, mask=mask_h, other=-float("inf")) + e_max = sink * LOG2E_CONST + e_sum = tl.where(mask_h, 1.0, 0.0) + else: + e_max = tl.full((BLOCK_H,), -float("inf"), dtype=tl.float32) + e_sum = tl.zeros((BLOCK_H,), dtype=tl.float32) + acc = tl.zeros((BLOCK_H, BLOCK_D), dtype=tl.float32) + + length = tl.load(lengths_ptr + token_id) + for start in range(0, index_topk, BLOCK_N): + offs_n = start + tl.arange(0, BLOCK_N) + idx = tl.load( + indices_ptr + token_id * stride_indices_t + offs_n * stride_indices_k, + mask=offs_n < index_topk, + other=-1, + ) + mask_kv = (offs_n < length) & (idx >= 0) & (idx < seq_kv) + k = tl.load( + kv_ptr + idx[None, :] * stride_kv_t + offs_d[:, None] * stride_kv_d, + mask=mask_kv[None, :], + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) * sm_scale_log2 + qk = tl.where( + mask_h[:, None] & mask_kv[None, :], + qk, + -3.4028234663852886e38, + ) + + v = tl.load( + kv_ptr + idx[:, None] * stride_kv_t + offs_d[None, :] * stride_kv_d, + mask=mask_kv[:, None], + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp2(e_max - n_e_max) + p = tl.exp2(qk - n_e_max[:, None]) + p = tl.where(mask_h[:, None] & mask_kv[None, :], p, 0.0) + acc = acc * re_scale[:, None] + tl.dot(p.to(v.dtype), v) + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + acc = acc / tl.maximum(e_sum, 1.0e-20)[:, None] + tl.store( + out_ptr + + token_id * stride_out_t + + heads[:, None] * stride_out_h + + offs_d[None, :] * stride_out_d, + acc.to(tl.bfloat16), + mask=mask_h[:, None], + ) + + +def sparse_attention_triton( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + lengths: torch.Tensor, + scale: float, + attn_sink: torch.Tensor | None, + out: torch.Tensor, +) -> None: + if indices.ndim == 3: + indices = indices.squeeze(1) + if kv.ndim == 3: + kv = kv.squeeze(1) + + num_tokens, num_heads, head_dim = q.shape + if num_tokens == 0: + return + if head_dim != DEEPSEEK_V4_MLA_HEAD_DIM: + raise ValueError( + "DeepSeek V4 sparse Triton fallback expects " + f"D={DEEPSEEK_V4_MLA_HEAD_DIM}, got {head_dim}" + ) + assert kv.shape[-1] == head_dim + assert out.shape[-1] == head_dim + + grid = (num_tokens, triton.cdiv(num_heads, 8)) + _sparse_attention_bf16_kernel[grid]( + q, + kv, + indices, + lengths, + attn_sink if attn_sink is not None else q, + out, + num_tokens, + num_heads, + kv.shape[0], + indices.shape[-1], + scale * LOG2E, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + indices.stride(0), + indices.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_H=8, + BLOCK_N=16, + BLOCK_D=DEEPSEEK_V4_MLA_HEAD_DIM, + HAS_SINK=attn_sink is not None, + LOG2E_CONST=LOG2E, + num_warps=8, + ) + + +@triton.jit +def _decode_sparse_attention_fp8_kernel( + q_ptr, + swa_cache_fp8_ptr, + swa_cache_bf16_ptr, + swa_cache_u8_ptr, + swa_indices_ptr, + swa_lens_ptr, + extra_cache_fp8_ptr, + extra_cache_bf16_ptr, + extra_cache_u8_ptr, + extra_indices_ptr, + extra_lens_ptr, + sink_ptr, + out_ptr, + num_tokens: tl.constexpr, + num_heads: tl.constexpr, + swa_index_topk: tl.constexpr, + extra_index_topk: tl.constexpr, + swa_num_blocks: tl.constexpr, + extra_num_blocks: tl.constexpr, + swa_block_size: tl.constexpr, + extra_block_size: tl.constexpr, + swa_stride_block_bytes: tl.constexpr, + extra_stride_block_bytes: tl.constexpr, + sm_scale_log2: tl.constexpr, + stride_qt: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_swa_indices_t: tl.constexpr, + stride_swa_indices_k: tl.constexpr, + stride_extra_indices_t: tl.constexpr, + stride_extra_indices_k: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + FP8_DIM: tl.constexpr, + SCALE_GROUP: tl.constexpr, + SCALE_BYTES: tl.constexpr, + TOKEN_BYTES: tl.constexpr, + HAS_EXTRA: tl.constexpr, + HAS_SINK: tl.constexpr, + LOG2E_CONST: tl.constexpr, +): + token_id = tl.program_id(0) + head_block = tl.program_id(1) + heads = head_block * BLOCK_H + tl.arange(0, BLOCK_H) + offs_d = tl.arange(0, BLOCK_D) + mask_h = heads < num_heads + + q = tl.load( + q_ptr + + token_id * stride_qt + + heads[:, None] * stride_qh + + offs_d[None, :] * stride_qd, + mask=mask_h[:, None], + other=0.0, + ) + + if HAS_SINK: + sink = tl.load(sink_ptr + heads, mask=mask_h, other=-float("inf")) + e_max = sink * LOG2E_CONST + e_sum = tl.where(mask_h, 1.0, 0.0) + else: + e_max = tl.full((BLOCK_H,), -float("inf"), dtype=tl.float32) + e_sum = tl.zeros((BLOCK_H,), dtype=tl.float32) + acc = tl.zeros((BLOCK_H, BLOCK_D), dtype=tl.float32) + + swa_len = tl.load(swa_lens_ptr + token_id) + extra_len = tl.load(extra_lens_ptr + token_id) if HAS_EXTRA else 0 + total_len = extra_len + swa_len + + for start in range(0, extra_index_topk + swa_index_topk, BLOCK_N): + offs_n = start + tl.arange(0, BLOCK_N) + use_extra = HAS_EXTRA & (offs_n < extra_len) + use_swa = (offs_n >= extra_len) & (offs_n < total_len) + + extra_cols = offs_n + swa_cols = offs_n - extra_len + extra_idx = tl.load( + extra_indices_ptr + + token_id * stride_extra_indices_t + + extra_cols * stride_extra_indices_k, + mask=HAS_EXTRA & (extra_cols < extra_index_topk), + other=-1, + ) + swa_idx = tl.load( + swa_indices_ptr + + token_id * stride_swa_indices_t + + swa_cols * stride_swa_indices_k, + mask=(swa_cols >= 0) & (swa_cols < swa_index_topk), + other=-1, + ) + idx = tl.where(use_extra, extra_idx, swa_idx) + + extra_block = idx // extra_block_size + extra_pos = idx - extra_block * extra_block_size + swa_block = idx // swa_block_size + swa_pos = idx - swa_block * swa_block_size + valid_extra = use_extra & (idx >= 0) & (extra_block < extra_num_blocks) + valid_swa = use_swa & (idx >= 0) & (swa_block < swa_num_blocks) + valid = valid_extra | valid_swa + + extra_token_base = extra_block * extra_stride_block_bytes + extra_token_base += extra_pos * TOKEN_BYTES + swa_token_base = swa_block * swa_stride_block_bytes + swa_token_base += swa_pos * TOKEN_BYTES + token_base = tl.where(use_extra, extra_token_base, swa_token_base) + block_size = tl.where(use_extra, extra_block_size, swa_block_size) + stride_block_bytes = tl.where( + use_extra, extra_stride_block_bytes, swa_stride_block_bytes + ) + pos = tl.where(use_extra, extra_pos, swa_pos) + + is_fp8 = offs_d < FP8_DIM + scale_offsets = ( + tl.where(use_extra, extra_block, swa_block)[:, None] + * stride_block_bytes[:, None] + + block_size[:, None] * TOKEN_BYTES + + pos[:, None] * SCALE_BYTES + + (offs_d[None, :] // SCALE_GROUP) + ) + encoded_scale = tl.load( + tl.where(use_extra[:, None], extra_cache_u8_ptr, swa_cache_u8_ptr) + + scale_offsets, + mask=valid[:, None] & is_fp8[None, :], + other=127, + ).to(tl.float32) + fp8_scale = tl.exp2(encoded_scale - 127.0) + + fp8_offsets = token_base[:, None] + offs_d[None, :] + fp8_vals = ( + tl.load( + tl.where(use_extra[:, None], extra_cache_fp8_ptr, swa_cache_fp8_ptr) + + fp8_offsets, + mask=valid[:, None] & is_fp8[None, :], + other=0.0, + ).to(tl.float32) + * fp8_scale + ) + + bf16_offsets = (token_base[:, None] + FP8_DIM) // 2 + bf16_offsets += offs_d[None, :] - FP8_DIM + bf16_vals = tl.load( + tl.where(use_extra[:, None], extra_cache_bf16_ptr, swa_cache_bf16_ptr) + + bf16_offsets, + mask=valid[:, None] & (~is_fp8[None, :]), + other=0.0, + ).to(tl.float32) + k = tl.where(is_fp8[None, :], fp8_vals, bf16_vals) + + qk = tl.dot(q, tl.trans(k.to(q.dtype))) * sm_scale_log2 + qk = tl.where( + mask_h[:, None] & valid[None, :], + qk, + -3.4028234663852886e38, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp2(e_max - n_e_max) + p = tl.exp2(qk - n_e_max[:, None]) + p = tl.where(mask_h[:, None] & valid[None, :], p, 0.0) + acc = acc * re_scale[:, None] + tl.dot(p.to(k.dtype), k) + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + acc = acc / tl.maximum(e_sum, 1.0e-20)[:, None] + tl.store( + out_ptr + + token_id * stride_out_t + + heads[:, None] * stride_out_h + + offs_d[None, :] * stride_out_d, + acc.to(tl.bfloat16), + mask=mask_h[:, None], + ) + + +def decode_sparse_attention_triton( + q: torch.Tensor, + swa_cache: torch.Tensor, + swa_indices: torch.Tensor, + swa_lens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor | None, + out: torch.Tensor, + extra_cache: torch.Tensor | None = None, + extra_indices: torch.Tensor | None = None, + extra_lens: torch.Tensor | None = None, +) -> None: + if swa_indices.ndim == 3: + swa_indices = swa_indices.squeeze(1) + if extra_indices is not None and extra_indices.ndim == 3: + extra_indices = extra_indices.squeeze(1) + + num_tokens, num_heads, head_dim = q.shape + if num_tokens == 0: + return + if head_dim != DEEPSEEK_V4_MLA_HEAD_DIM: + raise ValueError( + "DeepSeek V4 decode Triton fallback expects " + f"D={DEEPSEEK_V4_MLA_HEAD_DIM}, got {head_dim}" + ) + has_extra = ( + extra_cache is not None and extra_indices is not None and extra_lens is not None + ) + if not has_extra: + extra_cache = swa_cache + extra_indices = swa_indices[:, :1] + extra_lens = swa_lens + + assert extra_cache is not None + assert extra_indices is not None + assert extra_lens is not None + grid = (num_tokens, triton.cdiv(num_heads, 8)) + _decode_sparse_attention_fp8_kernel[grid]( + q, + swa_cache.view(torch.float8_e4m3fn), + swa_cache.view(torch.bfloat16), + swa_cache, + swa_indices, + swa_lens, + extra_cache.view(torch.float8_e4m3fn), + extra_cache.view(torch.bfloat16), + extra_cache, + extra_indices, + extra_lens, + attn_sink if attn_sink is not None else q, + out, + num_tokens, + num_heads, + swa_indices.shape[-1], + extra_indices.shape[-1] if has_extra else 0, + swa_cache.shape[0], + extra_cache.shape[0], + swa_cache.shape[1], + extra_cache.shape[1], + swa_cache.stride(0), + extra_cache.stride(0), + scale * LOG2E, + q.stride(0), + q.stride(1), + q.stride(2), + swa_indices.stride(0), + swa_indices.stride(1), + extra_indices.stride(0), + extra_indices.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_H=8, + BLOCK_N=16, + BLOCK_D=DEEPSEEK_V4_MLA_HEAD_DIM, + FP8_DIM=FP8_DS_MLA_FP8_DIM, + SCALE_GROUP=FP8_DS_MLA_SCALE_GROUP, + SCALE_BYTES=FP8_DS_MLA_SCALE_BYTES, + TOKEN_BYTES=FP8_DS_MLA_TOKEN_BYTES, + HAS_EXTRA=has_extra, + HAS_SINK=attn_sink is not None, + LOG2E_CONST=LOG2E, + num_warps=8, + ) + + +@triton.jit +def _deepseek_v4_fp8_einsum_triton_kernel( + a_ptr, + a_scale_ptr, + b_ptr, + b_scale_ptr, + out_ptr, + B: tl.constexpr, + G: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + a_stride_b: tl.constexpr, + a_stride_g: tl.constexpr, + a_stride_k: tl.constexpr, + as_stride_b: tl.constexpr, + as_stride_g: tl.constexpr, + as_stride_kb: tl.constexpr, + b_stride_g: tl.constexpr, + b_stride_n: tl.constexpr, + b_stride_k: tl.constexpr, + bs_stride_g: tl.constexpr, + bs_stride_nb: tl.constexpr, + bs_stride_kb: tl.constexpr, + out_stride_b: tl.constexpr, + out_stride_g: tl.constexpr, + out_stride_n: tl.constexpr, + BLOCK_B: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_g = tl.program_id(1) + pid_n = tl.program_id(2) + + offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + acc = tl.zeros((BLOCK_B, BLOCK_N), dtype=tl.float32) + for k0 in range(0, K, BLOCK_K): + k = k0 + offs_k + kb = k0 // BLOCK_K + + a = tl.load( + a_ptr + + offs_b[:, None] * a_stride_b + + pid_g * a_stride_g + + k[None, :] * a_stride_k, + mask=(offs_b[:, None] < B) & (k[None, :] < K), + other=0.0, + ) + b = tl.load( + b_ptr + + pid_g * b_stride_g + + offs_n[:, None] * b_stride_n + + k[None, :] * b_stride_k, + mask=(offs_n[:, None] < N) & (k[None, :] < K), + other=0.0, + ) + a_s = tl.load( + a_scale_ptr + + offs_b * as_stride_b + + pid_g * as_stride_g + + kb * as_stride_kb, + mask=offs_b < B, + other=0.0, + ).to(tl.float32) + b_s = tl.load( + b_scale_ptr + + pid_g * bs_stride_g + + (offs_n // BLOCK_K) * bs_stride_nb + + kb * bs_stride_kb, + mask=offs_n < N, + other=0.0, + ).to(tl.float32) + acc += ( + tl.dot(a, tl.trans(b), out_dtype=tl.float32) * a_s[:, None] * b_s[None, :] + ) + + tl.store( + out_ptr + + offs_b[:, None] * out_stride_b + + pid_g * out_stride_g + + offs_n[None, :] * out_stride_n, + acc, + mask=(offs_b[:, None] < B) & (offs_n[None, :] < N), + ) + + +def _e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: + return (scale.view(torch.uint8).to(torch.int32) << 23).view(torch.float32) + + +def _unpack_int32_e8m0_scales( + packed_scale: torch.Tensor, + num_blocks: int, +) -> torch.Tensor: + shifts = torch.arange(4, device=packed_scale.device, dtype=torch.int32) * 8 + unpacked = (packed_scale.to(torch.int32).unsqueeze(-1) >> shifts) & 0xFF + unpacked = unpacked.reshape(*packed_scale.shape[:-1], -1)[..., :num_blocks] + return (unpacked << 23).view(torch.float32) + + +def _normalize_deepseek_v4_fp8_einsum_inputs( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B, G, K = a.shape + _, out_g, N = out.shape + assert out_g == G + k_blocks = triton.cdiv(K, 128) + n_blocks = triton.cdiv(N, 128) + + if b.ndim == 2: + b = b.view(G, N, K) + if b_scale.ndim == 2: + b_scale = b_scale.view(G, n_blocks, k_blocks) + + if a_scale.dtype == torch.int32: + a_scale = _unpack_int32_e8m0_scales(a_scale, k_blocks) + if b_scale.dtype == torch.int32: + b_scale = _unpack_int32_e8m0_scales(b_scale, k_blocks) + + if a_scale.dtype == torch.float8_e8m0fnu: + a_scale = _e8m0_to_fp32(a_scale) + if b_scale.dtype == torch.float8_e8m0fnu: + b_scale = _e8m0_to_fp32(b_scale) + + return a, a_scale.contiguous(), b, b_scale.contiguous() + + +def deepseek_v4_fp8_einsum_triton( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, +) -> None: + a, a_scale, b, b_scale = _normalize_deepseek_v4_fp8_einsum_inputs( + a, a_scale, b, b_scale, out + ) + B, G, K = a.shape + N = out.shape[-1] + grid = (triton.cdiv(B, 16), G, triton.cdiv(N, 32)) + _deepseek_v4_fp8_einsum_triton_kernel[grid]( + a, + a_scale, + b, + b_scale, + out, + B, + G, + N, + K, + a.stride(0), + a.stride(1), + a.stride(2), + a_scale.stride(0), + a_scale.stride(1), + a_scale.stride(2), + b.stride(0), + b.stride(1), + b.stride(2), + b_scale.stride(0), + b_scale.stride(1), + b_scale.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_B=16, + BLOCK_N=32, + BLOCK_K=128, + num_warps=4, + ) + + +@triton.jit +def _fp8_mqa_logits_kernel( + q_ptr, + k_ptr, + scale_ptr, + weights_ptr, + cu_seqlen_ks_ptr, + cu_seqlen_ke_ptr, + logits_ptr, + num_q: tl.constexpr, + seq_len_kv: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kn: tl.constexpr, + stride_kd: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + valid_m = offs_m < num_q + valid_n = offs_n < seq_len_kv + seq_start = tl.load(cu_seqlen_ks_ptr + offs_m, mask=valid_m, other=0) + seq_end = tl.load(cu_seqlen_ke_ptr + offs_m, mask=valid_m, other=0) + seq_mask = (offs_n[None, :] >= seq_start[:, None]) & ( + offs_n[None, :] < seq_end[:, None] + ) + + logits = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for h in tl.range(0, num_heads): + scores = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + offs_m[:, None] * stride_qm + + h * stride_qh + + d[None, :] * stride_qd, + mask=valid_m[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + k_ptr + offs_n[:, None] * stride_kn + d[None, :] * stride_kd, + mask=valid_n[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.dot(q, tl.trans(k), input_precision="tf32") + scale = tl.load(scale_ptr + offs_n, mask=valid_n, other=0.0) + weighted = tl.maximum(scores * scale[None, :], 0.0) + weight = tl.load( + weights_ptr + offs_m * stride_wm + h * stride_wh, + mask=valid_m, + other=0.0, + ) + logits += weighted * weight[:, None] + + store_mask = valid_m[:, None] & valid_n[None, :] + logits = tl.where(seq_mask & store_mask, logits, float("-inf")) + tl.store( + logits_ptr + offs_m[:, None] * stride_lm + offs_n[None, :] * stride_ln, + logits, + mask=store_mask, + ) + + +def fp8_mqa_logits_triton( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + k_fp8, scale = kv + num_q, num_heads, head_dim = q.shape + seq_len_kv = k_fp8.shape[0] + logits = torch.empty( + (num_q, seq_len_kv), + device=q.device, + dtype=torch.float32, + ) + if num_q == 0 or seq_len_kv == 0: + return logits + + grid = (triton.cdiv(num_q, 8), triton.cdiv(seq_len_kv, 64)) + _fp8_mqa_logits_kernel[grid]( + q, + k_fp8, + scale, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + logits, + num_q, + seq_len_kv, + num_heads, + head_dim, + q.stride(0), + q.stride(1), + q.stride(2), + k_fp8.stride(0), + k_fp8.stride(1), + weights.stride(0), + weights.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_M=8, + BLOCK_N=64, + BLOCK_D=64, + num_warps=4, + ) + return logits + + +@triton.jit +def _fp8_paged_mqa_logits_kernel( + q_ptr, + kv_ptr, + scale_ptr, + weights_ptr, + context_lens_ptr, + block_tables_ptr, + logits_ptr, + token_start, + num_rows: tl.constexpr, + logits_width: tl.constexpr, + next_n: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + block_size: tl.constexpr, + stride_qb: tl.constexpr, + stride_qn: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kvb: tl.constexpr, + stride_kvs: tl.constexpr, + stride_kvd: tl.constexpr, + stride_sb: tl.constexpr, + stride_ss: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_clb: tl.constexpr, + stride_cln: tl.constexpr, + stride_btb: tl.constexpr, + stride_btk: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_local_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = token_start + offs_local_n + offs_d = tl.arange(0, BLOCK_D) + + valid_m = offs_m < num_rows + valid_n = offs_local_n < logits_width + batch = offs_m // next_n + q_pos = offs_m - batch * next_n + context_len = tl.load( + context_lens_ptr + batch * stride_clb + q_pos * stride_cln, + mask=valid_m, + other=0, + ) + context_mask = valid_n[None, :] & (offs_n[None, :] < context_len[:, None]) + + block_rank = offs_n // block_size + block_offset = offs_n - block_rank * block_size + block_idx = tl.load( + block_tables_ptr + + batch[:, None] * stride_btb + + block_rank[None, :] * stride_btk, + mask=valid_m[:, None] & valid_n[None, :], + other=0, + ) + + logits = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + scale = tl.load( + scale_ptr + block_idx * stride_sb + block_offset[None, :] * stride_ss, + mask=context_mask, + other=0.0, + ) + for h in tl.range(0, num_heads): + scores = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + batch[:, None] * stride_qb + + q_pos[:, None] * stride_qn + + h * stride_qh + + d[None, :] * stride_qd, + mask=valid_m[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + kv_ptr + + block_idx[:, :, None] * stride_kvb + + block_offset[None, :, None] * stride_kvs + + d[None, None, :] * stride_kvd, + mask=context_mask[:, :, None] & (d[None, None, :] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.sum(q[:, None, :] * k, axis=2) + weighted = tl.maximum(scores * scale, 0.0) + weight = tl.load( + weights_ptr + offs_m * stride_wm + h * stride_wh, + mask=valid_m, + other=0.0, + ) + logits += weighted * weight[:, None] + + store_mask = valid_m[:, None] & valid_n[None, :] + logits = tl.where(context_mask & store_mask, logits, float("-inf")) + tl.store( + logits_ptr + offs_m[:, None] * stride_lm + offs_local_n[None, :] * stride_ln, + logits, + mask=store_mask, + ) + + +def fp8_paged_mqa_logits_triton( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + token_start: int = 0, + token_count: int | None = None, +) -> torch.Tensor: + batch_size, next_n, num_heads, head_dim = q.size() + if next_n == 1 and head_dim % 64 == 0 and num_heads % 4 == 0: + return fp8_paged_mqa_logits_rowwise_triton( + q, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + token_start=token_start, + token_count=token_count, + ) + + kv_values, kv_scale = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_size, _, _ = kv_values.size() + num_rows = batch_size * next_n + if token_count is None: + token_count = max_model_len - token_start + assert token_start >= 0 + assert token_count >= 0 + assert token_start + token_count <= max_model_len + logits = torch.empty( + (num_rows, token_count), + device=q.device, + dtype=torch.float32, + ) + if num_rows == 0 or token_count == 0: + return logits + + context_lens_2d = context_lens.reshape(batch_size, -1) + if context_lens_2d.shape[1] == 1 and next_n != 1: + context_lens_2d = context_lens_2d.expand(batch_size, next_n).contiguous() + grid = (triton.cdiv(num_rows, 4), triton.cdiv(token_count, 64)) + _fp8_paged_mqa_logits_kernel[grid]( + q, + kv_values, + kv_scale, + weights, + context_lens_2d, + block_tables, + logits, + token_start, + num_rows, + token_count, + next_n, + num_heads, + head_dim, + block_size, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + kv_values.stride(0), + kv_values.stride(1), + kv_values.stride(3), + kv_scale.stride(0), + kv_scale.stride(1), + weights.stride(0), + weights.stride(1), + context_lens_2d.stride(0), + context_lens_2d.stride(1), + block_tables.stride(0), + block_tables.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_M=4, + BLOCK_N=64, + BLOCK_D=64, + num_warps=4, + ) + return logits + + +@triton.jit +def _fp8_paged_mqa_logits_rowwise_kernel( + q_ptr, + kv_ptr, + scale_ptr, + weights_ptr, + context_lens_ptr, + block_tables_ptr, + logits_ptr, + token_start, + num_rows: tl.constexpr, + logits_width: tl.constexpr, + next_n: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + block_size: tl.constexpr, + stride_qb: tl.constexpr, + stride_qn: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kvb: tl.constexpr, + stride_kvs: tl.constexpr, + stride_kvd: tl.constexpr, + stride_sb: tl.constexpr, + stride_ss: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_clb: tl.constexpr, + stride_cln: tl.constexpr, + stride_btb: tl.constexpr, + stride_btk: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + pid_n = tl.program_id(1) + offs_local_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = token_start + offs_local_n + offs_d = tl.arange(0, BLOCK_D) + + valid_row = row < num_rows + valid_n = offs_local_n < logits_width + batch = row // next_n + q_pos = row - batch * next_n + context_len = tl.load( + context_lens_ptr + batch * stride_clb + q_pos * stride_cln, + mask=valid_row, + other=0, + ) + context_mask = valid_n & (offs_n < context_len) + + block_rank = offs_n // block_size + block_offset = offs_n - block_rank * block_size + block_idx = tl.load( + block_tables_ptr + batch * stride_btb + block_rank * stride_btk, + mask=valid_row & valid_n, + other=0, + ) + + scale = tl.load( + scale_ptr + block_idx * stride_sb + block_offset * stride_ss, + mask=context_mask, + other=0.0, + ) + logits = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for h0 in tl.range(0, num_heads, BLOCK_H): + heads = h0 + tl.arange(0, BLOCK_H) + valid_h = heads < num_heads + scores = tl.zeros((BLOCK_H, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + batch * stride_qb + + q_pos * stride_qn + + heads[:, None] * stride_qh + + d[None, :] * stride_qd, + mask=valid_row & valid_h[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + kv_ptr + + block_idx[None, :] * stride_kvb + + block_offset[None, :] * stride_kvs + + d[:, None] * stride_kvd, + mask=context_mask[None, :] & (d[:, None] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.dot(q, k, input_precision="tf32") + + weighted = tl.maximum(scores * scale[None, :], 0.0) + weight = tl.load( + weights_ptr + row * stride_wm + heads * stride_wh, + mask=valid_row & valid_h, + other=0.0, + ) + logits += tl.sum(weighted * weight[:, None], axis=0) + + logits = tl.where(context_mask & valid_row, logits, float("-inf")) + tl.store( + logits_ptr + row * stride_lm + offs_local_n * stride_ln, + logits, + mask=valid_row & valid_n, + ) + + +def fp8_paged_mqa_logits_rowwise_triton( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + token_start: int = 0, + token_count: int | None = None, +) -> torch.Tensor: + batch_size, next_n, num_heads, head_dim = q.size() + kv_values, kv_scale = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_size, _, _ = kv_values.size() + num_rows = batch_size * next_n + if token_count is None: + token_count = max_model_len - token_start + assert token_start >= 0 + assert token_count >= 0 + assert token_start + token_count <= max_model_len + logits = torch.empty( + (num_rows, token_count), + device=q.device, + dtype=torch.float32, + ) + if num_rows == 0 or token_count == 0: + return logits + + context_lens_2d = context_lens.reshape(batch_size, -1) + if context_lens_2d.shape[1] == 1 and next_n != 1: + context_lens_2d = context_lens_2d.expand(batch_size, next_n).contiguous() + block_n = 128 + grid = (num_rows, triton.cdiv(token_count, block_n)) + _fp8_paged_mqa_logits_rowwise_kernel[grid]( + q, + kv_values, + kv_scale, + weights, + context_lens_2d, + block_tables, + logits, + token_start, + num_rows, + token_count, + next_n, + num_heads, + head_dim, + block_size, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + kv_values.stride(0), + kv_values.stride(1), + kv_values.stride(3), + kv_scale.stride(0), + kv_scale.stride(1), + weights.stride(0), + weights.stride(1), + context_lens_2d.stride(0), + context_lens_2d.stride(1), + block_tables.stride(0), + block_tables.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_N=block_n, + BLOCK_D=64, + BLOCK_H=8, + num_warps=4, + ) + return logits + + +@triton.jit +def _tf32_hc_prenorm_gemm_kernel( + x_ptr, + fn_ptr, + out_ptr, + sqrsum_ptr, + M: tl.constexpr, + K: tl.constexpr, + N: tl.constexpr, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_fnn: tl.constexpr, + stride_fnk: tl.constexpr, + stride_outs: tl.constexpr, + stride_outm: tl.constexpr, + stride_outn: tl.constexpr, + stride_sqs: tl.constexpr, + stride_sqm: tl.constexpr, + NUM_SPLIT: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + pid_s = tl.program_id(2) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + split_k = tl.cdiv(K, NUM_SPLIT) + split_begin = pid_s * split_k + split_end = tl.minimum(split_begin + split_k, K) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + sq = tl.zeros((BLOCK_M,), dtype=tl.float32) + + for k0 in tl.range(0, split_k, BLOCK_K): + k = split_begin + k0 + offs_k + k_mask = k < split_end + x = tl.load( + x_ptr + offs_m[:, None] * stride_xm + k[None, :] * stride_xk, + mask=(offs_m[:, None] < M) & k_mask[None, :], + other=0.0, + ).to(tl.float32) + fn = tl.load( + fn_ptr + offs_n[None, :] * stride_fnn + k[:, None] * stride_fnk, + mask=(offs_n[None, :] < N) & k_mask[:, None], + other=0.0, + ).to(tl.float32) + + acc += tl.dot(x, fn, input_precision="tf32", out_dtype=tl.float32) + sq += tl.sum(x * x, axis=1) + + tl.store( + out_ptr + + pid_s * stride_outs + + offs_m[:, None] * stride_outm + + offs_n[None, :] * stride_outn, + acc, + mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), + ) + + if pid_n == 0: + tl.store( + sqrsum_ptr + pid_s * stride_sqs + offs_m * stride_sqm, + sq, + mask=offs_m < M, + ) + + +def tf32_hc_prenorm_gemm_triton( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> None: + assert x.dim() == 2 + assert fn.dim() == 2 + assert out.dim() == 3 + assert sqrsum.dim() == 2 + + m, k = x.shape + n = fn.shape[0] + assert fn.shape[1] == k + assert out.shape == (num_split, m, n) + assert sqrsum.shape == (num_split, m) + + if m == 0: + return + + block_m = 16 + block_n = triton.next_power_of_2(n) + block_n = min(max(block_n, 16), 32) + block_k = 64 + grid = (triton.cdiv(m, block_m), triton.cdiv(n, block_n), num_split) + _tf32_hc_prenorm_gemm_kernel[grid]( + x, + fn, + out, + sqrsum, + m, + k, + n, + x.stride(0), + x.stride(1), + fn.stride(0), + fn.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + sqrsum.stride(0), + sqrsum.stride(1), + num_split, + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=4, + ) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 456f40bbf7a3..79edfa6f2d92 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -786,6 +786,19 @@ def update_expert_map(self): dp_size=get_dp_group().world_size, ) + @staticmethod + def _normalize_loaded_weight_for_copy( + expert_data: torch.Tensor, loaded_weight: torch.Tensor + ) -> torch.Tensor: + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if ( + e8m0_dtype is not None + and expert_data.dtype == torch.uint8 + and loaded_weight.dtype == e8m0_dtype + ): + return loaded_weight.view(torch.uint8) + return loaded_weight + def _load_per_tensor_weight_scale( self, shard_id: str, @@ -799,10 +812,12 @@ def _load_per_tensor_weight_scale( # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight + target = param_data[expert_id][idx] + target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight)) # If we are in the row parallel case (down_proj) elif shard_id == "w2": - param_data[expert_id] = loaded_weight + target = param_data[expert_id] + target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight)) def _load_combined_w13_weight_scale( self, @@ -819,7 +834,7 @@ def _load_combined_w13_weight_scale( loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size ) - param.copy_(loaded_weight) + param.copy_(self._normalize_loaded_weight_for_copy(param, loaded_weight)) def _load_model_weight_or_group_weight_scale( self, @@ -986,7 +1001,9 @@ def _load_w13( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) def _load_w2( self, @@ -1022,7 +1039,9 @@ def _load_w2( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) def _load_single_value( self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index d9aab35c25f4..e2df12488652 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -817,6 +817,35 @@ def get_w8a8_block_fp8_configs( return None +def _get_default_w8a8_block_fp8_config( + M: int, + block_n: int, + block_k: int, +) -> dict[str, Any]: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_n and + # BLOCK_SIZE_K must be divisible by block_k. + # M-aware tuning for low-M decode: BLOCK_SIZE_M=64 wastes most of the + # M-dim for single-request decode and short MTP-style draft batches. SM12x + # keeps benefiting from the low-M tile through M=32 on DeepSeek V4 shapes. + capability = current_platform.get_device_capability() + capability_major = getattr(capability, "major", None) + if capability_major is None and capability is not None: + capability_major = capability[0] + low_m_limit = 32 if capability_major == 12 else 8 + if low_m_limit >= M: + block_m, num_stages = 16, (2 if current_platform.is_rocm() else 3) + else: + block_m, num_stages = 64, 2 + return { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": num_stages, + } + + def w8a8_triton_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, @@ -861,6 +890,12 @@ def w8a8_triton_block_scaled_mm( N, K = B.shape assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if e8m0_dtype is not None: + if As.dtype == e8m0_dtype: + As = _upcast_e8m0_to_fp32(As) + if Bs.dtype == e8m0_dtype: + Bs = _upcast_e8m0_to_fp32(Bs) C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) @@ -870,17 +905,7 @@ def w8a8_triton_block_scaled_mm( # Get the optimal config if there is one config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: - # Default config - # Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0] - # BLOCK_SIZE_K must be divisible by block_size[1] - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": block_size[0], - "BLOCK_SIZE_K": block_size[1], - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 2, - } + config = _get_default_w8a8_block_fp8_config(M, block_size[0], block_size[1]) def grid(META): return ( @@ -1215,6 +1240,8 @@ def create_fp8_scale_parameter( if dtype == torch.float32: scale[:] = torch.finfo(torch.float32).min + elif dtype == getattr(torch, "float8_e8m0fnu", None): + scale[:] = 0 set_weight_attrs(scale, {"scale_type": "weight_scale"}) return scale diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 4bf52a49c43f..f974c873fdc1 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -4,7 +4,6 @@ import torch -import vllm.envs as envs from vllm._aiter_ops import rocm_aiter_ops from vllm.forward_context import get_forward_context from vllm.logger import init_logger @@ -12,7 +11,9 @@ from vllm.platforms import current_platform from vllm.utils.deep_gemm import ( fp8_fp4_mqa_logits, + fp8_fp4_mqa_topk_indices, fp8_fp4_paged_mqa_logits, + fp8_fp4_paged_mqa_topk_indices, has_deep_gemm, ) from vllm.utils.torch_utils import ( @@ -23,6 +24,7 @@ ) from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerMetadata, + sparse_indexer_max_logits_bytes, ) from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.worker.workspace import current_workspace_manager @@ -35,11 +37,60 @@ logger = init_logger(__name__) RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024 +SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH = 4096 +SM120_SHORT_ROW_TOPK_MAX_WIDTH = 12288 # MXFP4 layout: 2 values packed per byte, ue8m0 (1-byte) scale per block of 32. MXFP4_BLOCK_SIZE = 32 +def _should_use_sm120_short_row_topk_decode( + topk_tokens: int, + logits_width: int, + num_rows: int, + is_cuda_sm120: bool, +) -> bool: + if not is_cuda_sm120 or topk_tokens != 512: + return False + if logits_width <= SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH: + return True + return logits_width < SM120_SHORT_ROW_TOPK_MAX_WIDTH + + +def _use_sm120_short_row_topk_decode( + logits: torch.Tensor, + topk_tokens: int, +) -> bool: + return _should_use_sm120_short_row_topk_decode( + topk_tokens, + logits.shape[1], + logits.shape[0], + current_platform.is_cuda() + and current_platform.is_device_capability_family(120), + ) + + +def _decode_logits_width(max_model_len: int, max_seq_len: int) -> int: + if max_model_len <= 0: + return 0 + if max_seq_len <= 0: + return max_model_len + return min(max_model_len, max_seq_len) + + +def _decode_topk_logits_width( + max_model_len: int, max_seq_len: int, topk_tokens: int +) -> int: + logits_width = _decode_logits_width(max_model_len, max_seq_len) + return min(max_model_len, max(logits_width, topk_tokens)) + + +def _sparse_indexer_requires_deep_gemm() -> bool: + return current_platform.is_cuda() and not ( + current_platform.is_device_capability_family(120) + ) + + def _gather_workspace_shapes( total_seq_lens: int, head_dim: int, @@ -118,7 +169,7 @@ def sparse_attn_indexer( # Dummy allocation to simulate for peak logits tensor memory during inference. # FP8 elements so elements == bytes - max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + max_logits_elems = sparse_indexer_max_logits_bytes() _ = torch.empty( max_logits_elems, dtype=torch.uint8, device=hidden_states.device ) @@ -220,6 +271,19 @@ def sparse_attn_indexer( q_slice_cast = q_slice k_quant_cast = k_quant k_scale_cast = k_scale.view(torch.float32).squeeze(-1) + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + if fp8_fp4_mqa_topk_indices( + (q_slice_cast, q_scale_slice), + (k_quant_cast, k_scale_cast), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + ): + continue + logits = fp8_fp4_mqa_logits( (q_slice_cast, q_scale_slice), (k_quant_cast, k_scale_cast), @@ -230,10 +294,6 @@ def sparse_attn_indexer( ) num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[ - chunk.token_start : chunk.token_end, :topk_tokens - ] - if current_platform.is_xpu(): xpu_ops.top_k_per_row_prefill( # type: ignore[attr-defined] logits, @@ -307,35 +367,38 @@ def sparse_attn_indexer( if use_fp4_cache else padded_q_quant_decode_tokens ) - logits = fp8_fp4_paged_mqa_logits( - (padded_q_quant_cast, padded_q_scale), - kv_cache, - weights[:num_padded_tokens], - seq_lens, - decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len=max_model_len, - clean_logits=False, - ) - num_rows = logits.shape[0] topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - - if current_platform.is_cuda() and topk_tokens in (512, 1024, 2048): - workspace_manager = current_workspace_manager() - (topk_workspace,) = workspace_manager.get_simultaneous( - ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), - ) - torch.ops._C.persistent_topk( - logits, + logits_width = _decode_topk_logits_width( + max_model_len, attn_metadata_narrowed.max_seq_len, topk_tokens + ) + logits_bytes = num_padded_tokens * logits_width * torch.float32.itemsize + used_direct_topk = False + if logits_bytes > sparse_indexer_max_logits_bytes(): + used_direct_topk = fp8_fp4_paged_mqa_topk_indices( + (padded_q_quant_cast, padded_q_scale), + kv_cache, + weights[:num_padded_tokens], seq_lens, + decode_metadata.block_table, + logits_width, topk_indices, - topk_workspace, - topk_tokens, - attn_metadata_narrowed.max_seq_len, ) - else: - if current_platform.is_xpu(): - xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined] + + if not used_direct_topk: + logits = fp8_fp4_paged_mqa_logits( + (padded_q_quant_cast, padded_q_scale), + kv_cache, + weights[:num_padded_tokens], + seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=logits_width, + clean_logits=False, + ) + num_rows = logits.shape[0] + + if _use_sm120_short_row_topk_decode(logits, topk_tokens): + torch.ops._C.top_k_per_row_decode( logits, next_n, seq_lens, @@ -345,17 +408,42 @@ def sparse_attn_indexer( logits.stride(1), topk_tokens, ) - else: - torch.ops._C.top_k_per_row_decode( + elif current_platform.is_cuda() and topk_tokens in (512, 2048): + workspace_manager = current_workspace_manager() + (topk_workspace,) = workspace_manager.get_simultaneous( + ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), + ) + torch.ops._C.persistent_topk( logits, - next_n, seq_lens, topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), + topk_workspace, topk_tokens, + logits_width, ) + else: + if current_platform.is_xpu(): + xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined] + logits, + next_n, + seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + else: + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) if decode_metadata.requires_padding: # if padded, we need to unpack @@ -438,7 +526,7 @@ def __init__( self.topk_indices_buffer = topk_indices_buffer self.skip_k_cache_insert = skip_k_cache_insert self.use_fp4_cache = use_fp4_cache - if current_platform.is_cuda() and not has_deep_gemm(): + if _sparse_indexer_requires_deep_gemm() and not has_deep_gemm(): raise RuntimeError( "Sparse Attention Indexer CUDA op requires DeepGEMM to be installed." ) diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index cef4038dc2e6..1c3adb3ac4b0 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -8,10 +8,12 @@ import torch import torch.nn as nn +from vllm import envs from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import ( get_ep_group, + get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) @@ -55,10 +57,14 @@ from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op +from .interfaces import SupportsPP from .utils import ( AutoWeightsLoader, + PPMissingLayer, WeightsMapper, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) @@ -703,6 +709,25 @@ def _deepseek_v4_mega_moe_experts_op_fake( ) +def _use_deepseek_v4_mega_moe(vllm_config: VllmConfig) -> bool: + use_mega_moe = ( + vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe" + ) + + env_name = "VLLM_DEEPSEEK_V4_USE_MEGA_MOE" + if envs.is_set(env_name): + use_mega_moe = envs.VLLM_DEEPSEEK_V4_USE_MEGA_MOE + + if use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel: + raise NotImplementedError( + "DeepSeek V4 MegaMoE currently requires expert parallel. " + "Enable it with --enable-expert-parallel, or pick a different " + "moe backend." + ) + + return use_mega_moe + + class DeepseekV4MoE(nn.Module): def __init__( self, @@ -715,15 +740,7 @@ def __init__( config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.prefix = prefix - self.use_mega_moe = ( - vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe" - ) - if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel: - raise NotImplementedError( - "DeepSeek V4 MegaMoE currently requires expert parallel. " - "Enable it with --enable-expert-parallel, or pick a different " - "moe backend." - ) + self.use_mega_moe = _use_deepseek_v4_mega_moe(vllm_config) self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) self.hidden_size = config.hidden_size @@ -1226,15 +1243,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config - self.use_mega_moe = ( - vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe" - ) - if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel: - raise NotImplementedError( - "DeepSeek V4 MegaMoE currently requires expert parallel. " - "Enable it with --enable-expert-parallel, or pick a different " - "moe backend." - ) + self.use_mega_moe = _use_deepseek_v4_mega_moe(vllm_config) self.vocab_size = config.vocab_size self.hc_eps = config.hc_eps self.hc_mult = config.hc_mult @@ -1261,12 +1270,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): device=self.device, ) - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens", - ) + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -1279,7 +1291,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.layers", ) - self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], self.hc_dim + ) self.hc_head_fn = nn.Parameter( torch.empty( @@ -1304,26 +1323,37 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Pre-hc_head residual stream buffer for the MTP draft. Stable # address (outside the cudagraph pool) so the copy_ in forward() # refreshes it correctly across captured shapes. - self._mtp_hidden_buffer = torch.empty( - vllm_config.scheduler_config.max_num_batched_tokens, - self.hc_dim, - dtype=vllm_config.model_config.dtype, - device=self.device, - ) + if get_pp_group().is_last_rank: + self._mtp_hidden_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + self.hc_dim, + dtype=vllm_config.model_config.dtype, + device=self.device, + ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, - input_ids: torch.Tensor, + input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: - hidden_states = self.embed_input_ids(input_ids) - hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1) - if self.use_mega_moe: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"].view( + -1, self.hc_mult, self.config.hidden_size + ) + + if self.use_mega_moe and input_ids is not None: input_ids = input_ids.to(torch.int64) for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer( @@ -1332,6 +1362,9 @@ def forward( input_ids, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states.flatten(1)}) + # Stash pre-hc_head residual for the MTP draft (captured copy_). num_tokens = hidden_states.shape[0] self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1)) @@ -1379,6 +1412,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if weight_name not in name: continue name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + break param = params_dict[name] weight_loader = param.weight_loader @@ -1396,11 +1431,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: and loaded_weight.dtype == torch.float8_e8m0fnu ): loaded_weight = loaded_weight.view(torch.uint8) + skip_expert_weight = False for mapping in expert_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name_mapped = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name_mapped, self): + skip_expert_weight = True + break param = params_dict[name_mapped] # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other @@ -1419,15 +1458,21 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if success: name = name_mapped break + if skip_expert_weight: + continue loaded_params.add(name_mapped) continue elif "attn_sink" in name: + if is_pp_missing_parameter(name, self): + continue narrow_weight = loaded_weight[head_rank_start:head_rank_end] n = narrow_weight.shape[0] params_dict[name][:n].copy_(narrow_weight) loaded_params.add(name) continue else: + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader @@ -1525,7 +1570,7 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: ) -class DeepseekV4ForCausalLM(nn.Module): +class DeepseekV4ForCausalLM(nn.Module, SupportsPP): model_cls = DeepseekV4Model # Default mapper assumes the original FP4-expert checkpoint layout. @@ -1544,12 +1589,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = self.model_cls( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - prefix=maybe_prefix(prefix, "lm_head"), - ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) @@ -1563,7 +1614,7 @@ def compute_logits( def forward( self, - input_ids: torch.Tensor, + input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 6b89f5c33203..77ea8e9779ee 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -338,6 +338,254 @@ def transform_sf_into_required_layout(*args, **kwargs): ) +_SM120_MQA_LOGITS_MAX_SCORE_BYTES = 64 * 1024 * 1024 +_SM120_PAGED_MQA_TOPK_CHUNK_SIZE = 8192 + + +def _fp8_mqa_logits_head_chunk_size( + seq_len: int, + seq_len_kv: int, + num_heads: int, +) -> int: + # The SM120 torch path is used on long prefill paths where materializing + # [head_chunk, M, N] scores can otherwise allocate multiple GiB. Keep the + # transient score tensor bounded, while still using larger head chunks for + # short prompts where they are faster. + score_elems_per_head = max(1, seq_len * seq_len_kv) + max_heads = _SM120_MQA_LOGITS_MAX_SCORE_BYTES // (score_elems_per_head * 4) + return max(1, min(8, num_heads, max_heads)) + + +def _fp8_mqa_logits_k_chunk_size( + seq_len: int, + seq_len_kv: int, + head_chunk_size: int, +) -> int: + score_elems_per_key = max(1, seq_len * head_chunk_size) + max_keys = _SM120_MQA_LOGITS_MAX_SCORE_BYTES // (score_elems_per_key * 4) + return max(1, min(seq_len_kv, max_keys)) + + +def _fp8_mqa_logits_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 MQA logits torch path only supports FP8 Q") + + k_values, k_scales = kv + k_f32 = k_values.to(torch.float32) + k_f32.mul_(k_scales.reshape(-1, 1).to(torch.float32)) + k_t = k_f32.transpose(0, 1).contiguous() + + seq_len, num_heads, _ = q_values.shape + seq_len_kv = k_f32.shape[0] + logits = torch.zeros( + (seq_len, seq_len_kv), device=q_values.device, dtype=torch.float32 + ) + head_chunk_size = _fp8_mqa_logits_head_chunk_size(seq_len, seq_len_kv, num_heads) + + for head_start in range(0, num_heads, head_chunk_size): + head_end = min(head_start + head_chunk_size, num_heads) + q_chunk = q_values[:, head_start:head_end, :].to(torch.float32) + q_chunk = q_chunk.transpose(0, 1).contiguous() + head_weights = weights[:, head_start:head_end].transpose(0, 1).unsqueeze(-1) + k_chunk_size = _fp8_mqa_logits_k_chunk_size( + seq_len, seq_len_kv, head_end - head_start + ) + for k_start in range(0, seq_len_kv, k_chunk_size): + k_end = min(k_start + k_chunk_size, seq_len_kv) + scores = torch.matmul(q_chunk, k_t[:, k_start:k_end]) + scores.relu_() + scores.mul_(head_weights) + logits[:, k_start:k_end].add_( + scores[0] if scores.shape[0] == 1 else scores.sum(dim=0) + ) + + if clean_logits: + offsets = torch.arange(seq_len_kv, device=q_values.device) + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + logits = logits.masked_fill(~valid, float("-inf")) + + return logits + + +def _fp8_mqa_logits_topk_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_tokens: int, + out: torch.Tensor | None = None, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 MQA top-k torch path only supports FP8 Q") + + k_values, k_scales = kv + k_f32 = k_values.to(torch.float32) + k_f32.mul_(k_scales.reshape(-1, 1).to(torch.float32)) + k_t = k_f32.transpose(0, 1).contiguous() + + seq_len, num_heads, _ = q_values.shape + seq_len_kv = k_f32.shape[0] + if out is None: + out = torch.empty( + (seq_len, topk_tokens), device=q_values.device, dtype=torch.int32 + ) + else: + assert out.shape == (seq_len, topk_tokens) + assert out.dtype == torch.int32 + out.fill_(-1) + + best_values = torch.full( + (seq_len, topk_tokens), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + head_chunk_size = _fp8_mqa_logits_head_chunk_size(seq_len, seq_len_kv, num_heads) + k_chunk_size = _fp8_mqa_logits_k_chunk_size(seq_len, seq_len_kv, head_chunk_size) + max_chunk_topk = min(topk_tokens, k_chunk_size) + chunk_values_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + chunk_indices_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int64, + ) + chunk_indices_i32 = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + candidate_values = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + candidate_indices = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + next_best_values = torch.empty_like(best_values) + selected = torch.empty( + (seq_len, topk_tokens), + device=q_values.device, + dtype=torch.int64, + ) + + for k_start in range(0, seq_len_kv, k_chunk_size): + k_end = min(k_start + k_chunk_size, seq_len_kv) + chunk_logits = torch.zeros( + (seq_len, k_end - k_start), + device=q_values.device, + dtype=torch.float32, + ) + for head_start in range(0, num_heads, head_chunk_size): + head_end = min(head_start + head_chunk_size, num_heads) + q_chunk = q_values[:, head_start:head_end, :].to(torch.float32) + q_chunk = q_chunk.transpose(0, 1).contiguous() + head_weights = weights[:, head_start:head_end].transpose(0, 1).unsqueeze(-1) + scores = torch.matmul(q_chunk, k_t[:, k_start:k_end]) + scores.relu_() + scores.mul_(head_weights) + chunk_logits.add_(scores[0] if scores.shape[0] == 1 else scores.sum(dim=0)) + + offsets = torch.arange(k_start, k_end, device=q_values.device) + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + chunk_logits.masked_fill_(~valid, float("-inf")) + + chunk_topk = min(topk_tokens, k_end - k_start) + chunk_values = chunk_values_buf[:, :chunk_topk] + chunk_indices = chunk_indices_buf[:, :chunk_topk] + torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices)) + chunk_indices_out = chunk_indices_i32[:, :chunk_topk] + chunk_indices_out.copy_(chunk_indices) + chunk_indices_out.add_(k_start) + + candidate_cols = topk_tokens + chunk_topk + candidate_values_view = candidate_values[:, :candidate_cols] + candidate_indices_view = candidate_indices[:, :candidate_cols] + candidate_values_view[:, :topk_tokens].copy_(best_values) + candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values) + candidate_indices_view[:, :topk_tokens].copy_(out) + candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out) + torch.topk( + candidate_values_view, + topk_tokens, + dim=1, + out=(next_best_values, selected), + ) + torch.gather(candidate_indices_view, 1, selected, out=out) + best_values, next_best_values = next_best_values, best_values + out.masked_fill_(~torch.isfinite(best_values), -1) + + return out + + +def fp8_fp4_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 MQA top-k indices without materializing full logits.""" + _lazy_init() + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q[1] is None + ): + return False + _fp8_mqa_logits_topk_torch( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_indices.shape[1], + out=topk_indices, + ) + return True + + +def _fp8_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + q_values, q_scale = q + if clean_logits and q_scale is None and q_values.dim() == 3 and kv[0].dim() == 2: + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + fp8_mqa_logits_triton, + ) + + return fp8_mqa_logits_triton(q_values, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + return _fp8_mqa_logits_torch( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) + + def fp8_fp4_mqa_logits( q: tuple[torch.Tensor, torch.Tensor | None], kv: tuple[torch.Tensor, torch.Tensor], @@ -371,6 +619,10 @@ def fp8_fp4_mqa_logits( Logits tensor of shape [M, N], dtype `torch.float32`. """ _lazy_init() + if current_platform.is_device_capability_family(120) and q[1] is None: + return _fp8_mqa_logits_sm12x( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) if _fp8_fp4_mqa_logits_impl is None: return _missing() return _fp8_fp4_mqa_logits_impl( @@ -404,6 +656,216 @@ def get_paged_mqa_logits_metadata( return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) +def _fp8_paged_mqa_logits_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 paged MQA torch path only supports FP8 Q") + + batch_size, next_n, num_heads, head_dim = q_values.shape + head_dim_with_scale = kv_cache.shape[-1] + assert head_dim_with_scale > head_dim + assert weights.shape == (batch_size * next_n, num_heads) + assert context_lens.shape == (batch_size, next_n) + + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + _view_packed_fp8_paged_mqa_kv_cache, + ) + + kv_values, kv_scales = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_kv, _, _ = kv_values.shape + logits = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + + q_f32 = q_values.float() + score_bytes = _SM120_MQA_LOGITS_MAX_SCORE_BYTES + max_tokens_per_chunk = max(1, score_bytes // max(1, num_heads * 4)) + token_offsets_cache: dict[int, torch.Tensor] = {} + + for batch_idx in range(batch_size): + for next_idx in range(next_n): + row = batch_idx * next_n + next_idx + context_len = int(context_lens[batch_idx, next_idx].item()) + if context_len <= 0: + continue + + q_row = q_f32[batch_idx, next_idx] + row_weights = weights[row] + for token_start in range(0, context_len, max_tokens_per_chunk): + token_end = min(context_len, token_start + max_tokens_per_chunk) + chunk_len = token_end - token_start + token_offsets = token_offsets_cache.get(chunk_len) + if token_offsets is None or token_offsets.device != q_values.device: + token_offsets = torch.arange( + chunk_len, device=q_values.device, dtype=torch.long + ) + token_offsets_cache[chunk_len] = token_offsets + token_ids = token_start + token_offsets + logical_blocks = token_ids // block_kv + token_in_block = token_ids - logical_blocks * block_kv + physical_blocks = block_tables[batch_idx, logical_blocks] + kv_chunk = kv_values[physical_blocks, token_in_block, 0].float() + scale_chunk = kv_scales[physical_blocks, token_in_block, 0].squeeze(-1) + kv_chunk.mul_(scale_chunk[:, None]) + scores = torch.matmul(q_row, kv_chunk.T) + scores.relu_() + scores.mul_(row_weights[:, None]) + logits[row, token_start:token_end] = scores.sum(dim=0) + + return logits + + +def _fp8_paged_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + q_values, q_scale = q + if ( + q_scale is None + and q_values.dim() == 4 + and kv_cache.dtype == torch.uint8 + and kv_cache.shape[-1] == q_values.shape[-1] + 4 + ): + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + fp8_paged_mqa_logits_triton, + ) + + return fp8_paged_mqa_logits_triton( + q_values, kv_cache, weights, context_lens, block_tables, max_model_len + ) + return _fp8_paged_mqa_logits_torch( + q, kv_cache, weights, context_lens, block_tables, max_model_len + ) + + +def fp8_fp4_paged_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 paged MQA top-k indices without full logits.""" + _lazy_init() + q_values, q_scale = q + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q_scale is None + and q_values.dim() == 4 + and kv_cache.dtype == torch.uint8 + and kv_cache.shape[-1] == q_values.shape[-1] + 4 + ): + return False + + num_rows = q_values.shape[0] * q_values.shape[1] + topk_tokens = topk_indices.shape[1] + assert topk_indices.shape == (num_rows, topk_tokens) + assert topk_indices.dtype == torch.int32 + topk_indices.fill_(-1) + if num_rows == 0 or topk_tokens == 0 or max_model_len == 0: + return True + + best_values = torch.full( + (num_rows, topk_tokens), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + chunk_size = max(1, _SM120_PAGED_MQA_TOPK_CHUNK_SIZE) + max_chunk_topk = min(topk_tokens, chunk_size) + chunk_values_buf = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + chunk_indices_buf = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.int64, + ) + chunk_indices_i32 = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + candidate_values = torch.empty( + (num_rows, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + candidate_indices = torch.empty( + (num_rows, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + next_best_values = torch.empty_like(best_values) + selected = torch.empty( + (num_rows, topk_tokens), + device=q_values.device, + dtype=torch.int64, + ) + + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + fp8_paged_mqa_logits_triton, + ) + + for token_start in range(0, max_model_len, chunk_size): + token_count = min(chunk_size, max_model_len - token_start) + chunk_logits = fp8_paged_mqa_logits_triton( + q_values, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + token_start=token_start, + token_count=token_count, + ) + chunk_topk = min(topk_tokens, token_count) + chunk_values = chunk_values_buf[:, :chunk_topk] + chunk_indices = chunk_indices_buf[:, :chunk_topk] + torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices)) + chunk_indices_out = chunk_indices_i32[:, :chunk_topk] + chunk_indices_out.copy_(chunk_indices) + chunk_indices_out.add_(token_start) + + candidate_cols = topk_tokens + chunk_topk + candidate_values_view = candidate_values[:, :candidate_cols] + candidate_indices_view = candidate_indices[:, :candidate_cols] + candidate_values_view[:, :topk_tokens].copy_(best_values) + candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values) + candidate_indices_view[:, :topk_tokens].copy_(topk_indices) + candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out) + torch.topk( + candidate_values_view, + topk_tokens, + dim=1, + out=(next_best_values, selected), + ) + torch.gather(candidate_indices_view, 1, selected, out=topk_indices) + best_values, next_best_values = next_best_values, best_values + topk_indices.masked_fill_(~torch.isfinite(best_values), -1) + + return True + + def fp8_fp4_paged_mqa_logits( q: tuple[torch.Tensor, torch.Tensor | None], kv_cache: torch.Tensor, @@ -425,9 +887,10 @@ def fp8_fp4_paged_mqa_logits( [B, next_n, H, D] float8_e4m3fn and q_scale is None. FP4 path: q_values is packed uint8 and q_scale is the companion block-scale tensor. - kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, 1, - D+4], dtype `torch.uint8`, with the last 4 bytes per (block, pos) - storing the float dequant scale. + kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, D+4] + or [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. Within + each block, the D-byte FP8 values for every token are stored first, + followed by per-token fp32 scale bytes. weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. context_lens: Tensor of shape [B], dtype int32; effective context length for each batch element. @@ -443,6 +906,10 @@ def fp8_fp4_paged_mqa_logits( `torch.float32`. """ _lazy_init() + if current_platform.is_device_capability_family(120) and q[1] is None: + return _fp8_paged_mqa_logits_sm12x( + q, kv_cache, weights, context_lens, block_tables, max_model_len + ) if _fp8_fp4_paged_mqa_logits_impl is None: return _missing() return _fp8_fp4_paged_mqa_logits_impl( @@ -457,6 +924,52 @@ def fp8_fp4_paged_mqa_logits( ) +def _tf32_hc_prenorm_gemm_torch( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + """Portable SM12x HyperConnection prenorm GEMM fallback. + + DeepGEMM's split ABI only requires that downstream consumers recover the + full result by summing over the split dimension. Keep the implementation + simple by writing the full product to split zero and clearing the rest. + """ + del num_split + product = x.float() @ fn.float().T + norm = x.float().square().sum(dim=-1) + + if out.dim() == 3: + out.zero_() + sqrsum.zero_() + out[0].copy_(product) + sqrsum[0].copy_(norm) + else: + out.copy_(product) + sqrsum.copy_(norm) + return out + + +def _tf32_hc_prenorm_gemm_sm12x( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + if out.dim() == 3 and sqrsum.dim() == 2: + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + tf32_hc_prenorm_gemm_triton, + ) + + tf32_hc_prenorm_gemm_triton(x, fn, out, sqrsum, num_split) + return out + + return _tf32_hc_prenorm_gemm_torch(x, fn, out, sqrsum, num_split) + + def tf32_hc_prenorm_gemm( x: torch.Tensor, fn: torch.Tensor, @@ -472,6 +985,8 @@ def tf32_hc_prenorm_gemm( See the caller function for shape requirement """ _lazy_init() + if current_platform.is_device_capability_family(120): + return _tf32_hc_prenorm_gemm_sm12x(x, fn, out, sqrsum, num_split) if _tf32_hc_prenorm_gemm_impl is None: return _missing() return _tf32_hc_prenorm_gemm_impl( @@ -570,7 +1085,9 @@ def should_use_deepgemm_for_fp8_linear( "m_grouped_fp8_fp4_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", "fp8_fp4_mqa_logits", + "fp8_fp4_mqa_topk_indices", "fp8_fp4_paged_mqa_logits", + "fp8_fp4_paged_mqa_topk_indices", "get_paged_mqa_logits_metadata", "per_block_cast_to_fp8", "is_deep_gemm_e8m0_used", diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 474a5b2d421e..b2399c683eb5 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -30,6 +30,10 @@ SparseMLAAttentionImpl, ) from vllm.v1.attention.backends.mla.compressor_utils import get_compressed_slot_mapping +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled_for_platform, + triton_sparse_mla_cudagraphs_allowed, +) from vllm.v1.attention.backends.mla.sparse_utils import ( triton_convert_req_index_to_global_index, ) @@ -266,6 +270,20 @@ def get_prefill_workspace_size(max_model_len: int): class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + @classmethod + def get_cudagraph_support( + cls, + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + if ( + getattr(kv_cache_spec, "model_version", None) == "deepseek_v4" + and is_triton_sparse_mla_enabled_for_platform() + and not triton_sparse_mla_cudagraphs_allowed(vllm_config) + ): + return AttentionCGSupport.NEVER + return cls._cudagraph_support + def __init__( self, kv_cache_spec: AttentionSpec, diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 7c0715a9e8b6..eb0ea8f528b9 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from dataclasses import dataclass import torch -import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -32,6 +32,20 @@ logger = init_logger(__name__) +def sparse_indexer_max_logits_bytes(is_sm12x: bool | None = None) -> int: + configured_mb = os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB") + if configured_mb is not None: + return int(configured_mb) * 1024 * 1024 + + if is_sm12x is None: + is_sm12x = ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + ) + default_mb = 256 if is_sm12x else 512 + return default_mb * 1024 * 1024 + + @triton.jit def _prepare_uniform_decode_kernel( seq_lens_ptr, @@ -269,13 +283,12 @@ def __init__(self, *args, **kwargs): self.reorder_batch_threshold += self.num_speculative_tokens # NOTE(zyongye) fp4 indexer cache only natively supports next_n in # natively_supported_next_n_fp4; for other next_n values we fall back - # to the flattening path. Outside the SM100 datacenter family the FP8 - # paged MQA logits kernel has the same [1, 2] constraint (deepgemm - # smxx_fp8_fp4_paged_mqa_logits.hpp:233), so flatten there too. + # to the flattening path. When fp4 indexer cache is disabled, the + # native (non-flattening) path handles all next_n values. self.use_flattening = ( self.use_fp4_indexer_cache - or not current_platform.is_device_capability_family(100) - ) and next_n not in self.natively_supported_next_n_fp4 + and next_n not in self.natively_supported_next_n_fp4 + ) sm_count = num_compute_units(self.device.index) self.num_sms = sm_count @@ -520,7 +533,7 @@ def build( prefill_query_lens_cpu = torch.diff( query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1] ) - max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + max_logits_bytes = sparse_indexer_max_logits_bytes() # Upper bound is exact for prefill rows (the `[num_decodes:]` # slice below). assert common_attn_metadata.seq_lens_cpu_upper_bound is not None diff --git a/vllm/v1/attention/backends/mla/sparse_mla_env.py b/vllm/v1/attention/backends/mla/sparse_mla_env.py new file mode 100644 index 000000000000..52af38e4cea3 --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_env.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Environment controls for the portable Triton sparse MLA path.""" + +import os + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +_TRITON_MLA_SPARSE_ENV = "VLLM_TRITON_MLA_SPARSE" +_TRITON_MLA_SPARSE_TOPK_CHUNK_ENV = "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE" +_TRITON_MLA_SPARSE_QUERY_CHUNK_ENV = "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE" +_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH_ENV = ( + "VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH" +) +_TRITON_MLA_SPARSE_HEAD_BLOCK_ENV = "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE" +_TRITON_MLA_SPARSE_MATMUL_DECODE_ENV = "VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE" + +_ENV_TRUE_VALUES = {"1", "true", "yes", "on"} +_ENV_FALSE_VALUES = {"0", "false", "no", "off"} + +logger = init_logger(__name__) + + +def _optional_env_flag(name: str) -> bool | None: + raw_value = os.getenv(name) + if raw_value is None: + return None + value = raw_value.lower() + if value in _ENV_TRUE_VALUES: + return True + if value in _ENV_FALSE_VALUES: + return False + return None + + +def _is_sm12x_device(device: torch.device) -> bool: + if not torch.cuda.is_available(): + return False + index = device.index if device.index is not None else torch.cuda.current_device() + return torch.cuda.get_device_capability(index)[0] == 12 + + +def triton_sparse_mla_configured() -> bool | None: + return _optional_env_flag(_TRITON_MLA_SPARSE_ENV) + + +def is_triton_sparse_mla_enabled_for_platform() -> bool: + configured = triton_sparse_mla_configured() + if configured is not None: + return configured + return current_platform.is_device_capability_family(120) + + +def is_triton_sparse_mla_enabled(device: torch.device) -> bool: + configured = triton_sparse_mla_configured() + if configured is not None: + return configured + return _is_sm12x_device(device) + + +def _uses_speculative_decoding(vllm_config) -> bool: + return bool(getattr(vllm_config, "speculative_config", None)) + + +def triton_sparse_mla_cudagraphs_allowed(vllm_config=None) -> bool: + configured = _optional_env_flag(_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH_ENV) + if configured is not None: + return configured + return not ( + vllm_config is not None and _uses_speculative_decoding(vllm_config) + ) + + +def disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config) -> None: + if not is_triton_sparse_mla_enabled_for_platform(): + return + if triton_sparse_mla_cudagraphs_allowed(vllm_config): + logger.warning_once( + "Keeping vLLM compile and CUDA graphs enabled for the DeepSeek V4 " + "Triton sparse MLA path because " + f"{_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH_ENV}=1 or speculative " + "decoding is not configured. This is an " + "experimental performance mode." + ) + return + + from vllm.config.compilation import CompilationMode, CUDAGraphMode + + compilation_config = vllm_config.compilation_config + if ( + compilation_config.mode == CompilationMode.NONE + and compilation_config.cudagraph_mode == CUDAGraphMode.NONE + ): + return + + logger.warning_once( + "Disabling vLLM compile and CUDA graphs for the DeepSeek V4 Triton " + "sparse MLA path because the current Triton sparse MLA path is not " + "compile/graph-safe yet, or because speculative decoding uses " + "multi-token sparse MLA decode." + ) + compilation_config.mode = CompilationMode.NONE + compilation_config.compile_sizes = [] + compilation_config.compile_ranges_endpoints = [] + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + compilation_config.cudagraph_capture_sizes = [] + compilation_config.max_cudagraph_capture_size = 0 + + +def triton_sparse_mla_topk_chunk_size() -> int: + raw_value = os.getenv(_TRITON_MLA_SPARSE_TOPK_CHUNK_ENV) + if raw_value is None: + return 512 + try: + return max(1, int(raw_value)) + except ValueError: + return 512 + + +def triton_sparse_mla_query_chunk_size() -> int: + raw_value = os.getenv(_TRITON_MLA_SPARSE_QUERY_CHUNK_ENV) + if raw_value is None: + return 256 + try: + return max(1, int(raw_value)) + except ValueError: + return 256 + + +def triton_sparse_mla_head_block_size() -> int | None: + raw_value = os.getenv(_TRITON_MLA_SPARSE_HEAD_BLOCK_ENV) + if raw_value is None: + return None + try: + value = int(raw_value) + except ValueError: + return None + if value in (1, 2, 4): + return value + return None + + +def triton_sparse_mla_matmul_decode_enabled() -> bool: + configured = _optional_env_flag(_TRITON_MLA_SPARSE_MATMUL_DECODE_ENV) + if configured is not None: + return configured + return current_platform.is_device_capability_family(120) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py new file mode 100644 index 000000000000..834ecda43032 --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -0,0 +1,2694 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Portable sparse MLA Triton kernels.""" + +import torch + +from vllm.triton_utils import tl, triton +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + triton_sparse_mla_head_block_size, +) + + +def sparse_mla_decode_head_block_size(num_decode_tokens: int) -> int: + """Choose the SM12x sparse MLA head grouping for decode kernels. + + Single-token decode is latency sensitive and does best with one head per + program. Once there are enough query tokens, grouping heads lets the kernel + reuse each dequantized KV row across multiple heads. + """ + + configured_head_block_size = triton_sparse_mla_head_block_size() + if configured_head_block_size is not None: + return configured_head_block_size + if num_decode_tokens <= 4: + return 1 + if num_decode_tokens < 16: + return 2 + return 4 + + +@triton.jit +def _merge_two_subsets_with_sink_kernel( + out0_ptr, + lse0_ptr, + out1_ptr, + lse1_ptr, + sink_ptr, + output_ptr, + stride_out0_t: tl.constexpr, + stride_out0_h: tl.constexpr, + stride_out0_d: tl.constexpr, + stride_lse0_t: tl.constexpr, + stride_lse0_h: tl.constexpr, + stride_out1_t: tl.constexpr, + stride_out1_h: tl.constexpr, + stride_out1_d: tl.constexpr, + stride_lse1_t: tl.constexpr, + stride_lse1_h: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offsets < head_dim + + lse0 = tl.load(lse0_ptr + token_idx * stride_lse0_t + head_idx * stride_lse0_h) + lse1 = tl.load(lse1_ptr + token_idx * stride_lse1_t + head_idx * stride_lse1_h) + sink = tl.load(sink_ptr + head_idx) + merge_max = tl.maximum(tl.maximum(lse0, lse1), sink) + + weight0 = tl.exp(lse0 - merge_max) + weight1 = tl.exp(lse1 - merge_max) + weight_sink = tl.exp(sink - merge_max) + denom = weight0 + weight1 + weight_sink + + out0 = tl.load( + out0_ptr + + token_idx * stride_out0_t + + head_idx * stride_out0_h + + offsets * stride_out0_d, + mask=mask, + other=0.0, + ).to(tl.float32) + out1 = tl.load( + out1_ptr + + token_idx * stride_out1_t + + head_idx * stride_out1_h + + offsets * stride_out1_d, + mask=mask, + other=0.0, + ).to(tl.float32) + merged = (out0 * weight0 + out1 * weight1) / denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + merged, + mask=mask, + ) + + +def merge_two_sparse_mla_subsets_with_sink( + subset0_output: torch.Tensor, + subset0_lse: torch.Tensor, + subset1_output: torch.Tensor, + subset1_lse: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert subset0_output.shape == subset1_output.shape + assert subset0_output.shape == output.shape + assert subset0_lse.shape == subset1_lse.shape + assert subset0_lse.shape == subset0_output.shape[:2] + assert attn_sink.shape[0] == subset0_output.shape[1] + assert subset0_output.is_cuda + assert subset1_output.is_cuda + assert output.is_cuda + + num_tokens, num_heads, head_dim = subset0_output.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _merge_two_subsets_with_sink_kernel[grid]( + subset0_output, + subset0_lse, + subset1_output, + subset1_lse, + attn_sink, + output, + subset0_output.stride(0), + subset0_output.stride(1), + subset0_output.stride(2), + subset0_lse.stride(0), + subset0_lse.stride(1), + subset1_output.stride(0), + subset1_output.stride(1), + subset1_output.stride(2), + subset1_lse.stride(0), + subset1_lse.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _merge_single_subset_with_sink_kernel( + subset_output_ptr, + subset_lse_ptr, + sink_ptr, + output_ptr, + stride_subset_t: tl.constexpr, + stride_subset_h: tl.constexpr, + stride_subset_d: tl.constexpr, + stride_lse_t: tl.constexpr, + stride_lse_h: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offsets < head_dim + + subset_lse = tl.load( + subset_lse_ptr + token_idx * stride_lse_t + head_idx * stride_lse_h + ) + sink = tl.load(sink_ptr + head_idx) + merge_max = tl.maximum(subset_lse, sink) + + subset_weight = tl.exp(subset_lse - merge_max) + sink_weight = tl.exp(sink - merge_max) + denom = subset_weight + sink_weight + subset_output = tl.load( + subset_output_ptr + + token_idx * stride_subset_t + + head_idx * stride_subset_h + + offsets * stride_subset_d, + mask=mask, + other=0.0, + ).to(tl.float32) + merged = subset_output * subset_weight / denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + merged, + mask=mask, + ) + + +def merge_sparse_mla_subset_with_sink( + subset_output: torch.Tensor, + subset_lse: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert subset_output.shape == output.shape + assert subset_lse.shape == subset_output.shape[:2] + assert attn_sink.shape[0] == subset_output.shape[1] + assert subset_output.is_cuda + assert subset_lse.is_cuda + assert attn_sink.is_cuda + assert output.is_cuda + + num_tokens, num_heads, head_dim = subset_output.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _merge_single_subset_with_sink_kernel[grid]( + subset_output, + subset_lse, + attn_sink, + output, + subset_output.stride(0), + subset_output.stride(1), + subset_output.stride(2), + subset_lse.stride(0), + subset_lse.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _build_combined_decode_valid_mask_kernel( + output_ptr, + slot_ids_ptr, + topk_lens_ptr, + swa_lens_ptr, + stride_output_t: tl.constexpr, + stride_output_c: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + num_compressed_candidates: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_C: tl.constexpr, +): + token_idx = tl.program_id(0) + offsets = tl.arange(0, BLOCK_C) + candidate_mask = offsets < num_candidates + + topk_lens = tl.load(topk_lens_ptr + token_idx) + swa_lens = tl.load(swa_lens_ptr + token_idx) + is_compressed = offsets < num_compressed_candidates + swa_offsets = offsets - num_compressed_candidates + slot_ids = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + offsets * stride_slot_c, + mask=is_compressed, + other=-1, + ) + valid_compressed = is_compressed & (offsets < topk_lens) & (slot_ids >= 0) + valid_swa = (~is_compressed) & (swa_offsets < swa_lens) + valid = valid_compressed | valid_swa + tl.store( + output_ptr + token_idx * stride_output_t + offsets * stride_output_c, + valid, + mask=candidate_mask, + ) + + +def build_combined_sparse_mla_decode_valid_mask( + output: torch.Tensor, + compressed_slot_ids: torch.Tensor, + topk_lens: torch.Tensor, + swa_lens: torch.Tensor, +) -> None: + """Build `[compressed, SWA]` validity mask for SM12x decode.""" + if compressed_slot_ids.dim() == 3: + assert compressed_slot_ids.shape[1] == 1 + compressed_slot_ids = compressed_slot_ids[:, 0, :] + + assert output.dim() == 2 + assert output.dtype == torch.bool + assert compressed_slot_ids.dim() == 2 + assert output.shape[0] == compressed_slot_ids.shape[0] + assert output.shape[0] == topk_lens.shape[0] + assert output.shape[0] == swa_lens.shape[0] + assert output.shape[1] >= compressed_slot_ids.shape[1] + assert output.is_cuda + assert compressed_slot_ids.is_cuda + assert topk_lens.is_cuda + assert swa_lens.is_cuda + + num_candidates = output.shape[1] + block_c = triton.next_power_of_2(num_candidates) + _build_combined_decode_valid_mask_kernel[(output.shape[0],)]( + output, + compressed_slot_ids, + topk_lens, + swa_lens, + output.stride(0), + output.stride(1), + compressed_slot_ids.stride(0), + compressed_slot_ids.stride(1), + compressed_slot_ids.shape[1], + num_candidates, + BLOCK_C=block_c, + num_warps=4, + ) + + +def matmul_sparse_mla_attention_with_sink( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + num_heads: int | None = None, + score_buffer: torch.Tensor | None = None, + head_block_size: int = 1, + value_block_size: int | None = None, + candidate_block_size: int | None = None, +) -> None: + """Compute sink-aware sparse MLA over materialized BF16 KV. + + This path intentionally dequantizes/gathers KV once, computes scores with + batched matrix multiplication, and finishes the sink-aware value reduction + in Triton. It is useful for the SM12x decode path where the direct Triton + kernel otherwise repeats fp8_ds_mla dequantization once per head group. + """ + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}" + assert valid_tokens.shape == kv.shape[:2] + assert q.shape[0] == kv.shape[0] + assert q.shape[-1] == kv.shape[-1] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert q.is_cuda and kv.is_cuda and valid_tokens.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + + q_active = q[:, :active_heads] + num_tokens = q.shape[0] + num_candidates = kv.shape[1] + if score_buffer is None: + score_buffer = torch.empty( + (num_tokens, active_heads, num_candidates), + dtype=torch.float32, + device=q.device, + ) + assert score_buffer.shape == (num_tokens, active_heads, num_candidates) + assert score_buffer.device == q.device + assert score_buffer.dtype in (torch.float32, torch.bfloat16) + if score_buffer.dtype == torch.float32: + q_score = q_active.float() + kv_score = kv.float() + else: + q_score = q_active.to(score_buffer.dtype) + kv_score = kv.to(score_buffer.dtype) + torch.bmm(q_score, kv_score.transpose(1, 2), out=score_buffer) + score_buffer.mul_(scale) + finish_materialized_sparse_mla_scores_with_sink( + score_buffer, + kv, + valid_tokens, + attn_sink, + output, + num_heads=active_heads, + head_block_size=head_block_size, + value_block_size=value_block_size, + candidate_block_size=candidate_block_size, + ) + + +@triton.jit +def _finish_materialized_scores_with_sink_kernel( + scores_ptr, + kv_ptr, + valid_tokens_ptr, + attn_sink_ptr, + output_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_valid_t: tl.constexpr, + stride_valid_c: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + running_max = tl.load(attn_sink_ptr + head_offsets, mask=head_mask, other=0.0).to( + tl.float32 + ) + running_denom = tl.full((HEAD_BLOCK,), 1.0, tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + for candidate_idx in range(0, num_candidates): + is_valid = tl.load( + valid_tokens_ptr + + token_idx * stride_valid_t + + candidate_idx * stride_valid_c + ) + if is_valid: + score = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_offsets * stride_scores_h + + candidate_idx * stride_scores_c, + mask=head_mask, + other=-float("inf"), + ).to(tl.float32) + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidate_idx * stride_kv_c + + dim_offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + result = running_acc / running_denom[:, None] + tl.store( + output_ptr + + token_idx * stride_out_t + + head_offsets[:, None] * stride_out_h + + dim_offsets[None, :] * stride_out_d, + result, + mask=matrix_mask, + ) + + +@triton.jit +def _finish_materialized_scores_with_sink_candidate_block_kernel( + scores_ptr, + kv_ptr, + valid_tokens_ptr, + attn_sink_ptr, + output_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_valid_t: tl.constexpr, + stride_valid_c: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + head_dim: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + dim_block_idx = tl.program_id(2) + candidate_offsets = tl.arange(0, BLOCK_K) + dim_offsets = dim_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = dim_offsets < head_dim + + max_score = tl.load(attn_sink_ptr + head_idx).to(tl.float32) + for candidate_start in range(0, num_candidates, BLOCK_K): + candidates = candidate_start + candidate_offsets + candidate_mask = candidates < num_candidates + is_valid = tl.load( + valid_tokens_ptr + token_idx * stride_valid_t + candidates * stride_valid_c, + mask=candidate_mask, + other=0, + ).to(tl.int1) + scores = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidates * stride_scores_c, + mask=candidate_mask & is_valid, + other=-float("inf"), + ).to(tl.float32) + max_score = tl.maximum(max_score, tl.max(scores, axis=0)) + + denom = tl.exp(tl.load(attn_sink_ptr + head_idx).to(tl.float32) - max_score) + acc = tl.zeros((BLOCK_D,), tl.float32) + for candidate_start in range(0, num_candidates, BLOCK_K): + candidates = candidate_start + candidate_offsets + candidate_mask = candidates < num_candidates + is_valid = tl.load( + valid_tokens_ptr + token_idx * stride_valid_t + candidates * stride_valid_c, + mask=candidate_mask, + other=0, + ).to(tl.int1) + scores = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidates * stride_scores_c, + mask=candidate_mask & is_valid, + other=-float("inf"), + ).to(tl.float32) + weights = tl.exp(scores - max_score) + denom += tl.sum(weights, axis=0) + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidates[:, None] * stride_kv_c + + dim_offsets[None, :] * stride_kv_d, + mask=(candidate_mask & is_valid)[:, None] & dim_mask[None, :], + other=0.0, + ) + acc += tl.sum(kv.to(tl.float32) * weights[:, None], axis=0) + + tl.store( + output_ptr + + token_idx * stride_out_t + + head_idx * stride_out_h + + dim_offsets * stride_out_d, + acc / denom, + mask=dim_mask, + ) + + +@triton.jit +def _finish_materialized_scores_with_sink_value_block_kernel( + scores_ptr, + kv_ptr, + valid_tokens_ptr, + attn_sink_ptr, + output_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_valid_t: tl.constexpr, + stride_valid_c: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + head_dim: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + dim_block_idx = tl.program_id(2) + dim_offsets = dim_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = dim_offsets < head_dim + + running_max = tl.load(attn_sink_ptr + head_idx).to(tl.float32) + running_denom = tl.full((), 1.0, tl.float32) + running_acc = tl.zeros((BLOCK_D,), tl.float32) + + for candidate_idx in range(0, num_candidates): + is_valid = tl.load( + valid_tokens_ptr + + token_idx * stride_valid_t + + candidate_idx * stride_valid_c + ) + if is_valid: + score = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidate_idx * stride_scores_c + ).to(tl.float32) + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidate_idx * stride_kv_c + + dim_offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + result = running_acc / running_denom + tl.store( + output_ptr + + token_idx * stride_out_t + + head_idx * stride_out_h + + dim_offsets * stride_out_d, + result, + mask=dim_mask, + ) + + +def finish_materialized_sparse_mla_scores_with_sink( + scores: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, + num_heads: int | None = None, + head_block_size: int = 1, + value_block_size: int | None = None, + candidate_block_size: int | None = None, +) -> None: + assert scores.dim() == 3 + assert kv.dim() == 3 + assert valid_tokens.shape == kv.shape[:2] + assert scores.shape[0] == kv.shape[0] + assert scores.shape[2] == kv.shape[1] + assert output.shape[0] == kv.shape[0] + assert output.shape[2] == kv.shape[2] + assert scores.dtype in (torch.float32, torch.bfloat16) + assert head_block_size in (1, 2, 4) + if value_block_size is not None: + assert value_block_size in (64, 128, 256, 512) + if candidate_block_size is not None: + assert candidate_block_size in (16, 32, 64, 128) + assert scores.is_cuda and kv.is_cuda and valid_tokens.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= scores.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + + num_tokens, _, num_candidates = scores.shape + head_dim = kv.shape[2] + if candidate_block_size is not None: + block_d = value_block_size if value_block_size is not None else 128 + candidate_grid = (num_tokens, active_heads, triton.cdiv(head_dim, block_d)) + _finish_materialized_scores_with_sink_candidate_block_kernel[candidate_grid]( + scores, + kv, + valid_tokens, + attn_sink, + output, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + valid_tokens.stride(0), + valid_tokens.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + head_dim, + num_candidates, + BLOCK_K=candidate_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + if output.shape[1] > active_heads: + output[:, active_heads:].zero_() + return + + if value_block_size is not None and value_block_size < head_dim: + value_grid = ( + num_tokens, + active_heads, + triton.cdiv(head_dim, value_block_size), + ) + _finish_materialized_scores_with_sink_value_block_kernel[value_grid]( + scores, + kv, + valid_tokens, + attn_sink, + output, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + valid_tokens.stride(0), + valid_tokens.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + head_dim, + num_candidates, + BLOCK_D=value_block_size, + num_warps=4, + ) + if output.shape[1] > active_heads: + output[:, active_heads:].zero_() + return + + block_d = min(1024, triton.next_power_of_2(head_dim)) + head_grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _finish_materialized_scores_with_sink_kernel[head_grid]( + scores, + kv, + valid_tokens, + attn_sink, + output, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + valid_tokens.stride(0), + valid_tokens.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + active_heads, + head_dim, + num_candidates, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + if output.shape[1] > active_heads: + output[:, active_heads:].zero_() + + +@triton.jit +def _accumulate_gathered_attention_chunk_kernel( + q_ptr, + kv_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HAS_SLOT_IDS: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + + for candidate_idx in range(0, num_candidates): + is_valid = (candidate_offset + candidate_idx) < valid_len + if HAS_SLOT_IDS: + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = is_valid & (slot_id >= 0) + + if is_valid: + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidate_idx * stride_kv_c + + offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_gathered_sparse_mla_attention_chunk( + q: torch.Tensor, + kv: torch.Tensor, + lens: torch.Tensor, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, + slot_ids: torch.Tensor | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}" + assert q.shape[0] == kv.shape[0] + assert q.shape[-1] == kv.shape[-1] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert q.is_cuda and kv.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + if slot_ids is not None: + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + assert slot_ids.dim() == 2 + assert slot_ids.shape == kv.shape[:2] + assert slot_ids.is_cuda + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = kv.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_gathered_attention_chunk_kernel[grid]( + q, + kv, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + slot_ids.stride(0) if slot_ids is not None else 0, + slot_ids.stride(1) if slot_ids is not None else 0, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HAS_SLOT_IDS=slot_ids is not None, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_indexed_attention_chunk_kernel( + q_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + + for candidate_idx in range(0, num_candidates): + kv_index = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidate_idx * stride_indices_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (kv_index >= 0) + + if is_valid: + kv = tl.load( + kv_flat_ptr + + kv_index.to(tl.int64) * stride_kv_t + + offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_indexed_sparse_mla_attention_chunk( + q: torch.Tensor, + kv_flat: torch.Tensor, + indices: torch.Tensor, + lens: torch.Tensor, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv_flat.dim() == 2 + assert indices.dim() == 2 + assert indices.shape[0] == q.shape[0] + assert kv_flat.shape[-1] == q.shape[-1] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = indices.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_indexed_attention_chunk_kernel[grid]( + q, + kv_flat, + indices, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_global_slots_attention_chunk_kernel( + q_ptr, + k_cache_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + + fp8_mask = offsets < fp8_dim + rope_mask = (offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0) + + if is_valid: + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk( + q: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + lens: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = slot_ids.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_fp8ds_global_slots_attention_chunk_kernel[grid]( + q, + k_cache, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel( + q_ptr, + k_cache_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h + acc_offsets = ( + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d + ) + running_max = tl.load( + max_score_ptr + state_offsets, + mask=head_mask, + other=-float("inf"), + ) + running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) + running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( + tl.float32 + ) + valid_len = tl.load(lens_ptr + token_idx) + + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0) + + if is_valid: + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) + tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) + tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + + +def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + lens: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, + head_block_size: int = 2, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert head_block_size in (1, 2, 4) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = slot_ids.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) + _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel[grid]( + q, + k_cache, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_paged_attention_chunk_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = offsets < fp8_dim + rope_mask = (offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len + + if is_valid: + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int, + num_candidates: int, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_fp8ds_paged_attention_chunk_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h + acc_offsets = ( + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d + ) + running_max = tl.load( + max_score_ptr + state_offsets, + mask=head_mask, + other=-float("inf"), + ) + running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) + running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( + tl.float32 + ) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len + + if is_valid: + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) + tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) + tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + + +def accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int, + num_candidates: int, + head_block_size: int = 2, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert head_block_size in (1, 2, 4) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) + _accumulate_fp8ds_paged_attention_chunk_multihead_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _fp8ds_paged_attention_with_sink_multihead_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + sink_ptr, + output_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + candidate_offset: tl.constexpr, + num_candidates: tl.constexpr, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + running_max = tl.full((HEAD_BLOCK,), -float("inf"), tl.float32) + running_denom = tl.zeros((HEAD_BLOCK,), tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len + if is_valid: + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + sink = tl.load(sink_ptr + head_offsets, mask=head_mask, other=-float("inf")) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = running_denom * subset_scale + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + final = running_acc * subset_scale[:, None] * inv_total[:, None] + + tl.store( + output_ptr + + token_idx * stride_output_t + + head_offsets[:, None] * stride_output_h + + dim_offsets[None, :] * stride_output_d, + final, + mask=matrix_mask, + ) + + +def fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + candidate_offset: int, + num_candidates: int, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + head_block_size: int = 1, + num_heads: int | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert head_block_size in (1, 2, 4) + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _fp8ds_paged_attention_with_sink_multihead_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + attn_sink, + output, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + output.stride(0), + output.stride(1), + output.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + active_heads, + head_dim, + candidate_offset, + num_candidates, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _fp8ds_global_paged_attention_with_sink_multihead_kernel( + q_ptr, + compressed_k_cache_ptr, + slot_ids_ptr, + topk_lens_ptr, + swa_k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + sink_ptr, + output_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_block_table_t, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + compressed_cache_block_size: tl.constexpr, + compressed_block_stride: tl.constexpr, + swa_cache_block_size: tl.constexpr, + swa_block_stride: tl.constexpr, + token_data_size: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_compressed_candidates: tl.constexpr, + num_swa_candidates: tl.constexpr, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + running_max = tl.full((HEAD_BLOCK,), -float("inf"), tl.float32) + running_denom = tl.zeros((HEAD_BLOCK,), tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + topk_len = tl.load(topk_lens_ptr + token_idx) + + for candidate_idx in range(0, num_compressed_candidates): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = (candidate_idx < topk_len) & (slot_id >= 0) + if is_valid: + block_idx = slot_id // compressed_cache_block_size + pos_in_block = slot_id % compressed_cache_block_size + cache_block_ptr = ( + compressed_k_cache_ptr + + block_idx.to(tl.int64) * compressed_block_stride + ) + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + compressed_cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + for candidate_idx in range(0, num_swa_candidates): + is_valid = candidate_idx < gather_len + if is_valid: + pos = start_pos + candidate_idx + block_in_seq = pos // swa_cache_block_size + pos_in_block = pos % swa_cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = ( + swa_k_cache_ptr + physical_block.to(tl.int64) * swa_block_stride + ) + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + swa_cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + sink = tl.load(sink_ptr + head_offsets, mask=head_mask, other=-float("inf")) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = running_denom * subset_scale + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + final = running_acc * subset_scale[:, None] * inv_total[:, None] + + tl.store( + output_ptr + + token_idx * stride_output_t + + head_offsets[:, None] * stride_output_h + + dim_offsets[None, :] * stride_output_d, + final, + mask=matrix_mask, + ) + + +def fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q: torch.Tensor, + compressed_k_cache: torch.Tensor, + slot_ids: torch.Tensor, + topk_lens: torch.Tensor, + compressed_block_size: int, + swa_k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + swa_block_size: int, + num_compressed_candidates: int, + num_swa_candidates: int, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + head_block_size: int = 1, + num_heads: int | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert topk_lens.shape[0] == q.shape[0] + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert head_block_size in (1, 2, 4) + assert compressed_k_cache.dtype == torch.uint8 + assert swa_k_cache.dtype == torch.uint8 + assert q.is_cuda and compressed_k_cache.is_cuda and swa_k_cache.is_cuda + assert slot_ids.is_cuda and topk_lens.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _fp8ds_global_paged_attention_with_sink_multihead_kernel[grid]( + q, + compressed_k_cache, + slot_ids, + topk_lens, + swa_k_cache, + seq_lens, + gather_lens, + block_table, + attn_sink, + output, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + block_table.stride(0), + output.stride(0), + output.stride(1), + output.stride(2), + compressed_block_size, + compressed_k_cache.stride(0), + swa_block_size, + swa_k_cache.stride(0), + token_data_size, + token_fp8_dim, + token_scale_dim, + quant_block_size, + active_heads, + head_dim, + num_compressed_candidates, + num_swa_candidates, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _finish_attention_state_kernel( + max_score_ptr, + denom_ptr, + acc_ptr, + output_ptr, + lse_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + stride_lse_t: tl.constexpr, + stride_lse_h: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + is_valid = running_denom > 0.0 + inv_denom = tl.where(is_valid, 1.0 / running_denom, 0.0) + subset_lse = tl.where( + is_valid, + running_max + tl.log(running_denom), + -float("inf"), + ) + + acc = tl.load( + acc_ptr + + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + subset_output = acc * inv_denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + subset_output, + mask=dim_mask, + ) + if block_d == 0: + tl.store( + lse_ptr + token_idx * stride_lse_t + head_idx * stride_lse_h, + subset_lse, + ) + + +def finish_gathered_sparse_mla_attention( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + output: torch.Tensor, + lse: torch.Tensor, +) -> None: + assert max_score.shape == denom.shape + assert acc.shape[:2] == max_score.shape + assert output.shape == acc.shape + assert lse.shape == max_score.shape + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert output.dtype == torch.float32 + assert lse.dtype == torch.float32 + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert output.is_cuda and lse.is_cuda + + num_tokens, num_heads, head_dim = acc.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_attention_state_kernel[grid]( + max_score, + denom, + acc, + output, + lse, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + lse.stride(0), + lse.stride(1), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _finish_attention_state_with_sink_kernel( + max_score_ptr, + denom_ptr, + acc_ptr, + sink_ptr, + output_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + sink = tl.load(sink_ptr + head_idx) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0) + subset_weight = running_denom * subset_scale + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = subset_weight + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + + acc_values = tl.load( + acc_ptr + + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc_values = tl.where(has_tokens, acc_values, 0.0) + output = acc_values * subset_scale * inv_total + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + output, + mask=dim_mask, + ) + + +@triton.jit +def _finish_two_attention_states_with_sink_kernel( + max0_ptr, + denom0_ptr, + acc0_ptr, + max1_ptr, + denom1_ptr, + acc1_ptr, + sink_ptr, + output_ptr, + stride_state0_t: tl.constexpr, + stride_state0_h: tl.constexpr, + stride_acc0_t: tl.constexpr, + stride_acc0_h: tl.constexpr, + stride_acc0_d: tl.constexpr, + stride_state1_t: tl.constexpr, + stride_state1_h: tl.constexpr, + stride_acc1_t: tl.constexpr, + stride_acc1_h: tl.constexpr, + stride_acc1_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state0_offset = token_idx * stride_state0_t + head_idx * stride_state0_h + state1_offset = token_idx * stride_state1_t + head_idx * stride_state1_h + max0 = tl.load(max0_ptr + state0_offset) + denom0 = tl.load(denom0_ptr + state0_offset) + max1 = tl.load(max1_ptr + state1_offset) + denom1 = tl.load(denom1_ptr + state1_offset) + sink = tl.load(sink_ptr + head_idx) + + has0 = denom0 > 0.0 + has1 = denom1 > 0.0 + has_sink = sink > -float("inf") + valid_max0 = tl.where(has0, max0, -float("inf")) + valid_max1 = tl.where(has1, max1, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(tl.maximum(valid_max0, valid_max1), valid_sink) + has_any = has0 | has1 | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_max0 = tl.where(has0, max0, safe_merge_max) + safe_max1 = tl.where(has1, max1, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + scale0 = tl.where(has0, tl.exp(safe_max0 - safe_merge_max), 0.0) + scale1 = tl.where(has1, tl.exp(safe_max1 - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = denom0 * scale0 + denom1 * scale1 + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + + acc0 = tl.load( + acc0_ptr + + token_idx * stride_acc0_t + + head_idx * stride_acc0_h + + offsets * stride_acc0_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc1 = tl.load( + acc1_ptr + + token_idx * stride_acc1_t + + head_idx * stride_acc1_h + + offsets * stride_acc1_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc0 = tl.where(has0, acc0, 0.0) + acc1 = tl.where(has1, acc1, 0.0) + output = (acc0 * scale0 + acc1 * scale1) * inv_total + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + output, + mask=dim_mask, + ) + + +def finish_two_sparse_mla_attention_states_with_sink( + max_score0: torch.Tensor, + denom0: torch.Tensor, + acc0: torch.Tensor, + max_score1: torch.Tensor, + denom1: torch.Tensor, + acc1: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert max_score0.shape == denom0.shape + assert max_score1.shape == denom1.shape + assert max_score0.shape == max_score1.shape + assert acc0.shape == acc1.shape + assert acc0.shape[:2] == max_score0.shape + assert output.shape[0] == acc0.shape[0] + assert output.shape[1] >= acc0.shape[1] + assert output.shape[2] == acc0.shape[2] + assert attn_sink.shape[0] >= acc0.shape[1] + assert max_score0.dtype == torch.float32 + assert denom0.dtype == torch.float32 + assert acc0.dtype == torch.float32 + assert max_score1.dtype == torch.float32 + assert denom1.dtype == torch.float32 + assert acc1.dtype == torch.float32 + assert max_score0.is_cuda and denom0.is_cuda and acc0.is_cuda + assert max_score1.is_cuda and denom1.is_cuda and acc1.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + num_tokens, num_heads, head_dim = acc0.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_two_attention_states_with_sink_kernel[grid]( + max_score0, + denom0, + acc0, + max_score1, + denom1, + acc1, + attn_sink, + output, + max_score0.stride(0), + max_score0.stride(1), + acc0.stride(0), + acc0.stride(1), + acc0.stride(2), + max_score1.stride(0), + max_score1.stride(1), + acc1.stride(0), + acc1.stride(1), + acc1.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +def finish_sparse_mla_attention_with_sink( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert max_score.shape == denom.shape + assert acc.shape[:2] == max_score.shape + assert output.shape[0] == acc.shape[0] + assert output.shape[1] >= acc.shape[1] + assert output.shape[2] == acc.shape[2] + assert attn_sink.shape[0] >= acc.shape[1] + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + num_tokens, num_heads, head_dim = acc.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_attention_state_with_sink_kernel[grid]( + max_score, + denom, + acc, + attn_sink, + output, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_reference.py b/vllm/v1/attention/backends/mla/sparse_mla_reference.py new file mode 100644 index 000000000000..203b64188202 --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_reference.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Reference sparse MLA attention helpers. + +The helpers in this module intentionally use PyTorch tensor operations. They +are the correctness-first contract for portable sparse MLA fallbacks and tests; +optimized Triton/CUDA kernels should preserve these semantics. +""" + +import torch + + +def new_reference_attention_state( + q: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if q.dim() == 4: + q_bhd = q[:, 0, :, :].float() + else: + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + q_bhd = q.float() + + num_tokens = q_bhd.shape[0] + num_heads = q_bhd.shape[1] + head_dim = q_bhd.shape[2] + max_score = torch.full( + (num_tokens, num_heads), + float("-inf"), + dtype=torch.float32, + device=q.device, + ) + denom = torch.zeros_like(max_score) + acc = torch.zeros( + (num_tokens, num_heads, head_dim), + dtype=torch.float32, + device=q.device, + ) + return q_bhd, max_score, denom, acc + + +def accumulate_reference_attention_chunk( + q_bhd: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + scale: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + kv_btd = kv.float() + kv_btd = torch.where( + valid_tokens[:, :, None], + kv_btd, + torch.zeros((), dtype=kv_btd.dtype, device=kv_btd.device), + ) + scores = torch.einsum("bhd,btd->bht", q_bhd, kv_btd) * scale + scores = scores.masked_fill(~valid_tokens[:, None, :], float("-inf")) + + chunk_max = scores.amax(dim=-1) + next_max = torch.maximum(max_score, chunk_max) + + previous_scale = torch.exp(max_score - next_max) + previous_scale = torch.nan_to_num(previous_scale) + weights = torch.exp(scores - next_max[:, :, None]) + weights = torch.where( + valid_tokens[:, None, :], + weights, + torch.zeros((), dtype=weights.dtype, device=weights.device), + ) + weights = torch.nan_to_num(weights) + + acc = acc * previous_scale[:, :, None] + denom = denom * previous_scale + acc = acc + torch.einsum("bht,btd->bhd", weights, kv_btd) + denom = denom + weights.sum(dim=-1) + return next_max, denom, acc + + +def finish_reference_attention_no_sink( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + valid = denom > 0 + safe_denom = torch.where(valid, denom, torch.ones_like(denom)) + subset_output = acc / safe_denom[:, :, None] + subset_output = torch.where( + valid[:, :, None], + subset_output, + torch.zeros((), dtype=subset_output.dtype, device=subset_output.device), + ) + subset_lse = torch.where( + valid, + max_score + torch.log(safe_denom), + torch.full_like(max_score, float("-inf")), + ) + return subset_output, subset_lse + + +def reference_attention_no_sink( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, +) -> tuple[torch.Tensor, torch.Tensor]: + q_bhd, max_score, denom, acc = new_reference_attention_state(q) + max_score, denom, acc = accumulate_reference_attention_chunk( + q_bhd=q_bhd, + kv=kv, + valid_tokens=valid_tokens, + max_score=max_score, + denom=denom, + acc=acc, + scale=scale, + ) + return finish_reference_attention_no_sink(max_score, denom, acc) + + +def merge_reference_attention_with_sink( + subset_outputs: list[torch.Tensor], + subset_lses: list[torch.Tensor], + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert subset_outputs, "At least one attention subset is required" + assert len(subset_outputs) == len(subset_lses) + + sink = attn_sink[None, :].float() + merge_max = sink + for subset_lse in subset_lses: + merge_max = torch.maximum(merge_max, subset_lse) + + safe_merge_max = torch.where( + torch.isfinite(merge_max), merge_max, torch.zeros_like(merge_max) + ) + merged_acc = torch.zeros_like(subset_outputs[0], dtype=torch.float32) + sink_weight = torch.exp(sink - safe_merge_max) + sink_weight = torch.nan_to_num(sink_weight) + merged_denom = sink_weight + for subset_output, subset_lse in zip(subset_outputs, subset_lses): + subset_weight = torch.exp(subset_lse - safe_merge_max) + subset_weight = torch.nan_to_num(subset_weight) + merged_acc = merged_acc + subset_output.float() * subset_weight[:, :, None] + merged_denom = merged_denom + subset_weight + + safe_denom = torch.where( + merged_denom > 0, merged_denom, torch.ones_like(merged_denom) + ) + reference_output = merged_acc / safe_denom[:, :, None] + reference_output = torch.where( + (merged_denom > 0)[:, :, None], + reference_output, + torch.zeros((), dtype=reference_output.dtype, device=reference_output.device), + ) + output.copy_(reference_output.to(dtype=output.dtype)) + + +def sink_aware_reference_attention( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + subset_output, subset_lse = reference_attention_no_sink( + q=q, + kv=kv, + valid_tokens=valid_tokens, + scale=scale, + ) + merge_reference_attention_with_sink( + subset_outputs=[subset_output], + subset_lses=[subset_lse], + attn_sink=attn_sink, + output=output, + ) + + +def reference_sparse_mla_prefill( + q: torch.Tensor, + kv: torch.Tensor, + combined_indices: torch.Tensor, + combined_lens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + topk_chunk_size: int, + query_chunk_size: int, +) -> None: + kv_flat = kv.reshape(-1, q.shape[-1]) + topk_chunk_size = min(combined_indices.shape[-1], topk_chunk_size) + query_chunk_size = min(q.shape[0], query_chunk_size) + + for token_start in range(0, q.shape[0], query_chunk_size): + token_end = min(token_start + query_chunk_size, q.shape[0]) + q_chunk = q[token_start:token_end] + lens_chunk = combined_lens[token_start:token_end] + indices_chunk_full = combined_indices[token_start:token_end] + q_bhd, max_score, denom, acc = new_reference_attention_state(q_chunk) + + for index_start in range(0, combined_indices.shape[-1], topk_chunk_size): + index_end = min( + index_start + topk_chunk_size, + combined_indices.shape[-1], + ) + indices_chunk = indices_chunk_full[:, index_start:index_end] + index_offsets = torch.arange( + index_start, + index_end, + device=q.device, + ) + valid_tokens = ( + (index_offsets[None, :] < lens_chunk[:, None]) + & (indices_chunk >= 0) + ) + safe_indices = torch.where( + valid_tokens, + indices_chunk, + torch.zeros((), dtype=indices_chunk.dtype, device=q.device), + ).long() + gathered_kv = kv_flat[safe_indices] + max_score, denom, acc = accumulate_reference_attention_chunk( + q_bhd=q_bhd, + kv=gathered_kv, + valid_tokens=valid_tokens, + max_score=max_score, + denom=denom, + acc=acc, + scale=scale, + ) + + subset_output, subset_lse = finish_reference_attention_no_sink( + max_score, + denom, + acc, + ) + merge_reference_attention_with_sink( + subset_outputs=[subset_output], + subset_lses=[subset_lse], + attn_sink=attn_sink, + output=output[token_start:token_end], + ) diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index 28564e6a97d3..7689cf9e155a 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -16,9 +16,15 @@ CommonAttentionMetadata, MultipleOf, ) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled, + is_triton_sparse_mla_enabled_for_platform, + triton_sparse_mla_cudagraphs_allowed, +) from vllm.v1.attention.backends.utils import split_decodes_and_prefills from vllm.v1.attention.ops.flashmla import FlashMLASchedMeta, get_mla_metadata from vllm.v1.kv_cache_interface import ( + AttentionSpec, KVCacheSpec, MLAAttentionSpec, SlidingWindowMLASpec, @@ -162,6 +168,8 @@ class DeepseekSparseSWAMetadata: # Pre-computed prefill metadata shared across all DeepseekV4 attention layers. prefill_seq_lens: torch.Tensor | None = None prefill_gather_lens: torch.Tensor | None = None + prefill_seq_lens_cpu: torch.Tensor | None = None + prefill_gather_lens_cpu: torch.Tensor | None = None # Per-layer-type FlashMLA tile-scheduler metadata. One FlashMLASchedMeta # per present DeepseekV4 layer type, shared across all ~60 layers of that type @@ -195,6 +203,20 @@ class DeepseekSparseSWAMetadataBuilder(AttentionMetadataBuilder): reorder_batch_threshold: int = 1 _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + @classmethod + def get_cudagraph_support( + cls, + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + if ( + getattr(kv_cache_spec, "model_version", None) == "deepseek_v4" + and is_triton_sparse_mla_enabled_for_platform() + and not triton_sparse_mla_cudagraphs_allowed(vllm_config) + ): + return AttentionCGSupport.NEVER + return cls._cudagraph_support + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec) @@ -313,6 +335,8 @@ def build( num_prefills, seq_lens, query_start_loc, + query_start_loc_cpu, + common_attn_metadata.seq_lens_cpu_upper_bound, ) # Per-layer-type tile-scheduler plan holders. Empty FlashMLASchedMeta @@ -363,6 +387,8 @@ def build_tile_scheduler( } if num_decode_tokens == 0 or current_platform.is_rocm(): return out + if is_triton_sparse_mla_enabled(self.device): + return out for layer_type in self._layer_types: # get_mla_metadata() is the official FlashMLA entry point that # returns a fresh empty FlashMLASchedMeta; using it keeps this @@ -377,6 +403,8 @@ def _build_deepseek_v4_metadata( num_prefills: int, seq_lens: torch.Tensor, query_start_loc: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + seq_lens_cpu_upper_bound: torch.Tensor | None, ) -> dict[str, torch.Tensor | None]: """Pre-compute DeepseekV4 prefill metadata during the metadata build phase. @@ -403,8 +431,27 @@ def _build_deepseek_v4_metadata( BLOCK_SIZE=triton.next_power_of_2(num_prefills), ) + assert seq_lens_cpu_upper_bound is not None + seq_lens_cpu = seq_lens_cpu_upper_bound + prefill_seq_lens_cpu = seq_lens_cpu[ + num_decodes : num_decodes + num_prefills + ] + query_lens_cpu = ( + query_start_loc_cpu[ + num_decodes + 1 : num_decodes + num_prefills + 1 + ] + - query_start_loc_cpu[num_decodes : num_decodes + num_prefills] + ) + prefix_lens_cpu = prefill_seq_lens_cpu - query_lens_cpu + prefill_gather_lens_cpu = query_lens_cpu + torch.minimum( + prefix_lens_cpu, + torch.full_like(prefix_lens_cpu, self.window_size - 1), + ) + result["prefill_seq_lens"] = seq_lens[num_decodes:] result["prefill_gather_lens"] = pfx_gather_lens + result["prefill_seq_lens_cpu"] = prefill_seq_lens_cpu + result["prefill_gather_lens_cpu"] = prefill_gather_lens_cpu return result diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py index 959a79f292a5..da04498f384f 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py @@ -5,7 +5,10 @@ combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, + dequantize_combined_sparse_mla_decode_kv, + dequantize_global_slots_k_cache, quantize_and_insert_k_cache, + sparse_prefill_combined_topk_size, ) from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant from .fused_inv_rope_fp8_quant import fused_inv_rope_fp8_quant @@ -16,8 +19,11 @@ "combine_topk_swa_indices", "compute_global_topk_indices_and_lens", "dequantize_and_gather_k_cache", + "dequantize_combined_sparse_mla_decode_kv", + "dequantize_global_slots_k_cache", "fused_indexer_q_rope_quant", "fused_inv_rope_fp8_quant", "fused_q_kv_rmsnorm", "quantize_and_insert_k_cache", + "sparse_prefill_combined_topk_size", ] diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 69d20c107e11..33cfe699236f 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -349,12 +349,166 @@ def dequantize_and_gather_k_cache( ) +@triton.jit +def _dequantize_global_slots_k_kernel( + out_ptr, + out_stride_token, + out_stride_slot, + k_cache_ptr, + slot_ids_ptr, + slot_ids_stride_token, + slot_ids_stride_slot, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + bf16_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + output_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + topk_idx = tl.program_id(1) + + slot_id = tl.load( + slot_ids_ptr + + token_idx * slot_ids_stride_token + + topk_idx * slot_ids_stride_slot + ) + offsets = tl.arange(0, BLOCK_D) + output_row = out_ptr + token_idx * out_stride_token + topk_idx * out_stride_slot + + if slot_id < 0: + tl.store( + output_row + offsets, + tl.zeros((BLOCK_D,), dtype=tl.float32).to(tl.bfloat16), + mask=offsets < output_dim, + ) + return + + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + cache_block_size * token_data_size + pos_in_block * scale_dim + ) + + fp8_offsets = tl.arange(0, 512) + fp8_mask = fp8_offsets < fp8_dim + x_uint8 = tl.load(token_data_ptr + fp8_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + + scale_offsets = fp8_offsets // quant_block + encoded_scale = tl.load(token_scale_ptr + scale_offsets, mask=fp8_mask, other=127) + scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * scale + tl.store(output_row + fp8_offsets, x_dequant.to(tl.bfloat16), mask=fp8_mask) + + bf16_offsets = tl.arange(0, 64) + bf16_cache_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + bf16_vals = tl.load(bf16_cache_ptr + bf16_offsets, mask=bf16_offsets < bf16_dim) + tl.store( + output_row + fp8_dim + bf16_offsets, + bf16_vals, + mask=bf16_offsets < bf16_dim, + ) + + +def dequantize_global_slots_k_cache( + out: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + block_size: int, +) -> None: + """Dequantize fp8_ds_mla cache rows addressed by physical global slot ids.""" + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0, :] + assert slot_ids.dim() == 2, ( + f"slot_ids must be [num_tokens, topk], got {slot_ids.shape}" + ) + assert out.shape[:2] == slot_ids.shape + assert out.shape[-1] == 512 + assert out.dtype == torch.bfloat16 + assert k_cache.dtype == torch.uint8 + + TOKEN_FP8_DIM = 448 + TOKEN_BF16_DIM = 64 + TOKEN_SCALE_DIM = 8 + QUANT_BLOCK_SIZE = 64 + TOKEN_DATA_SIZE = TOKEN_FP8_DIM + TOKEN_BF16_DIM * 2 + + grid = slot_ids.shape + _dequantize_global_slots_k_kernel[grid]( + out, + out.stride(0), + out.stride(1), + k_cache, + slot_ids, + slot_ids.stride(0), + slot_ids.stride(1), + cache_block_size=block_size, + token_data_size=TOKEN_DATA_SIZE, + block_stride=k_cache.stride(0), + fp8_dim=TOKEN_FP8_DIM, + bf16_dim=TOKEN_BF16_DIM, + scale_dim=TOKEN_SCALE_DIM, + quant_block=QUANT_BLOCK_SIZE, + output_dim=512, + BLOCK_D=triton.next_power_of_2(512), + ) + + +def dequantize_combined_sparse_mla_decode_kv( + combined_kv: torch.Tensor, + compressed_k_cache: torch.Tensor, + compressed_slot_ids: torch.Tensor, + compressed_block_size: int, + swa_k_cache: torch.Tensor, + seq_lens: torch.Tensor, + swa_lens: torch.Tensor, + block_table: torch.Tensor, + swa_block_size: int, +) -> None: + """Fill `[compressed, SWA]` decode candidates into one output buffer.""" + assert combined_kv.dim() == 3 + compressed_topk = compressed_slot_ids.shape[-1] + assert combined_kv.shape[0] == compressed_slot_ids.shape[0] + assert combined_kv.shape[-1] == 512 + assert combined_kv.dtype == torch.bfloat16 + assert combined_kv.shape[1] >= compressed_topk + + dequantize_global_slots_k_cache( + combined_kv[:, :compressed_topk], + compressed_k_cache, + compressed_slot_ids, + compressed_block_size, + ) + swa_out = combined_kv[:, compressed_topk:] + if swa_out.shape[1] == 0: + return + dequantize_and_gather_k_cache( + swa_out, + swa_k_cache, + seq_lens=seq_lens, + gather_lens=swa_lens, + block_table=block_table, + block_size=swa_block_size, + offset=0, + ) + + def compute_global_topk_indices_and_lens( topk_indices: torch.Tensor, token_to_req_indices: torch.Tensor, block_table: torch.Tensor, block_size: int, is_valid_token: torch.Tensor, + global_topk_indices: torch.Tensor | None = None, + topk_lens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Map local topk indices to global KV cache slots and count valid entries. @@ -364,8 +518,20 @@ def compute_global_topk_indices_and_lens( 3. Masking padding tokens to length 0 """ num_tokens = topk_indices.shape[0] - global_topk_indices = torch.empty_like(topk_indices) - topk_lens = torch.empty(num_tokens, dtype=torch.int32, device=topk_indices.device) + if global_topk_indices is None: + global_topk_indices = torch.empty_like(topk_indices) + else: + assert global_topk_indices.shape == topk_indices.shape + assert global_topk_indices.dtype == topk_indices.dtype + assert global_topk_indices.device == topk_indices.device + if topk_lens is None: + topk_lens = torch.empty( + num_tokens, dtype=torch.int32, device=topk_indices.device + ) + else: + assert topk_lens.shape == (num_tokens,) + assert topk_lens.dtype == torch.int32 + assert topk_lens.device == topk_indices.device _compute_global_topk_indices_and_lens_kernel[(num_tokens,)]( global_topk_indices, global_topk_indices.stride(0), @@ -412,7 +578,7 @@ def _compute_global_topk_indices_and_lens_kernel( mask=mask, other=-1, ) - is_valid = local_idx >= 0 + is_valid = (local_idx >= 0) & is_valid_token block_indices = local_idx // block_size block_numbers = tl.load( @@ -442,6 +608,14 @@ def _compute_global_topk_indices_and_lens_kernel( _SPARSE_PREFILL_TOPK_ALIGNMENT = 128 +def sparse_prefill_combined_topk_size(topk: int, window_size: int) -> int: + return ( + (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1) + // _SPARSE_PREFILL_TOPK_ALIGNMENT + * _SPARSE_PREFILL_TOPK_ALIGNMENT + ) + + def combine_topk_swa_indices( topk_indices: torch.Tensor, query_start_loc: torch.Tensor, @@ -452,23 +626,35 @@ def combine_topk_swa_indices( topk: int, M: int, N: int, + combined_indices: torch.Tensor | None = None, + combined_lens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: num_tokens = topk_indices.shape[0] num_reqs = seq_lens.shape[0] - combined_topk = ( - (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1) - // _SPARSE_PREFILL_TOPK_ALIGNMENT - * _SPARSE_PREFILL_TOPK_ALIGNMENT - ) - combined_indices = torch.full( - (num_tokens, combined_topk), - fill_value=-1, - dtype=torch.int32, - device=topk_indices.device, - ) - combined_lens = torch.empty( - num_tokens, dtype=torch.int32, device=topk_indices.device - ) + combined_topk = sparse_prefill_combined_topk_size(topk, window_size) + if combined_indices is None: + combined_indices = torch.full( + (num_tokens, combined_topk), + fill_value=-1, + dtype=torch.int32, + device=topk_indices.device, + ) + else: + assert combined_indices.shape[0] >= num_tokens + assert combined_indices.shape[1] >= combined_topk + assert combined_indices.dtype == torch.int32 + assert combined_indices.device == topk_indices.device + combined_indices = combined_indices[:num_tokens, :combined_topk] + combined_indices.fill_(-1) + if combined_lens is None: + combined_lens = torch.empty( + num_tokens, dtype=torch.int32, device=topk_indices.device + ) + else: + assert combined_lens.shape[0] >= num_tokens + assert combined_lens.dtype == torch.int32 + assert combined_lens.device == topk_indices.device + combined_lens = combined_lens[:num_tokens] NUM_WORKERS = 128 _combine_topk_swa_indices_kernel[(num_reqs, NUM_WORKERS)]( diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py b/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py new file mode 100644 index 000000000000..71a6199e1d9d --- /dev/null +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""SM12x Triton FP8 einsum kernels for DeepSeek V4.""" + +import torch + +from vllm.triton_utils import tl, triton + + +def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: + exp_bits = scale.view(torch.uint8).to(torch.int32) + fp32_bits = exp_bits << 23 + return fp32_bits.view(torch.float32) + + +@triton.jit +def _deepseek_v4_sm12_fp8_einsum_kernel( + a_ptr, + a_scale_ptr, + b_ptr, + b_scale_ptr, + out_ptr, + num_tokens: tl.constexpr, + num_groups: tl.constexpr, + out_rank: tl.constexpr, + hidden_size: tl.constexpr, + a_stride_token: tl.constexpr, + a_stride_group: tl.constexpr, + a_stride_hidden: tl.constexpr, + a_scale_stride_token: tl.constexpr, + a_scale_stride_group: tl.constexpr, + a_scale_stride_hidden: tl.constexpr, + b_stride_group: tl.constexpr, + b_stride_out: tl.constexpr, + b_stride_hidden: tl.constexpr, + b_scale_stride_group: tl.constexpr, + b_scale_stride_out: tl.constexpr, + b_scale_stride_hidden: tl.constexpr, + out_stride_token: tl.constexpr, + out_stride_group: tl.constexpr, + out_stride_rank: tl.constexpr, + BLOCK_TOKENS: tl.constexpr, + BLOCK_OUT: tl.constexpr, + BLOCK_HIDDEN: tl.constexpr, +) -> None: + token_block = tl.program_id(0) + out_block = tl.program_id(1) + group = tl.program_id(2) + + token_offsets = token_block * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS) + out_offsets = out_block * BLOCK_OUT + tl.arange(0, BLOCK_OUT) + hidden_offsets = tl.arange(0, BLOCK_HIDDEN) + accum = tl.zeros((BLOCK_TOKENS, BLOCK_OUT), dtype=tl.float32) + + for hidden_start in range(0, hidden_size, BLOCK_HIDDEN): + hidden = hidden_start + hidden_offsets + a = tl.load( + a_ptr + + token_offsets[:, None] * a_stride_token + + group * a_stride_group + + hidden[None, :] * a_stride_hidden, + mask=(token_offsets[:, None] < num_tokens) + & (hidden[None, :] < hidden_size), + other=0.0, + ) + b = tl.load( + b_ptr + + group * b_stride_group + + out_offsets[None, :] * b_stride_out + + hidden[:, None] * b_stride_hidden, + mask=(out_offsets[None, :] < out_rank) + & (hidden[:, None] < hidden_size), + other=0.0, + ) + raw = tl.dot(a, b, out_dtype=tl.float32) + hidden_scale_block = hidden_start // BLOCK_HIDDEN + a_scale = tl.load( + a_scale_ptr + + token_offsets * a_scale_stride_token + + group * a_scale_stride_group + + hidden_scale_block * a_scale_stride_hidden, + mask=token_offsets < num_tokens, + other=0.0, + ) + b_scale = tl.load( + b_scale_ptr + + group * b_scale_stride_group + + (out_offsets // 128) * b_scale_stride_out + + hidden_scale_block * b_scale_stride_hidden, + mask=out_offsets < out_rank, + other=0.0, + ) + accum += raw * a_scale[:, None] * b_scale[None, :] + + tl.store( + out_ptr + + token_offsets[:, None] * out_stride_token + + group * out_stride_group + + out_offsets[None, :] * out_stride_rank, + accum, + mask=(token_offsets[:, None] < num_tokens) + & (out_offsets[None, :] < out_rank), + ) + + +def deepseek_v4_sm12_fp8_einsum( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, +) -> None: + """Compute ``bhr,hdr->bhd`` with FP32 block scales on SM12x. + + ``a`` is the transposed output of ``fused_inv_rope_fp8_quant`` with shape + ``[tokens, groups, hidden]``. ``b`` is ``wo_a`` reshaped to + ``[groups, out_rank, hidden]``. + """ + num_tokens, num_groups, hidden_size = a.shape + b_groups, out_rank, b_hidden_size = b.shape + assert b_groups == num_groups + assert b_hidden_size == hidden_size + assert out.shape == (num_tokens, num_groups, out_rank) + assert hidden_size % 128 == 0 + assert out_rank % 128 == 0 + assert a.dtype == torch.float8_e4m3fn + assert b.dtype == torch.float8_e4m3fn + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if a_scale.dtype == e8m0_dtype: + a_scale = _upcast_e8m0_to_fp32(a_scale) + if b_scale.dtype == e8m0_dtype: + b_scale = _upcast_e8m0_to_fp32(b_scale) + assert a_scale.dtype == torch.float32 + assert b_scale.dtype == torch.float32 + + if num_tokens == 0: + return + + block_tokens = 16 + block_out = 128 + block_hidden = 128 + grid = ( + triton.cdiv(num_tokens, block_tokens), + triton.cdiv(out_rank, block_out), + num_groups, + ) + _deepseek_v4_sm12_fp8_einsum_kernel[grid]( + a, + a_scale, + b, + b_scale, + out, + num_tokens, + num_groups, + out_rank, + hidden_size, + a.stride(0), + a.stride(1), + a.stride(2), + a_scale.stride(0), + a_scale.stride(1), + a_scale.stride(2), + b.stride(0), + b.stride(1), + b.stride(2), + b_scale.stride(0), + b_scale.stride(1), + b_scale.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_TOKENS=block_tokens, + BLOCK_OUT=block_out, + BLOCK_HIDDEN=block_hidden, + num_warps=4, + num_stages=3, + ) From a8746555ef68e7042c1733e0b8e31ebb19837ac9 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 5 May 2026 02:25:00 +0800 Subject: [PATCH 02/11] Align DeepSeek V4 API semantics Co-authored-by: OpenAI Codex Signed-off-by: jasl --- .../test_deepseekv3_reasoning_parser.py | 38 ++- tests/tokenizers_/test_deepseek_v4.py | 323 ++++++++++++++++++ vllm/entrypoints/chat_utils.py | 11 + .../openai/chat_completion/batch_serving.py | 6 +- .../openai/chat_completion/protocol.py | 120 ++++++- .../openai/chat_completion/serving.py | 8 +- vllm/entrypoints/openai/engine/protocol.py | 9 + vllm/entrypoints/serve/render/serving.py | 28 +- vllm/reasoning/__init__.py | 2 +- vllm/tokenizers/deepseek_v4_encoding.py | 11 +- 10 files changed, 536 insertions(+), 20 deletions(-) diff --git a/tests/reasoning/test_deepseekv3_reasoning_parser.py b/tests/reasoning/test_deepseekv3_reasoning_parser.py index f5b37194f927..a013cf1a8775 100644 --- a/tests/reasoning/test_deepseekv3_reasoning_parser.py +++ b/tests/reasoning/test_deepseekv3_reasoning_parser.py @@ -4,16 +4,35 @@ import pytest from transformers import AutoTokenizer +from vllm.config.reasoning import ReasoningConfig from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.reasoning import ReasoningParserManager from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser -from vllm.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser +from vllm.reasoning.deepseek_v3_reasoning_parser import ( + DeepSeekV3ReasoningParser, + DeepSeekV3ReasoningWithThinkingParser, +) from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-V3.1" +class FakeReasoningTokenizer: + + def get_vocab(self) -> dict[str, int]: + return {"": 100, "": 101} + + def encode( + self, + text: str, + add_special_tokens: bool = False, + **kwargs, + ) -> list[int]: + assert add_special_tokens is False + return [self.get_vocab()[text]] + + @pytest.fixture(scope="module") def tokenizer(): return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) @@ -37,7 +56,22 @@ def test_parser_selection(tokenizer, thinking, expected_parser_type): def test_deepseek_v4_reasoning_parser_alias(): parser_cls = ReasoningParserManager.get_reasoning_parser("deepseek_v4") - assert parser_cls is DeepSeekV3ReasoningParser + assert parser_cls is DeepSeekV3ReasoningWithThinkingParser + + +def test_deepseek_v4_auto_reasoning_config_initializes_budget_tokens(monkeypatch): + monkeypatch.setattr( + "vllm.config.reasoning.cached_tokenizer_from_config", + lambda model_config: FakeReasoningTokenizer(), + ) + + config = ReasoningConfig(reasoning_parser="deepseek_v4") + + config.initialize_token_ids(model_config=None) + + assert config.enabled is True + assert config.reasoning_start_token_ids == [100] + assert config.reasoning_end_token_ids == [101] def test_identity_reasoning_parser_basic(tokenizer): diff --git a/tests/tokenizers_/test_deepseek_v4.py b/tests/tokenizers_/test_deepseek_v4.py index 358732eabf40..0a3253957add 100644 --- a/tests/tokenizers_/test_deepseek_v4.py +++ b/tests/tokenizers_/test_deepseek_v4.py @@ -8,8 +8,14 @@ import pytest from vllm.entrypoints.chat_utils import parse_chat_messages +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ChatMessage, +) +from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.renderers.registry import RENDERER_REGISTRY from vllm.tokenizers.deepseek_v4 import get_deepseek_v4_tokenizer +from vllm.tokenizers.deepseek_v4_encoding import encode_arguments_to_dsml from vllm.tokenizers.registry import TokenizerRegistry FIXTURES_DIR = Path(__file__).parent / "fixtures" / "deepseek_v4" @@ -96,6 +102,130 @@ def test_deepseek_v4_enables_thinking_with_compatible_kwargs(kwargs): assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|>") +def test_deepseek_v4_honors_official_thinking_request_field(): + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + prompt = _tokenizer().apply_chat_template( + request.messages, + tokenize=False, + **chat_kwargs, + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|>") + + +def test_deepseek_v4_defaults_to_official_thinking_for_openai_request(): + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + + +def test_deepseek_v4_preserves_official_reasoning_content_alias(): + messages = [ + {"role": "user", "content": "Q1"}, + {"role": "assistant", "reasoning_content": "because", "content": "A1"}, + {"role": "user", "content": "Q2"}, + ] + + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + + assert conversation[1]["reasoning"] == "because" + assert conversation[1]["reasoning_content"] == "because" + + +def test_deepseek_v4_response_messages_expose_reasoning_content_alias(): + message = ChatMessage(role="assistant", reasoning="because", content="answer") + delta = DeltaMessage(reasoning="because") + + assert message.reasoning_content == "because" + assert delta.reasoning_content == "because" + assert ( + ChatMessage( + role="assistant", + reasoning_content="because", + content="answer", + ).reasoning + == "because" + ) + + +def test_deepseek_v4_preserves_official_prefix_assistant_message(): + messages = [ + {"role": "user", "content": "Please write quick sort code"}, + {"role": "assistant", "content": "```python\n", "prefix": True}, + ] + + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + prompt = _tokenizer().apply_chat_template( + conversation=conversation, + messages=messages, + tokenize=False, + ) + + assert conversation[1]["prefix"] is True + assert conversation[1]["wo_eos"] is True + assert prompt.endswith("<|Assistant|>```python\n") + assert not prompt.endswith("<|end▁of▁sentence|>") + + +def test_deepseek_v4_thinking_ignores_sampling_controls(): + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + "temperature": 0.2, + "top_p": 0.3, + "top_k": 4, + "presence_penalty": 1.5, + "frequency_penalty": 1.25, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + ) + + assert sampling_params.temperature == 1.0 + assert sampling_params.top_p == 1.0 + assert sampling_params.top_k == 0 + assert sampling_params.presence_penalty == 0.0 + assert sampling_params.frequency_penalty == 0.0 + + def test_deepseek_v4_uses_v4_tool_prompt_from_request_tools(): tools = [ { @@ -183,6 +313,66 @@ def test_deepseek_v4_renders_parsed_history_tool_arguments(): assert 'parameter name="arguments"' not in prompt +@pytest.mark.parametrize( + ("tool_call", "expected_parameter"), + [ + ({"name": "refresh", "arguments": None}, None), + ({"name": "refresh"}, None), + ({"name": "refresh", "arguments": ""}, None), + ( + {"name": "refresh", "arguments": '{"target": "cache"}'}, + '<|DSML|parameter name="target" string="true">cache', + ), + ( + {"name": "refresh", "arguments": {"target": "cache"}}, + '<|DSML|parameter name="target" string="true">cache', + ), + ], +) +def test_deepseek_v4_encodes_empty_history_tool_arguments( + tool_call, expected_parameter +): + prompt = encode_arguments_to_dsml(tool_call) + + if expected_parameter is None: + assert prompt == "" + else: + assert expected_parameter in prompt + + +def test_deepseek_v4_renders_openai_history_tool_call_with_null_arguments(): + messages = [ + {"role": "user", "content": "Refresh state"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "refresh", + "arguments": None, + }, + } + ], + }, + ] + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + + prompt = _tokenizer().apply_chat_template( + conversation=conversation, + messages=messages, + tokenize=False, + ) + + assert '<|DSML|invoke name="refresh">' in prompt + assert "<|DSML|parameter" not in prompt + + @pytest.mark.parametrize("reasoning_effort", ["minimal", "low", "medium", "high"]) def test_deepseek_v4_accepts_openai_reasoning_effort_values(reasoning_effort): prompt = _tokenizer().apply_chat_template( @@ -288,3 +478,136 @@ def test_deepseek_v4_matches_reference_golden_fixtures(case_id, kwargs): expected = (FIXTURES_DIR / f"test_output_{case_id}.txt").read_text() assert prompt == expected + + +@pytest.mark.parametrize( + "model", + [ + "deepseek-ai/DeepSeek-V4-Flash", + "deepseek-ai/DeepSeek-V4-Pro", + ], +) +def test_deepseek_v4_official_api_defaults_to_thinking_for_v4_family(model): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": model, + "messages": [{"role": "user", "content": "Hello"}], + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + + +def test_deepseek_v4_official_api_uses_model_config_for_family_detection(): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": "local-ds4-alias", + "messages": [{"role": "user", "content": "Hello"}], + "temperature": 0.2, + } + ) + model_config = SimpleNamespace( + hf_config=SimpleNamespace(model_type="deepseek_v4", architectures=[]), + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs, + model_config=model_config, + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + model_config=model_config, + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + assert sampling_params.temperature == 1.0 + + +def test_deepseek_v4_official_api_sampling_override_can_be_disabled(): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + "deepseek_v4_sampling_override": False, + "temperature": 0.2, + "top_p": 0.3, + "top_k": 4, + "min_p": 0.05, + "presence_penalty": 1.5, + "frequency_penalty": 1.25, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + ) + + assert sampling_params.temperature == 0.2 + assert sampling_params.top_p == 0.3 + assert sampling_params.top_k == 4 + assert sampling_params.min_p == 0.05 + assert sampling_params.presence_penalty == 1.5 + assert sampling_params.frequency_penalty == 1.25 + + +def test_deepseek_v4_official_api_sampling_override_is_v4_only(): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-R1", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + "temperature": 0.2, + "top_p": 0.3, + "top_k": 4, + "min_p": 0.05, + "presence_penalty": 1.5, + "frequency_penalty": 1.25, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + ) + + assert "thinking" not in chat_kwargs + assert "enable_thinking" not in chat_kwargs + assert sampling_params.temperature == 0.2 + assert sampling_params.top_p == 0.3 + assert sampling_params.top_k == 4 + assert sampling_params.min_p == 0.05 + assert sampling_params.presence_penalty == 1.5 + assert sampling_params.frequency_penalty == 1.25 diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index cfe0857b679e..78de5bf2ef1e 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -399,6 +399,12 @@ class ConversationMessage(TypedDict, total=False): reasoning_content: str | None """Deprecated: The reasoning content for interleaved thinking.""" + prefix: bool + """Whether this assistant message is a prefix for continuation.""" + + wo_eos: bool + """Whether this message should be rendered without an EOS marker.""" + tools: list[ChatCompletionFunctionToolParam] | None """The tools for developer role.""" @@ -1751,6 +1757,8 @@ def _parse_chat_message_content( role = message["role"] content = message.get("content") reasoning = message.get("reasoning") + if reasoning is None: + reasoning = message.get("reasoning_content") if content is None: content = [] @@ -1780,6 +1788,9 @@ def _parse_chat_message_content( result_msg["reasoning_content"] = cast( str, reasoning ) # keep compatibility + if parsed_msg.get("prefix"): + result_msg["prefix"] = True + result_msg["wo_eos"] = True elif role == "tool": parsed_msg = _ToolParser(message) if "tool_call_id" in parsed_msg: diff --git a/vllm/entrypoints/openai/chat_completion/batch_serving.py b/vllm/entrypoints/openai/chat_completion/batch_serving.py index cc49909b8361..df6aff4a45c3 100644 --- a/vllm/entrypoints/openai/chat_completion/batch_serving.py +++ b/vllm/entrypoints/openai/chat_completion/batch_serving.py @@ -159,8 +159,12 @@ async def create_batch_chat_completion( self.override_max_tokens, ) single_request = single_requests[i] + chat_template_kwargs = self._effective_chat_template_kwargs(single_request) sampling_params = single_request.to_sampling_params( - max_tokens, self.default_sampling_params + max_tokens, + self.default_sampling_params, + chat_template_kwargs=chat_template_kwargs, + model_config=self.model_config, ) self._log_inputs( sub_request_id, diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index c92cc13da01f..b0a73770da35 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -62,6 +62,15 @@ class ChatMessage(OpenAIBaseModel): # vLLM-specific fields that are not in OpenAI spec reasoning: str | None = None + reasoning_content: str | None = None + + @model_validator(mode="after") + def _populate_reasoning_content_alias(self) -> "ChatMessage": + if self.reasoning_content is None and self.reasoning is not None: + self.reasoning_content = self.reasoning + elif self.reasoning is None and self.reasoning_content is not None: + self.reasoning = self.reasoning_content + return self class ChatCompletionLogProb(OpenAIBaseModel): @@ -164,6 +173,10 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): type: Literal["function"] = "function" +class DeepSeekThinkingParam(OpenAIBaseModel): + type: Literal["enabled", "disabled"] = "enabled" + + class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create @@ -209,6 +222,15 @@ class ChatCompletionRequest(OpenAIBaseModel): "part of the standard OpenAI API specification." ), ) + thinking: DeepSeekThinkingParam | None = None + deepseek_v4_sampling_override: bool = Field( + default=True, + description=( + "Apply DeepSeek V4 official sampling defaults when thinking is " + "enabled. This only affects the DeepSeek V4 family and can be " + "disabled per request." + ), + ) thinking_token_budget: int | None = None include_reasoning: bool = True parallel_tool_calls: bool | None = True @@ -435,21 +457,79 @@ def build_chat_params( default_template: str | None, default_template_content_format: ChatTemplateContentFormatOption, ) -> ChatParams: + chat_kwargs = merge_kwargs( + self.chat_template_kwargs, + dict( + add_generation_prompt=self.add_generation_prompt, + continue_final_message=self.continue_final_message, + documents=self.documents, + reasoning_effort=self.reasoning_effort, + ), + ) return ChatParams( chat_template=self.chat_template or default_template, chat_template_content_format=default_template_content_format, - chat_template_kwargs=merge_kwargs( - self.chat_template_kwargs, - dict( - add_generation_prompt=self.add_generation_prompt, - continue_final_message=self.continue_final_message, - documents=self.documents, - reasoning_effort=self.reasoning_effort, - ), - ), + chat_template_kwargs=chat_kwargs, media_io_kwargs=self.media_io_kwargs, ) + def _is_deepseek_v4_model(self, model_config: ModelConfig | None = None) -> bool: + hf_config = getattr(model_config, "hf_config", None) + if getattr(hf_config, "model_type", None) == "deepseek_v4": + return True + + architectures = getattr(hf_config, "architectures", None) or () + if any("deepseekv4" in str(arch).replace("_", "").lower() + for arch in architectures): + return True + + model = (self.model or "").lower().replace("_", "-") + return "deepseek-v4" in model + + def apply_chat_template_kwargs( + self, + chat_template_kwargs: dict[str, Any], + *, + model_config: ModelConfig | None = None, + ) -> dict[str, Any]: + """Apply request-level DeepSeek API compatibility knobs. + + DeepSeek's OpenAI-compatible API exposes ``thinking`` as a top-level + request field, while vLLM's DeepSeek tokenizer consumes it as a chat + template kwarg. Keep the translation at the protocol boundary so the + tokenizer and reasoning parser see the same effective state. + """ + chat_template_kwargs = dict(chat_template_kwargs) + if not self._is_deepseek_v4_model(model_config): + return chat_template_kwargs + + if self.thinking is not None: + enabled = self.thinking.type == "enabled" + chat_template_kwargs["thinking"] = enabled + chat_template_kwargs["enable_thinking"] = enabled + elif ( + "thinking" not in chat_template_kwargs + and "enable_thinking" not in chat_template_kwargs + ): + chat_template_kwargs["thinking"] = True + chat_template_kwargs["enable_thinking"] = True + + return chat_template_kwargs + + def _use_deepseek_v4_sampling_override(self) -> bool: + return self.deepseek_v4_sampling_override + + @staticmethod + def _is_thinking_enabled( + chat_template_kwargs: dict[str, Any] | None, + ) -> bool: + if chat_template_kwargs is None: + return False + return bool( + chat_template_kwargs.get("thinking") + or chat_template_kwargs.get("enable_thinking") + ) + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: if self.max_completion_tokens is not None: max_output_tokens: int | None = self.max_completion_tokens @@ -499,6 +579,9 @@ def to_sampling_params( self, max_tokens: int, default_sampling_params: dict, + *, + chat_template_kwargs: dict[str, Any] | None = None, + model_config: ModelConfig | None = None, ) -> SamplingParams: # Default parameters if (repetition_penalty := self.repetition_penalty) is None: @@ -523,6 +606,21 @@ def to_sampling_params( "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] ) + if ( + self._is_deepseek_v4_model(model_config) + and self._use_deepseek_v4_sampling_override() + and self._is_thinking_enabled(chat_template_kwargs) + ): + temperature = self._DEFAULT_SAMPLING_PARAMS["temperature"] + top_p = self._DEFAULT_SAMPLING_PARAMS["top_p"] + top_k = self._DEFAULT_SAMPLING_PARAMS["top_k"] + min_p = self._DEFAULT_SAMPLING_PARAMS["min_p"] + presence_penalty = 0.0 + frequency_penalty = 0.0 + else: + presence_penalty = self.presence_penalty or 0.0 + frequency_penalty = self.frequency_penalty or 0.0 + prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: prompt_logprobs = self.top_logprobs @@ -565,8 +663,8 @@ def to_sampling_params( extra_args["kv_transfer_params"] = self.kv_transfer_params return SamplingParams.from_optional( n=self.n, - presence_penalty=self.presence_penalty, - frequency_penalty=self.frequency_penalty, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, repetition_penalty=repetition_penalty, temperature=temperature, top_p=top_p, diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 694ff80047c7..2dad7f948b5b 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -190,7 +190,7 @@ def warmup(self) -> None: def _effective_chat_template_kwargs( self, request: ChatCompletionRequest ) -> dict[str, Any]: - return ( + chat_template_kwargs = ( request.build_chat_params( self.chat_template, self.chat_template_content_format, @@ -198,6 +198,10 @@ def _effective_chat_template_kwargs( .with_defaults(self.default_chat_template_kwargs) .chat_template_kwargs ) + return request.apply_chat_template_kwargs( + chat_template_kwargs, + model_config=self.model_config, + ) async def render_chat_request( self, @@ -300,6 +304,8 @@ async def create_chat_completion( sampling_params = request.to_sampling_params( max_tokens, self.default_sampling_params, + chat_template_kwargs=chat_template_kwargs, + model_config=self.model_config, ) self._log_inputs( diff --git a/vllm/entrypoints/openai/engine/protocol.py b/vllm/entrypoints/openai/engine/protocol.py index 890af0300efc..8c531c6d5a93 100644 --- a/vllm/entrypoints/openai/engine/protocol.py +++ b/vllm/entrypoints/openai/engine/protocol.py @@ -268,8 +268,17 @@ class DeltaMessage(OpenAIBaseModel): role: str | None = None content: str | None = None reasoning: str | None = None + reasoning_content: str | None = None tool_calls: list[DeltaToolCall] = Field(default_factory=list) + @model_validator(mode="after") + def _populate_reasoning_content_alias(self) -> "DeltaMessage": + if self.reasoning_content is None and self.reasoning is not None: + self.reasoning_content = self.reasoning + elif self.reasoning is None and self.reasoning_content is not None: + self.reasoning = self.reasoning_content + return self + class GenerationError(Exception): """raised when finish_reason indicates internal server error (500)""" diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 967899229ada..d6da0dba6f4e 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence +from dataclasses import replace from http import HTTPStatus from typing import Any, cast @@ -165,7 +166,21 @@ async def render_chat_request( self.default_sampling_params, self.override_max_tokens, ) - params = request.to_sampling_params(max_tokens, self.default_sampling_params) + chat_template_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params( + self.chat_template, + self.chat_template_content_format, + ) + .with_defaults(self.default_chat_template_kwargs) + .chat_template_kwargs, + model_config=self.model_config, + ) + params = request.to_sampling_params( + max_tokens, + self.default_sampling_params, + chat_template_kwargs=chat_template_kwargs, + model_config=self.model_config, + ) request_id = f"chatcmpl-{random_uuid()}" @@ -556,6 +571,17 @@ async def preprocess_chat( default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None), default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None), ) + apply_chat_template_kwargs = getattr( + request, "apply_chat_template_kwargs", None + ) + if apply_chat_template_kwargs is not None: + chat_params = replace( + chat_params, + chat_template_kwargs=apply_chat_template_kwargs( + chat_params.chat_template_kwargs, + model_config=self.model_config, + ), + ) (conversation,), (engine_input,) = await renderer.render_chat_async( [messages], diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index cd51f106503a..e5962c7c0664 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -30,7 +30,7 @@ ), "deepseek_v4": ( "deepseek_v3_reasoning_parser", - "DeepSeekV3ReasoningParser", + "DeepSeekV3ReasoningWithThinkingParser", ), "poolside_v1": ( "poolside_v1_reasoning_parser", diff --git a/vllm/tokenizers/deepseek_v4_encoding.py b/vllm/tokenizers/deepseek_v4_encoding.py index 6895771e2f59..74f01ce017e4 100644 --- a/vllm/tokenizers/deepseek_v4_encoding.py +++ b/vllm/tokenizers/deepseek_v4_encoding.py @@ -155,10 +155,15 @@ def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str: p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}' P_dsml_strs = [] - if isinstance(tool_call["arguments"], str): - arguments = json.loads(tool_call["arguments"]) + raw_arguments = tool_call.get("arguments") + if raw_arguments is None or raw_arguments == "": + arguments = {} + elif isinstance(raw_arguments, str): + arguments = json.loads(raw_arguments) + if arguments is None: + arguments = {} else: - arguments = tool_call["arguments"] + arguments = raw_arguments for k, v in arguments.items(): p_dsml_str = p_dsml_template.format( From 04af72fb881b3850df5ff8a7dffb7ec2b1f2d4a3 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 5 May 2026 02:25:00 +0800 Subject: [PATCH 03/11] Add vLLM logprobs oracle comparator Co-authored-by: OpenAI Codex Signed-off-by: jasl --- .../test_compare_vllm_http_logprobs_oracle.py | 115 +++++ tools/compare_vllm_http_logprobs_oracle.py | 431 ++++++++++++++++++ 2 files changed, 546 insertions(+) create mode 100644 tests/tools/test_compare_vllm_http_logprobs_oracle.py create mode 100755 tools/compare_vllm_http_logprobs_oracle.py diff --git a/tests/tools/test_compare_vllm_http_logprobs_oracle.py b/tests/tools/test_compare_vllm_http_logprobs_oracle.py new file mode 100644 index 000000000000..729c76659d53 --- /dev/null +++ b/tests/tools/test_compare_vllm_http_logprobs_oracle.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import importlib.util +from pathlib import Path + +SCRIPT_PATH = ( + Path(__file__).parents[2] / "tools" / "compare_vllm_http_logprobs_oracle.py" +) +spec = importlib.util.spec_from_file_location( + "compare_vllm_http_logprobs_oracle", SCRIPT_PATH +) +assert spec is not None +oracle_compare = importlib.util.module_from_spec(spec) +assert spec.loader is not None +spec.loader.exec_module(oracle_compare) + + +def _response(tokens, top_logprobs): + token_ids = [ + int(token.split(":", 1)[1]) + for token in tokens + if isinstance(token, str) and token.startswith("token_id:") + ] + return { + "choices": [ + { + "text": "", + "logprobs": { + "tokens": tokens, + "token_logprobs": [-0.1 * (i + 1) for i in range(len(tokens))], + "top_logprobs": top_logprobs, + }, + "token_ids": token_ids, + "prompt_token_ids": [1, 2, 3], + } + ], + "usage": {"prompt_tokens": 3, "completion_tokens": len(tokens)}, + } + + +def test_compare_response_accepts_identical_top_logprobs(): + top_logprobs = [ + {"token_id:10": -0.1, "token_id:20": -1.0}, + {"token_id:11": -0.2, "token_id:21": -1.2}, + ] + report = oracle_compare.compare_response( + "case0", + _response(["token_id:10", "token_id:11"], top_logprobs), + _response(["token_id:10", "token_id:11"], top_logprobs), + top_n=2, + ) + + assert report["tokens_match"] is True + assert report["first_token_mismatch"] is None + assert report["top1_matches"] == 2 + assert report["topk_overlap_mean"] == 1.0 + assert report["max_common_logprob_abs_error"] == 0.0 + + +def test_compare_response_reports_first_generated_token_divergence(): + oracle = _response( + ["token_id:10", "token_id:11"], + [ + {"token_id:10": -0.1, "token_id:20": -1.0}, + {"token_id:11": -0.2, "token_id:21": -1.2}, + ], + ) + actual = _response( + ["token_id:10", "token_id:99"], + [ + {"token_id:10": -0.1, "token_id:20": -1.0}, + {"token_id:99": -0.2, "token_id:21": -1.2}, + ], + ) + + report = oracle_compare.compare_response("case0", oracle, actual, top_n=2) + + assert report["tokens_match"] is False + assert report["first_token_mismatch"] == { + "step": 1, + "oracle": "token_id:11", + "actual": "token_id:99", + } + assert report["matching_prefix_tokens"] == 1 + assert report["top1_matches"] == 1 + + +def test_compare_response_can_decode_oracle_token_id_keys(): + normalizer = oracle_compare.TokenNormalizer( + lambda token_id: {10: '","', 11: "title", 20: " What"}[token_id] + ) + oracle = _response( + ["token_id:10", "token_id:11"], + [ + {"token_id:10": -0.1, "token_id:20": -1.0}, + {"token_id:11": -0.2, "token_id:20": -1.2}, + ], + ) + actual = _response( + ['","', "title"], + [ + {'","': -0.11, " What": -1.1}, + {"title": -0.19, " What": -1.3}, + ], + ) + + report = oracle_compare.compare_response( + "case0", oracle, actual, top_n=2, normalizer=normalizer + ) + + assert report["tokens_match"] is True + assert report["top1_matches"] == 2 + assert report["topk_overlap_mean"] == 1.0 + assert report["max_common_logprob_abs_error"] == 0.10000000000000009 diff --git a/tools/compare_vllm_http_logprobs_oracle.py b/tools/compare_vllm_http_logprobs_oracle.py new file mode 100755 index 000000000000..85a3d6aade55 --- /dev/null +++ b/tools/compare_vllm_http_logprobs_oracle.py @@ -0,0 +1,431 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Compare current vLLM HTTP logprobs with a captured oracle bundle. + +The expected oracle format is a directory with request_*.json and matching +response_*.json files produced from /v1/completions with token-id logprob keys. +""" + +from __future__ import annotations + +import argparse +import json +import sys +import urllib.error +import urllib.request +from pathlib import Path +from typing import Any + +Json = dict[str, Any] + + +class TokenNormalizer: + """Normalize token-id placeholders to decoded token strings.""" + + def __init__(self, decode_token_id): + self._decode_token_id = decode_token_id + + def token_key(self, token: Any) -> str: + key = _token_key(token) + if not key.startswith("token_id:"): + return key + try: + token_id = int(key.removeprefix("token_id:")) + except ValueError: + return key + return self._decode_token_id(token_id) + + +def _token_key(token: Any) -> str: + if isinstance(token, str): + return token + if isinstance(token, int): + return f"token_id:{token}" + return str(token) + + +def _choice(response: Json) -> Json: + choices = response.get("choices") + if not isinstance(choices, list) or not choices: + raise ValueError("response has no choices[0]") + choice = choices[0] + if not isinstance(choice, dict): + raise ValueError("response choices[0] is not an object") + return choice + + +def _normalize_token(token: Any, normalizer: TokenNormalizer | None) -> str: + if normalizer is None: + return _token_key(token) + return normalizer.token_key(token) + + +def _generated_tokens( + response: Json, normalizer: TokenNormalizer | None = None +) -> list[str]: + choice = _choice(response) + logprobs = choice.get("logprobs") or {} + tokens = logprobs.get("tokens") + if isinstance(tokens, list) and tokens: + return [_normalize_token(token, normalizer) for token in tokens] + token_ids = choice.get("token_ids") + if isinstance(token_ids, list): + return [_normalize_token(token, normalizer) for token in token_ids] + return [] + + +def _prompt_token_ids(response: Json) -> list[int] | None: + token_ids = _choice(response).get("prompt_token_ids") + if not isinstance(token_ids, list): + return None + return [int(token) for token in token_ids] + + +def _token_logprobs(response: Json) -> list[float | None]: + logprobs = _choice(response).get("logprobs") or {} + values = logprobs.get("token_logprobs") + if not isinstance(values, list): + return [] + out: list[float | None] = [] + for value in values: + out.append(float(value) if value is not None else None) + return out + + +def _top_logprobs( + response: Json, normalizer: TokenNormalizer | None = None +) -> list[dict[str, float]]: + logprobs = _choice(response).get("logprobs") or {} + values = logprobs.get("top_logprobs") + if not isinstance(values, list): + return [] + out: list[dict[str, float]] = [] + for step in values: + if not isinstance(step, dict): + out.append({}) + continue + out.append( + { + _normalize_token(token, normalizer): float(logprob) + for token, logprob in step.items() + } + ) + return out + + +def _first_mismatch(oracle_tokens: list[str], actual_tokens: list[str]) -> Json | None: + for step, (oracle_token, actual_token) in enumerate( + zip(oracle_tokens, actual_tokens) + ): + if oracle_token != actual_token: + return {"step": step, "oracle": oracle_token, "actual": actual_token} + if len(oracle_tokens) != len(actual_tokens): + step = min(len(oracle_tokens), len(actual_tokens)) + return { + "step": step, + "oracle": oracle_tokens[step] if step < len(oracle_tokens) else None, + "actual": actual_tokens[step] if step < len(actual_tokens) else None, + } + return None + + +def _matching_prefix(oracle_tokens: list[str], actual_tokens: list[str]) -> int: + count = 0 + for oracle_token, actual_token in zip(oracle_tokens, actual_tokens): + if oracle_token != actual_token: + break + count += 1 + return count + + +def _top_keys(top_logprobs: dict[str, float], top_n: int) -> list[str]: + return list(top_logprobs.keys())[:top_n] + + +def _mean(values: list[float]) -> float | None: + return sum(values) / len(values) if values else None + + +def _max(values: list[float]) -> float | None: + return max(values) if values else None + + +def compare_response( + case_name: str, + oracle_response: Json, + actual_response: Json, + *, + top_n: int = 50, + normalizer: TokenNormalizer | None = None, +) -> Json: + oracle_tokens = _generated_tokens(oracle_response, normalizer) + actual_tokens = _generated_tokens(actual_response, normalizer) + oracle_top = _top_logprobs(oracle_response, normalizer) + actual_top = _top_logprobs(actual_response, normalizer) + oracle_token_logprobs = _token_logprobs(oracle_response) + actual_token_logprobs = _token_logprobs(actual_response) + + steps = min( + len(oracle_tokens), len(actual_tokens), len(oracle_top), len(actual_top) + ) + top1_matches = 0 + topk_overlaps: list[float] = [] + common_logprob_errors: list[float] = [] + chosen_token_logprob_errors: list[float] = [] + + for step in range(steps): + oracle_keys = _top_keys(oracle_top[step], top_n) + actual_keys = _top_keys(actual_top[step], top_n) + if oracle_keys and actual_keys and oracle_keys[0] == actual_keys[0]: + top1_matches += 1 + + oracle_set = set(oracle_keys) + actual_set = set(actual_keys) + if oracle_set: + topk_overlaps.append(len(oracle_set & actual_set) / len(oracle_set)) + for token in oracle_set & actual_set: + common_logprob_errors.append( + abs(oracle_top[step][token] - actual_top[step][token]) + ) + + if ( + oracle_tokens[step] == actual_tokens[step] + and step < len(oracle_token_logprobs) + and step < len(actual_token_logprobs) + and oracle_token_logprobs[step] is not None + and actual_token_logprobs[step] is not None + ): + chosen_token_logprob_errors.append( + abs(oracle_token_logprobs[step] - actual_token_logprobs[step]) + ) + + oracle_prompt_ids = _prompt_token_ids(oracle_response) + actual_prompt_ids = _prompt_token_ids(actual_response) + prompt_ids_match = ( + None + if oracle_prompt_ids is None or actual_prompt_ids is None + else oracle_prompt_ids == actual_prompt_ids + ) + + first_mismatch = _first_mismatch(oracle_tokens, actual_tokens) + return { + "case": case_name, + "tokens_match": first_mismatch is None, + "prompt_token_ids_match": prompt_ids_match, + "first_token_mismatch": first_mismatch, + "matching_prefix_tokens": _matching_prefix(oracle_tokens, actual_tokens), + "oracle_token_count": len(oracle_tokens), + "actual_token_count": len(actual_tokens), + "compared_steps": steps, + "top1_matches": top1_matches, + "top1_match_rate": top1_matches / steps if steps else None, + "topk_overlap_mean": _mean(topk_overlaps), + "topk_overlap_min": min(topk_overlaps) if topk_overlaps else None, + "max_common_logprob_abs_error": _max(common_logprob_errors), + "mean_common_logprob_abs_error": _mean(common_logprob_errors), + "max_chosen_token_logprob_abs_error": _max(chosen_token_logprob_errors), + "mean_chosen_token_logprob_abs_error": _mean(chosen_token_logprob_errors), + } + + +def _load_json(path: Path) -> Json: + with path.open(encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, dict): + raise ValueError(f"{path} is not a JSON object") + return data + + +def load_oracle_cases(oracle_dir: Path) -> list[tuple[str, Json, Json]]: + cases: list[tuple[str, Json, Json]] = [] + request_paths = sorted(oracle_dir.glob("request_*.json")) + if not request_paths: + raise ValueError(f"{oracle_dir} has no request_*.json files") + for request_path in request_paths: + suffix = request_path.stem.removeprefix("request_") + response_path = oracle_dir / f"response_{suffix}.json" + if not response_path.exists(): + raise ValueError(f"missing {response_path.name} for {request_path.name}") + cases.append((suffix, _load_json(request_path), _load_json(response_path))) + return cases + + +def post_completion(base_url: str, payload: Json, timeout: float) -> Json: + url = f"{base_url.rstrip('/')}/v1/completions" + encoded = json.dumps(payload).encode("utf-8") + request = urllib.request.Request( + url, + data=encoded, + headers={"Content-Type": "application/json"}, + method="POST", + ) + try: + with urllib.request.urlopen(request, timeout=timeout) as response: + body = response.read().decode("utf-8") + except urllib.error.HTTPError as exc: + body = exc.read().decode("utf-8", errors="replace") + raise RuntimeError(f"HTTP {exc.code} from {url}: {body}") from exc + data = json.loads(body) + if not isinstance(data, dict): + raise ValueError(f"{url} returned non-object JSON") + return data + + +def load_token_normalizer( + tokenizer: str, + *, + tokenizer_mode: str, + trust_remote_code: bool, +) -> TokenNormalizer: + from vllm.tokenizers import get_tokenizer + + hf_tokenizer = get_tokenizer( + tokenizer, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + ) + + def decode_token_id(token_id: int) -> str: + return hf_tokenizer.decode([token_id]) + + return TokenNormalizer(decode_token_id) + + +def summarize_reports(reports: list[Json]) -> Json: + return { + "case_count": len(reports), + "all_tokens_match": all(report["tokens_match"] for report in reports), + "all_prompt_token_ids_match": all( + report["prompt_token_ids_match"] is not False for report in reports + ), + "min_top1_match_rate": min( + ( + report["top1_match_rate"] + for report in reports + if report["top1_match_rate"] is not None + ), + default=None, + ), + "min_topk_overlap_mean": min( + ( + report["topk_overlap_mean"] + for report in reports + if report["topk_overlap_mean"] is not None + ), + default=None, + ), + "max_common_logprob_abs_error": max( + ( + report["max_common_logprob_abs_error"] + for report in reports + if report["max_common_logprob_abs_error"] is not None + ), + default=None, + ), + "max_chosen_token_logprob_abs_error": max( + ( + report["max_chosen_token_logprob_abs_error"] + for report in reports + if report["max_chosen_token_logprob_abs_error"] is not None + ), + default=None, + ), + } + + +def _fails_thresholds( + summary: Json, reports: list[Json], args: argparse.Namespace +) -> bool: + failed = False + if args.strict_tokens and not summary["all_tokens_match"]: + failed = True + if args.strict_prompt_token_ids and not summary["all_prompt_token_ids_match"]: + failed = True + if args.min_top1_match_rate is not None: + value = summary["min_top1_match_rate"] + failed = failed or value is None or value < args.min_top1_match_rate + if args.min_topk_overlap_mean is not None: + value = summary["min_topk_overlap_mean"] + failed = failed or value is None or value < args.min_topk_overlap_mean + if args.max_common_logprob_abs_error is not None: + value = summary["max_common_logprob_abs_error"] + failed = failed or value is None or value > args.max_common_logprob_abs_error + if args.max_chosen_token_logprob_abs_error is not None: + value = summary["max_chosen_token_logprob_abs_error"] + failed = ( + failed or value is None or value > args.max_chosen_token_logprob_abs_error + ) + if args.fail_on_first_mismatch: + failed = failed or any(report["first_token_mismatch"] for report in reports) + return failed + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--oracle-dir", required=True, type=Path) + parser.add_argument("--base-url", default="http://127.0.0.1:8000") + parser.add_argument("--timeout", type=float, default=240.0) + parser.add_argument("--top-n", type=int, default=50) + parser.add_argument("--output", type=Path) + parser.add_argument("--model", help="Override the model field from request_*.json") + parser.add_argument( + "--tokenizer", + help=( + "Tokenizer used to decode token_id: logprob keys before " + "comparing them with text token keys." + ), + ) + parser.add_argument("--tokenizer-mode", default="auto") + parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--strict-tokens", action="store_true") + parser.add_argument("--strict-prompt-token-ids", action="store_true") + parser.add_argument("--fail-on-first-mismatch", action="store_true") + parser.add_argument("--min-top1-match-rate", type=float) + parser.add_argument("--min-topk-overlap-mean", type=float) + parser.add_argument("--max-common-logprob-abs-error", type=float) + parser.add_argument("--max-chosen-token-logprob-abs-error", type=float) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + cases = load_oracle_cases(args.oracle_dir) + normalizer = None + if args.tokenizer: + normalizer = load_token_normalizer( + args.tokenizer, + tokenizer_mode=args.tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) + reports: list[Json] = [] + for suffix, request_payload, oracle_response in cases: + payload = dict(request_payload) + if args.model: + payload["model"] = args.model + actual_response = post_completion(args.base_url, payload, args.timeout) + reports.append( + compare_response( + f"request_{suffix}", + oracle_response, + actual_response, + top_n=args.top_n, + normalizer=normalizer, + ) + ) + + summary = summarize_reports(reports) + result: Json = {"summary": summary, "cases": reports} + text = json.dumps(result, indent=2, sort_keys=True) + if args.output: + args.output.write_text(text + "\n", encoding="utf-8") + print(text) + return 1 if _fails_thresholds(summary, reports, args) else 0 + + +if __name__ == "__main__": + try: + raise SystemExit(main()) + except (OSError, RuntimeError, ValueError, json.JSONDecodeError) as exc: + print(f"error: {exc}", file=sys.stderr) + raise SystemExit(2) from exc From e21b9cc63eff9b0c1363ce9a8861815c87aa1dd9 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 5 May 2026 02:25:00 +0800 Subject: [PATCH 04/11] Forward DeepSeek V4 MoE clamp limit Co-authored-by: OpenAI Codex Signed-off-by: jasl --- tests/kernels/moe/test_moe.py | 60 +++++++++++++++++++ .../layers/fused_moe/fused_marlin_moe.py | 1 + 2 files changed, 61 insertions(+) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index ebc3256b548f..a5d2c33b47eb 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -29,10 +29,15 @@ ) from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEConfig, + FusedMoEParallelConfig, + RoutingMethodType, int4_w4a16_moe_quant_config, int8_w8a16_moe_quant_config, + mxfp4_w4a16_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, batched_fused_marlin_moe, fused_marlin_moe, ) @@ -1007,6 +1012,61 @@ def test_fused_marlin_moe( torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0) +def test_marlin_experts_apply_forwards_gemm1_clamp_limit(monkeypatch): + captured: dict[str, float | None] = {} + + def fake_fused_marlin_moe(**kwargs): + captured["clamp_limit"] = kwargs.get("clamp_limit") + kwargs["output"].zero_() + + monkeypatch.setattr( + "vllm.model_executor.layers.fused_moe.fused_marlin_moe." + "fused_marlin_moe", + fake_fused_marlin_moe, + ) + + clamp_limit = 10.0 + moe_config = FusedMoEConfig( + num_experts=2, + experts_per_token=1, + hidden_dim=16, + intermediate_size_per_partition=8, + num_local_experts=2, + num_logical_experts=2, + activation=MoEActivation.SILU, + device="cpu", + routing_method=RoutingMethodType.Default, + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + in_dtype=torch.bfloat16, + ) + quant_config = mxfp4_w4a16_moe_quant_config( + w1_scale=torch.empty(0), + w2_scale=torch.empty(0), + gemm1_clamp_limit=clamp_limit, + ) + experts = MarlinExperts(moe_config=moe_config, quant_config=quant_config) + + experts.apply( + output=torch.empty((1, 16), dtype=torch.bfloat16), + hidden_states=torch.empty((1, 16), dtype=torch.bfloat16), + w1=torch.empty((2, 1, 1), dtype=torch.int32), + w2=torch.empty((2, 1, 1), dtype=torch.int32), + topk_weights=torch.ones((1, 1), dtype=torch.float32), + topk_ids=torch.zeros((1, 1), dtype=torch.int32), + activation=MoEActivation.SILU, + global_num_experts=2, + expert_map=None, + a1q_scale=None, + a2_scale=None, + workspace13=torch.empty(0), + workspace2=torch.empty(0), + expert_tokens_meta=None, + apply_router_weight_on_input=False, + ) + + assert captured["clamp_limit"] == clamp_limit + + @pytest.mark.flaky(reruns=2) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @pytest.mark.parametrize("m", [1, 256]) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 3487ac1766e6..416d871e24f4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -769,6 +769,7 @@ def apply( sort_indices2=self.w2_g_idx_sort_indices, is_k_full=self.is_k_full, input_dtype=self.input_dtype, + clamp_limit=self.gemm1_clamp_limit, ) return From a5ce0d7d068fb86d391dcdc4a7c8fc2d76145566 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 5 May 2026 11:14:36 +0800 Subject: [PATCH 05/11] Fix DeepSeek V4 MLA prefix cache reuse Protect hybrid-aligned DeepSeek V4 MLA prompt cache blocks so they survive decode and unrelated long-session cache churn. Keep common-prefix accounting aware of the extra protection reference and cover compressor-state SlidingWindowMLA groups in a regression test. Co-authored-by: OpenAI Codex Signed-off-by: jasl --- tests/v1/core/test_prefix_caching.py | 120 ++++++++++++++++ vllm/v1/core/kv_cache_coordinator.py | 2 + vllm/v1/core/single_type_kv_cache_manager.py | 140 ++++++++++++++++++- 3 files changed, 259 insertions(+), 3 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index c35c38911a1a..0db48d81187a 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -36,6 +36,8 @@ KVCacheConfig, KVCacheGroupSpec, MambaSpec, + MLAAttentionSpec, + SlidingWindowMLASpec, SlidingWindowSpec, ) @@ -2573,6 +2575,124 @@ def test_can_fit_full_sequence_swa_cap_admits_long_prompt(): ) +def test_deepseek_v4_mla_keeps_hybrid_aligned_prompt_blocks_after_decode(): + hash_block_size = 2 + full_block_size = 8 + swa_block_size = 2 + prompt_tokens = 35 + chunk_tokens = 4 * full_block_size + expected_hit_tokens = ( + (prompt_tokens - 1) // full_block_size * full_block_size + ) + + config = KVCacheConfig( + num_blocks=70, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=full_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_0"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + sliding_window=2 * swa_block_size, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_1"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + sliding_window=2 * swa_block_size, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_compressor_state"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=2 * swa_block_size, + ), + ), + ], + ) + manager = KVCacheManager( + config, + max_model_len=128, + max_num_batched_tokens=chunk_tokens, + enable_caching=True, + hash_block_size=hash_block_size, + ) + + def run_request(request: Request, num_decode_tokens: int) -> int: + computed_blocks, num_computed_tokens = manager.get_computed_blocks(request) + computed_so_far = num_computed_tokens + remaining_prompt_tokens = request.num_prompt_tokens - num_computed_tokens + first_chunk = True + while remaining_prompt_tokens > 0: + num_new_tokens = min(chunk_tokens, remaining_prompt_tokens) + allocated = manager.allocate_slots( + request, + num_new_tokens, + num_computed_tokens if first_chunk else 0, + computed_blocks if first_chunk else None, + ) + assert allocated is not None + computed_so_far += num_new_tokens + request.num_computed_tokens = computed_so_far + remaining_prompt_tokens -= num_new_tokens + first_chunk = False + + for i in range(num_decode_tokens): + request.append_output_token_ids(10_000 + i) + allocated = manager.allocate_slots(request, 1) + assert allocated is not None + computed_so_far += 1 + request.num_computed_tokens = computed_so_far + return num_computed_tokens + + prompt_a = list(range(prompt_tokens)) + req_a = make_request("a", prompt_a, hash_block_size, sha256) + assert run_request(req_a, num_decode_tokens=0) == 0 + manager.free(req_a) + + warm_a = make_request("warm_a", prompt_a, hash_block_size, sha256) + assert run_request(warm_a, num_decode_tokens=8) == expected_hit_tokens + assert manager.get_num_common_prefix_blocks("warm_a")[0] >= ( + expected_hit_tokens // full_block_size + ) + manager.free(warm_a) + + pressure_blocks = manager.block_pool.get_new_blocks( + manager.block_pool.get_num_free_blocks() + ) + manager.block_pool.free_blocks(reversed(pressure_blocks)) + + req_a_again = make_request("a_again", prompt_a, hash_block_size, sha256) + _, num_computed_tokens = manager.get_computed_blocks(req_a_again) + assert num_computed_tokens == expected_hit_tokens + + def test_can_fit_full_sequence_full_attention_still_gates_oversized(): """The cap only loosens the SWA group; a prompt that exceeds the full-attention pool capacity must still be rejected.""" diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 65993e804153..ac88957afe11 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -475,6 +475,8 @@ def verify_and_split_kv_cache_groups(self) -> None: # block cache hit yet. block_sizes = [spec.block_size for spec, _, _ in attention_groups] self.lcm_block_size = lcm(*block_sizes) + for manager in self.single_type_managers: + manager.cache_alignment_tokens = self.lcm_block_size # Attention-group indices (into ``self.attention_groups``) that # contain at least one EAGLE/MTP KV cache group. diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index e8d3a6f75688..63a6b4f1f267 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from abc import ABC, abstractmethod -from collections import defaultdict +from collections import defaultdict, deque from collections.abc import Sequence from vllm.utils.math_utils import cdiv @@ -42,6 +42,7 @@ def __init__( dcp_world_size: int = 1, pcp_world_size: int = 1, max_admission_blocks_per_request: int | None = None, + max_model_len: int | None = None, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -65,6 +66,8 @@ def __init__( self.block_pool = block_pool self.enable_caching = enable_caching self._max_admission_blocks_per_request = max_admission_blocks_per_request + self.max_model_len = max_model_len + self.cache_alignment_tokens = self.block_size self.new_block_ids: list[int] = [] # Mapping from request ID to blocks to track the blocks allocated @@ -80,6 +83,8 @@ def __init__( self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block + self._protected_prompt_block_ids: set[int] = set() + self._protected_prompt_block_queue: deque[int] = deque() @classmethod def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]): @@ -274,6 +279,48 @@ def take_new_block_ids(self) -> list[int]: self.new_block_ids = [] return ids + def _max_protected_prompt_blocks(self) -> int | None: + if self.max_model_len is None: + return None + return 2 * cdiv(max(1, self.max_model_len), self.block_size) + + def _protect_prompt_blocks(self, blocks: Sequence[KVCacheBlock]) -> None: + if not self.enable_caching: + return + + protected: list[KVCacheBlock] = [] + for block in blocks: + if ( + block.is_null + or block.block_hash is None + or block.block_id in self._protected_prompt_block_ids + ): + continue + protected.append(block) + self._protected_prompt_block_ids.add(block.block_id) + self._protected_prompt_block_queue.append(block.block_id) + + if not protected: + return + + # Keep an extra reference for prompt blocks that must survive after + # their request releases its normal runtime reference. Later request + # reuse increments/decrements the runtime reference as usual. + self.block_pool.touch(protected) + self._trim_protected_prompt_blocks() + + def _trim_protected_prompt_blocks(self) -> None: + max_blocks = self._max_protected_prompt_blocks() + if max_blocks is None: + return + + while len(self._protected_prompt_block_ids) > max_blocks: + block_id = self._protected_prompt_block_queue.popleft() + if block_id not in self._protected_prompt_block_ids: + continue + self._protected_prompt_block_ids.remove(block_id) + self.block_pool.free_blocks([self.block_pool.blocks[block_id]]) + def cache_blocks(self, request: Request, num_tokens: int) -> None: """ Cache the blocks for the request. @@ -504,6 +551,54 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: return num_common_blocks +class MLAAttentionManager(FullAttentionManager): + """KV cache manager for DeepSeek V4 compressed MLA cache.""" + + def _should_protect_prompt_blocks(self) -> bool: + return ( + self.kv_cache_spec.model_version == "deepseek_v4" + or self.kv_cache_spec.cache_dtype_str == "fp8_ds_mla" + or self.kv_cache_spec.compress_ratio > 1 + ) + + def cache_blocks(self, request: Request, num_tokens: int) -> None: + super().cache_blocks(request, num_tokens) + if ( + not self._should_protect_prompt_blocks() + or num_tokens < request.num_prompt_tokens + or request.num_prompt_tokens <= 1 + ): + return + + max_cache_hit_length = request.num_prompt_tokens - 1 + aligned_cache_hit_length = ( + max_cache_hit_length + // self.cache_alignment_tokens + * self.cache_alignment_tokens + ) + num_hit_blocks = aligned_cache_hit_length // self.block_size + if num_hit_blocks == 0: + return + + self._protect_prompt_blocks( + self.req_to_blocks[request.request_id][:num_hit_blocks] + ) + + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: + blocks = self.req_to_blocks[running_request_id] + num_common_blocks = 0 + expected_ref_cnt = len(self.req_to_blocks) + for block in blocks: + ref_cnt = block.ref_cnt + if block.block_id in self._protected_prompt_block_ids: + ref_cnt -= 1 + if ref_cnt == expected_ref_cnt: + num_common_blocks += 1 + else: + break + return num_common_blocks + + class SlidingWindowManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) @@ -641,6 +736,42 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: return 0 +class SlidingWindowMLAManager(SlidingWindowManager): + """KV cache manager for DeepSeek V4's sliding-window MLA cache. + + During decode, the live sliding window can move past the prompt boundary. + The blocks around the hybrid-aligned prompt boundary are still the suffix + needed for a future prefix-cache hit of the same prompt. + """ + + def cache_blocks(self, request: Request, num_tokens: int) -> None: + super().cache_blocks(request, num_tokens) + if not self.enable_caching or num_tokens < request.num_prompt_tokens: + return + if request.num_prompt_tokens <= 1: + return + + max_cache_hit_length = request.num_prompt_tokens - 1 + aligned_cache_hit_length = ( + max_cache_hit_length + // self.cache_alignment_tokens + * self.cache_alignment_tokens + ) + if aligned_cache_hit_length <= 0: + return + + aligned_num_hit_blocks = aligned_cache_hit_length // self.block_size + last_full_prompt_block = max_cache_hit_length // self.block_size + contiguous_blocks = cdiv(self.sliding_window - 1, self.block_size) + first_protected_block = max(0, aligned_num_hit_blocks - contiguous_blocks) + last_protected_block = max(aligned_num_hit_blocks, last_full_prompt_block) + blocks = self.req_to_blocks[request.request_id] + protected_blocks = blocks[ + first_protected_block : min(last_protected_block, len(blocks)) + ] + self._protect_prompt_blocks(protected_blocks) + + class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) @@ -1124,6 +1255,7 @@ def __init__( kv_cache_group_id: int, dcp_world_size: int = 1, pcp_world_size: int = 1, + max_model_len: int | None = None, ): super().__init__( kv_cache_spec, @@ -1132,6 +1264,7 @@ def __init__( kv_cache_group_id, dcp_world_size, pcp_world_size, + max_model_len=max_model_len, ) sink_len = kv_cache_spec.sink_len assert sink_len is not None and sink_len > 0 and sink_len % self.block_size == 0 @@ -1142,9 +1275,9 @@ def __init__( spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, TQFullAttentionSpec: FullAttentionManager, - MLAAttentionSpec: FullAttentionManager, + MLAAttentionSpec: MLAAttentionManager, SlidingWindowSpec: SlidingWindowManager, - SlidingWindowMLASpec: SlidingWindowManager, + SlidingWindowMLASpec: SlidingWindowMLAManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, CrossAttentionSpec: CrossAttentionManager, @@ -1159,6 +1292,7 @@ def get_manager_for_kv_cache_spec( **kwargs, ) -> SingleTypeKVCacheManager: manager_class = spec_manager_map[type(kv_cache_spec)] + kwargs["max_model_len"] = max_model_len # SlidingWindow / ChunkedLocalAttention managers recycle blocks across # chunks; the runtime admission cap must match the recycling-aware bound # the startup pool sizer uses (single source of truth: the spec method). From 1d6f5c4ebcd11d1150047132c2ffc4f0e68e6b8a Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 5 May 2026 15:22:45 +0800 Subject: [PATCH 06/11] Reserve DeepSeek V4 prefill workspace during profiling Co-authored-by: OpenAI Codex Signed-off-by: jasl --- .../test_deepseek_v4_sparse_mla_reference.py | 95 +++++++++++++++++++ .../layers/deepseek_v4_attention.py | 85 +++++++++++++++++ 2 files changed, 180 insertions(+) diff --git a/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py index f2e43c59ce61..5a8c045f9f9e 100644 --- a/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py +++ b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py @@ -144,6 +144,101 @@ def get_simultaneous(self, *shapes_and_dtypes): assert output.dtype == torch.bfloat16 +def test_dummy_attention_impl_reserves_prefill_workspace(monkeypatch) -> None: + class FakeMLAAttn: + def __init__(self) -> None: + self.reserved = False + + def _reserve_prefill_workspace(self) -> None: + self.reserved = True + + def __call__(self, *args, **kwargs) -> None: + raise AssertionError("dummy run must not execute real attention") + + mla_attn = FakeMLAAttn() + layer = object.__new__( + deepseek_v4_attention_module.DeepseekV4MultiHeadLatentAttentionWrapper + ) + layer.q_lora_rank = 2 + layer.head_dim = 4 + layer.n_local_heads = 2 + layer.padded_heads = 64 + layer.indexer = None + layer.compressor = None + layer.wq_b = lambda qr: torch.ones(qr.shape[0], 8) + layer.q_norm = SimpleNamespace(weight=SimpleNamespace(data=torch.empty(0))) + layer.kv_norm = SimpleNamespace(weight=SimpleNamespace(data=torch.empty(0))) + layer.eps = 1e-6 + layer.mla_attn = mla_attn + layer.attn_gemm_parallel_execute = lambda hidden_states: ( + torch.zeros(hidden_states.shape[0], 6), + None, + None, + None, + ) + layer._fused_qnorm_rope_kv_insert = lambda *args, **kwargs: None + + monkeypatch.setattr( + deepseek_v4_attention_module, + "get_forward_context", + lambda: SimpleNamespace(attn_metadata=None), + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "fused_q_kv_rmsnorm", + lambda qr, kv, *args, **kwargs: (qr, kv), + ) + + out = torch.ones((3, 64, 4)) + layer.attention_impl( + hidden_states=torch.zeros((3, 6)), + positions=torch.arange(3), + out=out, + ) + + assert mla_attn.reserved is True + assert torch.count_nonzero(out) == 0 + + +def test_prefill_workspace_reservation_specs_match_forward_prefill_bounds( + monkeypatch, +) -> None: + attn = SimpleNamespace( + max_model_len=16_384, + max_num_batched_tokens=8192, + compress_ratio=4, + window_size=128, + head_dim=512, + num_heads=64, + topk_indices_buffer=torch.empty((8192, 2048), dtype=torch.int32), + indexer=None, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "is_triton_sparse_mla_enabled_for_platform", + lambda: True, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "triton_sparse_mla_query_chunk_size", + lambda: 256, + ) + + specs = ( + deepseek_v4_attention_module.DeepseekV4MLAAttention. + _prefill_workspace_reservation_specs(attn) + ) + + assert specs == ( + ((4, 12_415, 512), torch.bfloat16), + ((8192, 2176), torch.int32), + ((8192,), torch.int32), + ((256, 64), torch.float32), + ((256, 64), torch.float32), + ((256, 64, 512), torch.float32), + ) + + def test_triton_sparse_mla_default_topk_chunk_size(monkeypatch) -> None: monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", raising=False) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 446935c7c739..613a765019b9 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -84,6 +84,7 @@ from vllm.v1.attention.backends.mla.sparse_mla_env import ( disable_triton_sparse_mla_cudagraphs_if_enabled, is_triton_sparse_mla_enabled, + is_triton_sparse_mla_enabled_for_platform, triton_sparse_mla_matmul_decode_enabled, triton_sparse_mla_query_chunk_size, triton_sparse_mla_topk_chunk_size, @@ -128,6 +129,21 @@ def _sparse_mla_prefill_workspace_bounds( return compressed_region_size, compressed_region_size + max_gather_len +def _sparse_mla_prefill_gather_len_upper_bound( + *, + max_model_len: int, + max_num_batched_tokens: int, + window_size: int, +) -> tuple[int, int]: + max_query_chunk_tokens = max(1, min(max_model_len, max_num_batched_tokens)) + max_prefix_len = max(max_model_len - max_query_chunk_tokens, 0) + max_gather_len = max_query_chunk_tokens + min( + max_prefix_len, + max(window_size - 1, 0), + ) + return max_query_chunk_tokens, max_gather_len + + def _deepseek_v4_fp8_einsum_config( capability_major: int, ) -> tuple[tuple[int, int, int], bool]: @@ -176,6 +192,7 @@ def _allocate_deepseek_v4_wo_a_output( # workspace allocated at _forward_prefill (and the matching profile-time # reservation in attention_impl's dummy-run branch). PREFILL_CHUNK_SIZE = 4 +_DEFAULT_SPARSE_MLA_TOPK_TOKENS = 2048 @dataclass @@ -589,6 +606,7 @@ def wq_b_kv_insert() -> torch.Tensor: # Handle dummy run (no metadata). if not isinstance(attn_metadata, dict): out.zero_() + self.mla_attn._reserve_prefill_workspace() return # Pad q to FlashMLA-required head count (64 or 128) @@ -870,6 +888,73 @@ def __init__( self.kv_cache = torch.tensor([]) + def _prefill_workspace_topk_bound(self) -> int: + if self.compress_ratio <= 1: + return 0 + if ( + self.topk_indices_buffer is not None + and self.topk_indices_buffer.ndim > 0 + and self.topk_indices_buffer.shape[-1] > 0 + ): + return int(self.topk_indices_buffer.shape[-1]) + indexer_topk = getattr(self.indexer, "topk_tokens", None) + if indexer_topk is not None: + return int(indexer_topk) + return _DEFAULT_SPARSE_MLA_TOPK_TOKENS + + def _prefill_workspace_reservation_specs( + self, + ) -> tuple[tuple[tuple[int, ...], torch.dtype], ...]: + max_model_len = max(1, int(self.max_model_len)) + max_num_batched_tokens = max(1, int(self.max_num_batched_tokens)) + window_size = max(1, int(self.window_size)) + compress_ratio = max(1, int(self.compress_ratio)) + head_dim = int(self.head_dim) + num_heads = int(self.num_heads) + + max_query_chunk_tokens, max_gather_len = ( + _sparse_mla_prefill_gather_len_upper_bound( + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + window_size=window_size, + ) + ) + if compress_ratio <= 1: + m_bound = max_gather_len + else: + compressed_region_size = max_model_len // compress_ratio + m_bound = compressed_region_size + max_gather_len + + combined_topk = sparse_prefill_combined_topk_size( + DeepseekV4MLAAttention._prefill_workspace_topk_bound(self), + window_size, + ) + specs: list[tuple[tuple[int, ...], torch.dtype]] = [ + ((PREFILL_CHUNK_SIZE, m_bound, head_dim), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ] + if is_triton_sparse_mla_enabled_for_platform(): + query_chunk_size = min( + max_query_chunk_tokens, + triton_sparse_mla_query_chunk_size(), + ) + specs.extend( + [ + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads, head_dim), torch.float32), + ] + ) + return tuple(specs) + + def _reserve_prefill_workspace(self) -> None: + try: + workspace_manager = current_workspace_manager() + except AssertionError: + return + workspace_manager.get_simultaneous(*self._prefill_workspace_reservation_specs()) + def get_attn_backend(self) -> type[AttentionBackend]: return DeepseekV4FlashMLASparseBackend From e734ace5ff5a71a682d357cee0fff5b9b826e909 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 5 May 2026 18:50:37 +0800 Subject: [PATCH 07/11] Release DeepSeek V4 protected prompt refs under pressure Co-authored-by: OpenAI Codex Signed-off-by: jasl --- tests/v1/core/test_prefix_caching.py | 91 ++++++++++++++++++++ vllm/v1/core/kv_cache_coordinator.py | 11 +++ vllm/v1/core/kv_cache_manager.py | 11 ++- vllm/v1/core/single_type_kv_cache_manager.py | 24 +++++- 4 files changed, 134 insertions(+), 3 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 0db48d81187a..c9f9fdd58fe2 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -2693,6 +2693,97 @@ def run_request(request: Request, num_decode_tokens: int) -> int: assert num_computed_tokens == expected_hit_tokens +def test_deepseek_v4_mla_protected_prompt_blocks_do_not_block_admission(): + block_size = 8 + prompt_tokens = 4 * block_size + 3 + protected_blocks_per_prompt = (prompt_tokens - 1) // block_size + num_prompts = 10 + num_blocks = 80 + manager = KVCacheManager( + KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ) + ], + ), + max_model_len=512, + max_num_batched_tokens=128, + enable_caching=True, + hash_block_size=block_size, + ) + mla_manager = manager.coordinator.single_type_managers[0] + + for i in range(num_prompts): + prompt = list(range(i * 1000, i * 1000 + prompt_tokens)) + req = make_request(f"protected_{i}", prompt, block_size, sha256) + assert manager.allocate_slots(req, prompt_tokens) is not None + req.num_computed_tokens = prompt_tokens + manager.free(req) + + assert len(mla_manager._protected_prompt_block_ids) == ( + num_prompts * protected_blocks_per_prompt + ) + assert manager.block_pool.get_num_free_blocks() < 64 + + long_req = make_request( + "long", + list(range(100_000, 100_000 + 64 * block_size)), + block_size, + sha256, + ) + assert ( + manager.allocate_slots(long_req, block_size, full_sequence_must_fit=True) + is not None + ) + + +def test_reset_prefix_cache_releases_deepseek_v4_mla_protected_blocks(): + block_size = 8 + prompt_tokens = 4 * block_size + 3 + manager = KVCacheManager( + KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ) + ], + ), + max_model_len=512, + max_num_batched_tokens=128, + enable_caching=True, + hash_block_size=block_size, + ) + + req = make_request("protected", list(range(prompt_tokens)), block_size, sha256) + assert manager.allocate_slots(req, prompt_tokens) is not None + req.num_computed_tokens = prompt_tokens + manager.free(req) + + assert manager.coordinator.single_type_managers[0]._protected_prompt_block_ids + assert manager.reset_prefix_cache() + + def test_can_fit_full_sequence_full_attention_still_gates_oversized(): """The cap only loosens the SWA group; a prompt that exceeds the full-attention pool capacity must still be rejected.""" diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index ac88957afe11..3f2ac7cdedb7 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -250,6 +250,17 @@ def remove_skipped_blocks( for manager in self.single_type_managers: manager.remove_skipped_blocks(request_id, total_computed_tokens) + def release_protected_prompt_blocks( + self, target_free_blocks: int | None = None + ) -> None: + for manager in self.single_type_managers: + if ( + target_free_blocks is not None + and self.block_pool.get_num_free_blocks() >= target_free_blocks + ): + return + manager.release_protected_prompt_blocks(target_free_blocks) + def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]: """ Get the blocks for the request. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 431776870cf4..2f3c5abf3066 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -345,7 +345,7 @@ def allocate_slots( num_tokens_main_model=full_num_tokens, apply_admission_cap=True, ) - if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + if not self._has_enough_free_blocks(num_blocks_to_allocate): return None num_tokens_main_model = total_computed_tokens + num_new_tokens @@ -373,7 +373,7 @@ def allocate_slots( num_tokens_main_model=num_tokens_main_model, ) - if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + if not self._has_enough_free_blocks(num_blocks_to_allocate): # Cannot allocate new blocks return None @@ -446,6 +446,12 @@ def evict_blocks(self, block_ids: set[int]) -> None: """ self.block_pool.evict_blocks(block_ids) + def _has_enough_free_blocks(self, num_blocks: int) -> bool: + if num_blocks <= self.block_pool.get_num_free_blocks(): + return True + self.coordinator.release_protected_prompt_blocks(num_blocks) + return num_blocks <= self.block_pool.get_num_free_blocks() + def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalidate prefix caching after the weights are updated, @@ -455,6 +461,7 @@ def reset_prefix_cache(self) -> bool: bool: True if the prefix cache is successfully reset, False otherwise. """ + self.coordinator.release_protected_prompt_blocks() if not self.block_pool.reset_prefix_cache(): return False if self.log_stats: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 63a6b4f1f267..651f52f6b19d 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -315,11 +315,33 @@ def _trim_protected_prompt_blocks(self) -> None: return while len(self._protected_prompt_block_ids) > max_blocks: + if not self._release_one_protected_prompt_block(): + return + + def _release_one_protected_prompt_block(self) -> bool: + while self._protected_prompt_block_queue: block_id = self._protected_prompt_block_queue.popleft() if block_id not in self._protected_prompt_block_ids: continue + self._protected_prompt_block_ids.remove(block_id) - self.block_pool.free_blocks([self.block_pool.blocks[block_id]]) + block = self.block_pool.blocks[block_id] + if block.ref_cnt > 0: + self.block_pool.free_blocks([block]) + return True + return False + + def release_protected_prompt_blocks( + self, target_free_blocks: int | None = None + ) -> None: + while self._protected_prompt_block_ids: + if ( + target_free_blocks is not None + and self.block_pool.get_num_free_blocks() >= target_free_blocks + ): + return + if not self._release_one_protected_prompt_block(): + return def cache_blocks(self, request: Request, num_tokens: int) -> None: """ From bd9fde75ea4ece787859256c14e06467d1298e16 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 5 May 2026 23:39:35 +0800 Subject: [PATCH 08/11] Rename DeepSeek V4 SM12x FP8 einsum path Signed-off-by: jasl --- .../test_deepseek_v4_sparse_mla_reference.py | 27 +++++++++++-------- .../layers/deepseek_v4_attention.py | 8 +++--- .../ops/deepseek_v4_ops/fp8_einsum.py | 12 ++++----- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py index 5a8c045f9f9e..5a809cd09890 100644 --- a/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py +++ b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py @@ -64,8 +64,9 @@ dequantize_combined_sparse_mla_decode_kv, dequantize_global_slots_k_cache, ) +from vllm.v1.attention.ops.deepseek_v4_ops import fp8_einsum as fp8_einsum_module from vllm.v1.attention.ops.deepseek_v4_ops.fp8_einsum import ( - deepseek_v4_sm12_fp8_einsum, + deepseek_v4_sm12x_fp8_einsum, ) from vllm.v1.kv_cache_interface import MLAAttentionSpec, SlidingWindowMLASpec @@ -224,10 +225,8 @@ def test_prefill_workspace_reservation_specs_match_forward_prefill_bounds( lambda: 256, ) - specs = ( - deepseek_v4_attention_module.DeepseekV4MLAAttention. - _prefill_workspace_reservation_specs(attn) - ) + attention_cls = deepseek_v4_attention_module.DeepseekV4MLAAttention + specs = attention_cls._prefill_workspace_reservation_specs(attn) assert specs == ( ((4, 12_415, 512), torch.bfloat16), @@ -487,9 +486,15 @@ def test_deepseek_v4_fp8_einsum_config_for_sm12x( ) +def test_deepseek_v4_fp8_einsum_uses_sm12x_names() -> None: + assert hasattr(fp8_einsum_module, "_deepseek_v4_sm12x_fp8_einsum_kernel") + assert hasattr(fp8_einsum_module, "deepseek_v4_sm12x_fp8_einsum") + assert not hasattr(fp8_einsum_module, "_deepseek_v4_sm12_fp8_einsum_kernel") + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") @pytest.mark.parametrize("use_e8m0_scale", [False, True]) -def test_deepseek_v4_sm12_triton_fp8_einsum_matches_deepgemm_reference( +def test_deepseek_v4_sm12x_triton_fp8_einsum_matches_deepgemm_reference( use_e8m0_scale: bool, ) -> None: if use_e8m0_scale and not hasattr(torch, "float8_e8m0fnu"): @@ -602,7 +607,7 @@ def fake_sm12_fp8_einsum( ) monkeypatch.setattr( deepseek_v4_attention_module, - "deepseek_v4_sm12_fp8_einsum", + "deepseek_v4_sm12x_fp8_einsum", fake_sm12_fp8_einsum, ) @@ -644,7 +649,7 @@ def fake_sm12_fp8_einsum( @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") -def test_deepseek_v4_sm12_triton_fp8_einsum_primitive_matches_reference() -> None: +def test_deepseek_v4_sm12x_triton_fp8_einsum_primitive_matches_reference() -> None: torch.manual_seed(0) num_tokens = 17 num_groups = 4 @@ -684,14 +689,14 @@ def test_deepseek_v4_sm12_triton_fp8_einsum_primitive_matches_reference() -> Non actual = torch.empty_like(expected) fp8_einsum("bhr,hdr->bhd", (a, a_scale), (b, b_scale), expected, recipe=recipe) - deepseek_v4_sm12_fp8_einsum(a, a_scale, b, b_scale, actual) + deepseek_v4_sm12x_fp8_einsum(a, a_scale, b, b_scale, actual) _assert_fp8_einsum_close(actual, expected) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") @pytest.mark.parametrize("num_groups", [1, 2, 4]) -def test_deepseek_v4_sm12_triton_fp8_einsum_supports_tp_local_group_counts( +def test_deepseek_v4_sm12x_triton_fp8_einsum_supports_tp_local_group_counts( num_groups: int, ) -> None: torch.manual_seed(18 + num_groups) @@ -732,7 +737,7 @@ def test_deepseek_v4_sm12_triton_fp8_einsum_supports_tp_local_group_counts( actual = torch.empty_like(expected) fp8_einsum("bhr,hdr->bhd", (a, a_scale), (b, b_scale), expected, recipe=recipe) - deepseek_v4_sm12_fp8_einsum(a, a_scale, b, b_scale, actual) + deepseek_v4_sm12x_fp8_einsum(a, a_scale, b, b_scale, actual) _assert_fp8_einsum_close(actual, expected) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 613a765019b9..ba520ce476a0 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -32,7 +32,7 @@ sparse_prefill_combined_topk_size, ) from vllm.v1.attention.ops.deepseek_v4_ops.fp8_einsum import ( - deepseek_v4_sm12_fp8_einsum, + deepseek_v4_sm12x_fp8_einsum, ) from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( rocm_forward_decode_fallback, @@ -152,7 +152,7 @@ def _deepseek_v4_fp8_einsum_config( return (1, 128, 128), False -def _use_deepseek_v4_sm12_triton_fp8_einsum( +def _use_deepseek_v4_sm12x_triton_fp8_einsum( equation: str, recipe: list[int], b_scale: torch.Tensor, @@ -750,8 +750,8 @@ def deepseek_v4_fp8_einsum( if b_groups != num_groups: b_scale = b_scale.narrow(0, group_start, num_groups) - if _use_deepseek_v4_sm12_triton_fp8_einsum(equation, recipe, b_scale): - deepseek_v4_sm12_fp8_einsum(a, a_scale, b, b_scale, out) + if _use_deepseek_v4_sm12x_triton_fp8_einsum(equation, recipe, b_scale): + deepseek_v4_sm12x_fp8_einsum(a, a_scale, b, b_scale, out) return fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py b/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py index 71a6199e1d9d..652dc7be7907 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py @@ -14,7 +14,7 @@ def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: @triton.jit -def _deepseek_v4_sm12_fp8_einsum_kernel( +def _deepseek_v4_sm12x_fp8_einsum_kernel( a_ptr, a_scale_ptr, b_ptr, @@ -68,8 +68,7 @@ def _deepseek_v4_sm12_fp8_einsum_kernel( + group * b_stride_group + out_offsets[None, :] * b_stride_out + hidden[:, None] * b_stride_hidden, - mask=(out_offsets[None, :] < out_rank) - & (hidden[:, None] < hidden_size), + mask=(out_offsets[None, :] < out_rank) & (hidden[:, None] < hidden_size), other=0.0, ) raw = tl.dot(a, b, out_dtype=tl.float32) @@ -98,12 +97,11 @@ def _deepseek_v4_sm12_fp8_einsum_kernel( + group * out_stride_group + out_offsets[None, :] * out_stride_rank, accum, - mask=(token_offsets[:, None] < num_tokens) - & (out_offsets[None, :] < out_rank), + mask=(token_offsets[:, None] < num_tokens) & (out_offsets[None, :] < out_rank), ) -def deepseek_v4_sm12_fp8_einsum( +def deepseek_v4_sm12x_fp8_einsum( a: torch.Tensor, a_scale: torch.Tensor, b: torch.Tensor, @@ -144,7 +142,7 @@ def deepseek_v4_sm12_fp8_einsum( triton.cdiv(out_rank, block_out), num_groups, ) - _deepseek_v4_sm12_fp8_einsum_kernel[grid]( + _deepseek_v4_sm12x_fp8_einsum_kernel[grid]( a, a_scale, b, From 0e186bc170bde9b9b4378385d47de836a9373ada Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 6 May 2026 02:16:58 +0800 Subject: [PATCH 09/11] Add dense sparse MLA decode golden tests Signed-off-by: jasl --- .../test_deepseek_v4_sparse_mla_reference.py | 281 ++++++++++++++++++ 1 file changed, 281 insertions(+) diff --git a/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py index 5a809cd09890..dbdc6f4e689c 100644 --- a/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py +++ b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py @@ -874,6 +874,64 @@ def _write_fp8_ds_mla_token( return torch.cat([expected_nope, rope.float()]).to(torch.bfloat16) +def _materialize_global_fp8_ds_mla_slots( + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + block_size: int, +) -> torch.Tensor: + gathered = torch.zeros( + *slot_ids.shape, + 512, + dtype=torch.bfloat16, + device=k_cache.device, + ) + for token_idx, row in enumerate(slot_ids.detach().cpu().tolist()): + for candidate_idx, slot in enumerate(row): + if slot >= 0: + gathered[token_idx, candidate_idx] = _write_fp8_ds_mla_token( + k_cache, + slot, + block_size, + ) + return gathered + + +def _materialize_paged_fp8_ds_mla_window( + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, +) -> torch.Tensor: + seq_lens_cpu = seq_lens.detach().cpu().tolist() + gather_lens_cpu = gather_lens.detach().cpu().tolist() + block_table_cpu = block_table.detach().cpu().tolist() + max_gather_len = max(gather_lens_cpu) + gathered = torch.zeros( + len(seq_lens_cpu), + max_gather_len, + 512, + dtype=torch.bfloat16, + device=k_cache.device, + ) + for token_idx, (seq_len, gather_len) in enumerate( + zip(seq_lens_cpu, gather_lens_cpu) + ): + start_pos = seq_len - gather_len + for gather_idx in range(gather_len): + logical_pos = start_pos + gather_idx + logical_block = logical_pos // block_size + block_offset = logical_pos % block_size + physical_block = block_table_cpu[token_idx][logical_block] + physical_slot = physical_block * block_size + block_offset + gathered[token_idx, gather_idx] = _write_fp8_ds_mla_token( + k_cache, + physical_slot, + block_size, + ) + return gathered + + def test_reference_attention_no_sink_matches_logsumexp() -> None: torch.manual_seed(0) scale = 0.25 @@ -2096,6 +2154,105 @@ def test_triton_fp8ds_global_paged_attention_with_sink_direct_matches_state_path torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_global_paged_decode_matches_dense_golden_long_offsets() -> None: + torch.manual_seed(71) + compressed_block_size = 4 + swa_block_size = 4 + compressed_cache = torch.zeros( + 8, + compressed_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + swa_cache = torch.zeros( + 8, + swa_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + compressed_slot_ids = torch.tensor( + [ + [0, 5, -1, 11, 17, 2], + [23, -1, 8, 0, 19, 31], + [4, -1, -1, 6, 7, 12], + ], + dtype=torch.int32, + device="cuda", + ) + topk_lens = torch.tensor([5, 6, 0], dtype=torch.int32, device="cuda") + seq_lens = torch.tensor([17, 30, 7], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([5, 6, 4], dtype=torch.int32, device="cuda") + block_table = torch.tensor( + [ + [4, 0, 6, 1, 3, 5, 2, 7], + [7, 2, 4, 0, 6, 1, 5, 3], + [1, 3, 0, 2, 4, 5, 6, 7], + ], + dtype=torch.int32, + device="cuda", + ) + q = torch.randn(3, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + active_heads = 5 + sink = torch.linspace(-0.75, 0.75, active_heads, device="cuda") + scale = 0.0625 + + compressed_kv = _materialize_global_fp8_ds_mla_slots( + compressed_cache, + compressed_slot_ids, + compressed_block_size, + ) + swa_kv = _materialize_paged_fp8_ds_mla_window( + swa_cache, + seq_lens, + gather_lens, + block_table, + swa_block_size, + ) + compressed_offsets = torch.arange( + compressed_slot_ids.shape[1], + device="cuda", + dtype=torch.int32, + ) + swa_offsets = torch.arange(swa_kv.shape[1], device="cuda", dtype=torch.int32) + compressed_valid = (compressed_offsets[None, :] < topk_lens[:, None]) & ( + compressed_slot_ids >= 0 + ) + swa_valid = swa_offsets[None, :] < gather_lens[:, None] + expected = _golden_sink_attention( + q[:, 0, :active_heads], + torch.cat([compressed_kv, swa_kv], dim=1), + torch.cat([compressed_valid, swa_valid], dim=1), + scale, + sink, + ).to(torch.bfloat16) + + actual = torch.empty(3, active_heads, 512, device="cuda", dtype=torch.bfloat16) + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q=q, + compressed_k_cache=compressed_cache, + slot_ids=compressed_slot_ids, + topk_lens=topk_lens, + compressed_block_size=compressed_block_size, + swa_k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + swa_block_size=swa_block_size, + num_compressed_candidates=compressed_slot_ids.shape[1], + num_swa_candidates=swa_kv.shape[1], + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=2, + num_heads=active_heads, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") def test_matmul_sparse_mla_attention_with_sink_matches_reference() -> None: torch.manual_seed(41) @@ -2136,6 +2293,130 @@ def test_matmul_sparse_mla_attention_with_sink_matches_reference() -> None: torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_mtp_matmul_global_slots_decode_matches_dense_golden_long_offsets() -> None: + torch.manual_seed(73) + compressed_block_size = 4 + swa_block_size = 4 + compressed_cache = torch.zeros( + 8, + compressed_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + swa_cache = torch.zeros( + 8, + swa_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + compressed_slot_ids = torch.tensor( + [ + [0, 5, -1, 11, 17, 2], + [23, -1, 8, 0, 19, 31], + [4, -1, -1, 6, 7, 12], + ], + dtype=torch.int32, + device="cuda", + ) + topk_lens = torch.tensor([5, 6, 0], dtype=torch.int32, device="cuda") + swa_slot_ids = torch.tensor( + [ + [3, 4, 13, 14, 15], + [28, 29, 30, 31, -1], + [7, 6, 5, -1, -1], + ], + dtype=torch.int32, + device="cuda", + ) + swa_lens = torch.tensor([5, 4, 3], dtype=torch.int32, device="cuda") + q = torch.randn(3, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + active_heads = 5 + sink = torch.linspace(-0.5, 0.5, active_heads, device="cuda") + scale = 0.0625 + + compressed_kv = _materialize_global_fp8_ds_mla_slots( + compressed_cache, + compressed_slot_ids, + compressed_block_size, + ) + swa_kv = _materialize_global_fp8_ds_mla_slots( + swa_cache, + swa_slot_ids, + swa_block_size, + ) + combined_kv = torch.empty( + 3, + compressed_slot_ids.shape[1] + swa_slot_ids.shape[1], + 512, + dtype=torch.bfloat16, + device="cuda", + ) + dequantize_global_slots_k_cache( + combined_kv[:, : compressed_slot_ids.shape[1]], + compressed_cache, + compressed_slot_ids, + compressed_block_size, + ) + dequantize_global_slots_k_cache( + combined_kv[:, compressed_slot_ids.shape[1] :], + swa_cache, + swa_slot_ids, + swa_block_size, + ) + valid_tokens = torch.empty( + combined_kv.shape[:2], + dtype=torch.bool, + device="cuda", + ) + build_combined_sparse_mla_decode_valid_mask( + valid_tokens, + compressed_slot_ids, + topk_lens, + swa_lens, + ) + + compressed_offsets = torch.arange( + compressed_slot_ids.shape[1], + device="cuda", + dtype=torch.int32, + ) + swa_offsets = torch.arange(swa_slot_ids.shape[1], device="cuda", dtype=torch.int32) + compressed_valid = (compressed_offsets[None, :] < topk_lens[:, None]) & ( + compressed_slot_ids >= 0 + ) + swa_valid = swa_offsets[None, :] < swa_lens[:, None] + expected_kv = torch.cat([compressed_kv, swa_kv], dim=1) + expected_valid = torch.cat([compressed_valid, swa_valid], dim=1) + expected = _golden_sink_attention( + q[:, 0, :active_heads], + expected_kv, + expected_valid, + scale, + sink, + ).to(torch.bfloat16) + + torch.testing.assert_close(combined_kv.float(), expected_kv.float(), rtol=0, atol=0) + torch.testing.assert_close(valid_tokens, expected_valid) + + actual = torch.empty_like(expected) + matmul_sparse_mla_attention_with_sink( + q, + combined_kv, + valid_tokens, + scale, + sink, + actual, + num_heads=active_heads, + value_block_size=512, + candidate_block_size=128, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") def test_matmul_sparse_mla_attention_accepts_bf16_score_buffer() -> None: torch.manual_seed(67) From f8950c33fdf7995834968c1b47467c978fc10812 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 6 May 2026 05:22:01 +0800 Subject: [PATCH 10/11] Clarify Ray cluster device warning Signed-off-by: jasl --- tests/v1/executor/test_ray_utils.py | 46 +++++++++++++++++++++++++++++ vllm/v1/executor/ray_utils.py | 45 ++++++++++++---------------- 2 files changed, 65 insertions(+), 26 deletions(-) diff --git a/tests/v1/executor/test_ray_utils.py b/tests/v1/executor/test_ray_utils.py index 8da9d5459e73..a83b513baca8 100644 --- a/tests/v1/executor/test_ray_utils.py +++ b/tests/v1/executor/test_ray_utils.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import SimpleNamespace + import numpy as np +from vllm.v1.executor import ray_utils from vllm.v1.executor.ray_utils import detach_zero_copy_from_model_runner_output from vllm.v1.outputs import LogprobsLists, LogprobsTensors, ModelRunnerOutput @@ -52,3 +55,46 @@ def test_detach_zero_copy_from_model_runner_output_copies_only_numpy_views(): assert detached_logprobs.sampled_token_ranks.flags.writeable assert detached_logprobs.cu_num_generated_tokens is cu_num_generated_tokens assert output.prompt_logprobs_dict["req-0"] is prompt_logprobs + + +def test_cluster_device_warning_uses_ray_cluster_resources(monkeypatch): + warnings = [] + + monkeypatch.setattr( + ray_utils, + "ray", + SimpleNamespace(cluster_resources=lambda: {"GPU": 2}), + ) + monkeypatch.setattr( + ray_utils.logger, + "warning", + lambda *args, **kwargs: warnings.append(args), + ) + + ray_utils._warn_if_insufficient_cluster_devices( + SimpleNamespace(world_size=2), "GPU" + ) + + assert warnings == [] + + +def test_cluster_device_warning_reports_cluster_shortage(monkeypatch): + warnings = [] + + monkeypatch.setattr( + ray_utils, + "ray", + SimpleNamespace(cluster_resources=lambda: {"GPU": 1}), + ) + monkeypatch.setattr( + ray_utils.logger, + "warning", + lambda *args, **kwargs: warnings.append(args), + ) + + ray_utils._warn_if_insufficient_cluster_devices( + SimpleNamespace(world_size=2), "GPU" + ) + + assert len(warnings) == 1 + assert "distributed world size" in warnings[0][0] diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index 1541b24deaaf..5d3fe0dbd166 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -273,6 +273,22 @@ def assert_ray_available(): ) +def _warn_if_insufficient_cluster_devices( + parallel_config: ParallelConfig, device_str: str +) -> None: + num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) + if parallel_config.world_size > num_devices_in_cluster: + logger.warning( + "The requested distributed world size (%d) exceeds the total " + "number of available %ss (%d) in the Ray cluster. This may result " + "in Ray placement group allocation failures. Check `ray status` " + "and `ray list nodes`, or reduce tensor/pipeline parallel size.", + parallel_config.world_size, + device_str, + num_devices_in_cluster, + ) + + def _verify_bundles( placement_group: "PlacementGroup", parallel_config: ParallelConfig, @@ -549,21 +565,6 @@ def initialize_ray_cluster( if os.environ.get("RAY_USAGE_STATS_ENABLED", "0") != "1": os.environ["RAY_USAGE_STATS_ENABLED"] = "0" - # Prevalidate GPU requirements before Ray processing - if current_platform.is_cuda() and parallel_config.world_size > 1: - available_gpus = current_platform.device_count() - if parallel_config.world_size > available_gpus: - logger.warning( - "Tensor parallel size (%d) exceeds available GPUs (%d). " - "This may result in Ray placement group allocation failures. " - "Consider reducing tensor_parallel_size to %d or less, " - "or ensure your Ray cluster has %d GPUs available.", - parallel_config.world_size, - available_gpus, - available_gpus, - parallel_config.world_size, - ) - if ray.is_initialized(): logger.info("Ray is already initialized. Skipping Ray initialization.") elif current_platform.is_rocm() or current_platform.is_xpu(): @@ -589,6 +590,9 @@ def initialize_ray_cluster( f"current platform {current_platform.device_name} does not support ray." ) + if parallel_config.world_size > 1: + _warn_if_insufficient_cluster_devices(parallel_config, device_str) + # Create or get the placement group for worker processes if parallel_config.placement_group: current_placement_group = parallel_config.placement_group @@ -619,17 +623,6 @@ def initialize_ray_cluster( ) else: logger.info("No current placement group found. Creating a new placement group.") - num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) - # Log a warning message and delay resource allocation failure response. - # Avoid immediate rejection to allow user-initiated placement group - # created and wait cluster to be ready - if parallel_config.world_size > num_devices_in_cluster: - logger.warning( - "The number of required %ss exceeds the total " - "number of available %ss in the placement group.", - device_str, - device_str, - ) # Create a new placement group placement_group_specs: list[dict[str, float]] = [ {device_str: 1.0} for _ in range(parallel_config.world_size) From 0789bc94e6e5360088607fda8d22b66be5f1f7dd Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 5 May 2026 23:33:30 +0800 Subject: [PATCH 11/11] Keep SM12x paged MQA off DeepGEMM metadata Route SM12x sparse MLA decode metadata around DeepGEMM scheduler metadata instead of returning placeholder metadata. Let get_paged_mqa_logits_metadata call the backend normally so unexpected SM12x metadata calls fail through the backend. Also keep SM12x FP8 MQA and paged MQA local fallback dispatch from initializing DeepGEMM before the SM12x guard runs. Co-authored-by: OpenAI Codex Signed-off-by: jasl --- .../test_sm120_deepgemm_fallbacks.py | 114 ++++++++++++++++++ vllm/utils/deep_gemm.py | 8 +- vllm/v1/attention/backends/mla/indexer.py | 11 +- 3 files changed, 126 insertions(+), 7 deletions(-) diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py index 88d337ea5998..3fae48146f2d 100644 --- a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -12,6 +12,7 @@ ) from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.mla import indexer as mla_indexer def test_decode_logits_width_uses_active_context_bound(): @@ -50,6 +51,119 @@ def test_non_sm120_cuda_sparse_indexer_still_requires_deep_gemm(monkeypatch): assert _sparse_indexer_requires_deep_gemm() is True +def test_sm120_paged_mqa_metadata_uses_backend_impl(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: capability == 120, + ) + lazy_init_calls = [] + monkeypatch.setattr( + deep_gemm_utils, "_lazy_init", lambda: lazy_init_calls.append(1) + ) + expected = torch.tensor([[1, 2]], dtype=torch.int32) + + def fake_deep_gemm_metadata(context_lens, block_size, num_sms): + assert context_lens.shape == (2, 1) + assert block_size == 256 + assert num_sms == 4 + return expected + + monkeypatch.setattr( + deep_gemm_utils, + "_get_paged_mqa_logits_metadata_impl", + fake_deep_gemm_metadata, + ) + context_lens = torch.tensor([[1], [3]], dtype=torch.int32) + + metadata = deep_gemm_utils.get_paged_mqa_logits_metadata( + context_lens, block_size=256, num_sms=4 + ) + + assert metadata is expected + assert lazy_init_calls == [1] + + +def test_sm120_mla_indexer_skips_deep_gemm_scheduler_metadata(monkeypatch): + monkeypatch.setattr(current_platform, "is_cuda", lambda: True) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: capability == 120, + ) + monkeypatch.setattr(mla_indexer, "has_deep_gemm", lambda: True) + + assert not mla_indexer._uses_deep_gemm_scheduler_metadata() + + +def test_cuda_mla_indexer_uses_deep_gemm_scheduler_metadata_off_sm12x(monkeypatch): + monkeypatch.setattr(current_platform, "is_cuda", lambda: True) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: False, + ) + monkeypatch.setattr(mla_indexer, "has_deep_gemm", lambda: True) + + assert mla_indexer._uses_deep_gemm_scheduler_metadata() + + +def test_sm120_fp8_mqa_fallbacks_do_not_initialize_deep_gemm(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: capability == 120, + ) + + def fail_lazy_init(): + raise AssertionError("SM120 FP8 MQA should not initialize DeepGEMM") + + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", fail_lazy_init) + + mqa_result = torch.empty(1) + paged_result = torch.empty(1) + calls = [] + + def fake_mqa_fallback(*args, **kwargs): + calls.append("mqa") + return mqa_result + + def fake_paged_fallback(*args, **kwargs): + calls.append("paged") + return paged_result + + monkeypatch.setattr(deep_gemm_utils, "_fp8_mqa_logits_sm12x", fake_mqa_fallback) + monkeypatch.setattr( + deep_gemm_utils, "_fp8_paged_mqa_logits_sm12x", fake_paged_fallback + ) + + assert ( + deep_gemm_utils.fp8_fp4_mqa_logits( + (torch.empty(1, 1, 1), None), + (torch.empty(1, 1), torch.empty(1)), + torch.empty(1, 1), + torch.empty(1, dtype=torch.int32), + torch.empty(1, dtype=torch.int32), + clean_logits=False, + ) + is mqa_result + ) + assert ( + deep_gemm_utils.fp8_fp4_paged_mqa_logits( + (torch.empty(1, 1, 1, 1), None), + torch.empty(1, 1, 1, 5, dtype=torch.uint8), + torch.empty(1, 1), + torch.empty(1, 1, dtype=torch.int32), + torch.empty(1, 1, dtype=torch.int32), + torch.empty(1, dtype=torch.int32), + max_model_len=1, + clean_logits=False, + ) + is paged_result + ) + assert calls == ["mqa", "paged"] + + @pytest.mark.skipif( not current_platform.is_device_capability_family(120), reason="SM120 only" ) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 77ea8e9779ee..7df618e58e1a 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -547,7 +547,6 @@ def fp8_fp4_mqa_topk_indices( topk_indices: torch.Tensor, ) -> bool: """Write SM120 FP8 MQA top-k indices without materializing full logits.""" - _lazy_init() if not ( current_platform.is_cuda() and current_platform.is_device_capability_family(120) @@ -618,11 +617,11 @@ def fp8_fp4_mqa_logits( Returns: Logits tensor of shape [M, N], dtype `torch.float32`. """ - _lazy_init() if current_platform.is_device_capability_family(120) and q[1] is None: return _fp8_mqa_logits_sm12x( q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits ) + _lazy_init() if _fp8_fp4_mqa_logits_impl is None: return _missing() return _fp8_fp4_mqa_logits_impl( @@ -762,7 +761,6 @@ def fp8_fp4_paged_mqa_topk_indices( topk_indices: torch.Tensor, ) -> bool: """Write SM120 FP8 paged MQA top-k indices without full logits.""" - _lazy_init() q_values, q_scale = q if not ( current_platform.is_cuda() @@ -905,11 +903,11 @@ def fp8_fp4_paged_mqa_logits( Logits tensor of shape [B * next_n, max_model_len], dtype `torch.float32`. """ - _lazy_init() if current_platform.is_device_capability_family(120) and q[1] is None: return _fp8_paged_mqa_logits_sm12x( q, kv_cache, weights, context_lens, block_tables, max_model_len ) + _lazy_init() if _fp8_fp4_paged_mqa_logits_impl is None: return _missing() return _fp8_fp4_paged_mqa_logits_impl( @@ -984,9 +982,9 @@ def tf32_hc_prenorm_gemm( See the caller function for shape requirement """ - _lazy_init() if current_platform.is_device_capability_family(120): return _tf32_hc_prenorm_gemm_sm12x(x, fn, out, sqrsum, num_split) + _lazy_init() if _tf32_hc_prenorm_gemm_impl is None: return _missing() return _tf32_hc_prenorm_gemm_impl( diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index eb0ea8f528b9..3f43b65d26fb 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -46,6 +46,14 @@ def sparse_indexer_max_logits_bytes(is_sm12x: bool | None = None) -> int: return default_mb * 1024 * 1024 +def _uses_deep_gemm_scheduler_metadata() -> bool: + return ( + current_platform.is_cuda() + and has_deep_gemm() + and not current_platform.is_device_capability_family(120) + ) + + @triton.jit def _prepare_uniform_decode_kernel( seq_lens_ptr, @@ -623,8 +631,7 @@ def build( if seq_lens.dim() == 1: seq_lens = seq_lens.unsqueeze(-1) - # DeepGEMM is required for the paged MQA logits on CUDA devices - if current_platform.is_cuda() and has_deep_gemm(): + if _uses_deep_gemm_scheduler_metadata(): self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( seq_lens, self.kv_cache_spec.storage_block_size,