diff --git a/csrc/libtorch_stable/moe/marlin_moe_wna16/ops.cu b/csrc/libtorch_stable/moe/marlin_moe_wna16/ops.cu index 177eefa2c6f0..8fd63140e6d8 100644 --- a/csrc/libtorch_stable/moe/marlin_moe_wna16/ops.cu +++ b/csrc/libtorch_stable/moe/marlin_moe_wna16/ops.cu @@ -437,6 +437,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); STD_TORCH_CHECK(max_shared_mem > 0); + int device_max_shared_mem = max_shared_mem; int major_capability, minor_capability; cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, @@ -527,10 +528,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, } cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - max_shared_mem); + device_max_shared_mem); // avoid ">>>" being formatted to "> > >" // clang-format off - kernel<<>>( + kernel<<>>( A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, num_groups, prob_m, @@ -708,9 +709,8 @@ torch::stable::Tensor moe_wna16_marlin_gemm( torch::stable::Tensor c_tmp; if (use_fp32_reduce && !use_atomic_add) { // max num of threadblocks is sms * 4 - long max_c_tmp_size = min( - (long)size_n * sorted_token_ids.size(0), - (long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n); + long max_c_tmp_size = + (long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n; if (moe_block_size == 8) max_c_tmp_size *= 2; c_tmp = torch::stable::new_empty(a, {max_c_tmp_size}, kFloat); } else { diff --git a/tests/compile/passes/test_functionalization.py b/tests/compile/passes/test_functionalization.py index 31bf225d4135..3feb90fe0c56 100644 --- a/tests/compile/passes/test_functionalization.py +++ b/tests/compile/passes/test_functionalization.py @@ -251,12 +251,70 @@ def ops_not_in_model(self): return [] +class TestFusedDeepseekV4QnormRopeKvInsert(torch.nn.Module): + OP_REGISTERED = False + + def __init__(self): + super().__init__() + self.register_test_custom_op() + + @classmethod + def register_test_custom_op(cls): + if not cls.OP_REGISTERED: + + def fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_impl( + q: torch.Tensor, + kv: torch.Tensor, + k_cache: torch.Tensor, + ) -> None: + q.add_(kv) + k_cache.add_(kv) + + def fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_fake( + q: torch.Tensor, + kv: torch.Tensor, + k_cache: torch.Tensor, + ) -> None: + return None + + direct_register_custom_op( + op_name="fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", + op_func=fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_impl, + mutates_args=["q", "k_cache"], + fake_impl=fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_fake, + ) + + cls.OP_REGISTERED = True + + def forward( + self, q: torch.Tensor, kv: torch.Tensor, k_cache: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(q, kv, k_cache) + return q, k_cache + + def example_inputs(self, num_tokens=32, hidden_size=128): + return ( + torch.rand(num_tokens, hidden_size), + torch.rand(num_tokens, hidden_size), + torch.rand(num_tokens, hidden_size), + ) + + def ops_in_model(self, do_fusion): + return [ + torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert.default + ] + + def ops_not_in_model(self): + return [] + + MODELS_AND_DO_FUSION = { TestSiluMul: [True, False], TestFusedAddRMSNorm: [True, False], TestRotaryEmbedding: [False], TestRotaryEmbeddingSliceScatter: [False], TestFunctionWithMutatedArgsAndReturn: [False], + TestFusedDeepseekV4QnormRopeKvInsert: [False], } diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index d822b68c5036..be9318b32263 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -468,6 +468,38 @@ def test_cudagraph_sizes_post_init( ) +def test_spec_decode_cudagraph_sizes_keep_small_full_decode_batches_exact(): + config = CompilationConfig( + cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE, + cudagraph_capture_sizes=[ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 40, + 48, + 56, + 64, + 72, + 80, + 88, + 96, + ], + max_cudagraph_capture_size=96, + ) + + config.adjust_cudagraph_sizes_for_spec_decode( + uniform_decode_query_len=3, + tensor_parallel_size=1, + ) + + for num_reqs in range(1, 33): + assert 3 * num_reqs in config.cudagraph_capture_sizes + + @pytest.mark.skipif( not current_platform.support_static_graph_mode(), reason="Skip if not cudagraph mode supported", diff --git a/tests/config/test_deepseek_v4_cudagraph_config.py b/tests/config/test_deepseek_v4_cudagraph_config.py new file mode 100644 index 000000000000..2989a3f17d3d --- /dev/null +++ b/tests/config/test_deepseek_v4_cudagraph_config.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +from vllm.config.vllm import _should_auto_enable_breakable_cudagraph + + +def _model_config(*architectures: str): + return SimpleNamespace(architectures=list(architectures)) + + +def test_deepseek_v4_does_not_auto_enable_breakable_cudagraph(): + # Breakable cudagraph disables torch.compile and is 1.5-3.8x slower for MTP + # decode on SM12x (measured); DeepSeek-V4 defaults to FULL_AND_PIECEWISE. + assert not _should_auto_enable_breakable_cudagraph( + _model_config("DeepseekV4ForCausalLM") + ) + assert not _should_auto_enable_breakable_cudagraph( + _model_config("DeepSeekV4MTPModel") + ) + + +def test_minimax_m3_auto_enables_breakable_cudagraph(): + # MiniMax M3 retains upstream's unconditional auto-enable. + assert _should_auto_enable_breakable_cudagraph( + _model_config("MiniMaxM3SparseForCausalLM") + ) + assert _should_auto_enable_breakable_cudagraph( + _model_config("MiniMaxM3SparseForConditionalGeneration") + ) + + +def test_other_models_do_not_auto_enable_breakable_cudagraph(): + assert not _should_auto_enable_breakable_cudagraph( + _model_config("Qwen3ForCausalLM") + ) diff --git a/tests/kernels/moe/test_flashinfer_cutlass_mxfp4_config.py b/tests/kernels/moe/test_flashinfer_cutlass_mxfp4_config.py new file mode 100644 index 000000000000..6bb12fde710b --- /dev/null +++ b/tests/kernels/moe/test_flashinfer_cutlass_mxfp4_config.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import sys +import types + +import torch + +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + RoutingMethodType, + mxfp4_mxfp8_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutlass_moe import ( + FlashInferExperts, +) +from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( + Mxfp4MoeBackend, + convert_weight_to_mxfp4_moe_kernel_format, +) + + +def _make_moe_config() -> FusedMoEConfig: + return FusedMoEConfig( + num_experts=2, + experts_per_token=1, + hidden_dim=16, + intermediate_size_per_partition=16, + num_local_experts=2, + num_logical_experts=2, + activation=MoEActivation.SILU, + device="cpu", + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + in_dtype=torch.bfloat16, + routing_method=RoutingMethodType.TopK, + max_num_tokens=16, + ) + + +def _make_experts( + *, + gemm1_alpha: float | None = None, + gemm1_beta: float | None = None, + gemm1_clamp_limit: float | None = None, +) -> FlashInferExperts: + quant_config = mxfp4_mxfp8_moe_quant_config( + w1_scale=torch.ones((2, 32, 1), dtype=torch.float8_e4m3fn), + w2_scale=torch.ones((2, 16, 1), dtype=torch.float8_e4m3fn), + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=gemm1_clamp_limit, + ) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return FlashInferExperts( + moe_config=_make_moe_config(), + quant_config=quant_config, + ) + + +def test_mxfp4_swiglu_parameters_stay_unset_without_quant_config() -> None: + experts = _make_experts() + + assert experts.gemm1_alpha is None + assert experts.gemm1_beta is None + assert experts.gemm1_clamp_limit is None + + +def test_mxfp4_swiglu_parameters_follow_quant_config() -> None: + experts = _make_experts( + gemm1_alpha=1.25, + gemm1_beta=0.75, + gemm1_clamp_limit=5.5, + ) + + torch.testing.assert_close(experts.gemm1_alpha, torch.tensor([1.25, 1.25])) + torch.testing.assert_close(experts.gemm1_beta, torch.tensor([0.75, 0.75])) + torch.testing.assert_close( + experts.gemm1_clamp_limit, + torch.tensor([5.5, 5.5]), + ) + + +def test_cutlass_mxfp8_kernel_format_converts_gate_up_layout(monkeypatch) -> None: + monkeypatch.setitem( + sys.modules, + "flashinfer", + types.SimpleNamespace(block_scale_interleave=lambda x: x.contiguous()), + ) + + num_experts = 1 + intermediate_size = 64 + hidden_size = 64 + packed_hidden_size = hidden_size // 2 + sf_block_size = 32 + + w13_weight = torch.arange( + num_experts * 2 * intermediate_size * packed_hidden_size, + dtype=torch.uint8, + ).reshape(num_experts, 2 * intermediate_size, packed_hidden_size) + w2_weight = torch.arange( + num_experts * hidden_size * (intermediate_size // 2), + dtype=torch.uint8, + ).reshape(num_experts, hidden_size, intermediate_size // 2) + w13_scale_u8 = torch.arange( + num_experts * 2 * intermediate_size * (hidden_size // sf_block_size), + dtype=torch.uint8, + ).reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size) + w2_scale_u8 = torch.arange( + num_experts * hidden_size * (intermediate_size // sf_block_size), + dtype=torch.uint8, + ).reshape(num_experts, hidden_size, intermediate_size // sf_block_size) + w13_bias = torch.arange( + num_experts * 2 * intermediate_size, + dtype=torch.bfloat16, + ).reshape(num_experts, 2 * intermediate_size) + w2_bias = torch.arange( + num_experts * hidden_size, + dtype=torch.bfloat16, + ).reshape(num_experts, hidden_size) + + ( + out_w13, + out_w2, + out_w13_scale, + out_w2_scale, + out_w13_bias, + out_w2_bias, + ) = convert_weight_to_mxfp4_moe_kernel_format( + mxfp4_backend=Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + layer=torch.nn.Module(), + w13_weight=w13_weight, + w2_weight=w2_weight, + w13_weight_scale=w13_scale_u8.view(torch.float8_e4m3fn), + w2_weight_scale=w2_scale_u8.view(torch.float8_e4m3fn), + w13_bias=w13_bias, + w2_bias=w2_bias, + ) + + expected_w13 = torch.cat( + [ + w13_weight[:, intermediate_size:, :], + w13_weight[:, :intermediate_size, :], + ], + dim=1, + ) + expected_w13_scale = torch.cat( + [ + w13_scale_u8[:, intermediate_size:, :], + w13_scale_u8[:, :intermediate_size, :], + ], + dim=1, + ) + expected_w13_bias = torch.cat( + [ + w13_bias[:, intermediate_size:], + w13_bias[:, :intermediate_size], + ], + dim=1, + ) + + assert out_w13.is_contiguous() + assert out_w2.is_contiguous() + torch.testing.assert_close(out_w13, expected_w13) + torch.testing.assert_close(out_w2, w2_weight) + torch.testing.assert_close(out_w13_scale, expected_w13_scale) + torch.testing.assert_close(out_w2_scale, w2_scale_u8) + torch.testing.assert_close(out_w13_bias, expected_w13_bias) + torch.testing.assert_close(out_w2_bias, w2_bias) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 45cd17b3b115..750a99b20bba 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -1009,6 +1009,119 @@ def test_fused_marlin_moe( torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +def test_fused_marlin_moe_cuda_graph(): + """Test that GPTQ Marlin MoE works correctly with CUDA graphs. + + Regression test for https://github.com/vllm-project/vllm/issues/36811 + The bug was caused by: + 1. cudaFuncSetAttribute(MaxDynamicSharedMemorySize) being overwritten + by later CUDA graph captures with different blocks_per_sm values. + 2. Batch-dependent c_tmp buffer sizing via sorted_token_ids.size(0). + """ + torch.cuda.manual_seed(42) + + # Use 64 experts like Qwen3.5-35B-A3B to trigger the bug + e, topk = 64, 8 + n, k = 1024, 1024 + group_size = 128 + quant_type = scalar_types.uint4b8 + dtype = torch.half + + # Batch sizes matching vLLM's CUDA graph capture sizes + batch_sizes = [1, 2, 4, 8, 16, 24, 32] + + # Create weights (shared across all batch sizes) + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + w1_data = MarlinMoEWeightData.make( + w=w1, + quant_type=quant_type, + group_size=group_size, + act_order=False, + ) + w2_data = MarlinMoEWeightData.make( + w=w2, + quant_type=quant_type, + group_size=group_size, + act_order=False, + ) + + # Capture a CUDA graph for each batch size + graphs = {} + static_inputs = {} + static_outputs = {} + + for m in batch_sizes: + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + + # Static input tensors for graph replay + a_static = a.clone() + topk_weights_static = topk_weights.clone() + topk_ids_static = topk_ids.clone() + + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.stream(stream), torch.cuda.graph(graph): + out = fused_marlin_moe( + a_static, + w1_data.qweight, + w2_data.qweight, + None, + None, + w1_data.scales, + w2_data.scales, + topk_weights_static, + topk_ids_static, + global_num_experts=e, + quant_type_id=quant_type.id, + is_k_full=True, + ) + + graphs[m] = graph + static_inputs[m] = (a_static, topk_weights_static, topk_ids_static) + static_outputs[m] = out + + torch.accelerator.synchronize() + + # Replay each graph and compare against eager mode. + # The bug manifested when replaying small batch size graphs (e.g., M=1) + # after larger batch sizes had overwritten cudaFuncSetAttribute. + for m in batch_sizes: + a_static, topk_weights_static, topk_ids_static = static_inputs[m] + + # Compute eager reference with the same inputs + eager_out = fused_marlin_moe( + a_static, + w1_data.qweight, + w2_data.qweight, + None, + None, + w1_data.scales, + w2_data.scales, + topk_weights_static, + topk_ids_static, + global_num_experts=e, + quant_type_id=quant_type.id, + is_k_full=True, + ) + + # Replay the captured graph + static_outputs[m].zero_() + graphs[m].replay() + torch.accelerator.synchronize() + + torch.testing.assert_close( + static_outputs[m], + eager_out, + atol=1e-2, + rtol=1e-2, + ) + + @pytest.mark.flaky(reruns=2) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @pytest.mark.usefixtures("default_vllm_config") diff --git a/tests/kernels/test_mhc_kernels.py b/tests/kernels/test_mhc_kernels.py index 2bdce9f9c14d..067e7cfd0a43 100644 --- a/tests/kernels/test_mhc_kernels.py +++ b/tests/kernels/test_mhc_kernels.py @@ -7,6 +7,7 @@ from vllm.model_executor.kernels.mhc.tilelang import ( _tilelang_hc_prenorm_gemm, _torch_hc_prenorm_gemm, + _use_tf32_hc_prenorm_gemm, ) from vllm.model_executor.layers.mhc import HAS_TILELANG_MHC from vllm.platforms import current_platform @@ -96,6 +97,20 @@ def hc_head_ref( return torch.sum(pre_mix.unsqueeze(-1) * residual.float(), dim=-2).bfloat16() +def test_sm120_uses_tf32_hc_prenorm_gemm_without_deepgemm(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda family: family == 120, + ) + monkeypatch.setattr( + "vllm.utils.deep_gemm.is_deep_gemm_supported", + lambda: False, + ) + + assert _use_tf32_hc_prenorm_gemm() + + @pytest.mark.skipif( not HAS_TILELANG_MHC, reason="TileLang MHC support required", diff --git a/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py b/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py index da974131f65a..cb4c4d492fd2 100644 --- a/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py +++ b/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py @@ -8,6 +8,9 @@ import pytest import torch +import vllm.model_executor.model_loader.weight_utils as weight_utils +from vllm.config.load import LoadConfig +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import ( download_weights_from_hf, fastsafetensors_weights_iterator, @@ -16,6 +19,77 @@ from vllm.platforms import current_platform +def test_default_loader_filters_fastsafetensors_before_materializing(monkeypatch): + class FakeProcessGroup: + def size(self): + return 1 + + class FakeFileBuffer: + def __init__(self): + self.key_to_rank_lidx = { + "model.layers.0.self_attn.q_proj.weight": (0, 0), + "model.layers.0.mlp.experts.0.gate_proj.weight": (0, 1), + "model.layers.0.mlp.experts.1.gate_proj.weight": (0, 2), + "model.mtp.0.weight": (0, 3), + } + self.loaded_keys: list[str] = [] + self.closed = False + + def get_tensor(self, key: str): + self.loaded_keys.append(key) + return torch.tensor([len(self.loaded_keys)]) + + def close(self): + self.closed = True + + class FakeLoader: + def __init__(self, file_buffer): + self.file_buffer = file_buffer + self.closed = False + + def copy_files_to_device(self): + return self.file_buffer + + def close(self): + self.closed = True + + file_buffer = FakeFileBuffer() + loader = FakeLoader(file_buffer) + + model_loader = DefaultModelLoader(LoadConfig(load_format="fastsafetensors")) + model_loader.local_expert_ids = {0} + monkeypatch.setattr( + model_loader, + "_prepare_weights", + lambda *_args: ("/weights", ["model.safetensors"], True), + ) + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: False) + monkeypatch.setattr(weight_utils, "SingleGroup", FakeProcessGroup) + monkeypatch.setattr( + weight_utils, + "_init_fastsafetensors_loader", + lambda *_args, **_kwargs: loader, + ) + + loaded = dict( + model_loader._get_weights_iterator( + DefaultModelLoader.Source("model", revision=None), + weight_name_filter=lambda name: "model.mtp." in name, + ) + ) + + assert set(loaded) == { + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.mlp.experts.0.gate_proj.weight", + } + assert file_buffer.loaded_keys == [ + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.mlp.experts.0.gate_proj.weight", + ] + assert file_buffer.closed + assert loader.closed + + @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="fastsafetensors requires NVIDIA/AMD GPUs", diff --git a/tests/model_executor/model_loader/test_ep_weight_filter.py b/tests/model_executor/model_loader/test_ep_weight_filter.py index 2ac38192a4b0..d032cd569019 100644 --- a/tests/model_executor/model_loader/test_ep_weight_filter.py +++ b/tests/model_executor/model_loader/test_ep_weight_filter.py @@ -319,6 +319,20 @@ def test_ep2_rank0_gets_half_experts(self, synthetic_moe_files): assert "model.layers.0.input_layernorm.weight" in loaded assert "model.layers.0.mlp.shared_experts.gate_proj.weight" in loaded + def test_weight_name_filter_skips_dense_weights(self, synthetic_moe_files): + files, _ = synthetic_moe_files + loaded = dict( + safetensors_weights_iterator( + files, + False, + weight_name_filter=lambda name: "self_attn.q_proj" in name, + ) + ) + + assert "model.layers.0.self_attn.q_proj.weight" not in loaded + assert "model.embed_tokens.weight" in loaded + assert "model.layers.0.mlp.shared_experts.gate_proj.weight" in loaded + def test_ep2_rank1_gets_other_half(self, synthetic_moe_files): files, expected = synthetic_moe_files local_ids = compute_local_expert_ids(8, ep_size=2, ep_rank=1) diff --git a/tests/model_executor/test_deepseek_v4_flashmla_decode_dispatch.py b/tests/model_executor/test_deepseek_v4_flashmla_decode_dispatch.py new file mode 100644 index 000000000000..0955306a0f4b --- /dev/null +++ b/tests/model_executor/test_deepseek_v4_flashmla_decode_dispatch.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +import torch + +from vllm.models.deepseek_v4.nvidia import flashmla as flashmla_mod +from vllm.models.deepseek_v4.nvidia.flashmla import DeepseekV4FlashMLAAttention + + +def _make_layer(compress_ratio: int) -> DeepseekV4FlashMLAAttention: + layer = object.__new__(DeepseekV4FlashMLAAttention) + layer.compress_ratio = compress_ratio + layer.swa_cache_layer = SimpleNamespace(kv_cache=torch.empty(1, 1, 512)) + layer.n_local_heads = 2 + layer.scale = 1.0 + layer.attn_sink = torch.zeros(2) + return layer + + +def _make_swa_metadata() -> SimpleNamespace: + return SimpleNamespace( + num_decodes=1, + num_decode_tokens=1, + decode_swa_indices=torch.zeros((1, 4), dtype=torch.int32), + decode_swa_lens=torch.ones(1, dtype=torch.int32), + is_valid_token=torch.ones(1, dtype=torch.bool), + token_to_req_indices=torch.zeros(1, dtype=torch.int32), + block_table=torch.zeros((1, 1), dtype=torch.int32), + block_size=64, + seq_lens=torch.full((1,), 4, dtype=torch.int32), + tile_sched_swaonly=None, + tile_sched_c4a=None, + tile_sched_c128a=None, + ) + + +def test_swa_decode_uses_triton_path_without_flashmla_tile_sched(monkeypatch): + layer = _make_layer(compress_ratio=1) + metadata = _make_swa_metadata() + calls = [] + + def fake_decode( + cls, + layer, + q, + swa_k_cache, + swa_metadata, + output, + ): + calls.append((layer, q.shape, swa_k_cache.shape, swa_metadata, output.shape)) + + monkeypatch.setattr(flashmla_mod, "is_triton_sparse_mla_enabled", lambda _: True) + monkeypatch.setattr( + DeepseekV4FlashMLAAttention, + "_forward_sparse_mla_swa_decode_triton", + classmethod(fake_decode), + ) + + q = torch.empty(1, 2, 512) + output = torch.empty_like(q) + + layer._forward_decode( + q=q, + kv_cache=None, + swa_metadata=metadata, + attn_metadata=None, + swa_only=True, + output=output, + ) + + assert calls == [(layer, (1, 1, 2, 512), (1, 1, 512), metadata, (1, 2, 512))] + + +def test_compressed_decode_uses_triton_path_without_flashmla_tile_sched(monkeypatch): + layer = _make_layer(compress_ratio=128) + metadata = _make_swa_metadata() + attn_metadata = SimpleNamespace( + block_size=256, + c128a_global_decode_topk_indices=torch.zeros((1, 1, 2), dtype=torch.int32), + c128a_decode_topk_lens=torch.ones(1, dtype=torch.int32), + ) + kv_cache = torch.empty(1, 1, 512) + calls = [] + + def fake_decode( + cls, + layer, + q, + compressed_k_cache, + swa_k_cache, + topk_indices, + topk_lens, + swa_metadata, + attn_metadata, + output, + ): + calls.append( + ( + layer, + q.shape, + compressed_k_cache.shape, + swa_k_cache.shape, + topk_indices.shape, + topk_lens.shape, + swa_metadata, + attn_metadata, + output.shape, + ) + ) + + monkeypatch.setattr(flashmla_mod, "is_triton_sparse_mla_enabled", lambda _: True) + monkeypatch.setattr( + DeepseekV4FlashMLAAttention, + "_forward_sparse_mla_compressed_decode_triton", + classmethod(fake_decode), + ) + + q = torch.empty(1, 2, 512) + output = torch.empty_like(q) + + layer._forward_decode( + q=q, + kv_cache=kv_cache, + swa_metadata=metadata, + attn_metadata=attn_metadata, + swa_only=False, + output=output, + ) + + assert calls == [ + ( + layer, + (1, 1, 2, 512), + (1, 1, 512), + (1, 1, 512), + (1, 1, 2), + (1,), + metadata, + attn_metadata, + (1, 2, 512), + ) + ] diff --git a/tests/model_executor/test_deepseek_v4_kernel_warmup.py b/tests/model_executor/test_deepseek_v4_kernel_warmup.py new file mode 100644 index 000000000000..6747107d59d2 --- /dev/null +++ b/tests/model_executor/test_deepseek_v4_kernel_warmup.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +from vllm.model_executor.warmup import kernel_warmup + + +def _mtp_runner(query_len: int = 3): + return SimpleNamespace( + speculative_config=SimpleNamespace(method="mtp"), + num_spec_tokens=query_len - 1, + uniform_decode_query_len=query_len, + ) + + +def test_deepseek_v4_mtp_uniform_decode_warmup_covers_c256(): + requests = kernel_warmup._deepseek_v4_mtp_uniform_decode_warmup_requests( + _mtp_runner(), + max_tokens=4096, + max_reqs=256, + ) + + assert requests == (1, 2, 4, 8, 16, 24, 32, 256) + + +def test_deepseek_v4_mtp_uniform_decode_warmup_still_respects_limits(): + assert kernel_warmup._deepseek_v4_mtp_uniform_decode_warmup_requests( + _mtp_runner(), + max_tokens=4096, + max_reqs=24, + ) == (1, 2, 4, 8, 16, 24) + assert kernel_warmup._deepseek_v4_mtp_uniform_decode_warmup_requests( + _mtp_runner(), + max_tokens=96, + max_reqs=256, + ) == (1, 2, 4, 8, 16, 24, 32) diff --git a/tests/model_executor/test_deepseek_v4_moe_metadata.py b/tests/model_executor/test_deepseek_v4_moe_metadata.py new file mode 100644 index 000000000000..25a177a9792f --- /dev/null +++ b/tests/model_executor/test_deepseek_v4_moe_metadata.py @@ -0,0 +1,159 @@ +from types import SimpleNamespace + +from vllm.model_executor.layers.fused_moe import RoutedExperts +from vllm.models.deepseek_v4 import quant_config as deepseek_v4_quant_config +from vllm.models.deepseek_v4.nvidia import model as deepseek_v4_model +from vllm.models.deepseek_v4.nvidia.model import ( + DeepseekV4MixtureOfExperts, + DeepseekV4MoE, +) + + +def test_deepseek_v4_fused_moe_metadata_is_available_to_mixture(): + moe = object.__new__(DeepseekV4MoE) + moe.n_routed_experts = 256 + moe.n_shared_experts = 1 + moe.experts = SimpleNamespace( + logical_num_experts=256, + global_num_experts=256, + local_num_experts=128, + ) + moe._sync_fused_moe_metadata() + + mixture = object.__new__(DeepseekV4MixtureOfExperts) + mixture.extract_moe_parameters(moe) + + assert mixture.num_logical_experts == 256 + assert mixture.num_physical_experts == 256 + assert mixture.num_local_physical_experts == 128 + assert mixture.num_routed_experts == 256 + assert mixture.num_shared_experts == 1 + assert mixture.num_redundant_experts == 0 + + +def test_deepseek_v4_fused_moe_metadata_handles_moe_runner_shape(): + moe = object.__new__(DeepseekV4MoE) + moe.n_routed_experts = 256 + moe.n_shared_experts = 1 + moe.experts = SimpleNamespace( + moe_config=SimpleNamespace( + num_logical_experts=256, + num_experts=256, + num_local_experts=128, + ), + routed_experts=SimpleNamespace( + global_num_experts=256, + local_num_experts=128, + ), + ) + + moe._sync_fused_moe_metadata() + + assert moe.n_logical_experts == 256 + assert moe.n_physical_experts == 256 + assert moe.n_local_physical_experts == 128 + assert moe.n_local_experts == 128 + assert moe.n_redundant_experts == 0 + + +def test_deepseek_v4_fp4_quant_config_handles_routed_experts_after_moe_refactor( + monkeypatch, +): + class FakeMxfp4MoEMethod: + def __init__(self, moe_config): + self.moe_config = moe_config + + quant_config = deepseek_v4_quant_config.DeepseekV4FP8Config( + is_checkpoint_fp8_serialized=True, + weight_block_size=[128, 128], + ) + layer = object.__new__(RoutedExperts) + layer.moe_config = object() + + monkeypatch.setattr( + deepseek_v4_quant_config, + "Mxfp4MoEMethod", + FakeMxfp4MoEMethod, + ) + monkeypatch.setattr( + deepseek_v4_quant_config, + "get_current_vllm_config", + lambda: SimpleNamespace( + model_config=SimpleNamespace( + hf_config=SimpleNamespace( + expert_dtype="fp4", + quantization_config={}, + ) + ) + ), + ) + + method = quant_config.get_quant_method(layer, "model.layers.3.mlp.experts") + + assert isinstance(method, FakeMxfp4MoEMethod) + assert method.moe_config is layer.moe_config + + +def test_deepseek_v4_fused_moe_init_exports_moe_metadata(monkeypatch): + class FakeGate: + def __init__(self, *args, **kwargs): + self.e_score_correction_bias = None + self.tid2eid = None + + class FakeFusedMoE: + logical_num_experts = 256 + global_num_experts = 256 + local_num_experts = 128 + + def __init__(self, **kwargs): + self.kwargs = kwargs + + class FakeMLP: + def __init__(self, *args, **kwargs): + pass + + monkeypatch.setattr(deepseek_v4_model, "GateLinear", FakeGate) + monkeypatch.setattr(deepseek_v4_model, "DeepseekV4MLP", FakeMLP) + monkeypatch.setattr(deepseek_v4_model, "FusedMoE", FakeFusedMoE) + monkeypatch.setattr( + deepseek_v4_model, + "get_tensor_model_parallel_world_size", + lambda: 2, + ) + monkeypatch.setattr( + deepseek_v4_model, + "get_tensor_model_parallel_rank", + lambda: 1, + ) + + config = SimpleNamespace( + n_routed_experts=256, + n_shared_experts=1, + num_experts_per_tok=8, + hidden_size=7168, + moe_intermediate_size=2048, + swiglu_limit=7.0, + hidden_act="silu", + norm_topk_prob=True, + num_hash_layers=0, + vocab_size=128000, + ) + vllm_config = SimpleNamespace( + model_config=SimpleNamespace(hf_config=config), + quant_config=None, + kernel_config=SimpleNamespace(moe_backend="auto"), + parallel_config=SimpleNamespace( + enable_expert_parallel=True, + enable_eplb=False, + eplb_config=SimpleNamespace(num_redundant_experts=0), + ), + ) + + moe = DeepseekV4MoE(vllm_config, prefix="model.layers.3.mlp") + + assert moe.n_logical_experts == 256 + assert moe.n_physical_experts == 256 + assert moe.n_local_physical_experts == 128 + assert moe.n_local_experts == 128 + assert moe.n_shared_experts == 1 + assert moe.n_redundant_experts == 0 diff --git a/tests/model_executor/test_deepseek_v4_o_proj.py b/tests/model_executor/test_deepseek_v4_o_proj.py new file mode 100644 index 000000000000..09b32e43b370 --- /dev/null +++ b/tests/model_executor/test_deepseek_v4_o_proj.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.models.deepseek_v4.nvidia.ops.o_proj import compute_fp8_einsum_recipe +from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability + + +@pytest.mark.parametrize( + ("capability", "expected_recipe", "expected_tma_aligned"), + [ + ((9, 0), (1, 128, 128), False), + ((10, 0), (1, 1, 128), True), + ((12, 0), (1, 128, 128), False), + ((12, 1), (1, 128, 128), False), + ], +) +def test_deepseek_v4_o_proj_recipe_is_arch_specific( + monkeypatch: pytest.MonkeyPatch, + capability: tuple[int, int], + expected_recipe: tuple[int, int, int], + expected_tma_aligned: bool, +): + monkeypatch.setattr( + current_platform, + "get_device_capability", + lambda device_id=0: DeviceCapability(*capability), + ) + + assert compute_fp8_einsum_recipe() == (expected_recipe, expected_tma_aligned) diff --git a/tests/model_executor/test_deepseek_v4_sparse_mla_metadata.py b/tests/model_executor/test_deepseek_v4_sparse_mla_metadata.py new file mode 100644 index 000000000000..8d26aa7fea24 --- /dev/null +++ b/tests/model_executor/test_deepseek_v4_sparse_mla_metadata.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.models.deepseek_v4 import sparse_mla +from vllm.models.deepseek_v4.nvidia import flashmla + + +def test_c128a_effective_topk_width_uses_current_positions() -> None: + assert ( + sparse_mla._c128a_effective_topk_width( + positions=torch.tensor([0, 126], dtype=torch.int64), + compress_ratio=128, + max_compressed_tokens=4096, + alignment=128, + ) + == 128 + ) + assert ( + sparse_mla._c128a_effective_topk_width( + positions=torch.tensor([127, 1023], dtype=torch.int64), + compress_ratio=128, + max_compressed_tokens=4096, + alignment=128, + ) + == 128 + ) + assert ( + sparse_mla._c128a_effective_topk_width( + positions=torch.tensor([524287], dtype=torch.int64), + compress_ratio=128, + max_compressed_tokens=8192, + alignment=128, + ) + == 4096 + ) + assert ( + sparse_mla._c128a_effective_topk_width( + positions=torch.tensor([1048575], dtype=torch.int64), + compress_ratio=128, + max_compressed_tokens=8192, + alignment=128, + ) + == 8192 + ) + + +def test_indexed_d512_split_topk_keeps_small_c128a_prefills() -> None: + assert not flashmla._is_indexed_d512_split_topk(128) + assert flashmla._is_indexed_d512_split_topk(256) + assert flashmla._is_indexed_d512_split_topk(512) + assert flashmla._is_indexed_d512_split_topk(1152) + assert not flashmla._is_indexed_d512_split_topk(1280) diff --git a/tests/model_executor/test_fp8_marlin_kernel_selection.py b/tests/model_executor/test_fp8_marlin_kernel_selection.py new file mode 100644 index 000000000000..0e3cdbddf5cc --- /dev/null +++ b/tests/model_executor/test_fp8_marlin_kernel_selection.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.model_executor.kernels.linear.scaled_mm import ( + FP8ScaledMMLinearLayerConfig, + MarlinFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + QuantKey, + ScaleDesc, +) + + +def _fp8_linear_config( + activation_group_shape: GroupShape, +) -> FP8ScaledMMLinearLayerConfig: + return FP8ScaledMMLinearLayerConfig( + weight_quant_key=QuantKey( + dtype=torch.float8_e4m3fn, + scale=ScaleDesc( + dtype=torch.float32, + static=True, + group_shape=GroupShape.PER_CHANNEL, + ), + ), + activation_quant_key=QuantKey( + dtype=torch.float8_e4m3fn, + scale=ScaleDesc( + dtype=torch.float32, + static=False, + group_shape=activation_group_shape, + ), + ), + weight_shape=(128, 128), + input_dtype=torch.bfloat16, + out_dtype=torch.bfloat16, + ) + + +def test_marlin_fp8_refuses_block_fp8_activation_scales(): + can_implement, reason = MarlinFP8ScaledMMLinearKernel.can_implement( + _fp8_linear_config(GroupShape(1, 128)) + ) + + assert not can_implement + assert reason is not None + assert "block-FP8" in reason + + +def test_marlin_fp8_keeps_non_block_fp8_layers_available(): + can_implement, reason = MarlinFP8ScaledMMLinearKernel.can_implement( + _fp8_linear_config(GroupShape.PER_TOKEN) + ) + + assert can_implement + assert reason is None diff --git a/tests/models/test_deepseek_v4_mega_moe.py b/tests/models/test_deepseek_v4_mega_moe.py index 3daae242d459..93a68377e310 100644 --- a/tests/models/test_deepseek_v4_mega_moe.py +++ b/tests/models/test_deepseek_v4_mega_moe.py @@ -6,12 +6,14 @@ import pytest import torch +from vllm.config import CompilationConfig from vllm.models.deepseek_v4.nvidia.model import ( DeepseekV4MegaMoEExperts, make_deepseek_v4_expert_params_mapping, ) from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs from vllm.platforms import current_platform +from vllm.utils import deep_gemm as deep_gemm_utils pytestmark = pytest.mark.skipif( not current_platform.is_cuda(), @@ -46,7 +48,8 @@ def test_deepseek_v4_mega_moe_ue8m0_uint8_to_float(): def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership(): vllm_config = SimpleNamespace( - scheduler_config=SimpleNamespace(max_num_batched_tokens=4) + compilation_config=CompilationConfig(), + scheduler_config=SimpleNamespace(max_num_batched_tokens=4), ) experts = DeepseekV4MegaMoEExperts( vllm_config, @@ -106,12 +109,75 @@ def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership(): assert torch.count_nonzero(experts.w13_weight[1]) == 0 +def test_deepseek_v4_mega_moe_finalize_uses_deep_gemm_wrapper(monkeypatch): + vllm_config = SimpleNamespace( + compilation_config=CompilationConfig(), + scheduler_config=SimpleNamespace(max_num_batched_tokens=4), + ) + experts = DeepseekV4MegaMoEExperts( + vllm_config, + num_experts=4, + num_local_experts=2, + experts_start_idx=0, + top_k=2, + hidden_size=128, + intermediate_size=128, + ) + + transformed = (object(), object()) + calls: list[object] = [] + + def fake_runtime_check(self): + calls.append("runtime_check") + + def fake_transform_sf_into_required_layout( + scale, rows, cols, block_shape, num_local_experts + ): + calls.append((rows, cols, block_shape, num_local_experts)) + return scale + + def fake_transform_weights_for_mega_moe(w13_weight, w2_weight): + calls.append((w13_weight[0].shape, w2_weight[0].shape)) + return transformed + + monkeypatch.setattr( + DeepseekV4MegaMoEExperts, + "_check_runtime_supported", + fake_runtime_check, + ) + monkeypatch.setattr( + deep_gemm_utils, + "transform_sf_into_required_layout", + fake_transform_sf_into_required_layout, + ) + monkeypatch.setattr( + deep_gemm_utils, + "transform_weights_for_mega_moe", + fake_transform_weights_for_mega_moe, + ) + + experts.finalize_weights() + + assert experts._transformed_l1_weights is transformed[0] + assert experts._transformed_l2_weights is transformed[1] + assert experts.w13_weight is None + assert experts.w13_weight_scale is None + assert experts.w2_weight is None + assert experts.w2_weight_scale is None + assert calls == [ + "runtime_check", + (256, 128, (1, 32), 2), + (128, 128, (1, 32), 2), + (torch.Size([2, 256, 64]), torch.Size([2, 128, 64])), + ] + + @pytest.mark.skipif( not torch.cuda.is_available(), reason="DeepSeek V4 MegaMoE fused input staging requires CUDA.", ) def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact(): - from vllm.third_party.deep_gemm.utils import per_token_cast_to_fp8 + deep_gemm_utils = pytest.importorskip("deep_gemm.utils") device = torch.device("cuda") num_tokens = 7 @@ -150,7 +216,7 @@ def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact(): generator=generator, ) - ref_x, ref_x_sf = per_token_cast_to_fp8( + ref_x, ref_x_sf = deep_gemm_utils.per_token_cast_to_fp8( hidden_states, use_ue8m0=True, gran_k=32, diff --git a/tests/quantization/test_sm12x_tuned_config_lookup.py b/tests/quantization/test_sm12x_tuned_config_lookup.py new file mode 100644 index 000000000000..aa85120fdde4 --- /dev/null +++ b/tests/quantization/test_sm12x_tuned_config_lookup.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.quantization.utils import fp8_utils +from vllm.platforms import current_platform + + +def _get_fused_moe_configs(e, n, block_shape): + if block_shape is None: + return fused_moe.get_moe_configs(e, n, "fp8_w8a8") + block_n, block_k = block_shape + return fused_moe.get_moe_configs(e, n, "fp8_w8a8", block_n, block_k) + + +def test_rtx_pro_6000_variants_reuse_workstation_tuned_configs(monkeypatch): + monkeypatch.setattr( + current_platform, + "get_device_name", + lambda: "NVIDIA RTX PRO 6000 Blackwell Server Edition", + ) + monkeypatch.setattr(fused_moe.envs, "VLLM_BATCH_INVARIANT", False) + fp8_utils.get_w8a8_block_fp8_configs.cache_clear() + fused_moe.get_moe_configs.cache_clear() + + assert fp8_utils.get_w8a8_block_fp8_configs(1536, 4096, 128, 128) is not None + assert _get_fused_moe_configs(256, 384, (128, 128)) is not None diff --git a/tests/reasoning/test_deepseekv3_reasoning_parser.py b/tests/reasoning/test_deepseekv3_reasoning_parser.py index f5b37194f927..4e19ca3811a8 100644 --- a/tests/reasoning/test_deepseekv3_reasoning_parser.py +++ b/tests/reasoning/test_deepseekv3_reasoning_parser.py @@ -4,16 +4,35 @@ import pytest from transformers import AutoTokenizer +from vllm.config.reasoning import ReasoningConfig from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.reasoning import ReasoningParserManager from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser -from vllm.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser +from vllm.reasoning.deepseek_v3_reasoning_parser import ( + DeepSeekV3ReasoningParser, + DeepSeekV3ReasoningWithThinkingParser, +) from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-V3.1" +class FakeReasoningTokenizer: + + def get_vocab(self) -> dict[str, int]: + return {"": 100, "": 101} + + def encode( + self, + text: str, + add_special_tokens: bool = False, + **kwargs, + ) -> list[int]: + assert add_special_tokens is False + return [self.get_vocab()[text]] + + @pytest.fixture(scope="module") def tokenizer(): return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) @@ -34,10 +53,21 @@ def test_parser_selection(tokenizer, thinking, expected_parser_type): assert isinstance(parser._parser, expected_parser_type) -def test_deepseek_v4_reasoning_parser_alias(): +def test_deepseek_v4_reasoning_parser_registration(): + """``deepseek_v4`` now resolves to its own parser (a defensive + extension of ``DeepSeekV3ReasoningParser``) rather than reusing the + V3 parser directly. See ``tests/reasoning/test_deepseekv4_reasoning_parser.py``. + """ + from vllm.reasoning.deepseek_v4_reasoning_parser import ( + DeepSeekV4ReasoningParser, + ) + parser_cls = ReasoningParserManager.get_reasoning_parser("deepseek_v4") - assert parser_cls is DeepSeekV3ReasoningParser + assert parser_cls is DeepSeekV4ReasoningParser + # The V3 alias must remain pointed at V3. + v3_cls = ReasoningParserManager.get_reasoning_parser("deepseek_v3") + assert v3_cls is DeepSeekV3ReasoningParser def test_identity_reasoning_parser_basic(tokenizer): diff --git a/tests/reasoning/test_deepseekv4_reasoning_parser.py b/tests/reasoning/test_deepseekv4_reasoning_parser.py new file mode 100644 index 000000000000..fd3dff449f80 --- /dev/null +++ b/tests/reasoning/test_deepseekv4_reasoning_parser.py @@ -0,0 +1,407 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Tests for the DeepSeek V4 reasoning parser. + +The V4 parser is a defensive extension of :class:`DeepSeekR1ReasoningParser` +that treats the DSML tool-call start marker as an implicit end-of-reasoning +when ```` is missing. That failure mode is observed at long context +on DSv4-Flash and was previously trapping tool calls inside reasoning, +leaving the agent loop with nothing to dispatch. +""" + +from unittest.mock import MagicMock + +import pytest + +from vllm.reasoning import ReasoningParserManager +from vllm.reasoning.deepseek_v4_reasoning_parser import ( + DeepSeekV4ReasoningParser, + DeepSeekV4ThinkingReasoningParser, +) +from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser + +START_TOKEN = "" +END_TOKEN = "" +START_TOKEN_ID = 9001 +END_TOKEN_ID = 9002 +DSML_MARKER = "<|DSML|tool_calls>" +DSML_MARKER_TOKEN_ID = 9100 + + +def _make_tokenizer() -> MagicMock: + """Mock tokenizer mapping the four special strings we care about.""" + tok = MagicMock() + vocab = { + START_TOKEN: START_TOKEN_ID, + END_TOKEN: END_TOKEN_ID, + DSML_MARKER: DSML_MARKER_TOKEN_ID, + } + tok.get_vocab.return_value = vocab + + def _decode(ids, *args, **kwargs): + out = [] + for tid in ids: + for s, sid in vocab.items(): + if sid == tid: + out.append(s) + break + else: + out.append(f"") + return "".join(out) + + tok.decode = _decode + return tok + + +@pytest.fixture +def tokenizer() -> MagicMock: + return _make_tokenizer() + + +@pytest.fixture +def parser(tokenizer) -> DeepSeekV4ThinkingReasoningParser: + return DeepSeekV4ThinkingReasoningParser(tokenizer) + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + + +def test_registration_resolves_to_v4_class(): + parser_cls = ReasoningParserManager.get_reasoning_parser("deepseek_v4") + assert parser_cls is DeepSeekV4ReasoningParser + + +@pytest.mark.parametrize( + "thinking_kwargs,expected_inner", + [ + ({"thinking": True}, DeepSeekV4ThinkingReasoningParser), + ({"enable_thinking": True}, DeepSeekV4ThinkingReasoningParser), + ({"thinking": False}, IdentityReasoningParser), + ({}, IdentityReasoningParser), + ], +) +def test_dispatch_based_on_thinking_kwarg(tokenizer, thinking_kwargs, expected_inner): + parser = DeepSeekV4ReasoningParser(tokenizer, chat_template_kwargs=thinking_kwargs) + assert isinstance(parser._parser, expected_inner) + + +# --------------------------------------------------------------------------- +# Healthy path (explicit ) — must match parent behavior +# --------------------------------------------------------------------------- + + +def test_healthy_implicit_start_explicit_end_in_delta(parser): + """Model emits reasoning text then in the same delta; the + parent's R1 splitter handles this — V4 must not interfere.""" + delta = parser.extract_reasoning_streaming( + previous_text="some reasoning ", + current_text="some reasoning after", + delta_text="after", + previous_token_ids=[100, 101], + current_token_ids=[100, 101, END_TOKEN_ID, 102], + delta_token_ids=[END_TOKEN_ID, 102], + ) + assert delta is not None + assert delta.reasoning == "" + assert delta.content == "after" + + +def test_healthy_explicit_end_in_previous_emits_content(parser): + """Once has been seen, the parser emits delta_text as content.""" + delta = parser.extract_reasoning_streaming( + previous_text="reasoning", + current_text="reasoningsome content", + delta_text="some content", + previous_token_ids=[100, END_TOKEN_ID], + current_token_ids=[100, END_TOKEN_ID, 101, 102], + delta_token_ids=[101, 102], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == "some content" + + +def test_healthy_explicit_start_in_delta(parser): + """Model emits both and in the same delta.""" + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text="quick thoughtanswer", + delta_text="quick thoughtanswer", + previous_token_ids=[], + current_token_ids=[START_TOKEN_ID, 200, 201, END_TOKEN_ID, 202], + delta_token_ids=[START_TOKEN_ID, 200, 201, END_TOKEN_ID, 202], + ) + assert delta is not None + assert delta.reasoning == "quick thought" + assert delta.content == "answer" + + +# --------------------------------------------------------------------------- +# Defensive path (no , DSML marker appears) — the fix +# --------------------------------------------------------------------------- + + +def test_implicit_end_marker_in_isolated_delta(parser): + """Marker arrives as its own delta after pure reasoning. The marker + token alone should be classified as content, not reasoning.""" + # First, two reasoning-only deltas to populate state. + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text="step 1 ", + delta_text="step 1 ", + previous_token_ids=[], + current_token_ids=[300, 301], + delta_token_ids=[300, 301], + ) + assert delta is not None + assert delta.reasoning == "step 1 " + assert delta.content is None + assert parser._implicit_end_seen is False + + delta = parser.extract_reasoning_streaming( + previous_text="step 1 ", + current_text="step 1 step 2", + delta_text="step 2", + previous_token_ids=[300, 301], + current_token_ids=[300, 301, 302, 303], + delta_token_ids=[302, 303], + ) + assert delta is not None + assert delta.reasoning == "step 2" + assert delta.content is None + assert parser._implicit_end_seen is False + + # Now the marker arrives. + delta = parser.extract_reasoning_streaming( + previous_text="step 1 step 2", + current_text=f"step 1 step 2{DSML_MARKER}", + delta_text=DSML_MARKER, + previous_token_ids=[300, 301, 302, 303], + current_token_ids=[300, 301, 302, 303, DSML_MARKER_TOKEN_ID], + delta_token_ids=[DSML_MARKER_TOKEN_ID], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == DSML_MARKER + assert parser._implicit_end_seen is True + + +def test_implicit_end_marker_within_delta_split(parser): + """Marker appears partway through a delta — split it at the boundary.""" + delta_text = f'tail of reasoning{DSML_MARKER}\n<|DSML|invoke name="w"' + delta = parser.extract_reasoning_streaming( + previous_text="head ", + current_text=f"head {delta_text}", + delta_text=delta_text, + previous_token_ids=[400], + current_token_ids=[400, 401, DSML_MARKER_TOKEN_ID, 402, 403], + delta_token_ids=[401, DSML_MARKER_TOKEN_ID, 402, 403], + ) + assert delta is not None + assert delta.reasoning == "tail of reasoning" + assert delta.content == f'{DSML_MARKER}\n<|DSML|invoke name="w"' + assert parser._implicit_end_seen is True + + +def test_partial_implicit_end_marker_prefix_is_buffered(parser): + """A DSML marker split across text deltas must not leak into reasoning.""" + delta = parser.extract_reasoning_streaming( + previous_text="reasoning ", + current_text="reasoning <|DSML|tool", + delta_text="<|DSML|tool", + previous_token_ids=[450], + current_token_ids=[450, 451], + delta_token_ids=[451], + ) + assert delta is None + assert parser._implicit_end_seen is False + + delta = parser.extract_reasoning_streaming( + previous_text="reasoning <|DSML|tool", + current_text=f"reasoning {DSML_MARKER}\n<|DSML|invoke", + delta_text="_calls>\n<|DSML|invoke", + previous_token_ids=[450, 451], + current_token_ids=[450, 451, 452, 453], + delta_token_ids=[452, 453], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == f"{DSML_MARKER}\n<|DSML|invoke" + assert parser._implicit_end_seen is True + + +def test_partial_implicit_end_marker_with_reasoning_prefix(parser): + """Reasoning before a split DSML marker is emitted, but the marker waits.""" + delta = parser.extract_reasoning_streaming( + previous_text="head ", + current_text="head tail reasoning<|DSML|tool", + delta_text="tail reasoning<|DSML|tool", + previous_token_ids=[460], + current_token_ids=[460, 461, 462], + delta_token_ids=[461, 462], + ) + assert delta is not None + assert delta.reasoning == "tail reasoning" + assert delta.content is None + assert parser._implicit_end_seen is False + + delta = parser.extract_reasoning_streaming( + previous_text="head tail reasoning<|DSML|tool", + current_text=f"head tail reasoning{DSML_MARKER}", + delta_text="_calls>", + previous_token_ids=[460, 461, 462], + current_token_ids=[460, 461, 462, 463], + delta_token_ids=[463], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == DSML_MARKER + assert parser._implicit_end_seen is True + + +def test_partial_implicit_end_marker_can_span_multiple_deltas(parser): + """The buffered marker prefix can grow across more than two chunks.""" + first = parser.extract_reasoning_streaming( + previous_text="r", + current_text="r<|DSML|", + delta_text="<|DSML|", + previous_token_ids=[470], + current_token_ids=[470, 471], + delta_token_ids=[471], + ) + assert first is None + + second = parser.extract_reasoning_streaming( + previous_text="r<|DSML|", + current_text="r<|DSML|tool", + delta_text="tool", + previous_token_ids=[470, 471], + current_token_ids=[470, 471, 472], + delta_token_ids=[472], + ) + assert second is None + assert parser._implicit_end_seen is False + + third = parser.extract_reasoning_streaming( + previous_text="r<|DSML|tool", + current_text=f"r{DSML_MARKER}", + delta_text="_calls>", + previous_token_ids=[470, 471, 472], + current_token_ids=[470, 471, 472, 473], + delta_token_ids=[473], + ) + assert third is not None + assert third.reasoning is None + assert third.content == DSML_MARKER + assert parser._implicit_end_seen is True + + +def test_subsequent_delta_after_implicit_end_is_content(parser): + """Once the implicit end fires, every later delta is content.""" + # Seed the parser by flipping the sticky flag via a marker delta. + parser.extract_reasoning_streaming( + previous_text="reasoning", + current_text=f"reasoning{DSML_MARKER}", + delta_text=DSML_MARKER, + previous_token_ids=[500], + current_token_ids=[500, DSML_MARKER_TOKEN_ID], + delta_token_ids=[DSML_MARKER_TOKEN_ID], + ) + assert parser._implicit_end_seen is True + + # Next delta: pure tool-call body. Should be content. + delta = parser.extract_reasoning_streaming( + previous_text=f"reasoning{DSML_MARKER}", + current_text=f"reasoning{DSML_MARKER}\n<|DSML|invoke", + delta_text="\n<|DSML|invoke", + previous_token_ids=[500, DSML_MARKER_TOKEN_ID], + current_token_ids=[500, DSML_MARKER_TOKEN_ID, 600, 601], + delta_token_ids=[600, 601], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == "\n<|DSML|invoke" + + +def test_marker_does_not_fire_when_explicit_start_present(parser): + """If the explicit ```` token is in the stream, defer to parent. + This guards against false-positive splits when something that looks + like a marker shows up in the user's prompt history. + """ + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text=f"discussing {DSML_MARKER}", + delta_text=f"discussing {DSML_MARKER}", + previous_token_ids=[START_TOKEN_ID], + current_token_ids=[START_TOKEN_ID, 700, 701, DSML_MARKER_TOKEN_ID], + delta_token_ids=[700, 701, DSML_MARKER_TOKEN_ID], + ) + assert delta is not None + # Parent puts everything after into reasoning until . + assert delta.reasoning == f"discussing {DSML_MARKER}" + assert delta.content is None + assert parser._implicit_end_seen is False + + +# --------------------------------------------------------------------------- +# is_reasoning_end / is_reasoning_end_streaming +# --------------------------------------------------------------------------- + + +def test_is_reasoning_end_with_explicit_end_token(parser): + assert parser.is_reasoning_end([100, END_TOKEN_ID, 101]) is True + + +def test_is_reasoning_end_with_implicit_marker(parser): + """When start/end tokens are absent, decoding to text and finding the + marker counts as end-of-reasoning.""" + assert parser.is_reasoning_end([300, 301, DSML_MARKER_TOKEN_ID]) is True + + +def test_is_reasoning_end_pure_reasoning(parser): + assert parser.is_reasoning_end([300, 301, 302]) is False + + +def test_is_reasoning_end_streaming_sticky_after_split(parser): + """After ``extract_reasoning_streaming`` flips the sticky flag, + ``is_reasoning_end_streaming`` must report end-of-reasoning for any + subsequent delta — even one that contains neither nor the + marker.""" + # Seed via marker delta. + parser.extract_reasoning_streaming( + previous_text="r", + current_text=f"r{DSML_MARKER}", + delta_text=DSML_MARKER, + previous_token_ids=[800], + current_token_ids=[800, DSML_MARKER_TOKEN_ID], + delta_token_ids=[DSML_MARKER_TOKEN_ID], + ) + assert parser._implicit_end_seen is True + # Now a plain content-only delta. + assert ( + parser.is_reasoning_end_streaming([800, DSML_MARKER_TOKEN_ID, 900], [900]) + is True + ) + + +# --------------------------------------------------------------------------- +# Sanity: parent's empty-delta and single-token guards still apply +# --------------------------------------------------------------------------- + + +def test_single_end_token_delta_returns_none(parser): + """Parent contract: a delta containing only the end token returns + ``None`` — the orchestrator handles the transition via + ``is_reasoning_end``.""" + out = parser.extract_reasoning_streaming( + previous_text="r", + current_text="r", + delta_text="", + previous_token_ids=[900], + current_token_ids=[900, END_TOKEN_ID], + delta_token_ids=[END_TOKEN_ID], + ) + assert out is None diff --git a/tests/tokenizers_/test_deepseek_v4.py b/tests/tokenizers_/test_deepseek_v4.py index 358732eabf40..0a3253957add 100644 --- a/tests/tokenizers_/test_deepseek_v4.py +++ b/tests/tokenizers_/test_deepseek_v4.py @@ -8,8 +8,14 @@ import pytest from vllm.entrypoints.chat_utils import parse_chat_messages +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ChatMessage, +) +from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.renderers.registry import RENDERER_REGISTRY from vllm.tokenizers.deepseek_v4 import get_deepseek_v4_tokenizer +from vllm.tokenizers.deepseek_v4_encoding import encode_arguments_to_dsml from vllm.tokenizers.registry import TokenizerRegistry FIXTURES_DIR = Path(__file__).parent / "fixtures" / "deepseek_v4" @@ -96,6 +102,130 @@ def test_deepseek_v4_enables_thinking_with_compatible_kwargs(kwargs): assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|>") +def test_deepseek_v4_honors_official_thinking_request_field(): + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + prompt = _tokenizer().apply_chat_template( + request.messages, + tokenize=False, + **chat_kwargs, + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|>") + + +def test_deepseek_v4_defaults_to_official_thinking_for_openai_request(): + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + + +def test_deepseek_v4_preserves_official_reasoning_content_alias(): + messages = [ + {"role": "user", "content": "Q1"}, + {"role": "assistant", "reasoning_content": "because", "content": "A1"}, + {"role": "user", "content": "Q2"}, + ] + + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + + assert conversation[1]["reasoning"] == "because" + assert conversation[1]["reasoning_content"] == "because" + + +def test_deepseek_v4_response_messages_expose_reasoning_content_alias(): + message = ChatMessage(role="assistant", reasoning="because", content="answer") + delta = DeltaMessage(reasoning="because") + + assert message.reasoning_content == "because" + assert delta.reasoning_content == "because" + assert ( + ChatMessage( + role="assistant", + reasoning_content="because", + content="answer", + ).reasoning + == "because" + ) + + +def test_deepseek_v4_preserves_official_prefix_assistant_message(): + messages = [ + {"role": "user", "content": "Please write quick sort code"}, + {"role": "assistant", "content": "```python\n", "prefix": True}, + ] + + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + prompt = _tokenizer().apply_chat_template( + conversation=conversation, + messages=messages, + tokenize=False, + ) + + assert conversation[1]["prefix"] is True + assert conversation[1]["wo_eos"] is True + assert prompt.endswith("<|Assistant|>```python\n") + assert not prompt.endswith("<|end▁of▁sentence|>") + + +def test_deepseek_v4_thinking_ignores_sampling_controls(): + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + "temperature": 0.2, + "top_p": 0.3, + "top_k": 4, + "presence_penalty": 1.5, + "frequency_penalty": 1.25, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + ) + + assert sampling_params.temperature == 1.0 + assert sampling_params.top_p == 1.0 + assert sampling_params.top_k == 0 + assert sampling_params.presence_penalty == 0.0 + assert sampling_params.frequency_penalty == 0.0 + + def test_deepseek_v4_uses_v4_tool_prompt_from_request_tools(): tools = [ { @@ -183,6 +313,66 @@ def test_deepseek_v4_renders_parsed_history_tool_arguments(): assert 'parameter name="arguments"' not in prompt +@pytest.mark.parametrize( + ("tool_call", "expected_parameter"), + [ + ({"name": "refresh", "arguments": None}, None), + ({"name": "refresh"}, None), + ({"name": "refresh", "arguments": ""}, None), + ( + {"name": "refresh", "arguments": '{"target": "cache"}'}, + '<|DSML|parameter name="target" string="true">cache', + ), + ( + {"name": "refresh", "arguments": {"target": "cache"}}, + '<|DSML|parameter name="target" string="true">cache', + ), + ], +) +def test_deepseek_v4_encodes_empty_history_tool_arguments( + tool_call, expected_parameter +): + prompt = encode_arguments_to_dsml(tool_call) + + if expected_parameter is None: + assert prompt == "" + else: + assert expected_parameter in prompt + + +def test_deepseek_v4_renders_openai_history_tool_call_with_null_arguments(): + messages = [ + {"role": "user", "content": "Refresh state"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "refresh", + "arguments": None, + }, + } + ], + }, + ] + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + + prompt = _tokenizer().apply_chat_template( + conversation=conversation, + messages=messages, + tokenize=False, + ) + + assert '<|DSML|invoke name="refresh">' in prompt + assert "<|DSML|parameter" not in prompt + + @pytest.mark.parametrize("reasoning_effort", ["minimal", "low", "medium", "high"]) def test_deepseek_v4_accepts_openai_reasoning_effort_values(reasoning_effort): prompt = _tokenizer().apply_chat_template( @@ -288,3 +478,136 @@ def test_deepseek_v4_matches_reference_golden_fixtures(case_id, kwargs): expected = (FIXTURES_DIR / f"test_output_{case_id}.txt").read_text() assert prompt == expected + + +@pytest.mark.parametrize( + "model", + [ + "deepseek-ai/DeepSeek-V4-Flash", + "deepseek-ai/DeepSeek-V4-Pro", + ], +) +def test_deepseek_v4_official_api_defaults_to_thinking_for_v4_family(model): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": model, + "messages": [{"role": "user", "content": "Hello"}], + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + + +def test_deepseek_v4_official_api_uses_model_config_for_family_detection(): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": "local-ds4-alias", + "messages": [{"role": "user", "content": "Hello"}], + "temperature": 0.2, + } + ) + model_config = SimpleNamespace( + hf_config=SimpleNamespace(model_type="deepseek_v4", architectures=[]), + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs, + model_config=model_config, + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + model_config=model_config, + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + assert sampling_params.temperature == 1.0 + + +def test_deepseek_v4_official_api_sampling_override_can_be_disabled(): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + "deepseek_v4_sampling_override": False, + "temperature": 0.2, + "top_p": 0.3, + "top_k": 4, + "min_p": 0.05, + "presence_penalty": 1.5, + "frequency_penalty": 1.25, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + ) + + assert sampling_params.temperature == 0.2 + assert sampling_params.top_p == 0.3 + assert sampling_params.top_k == 4 + assert sampling_params.min_p == 0.05 + assert sampling_params.presence_penalty == 1.5 + assert sampling_params.frequency_penalty == 1.25 + + +def test_deepseek_v4_official_api_sampling_override_is_v4_only(): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-R1", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + "temperature": 0.2, + "top_p": 0.3, + "top_k": 4, + "min_p": 0.05, + "presence_penalty": 1.5, + "frequency_penalty": 1.25, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + ) + + assert "thinking" not in chat_kwargs + assert "enable_thinking" not in chat_kwargs + assert sampling_params.temperature == 0.2 + assert sampling_params.top_p == 0.3 + assert sampling_params.top_k == 4 + assert sampling_params.min_p == 0.05 + assert sampling_params.presence_penalty == 1.5 + assert sampling_params.frequency_penalty == 1.25 diff --git a/tests/tool_parsers/test_structural_tag_registry.py b/tests/tool_parsers/test_structural_tag_registry.py index bd84b2cbbfac..0a554bdc5cd0 100644 --- a/tests/tool_parsers/test_structural_tag_registry.py +++ b/tests/tool_parsers/test_structural_tag_registry.py @@ -91,6 +91,40 @@ def test_get_model_structural_tag_supports_all_xgrammar_builtins( assert isinstance(tag, StructuralTag) +def test_deepseek_v4_chat_structural_tag_allows_stray_think_end( + sample_tools: list[ChatCompletionToolsParam], +): + tag = get_model_structural_tag( + model="deepseek_v4", + tools=sample_tools, + tool_choice="auto", + reasoning=False, + ) + + dumped = tag.model_dump() + excludes = dumped["format"]["excludes"] + assert "" in excludes + assert "" not in excludes + + +def test_deepseek_v4_reasoning_structural_tag_is_unchanged( + sample_tools: list[ChatCompletionToolsParam], +): + tag = get_model_structural_tag( + model="deepseek_v4", + tools=sample_tools, + tool_choice="auto", + reasoning=True, + ) + + dumped = tag.model_dump() + assert dumped["format"]["type"] == "sequence" + assert dumped["format"]["elements"][1]["excludes"] == [ + "", + "", + ] + + def test_get_model_structural_tag_supports_vllm_hermes( sample_tools: list[ChatCompletionToolsParam], ): diff --git a/tests/utils_/test_import_utils.py b/tests/utils_/test_import_utils.py index 464f209f0f25..dc1e55bcd0e4 100644 --- a/tests/utils_/test_import_utils.py +++ b/tests/utils_/test_import_utils.py @@ -1,12 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import builtins +import importlib.util from unittest.mock import MagicMock, patch import pytest +from vllm.utils import import_utils from vllm.utils.import_utils import PlaceholderModule, _has_module +def _clear_import_utils_caches(): + import_utils._has_module.cache_clear() + if hasattr(import_utils.has_cutedsl, "cache_clear"): + import_utils.has_cutedsl.cache_clear() + + def _raises_module_not_found(): return pytest.raises(ModuleNotFoundError, match="No module named") @@ -101,3 +110,27 @@ def test_result_is_cached(self): result = _has_module("json") # should hit cache mock_spec.assert_not_called() assert result is True + + +def test_has_cutedsl_requires_importable_cutlass(monkeypatch: pytest.MonkeyPatch): + real_find_spec = importlib.util.find_spec + real_import = builtins.__import__ + + def fake_find_spec(name, *args, **kwargs): + if name == "cutlass": + return object() + return real_find_spec(name, *args, **kwargs) + + def fake_import(name, *args, **kwargs): + if name == "cutlass": + raise ImportError("broken CUTLASS DSL") + return real_import(name, *args, **kwargs) + + _clear_import_utils_caches() + monkeypatch.setattr(import_utils.importlib.util, "find_spec", fake_find_spec) + monkeypatch.setattr(builtins, "__import__", fake_import) + + try: + assert import_utils.has_cutedsl() is False + finally: + _clear_import_utils_caches() diff --git a/tests/v1/attention/test_deepseek_v4_sparse_swa.py b/tests/v1/attention/test_deepseek_v4_sparse_swa.py new file mode 100644 index 000000000000..d2f7c1d5041f --- /dev/null +++ b/tests/v1/attention/test_deepseek_v4_sparse_swa.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from tests.v1.attention.utils import create_vllm_config +from vllm.config import SpeculativeConfig +from vllm.v1.attention.backends.mla.sparse_swa import ( + DeepseekSparseSWAMetadataBuilder, +) +from vllm.v1.kv_cache_interface import MLAAttentionSpec + + +def test_sparse_swa_reorder_threshold_matches_spec_decode_threshold(): + vllm_config = create_vllm_config( + block_size=256, + hf_config_override={ + "sliding_window": 128, + "compress_ratios": [1, 4, 128], + }, + ) + vllm_config.speculative_config = SpeculativeConfig( + method="ngram", + num_speculative_tokens=2, + ) + kv_cache_spec = MLAAttentionSpec( + block_size=256, + num_kv_heads=1, + head_size=512, + dtype=torch.bfloat16, + compress_ratio=4, + ) + + builder = DeepseekSparseSWAMetadataBuilder( + kv_cache_spec=kv_cache_spec, + layer_names=["dummy"], + vllm_config=vllm_config, + device=torch.device("cpu"), + ) + + assert builder.decode_threshold == 3 + assert builder.reorder_batch_threshold == builder.decode_threshold diff --git a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py index 159bb8af3fb9..eb535598fda6 100644 --- a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py +++ b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py @@ -4,12 +4,70 @@ import pytest import torch +import vllm.v1.attention.backends.mla.indexer as indexer_module from tests.v1.attention.utils import create_vllm_config from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadataBuilder from vllm.v1.kv_cache_interface import MLAAttentionSpec +@pytest.mark.parametrize( + ("is_prefilling", "expected_treat_short_extends_as_decodes"), + [ + (torch.tensor([False, False]), True), + (torch.tensor([False, True]), False), + ], +) +def test_indexer_builder_keeps_short_prefill_continuations_as_prefills( + monkeypatch, + is_prefilling, + expected_treat_short_extends_as_decodes, +): + builder = object.__new__(DeepseekV32IndexerMetadataBuilder) + builder.reorder_batch_threshold = 1 + builder.use_flattening = False + + captured = {} + + def fake_split_decodes_and_prefills( + common_attn_metadata, + *, + decode_threshold=1, + require_uniform=False, + treat_short_extends_as_decodes=True, + ): + captured["treat_short_extends_as_decodes"] = treat_short_extends_as_decodes + raise RuntimeError("stop after split_decodes_and_prefills") + + monkeypatch.setattr( + indexer_module, + "split_decodes_and_prefills", + fake_split_decodes_and_prefills, + ) + query_start_loc = torch.tensor([0, 1, 2], dtype=torch.int32) + metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc.clone(), + seq_lens=torch.tensor([128, 129], dtype=torch.int32), + num_reqs=2, + num_actual_tokens=2, + max_query_len=1, + max_seq_len=129, + block_table_tensor=torch.zeros((2, 1), dtype=torch.int32), + slot_mapping=torch.arange(2, dtype=torch.int64), + is_prefilling=is_prefilling, + seq_lens_cpu_upper_bound=torch.tensor([128, 129], dtype=torch.int32), + ) + + with pytest.raises(RuntimeError, match="stop after"): + builder.build(common_prefix_len=0, common_attn_metadata=metadata) + + assert ( + captured["treat_short_extends_as_decodes"] + is expected_treat_short_extends_as_decodes + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_size(): """Regression test: DeepseekV4 compression path must compute slot_mapping from diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py new file mode 100644 index 000000000000..178983449b2c --- /dev/null +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -0,0 +1,396 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm.utils.deep_gemm as deep_gemm_utils +from vllm.model_executor.layers.sparse_attn_indexer import ( + _decode_logits_width, + _decode_topk_logits_width, + _sparse_indexer_requires_deep_gemm, +) +from vllm.models.deepseek_v4.nvidia.ops import ( + sm12x_deep_gemm_fallbacks, + sm12x_mqa, +) +from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv + + +def _make_indexer_kv_cache( + kv_fp8: torch.Tensor, + kv_scale: torch.Tensor, +) -> torch.Tensor: + num_blocks, block_size, num_kv_heads, head_dim = kv_fp8.shape + assert num_kv_heads == 1 + fused_kv = torch.empty( + num_blocks, + block_size, + 1, + head_dim + torch.float32.itemsize, + device=kv_fp8.device, + dtype=torch.uint8, + ) + block_stride = fused_kv.stride(0) + kv_values = torch.as_strided( + fused_kv, + size=kv_fp8.shape, + stride=(block_stride, head_dim, head_dim, 1), + ) + kv_scales = torch.as_strided( + fused_kv, + size=(num_blocks, block_size, 1, torch.float32.itemsize), + stride=(block_stride, torch.float32.itemsize, torch.float32.itemsize, 1), + storage_offset=block_size * head_dim, + ) + kv_values.copy_(kv_fp8.view(torch.uint8)) + kv_scales.copy_(kv_scale.contiguous().view(torch.uint8)) + return fused_kv + + +def _reference_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_fp8: torch.Tensor, + kv_scale: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + batch_size, next_n, _, _ = q_fp8.shape + _, block_size, _, _ = kv_fp8.shape + logits = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device=q_fp8.device, + dtype=torch.float32, + ) + q = q_fp8.float() + kv = kv_fp8.float() * kv_scale.float() + for batch_idx in range(batch_size): + for next_idx in range(next_n): + row = batch_idx * next_n + next_idx + context_len = min( + int(context_lens[batch_idx, next_idx].item()), + max_model_len, + ) + for token_idx in range(context_len): + block_idx = block_tables[batch_idx, token_idx // block_size] + block_offset = token_idx % block_size + k = kv[block_idx, block_offset, 0] + scores = (q[batch_idx, next_idx] * k).sum(dim=-1).relu() + logits[row, token_idx] = (scores * weights[row]).sum() + return logits + + +def _assert_topk_indices_match_values( + logits: torch.Tensor, + actual: torch.Tensor, + topk_tokens: int, +) -> None: + for row_idx in range(logits.shape[0]): + valid_count = int(torch.isfinite(logits[row_idx]).sum().item()) + row_topk = min(topk_tokens, valid_count) + expected = torch.topk(logits[row_idx], row_topk, dim=0).indices.to(torch.int32) + actual_row = actual[row_idx, :row_topk] + assert torch.all(actual_row >= 0) + assert torch.all(torch.isfinite(logits[row_idx, actual_row.long()])) + actual_values = torch.sort(logits[row_idx, actual_row.long()]).values + expected_values = torch.sort(logits[row_idx, expected.long()]).values + torch.testing.assert_close( + actual_values, + expected_values, + rtol=0, + atol=0, + ) + if row_topk < topk_tokens: + torch.testing.assert_close( + actual[row_idx, row_topk:], + torch.full_like(actual[row_idx, row_topk:], -1), + rtol=0, + atol=0, + ) + + +def test_decode_logits_width_uses_active_context_bound(): + assert _decode_logits_width(262144, 1024) == 1024 + assert _decode_logits_width(4096, 8192) == 4096 + assert _decode_logits_width(4096, 0) == 4096 + assert _decode_logits_width(0, 1024) == 0 + + +def test_decode_topk_logits_width_keeps_topk_kernel_width(): + assert _decode_topk_logits_width(262144, 1024, 512) == 1024 + assert _decode_topk_logits_width(262144, 128, 512) == 512 + assert _decode_topk_logits_width(300, 128, 512) == 300 + assert _decode_topk_logits_width(0, 128, 512) == 0 + + +def test_sparse_indexer_requires_deep_gemm_for_sm120_fp4_cache(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_cuda", + lambda: True, + ) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda family: family == 120, + ) + + assert not _sparse_indexer_requires_deep_gemm(use_fp4_cache=False) + assert _sparse_indexer_requires_deep_gemm(use_fp4_cache=True) + + +def test_sparse_indexer_requires_deep_gemm_for_other_cuda_arches(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_cuda", + lambda: True, + ) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda family: False, + ) + + assert _sparse_indexer_requires_deep_gemm(use_fp4_cache=False) + + +def test_sm120_direct_mqa_logits_block_m_prefers_short_prefill_tile(): + assert sm12x_mqa._fp8_mqa_logits_block_m(1024, 1024) == 16 + assert sm12x_mqa._fp8_mqa_logits_block_m(4096, 4096) == 16 + assert sm12x_mqa._fp8_mqa_logits_block_m(16384, 16384) == 16 + assert sm12x_mqa._fp8_mqa_logits_block_m(65536, 65536) == 64 + + +def test_tf32_hc_prenorm_gemm_does_not_specialize_prefill_token_count(): + kernel = sm12x_mqa._tf32_hc_prenorm_gemm_kernel + constexpr_names = {kernel.arg_names[index] for index in kernel.constexprs} + + assert "M" not in constexpr_names + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_mqa_direct_topk_uses_triton_logits_when_logits_fit( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(11) + num_q, num_heads, head_dim = 8, 4, 64 + seq_len_kv, topk_tokens = 17, 5 + monkeypatch.setattr( + sm12x_deep_gemm_fallbacks, + "_SM120_MQA_TRITON_TOPK_MAX_LOGITS_BYTES", + num_q * seq_len_kv * torch.float32.itemsize, + ) + + q = torch.randn(num_q, num_heads, head_dim, device="cuda", dtype=torch.bfloat16) + q_fp8 = q.to(torch.float8_e4m3fn).contiguous() + kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16) + kv_scale = kv.abs().float().amax(dim=-1).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()[:, None]).to(torch.float8_e4m3fn) + weights = torch.randn(num_q, num_heads, device="cuda", dtype=torch.float32) + cu_seqlen_ks = torch.arange(num_q, device="cuda", dtype=torch.int32) % 3 + cu_seqlen_ke = torch.full((num_q,), seq_len_kv, device="cuda", dtype=torch.int32) + out = torch.empty(num_q, topk_tokens, device="cuda", dtype=torch.int32) + + original_triton = sm12x_mqa.fp8_mqa_logits_triton + calls = 0 + + def wrapped_triton(*args, **kwargs): + nonlocal calls + calls += 1 + return original_triton(*args, **kwargs) + + monkeypatch.setattr(sm12x_mqa, "fp8_mqa_logits_triton", wrapped_triton) + original_topk_op = sm12x_deep_gemm_fallbacks._top_k_per_row_prefill_op() + if original_topk_op is None: + pytest.skip("top_k_per_row_prefill op is unavailable") + topk_calls = 0 + + def wrapped_topk_op(*args, **kwargs): + nonlocal topk_calls + topk_calls += 1 + return original_topk_op(*args, **kwargs) + + monkeypatch.setattr( + sm12x_deep_gemm_fallbacks, + "_top_k_per_row_prefill_op", + lambda: wrapped_topk_op, + ) + + assert deep_gemm_utils.fp8_fp4_mqa_topk_indices( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + out, + ) + assert calls == 1 + assert topk_calls == 1 + + reference_logits = sm12x_deep_gemm_fallbacks._fp8_mqa_logits_torch( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + _assert_topk_indices_match_values(reference_logits, out, topk_tokens) + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_mqa_direct_topk_uses_triton_chunks_when_logits_do_not_fit( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(13) + num_q, num_heads, head_dim = 9, 4, 64 + seq_len_kv, topk_tokens = 23, 6 + logits_bytes = num_q * seq_len_kv * torch.float32.itemsize + monkeypatch.setattr( + sm12x_deep_gemm_fallbacks, + "_SM120_MQA_TRITON_TOPK_MAX_LOGITS_BYTES", + logits_bytes - 1, + ) + monkeypatch.setattr( + sm12x_deep_gemm_fallbacks, + "_SM120_MQA_TRITON_CHUNKED_TOPK_CHUNK_SIZE", + 7, + ) + + q = torch.randn(num_q, num_heads, head_dim, device="cuda", dtype=torch.bfloat16) + q_fp8 = q.to(torch.float8_e4m3fn).contiguous() + kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16) + kv_scale = kv.abs().float().amax(dim=-1).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()[:, None]).to(torch.float8_e4m3fn) + weights = torch.randn(num_q, num_heads, device="cuda", dtype=torch.float32) + cu_seqlen_ks = torch.arange(num_q, device="cuda", dtype=torch.int32) % 4 + cu_seqlen_ke = torch.full((num_q,), seq_len_kv, device="cuda", dtype=torch.int32) + out = torch.empty(num_q, topk_tokens, device="cuda", dtype=torch.int32) + + original_triton = sm12x_mqa.fp8_mqa_logits_triton + calls = 0 + + def wrapped_triton(*args, **kwargs): + nonlocal calls + calls += 1 + return original_triton(*args, **kwargs) + + monkeypatch.setattr(sm12x_mqa, "fp8_mqa_logits_triton", wrapped_triton) + + assert deep_gemm_utils.fp8_fp4_mqa_topk_indices( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + out, + ) + assert calls == cdiv(seq_len_kv, 7) + + reference_logits = sm12x_deep_gemm_fallbacks._fp8_mqa_logits_torch( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + _assert_topk_indices_match_values(reference_logits, out, topk_tokens) + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_paged_mqa_direct_topk_matches_truncated_decode_width( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(7) + batch_size, next_n, num_heads, head_dim = 2, 2, 8, 32 + block_size, max_model_len, num_blocks = 4, 64, 16 + active_max_len = 13 + topk_tokens = 6 + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None) + monkeypatch.setattr( + sm12x_deep_gemm_fallbacks, + "_SM120_PAGED_MQA_TOPK_CHUNK_SIZE", + 7, + ) + + q = torch.randn( + batch_size, + next_n, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + q_fp8 = q.to(torch.float8_e4m3fn).contiguous() + kv = torch.randn( + num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn) + fused_kv = _make_indexer_kv_cache(kv_fp8, kv_scale) + + weights = torch.randn( + batch_size * next_n, num_heads, device="cuda", dtype=torch.float32 + ) + context_lens = torch.tensor( + [[7, active_max_len], [9, 12]], device="cuda", dtype=torch.int32 + ) + block_tables = ( + torch.arange( + batch_size * cdiv(max_model_len, block_size), + device="cuda", + dtype=torch.int32, + ).reshape(batch_size, -1) + % num_blocks + ) + + full_width_topk = torch.empty( + batch_size * next_n, topk_tokens, device="cuda", dtype=torch.int32 + ) + truncated_width_topk = torch.empty_like(full_width_topk) + + assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices( + (q_fp8, None), + fused_kv, + weights, + context_lens, + block_tables, + max_model_len, + full_width_topk, + ) + assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices( + (q_fp8, None), + fused_kv, + weights, + context_lens, + block_tables, + active_max_len, + truncated_width_topk, + ) + + reference_logits = _reference_paged_mqa_logits( + q_fp8, + kv_fp8, + kv_scale, + weights, + context_lens, + block_tables, + active_max_len, + ) + expected_topk = torch.topk(reference_logits, topk_tokens, dim=1).indices.to( + torch.int32 + ) + + torch.testing.assert_close(truncated_width_topk, full_width_topk, rtol=0, atol=0) + torch.testing.assert_close(truncated_width_topk, expected_topk, rtol=0, atol=0) diff --git a/tests/v1/attention/test_sparse_mla_env.py b/tests/v1/attention/test_sparse_mla_env.py new file mode 100644 index 000000000000..b50d5f1b9dd8 --- /dev/null +++ b/tests/v1/attention/test_sparse_mla_env.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for sparse MLA environment helpers.""" + +from vllm.v1.attention.backends.mla import sparse_mla_env + + +def test_prefill_topk_uses_sm12x_multi_request_guard(monkeypatch): + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", raising=False) + monkeypatch.setattr( + sparse_mla_env.current_platform, + "is_device_capability_family", + lambda family: family == 120, + ) + + assert ( + sparse_mla_env.triton_sparse_mla_prefill_topk_chunk_size( + combined_topk_size=1152, + compress_ratio=128, + request_count=2, + ) + == 256 + ) + + +def test_prefill_topk_relaxes_sm12x_single_request_c128a(monkeypatch): + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", raising=False) + monkeypatch.setattr( + sparse_mla_env.current_platform, + "is_device_capability_family", + lambda family: family == 120, + ) + + assert ( + sparse_mla_env.triton_sparse_mla_prefill_topk_chunk_size( + combined_topk_size=1152, + compress_ratio=128, + request_count=1, + ) + == 1024 + ) + + +def test_prefill_topk_keeps_default_for_other_lower_risk_shapes(monkeypatch): + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", raising=False) + monkeypatch.setattr( + sparse_mla_env.current_platform, + "is_device_capability_family", + lambda family: family == 120, + ) + + assert ( + sparse_mla_env.triton_sparse_mla_prefill_topk_chunk_size( + combined_topk_size=640, + compress_ratio=4, + request_count=2, + ) + == 512 + ) + assert ( + sparse_mla_env.triton_sparse_mla_prefill_topk_chunk_size( + combined_topk_size=128, + compress_ratio=1, + request_count=2, + ) + == 128 + ) + + +def test_prefill_topk_honors_explicit_env_override(monkeypatch): + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", "512") + monkeypatch.setattr( + sparse_mla_env.current_platform, + "is_device_capability_family", + lambda family: family == 120, + ) + + assert ( + sparse_mla_env.triton_sparse_mla_prefill_topk_chunk_size( + combined_topk_size=1152, + compress_ratio=128, + request_count=2, + ) + == 512 + ) diff --git a/tests/v1/attention/test_sparse_mla_indexed_d512.py b/tests/v1/attention/test_sparse_mla_indexed_d512.py new file mode 100644 index 000000000000..8486de516281 --- /dev/null +++ b/tests/v1/attention/test_sparse_mla_indexed_d512.py @@ -0,0 +1,291 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_indexed_d512_chunked_sparse_mla_attention, + accumulate_indexed_d512_split_sparse_mla_attention, + accumulate_indexed_sparse_mla_attention_chunk, +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_indexed_d512_split_sparse_mla_matches_indexed_accumulate(): + torch.cuda.set_device(0) + device = torch.device("cuda:0") + torch.manual_seed(17) + num_tokens = 64 + num_heads = 8 + head_dim = 512 + num_candidates = 640 + kv_tokens = 4096 + scale = head_dim**-0.5 + + q = torch.randn( + num_tokens, + num_heads, + head_dim, + device=device, + dtype=torch.bfloat16, + ) + kv = torch.randn(kv_tokens, head_dim, device=device, dtype=torch.bfloat16) + indices = torch.randint( + 0, + kv_tokens, + (num_tokens, num_candidates), + device=device, + dtype=torch.int32, + ) + lens = torch.randint( + num_candidates // 2, + num_candidates + 1, + (num_tokens,), + device=device, + dtype=torch.int32, + ) + + current_max = torch.full( + (num_tokens, num_heads), + -float("inf"), + device=device, + dtype=torch.float32, + ) + current_denom = torch.zeros_like(current_max) + current_acc = torch.zeros( + num_tokens, num_heads, head_dim, device=device, dtype=torch.float32 + ) + split_max = torch.empty_like(current_max) + split_denom = torch.empty_like(current_denom) + split_acc = torch.empty_like(current_acc) + split_scores = torch.empty( + num_tokens, + num_heads, + num_candidates, + device=device, + dtype=torch.float32, + ) + + accumulate_indexed_sparse_mla_attention_chunk( + q=q, + kv_flat=kv, + indices=indices, + lens=lens, + scale=scale, + max_score=current_max, + denom=current_denom, + acc=current_acc, + ) + accumulate_indexed_d512_split_sparse_mla_attention( + q=q, + kv_flat=kv, + indices=indices, + lens=lens, + scale=scale, + max_score=split_max, + denom=split_denom, + acc=split_acc, + scores=split_scores, + ) + torch.cuda.synchronize() + + current = current_acc / current_denom[:, :, None] + split = split_acc / split_denom[:, :, None] + torch.testing.assert_close(split_max, current_max, atol=2e-5, rtol=2e-5) + torch.testing.assert_close(split_denom, current_denom, atol=2e-3, rtol=2e-3) + torch.testing.assert_close(split, current, atol=2e-3, rtol=2e-3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_indexed_d512_split_sparse_mla_matches_c128_combined_width(): + torch.cuda.set_device(0) + device = torch.device("cuda:0") + torch.manual_seed(23) + num_tokens = 64 + num_heads = 8 + head_dim = 512 + num_candidates = 1152 + kv_tokens = 4096 + scale = head_dim**-0.5 + + q = torch.randn( + num_tokens, + num_heads, + head_dim, + device=device, + dtype=torch.bfloat16, + ) + kv = torch.randn(kv_tokens, head_dim, device=device, dtype=torch.bfloat16) + indices = torch.randint( + 0, + kv_tokens, + (num_tokens, num_candidates), + device=device, + dtype=torch.int32, + ) + lens = torch.randint( + 128, + 1097, + (num_tokens,), + device=device, + dtype=torch.int32, + ) + + current_max = torch.full( + (num_tokens, num_heads), + -float("inf"), + device=device, + dtype=torch.float32, + ) + current_denom = torch.zeros_like(current_max) + current_acc = torch.zeros( + num_tokens, num_heads, head_dim, device=device, dtype=torch.float32 + ) + split_max = torch.empty_like(current_max) + split_denom = torch.empty_like(current_denom) + split_acc = torch.empty_like(current_acc) + split_scores = torch.empty( + num_tokens, + num_heads, + num_candidates, + device=device, + dtype=torch.float32, + ) + + accumulate_indexed_sparse_mla_attention_chunk( + q=q, + kv_flat=kv, + indices=indices, + lens=lens, + scale=scale, + max_score=current_max, + denom=current_denom, + acc=current_acc, + ) + accumulate_indexed_d512_split_sparse_mla_attention( + q=q, + kv_flat=kv, + indices=indices, + lens=lens, + scale=scale, + max_score=split_max, + denom=split_denom, + acc=split_acc, + scores=split_scores, + ) + torch.cuda.synchronize() + + current = current_acc / current_denom[:, :, None] + split = split_acc / split_denom[:, :, None] + torch.testing.assert_close(split_max, current_max, atol=2e-5, rtol=2e-5) + torch.testing.assert_close(split_denom, current_denom, atol=2e-3, rtol=2e-3) + torch.testing.assert_close(split, current, atol=2e-3, rtol=2e-3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_indexed_d512_chunked_sparse_mla_matches_wide_indexed_accumulate(): + torch.cuda.set_device(0) + device = torch.device("cuda:0") + torch.manual_seed(29) + num_tokens = 64 + num_heads = 8 + head_dim = 512 + num_candidates = 2048 + chunk_candidates = 1152 + kv_tokens = 8192 + scale = head_dim**-0.5 + + q = torch.randn( + num_tokens, + num_heads, + head_dim, + device=device, + dtype=torch.bfloat16, + ) + kv = torch.randn(kv_tokens, head_dim, device=device, dtype=torch.bfloat16) + indices = torch.randint( + 0, + kv_tokens, + (num_tokens, num_candidates), + device=device, + dtype=torch.int32, + ) + lens = torch.randint( + num_candidates // 2, + num_candidates + 1, + (num_tokens,), + device=device, + dtype=torch.int32, + ) + lens[0] = 0 + + current_max = torch.full( + (num_tokens, num_heads), + -float("inf"), + device=device, + dtype=torch.float32, + ) + current_denom = torch.zeros_like(current_max) + current_acc = torch.zeros( + num_tokens, num_heads, head_dim, device=device, dtype=torch.float32 + ) + chunked_max = torch.empty_like(current_max) + chunked_denom = torch.empty_like(current_denom) + chunked_acc = torch.empty_like(current_acc) + chunk_max = torch.empty_like(current_max) + chunk_denom = torch.empty_like(current_denom) + chunk_acc = torch.empty_like(current_acc) + chunk_scores = torch.empty( + num_tokens, + num_heads, + chunk_candidates, + device=device, + dtype=torch.float32, + ) + + accumulate_indexed_sparse_mla_attention_chunk( + q=q, + kv_flat=kv, + indices=indices, + lens=lens, + scale=scale, + max_score=current_max, + denom=current_denom, + acc=current_acc, + ) + accumulate_indexed_d512_chunked_sparse_mla_attention( + q=q, + kv_flat=kv, + indices=indices, + lens=lens, + scale=scale, + max_score=chunked_max, + denom=chunked_denom, + acc=chunked_acc, + scores=chunk_scores, + chunk_max_score=chunk_max, + chunk_denom=chunk_denom, + chunk_acc=chunk_acc, + ) + torch.cuda.synchronize() + + valid_rows = lens > 0 + current = current_acc[valid_rows] / current_denom[valid_rows, :, None] + chunked = chunked_acc[valid_rows] / chunked_denom[valid_rows, :, None] + torch.testing.assert_close( + chunked_max[valid_rows], + current_max[valid_rows], + atol=2e-5, + rtol=2e-5, + ) + torch.testing.assert_close( + chunked_denom[valid_rows], + current_denom[valid_rows], + atol=2e-3, + rtol=2e-3, + ) + torch.testing.assert_close(chunked, current, atol=2e-3, rtol=2e-3) + assert torch.isneginf(chunked_max[~valid_rows]).all() + assert torch.count_nonzero(chunked_denom[~valid_rows]) == 0 + assert torch.count_nonzero(chunked_acc[~valid_rows]) == 0 diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 3e375b8720ec..6a4734b26781 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -42,6 +42,7 @@ KVCacheSpecKind, MambaSpec, MLAAttentionSpec, + SlidingWindowMLASpec, SlidingWindowSpec, ) @@ -126,6 +127,27 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: ) +def make_mla_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: + return KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=512, + dtype=torch.float8_e4m3fn, + cache_dtype_str="fp8_ds_mla", + compress_ratio=2, + model_version="deepseek_v4", + ), + ) + ], + ) + + def make_kv_cache_config_hybrid_model( block_size: int, num_blocks: int, @@ -1655,6 +1677,38 @@ def test_cache_blocks_multi_group(): ) +def test_deepseek_v4_mla_prompt_protection_scales_with_max_num_seqs(): + block_size = 4 + manager = make_kv_cache_manager( + make_mla_kv_cache_config(block_size=block_size, num_blocks=64), + max_model_len=16, + max_num_batched_tokens=16, + max_num_seqs=4, + enable_caching=True, + hash_block_size=block_size, + ) + + mla_manager = manager.coordinator.single_type_managers[0] + + assert mla_manager._max_protected_prompt_blocks() == 16 + + +def test_deepseek_v4_mla_prompt_protection_leaves_allocation_headroom(): + block_size = 4 + manager = make_kv_cache_manager( + make_mla_kv_cache_config(block_size=block_size, num_blocks=15), + max_model_len=16, + max_num_batched_tokens=16, + max_num_seqs=4, + enable_caching=True, + hash_block_size=block_size, + ) + + mla_manager = manager.coordinator.single_type_managers[0] + + assert mla_manager._max_protected_prompt_blocks() == 10 + + def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. @@ -2036,6 +2090,34 @@ def test_maybe_evict_cached_block(): assert pool.cached_block_hash_to_block._cache == {} +@pytest.mark.parametrize("cache_state", ["missing", "different_block"]) +def test_maybe_evict_cached_block_resets_stale_hash_on_miss(cache_state: str): + pool = BlockPool( + num_gpu_blocks=3, + enable_caching=True, + hash_block_size=16, + enable_kv_cache_events=True, + ) + block = pool.blocks[1] + other_block = pool.blocks[2] + block_hash = make_block_hash_with_group_id(BlockHash(b"stale"), 0) + + block.block_hash = block_hash + if cache_state == "different_block": + other_block.block_hash = block_hash + pool.cached_block_hash_to_block.insert(block_hash, other_block) + + assert pool._maybe_evict_cached_block(block) is False + assert block.block_hash is None + assert pool.kv_event_queue == [] + + if cache_state == "different_block": + assert pool.cached_block_hash_to_block._cache == {block_hash: other_block} + assert other_block.block_hash == block_hash + else: + assert pool.cached_block_hash_to_block._cache == {} + + @pytest.mark.parametrize("blocks_to_cache", [2, 3, 10]) def test_kv_cache_events(blocks_to_cache: int): block_size = 16 @@ -2925,12 +3007,14 @@ def test_hybrid_cache_blocks_swa_tail_window_only(): ) -def test_hybrid_cache_blocks_clamped_to_lcm(): - """HybridKVCacheCoordinator.cache_blocks() clamps to scheduler_block_size. - Chunks past the last lcm-aligned boundary can never participate in a - cache hit (find_longest_cache_hit always returns lcm-aligned hits), so - caching them only pollutes the prefix-cache hash map and keeps blocks - on the LRU list that could otherwise return to the free pool.""" +def test_hybrid_cache_blocks_keeps_tail_blocks_but_hits_stay_lcm_aligned(): + """HybridKVCacheCoordinator.cache_blocks() keeps complete tail blocks. + + find_longest_cache_hit still returns lcm-aligned hits, so caching the + complete blocks past the last lcm boundary must not produce partial hybrid + cache hits. Keeping those blocks lets later turns complete a reusable + lcm-aligned segment instead of permanently dropping the tail. + """ block_size = 16 # Full attn block_size=32, SWA block_size=16 -> lcm=32. kv_cache_config = KVCacheConfig( @@ -2965,8 +3049,9 @@ def test_hybrid_cache_blocks_clamped_to_lcm(): hash_block_size=block_size, ) - # 7 hash-blocks of 16 tokens (112 tokens). With lcm=32 the clamp truncates - # to 96 tokens — SWA caches 6 hashes, full-attn caches 3. + # 7 hash-blocks of 16 tokens (112 tokens). The trailing SWA block is a + # complete block and should be cached, but a later hit is still capped by + # the full-attention group at the last 32-token lcm boundary. token_ids = [i for i in range(7) for _ in range(block_size)] req = make_request("0", token_ids, block_size, sha256) computed_blocks, _ = manager.get_computed_blocks(req) @@ -2980,16 +3065,18 @@ def test_hybrid_cache_blocks_clamped_to_lcm(): assert len(req.block_hashes) == 7 pool = manager.block_pool - # SWA group_id=1: hashes 0..5 cached (6 blocks * 16 tokens = 96), hash 6 - # spans tokens [96, 112) past the lcm boundary and must NOT be cached. - for i in range(6): + # SWA group_id=1: all complete SWA blocks should be cached. + for i in range(7): assert ( pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[1]) is not None ), f"SWA hash {i} should be cached" - assert pool.get_cached_block(req.block_hashes[6], kv_cache_group_ids=[1]) is None, ( - "SWA hash 6 spans tokens past the lcm boundary; should not be cached" - ) + + warm = make_request("1", token_ids, block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(warm) + assert len(computed_blocks.blocks[0]) == 3 + assert len(computed_blocks.blocks[1]) == 6 + assert num_computed_tokens == 6 * block_size def test_hybrid_local_kv_retention_interval_aligns_in_manager(monkeypatch): @@ -3470,6 +3557,272 @@ def test_can_fit_full_sequence_swa_cap_admits_long_prompt(): ) +def test_deepseek_v4_mla_prompt_cache_survives_decode_pressure(): + hash_block_size = 2 + full_block_size = 8 + swa_block_size = 2 + prompt_tokens = 35 + chunk_tokens = 4 * full_block_size + expected_hit_tokens = (prompt_tokens - 1) // full_block_size * full_block_size + + config = KVCacheConfig( + num_blocks=70, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=full_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_0"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + sliding_window=2 * swa_block_size, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_1"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + sliding_window=2 * swa_block_size, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_compressor_state"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=2 * swa_block_size, + ), + ), + ], + ) + manager = KVCacheManager( + config, + max_model_len=128, + scheduler_block_size=full_block_size, + max_num_batched_tokens=chunk_tokens, + enable_caching=True, + hash_block_size=hash_block_size, + ) + + def run_request(request: Request, num_decode_tokens: int) -> int: + computed_blocks, num_computed_tokens = manager.get_computed_blocks(request) + computed_so_far = num_computed_tokens + remaining_prompt_tokens = request.num_prompt_tokens - num_computed_tokens + first_chunk = True + while remaining_prompt_tokens > 0: + num_new_tokens = min(chunk_tokens, remaining_prompt_tokens) + allocated = manager.allocate_slots( + request, + num_new_tokens, + num_computed_tokens if first_chunk else 0, + computed_blocks if first_chunk else None, + ) + assert allocated is not None + computed_so_far += num_new_tokens + request.num_computed_tokens = computed_so_far + remaining_prompt_tokens -= num_new_tokens + first_chunk = False + + for i in range(num_decode_tokens): + request.append_output_token_ids(10_000 + i) + allocated = manager.allocate_slots(request, 1) + assert allocated is not None + computed_so_far += 1 + request.num_computed_tokens = computed_so_far + return num_computed_tokens + + prompt_a = list(range(prompt_tokens)) + req_a = make_request("a", prompt_a, hash_block_size, sha256) + assert run_request(req_a, num_decode_tokens=0) == 0 + manager.free(req_a) + + warm_a = make_request("warm_a", prompt_a, hash_block_size, sha256) + assert run_request(warm_a, num_decode_tokens=8) == expected_hit_tokens + assert manager.get_num_common_prefix_blocks("warm_a")[0] >= ( + expected_hit_tokens // full_block_size + ) + manager.free(warm_a) + + pressure_blocks = manager.block_pool.get_new_blocks( + manager.block_pool.get_num_free_blocks() + ) + manager.block_pool.free_blocks(reversed(pressure_blocks)) + + req_a_again = make_request("a_again", prompt_a, hash_block_size, sha256) + _, num_computed_tokens = manager.get_computed_blocks(req_a_again) + assert num_computed_tokens == expected_hit_tokens + + +def test_deepseek_v4_mla_cached_prompts_do_not_block_admission(): + block_size = 8 + prompt_tokens = 4 * block_size + 3 + num_prompts = 10 + num_blocks = 80 + manager = KVCacheManager( + KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ) + ], + ), + max_model_len=512, + scheduler_block_size=block_size, + max_num_batched_tokens=128, + enable_caching=True, + hash_block_size=block_size, + ) + + for i in range(num_prompts): + prompt = list(range(i * 1000, i * 1000 + prompt_tokens)) + req = make_request(f"protected_{i}", prompt, block_size, sha256) + assert manager.allocate_slots(req, prompt_tokens) is not None + req.num_computed_tokens = prompt_tokens + manager.free(req) + + assert manager.block_pool.get_num_free_blocks() < 64 + + long_req = make_request( + "long", + list(range(100_000, 100_000 + 64 * block_size)), + block_size, + sha256, + ) + assert ( + manager.allocate_slots(long_req, block_size, full_sequence_must_fit=True) + is not None + ) + + +@pytest.mark.skip_global_cleanup +def test_deepseek_v4_mla_prefix_hit_under_pressure_does_not_overallocate(): + block_size = 8 + prompt_tokens = 5 * block_size + manager = KVCacheManager( + KVCacheConfig( + num_blocks=12, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ) + ], + ), + max_model_len=512, + scheduler_block_size=block_size, + max_num_batched_tokens=128, + enable_caching=True, + hash_block_size=block_size, + ) + + warm_req = make_request("warm", list(range(prompt_tokens)), block_size, sha256) + assert manager.allocate_slots(warm_req, prompt_tokens) is not None + warm_req.num_computed_tokens = prompt_tokens + manager.free(warm_req) + + pressure_blocks = manager.block_pool.get_new_blocks( + manager.block_pool.get_num_free_blocks() - 2 + ) + + reuse_req = make_request( + "reuse", + list(range(prompt_tokens)) + list(range(10_000, 10_000 + 2 * block_size)), + block_size, + sha256, + ) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(reuse_req) + + try: + assert ( + manager.allocate_slots( + reuse_req, + 2 * block_size, + num_new_computed_tokens=num_computed_tokens, + new_computed_blocks=computed_blocks, + ) + is None + ) + req_blocks = manager.coordinator.single_type_managers[0].req_to_blocks + assert "reuse" not in req_blocks + finally: + manager.block_pool.free_blocks(reversed(pressure_blocks)) + + +def test_reset_prefix_cache_after_deepseek_v4_mla_prompt_cache(): + block_size = 8 + prompt_tokens = 4 * block_size + 3 + manager = KVCacheManager( + KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ) + ], + ), + max_model_len=512, + scheduler_block_size=block_size, + max_num_batched_tokens=128, + enable_caching=True, + hash_block_size=block_size, + ) + + req = make_request("protected", list(range(prompt_tokens)), block_size, sha256) + assert manager.allocate_slots(req, prompt_tokens) is not None + req.num_computed_tokens = prompt_tokens + manager.free(req) + + assert manager.reset_prefix_cache() + + def test_can_fit_full_sequence_full_attention_still_gates_oversized(): """The cap only loosens the SWA group; a prompt that exceeds the full-attention pool capacity must still be rejected.""" diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index dcfbfd5b1b30..b92dc8ef4a32 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -903,6 +903,80 @@ def test_schedule_order(enable_chunked_prefill: bool): assert len(scheduler_output1.scheduled_new_reqs) == 1 +def _run_request_to_completion(scheduler: Scheduler, request: Request) -> None: + while request.request_id in scheduler.requests: + scheduler_output = scheduler.schedule() + assert scheduler_output.num_scheduled_tokens + req_ids = list(scheduler_output.num_scheduled_tokens) + sampled_token_ids = [] + for req_id in req_ids: + if ( + req_id == request.request_id + and request.num_computed_tokens >= request.num_prompt_tokens + ): + sampled_token_ids.append([0]) + else: + sampled_token_ids.append([]) + scheduler.update_from_output( + scheduler_output, + ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)}, + sampled_token_ids=sampled_token_ids, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + +def test_prefix_cache_peek_does_not_record_stats(): + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_model_len=2048, + max_num_seqs=2, + enable_chunked_prefill=True, + enable_prefix_caching=True, + ) + warm_req = create_requests( + num_requests=1, + num_tokens=600, + max_tokens=1, + same_prompt=True, + req_ids=["warm_prefix"], + )[0] + replay_req = create_requests( + num_requests=1, + num_tokens=620, + same_prompt=True, + req_ids=["replay_prefix"], + )[0] + + scheduler.add_request(warm_req) + _run_request_to_completion(scheduler, warm_req) + stats = scheduler.kv_cache_manager.make_prefix_cache_stats() + assert stats is not None + + _, peeked_tokens = scheduler.kv_cache_manager.get_computed_blocks( + replay_req, + record_stats=False, + ) + assert peeked_tokens > 0 + stats = scheduler.kv_cache_manager.make_prefix_cache_stats() + assert stats is not None + assert stats.requests == 0 + assert stats.queries == 0 + assert stats.hits == 0 + + _, recorded_tokens = scheduler.kv_cache_manager.get_computed_blocks(replay_req) + assert recorded_tokens == peeked_tokens + stats = scheduler.kv_cache_manager.make_prefix_cache_stats() + assert stats is not None + assert stats.requests == 1 + assert stats.queries == replay_req.num_tokens + assert stats.hits == peeked_tokens + + def test_preempt_during_execution(): # NOTE(woosuk): The actual number of available blocks is 10 instead of 11 # because block 0 is reserved as the null block. diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index c10835821f58..94dcfe1f5d59 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -267,6 +267,39 @@ def test_get_capture_descs_empty_when_not_initialized(self): assert dispatcher.get_capture_descs() == [] + def test_deepseek_v4_mtp_spec_decode_keeps_full_and_piecewise_graphs(self): + comp_config = CompilationConfig( + cudagraph_mode="FULL_AND_PIECEWISE", + mode=CompilationMode.VLLM_COMPILE, + cudagraph_capture_sizes=[3, 6, 12, 18], + ) + config = _create_vllm_config(comp_config, max_num_seqs=8) + config.speculative_config = MagicMock() + config.speculative_config.method = "mtp" + config.speculative_config.num_speculative_tokens = 2 + config.model_config = MagicMock() + config.model_config.hf_config = MagicMock() + config.model_config.hf_config.model_type = "deepseek_v4" + config.model_config.hf_config.architectures = ["DeepseekV4ForCausalLM"] + + dispatcher = CudagraphDispatcher(config) + dispatcher.initialize_cudagraph_keys( + cudagraph_mode=comp_config.cudagraph_mode, + uniform_decode_query_len=3, + ) + + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) > 0 + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) > 0 + + rt_mode, key = dispatcher.dispatch( + num_tokens=12, + uniform_decode=True, + has_lora=False, + ) + + assert rt_mode == CUDAGraphMode.FULL + assert key == BatchDescriptor(num_tokens=12, num_reqs=4, uniform=True) + @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") class TestCUDAGraphWrapper: diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index e334371f6d8a..fa5416ae3edb 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -25,13 +25,49 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.sample.logits_processor import LogitsProcessors +from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.llm_base_proposer import compute_probs_and_sample_next_token mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base" DEVICE_TYPE = current_platform.device_type -def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: +def _create_sampling_metadata( + all_greedy: bool, + batch_size: int, + top_k: torch.Tensor | None = None, + top_p: torch.Tensor | None = None, +) -> SamplingMetadata: + temperature = None + if not all_greedy: + temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE) + return SamplingMetadata( + temperature=temperature, + all_greedy=all_greedy, + all_random=not all_greedy, + top_p=top_p, + top_k=top_k, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=torch.tensor([], device=DEVICE_TYPE), + presence_penalties=torch.tensor([], device=DEVICE_TYPE), + repetition_penalties=torch.tensor([], device=DEVICE_TYPE), + output_token_ids=[], + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=LogitsProcessors(), + spec_token_ids=[], + ) + + +def _create_mtp_proposer( + num_speculative_tokens: int, + parallel_drafting: bool = False, +) -> EagleProposer: """Create an MTP proposer with unified model configuration.""" model_config = ModelConfig( model=mimo_7b_dir, runner="generate", max_model_len=100, trust_remote_code=True @@ -43,7 +79,10 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: model=mimo_7b_dir, method="mtp", num_speculative_tokens=num_speculative_tokens, + parallel_drafting=parallel_drafting, ) + if parallel_drafting: + speculative_config.draft_model_config.hf_config.ptd_token_id = 0 vllm_config = VllmConfig( model_config=model_config, @@ -219,3 +258,240 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): assert model_mock.called # Verify output shape assert result.shape == (batch_size, num_speculative_tokens) + + +def test_mtp_propose_random_sampling_records_draft_probs(): + device = torch.device(DEVICE_TYPE) + batch_size = 2 + seq_lens = [3, 2] + total_tokens = sum(seq_lens) + vocab_size = 4 + + proposer = _create_mtp_proposer(num_speculative_tokens=1) + assert proposer._enable_probabilistic_draft_probs + hidden_size = proposer.hidden_size + + model_mock = mock.MagicMock() + model_mock.return_value = torch.zeros(total_tokens, hidden_size, device=device) + logits = torch.tensor([[0.0, 1.0, 2.0, 3.0], [3.0, 2.0, 1.0, 0.0]], device=device) + model_mock.compute_logits.return_value = logits.clone() + proposer.model = model_mock + proposer._draft_attn_layer_names = {"layer.0"} + + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) + common_attn_metadata = create_common_attn_metadata( + batch_spec, block_size=16, device=device + ) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=list(proposer._draft_attn_layer_names), + vllm_config=proposer.vllm_config, + device=device, + ) + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder + mock_attn_group.layer_names = list(proposer._draft_attn_layer_names) + mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec + proposer.draft_attn_groups = [mock_attn_group] + + result = proposer.propose( + target_token_ids=torch.randint(0, vocab_size, (total_tokens,), device=device), + target_positions=torch.arange(total_tokens, device=device), + target_hidden_states=torch.randn(total_tokens, hidden_size, device=device), + next_token_ids=torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ), + token_indices_to_sample=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=_create_sampling_metadata( + all_greedy=False, batch_size=batch_size + ), + ) + + assert result.shape == (batch_size, 1) + assert proposer._last_draft_probs is not None + assert proposer._last_draft_probs.shape == (batch_size, 1, vocab_size) + expected_probs = torch.softmax(logits, dim=-1).view(batch_size, 1, vocab_size) + assert torch.allclose(proposer._last_draft_probs, expected_probs) + # ``take_last_draft_probs`` is the upstream-side accessor that + # ``GPUModelRunner`` uses to plumb probs into the rejection sampler. + assert torch.equal(proposer.take_last_draft_probs(), proposer._last_draft_probs) + + +def test_mtp_sequential_drafting_passes_spec_step_indices(): + device = torch.device(DEVICE_TYPE) + batch_size = 2 + seq_lens = [3, 2] + total_tokens = sum(seq_lens) + vocab_size = 4 + num_spec_tokens = 2 + + proposer = _create_mtp_proposer(num_speculative_tokens=num_spec_tokens) + proposer.block_size = 16 + hidden_size = proposer.hidden_size + + model_mock = mock.MagicMock() + model_mock.side_effect = [ + torch.zeros(total_tokens, hidden_size, device=device), + torch.zeros(batch_size, hidden_size, device=device), + ] + + def logits_for_token(token_id: int): + logits = torch.full((batch_size, vocab_size), -100.0, device=device) + logits[:, token_id] = 100.0 + return logits + + model_mock.compute_logits.side_effect = [ + logits_for_token(1), + logits_for_token(2), + ] + proposer.model = model_mock + proposer._draft_attn_layer_names = {"layer.0"} + + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) + common_attn_metadata = create_common_attn_metadata( + batch_spec, block_size=16, device=device + ) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=list(proposer._draft_attn_layer_names), + vllm_config=proposer.vllm_config, + device=device, + ) + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder + mock_attn_group.layer_names = list(proposer._draft_attn_layer_names) + mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec + proposer.draft_attn_groups = [mock_attn_group] + + result = proposer.propose( + target_token_ids=torch.randint(0, vocab_size, (total_tokens,), device=device), + target_positions=torch.arange(total_tokens, device=device), + target_hidden_states=torch.randn(total_tokens, hidden_size, device=device), + next_token_ids=torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ), + token_indices_to_sample=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=_create_sampling_metadata( + all_greedy=True, batch_size=batch_size + ), + ) + + assert torch.equal( + result, + torch.tensor([[1, 2], [1, 2]], device=device), + ) + assert [ + call.kwargs.get("spec_step_idx", 0) + for call in model_mock.compute_logits.call_args_list + ] == [0, 1] + assert [ + call.kwargs.get("spec_step_idx", 0) for call in model_mock.call_args_list + ] == [0, 1] + + +def test_mtp_draft_sampling_applies_top_k_to_draft_probs(): + logits = torch.tensor([[0.0, 1.0, 2.0, 3.0]], device=DEVICE_TYPE) + top_k = torch.tensor([2], dtype=torch.int32, device=DEVICE_TYPE) + + _token_ids, draft_probs = compute_probs_and_sample_next_token( + logits, + _create_sampling_metadata(all_greedy=False, batch_size=1, top_k=top_k), + ) + + expected_logits = torch.tensor( + [[-float("inf"), -float("inf"), 2.0, 3.0]], device=DEVICE_TYPE + ) + expected_probs = torch.softmax(expected_logits, dim=-1, dtype=torch.float32) + assert torch.allclose(draft_probs, expected_probs) + + +def test_mtp_parallel_drafting_random_sampling_records_draft_probs(): + device = torch.device(DEVICE_TYPE) + batch_size = 2 + num_spec_tokens = 2 + seq_lens = [2, 2] + total_tokens = sum(seq_lens) + vocab_size = 4 + + proposer = _create_mtp_proposer( + num_speculative_tokens=num_spec_tokens, + parallel_drafting=True, + ) + assert proposer._enable_probabilistic_draft_probs + proposer.block_size = 16 + hidden_size = proposer.hidden_size + + model_mock = mock.MagicMock() + model_mock.return_value = torch.zeros( + total_tokens + batch_size, + hidden_size, + dtype=proposer.dtype, + device=device, + ) + logits = torch.tensor( + [ + [0.0, 1.0, 2.0, 3.0], + [3.0, 2.0, 1.0, 0.0], + [0.0, 0.5, 1.0, 1.5], + [1.5, 1.0, 0.5, 0.0], + ], + device=device, + ) + model_mock.compute_logits.return_value = logits.clone() + proposer.model = model_mock + proposer._draft_attn_layer_names = {"layer.0"} + + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) + common_attn_metadata = create_common_attn_metadata( + batch_spec, block_size=16, device=device + ) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=list(proposer._draft_attn_layer_names), + vllm_config=proposer.vllm_config, + device=device, + ) + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder + mock_attn_group.layer_names = list(proposer._draft_attn_layer_names) + mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec + proposer.draft_attn_groups = [mock_attn_group] + + result = proposer.propose( + target_token_ids=torch.randint(0, vocab_size, (total_tokens,), device=device), + target_positions=torch.arange(total_tokens, device=device), + target_hidden_states=torch.randn( + total_tokens, + hidden_size, + dtype=proposer.dtype, + device=device, + ), + next_token_ids=torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ), + token_indices_to_sample=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=_create_sampling_metadata( + all_greedy=False, batch_size=batch_size + ), + ) + + assert result.shape == (batch_size, num_spec_tokens) + assert proposer._last_draft_probs is not None + assert proposer._last_draft_probs.shape == (batch_size, num_spec_tokens, vocab_size) + assert torch.allclose( + proposer._last_draft_probs, + torch.softmax(logits, dim=-1).view(batch_size, num_spec_tokens, vocab_size), + ) + assert torch.equal(proposer.take_last_draft_probs(), proposer._last_draft_probs) diff --git a/tests/v1/worker/test_ubatch_utils.py b/tests/v1/worker/test_ubatch_utils.py new file mode 100644 index 000000000000..f0bc987a7c73 --- /dev/null +++ b/tests/v1/worker/test_ubatch_utils.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.v1.attention.backend import CommonAttentionMetadata +from vllm.v1.worker.ubatch_utils import UBatchSlice, split_attn_metadata + + +def test_split_attn_metadata_preserves_position_and_prefill_flags(): + query_start_loc = torch.tensor([0, 3, 7, 9], dtype=torch.int32) + metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc.clone(), + seq_lens=torch.tensor([16, 24, 32], dtype=torch.int32), + num_reqs=3, + num_actual_tokens=9, + max_query_len=4, + max_seq_len=32, + block_table_tensor=torch.arange(12, dtype=torch.int32).reshape(3, 4), + slot_mapping=torch.arange(9, dtype=torch.int64), + positions=torch.arange(100, 109, dtype=torch.int64), + is_prefilling=torch.tensor([False, True, True]), + seq_lens_cpu_upper_bound=torch.tensor([16, 24, 32], dtype=torch.int32), + ) + + split = split_attn_metadata( + [UBatchSlice(request_slice=slice(1, 3), token_slice=slice(3, 9))], + metadata, + )[0] + + torch.testing.assert_close(split.positions, metadata.positions[3:9]) + torch.testing.assert_close(split.is_prefilling, metadata.is_prefilling[1:3]) diff --git a/vllm/compilation/breakable_cudagraph.py b/vllm/compilation/breakable_cudagraph.py index 6da3ec717861..787cc16b4d23 100644 --- a/vllm/compilation/breakable_cudagraph.py +++ b/vllm/compilation/breakable_cudagraph.py @@ -53,6 +53,13 @@ def is_breakable_cudagraph_enabled() -> bool: return bool(envs.VLLM_USE_BREAKABLE_CUDAGRAPH) +# Set True (once, at wrapper init) when the engine runs speculative decoding. +# The DeepSeek-V4 sparse-MLA decode is per-token and not FULL-cudagraph-safe for +# spec-decode batches; when this is set we eager-break the DSv4 attention out of +# the FULL graph (see eager_break_during_capture). +_BREAK_DSV4_ATTN_UNDER_FULL_FOR_SPEC = False + + F = TypeVar("F", bound=Callable[..., Any]) @@ -99,7 +106,16 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return fn(*args, **kwargs) if is_forward_context_available(): mode = get_forward_context().cudagraph_runtime_mode - if mode == CUDAGraphMode.FULL: + # Under spec-decode, the per-token DeepSeek-V4 sparse-MLA decode + # cross-contaminates requests when captured into a FULL monolithic + # cudagraph (token_to_req_indices -> per-request block_table/topk + # gather). Eager-break the DSv4 attention (the nested indexer then + # runs eagerly too) for the q=1 draft and the q>1 verify forwards; + # keep FULL capture for non-spec decode and every non-DSv4 op. + if mode == CUDAGraphMode.FULL and not ( + _BREAK_DSV4_ATTN_UNDER_FULL_FOR_SPEC + and "deepseek_v4.attention" in (getattr(fn, "__module__", "") or "") + ): return fn(*args, **kwargs) # Weak-ref args: strong refs in the replay lambda pin cudagraph-pool @@ -278,6 +294,10 @@ def __init__( # BatchDescriptor which already encodes batch shape / uniformity. self.runnable = runnable self.vllm_config = vllm_config + _sc = getattr(vllm_config, "speculative_config", None) + if _sc is not None and (getattr(_sc, "num_speculative_tokens", 0) or 0) > 0: + global _BREAK_DSV4_ATTN_UNDER_FULL_FOR_SPEC + _BREAK_DSV4_ATTN_UNDER_FULL_FOR_SPEC = True self.compilation_config = vllm_config.compilation_config self.graph_pool = current_platform.get_global_graph_pool() self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" diff --git a/vllm/compilation/passes/utility/fix_functionalization.py b/vllm/compilation/passes/utility/fix_functionalization.py index c0643a916b38..75994ba58039 100644 --- a/vllm/compilation/passes/utility/fix_functionalization.py +++ b/vllm/compilation/passes/utility/fix_functionalization.py @@ -39,11 +39,20 @@ def __call__(self, graph: torch.fx.Graph) -> None: count = 0 rope_targets = [torch.ops._C.rotary_embedding.default] + fused_deepseek_v4_mla_targets = [] if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"): rope_targets.append( torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default ) + if hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert"): + fused_deepseek_v4_mla_targets.append( + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert.default + ) + if hasattr(torch.ops.vllm, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert"): + fused_deepseek_v4_mla_targets.append( + torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert.default + ) for node in graph.nodes: if not is_func(node, auto_functionalized): @@ -181,6 +190,9 @@ def __call__(self, graph: torch.fx.Graph) -> None: 2: "key", } self.defunctionalize(graph, node, mutated_args=mutated_args) + elif at_target in fused_deepseek_v4_mla_targets: + mutated_args = {1: "q", 2: "k_cache"} + self.defunctionalize(graph, node, mutated_args) elif ( hasattr(torch.ops.vllm, "fused_rope_unified_mla_kv_cache_update") and at_target diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index bc38ec6a8a8a..fe6416de3c0b 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -757,6 +757,7 @@ class CompilationConfig: "vllm::sparse_attn_indexer", "vllm::rocm_aiter_sparse_attn_indexer", "vllm::deepseek_v4_attention", + "vllm::deepseek_v4_fp8_einsum", ] def compute_hash(self) -> str: @@ -1489,6 +1490,16 @@ def adjust_cudagraph_sizes_for_spec_decode( if round_up(size, multiple_of) <= self.max_cudagraph_capture_size ) ) + # Spec decode uniform decode graphs are shaped by request count + # (`num_reqs * (1 + num_speculative_tokens)`). Keep the small + # interactive request counts exact so FULL decode graphs do not replay + # against padded virtual requests. + small_decode_sizes = { + multiple_of * num_reqs + for num_reqs in range(1, 33) + if multiple_of * num_reqs <= self.max_cudagraph_capture_size + } + rounded_sizes = sorted(set(rounded_sizes) | small_decode_sizes) if len(rounded_sizes) == 0 and multiple_of <= self.max_cudagraph_capture_size: # if one valid but would be round_down use that diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ba7d26c93b23..c7ddeb5f787f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -76,6 +76,13 @@ } ) +_BREAKABLE_CUDAGRAPH_AUTO_ENABLE_ARCHITECTURES = frozenset( + { + "MiniMaxM3SparseForCausalLM", + "MiniMaxM3SparseForConditionalGeneration", + } +) + class OptimizationLevel(IntEnum): """Optimization level enum.""" @@ -104,6 +111,24 @@ class OptimizationLevel(IntEnum): # See https://github.com/vllm-project/vllm/issues/25689. +def _should_auto_enable_breakable_cudagraph( + model_config: ModelConfig, +) -> bool: + # Auto-enable breakable cudagraph only for architectures that lack + # @support_torch_compile and are known-good under it (MiniMax M3 retains its + # upstream unconditional auto-enable). DeepSeek-V4 is deliberately excluded: + # breakable mode disables the torch.compile pipeline (equivalent to + # -cc.mode=none) and runs attention eagerly every decode step, which on SM12x + # is 1.5-3.8x SLOWER for MTP decode and degrades with output length (measured + # on RTX PRO 6000 / SM120 and 2x GB10 / SM121). FULL_AND_PIECEWISE + + # torch.compile is correct (GSM8K parity, bare-prompt clean) and faster, so + # it is the DSv4 default. Opt in with VLLM_USE_BREAKABLE_CUDAGRAPH=1. + return any( + arch in _BREAKABLE_CUDAGRAPH_AUTO_ENABLE_ARCHITECTURES + for arch in model_config.architectures + ) + + def enable_norm_fusion(cfg: "VllmConfig") -> bool: """Enable if either RMS norm or quant FP8 custom op is active; otherwise Inductor handles fusion.""" @@ -1073,22 +1098,15 @@ def __post_init__(self): ) self.compilation_config.mode = CompilationMode.NONE - # For model classes don't carry @support_torch_compile — - # the breakable cudagraph is the supported PIECEWISE path. Auto-enable - # it unless the user has explicitly opted out via the env var. + # Some model classes don't carry @support_torch_compile and rely on the + # breakable cudagraph PIECEWISE path; auto-enable it for those unless the + # user explicitly opted out. DeepSeek-V4 is deliberately excluded (see + # _should_auto_enable_breakable_cudagraph) — it is faster on + # FULL_AND_PIECEWISE + torch.compile. if ( self.model_config is not None and "VLLM_USE_BREAKABLE_CUDAGRAPH" not in os.environ - and any( - a - in ( - "DeepseekV4ForCausalLM", - "DeepSeekV4MTPModel", - "MiniMaxM3SparseForCausalLM", - "MiniMaxM3SparseForConditionalGeneration", - ) - for a in self.model_config.architectures - ) + and _should_auto_enable_breakable_cudagraph(self.model_config) ): os.environ["VLLM_USE_BREAKABLE_CUDAGRAPH"] = "1" logger.info_once( diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index f0b56da3432d..625a65d733a4 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -382,6 +382,12 @@ class ConversationMessage(TypedDict, total=False): reasoning_content: str | None """Deprecated: The reasoning content for interleaved thinking.""" + prefix: bool + """Whether this assistant message is a prefix for continuation.""" + + wo_eos: bool + """Whether this message should be rendered without an EOS marker.""" + tools: list[ChatCompletionFunctionToolParam] | None """The tools for developer role.""" @@ -1750,6 +1756,8 @@ def _parse_chat_message_content( role = message["role"] content = message.get("content") reasoning = message.get("reasoning") + if reasoning is None: + reasoning = message.get("reasoning_content") if content is None: content = [] @@ -1779,6 +1787,9 @@ def _parse_chat_message_content( result_msg["reasoning_content"] = cast( str, reasoning ) # keep compatibility + if parsed_msg.get("prefix"): + result_msg["prefix"] = True + result_msg["wo_eos"] = True elif role == "tool": parsed_msg = _ToolParser(message) if "tool_call_id" in parsed_msg: diff --git a/vllm/entrypoints/openai/chat_completion/batch_serving.py b/vllm/entrypoints/openai/chat_completion/batch_serving.py index 96ed7dcb777b..5ee929976a19 100644 --- a/vllm/entrypoints/openai/chat_completion/batch_serving.py +++ b/vllm/entrypoints/openai/chat_completion/batch_serving.py @@ -160,8 +160,12 @@ async def create_batch_chat_completion( self.override_max_tokens, ) single_request = single_requests[i] + chat_template_kwargs = self._effective_chat_template_kwargs(single_request) sampling_params = single_request.to_sampling_params( - max_tokens, self.default_sampling_params + max_tokens, + self.default_sampling_params, + chat_template_kwargs=chat_template_kwargs, + model_config=self.model_config, ) self._log_inputs( sub_request_id, diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 3457aa12f4ac..a4f0b54b74c5 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -65,6 +65,15 @@ class ChatMessage(OpenAIBaseModel): # vLLM-specific fields that are not in OpenAI spec reasoning: str | None = None + reasoning_content: str | None = None + + @model_validator(mode="after") + def _populate_reasoning_content_alias(self) -> "ChatMessage": + if self.reasoning_content is None and self.reasoning is not None: + self.reasoning_content = self.reasoning + elif self.reasoning is None and self.reasoning_content is not None: + self.reasoning = self.reasoning_content + return self class ChatCompletionLogProb(OpenAIBaseModel): @@ -183,6 +192,10 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): type: Literal["function"] = "function" +class DeepSeekThinkingParam(OpenAIBaseModel): + type: Literal["enabled", "disabled"] = "enabled" + + class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create @@ -228,6 +241,15 @@ class ChatCompletionRequest(OpenAIBaseModel): "part of the standard OpenAI API specification." ), ) + thinking: DeepSeekThinkingParam | None = None + deepseek_v4_sampling_override: bool = Field( + default=True, + description=( + "Apply DeepSeek V4 official sampling defaults when thinking is " + "enabled. This only affects the DeepSeek V4 family and can be " + "disabled per request." + ), + ) thinking_token_budget: ThinkingTokenBudget = None include_reasoning: bool = True parallel_tool_calls: bool | None = True @@ -500,6 +522,63 @@ def build_chat_params( media_io_kwargs=self.media_io_kwargs, ) + def _is_deepseek_v4_model(self, model_config: ModelConfig | None = None) -> bool: + hf_config = getattr(model_config, "hf_config", None) + if getattr(hf_config, "model_type", None) == "deepseek_v4": + return True + + architectures = getattr(hf_config, "architectures", None) or () + if any("deepseekv4" in str(arch).replace("_", "").lower() + for arch in architectures): + return True + + model = (self.model or "").lower().replace("_", "-") + return "deepseek-v4" in model + + def apply_chat_template_kwargs( + self, + chat_template_kwargs: dict[str, Any], + *, + model_config: ModelConfig | None = None, + ) -> dict[str, Any]: + """Apply request-level DeepSeek API compatibility knobs. + + DeepSeek's OpenAI-compatible API exposes ``thinking`` as a top-level + request field, while vLLM's DeepSeek tokenizer consumes it as a chat + template kwarg. Keep the translation at the protocol boundary so the + tokenizer and reasoning parser see the same effective state. + """ + chat_template_kwargs = dict(chat_template_kwargs) + if not self._is_deepseek_v4_model(model_config): + return chat_template_kwargs + + if self.thinking is not None: + enabled = self.thinking.type == "enabled" + chat_template_kwargs["thinking"] = enabled + chat_template_kwargs["enable_thinking"] = enabled + elif ( + "thinking" not in chat_template_kwargs + and "enable_thinking" not in chat_template_kwargs + ): + chat_template_kwargs["thinking"] = True + chat_template_kwargs["enable_thinking"] = True + + return chat_template_kwargs + + def _use_deepseek_v4_sampling_override(self) -> bool: + return self.deepseek_v4_sampling_override + + @staticmethod + def _is_thinking_enabled( + chat_template_kwargs: dict[str, Any] | None, + ) -> bool: + if chat_template_kwargs is None: + return False + return bool( + chat_template_kwargs.get("thinking") + or chat_template_kwargs.get("enable_thinking") + ) + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: if self.max_completion_tokens is not None: max_output_tokens: int | None = self.max_completion_tokens @@ -550,6 +629,9 @@ def to_sampling_params( self, max_tokens: int, default_sampling_params: dict, + *, + chat_template_kwargs: dict[str, Any] | None = None, + model_config: ModelConfig | None = None, ) -> SamplingParams: # Default parameters if (repetition_penalty := self.repetition_penalty) is None: @@ -574,6 +656,21 @@ def to_sampling_params( "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] ) + if ( + self._is_deepseek_v4_model(model_config) + and self._use_deepseek_v4_sampling_override() + and self._is_thinking_enabled(chat_template_kwargs) + ): + temperature = self._DEFAULT_SAMPLING_PARAMS["temperature"] + top_p = self._DEFAULT_SAMPLING_PARAMS["top_p"] + top_k = self._DEFAULT_SAMPLING_PARAMS["top_k"] + min_p = self._DEFAULT_SAMPLING_PARAMS["min_p"] + presence_penalty = 0.0 + frequency_penalty = 0.0 + else: + presence_penalty = self.presence_penalty or 0.0 + frequency_penalty = self.frequency_penalty or 0.0 + prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: prompt_logprobs = self.top_logprobs @@ -616,8 +713,8 @@ def to_sampling_params( extra_args["kv_transfer_params"] = self.kv_transfer_params return SamplingParams.from_optional( n=self.n, - presence_penalty=self.presence_penalty, - frequency_penalty=self.frequency_penalty, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, repetition_penalty=repetition_penalty, temperature=temperature, top_p=top_p, diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 911421029c3b..ecb798f3e5c1 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -196,7 +196,7 @@ def warmup(self) -> None: def _effective_chat_template_kwargs( self, request: ChatCompletionRequest ) -> dict[str, Any]: - return ( + chat_template_kwargs = ( request.build_chat_params( self.chat_template, self.chat_template_content_format, @@ -204,6 +204,10 @@ def _effective_chat_template_kwargs( .with_defaults(self.default_chat_template_kwargs) .chat_template_kwargs ) + return request.apply_chat_template_kwargs( + chat_template_kwargs, + model_config=self.model_config, + ) async def render_chat_request( self, @@ -319,6 +323,8 @@ async def _create_chat_completion( sampling_params = request.to_sampling_params( max_tokens, self.default_sampling_params, + chat_template_kwargs=chat_template_kwargs, + model_config=self.model_config, ) self._log_inputs( diff --git a/vllm/entrypoints/openai/engine/protocol.py b/vllm/entrypoints/openai/engine/protocol.py index d86c77561dbd..6817f01b6b80 100644 --- a/vllm/entrypoints/openai/engine/protocol.py +++ b/vllm/entrypoints/openai/engine/protocol.py @@ -351,8 +351,17 @@ class DeltaMessage(OpenAIBaseModel): role: str | None = None content: str | None = None reasoning: str | None = None + reasoning_content: str | None = None tool_calls: list[DeltaToolCall] = Field(default_factory=list) + @model_validator(mode="after") + def _populate_reasoning_content_alias(self) -> "DeltaMessage": + if self.reasoning_content is None and self.reasoning is not None: + self.reasoning_content = self.reasoning + elif self.reasoning is None and self.reasoning_content is not None: + self.reasoning = self.reasoning_content + return self + class GenerationError(Exception): """raised when finish_reason indicates internal server error (500)""" diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 612ff6d35e03..77ebb2353cab 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time from collections.abc import Sequence +from dataclasses import replace from http import HTTPStatus from typing import Any, cast @@ -303,7 +304,21 @@ async def render_chat_request( self.override_max_tokens, truncate_prompt_tokens=request.truncate_prompt_tokens, ) - params = request.to_sampling_params(max_tokens, self.default_sampling_params) + chat_template_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params( + self.chat_template, + self.chat_template_content_format, + ) + .with_defaults(self.default_chat_template_kwargs) + .chat_template_kwargs, + model_config=self.model_config, + ) + params = request.to_sampling_params( + max_tokens, + self.default_sampling_params, + chat_template_kwargs=chat_template_kwargs, + model_config=self.model_config, + ) request_id = f"chatcmpl-{random_uuid()}" @@ -913,6 +928,17 @@ async def preprocess_chat( default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None), default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None), ) + apply_chat_template_kwargs = getattr( + request, "apply_chat_template_kwargs", None + ) + if apply_chat_template_kwargs is not None: + chat_params = replace( + chat_params, + chat_template_kwargs=apply_chat_template_kwargs( + chat_params.chat_template_kwargs, + model_config=self.model_config, + ), + ) (conversation,), (engine_input,) = await renderer.render_chat_async( [messages], diff --git a/vllm/envs.py b/vllm/envs.py index 190b15667dd1..7f55c414c2c6 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -179,6 +179,18 @@ VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True + VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: bool = True + VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL: bool = True + VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS: int = 4096 + VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_WARMUP: bool = True + VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL: bool = True + VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE: bool = False + VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL: bool = False + VLLM_TRITON_MLA_SPARSE: bool | None = None + VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE: int = 512 + VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE: int = 256 + VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE: int | None = None + VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE: bool | None = None VLLM_DEEP_GEMM_WARMUP: Literal[ "skip", "full", @@ -1445,6 +1457,57 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool( int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1")) ), + "VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP": lambda: bool( + int(os.getenv("VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP", "1")) + ), + "VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL": lambda: bool( + int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL", "1")) + ), + # Minimum prefill sequence length to admit a row to the indexed-D512 fast + # prefill path. Lower = more (shorter / early-chunk) prefills use the fast + # kernel. 4096 measured +9-59% short/medium-prefill tok/s (biggest at the + # 4-8k band), GSM8K-clean and KV-cache-neutral on SM12x; set 8192 to revert + # to the prior conservative threshold. + "VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS": lambda: int( + os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS", "4096") + ), + # Pre-compile the D512-split sparse-MLA prefill Triton kernels at startup + # (one per 128-aligned combined_topk in [256, 1152]) so the first long + # prefill does not pay a first-use JIT compile that blocks the engine step. + "VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_WARMUP": lambda: bool( + int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_WARMUP", "1")) + ), + "VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL": lambda: bool( + int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL", "1")) + ), + "VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE": lambda: bool( + int(os.getenv("VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE", "0")) + ), + "VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL": lambda: bool( + int(os.getenv("VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL", "0")) + ), + # Experimental sparse MLA fallback controls. + # ``VLLM_TRITON_MLA_SPARSE`` unset means auto-select where FlashMLA sparse + # is unavailable; set 0/1 to force-disable/force-enable the fallback. + "VLLM_TRITON_MLA_SPARSE": lambda: ( + None + if os.getenv("VLLM_TRITON_MLA_SPARSE") is None + else bool(int(os.getenv("VLLM_TRITON_MLA_SPARSE", "0"))) + ), + "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE": lambda: int( + os.getenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", "512") + ), + "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE": lambda: int( + os.getenv("VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE", "256") + ), + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE": lambda: maybe_convert_int( + os.getenv("VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE") + ), + "VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE": lambda: ( + None + if os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE") is None + else bool(int(os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE", "0"))) + ), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py index 9f69ab0c7377..0b679792267c 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py @@ -9,6 +9,9 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, @@ -30,6 +33,20 @@ ) +def _is_sm12x_compute_capability(compute_capability) -> bool: + if compute_capability is None: + return current_platform.is_device_capability_family(120) + + if isinstance(compute_capability, tuple): + return compute_capability[0] == 12 + + to_int = getattr(compute_capability, "to_int", None) + if callable(to_int): + return to_int() // 10 == 12 + + return int(compute_capability) // 10 == 12 + + class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def is_supported( @@ -286,6 +303,9 @@ def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None: @classmethod def is_supported(cls, compute_capability=None): + if _is_sm12x_compute_capability(compute_capability): + return False, "CUTLASS block-scaled FP8 GEMM is not supported on SM12x." + if not CUTLASS_BLOCK_FP8_SUPPORTED: return ( False, @@ -309,6 +329,31 @@ def can_implement(cls, config: FP8ScaledMMLinearLayerConfig): ) return True, None + def process_weights_after_loading(self, layer: torch.nn.Module): + super().process_weights_after_loading(layer) + params = self._get_layer_params(layer) + weight_scale = ( + params.weight_scale + if params.weight_scale_inv is None + else params.weight_scale_inv + ) + scale_attr_name = ( + params.WEIGHT_SCALE + if params.weight_scale_inv is None + else params.WEIGHT_SCALE_INV + ) + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if ( + e8m0_dtype is not None + and weight_scale is not None + and weight_scale.dtype == e8m0_dtype + ): + replace_parameter( + layer, + scale_attr_name, + _upcast_e8m0_to_fp32(weight_scale), + ) + def apply_block_scaled_mm( self, A: torch.Tensor, diff --git a/vllm/model_executor/kernels/linear/scaled_mm/marlin.py b/vllm/model_executor/kernels/linear/scaled_mm/marlin.py index 66a03b4d205b..59427e0d3897 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/marlin.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/marlin.py @@ -57,6 +57,22 @@ def is_supported( @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + # Marlin's per-channel scale layout cannot serve block-FP8 layers + # (weight_scale_inv shape [N_blocks_n, N_blocks_k]). + # csrc/quantization/marlin/marlin.cu fires + # ``TORCH_CHECK(b_scales.size(1) == size_n, ...)`` for any layer + # where the block-quant scales arrive in their raw 2-D layout. + # Even when VLLM_TEST_FORCE_FP8_MARLIN=1 is set (which downstream + # operators do for NVFP4-MoE on SM 12.0 / RTX PRO 6000 Blackwell), + # block-FP8 attention compressor layers like DSv4 fused_wqa_wkv + # must fall through to the next supported kernel in + # ``_POSSIBLE_FP8_BLOCK_KERNELS[CUDA]`` (typically the Triton + # block-FP8 path). + if c.activation_quant_key.scale.group_shape.is_per_group(): + return False, ( + "MarlinFP8 cannot serve block-FP8 layers; falling " + "through to the next kernel in the priority list." + ) return True, None def __init__( diff --git a/vllm/model_executor/kernels/mhc/tilelang.py b/vllm/model_executor/kernels/mhc/tilelang.py index e0007141d53d..0876dc88602d 100644 --- a/vllm/model_executor/kernels/mhc/tilelang.py +++ b/vllm/model_executor/kernels/mhc/tilelang.py @@ -2,9 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op +def _use_tf32_hc_prenorm_gemm() -> bool: + from vllm.utils.deep_gemm import is_deep_gemm_supported + + return current_platform.is_device_capability_family(120) or is_deep_gemm_supported() + + def _torch_hc_prenorm_gemm( x: torch.Tensor, fn: torch.Tensor, @@ -162,10 +169,8 @@ def mhc_pre_tilelang( residual_flat = residual.view(-1, hc_mult, hidden_size) num_tokens = residual_flat.shape[0] - from vllm.utils.deep_gemm import is_deep_gemm_supported - - use_deep_gemm = is_deep_gemm_supported() - if use_deep_gemm: + use_tf32_hc_prenorm_gemm = _use_tf32_hc_prenorm_gemm() + if use_tf32_hc_prenorm_gemm: # these numbers are from deepgemm kernel impl block_k = 64 block_m = 64 @@ -191,7 +196,7 @@ def mhc_pre_tilelang( ) residual_2d = residual_flat.view(num_tokens, hc_mult * hidden_size) - if use_deep_gemm: + if use_tf32_hc_prenorm_gemm: tf32_hc_prenorm_gemm( residual_2d, fn, @@ -405,16 +410,14 @@ def mhc_fused_post_pre_tilelang( post_layer_mix_flat = post_layer_mix.view(num_tokens, hc_mult) comb_res_mix_flat = comb_res_mix.view(num_tokens, hc_mult, hc_mult) - from vllm.utils.deep_gemm import is_deep_gemm_supported - - use_deep_gemm = is_deep_gemm_supported() + use_tf32_hc_prenorm_gemm = _use_tf32_hc_prenorm_gemm() use_small_fma = num_tokens <= 16 if use_small_fma: # TODO(gnovack): investigate autotuning these heuristics tile_n = 2 if num_tokens < 8 else 3 n_splits = 8 if (num_tokens < 8 and hidden_size <= 4096) else 4 else: - if use_deep_gemm: + if use_tf32_hc_prenorm_gemm: # these number are from deepgemm kernel impl block_k = 64 block_m = 64 @@ -485,7 +488,7 @@ def mhc_fused_post_pre_tilelang( ) residual_cur_2d = residual_cur.view(num_tokens, hc_mult * hidden_size) - if use_deep_gemm: + if use_tf32_hc_prenorm_gemm: from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm tf32_hc_prenorm_gemm( diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..06431d4d355e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..75dfc52cb46d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..75dfc52cb46d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json similarity index 100% rename from vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json rename to vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..8b78f87e7f73 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.0", + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..4b98ba105c14 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json rename to vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..bcec61632e3e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json rename to vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..705ca33d594b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..952b908297b7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json similarity index 100% rename from vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json rename to vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..9c2ebaddd83f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..36d0361b926f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutlass_moe.py index 76cd15ff5a0c..cd0b7ac4a54f 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutlass_moe.py @@ -103,20 +103,24 @@ def __init__( ) if quant_config.weight_quant_dtype == "mxfp4": - # This value is used specifically for gpt-oss, - # Need to revisit this for other models - self.gemm1_alpha = torch.tensor( - [1.702] * self.num_experts, dtype=torch.float32, device=self.device - ) - self.gemm1_beta = torch.tensor( - [1.0] * self.num_experts, dtype=torch.float32, device=self.device + self.gemm1_alpha = ( + torch.tensor( + [quant_config.gemm1_alpha] * self.num_experts, + dtype=torch.float32, + device=self.device, + ) + if quant_config.gemm1_alpha is not None + else None ) - if self.gemm1_clamp_limit is None: - self.gemm1_clamp_limit = torch.tensor( - [7.0] * self.num_experts, + self.gemm1_beta = ( + torch.tensor( + [quant_config.gemm1_beta] * self.num_experts, dtype=torch.float32, device=self.device, ) + if quant_config.gemm1_beta is not None + else None + ) if quant_config.quant_dtype == "mxfp8": self.fake_input_scale = torch.ones( self.num_experts, @@ -325,9 +329,6 @@ def apply( elif self.weight_quant_dtype == "mxfp4": assert self.w1_scale is not None and self.w2_scale is not None assert w1.is_contiguous() and w2.is_contiguous() - assert self.gemm1_alpha is not None - assert self.gemm1_beta is not None - assert self.gemm1_clamp_limit is not None assert topk_ids.is_contiguous() fc1_expert_biases = self.w1_bias diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 49957c8f5e36..3636c7a6dcd1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -35,6 +35,23 @@ logger = init_logger(__name__) +_SM12X_TUNED_CONFIG_DEVICE_ALIASES = { + "NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition": ( + "NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition", + ), + "NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition": ( + "NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition", + ), +} + + +def _tuned_config_device_names() -> tuple[str, ...]: + device_name = current_platform.get_device_name().replace(" ", "_") + if "H200" in device_name.split("_"): + return ("NVIDIA_H200",) + return (device_name, *_SM12X_TUNED_CONFIG_DEVICE_ALIASES.get(device_name, ())) + + @triton.jit def write_zeros_to_output( c_ptr, @@ -999,10 +1016,30 @@ def zero_experts_compute_triton( def get_config_file_name( E: int, N: int, dtype: str | None, block_shape: list[int] | None = None ) -> str: - device_name = current_platform.get_device_name().replace(" ", "_") - # Set device_name to H200 if a device from the H200 family is detected - if "H200" in device_name.split("_"): - device_name = "NVIDIA_H200" + device_name = _tuned_config_device_names()[0] + dtype_selector = "" if not dtype else f",dtype={dtype}" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ).replace(" ", "") + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 + + +def _get_config_file_names( + E: int, N: int, dtype: str | None, block_shape: list[int] | None = None +) -> tuple[str, ...]: + return tuple( + _get_config_file_name_for_device(E, N, dtype, device_name, block_shape) + for device_name in _tuned_config_device_names() + ) + + +def _get_config_file_name_for_device( + E: int, + N: int, + dtype: str | None, + device_name: str, + block_shape: list[int] | None = None, +) -> str: dtype_selector = "" if not dtype else f",dtype={dtype}" block_shape_selector = ( "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" @@ -1035,22 +1072,26 @@ def get_moe_configs( # First look up if an optimized configuration is available in the configs # directory block_shape = [block_n, block_k] if block_n and block_k else None - json_file_name = get_config_file_name(E, N, dtype, block_shape) + json_file_names = _get_config_file_names(E, N, dtype, block_shape) - config_file_paths = [] + config_file_paths: list[str] = [] # note that we prioritize user defined config user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER if user_defined_config_folder is not None: - user_defined_config_file_path = os.path.join( - user_defined_config_folder, json_file_name + config_file_paths.extend( + os.path.join(user_defined_config_folder, json_file_name) + for json_file_name in json_file_names ) - config_file_paths.append(user_defined_config_file_path) - default_config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + config_file_paths.extend( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "configs", + json_file_name, + ) + for json_file_name in json_file_names ) - config_file_paths.append(default_config_file_path) for config_file_path in config_file_paths: if os.path.exists(config_file_path): diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index 5d94d82c01c6..6dcdfe6add12 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -1228,7 +1228,7 @@ def convert_weight_to_mxfp4_moe_kernel_format( ]: """Convert loaded weights into backend-specific kernel format. - Supports DeepGEMM, TRTLLM MXFP8, Triton and Marlin backends. + Supports DeepGEMM, TRTLLM MXFP8, CUTLASS MXFP8, Triton and Marlin backends. """ if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4: @@ -1284,6 +1284,47 @@ def convert_weight_to_mxfp4_moe_kernel_format( sf_block_size = 32 # mxfp4 block size + if mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8: + from flashinfer import block_scale_interleave + + w13_weight = w13_weight.data + w2_weight = w2_weight.data + w13_weight_scale = w13_weight_scale.data + w2_weight_scale = w2_weight_scale.data + + # FlashInfer CUTLASS expects [up, gate], while vLLM stores [gate, up]. + w1_weight = w13_weight[:, :intermediate_size, :] + w3_weight = w13_weight[:, intermediate_size:, :] + w13_weight = torch.cat([w3_weight, w1_weight], dim=1).contiguous() + + w1_scale = w13_weight_scale[:, :intermediate_size, :] + w3_scale = w13_weight_scale[:, intermediate_size:, :] + w13_weight_scale = torch.cat([w3_scale, w1_scale], dim=1).contiguous() + + if w13_bias is not None: + b1 = w13_bias[:, :intermediate_size] + b3 = w13_bias[:, intermediate_size:] + w13_bias = torch.cat([b3, b1], dim=1).contiguous() + + w13_scale_shape = w13_weight_scale.shape + w13_weight_scale = block_scale_interleave( + w13_weight_scale.view(torch.uint8) + ).reshape(w13_scale_shape) + + w2_scale_shape = w2_weight_scale.shape + w2_weight_scale = block_scale_interleave( + w2_weight_scale.view(torch.uint8) + ).reshape(w2_scale_shape) + + return ( + w13_weight, + w2_weight.contiguous(), + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) + if mxfp4_backend in TRTLLM_BACKENDS: assert _cache_permute_indices is not None from flashinfer.fp4_quantization import nvfp4_block_scale_interleave diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 93bc81c22be4..2f98e9285977 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -174,6 +174,7 @@ def select_nvfp4_moe_backend( NVFP4_BACKENDS_WITH_CLAMP = { NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_CUTLASS, } if config.swiglu_limit is not None: diff --git a/vllm/model_executor/layers/fused_moe/routed_experts.py b/vllm/model_executor/layers/fused_moe/routed_experts.py index 669d1d376902..0a4a1a37d958 100644 --- a/vllm/model_executor/layers/fused_moe/routed_experts.py +++ b/vllm/model_executor/layers/fused_moe/routed_experts.py @@ -275,6 +275,19 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: # Weight Loading Methods # + @staticmethod + def _normalize_loaded_weight_for_copy( + expert_data: torch.Tensor, loaded_weight: torch.Tensor + ) -> torch.Tensor: + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if ( + e8m0_dtype is not None + and expert_data.dtype == torch.uint8 + and loaded_weight.dtype == e8m0_dtype + ): + return loaded_weight.view(torch.uint8) + return loaded_weight + def _load_per_tensor_weight_scale( self, shard_id: str, @@ -288,10 +301,12 @@ def _load_per_tensor_weight_scale( # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight + target = param_data[expert_id][idx] + target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight)) # If we are in the row parallel case (down_proj) elif shard_id == "w2": - param_data[expert_id] = loaded_weight + target = param_data[expert_id] + target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight)) def _load_combined_w13_weight_scale( self, @@ -308,7 +323,7 @@ def _load_combined_w13_weight_scale( loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size ) - param.copy_(loaded_weight) + param.copy_(self._normalize_loaded_weight_for_copy(param, loaded_weight)) def _load_model_weight_or_group_weight_scale( self, @@ -366,7 +381,9 @@ def _load_per_channel_weight_scale( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) elif shard_id in ("w1", "w3"): self._load_w13( shard_id=shard_id, @@ -482,7 +499,9 @@ def _load_w13( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) def _load_w2( self, @@ -517,7 +536,9 @@ def _load_w2( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) def _load_single_value( self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 1b2a8a74bdcb..a205dd612a21 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -37,6 +37,14 @@ logger = init_logger(__name__) +def _clear_mxfp4_moe_weight_loading_cache() -> None: + # WORKAROUND: SM12x/GB10 can hit driver instability while loading MXFP4 + # MoE weights under tight VRAM pressure. Release PyTorch's staging cache + # after the backend kernel has taken ownership of the converted weights. + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + class Mxfp4Config(QuantizationConfig): """Canonical base config for MXFP4 quantization. @@ -389,6 +397,8 @@ def process_weights_after_loading(self, layer: RoutedExperts) -> None: return self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias) + del w13, w2, w13_scale, w2_scale, w13_bias, w2_bias + _clear_mxfp4_moe_weight_loading_cache() def get_fused_moe_quant_config( self, layer: RoutedExperts @@ -733,6 +743,8 @@ def process_weights_after_loading(self, layer): return self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias) + del w13, w2, w13_scale, w2_scale, w13_bias, w2_bias + _clear_mxfp4_moe_weight_loading_cache() def get_fused_moe_quant_config( self, diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1024,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1024,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..46a4cc726f7c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1024,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..306fdae8639e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..387b572731b6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..1cb973e5c383 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..ac91f525e96b --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..2d655a2debac --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..ac1053b588c5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..2026a21038b9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..be96d80c51f1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..afeb347e229a --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..1e84c847ff37 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..5163bc4f3da1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..6059ab794a4f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..36a2e6621f2e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..5a6f1de61395 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 32a2d86899cc..2b123c22a9db 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -38,6 +38,21 @@ logger = init_logger(__name__) +_SM12X_TUNED_CONFIG_DEVICE_ALIASES = { + "NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition": ( + "NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition", + ), + "NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition": ( + "NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition", + ), +} + + +def _tuned_config_device_names() -> tuple[str, ...]: + device_name = current_platform.get_device_name().replace(" ", "_") + return (device_name, *_SM12X_TUNED_CONFIG_DEVICE_ALIASES.get(device_name, ())) + + def is_fp8(x: torch.dtype | torch.Tensor) -> bool: if isinstance(x, torch.Tensor): x = x.dtype @@ -864,31 +879,63 @@ def get_w8a8_block_fp8_configs( # First look up if an optimized configuration is available in the configs # directory - device_name = current_platform.get_device_name().replace(" ", "_") - json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501 - - config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name - ) - if os.path.exists(config_file_path): - with open(config_file_path) as f: - logger.info( - "Using configuration from %s for W8A8 Block FP8 kernel.", - config_file_path, - ) - # If a configuration has been found, return it - return {int(key): val for key, val in json.load(f).items()} + config_file_paths = [ + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "configs", + f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json", # noqa: E501 + ) + for device_name in _tuned_config_device_names() + ] + for config_file_path in config_file_paths: + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info( + "Using configuration from %s for W8A8 Block FP8 kernel.", + config_file_path, + ) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration logger.warning( "Using default W8A8 Block FP8 kernel config. Performance might " - "be sub-optimal! Config file not found at %s", - config_file_path, + "be sub-optimal! Config files not found at %s", + config_file_paths, ) return None +def _get_default_w8a8_block_fp8_config( + M: int, + block_n: int, + block_k: int, +) -> dict[str, Any]: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_n and + # BLOCK_SIZE_K must be divisible by block_k. + # M-aware tuning for low-M decode: BLOCK_SIZE_M=64 wastes most of the + # M-dim for single-request decode and short MTP-style draft batches. SM12x + # keeps benefiting from the low-M tile through M=32 on DeepSeek V4 shapes. + capability = current_platform.get_device_capability() + capability_major = getattr(capability, "major", None) + if capability_major is None and capability is not None: + capability_major = capability[0] + low_m_limit = 32 if capability_major == 12 else 8 + if low_m_limit >= M: + block_m, num_stages = 16, (2 if current_platform.is_rocm() else 3) + else: + block_m, num_stages = 64, 2 + return { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": num_stages, + } + + def w8a8_triton_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, @@ -933,6 +980,12 @@ def w8a8_triton_block_scaled_mm( N, K = B.shape assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if e8m0_dtype is not None: + if As.dtype == e8m0_dtype: + As = _upcast_e8m0_to_fp32(As) + if Bs.dtype == e8m0_dtype: + Bs = _upcast_e8m0_to_fp32(Bs) C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) @@ -942,17 +995,7 @@ def w8a8_triton_block_scaled_mm( # Get the optimal config if there is one config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: - # Default config - # Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0] - # BLOCK_SIZE_K must be divisible by block_size[1] - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": block_size[0], - "BLOCK_SIZE_K": block_size[1], - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 2, - } + config = _get_default_w8a8_block_fp8_config(M, block_size[0], block_size[1]) def grid(META): return ( @@ -1303,6 +1346,8 @@ def create_fp8_scale_parameter( if dtype == torch.float32: scale[:] = torch.finfo(torch.float32).min + elif dtype == getattr(torch, "float8_e8m0fnu", None): + scale[:] = 0 set_weight_attrs(scale, {"scale_type": "weight_scale"}) return scale diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 45c5d5f78198..a8372ed7d43a 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -4,7 +4,6 @@ import torch -import vllm.envs as envs from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.compilation.breakable_cudagraph import eager_break_during_capture @@ -14,7 +13,9 @@ from vllm.platforms import current_platform from vllm.utils.deep_gemm import ( fp8_fp4_mqa_logits, + fp8_fp4_mqa_topk_indices, fp8_fp4_paged_mqa_logits, + fp8_fp4_paged_mqa_topk_indices, has_deep_gemm, ) from vllm.utils.torch_utils import ( @@ -25,6 +26,7 @@ ) from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerMetadata, + sparse_indexer_max_logits_bytes, ) from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.worker.workspace import current_workspace_manager @@ -32,11 +34,64 @@ logger = init_logger(__name__) RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024 +SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH = 4096 +SM120_SHORT_ROW_TOPK_MAX_WIDTH = 12288 # MXFP4 layout: 2 values packed per byte, ue8m0 (1-byte) scale per block of 32. MXFP4_BLOCK_SIZE = 32 +def _should_use_sm120_short_row_topk_decode( + topk_tokens: int, + logits_width: int, + is_cuda_sm120: bool, +) -> bool: + if not is_cuda_sm120 or topk_tokens != 512: + return False + if logits_width <= SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH: + return True + return logits_width < SM120_SHORT_ROW_TOPK_MAX_WIDTH + + +def _use_sm120_short_row_topk_decode( + logits: torch.Tensor, + topk_tokens: int, +) -> bool: + return _should_use_sm120_short_row_topk_decode( + topk_tokens, + logits.shape[1], + current_platform.is_cuda() + and current_platform.is_device_capability_family(120), + ) + + +def _decode_logits_width(max_model_len: int, max_seq_len: int) -> int: + if max_model_len <= 0: + return 0 + if max_seq_len <= 0: + return max_model_len + return min(max_model_len, max_seq_len) + + +def _decode_topk_logits_width( + max_model_len: int, max_seq_len: int, topk_tokens: int +) -> int: + logits_width = _decode_logits_width(max_model_len, max_seq_len) + return min(max_model_len, max(logits_width, topk_tokens)) + + +def _sparse_indexer_requires_deep_gemm(use_fp4_cache: bool = False) -> bool: + if not current_platform.is_cuda(): + return False + if current_platform.is_device_capability_family(120): + # The SM120 fallback path covers FP8-Q sparse indexer calls. FP4-Q + # indexer calls still route through DeepGEMM's fp8_fp4 kernels, so + # fail during construction instead of letting the first forward hit + # the generic DeepGEMM missing-dependency error. + return use_fp4_cache + return True + + def _gather_workspace_shapes( total_seq_lens: int, head_dim: int, @@ -116,7 +171,7 @@ def sparse_attn_indexer( # Dummy allocation to simulate for peak logits tensor memory during inference. # FP8 elements so elements == bytes - max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + max_logits_elems = sparse_indexer_max_logits_bytes() _ = torch.empty( max_logits_elems, dtype=torch.uint8, device=hidden_states.device ) @@ -218,6 +273,19 @@ def sparse_attn_indexer( q_slice_cast = q_slice k_quant_cast = k_quant k_scale_cast = k_scale.view(torch.float32).squeeze(-1) + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + if not current_platform.is_xpu() and fp8_fp4_mqa_topk_indices( + (q_slice_cast, q_scale_slice), + (k_quant_cast, k_scale_cast), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + ): + continue + if current_platform.is_xpu(): if q_scale_slice is not None: raise RuntimeError("XPU fp8_mqa_logits does not support FP4 Q") @@ -240,10 +308,6 @@ def sparse_attn_indexer( ) num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[ - chunk.token_start : chunk.token_end, :topk_tokens - ] - ops.top_k_per_row_prefill( logits, chunk.cu_seqlen_ks, @@ -296,7 +360,18 @@ def sparse_attn_indexer( batch_size = padded_q_quant_decode_tokens.shape[0] next_n = padded_q_quant_decode_tokens.shape[1] num_padded_tokens = batch_size * next_n - seq_lens = decode_metadata.seq_lens[:batch_size] + # ``.contiguous()`` was originally required because the + # ``DeepseekV32IndexerMetadataBuilder`` allocated + # ``decode_seq_lens_buffer`` as a 2D ``(max_num_seqs, next_n)`` + # tensor, and a ``[:num_decodes, :max_decode_len]`` slice was + # non-contiguous when ``max_decode_len < next_n`` under V2 model + # runner cudagraph capture. Reported by aabbccddwasd in PR #41834 + # comment 4450901180. Upstream PR #42135 (ee58665aa) since + # unified the buffer to 1D ``(max_num_batched_tokens,)``, so the + # slice is now always contiguous and this call is a no-op pointer + # return. Kept as a defensive belt against future regressions in + # the metadata builder's buffer shape. + seq_lens = decode_metadata.seq_lens[:batch_size].contiguous() # seq_lens is always 2D: (B, next_n) for native spec decode, (B, 1) # otherwise. deep_gemm fp8_fp4_paged_mqa_logits requires 2D context_lens; # the downstream topk kernels accept both 1D and 2D. @@ -305,60 +380,93 @@ def sparse_attn_indexer( if use_fp4_cache else padded_q_quant_decode_tokens ) - if current_platform.is_xpu(): - if padded_q_scale is not None: - raise RuntimeError("XPU fp8_paged_mqa_logits does not support FP4 Q") - seq_lens_xpu = ( - seq_lens[:, -1].contiguous() if seq_lens.ndim == 2 else seq_lens - ) - logits = torch.ops.vllm.xpu_fp8_paged_mqa_logits( - padded_q_quant_cast, - kv_cache, - weights[:num_padded_tokens], - seq_lens_xpu, - decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len, - ) - else: - logits = fp8_fp4_paged_mqa_logits( + topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] + logits_width = _decode_topk_logits_width( + max_model_len, attn_metadata_narrowed.max_seq_len, topk_tokens + ) + logits_bytes = num_padded_tokens * logits_width * torch.float32.itemsize + used_direct_topk = False + if ( + not current_platform.is_xpu() + and logits_bytes > sparse_indexer_max_logits_bytes() + ): + used_direct_topk = fp8_fp4_paged_mqa_topk_indices( (padded_q_quant_cast, padded_q_scale), kv_cache, weights[:num_padded_tokens], seq_lens, decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len=max_model_len, - clean_logits=False, - ) - num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - - if current_platform.is_cuda() and topk_tokens in (512, 1024, 2048): - workspace_manager = current_workspace_manager() - (topk_workspace,) = workspace_manager.get_simultaneous( - ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), - ) - torch.ops._C.persistent_topk( - logits, - seq_lens, - topk_indices, - topk_workspace, - topk_tokens, - attn_metadata_narrowed.max_seq_len, - ) - else: - ops.top_k_per_row_decode( - logits, - next_n, - seq_lens, + logits_width, topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, ) + if not used_direct_topk: + if current_platform.is_xpu(): + if padded_q_scale is not None: + raise RuntimeError( + "XPU fp8_paged_mqa_logits does not support FP4 Q" + ) + seq_lens_xpu = ( + seq_lens[:, -1].contiguous() if seq_lens.ndim == 2 else seq_lens + ) + logits = torch.ops.vllm.xpu_fp8_paged_mqa_logits( + padded_q_quant_cast, + kv_cache, + weights[:num_padded_tokens], + seq_lens_xpu, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len, + ) + else: + logits = fp8_fp4_paged_mqa_logits( + (padded_q_quant_cast, padded_q_scale), + kv_cache, + weights[:num_padded_tokens], + seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=logits_width, + clean_logits=False, + ) + num_rows = logits.shape[0] + + if _use_sm120_short_row_topk_decode(logits, topk_tokens): + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + elif current_platform.is_cuda() and topk_tokens in (512, 1024, 2048): + workspace_manager = current_workspace_manager() + (topk_workspace,) = workspace_manager.get_simultaneous( + ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), + ) + torch.ops._C.persistent_topk( + logits, + seq_lens, + topk_indices, + topk_workspace, + topk_tokens, + logits_width, + ) + else: + ops.top_k_per_row_decode( + logits, + next_n, + seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens @@ -440,7 +548,7 @@ def __init__( self.topk_indices_buffer = topk_indices_buffer self.skip_k_cache_insert = skip_k_cache_insert self.use_fp4_cache = use_fp4_cache - if current_platform.is_cuda() and not has_deep_gemm(): + if _sparse_indexer_requires_deep_gemm(use_fp4_cache) and not has_deep_gemm(): raise RuntimeError( "Sparse Attention Indexer CUDA op requires DeepGEMM support in " "the current vLLM environment." diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 3ea76f4d9b3a..25682c1085dc 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -4,7 +4,7 @@ import glob import os import time -from collections.abc import Generator, Iterable +from collections.abc import Callable, Generator, Iterable from typing import cast import torch @@ -242,7 +242,9 @@ def _prepare_weights( return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( - self, source: "Source" + self, + source: "Source", + weight_name_filter: Callable[[str], bool] | None = None, ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" extra_config = self.load_config.model_loader_extra_config @@ -268,6 +270,8 @@ def _get_weights_iterator( weights_iterator = fastsafetensors_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, + local_expert_ids=self.local_expert_ids, + weight_name_filter=weight_name_filter, ) elif self.load_config.load_format == "instanttensor": weights_iterator = instanttensor_weights_iterator( @@ -289,6 +293,7 @@ def _get_weights_iterator( self.load_config.use_tqdm_on_load, self.load_config.safetensors_load_strategy, local_expert_ids=self.local_expert_ids, + weight_name_filter=weight_name_filter, safetensors_prefetch_num_threads=( self.load_config.safetensors_prefetch_num_threads ), @@ -323,6 +328,9 @@ def get_all_weights( model_config: ModelConfig, model: nn.Module, ) -> Generator[tuple[str, torch.Tensor], None, None]: + weight_name_filter = getattr(model, "skip_weight_name_before_load", None) + if not callable(weight_name_filter): + weight_name_filter = None primary_weights = DefaultModelLoader.Source( model_config.model, model_config.revision, @@ -330,14 +338,14 @@ def get_all_weights( fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None), ) - yield from self._get_weights_iterator(primary_weights) + yield from self._get_weights_iterator(primary_weights, weight_name_filter) secondary_weights = cast( Iterable[DefaultModelLoader.Source], getattr(model, "secondary_weights", ()), ) for source in secondary_weights: - yield from self._get_weights_iterator(source) + yield from self._get_weights_iterator(source, weight_name_filter) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights( diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 47c6c02be6ab..7037778b7f97 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -817,11 +817,22 @@ def _run_prefetch() -> None: threading.Thread(target=_run_prefetch, daemon=True).start() +def _should_skip_safetensors_weight( + weight_name: str, + local_expert_ids: set[int] | None, + weight_name_filter: Callable[[str], bool] | None, +) -> bool: + if should_skip_weight(weight_name, local_expert_ids): + return True + return weight_name_filter is not None and weight_name_filter(weight_name) + + def safetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, safetensors_load_strategy: str | None = None, local_expert_ids: set[int] | None = None, + weight_name_filter: Callable[[str], bool] | None = None, *, safetensors_prefetch_num_threads: int = DEFAULT_SAFETENSORS_PREFETCH_NUM_THREADS, safetensors_prefetch_block_size: int = DEFAULT_SAFETENSORS_PREFETCH_BLOCK_SIZE, @@ -831,6 +842,9 @@ def safetensors_weights_iterator( When *local_expert_ids* is provided, expert weights not belonging to this rank are skipped **before** reading from disk, which drastically reduces storage I/O for MoE models under EP. + + When *weight_name_filter* is provided, names for which the callback returns + ``True`` are also skipped before tensor materialization. """ loading_desc = "Loading safetensors checkpoint shards" if safetensors_load_strategy == "eager": @@ -913,7 +927,9 @@ def safetensors_weights_iterator( with open(st_file, "rb") as f: state_dict = load(f.read()) for name, param in state_dict.items(): - if not should_skip_weight(name, local_expert_ids): + if not _should_skip_safetensors_weight( + name, local_expert_ids, weight_name_filter + ): yield name, param elif safetensors_load_strategy == "torchao": # we can't load flattened torchao tensor subclasses directly into the model @@ -930,7 +946,9 @@ def safetensors_weights_iterator( with safe_open(st_file, framework="pt") as f: state_dict = {} for name in f.keys(): # noqa: SIM118 - if should_skip_weight(name, local_expert_ids): + if _should_skip_safetensors_weight( + name, local_expert_ids, weight_name_filter + ): continue state_dict[name] = f.get_tensor(name) @@ -948,7 +966,9 @@ def safetensors_weights_iterator( else: with safe_open(st_file, framework="pt") as f: for name in f.keys(): # noqa: SIM118 - if should_skip_weight(name, local_expert_ids): + if _should_skip_safetensors_weight( + name, local_expert_ids, weight_name_filter + ): continue param = f.get_tensor(name) yield name, param @@ -1024,6 +1044,8 @@ def runai_safetensors_weights_iterator( def fastsafetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, + local_expert_ids: set[int] | None = None, + weight_name_filter: Callable[[str], bool] | None = None, ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files using fastsafetensor library. @@ -1072,6 +1094,12 @@ def _make_loader(nogds: bool) -> "ParallelLoader": pl = _make_loader(nogds) for name, tensor in pl.iterate_weights(): yielded = True + if _should_skip_safetensors_weight( + name, + local_expert_ids, + weight_name_filter, + ): + continue yield name, tensor except RuntimeError as e: if nogds or yielded or "gds" not in str(e): @@ -1084,7 +1112,14 @@ def _make_loader(nogds: bool) -> "ParallelLoader": if pl is not None: pl.close() pl = _make_loader(nogds=True) - yield from pl.iterate_weights() + for name, tensor in pl.iterate_weights(): + if _should_skip_safetensors_weight( + name, + local_expert_ids, + weight_name_filter, + ): + continue + yield name, tensor finally: if pl is not None: pl.close() diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 754270e6525b..b97368269348 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -6,8 +6,10 @@ happen during model execution. """ +from types import SimpleNamespace from typing import TYPE_CHECKING +import numpy as np import torch import vllm.envs as envs @@ -27,6 +29,8 @@ from vllm.platforms import current_platform from vllm.utils.deep_gemm import is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput +from vllm.v1.structured_output.utils import apply_grammar_bitmask if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -34,6 +38,624 @@ logger = init_logger(__name__) +_DEEPSEEK_V4_SPARSE_MLA_BACKENDS = frozenset( + { + "V4_FLASHMLA_SPARSE", + "DEEPSEEK_SPARSE_SWA", + } +) +_DEEPSEEK_V4_SPARSE_MLA_MIXED_WARMUP_TOKENS = 16 +# Cap warmup at the largest single-chunk prefill the scheduler will ever +# issue (max_num_batched_tokens). 8192 covers the canonical SM12x serve +# (max_num_batched_tokens=8192); larger scheduler caps clamp to this +# value via _clamp_warmup_tokens at the call site, smaller caps clamp +# down naturally. +_DEEPSEEK_V4_SPARSE_MLA_PREFILL_WARMUP_TOKENS = 8192 +# Steady-state MTP decode shapes to warm. Keep this bounded to high-concurrency +# SM12x gates while still avoiding the scheduler's raw max_num_seqs (often 1024), +# which can consume multiple GiB of temporary workspace on long-context serves +# before the first request. +_DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS = (1, 2, 4, 8, 16, 24, 32) +_DEEPSEEK_V4_MTP_UNIFORM_DECODE_MAX_WARMUP_REQUESTS = 256 +_DEEPSEEK_V4_SLOT_MAPPING_WARMUP_TOKENS = tuple(range(1, 17)) + ( + 32, + 64, + 128, + 256, + 512, +) + + +def _attention_backend_name(backend: object) -> str | None: + get_name = getattr(backend, "get_name", None) + if get_name is None: + return None + try: + return get_name() + except NotImplementedError: + return None + + +def _has_deepseek_v4_sparse_mla_backend(runner: "GPUModelRunner") -> bool: + for groups in getattr(runner, "attn_groups", []) or (): + for group in groups: + name = _attention_backend_name(getattr(group, "backend", None)) + if name in _DEEPSEEK_V4_SPARSE_MLA_BACKENDS: + return True + return False + + +def _clamp_warmup_tokens(num_tokens: int, max_tokens: int) -> int: + return max(0, min(num_tokens, max_tokens)) + + +def _is_deepseek_v4_mtp_spec_decode(runner: "GPUModelRunner") -> bool: + spec_config = getattr(runner, "speculative_config", None) + return ( + getattr(spec_config, "method", None) == "mtp" + and getattr(runner, "num_spec_tokens", 0) > 0 + ) + + +def _deepseek_v4_mtp_uniform_decode_warmup_requests( + runner: "GPUModelRunner", + max_tokens: int, + max_reqs: int, +) -> tuple[int, ...]: + if not _is_deepseek_v4_mtp_spec_decode(runner): + return () + + query_len = getattr( + runner, + "uniform_decode_query_len", + 1 + getattr(runner, "num_spec_tokens", 0), + ) + if query_len <= 0: + return () + + max_warmup_reqs = min( + max_reqs, + max_tokens // query_len, + _DEEPSEEK_V4_MTP_UNIFORM_DECODE_MAX_WARMUP_REQUESTS, + ) + candidates = sorted( + set(_DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS) | {max_warmup_reqs} + ) + return tuple(reqs for reqs in candidates if reqs <= max_warmup_reqs) + + +def _deepseek_v4_slot_mapping_warmup(runner: "GPUModelRunner") -> None: + max_tokens = getattr(runner, "max_num_tokens", 1) + block_table = runner.input_batch.block_table + + # Snapshot the runner buffers we mutate so warmup never leaks state into + # the first real request. + saved_query_start_loc_np: np.ndarray | None = None + saved_query_start_loc_gpu: torch.Tensor | None = None + if hasattr(runner, "query_start_loc"): + saved_query_start_loc_np = runner.query_start_loc.np[:2].copy() + saved_query_start_loc_gpu = runner.query_start_loc.gpu[:2].clone() + + try: + for requested_tokens in _DEEPSEEK_V4_SLOT_MAPPING_WARMUP_TOKENS: + num_tokens = _clamp_warmup_tokens(requested_tokens, max_tokens) + if num_tokens <= 0: + continue + + positions_source = torch.arange( + num_tokens, dtype=torch.int64, device=runner.device + ) + if hasattr(runner, "query_start_loc"): + runner.query_start_loc.np[0] = 0 + runner.query_start_loc.np[1] = num_tokens + runner.query_start_loc.copy_to_gpu(2) + query_start_loc = runner.query_start_loc.gpu[:2] + else: + query_start_loc = torch.tensor( + [0, num_tokens], dtype=torch.int32, device=runner.device + ) + + if hasattr(runner, "positions"): + saved_positions: torch.Tensor | None = runner.positions[ + :num_tokens + ].clone() + runner.positions[:num_tokens].copy_(positions_source) + positions = runner.positions[:num_tokens] + else: + saved_positions = None + positions = positions_source + + try: + block_table.commit_block_table(1) + block_table.compute_slot_mapping(1, query_start_loc, positions) + finally: + if saved_positions is not None: + runner.positions[:num_tokens].copy_(saved_positions) + finally: + if saved_query_start_loc_np is not None: + runner.query_start_loc.np[:2] = saved_query_start_loc_np + assert saved_query_start_loc_gpu is not None + runner.query_start_loc.gpu[:2].copy_(saved_query_start_loc_gpu) + + +def _deepseek_v4_structured_output_bitmask_warmup( + runner: "GPUModelRunner", +) -> None: + vocab_size = runner.model_config.get_vocab_size() + if vocab_size <= 0: + return + + dtypes = [torch.float32] + model_dtype = getattr(runner.model_config, "dtype", None) + if isinstance(model_dtype, torch.dtype) and model_dtype not in dtypes: + dtypes.append(model_dtype) + + bitmask_width = (vocab_size + 31) // 32 + req_id = "_deepseek_v4_warmup_" + grammar_bitmask = np.full((1, bitmask_width), fill_value=-1, dtype=np.int32) + grammar_output = GrammarOutput( + structured_output_request_ids=[req_id], grammar_bitmask=grammar_bitmask + ) + + for dtype in dtypes: + for req_ids in ([req_id], [req_id, "_deepseek_v4_warmup_unmasked_"]): + logits = torch.zeros( + (len(req_ids), vocab_size), dtype=dtype, device=runner.device + ) + input_batch = SimpleNamespace(req_ids=req_ids) + apply_grammar_bitmask( + SchedulerOutput.make_empty(), + grammar_output, + input_batch, # type: ignore[arg-type] + logits, + ) + + +@torch.inference_mode() +def _deepseek_v4_request_prep_warmup(worker: "Worker") -> None: + if not envs.VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: + return + + runner = worker.model_runner + if runner.is_pooling_model or not _has_deepseek_v4_sparse_mla_backend(runner): + return + if not current_platform.is_cuda_alike(): + return + + logger.info("Warming up DeepSeek V4 request preparation kernels.") + _deepseek_v4_slot_mapping_warmup(runner) + + if getattr(runner, "is_last_pp_rank", True): + try: + _deepseek_v4_structured_output_bitmask_warmup(runner) + except ImportError: + logger.debug( + "Skipping DeepSeek V4 structured output bitmask warmup because " + "xgrammar is unavailable." + ) + + torch.accelerator.synchronize() + + +def _run_deepseek_v4_mtp_spec_decode_warmup_kernels( + *, + device: torch.device, + num_reqs: int, + num_spec_tokens: int, + vocab_size: int, + block_size: int, + max_model_len: int, +) -> None: + from vllm.v1.sample.logits_processor import LogitsProcessors + from vllm.v1.sample.metadata import SamplingMetadata + from vllm.v1.sample.rejection_sampler import rejection_sample + from vllm.v1.spec_decode.utils import ( + eagle_prepare_inputs_padded_kernel, + eagle_prepare_next_token_padded_kernel, + eagle_step_update_slot_mapping_and_metadata, + next_power_of_2, + ) + + num_sampled_tokens = num_spec_tokens + 1 + sampled_token_ids = torch.arange( + num_reqs * num_sampled_tokens, dtype=torch.int32, device=device + ).reshape(num_reqs, num_sampled_tokens) + sampled_token_ids.remainder_(vocab_size) + discard_request_mask = torch.zeros(num_reqs, dtype=torch.bool, device=device) + backup_next_token_ids = torch.zeros(num_reqs, dtype=torch.int32, device=device) + next_token_ids = torch.empty(num_reqs, dtype=torch.int32, device=device) + valid_sampled_tokens_count = torch.empty(num_reqs, dtype=torch.int32, device=device) + eagle_prepare_next_token_padded_kernel[(num_reqs,)]( + sampled_token_ids, + discard_request_mask, + backup_next_token_ids, + next_token_ids, + valid_sampled_tokens_count, + vocab_size, + num_sampled_tokens, + num_reqs, + sampled_token_ids.stride(0), + BLOCK_SIZE_TOKENS=next_power_of_2(num_sampled_tokens), + ) + + cu_num_draft_tokens = torch.arange( + num_spec_tokens, + num_reqs * num_spec_tokens + 1, + num_spec_tokens, + dtype=torch.int32, + device=device, + ) + query_start_loc = torch.arange( + 0, + (num_reqs + 1) * num_sampled_tokens, + num_sampled_tokens, + dtype=torch.int32, + device=device, + ) + token_indices_to_sample = torch.empty(num_reqs, dtype=torch.int32, device=device) + num_rejected_tokens = torch.empty(num_reqs, dtype=torch.int32, device=device) + eagle_prepare_inputs_padded_kernel[(num_reqs,)]( + cu_num_draft_tokens, + valid_sampled_tokens_count, + query_start_loc, + token_indices_to_sample, + num_rejected_tokens, + num_reqs, + ) + + positions = torch.arange(num_reqs, dtype=torch.int64, device=device) + block_table_tensor = torch.zeros((num_reqs, 1), dtype=torch.int32, device=device) + seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device) + out_clamped_positions = torch.empty_like(positions) + out_slot_mapping = torch.empty(num_reqs, dtype=torch.int64, device=device) + eagle_step_update_slot_mapping_and_metadata( + positions, + block_table_tensor, + seq_lens, + block_size, + max_model_len, + out_clamped_positions, + out_slot_mapping, + input_batch_size=num_reqs, + ) + + total_draft_tokens = num_reqs * num_spec_tokens + draft_token_ids = torch.arange(total_draft_tokens, dtype=torch.int32, device=device) + draft_token_ids.remainder_(vocab_size) + draft_probs = torch.rand( + total_draft_tokens, vocab_size, dtype=torch.float32, device=device + ) + draft_probs = draft_probs / draft_probs.sum(dim=-1, keepdim=True) + target_logits = torch.randn( + total_draft_tokens, vocab_size, dtype=torch.float32, device=device + ) + bonus_token_ids = torch.zeros((num_reqs, 1), dtype=torch.int32, device=device) + sampling_metadata = SamplingMetadata( + temperature=torch.full((num_reqs,), 0.7, dtype=torch.float32, device=device), + all_greedy=False, + all_random=True, + top_p=None, + top_k=None, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=torch.empty(0, device=device), + presence_penalties=torch.empty(0, device=device), + repetition_penalties=torch.empty(0, device=device), + output_token_ids=[[] for _ in range(num_reqs)], + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=LogitsProcessors(), + logprob_token_ids=None, + spec_token_ids=[[] for _ in range(num_reqs)], + ) + rejection_sample( + draft_token_ids=draft_token_ids, + num_draft_tokens=[num_spec_tokens] * num_reqs, + max_spec_len=num_spec_tokens, + cu_num_draft_tokens=cu_num_draft_tokens, + draft_probs=draft_probs, + target_logits=target_logits, + bonus_token_ids=bonus_token_ids, + sampling_metadata=sampling_metadata, + ) + + +def _deepseek_v4_indexed_d512_split_prefill_warmup(runner: "GPUModelRunner") -> None: + """Force-compile the DeepSeek-V4 D512-split sparse-MLA prefill kernels. + + The split path (``_use_indexed_d512_split_prefill`` -> + ``accumulate_indexed_d512_split_sparse_mla_attention``) bottoms out in three + plain ``@triton.jit`` kernels whose compile key is the constexpr set -- + chiefly ``num_candidates`` (= the per-chunk ``combined_topk``) plus the + workspace buffer strides. ``combined_topk`` is 128-aligned + (``_SPARSE_PREFILL_TOPK_ALIGNMENT``) and the split path is gated to + ``[256, 1152]`` (``_is_indexed_d512_split_topk``), so the complete + specialization set is the eight widths {256, 384, ..., 1152}. The kernels + never see ``compress_ratio``, so one warm per width covers cr=4 and cr=128. + + Without this, the first long-prefill request JIT-compiles these kernels + inside the engine step (~20s), parking EngineCore in shm_broadcast and + surfacing as a "sample_tokens RPC timed out" wedge (PR #41834). + + Triton compilation is data-independent, so synthetic zero tensors compile + the same cubin a real request uses -- provided every constexpr matches. Two + non-obvious constexprs (verified against the live jit_monitor compile key): + the per-chunk ``scores``/``indices`` workspaces are sized to that chunk's own + ``combined_topk`` (contiguous at width C, so ``stride_scores_h == C`` and + ``stride_indices_t == C`` -- NOT a slice of a wider buffer), and the prefill + ``q`` buffer is padded to the FP8-decode head count (``padded_heads``), so + ``stride_q_t == padded_heads * head_dim`` even though the kernel reads only + ``n_local_heads``. The synthetic tensors mirror both. + + Scope: only the split path (``combined_topk <= 1152``) is warmed. DeepSeek-V4 + -Flash caps ``combined_topk`` at ``sparse_prefill_combined_topk_size( + index_topk=512, 128) = 640`` for every context length, so that is complete + coverage. A variant whose ``combined_topk`` can exceed 1152 routes onto the + chunked path (extra split-stride and merge kernels) which is not pre-warmed + here; that case is warned at startup rather than left as a silent gap. + """ + if not ( + envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_WARMUP + and envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL + ): + return + + try: + from vllm.models.deepseek_v4.common.ops.cache_utils import ( + sparse_prefill_combined_topk_size, + ) + from vllm.models.deepseek_v4.nvidia.flashmla import ( + _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK, + _INDEXED_D512_SPLIT_PREFILL_MIN_TOPK, + DeepseekV4FlashMLAAttention, + ) + from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled_for_platform, + triton_sparse_mla_query_chunk_size, + ) + from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_indexed_d512_split_sparse_mla_attention, + ) + except ImportError as exc: + # The early gate above already confirmed the warmup is requested, so a + # failed import here is not a benign "kernels unavailable" case — it is + # usually a renamed symbol (it silently disabled this warmup for weeks). + # Surface it at WARNING so a future rename does not no-op the warmup. + logger.warning( + "Skipping DeepSeek V4 D512-split prefill warmup: a required symbol " + "failed to import (%s). The split kernels are likely present but a " + "helper was renamed; the first long prefill will JIT them mid-inference.", + exc, + ) + return + + try: + if not is_triton_sparse_mla_enabled_for_platform(): + return + if ( + getattr(runner, "max_model_len", 0) + < envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS + ): + return + + # The split kernel never sees compress_ratio, so any cr in (4, 128) + # layer yields identical strides; the first one is representative. + layer = None + for module in runner.get_model().modules(): + if ( + isinstance(module, DeepseekV4FlashMLAAttention) + and module.compress_ratio in (4, 128) + ): + layer = module + break + if layer is None: + return + + head_dim = int(layer.head_dim) + if head_dim != 512: + return + num_heads = int(layer.n_local_heads) + window_size = max(1, int(layer.window_size)) + device = layer.attn_sink.device + + # Upper bound on the per-chunk combined_topk a real request can reach: + # the runtime sizes its prefill workspace from the same expression, so a + # request cannot exceed it. The split path is gated to <= 1152. + topk_bound = DeepseekV4FlashMLAAttention._prefill_workspace_topk_bound(layer) + max_reachable_topk = sparse_prefill_combined_topk_size(topk_bound, window_size) + # Non-silent gap: if combined_topk can exceed the split ceiling, the + # request routes onto the chunked path whose kernels we do not pre-warm. + # DSv4-Flash caps at 640 so this never fires for it; warn for variants. + if max_reachable_topk > _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK: + logger.warning( + "DeepSeek V4 D512 prefill: combined_topk can reach %d (> %d); the " + "chunked-prefill kernels are NOT pre-warmed and may JIT on the " + "first very-long prefill. Only the split path is warmed.", + max_reachable_topk, + _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK, + ) + max_topk = min(_INDEXED_D512_SPLIT_PREFILL_MAX_TOPK, max_reachable_topk) + topk_widths = list( + range(_INDEXED_D512_SPLIT_PREFILL_MIN_TOPK, max_topk + 1, 128) + ) + if not topk_widths: + return + + # The real prefill q buffer is padded to the FP8-decode head count; the + # split kernel reads only n_local_heads, but stride_q_t (a constexpr in + # the compile key) reflects the padded width, so match it. + padded_heads = int( + getattr(layer, "padded_heads", 0) + or DeepseekV4FlashMLAAttention.get_padded_num_q_heads(num_heads) + ) + # T sizes only the launch grid -- the cubin is T-independent -- so keep + # it small to bound the transient footprint. + num_tokens = max(1, min(triton_sparse_mla_query_chunk_size(), 32)) + + logger.info( + "Warming up DeepSeek V4 D512-split sparse-MLA prefill kernels for " + "combined_topk widths=%s (heads=%d, padded_q_heads=%d).", + topk_widths, + num_heads, + padded_heads, + ) + + # Throwaway tensors -- never the shared workspace, so warmup can't grow + # or leak steady-state memory. q/kv/state are width-independent; scores + # and indices are contiguous at each per-chunk width so their constexpr + # strides (stride_scores_h == width, stride_indices_t == width) match the + # runtime per-chunk workspace exactly. + q = torch.zeros( + (num_tokens, padded_heads, head_dim), dtype=torch.bfloat16, device=device + ) + kv_flat = torch.zeros( + (max_topk, head_dim), dtype=torch.bfloat16, device=device + ) + max_score = torch.zeros( + (num_tokens, num_heads), dtype=torch.float32, device=device + ) + denom = torch.zeros( + (num_tokens, num_heads), dtype=torch.float32, device=device + ) + acc = torch.zeros( + (num_tokens, num_heads, head_dim), dtype=torch.float32, device=device + ) + lens = torch.zeros((num_tokens,), dtype=torch.int32, device=device) + + for width in topk_widths: + # indices=0 (valid row) + lens=width keep every candidate active so + # the full kernel body, including the tl.dot MMA, compiles rather + # than an early-return stub. + indices = torch.zeros( + (num_tokens, width), dtype=torch.int32, device=device + ) + scores = torch.zeros( + (num_tokens, num_heads, width), dtype=torch.float32, device=device + ) + lens.fill_(width) + accumulate_indexed_d512_split_sparse_mla_attention( + q=q, + kv_flat=kv_flat, + indices=indices, + lens=lens, + scale=layer.scale, + scores=scores, + max_score=max_score, + denom=denom, + acc=acc, + ) + torch.accelerator.synchronize() + except Exception as exc: # noqa: BLE001 - warmup must never break startup + # Warn (not debug): a swallowed failure here silently leaves the split + # kernels uncompiled, so the first long prefill pays the JIT stall again. + logger.warning( + "DeepSeek V4 D512-split prefill warmup skipped after error " + "(first long prefill may JIT in-inference): %s", + exc, + ) + + +def _deepseek_v4_sparse_mla_attention_warmup(worker: "Worker") -> None: + if not envs.VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: + return + + runner = worker.model_runner + if runner.is_pooling_model or not _has_deepseek_v4_sparse_mla_backend(runner): + return + + max_tokens = worker.scheduler_config.max_num_batched_tokens + mixed_tokens = _clamp_warmup_tokens( + _DEEPSEEK_V4_SPARSE_MLA_MIXED_WARMUP_TOKENS, max_tokens + ) + prefill_tokens = _clamp_warmup_tokens( + _DEEPSEEK_V4_SPARSE_MLA_PREFILL_WARMUP_TOKENS, max_tokens + ) + uniform_decode_reqs = _deepseek_v4_mtp_uniform_decode_warmup_requests( + runner, + max_tokens=max_tokens, + max_reqs=worker.scheduler_config.max_num_seqs, + ) + if mixed_tokens <= 0 and prefill_tokens <= 0 and not uniform_decode_reqs: + return + + logger.info( + "Warming up DeepSeek V4 sparse MLA attention " + "for mixed tokens=%s, prefill tokens=%s, and MTP uniform decode " + "requests=%s.", + mixed_tokens, + prefill_tokens, + list(uniform_decode_reqs), + ) + if mixed_tokens > 0: + runner._dummy_run( + num_tokens=mixed_tokens, + skip_eplb=True, + is_profile=True, + force_attention=True, + create_mixed_batch=True, + ) + if prefill_tokens > 0: + runner._dummy_run( + num_tokens=prefill_tokens, + skip_eplb=True, + is_profile=True, + force_attention=True, + create_single_prefill=True, + ) + # Simulate the second-and-later chunk of a chunked prefill so + # `_build_prefill_chunk_metadata_kernel` and the alt-shape + # `_w8a8_triton_block_scaled_mm` configs that fire when the + # indexer sees prior context get JIT-compiled here, not on the + # first user request that exceeds `max_num_batched_tokens`. + runner._dummy_run( + num_tokens=prefill_tokens, + skip_eplb=True, + is_profile=True, + force_attention=True, + create_single_prefill=True, + profile_seq_lens=prefill_tokens * 2, + ) + # Do not synthesize multi-request prefill here: that dummy shape + # overflows the CUTeDSL KV-gather workspace on SM12x. Revisit only + # with a real buffer-sizing fix for that warmup path. + + # The prefill dummies above never drive the C128A indexer, so the + # D512-split prefill kernels stay uncompiled until the first long request + # (PR #41834 wedge). Compile them directly with synthetic inputs. + _deepseek_v4_indexed_d512_split_prefill_warmup(runner) + + query_len = getattr(runner, "uniform_decode_query_len", 0) + for num_reqs in uniform_decode_reqs: + runner._dummy_run( + num_tokens=num_reqs * query_len, + skip_eplb=True, + is_profile=True, + force_attention=True, + uniform_decode=True, + ) + + if uniform_decode_reqs and current_platform.is_cuda_alike(): + vocab_size = runner.model_config.get_vocab_size() + block_size = getattr(runner.cache_config, "block_size", None) or 16 + logger.info( + "Warming up DeepSeek V4 MTP spec-decode kernels for request " + "counts=%s and %d draft tokens.", + list(uniform_decode_reqs), + runner.num_spec_tokens, + ) + for num_reqs in uniform_decode_reqs: + _run_deepseek_v4_mtp_spec_decode_warmup_kernels( + device=runner.device, + num_reqs=num_reqs, + num_spec_tokens=runner.num_spec_tokens, + vocab_size=vocab_size, + block_size=block_size, + max_model_len=runner.max_model_len, + ) + torch.accelerator.synchronize() + def kernel_warmup(worker: "Worker"): from vllm.model_executor.warmup.minimax_m3_msa_warmup import ( @@ -68,6 +690,9 @@ def kernel_warmup(worker: "Worker"): minimax_m3_msa_warmup(worker) + _deepseek_v4_sparse_mla_attention_warmup(worker) + _deepseek_v4_request_prep_warmup(worker) + enable_flashinfer_autotune = ( worker.vllm_config.kernel_config.enable_flashinfer_autotune ) diff --git a/vllm/models/deepseek_v4/common/ops/__init__.py b/vllm/models/deepseek_v4/common/ops/__init__.py index ff6ee22996d6..9c15ce832cc5 100644 --- a/vllm/models/deepseek_v4/common/ops/__init__.py +++ b/vllm/models/deepseek_v4/common/ops/__init__.py @@ -6,7 +6,10 @@ combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, + dequantize_combined_sparse_mla_decode_kv, + dequantize_global_slots_k_cache, quantize_and_insert_k_cache, + sparse_prefill_combined_topk_size, ) from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant from .fused_inv_rope_fp8_quant import fused_inv_rope_fp8_quant @@ -20,6 +23,8 @@ "combine_topk_swa_indices", "compute_global_topk_indices_and_lens", "dequantize_and_gather_k_cache", + "dequantize_combined_sparse_mla_decode_kv", + "dequantize_global_slots_k_cache", "fused_indexer_q_rope_quant", "fused_inv_rope_fp8_quant", "fused_mtp_input_rmsnorm", @@ -27,4 +32,5 @@ "mtp_shared_head_rmsnorm", "quantize_and_insert_k_cache", "save_partial_states", + "sparse_prefill_combined_topk_size", ] diff --git a/vllm/models/deepseek_v4/common/ops/cache_utils.py b/vllm/models/deepseek_v4/common/ops/cache_utils.py index ffaec528aa86..32c6850b10ca 100644 --- a/vllm/models/deepseek_v4/common/ops/cache_utils.py +++ b/vllm/models/deepseek_v4/common/ops/cache_utils.py @@ -423,12 +423,166 @@ def dequantize_and_gather_k_cache( ) +@triton.jit +def _dequantize_global_slots_k_kernel( + out_ptr, + out_stride_token, + out_stride_slot, + k_cache_ptr, + slot_ids_ptr, + slot_ids_stride_token, + slot_ids_stride_slot, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + bf16_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + output_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + topk_idx = tl.program_id(1) + + slot_id = tl.load( + slot_ids_ptr + + token_idx * slot_ids_stride_token + + topk_idx * slot_ids_stride_slot + ) + offsets = tl.arange(0, BLOCK_D) + output_row = out_ptr + token_idx * out_stride_token + topk_idx * out_stride_slot + + if slot_id < 0: + tl.store( + output_row + offsets, + tl.zeros((BLOCK_D,), dtype=tl.float32).to(tl.bfloat16), + mask=offsets < output_dim, + ) + return + + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + cache_block_size * token_data_size + pos_in_block * scale_dim + ) + + fp8_offsets = tl.arange(0, 512) + fp8_mask = fp8_offsets < fp8_dim + x_uint8 = tl.load(token_data_ptr + fp8_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + + scale_offsets = fp8_offsets // quant_block + encoded_scale = tl.load(token_scale_ptr + scale_offsets, mask=fp8_mask, other=127) + scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * scale + tl.store(output_row + fp8_offsets, x_dequant.to(tl.bfloat16), mask=fp8_mask) + + bf16_offsets = tl.arange(0, 64) + bf16_cache_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + bf16_vals = tl.load(bf16_cache_ptr + bf16_offsets, mask=bf16_offsets < bf16_dim) + tl.store( + output_row + fp8_dim + bf16_offsets, + bf16_vals, + mask=bf16_offsets < bf16_dim, + ) + + +def dequantize_global_slots_k_cache( + out: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + block_size: int, +) -> None: + """Dequantize fp8_ds_mla cache rows addressed by physical global slot ids.""" + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0, :] + assert slot_ids.dim() == 2, ( + f"slot_ids must be [num_tokens, topk], got {slot_ids.shape}" + ) + assert out.shape[:2] == slot_ids.shape + assert out.shape[-1] == 512 + assert out.dtype == torch.bfloat16 + assert k_cache.dtype == torch.uint8 + + TOKEN_FP8_DIM = 448 + TOKEN_BF16_DIM = 64 + TOKEN_SCALE_DIM = 8 + QUANT_BLOCK_SIZE = 64 + TOKEN_DATA_SIZE = TOKEN_FP8_DIM + TOKEN_BF16_DIM * 2 + + grid = slot_ids.shape + _dequantize_global_slots_k_kernel[grid]( + out, + out.stride(0), + out.stride(1), + k_cache, + slot_ids, + slot_ids.stride(0), + slot_ids.stride(1), + cache_block_size=block_size, + token_data_size=TOKEN_DATA_SIZE, + block_stride=k_cache.stride(0), + fp8_dim=TOKEN_FP8_DIM, + bf16_dim=TOKEN_BF16_DIM, + scale_dim=TOKEN_SCALE_DIM, + quant_block=QUANT_BLOCK_SIZE, + output_dim=512, + BLOCK_D=triton.next_power_of_2(512), + ) + + +def dequantize_combined_sparse_mla_decode_kv( + combined_kv: torch.Tensor, + compressed_k_cache: torch.Tensor, + compressed_slot_ids: torch.Tensor, + compressed_block_size: int, + swa_k_cache: torch.Tensor, + seq_lens: torch.Tensor, + swa_lens: torch.Tensor, + block_table: torch.Tensor, + swa_block_size: int, +) -> None: + """Fill `[compressed, SWA]` decode candidates into one output buffer.""" + assert combined_kv.dim() == 3 + compressed_topk = compressed_slot_ids.shape[-1] + assert combined_kv.shape[0] == compressed_slot_ids.shape[0] + assert combined_kv.shape[-1] == 512 + assert combined_kv.dtype == torch.bfloat16 + assert combined_kv.shape[1] >= compressed_topk + + dequantize_global_slots_k_cache( + combined_kv[:, :compressed_topk], + compressed_k_cache, + compressed_slot_ids, + compressed_block_size, + ) + swa_out = combined_kv[:, compressed_topk:] + if swa_out.shape[1] == 0: + return + dequantize_and_gather_k_cache( + swa_out, + swa_k_cache, + seq_lens=seq_lens, + gather_lens=swa_lens, + block_table=block_table, + block_size=swa_block_size, + offset=0, + ) + + def compute_global_topk_indices_and_lens( topk_indices: torch.Tensor, token_to_req_indices: torch.Tensor, block_table: torch.Tensor, block_size: int, is_valid_token: torch.Tensor, + global_topk_indices: torch.Tensor | None = None, + topk_lens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Map local topk indices to global KV cache slots and count valid entries. @@ -438,8 +592,20 @@ def compute_global_topk_indices_and_lens( 3. Masking padding tokens to length 0 """ num_tokens = topk_indices.shape[0] - global_topk_indices = torch.empty_like(topk_indices) - topk_lens = torch.empty(num_tokens, dtype=torch.int32, device=topk_indices.device) + if global_topk_indices is None: + global_topk_indices = torch.empty_like(topk_indices) + else: + assert global_topk_indices.shape == topk_indices.shape + assert global_topk_indices.dtype == topk_indices.dtype + assert global_topk_indices.device == topk_indices.device + if topk_lens is None: + topk_lens = torch.empty( + num_tokens, dtype=torch.int32, device=topk_indices.device + ) + else: + assert topk_lens.shape == (num_tokens,) + assert topk_lens.dtype == torch.int32 + assert topk_lens.device == topk_indices.device _compute_global_topk_indices_and_lens_kernel[(num_tokens,)]( global_topk_indices, global_topk_indices.stride(0), @@ -486,7 +652,7 @@ def _compute_global_topk_indices_and_lens_kernel( mask=mask, other=-1, ) - is_valid = local_idx >= 0 + is_valid = (local_idx >= 0) & is_valid_token block_indices = local_idx // block_size block_numbers = tl.load( @@ -516,6 +682,14 @@ def _compute_global_topk_indices_and_lens_kernel( _SPARSE_PREFILL_TOPK_ALIGNMENT = 128 +def sparse_prefill_combined_topk_size(topk: int, window_size: int) -> int: + return ( + (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1) + // _SPARSE_PREFILL_TOPK_ALIGNMENT + * _SPARSE_PREFILL_TOPK_ALIGNMENT + ) + + def combine_topk_swa_indices( topk_indices: torch.Tensor, query_start_loc: torch.Tensor, @@ -526,23 +700,35 @@ def combine_topk_swa_indices( topk: int, M: int, N: int, + combined_indices: torch.Tensor | None = None, + combined_lens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: num_tokens = topk_indices.shape[0] num_reqs = seq_lens.shape[0] - combined_topk = ( - (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1) - // _SPARSE_PREFILL_TOPK_ALIGNMENT - * _SPARSE_PREFILL_TOPK_ALIGNMENT - ) - combined_indices = torch.full( - (num_tokens, combined_topk), - fill_value=-1, - dtype=torch.int32, - device=topk_indices.device, - ) - combined_lens = torch.empty( - num_tokens, dtype=torch.int32, device=topk_indices.device - ) + combined_topk = sparse_prefill_combined_topk_size(topk, window_size) + if combined_indices is None: + combined_indices = torch.full( + (num_tokens, combined_topk), + fill_value=-1, + dtype=torch.int32, + device=topk_indices.device, + ) + else: + assert combined_indices.shape[0] >= num_tokens + assert combined_indices.shape[1] >= combined_topk + assert combined_indices.dtype == torch.int32 + assert combined_indices.device == topk_indices.device + combined_indices = combined_indices[:num_tokens, :combined_topk] + combined_indices.fill_(-1) + if combined_lens is None: + combined_lens = torch.empty( + num_tokens, dtype=torch.int32, device=topk_indices.device + ) + else: + assert combined_lens.shape[0] >= num_tokens + assert combined_lens.dtype == torch.int32 + assert combined_lens.device == topk_indices.device + combined_lens = combined_lens[:num_tokens] NUM_WORKERS = 128 _combine_topk_swa_indices_kernel[(num_reqs, NUM_WORKERS)]( diff --git a/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py b/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py new file mode 100644 index 000000000000..547c2380f6d8 --- /dev/null +++ b/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py @@ -0,0 +1,470 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""DeepSeek V4 FlashInfer packed sparse-MLA SM120 decode (gated). + +Subclasses the FlashMLA V4 attention to reuse its packed ``fp8_ds_mla`` KV cache, +sparse-index metadata, and packed prefill, and overrides only the decode path to +use FlashInfer's official SM120 packed sparse-MLA decode kernel (FlashInfer +PR3395, merged in flashinfer >= 0.6.13). That kernel scales better at high +concurrency in the MTP speculative-verify (multi-query) decode shape than the +FlashMLA decode kernel, which is the root cause of the C8-C64 ctx0 decode gap. + +Implementation note: flashinfer main exposes this kernel through the +``trtllm_batch_decode_sparse_mla_dsv4`` wrapper, but that wrapper re-validates +inputs and -- critically -- carves the split-K ``mid_out``/``mid_lse`` scratch +from a fixed workspace only for ``num_tokens <= 64``, falling back to a fresh +``torch.empty`` of hundreds of MB on every decode step above that. The MTP +multi-query decode shape routinely exceeds 64 tokens (C32/C64), so that per-step +allocation dominates and makes the wrapper materially slower than the FlashMLA +path. We instead drive the same kernel through its low-level +``_SparseMLAPagedAttentionRunner``, constructed once and fed graph-stable scratch +from vLLM's workspace manager -- so the scratch is reserved during warmup and +reused, never reallocated per step. + +Gated behind ``VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE``; selected only on SM12x +when the official packed kernel is importable (see ``_select_dsv4_attn_cls``). +Default off; gate-off behavior is identical to the FlashMLA decode path. +""" + +from typing import TYPE_CHECKING + +import torch + +from vllm.config import get_current_vllm_config +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.models.deepseek_v4.common.ops import compute_global_topk_indices_and_lens +from vllm.models.deepseek_v4.nvidia.flashmla import DeepseekV4FlashMLAAttention +from vllm.v1.worker.workspace import current_workspace_manager + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.models.deepseek_v4.sparse_mla import DeepseekV4FlashMLAMetadata + from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata + +# Split-K decode scratch sizing, mirrored from the FlashInfer sparse-sm120 +# kernel (``_BI`` = 64 candidates per partition tile): one tile per SWA top-k +# plus one per compressed (extra) top-k. Cap the warmup reservation at the +# largest single-graph decode batch. +_DECODE_MAX_TOKENS = 64 +_DECODE_SPLIT_TILE = 64 +_C128A_TOPK_ALIGNMENT = 128 + + +def _cdiv(x: int, y: int) -> int: + return (int(x) + int(y) - 1) // int(y) + + +def _max_decode_workspace_tokens(max_num_batched_tokens: int) -> int: + return min(int(max_num_batched_tokens), _DECODE_MAX_TOKENS) + + +def _decode_num_splits(topk: int, extra_topk: int = 0) -> int: + return _cdiv(topk, _DECODE_SPLIT_TILE) + _cdiv(extra_topk, _DECODE_SPLIT_TILE) + + +def _c128a_max_compressed(max_model_len: int, compress_ratio: int) -> int: + return ( + _cdiv(_cdiv(max_model_len, compress_ratio), _C128A_TOPK_ALIGNMENT) + * _C128A_TOPK_ALIGNMENT + ) + + +def _get_decode_scratch( + num_tokens: int, + num_heads: int, + head_dim: int, + topk: int, + extra_topk: int = 0, +) -> tuple[torch.Tensor, torch.Tensor]: + num_splits = _decode_num_splits(topk, extra_topk) + mid_out, mid_lse = current_workspace_manager().get_simultaneous( + ((num_tokens, num_heads, num_splits, head_dim), torch.bfloat16), + ((num_tokens, num_heads, num_splits), torch.float32), + ) + return mid_out, mid_lse + + +def _as_sparse_sm120_cache(kv_cache: torch.Tensor) -> torch.Tensor: + if kv_cache.dtype == torch.float8_e4m3fn: + kv_cache = kv_cache.view(torch.uint8) + if kv_cache.dim() == 4: + return kv_cache + return kv_cache.unsqueeze(-2) + + +def _get_prefill_swa_scratch( + num_tokens: int, window_size: int +) -> tuple[torch.Tensor, torch.Tensor]: + # Graph-stable per-token SWA window indices + lengths for the prefill tokens. + swa_indices, swa_lens = current_workspace_manager().get_simultaneous( + ((num_tokens, 1, window_size), torch.int32), + ((num_tokens,), torch.int32), + ) + return swa_indices, swa_lens + + +class DeepseekV4FlashInferSM120DecodeAttention(DeepseekV4FlashMLAAttention): + """FlashMLA V4 attention with the official FlashInfer SM120 packed decode. + + Reuses every FlashMLA V4 behavior (packed ``fp8_ds_mla`` cache, metadata + pipeline, packed prefill); only :meth:`_forward_decode` differs. + """ + + @classmethod + def get_padded_num_q_heads(cls, num_heads: int) -> int: + if num_heads <= 16: + return 16 + if num_heads <= 32: + return 32 + if num_heads <= 64: + return 64 + if num_heads <= 128: + return 128 + raise ValueError( + f"DeepseekV4 FlashInfer sparse-sm120 decode does not support " + f"{num_heads} heads (SM120 kernel requires h_q in {{16, 32, 64, 128}})." + ) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + from vllm.utils.flashinfer import has_flashinfer_trtllm_sparse_mla_dsv4 + + if not has_flashinfer_trtllm_sparse_mla_dsv4(): + raise RuntimeError( + "VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE requires FlashInfer's " + "SM120 packed sparse-MLA decode kernel " + "(trtllm_batch_decode_sparse_mla_dsv4, PR3395, " + "flashinfer >= 0.6.13)." + ) + + from flashinfer.mla._sparse_mla_sm120 import _SparseMLAPagedAttentionRunner + + max_tokens = get_current_vllm_config().scheduler_config.max_num_batched_tokens + runner_device = torch.device("cuda", torch.accelerator.current_device_index()) + # Construct the low-level runner once: its only per-instance state is a + # pre-sized LSE buffer. We feed it graph-stable mid_out/mid_lse scratch + # explicitly on every call, so it never allocates per step. + self._sm120_runner = _SparseMLAPagedAttentionRunner( + max_num_tokens=max_tokens, + max_num_heads=self.padded_heads, + d_v=self.head_dim, + kv_scale_format="auto", + device=runner_device, + ) + logger.info_once( + "DeepSeek V4: using official FlashInfer SM120 packed sparse-MLA decode " + "via the low-level runner (VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE=1)." + ) + + def _reserve_sm120_decode_workspace(self) -> None: + if self.compress_ratio <= 1: + extra_topk = 0 + elif self.compress_ratio == 4: + assert self.topk_indices_buffer is not None + extra_topk = self.topk_indices_buffer.shape[-1] + elif self.compress_ratio == 128: + extra_topk = _c128a_max_compressed(self.max_model_len, self.compress_ratio) + else: + raise ValueError( + f"Unsupported compress_ratio={self.compress_ratio}; " + "expected 1, 4, or 128." + ) + _get_decode_scratch( + _max_decode_workspace_tokens(self.max_num_batched_tokens), + self.padded_heads, + self.head_dim, + self.window_size, + extra_topk, + ) + + def _prepare_sm120_query( + self, q: torch.Tensor, output: torch.Tensor + ) -> torch.Tensor: + # The SM120 packed kernel consumes a bf16 query; the FlashMLA fp8 path + # keeps q in fp8, so convert here. q already arrives padded to + # ``padded_heads`` by the outer attention wrapper. + if self.kv_cache_torch_dtype == torch.float8_e4m3fn: + assert q.dtype == torch.float8_e4m3fn + q = q.to(torch.bfloat16) + else: + assert q.dtype == torch.bfloat16 + padded_heads = output.shape[1] + if q.shape[1] < padded_heads: + padded_query = q.new_zeros((q.shape[0], padded_heads, q.shape[2])) + padded_query[:, : q.shape[1], :] = q + q = padded_query + return q.contiguous() + + def forward_mqa( + self, + q: torch.Tensor, + kv: torch.Tensor, + positions: torch.Tensor, + output: torch.Tensor, + ) -> None: + # Mirror the FlashMLA warmup branch but also reserve the graph-stable + # split-K decode scratch the sparse-sm120 kernel needs, then defer to the + # parent for the real prefill/decode split (which calls our overridden + # _forward_decode). + forward_context = get_forward_context() + if forward_context.attn_metadata is None: + self._reserve_prefill_workspace(self) + self._reserve_sm120_decode_workspace() + output.zero_() + return + super().forward_mqa(q, kv, positions, output) + + def _forward_decode( + self, + q: torch.Tensor, + kv_cache: torch.Tensor | None, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: "DeepseekV4FlashMLAMetadata | None", + swa_only: bool, + output: torch.Tensor, + ) -> None: + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + + # Identical decode-side index/length construction to the FlashMLA decode + # path; only the kernel launch below differs. + topk_indices = None + topk_lens = None + if not swa_only: + assert attn_metadata is not None + assert swa_metadata.is_valid_token is not None + block_size = attn_metadata.block_size // self.compress_ratio + is_valid = swa_metadata.is_valid_token[:num_decode_tokens] + if self.compress_ratio == 4: + assert self.topk_indices_buffer is not None + global_indices, topk_lens = compute_global_topk_indices_and_lens( + self.topk_indices_buffer[:num_decode_tokens], + swa_metadata.token_to_req_indices, + attn_metadata.block_table[:num_decodes], + block_size, + is_valid, + ) + topk_indices = global_indices.view(num_decode_tokens, 1, -1) + else: + topk_indices = attn_metadata.c128a_global_decode_topk_indices + topk_lens = attn_metadata.c128a_decode_topk_lens + # The sparse-sm120 kernel asserts the extra (compressed) index + # tensor is int32 and contiguous; current's metadata builder can hand + # back a non-contiguous view, so normalize before the launch. + if topk_indices is not None: + topk_indices = topk_indices.contiguous() + + swa_indices = swa_metadata.decode_swa_indices + swa_lens = swa_metadata.decode_swa_lens + assert swa_indices is not None + assert swa_lens is not None + + extra_topk = topk_indices.shape[-1] if topk_indices is not None else 0 + mid_out, mid_lse = _get_decode_scratch( + num_decode_tokens, + output.shape[1], + output.shape[-1], + swa_indices.shape[-1], + extra_topk, + ) + + # Each decode token is a one-token query: [num_decode_tokens, 1, h, d]; + # the runner squeezes the singleton s_q dim internally. + q = self._prepare_sm120_query(q, output).unsqueeze(1) + swa_cache = _as_sparse_sm120_cache(self.swa_cache_layer.kv_cache) + extra_cache = ( + _as_sparse_sm120_cache(kv_cache) + if (kv_cache is not None and not swa_only) + else None + ) + self._sm120_runner.run( + q, + swa_cache, + swa_indices, + output, + self.scale, + topk_length=swa_lens, + attn_sink=self.attn_sink, + extra_kv_cache=extra_cache, + extra_indices=topk_indices, + extra_topk_length=topk_lens, + mid_out=mid_out, + mid_lse=mid_lse, + ) + + def _forward_prefill( + self, + q: torch.Tensor, + positions: torch.Tensor, + compressed_k_cache: torch.Tensor | None, + swa_k_cache: torch.Tensor, + output: torch.Tensor, + attn_metadata: "DeepseekV4FlashMLAMetadata | None", + swa_metadata: "DeepseekSparseSWAMetadata", + ) -> None: + import vllm.envs as envs + + # Packed prefill is an independent opt-in on top of the decode port; when + # off, defer to the FlashMLA indexed-D512 prefill path byte-for-byte. + if not envs.VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL: + super()._forward_prefill( + q, + positions, + compressed_k_cache, + swa_k_cache, + output, + attn_metadata, + swa_metadata, + ) + return + + swa_only = attn_metadata is None + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + num_prefills = swa_metadata.num_prefills + num_prefill_tokens = swa_metadata.num_prefill_tokens + num_reqs = num_decodes + num_prefills + num_tokens = num_decode_tokens + num_prefill_tokens + if num_prefill_tokens == 0: + return + + assert swa_metadata.is_valid_token is not None + assert swa_metadata.query_start_loc is not None + assert swa_metadata.seq_lens is not None + assert swa_metadata.token_to_req_indices is not None + assert swa_metadata.block_table is not None + + # --- Prefill SWA window indices. The metadata builder hoists this once per + # step (DeepseekSparseSWAMetadataBuilder.build widens its decode-SWA launch + # over the prefill tail), so steady-state just reads the precomputed views + # and skips ~60 redundant per-layer kernel launches. The builder deliberately + # leaves them None on the warmup/profile dummy (its all-(-1) slot_mapping + # makes is_valid_token all-False over the prefill tail) and during CUDA-graph + # capture; we then self-compute exactly as before, keeping those paths + # byte-identical to the validated v1 (the prefill packed kernel is a no-op + # over the all-invalid dummy: every token gets swa_len=0). + swa_indices = swa_metadata.prefill_swa_indices + swa_lens = swa_metadata.prefill_swa_lens + if swa_indices is None or swa_lens is None: + from vllm.v1.attention.backends.mla.sparse_swa import ( + _compute_swa_indices_and_lens_kernel, + ) + + swa_idx_full, swa_len_full = _get_prefill_swa_scratch( + num_tokens, self.window_size + ) + _compute_swa_indices_and_lens_kernel[(num_tokens,)]( + swa_idx_full, + swa_idx_full.stride(0), + swa_len_full, + self.window_size, + swa_metadata.query_start_loc, + swa_metadata.seq_lens, + swa_metadata.token_to_req_indices, + swa_metadata.is_valid_token, + swa_metadata.block_table, + swa_metadata.block_table.stride(0), + swa_metadata.block_size, + token_offset=0, + TRITON_BLOCK_SIZE=1024, + ) + swa_indices = swa_idx_full[num_decode_tokens:num_tokens] + swa_lens = swa_len_full[num_decode_tokens:num_tokens] + + # --- Compressed (extra) prefill indices, mirroring the FlashMLA prefill + # construction but converted to global slots for the packed kernel. + topk_indices = None + topk_lens = None + if not swa_only: + assert attn_metadata is not None + block_size = attn_metadata.block_size // self.compress_ratio + if self.compress_ratio == 4: + assert self.topk_indices_buffer is not None + prefill_local = self.topk_indices_buffer[num_decode_tokens:num_tokens] + # Rebase the indexer's BATCH-GLOBAL compressed top-k positions + # (cu_seqlen_ks = exclusive cumsum of seq_len // compress_ratio; see + # indexer.py) to per-request-local so block_table[req] maps them + # in-range. Without this, req>0 positions overflow into the wrong + # request's physical blocks. No-op at num_prefills==1 (cu_base[0]==0). + comp_lens = ( + swa_metadata.seq_lens[num_decodes:num_reqs] // self.compress_ratio + ) + cu_base = (torch.cumsum(comp_lens, dim=0) - comp_lens).to(torch.int32) + req_local = ( + swa_metadata.token_to_req_indices[num_decode_tokens:num_tokens] + - num_decodes + ).long() + base_per_token = cu_base[req_local].unsqueeze(1) + prefill_local = torch.where( + prefill_local >= 0, prefill_local - base_per_token, prefill_local + ) + global_indices, topk_lens = compute_global_topk_indices_and_lens( + prefill_local, + swa_metadata.token_to_req_indices[num_decode_tokens:num_tokens], + attn_metadata.block_table[:num_reqs], + block_size, + swa_metadata.is_valid_token[num_decode_tokens:num_tokens], + ) + topk_indices = global_indices.view(num_prefill_tokens, 1, -1) + else: + assert attn_metadata.c128a_prefill_topk_indices is not None + topk_indices = attn_metadata.c128a_prefill_topk_indices.view( + num_prefill_tokens, 1, -1 + ) + topk_indices = topk_indices.contiguous() + + # --- Launch the packed prefill kernel via the runner. num_tokens > 64 + # auto-dispatches the prefill kernel; mid_out/mid_lse are decode-only and + # only needed for the (rare) <=64-token prefill chunk. + query = self._prepare_sm120_query(q, output) + # Bug-C guard: under CUDA-graph padding or MTP-draft, q can carry more rows + # than the real prefill-token count; the runner sizes its writes by the + # query row count, so slice to num_prefill_tokens to match output/indices/ + # scratch (no-op in the common, unpadded case). + if query.shape[0] > num_prefill_tokens: + query = query[:num_prefill_tokens] + # The packed kernel hard-asserts output.size(0) == num_tokens (derived from + # the sliced query row count). The output buffer can still carry padded + # prefill rows (output[num_decode_tokens:] of a CUDA-graph / MTP-draft padded + # batch), so slice it the same way as the query/indices/scratch above. It is + # a view into the same storage and the padded tail rows are never read or + # written downstream, so this only narrows what the kernel writes; no-op in + # the common unpadded case. Without it the kernel aborts (84 vs 83). + out = ( + output[:num_prefill_tokens] + if output.shape[0] > num_prefill_tokens + else output + ) + swa_cache = _as_sparse_sm120_cache(swa_k_cache) + extra_cache = ( + _as_sparse_sm120_cache(compressed_k_cache) + if (compressed_k_cache is not None and not swa_only) + else None + ) + mid_out = None + mid_lse = None + if num_prefill_tokens <= _DECODE_MAX_TOKENS: + extra_topk = topk_indices.shape[-1] if topk_indices is not None else 0 + mid_out, mid_lse = _get_decode_scratch( + num_prefill_tokens, + output.shape[1], + output.shape[-1], + swa_indices.shape[-1], + extra_topk, + ) + self._sm120_runner.run( + query, + swa_cache, + swa_indices, + out, + self.scale, + topk_length=swa_lens, + attn_sink=self.attn_sink, + extra_kv_cache=extra_cache, + extra_indices=topk_indices, + extra_topk_length=topk_lens, + mid_out=mid_out, + mid_lse=mid_lse, + ) diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index 9fa4e1c11b94..a59452712a97 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -5,12 +5,16 @@ import torch +import vllm.envs as envs from vllm.forward_context import get_forward_context from vllm.models.deepseek_v4.attention import DeepseekV4Attention from vllm.models.deepseek_v4.common.ops import ( combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, + dequantize_combined_sparse_mla_decode_kv, + dequantize_global_slots_k_cache, + sparse_prefill_combined_topk_size, ) from vllm.models.deepseek_v4.nvidia.ops.o_proj import ( compute_fp8_einsum_recipe, @@ -20,6 +24,27 @@ DeepseekV4FlashMLABackend, DeepseekV4FlashMLAMetadata, ) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled, + is_triton_sparse_mla_enabled_for_platform, + triton_sparse_mla_matmul_decode_enabled, + triton_sparse_mla_prefill_topk_chunk_size, + triton_sparse_mla_query_chunk_size, + triton_sparse_mla_topk_chunk_size, +) +from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead, + accumulate_indexed_d512_chunked_sparse_mla_attention, + accumulate_indexed_d512_split_sparse_mla_attention, + accumulate_indexed_sparse_mla_attention_chunk, + build_combined_sparse_mla_decode_valid_mask, + finish_sparse_mla_attention_with_sink, + finish_two_sparse_mla_attention_states_with_sink, + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead, + fp8ds_paged_sparse_mla_attention_with_sink_multihead, + matmul_sparse_mla_attention_with_sink, + sparse_mla_decode_head_block_size, +) from vllm.v1.attention.ops.flashmla import ( flash_mla_sparse_fwd, flash_mla_with_kvcache, @@ -30,6 +55,76 @@ from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata +_INDEXED_D512_SPLIT_PREFILL_MIN_TOPK = 256 +_INDEXED_D512_SPLIT_PREFILL_MAX_TOPK = 1152 + + +def _use_indexed_d512_split_prefill( + *, + compress_ratio: int, + head_dim: int, + num_prefills: int, + combined_topk: int, + max_prefill_seq_len: int, + swa_only: bool, +) -> bool: + return ( + envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL + and not swa_only + and compress_ratio in (4, 128) + and head_dim == 512 + and num_prefills == 1 + and _is_indexed_d512_split_topk(combined_topk) + and max_prefill_seq_len + >= envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS + ) + + +def _is_indexed_d512_split_topk(combined_topk: int) -> bool: + return ( + _INDEXED_D512_SPLIT_PREFILL_MIN_TOPK + <= combined_topk + <= _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK + ) + + +def _use_indexed_d512_chunked_prefill( + *, + compress_ratio: int, + head_dim: int, + num_prefills: int, + combined_topk: int, + max_prefill_seq_len: int, + swa_only: bool, +) -> bool: + return ( + envs.VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL + and envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL + and not swa_only + and compress_ratio in (4, 128) + and head_dim == 512 + and num_prefills == 1 + and combined_topk > _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK + and max_prefill_seq_len + >= envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS + ) + + +def _sparse_mla_prefill_gather_len_upper_bound( + *, + max_model_len: int, + max_num_batched_tokens: int, + window_size: int, +) -> tuple[int, int]: + max_query_chunk_tokens = max(1, min(max_model_len, max_num_batched_tokens)) + max_prefix_len = max(max_model_len - max_query_chunk_tokens, 0) + max_gather_len = max_query_chunk_tokens + min( + max_prefix_len, + max(window_size - 1, 0), + ) + return max_query_chunk_tokens, max_gather_len + + class DeepseekV4FlashMLAAttention(DeepseekV4Attention): """FlashMLA sparse MLA attention layer for DeepSeek V4 (CUDA).""" @@ -65,6 +160,119 @@ def get_padded_num_q_heads(cls, num_heads: int) -> int: ) return 64 if num_heads <= 64 else 128 + @classmethod + def _prefill_workspace_topk_bound( + cls, + layer: "DeepseekV4FlashMLAAttention", + ) -> int: + if layer.compress_ratio <= 1: + return 0 + if ( + layer.topk_indices_buffer is not None + and layer.topk_indices_buffer.ndim > 0 + and layer.topk_indices_buffer.shape[-1] > 0 + ): + return int(layer.topk_indices_buffer.shape[-1]) + indexer_topk = getattr(layer.indexer, "topk_tokens", None) + if indexer_topk is not None: + return int(indexer_topk) + return 2048 + + @classmethod + def _prefill_workspace_reservation_specs( + cls, + layer: "DeepseekV4FlashMLAAttention", + ) -> tuple[tuple[tuple[int, ...], torch.dtype], ...]: + max_model_len = max(1, int(layer.max_model_len)) + max_num_batched_tokens = max(1, int(layer.max_num_batched_tokens)) + window_size = max(1, int(layer.window_size)) + compress_ratio = max(1, int(layer.compress_ratio)) + head_dim = int(layer.head_dim) + num_heads = int(layer.n_local_heads) + + max_query_chunk_tokens, max_gather_len = ( + _sparse_mla_prefill_gather_len_upper_bound( + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + window_size=window_size, + ) + ) + if compress_ratio <= 1: + m_bound = max_gather_len + else: + compressed_region_size = max_model_len // compress_ratio + m_bound = compressed_region_size + max_gather_len + + combined_topk = sparse_prefill_combined_topk_size( + cls._prefill_workspace_topk_bound(layer), + window_size, + ) + specs: list[tuple[tuple[int, ...], torch.dtype]] = [ + ((layer.PREFILL_CHUNK_SIZE, m_bound, head_dim), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ] + if is_triton_sparse_mla_enabled_for_platform(): + query_chunk_size = min( + max_query_chunk_tokens, + triton_sparse_mla_query_chunk_size(), + ) + specs.extend( + [ + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads, head_dim), torch.float32), + ] + ) + if _use_indexed_d512_split_prefill( + compress_ratio=compress_ratio, + head_dim=head_dim, + num_prefills=1, + combined_topk=combined_topk, + max_prefill_seq_len=max_model_len, + swa_only=False, + ): + specs.append( + ((query_chunk_size, num_heads, combined_topk), torch.float32) + ) + elif _use_indexed_d512_chunked_prefill( + compress_ratio=compress_ratio, + head_dim=head_dim, + num_prefills=1, + combined_topk=combined_topk, + max_prefill_seq_len=max_model_len, + swa_only=False, + ): + chunked_score_width = min( + combined_topk, + _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK, + ) + specs.extend( + ( + ( + (query_chunk_size, num_heads, chunked_score_width), + torch.float32, + ), + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads, head_dim), torch.float32), + ) + ) + return tuple(specs) + + @classmethod + def _reserve_prefill_workspace( + cls, + layer: "DeepseekV4FlashMLAAttention", + ) -> None: + try: + workspace_manager = current_workspace_manager() + except AssertionError: + return + workspace_manager.get_simultaneous( + *cls._prefill_workspace_reservation_specs(layer) + ) + def forward_mqa( self, q: torch.Tensor, @@ -84,20 +292,9 @@ def forward_mqa( attn_metadata = forward_context.attn_metadata if attn_metadata is None: - # Warmup dummy run: no real metadata. Reserve the same bf16 - # gather workspace _forward_prefill would; the dequantize / topk - # / sparse_fwd kernels are skipped this step. - swa_only = self.compress_ratio <= 1 - N = ( - 0 - if swa_only - else (self.max_model_len + self.compress_ratio - 1) - // self.compress_ratio - ) - M = N + self.window_size + self.max_num_batched_tokens - current_workspace_manager().get_simultaneous( - ((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - ) + # Warmup dummy run: no real metadata. Reserve the same graph-stable + # workspace shapes _forward_prefill can use, but skip real kernels. + self._reserve_prefill_workspace(self) output.zero_() return @@ -142,6 +339,417 @@ def forward_mqa( output=output[:num_decode_tokens], ) + @classmethod + def _forward_sparse_mla_swa_decode_triton( + cls, + layer: "DeepseekV4FlashMLAAttention", + q: torch.Tensor, + swa_k_cache: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + output: torch.Tensor, + ) -> None: + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + mtp_decode = num_decode_tokens != num_decodes + + # Decode metadata is unconditionally populated when num_decode_tokens > 0, + # which is the only path that reaches the decode kernels. + assert swa_metadata.decode_swa_lens is not None + assert swa_metadata.decode_swa_indices is not None + assert swa_metadata.seq_lens is not None + swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens] + swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens] + max_swa_len = swa_metadata.decode_swa_indices.shape[-1] + head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens) + if not mtp_decode: + fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q=q, + k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + block_size=swa_metadata.block_size, + candidate_offset=0, + num_candidates=max_swa_len, + scale=layer.scale, + attn_sink=layer.attn_sink, + output=output, + head_block_size=head_block_size, + num_heads=layer.n_local_heads, + ) + if output.shape[1] > layer.n_local_heads: + output[:, layer.n_local_heads :].zero_() + return + + ( + swa_max_score, + swa_denom, + swa_acc, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, layer.n_local_heads), torch.float32), + ((num_decode_tokens, layer.n_local_heads), torch.float32), + ((num_decode_tokens, layer.n_local_heads, q.shape[-1]), torch.float32), + ) + swa_max_score.fill_(float("-inf")) + swa_denom.zero_() + swa_acc.zero_() + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + slot_ids=swa_indices, + lens=swa_lens, + block_size=swa_metadata.block_size, + scale=layer.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + finish_sparse_mla_attention_with_sink( + swa_max_score, + swa_denom, + swa_acc, + layer.attn_sink, + output=output, + ) + if output.shape[1] > layer.n_local_heads: + output[:, layer.n_local_heads :].zero_() + + @classmethod + def _forward_sparse_mla_compressed_decode_triton( + cls, + layer: "DeepseekV4FlashMLAAttention", + q: torch.Tensor, + compressed_k_cache: torch.Tensor, + swa_k_cache: torch.Tensor, + topk_indices: torch.Tensor, + topk_lens: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: DeepseekV4FlashMLAMetadata, + output: torch.Tensor, + ) -> None: + if layer.compress_ratio not in (4, 128): + raise NotImplementedError( + "Triton sparse MLA compressed decode currently supports " + f"compress_ratio=4 or 128, got {layer.compress_ratio}" + ) + + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + mtp_decode = num_decode_tokens != num_decodes + + # Decode metadata is unconditionally populated when num_decode_tokens > 0, + # which is the only path that reaches the decode kernels. + assert swa_metadata.decode_swa_lens is not None + assert swa_metadata.decode_swa_indices is not None + assert swa_metadata.seq_lens is not None + max_swa_len = swa_metadata.decode_swa_indices.shape[-1] + compressed_block_size = attn_metadata.block_size // layer.compress_ratio + compressed_topk = topk_indices.shape[-1] + topk_chunk_size = min( + compressed_topk, + triton_sparse_mla_topk_chunk_size(), + ) + compressed_slot_ids = topk_indices[:, 0, :] + swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens] + swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens] + head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens) + if ( + compressed_topk <= topk_chunk_size + and triton_sparse_mla_matmul_decode_enabled() + ): + total_candidates = compressed_topk + max_swa_len + ( + combined_kv, + valid_tokens, + score_buffer, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, total_candidates, q.shape[-1]), torch.bfloat16), + ((num_decode_tokens, total_candidates), torch.bool), + ( + (num_decode_tokens, layer.n_local_heads, total_candidates), + torch.bfloat16, + ), + ) + if mtp_decode: + dequantize_global_slots_k_cache( + combined_kv[:, :compressed_topk], + compressed_k_cache, + compressed_slot_ids, + compressed_block_size, + ) + dequantize_global_slots_k_cache( + combined_kv[:, compressed_topk:], + swa_k_cache, + swa_indices, + swa_metadata.block_size, + ) + else: + dequantize_combined_sparse_mla_decode_kv( + combined_kv, + compressed_k_cache, + compressed_slot_ids, + compressed_block_size, + swa_k_cache, + swa_metadata.seq_lens[:num_decodes], + swa_lens, + swa_metadata.block_table[:num_decodes], + swa_metadata.block_size, + ) + + build_combined_sparse_mla_decode_valid_mask( + valid_tokens, + compressed_slot_ids, + topk_lens, + swa_lens, + ) + use_dot_finish = num_decode_tokens <= 16 + matmul_sparse_mla_attention_with_sink( + q=q, + kv=combined_kv, + valid_tokens=valid_tokens, + scale=layer.scale, + attn_sink=layer.attn_sink, + output=output, + num_heads=layer.n_local_heads, + score_buffer=score_buffer, + value_block_size=512 if use_dot_finish else 256, + candidate_block_size=128 if use_dot_finish else None, + ) + return + + if not mtp_decode and compressed_topk <= topk_chunk_size: + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q=q, + compressed_k_cache=compressed_k_cache, + slot_ids=compressed_slot_ids, + topk_lens=topk_lens, + compressed_block_size=compressed_block_size, + swa_k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + swa_block_size=swa_metadata.block_size, + num_compressed_candidates=compressed_topk, + num_swa_candidates=max_swa_len, + scale=layer.scale, + attn_sink=layer.attn_sink, + output=output, + head_block_size=head_block_size, + num_heads=layer.n_local_heads, + ) + if output.shape[1] > layer.n_local_heads: + output[:, layer.n_local_heads :].zero_() + return + + ( + comp_max_score, + comp_denom, + comp_acc, + swa_max_score, + swa_denom, + swa_acc, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, layer.n_local_heads), torch.float32), + ((num_decode_tokens, layer.n_local_heads), torch.float32), + ((num_decode_tokens, layer.n_local_heads, q.shape[-1]), torch.float32), + ((num_decode_tokens, layer.n_local_heads), torch.float32), + ((num_decode_tokens, layer.n_local_heads), torch.float32), + ((num_decode_tokens, layer.n_local_heads, q.shape[-1]), torch.float32), + ) + comp_max_score.fill_(float("-inf")) + comp_denom.zero_() + comp_acc.zero_() + swa_max_score.fill_(float("-inf")) + swa_denom.zero_() + swa_acc.zero_() + + for chunk_start in range(0, compressed_topk, topk_chunk_size): + chunk_end = min(chunk_start + topk_chunk_size, compressed_topk) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=compressed_k_cache, + slot_ids=compressed_slot_ids[:, chunk_start:chunk_end], + lens=topk_lens, + block_size=compressed_block_size, + candidate_offset=chunk_start, + scale=layer.scale, + max_score=comp_max_score, + denom=comp_denom, + acc=comp_acc, + head_block_size=head_block_size, + ) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + slot_ids=swa_indices, + lens=swa_lens, + block_size=swa_metadata.block_size, + scale=layer.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + finish_two_sparse_mla_attention_states_with_sink( + comp_max_score, + comp_denom, + comp_acc, + swa_max_score, + swa_denom, + swa_acc, + layer.attn_sink, + output=output, + ) + if output.shape[1] > layer.n_local_heads: + output[:, layer.n_local_heads :].zero_() + + @classmethod + def _forward_sparse_mla_prefill_triton( + cls, + layer: "DeepseekV4FlashMLAAttention", + q: torch.Tensor, + kv: torch.Tensor, + combined_indices: torch.Tensor, + combined_lens: torch.Tensor, + output: torch.Tensor, + state_buffers: tuple[torch.Tensor, ...] | None = None, + ) -> None: + kv_flat = kv.reshape(-1, q.shape[-1]) + topk_chunk_size = triton_sparse_mla_prefill_topk_chunk_size( + combined_topk_size=combined_indices.shape[-1], + compress_ratio=int(layer.compress_ratio), + request_count=kv.shape[0], + ) + query_chunk_size = min( + q.shape[0], + triton_sparse_mla_query_chunk_size(), + ) + if state_buffers is None: + ( + max_score_buffer, + denom_buffer, + output_buffer, + ) = current_workspace_manager().get_simultaneous( + ((query_chunk_size, layer.n_local_heads), torch.float32), + ((query_chunk_size, layer.n_local_heads), torch.float32), + ((query_chunk_size, layer.n_local_heads, q.shape[-1]), torch.float32), + ) + else: + max_score_buffer, denom_buffer, output_buffer = state_buffers[:3] + indexed_d512_scores = None + indexed_d512_chunked_buffers = None + if ( + state_buffers is not None + and envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL + and layer.compress_ratio in (4, 128) + and q.shape[-1] == 512 + and kv.shape[0] == 1 + and _is_indexed_d512_split_topk(combined_indices.shape[-1]) + and len(state_buffers) == 4 + ): + indexed_d512_scores = state_buffers[3] + elif ( + state_buffers is not None + and envs.VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL + and envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL + and layer.compress_ratio in (4, 128) + and q.shape[-1] == 512 + and kv.shape[0] == 1 + and combined_indices.shape[-1] > _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK + and len(state_buffers) == 7 + ): + indexed_d512_chunked_buffers = state_buffers[3:7] + + for token_start in range(0, q.shape[0], query_chunk_size): + token_end = min(token_start + query_chunk_size, q.shape[0]) + q_chunk = q[token_start:token_end] + indices_chunk_full = combined_indices[token_start:token_end] + lens_chunk = combined_lens[token_start:token_end] + num_tokens = token_end - token_start + max_score = max_score_buffer[:num_tokens] + denom = denom_buffer[:num_tokens] + subset_acc = output_buffer[:num_tokens] + can_use_indexed_d512_scores = ( + indexed_d512_scores is not None + and indexed_d512_scores.shape[0] >= num_tokens + and indexed_d512_scores.shape[2] >= combined_indices.shape[-1] + ) + can_use_indexed_d512_chunked = ( + indexed_d512_chunked_buffers is not None + and indexed_d512_chunked_buffers[0].shape[0] >= num_tokens + ) + if can_use_indexed_d512_scores: + assert indexed_d512_scores is not None + accumulate_indexed_d512_split_sparse_mla_attention( + q=q_chunk, + kv_flat=kv_flat, + indices=indices_chunk_full, + lens=lens_chunk, + scale=layer.scale, + max_score=max_score, + denom=denom, + acc=subset_acc, + scores=indexed_d512_scores[ + :num_tokens, :, : combined_indices.shape[-1] + ], + ) + elif can_use_indexed_d512_chunked: + assert indexed_d512_chunked_buffers is not None + ( + indexed_d512_scores, + chunk_max_score, + chunk_denom, + chunk_acc, + ) = indexed_d512_chunked_buffers + accumulate_indexed_d512_chunked_sparse_mla_attention( + q=q_chunk, + kv_flat=kv_flat, + indices=indices_chunk_full, + lens=lens_chunk, + scale=layer.scale, + max_score=max_score, + denom=denom, + acc=subset_acc, + scores=indexed_d512_scores[:num_tokens], + chunk_max_score=chunk_max_score[:num_tokens], + chunk_denom=chunk_denom[:num_tokens], + chunk_acc=chunk_acc[:num_tokens], + ) + else: + max_score.fill_(float("-inf")) + denom.zero_() + subset_acc.zero_() + + for index_start in range( + 0, combined_indices.shape[-1], topk_chunk_size + ): + index_end = min( + index_start + topk_chunk_size, + combined_indices.shape[-1], + ) + accumulate_indexed_sparse_mla_attention_chunk( + q=q_chunk, + kv_flat=kv_flat, + indices=indices_chunk_full[:, index_start:index_end], + lens=lens_chunk, + candidate_offset=index_start, + scale=layer.scale, + max_score=max_score, + denom=denom, + acc=subset_acc, + ) + + finish_sparse_mla_attention_with_sink( + max_score, + denom, + subset_acc, + layer.attn_sink, + output=output[token_start:token_end], + ) + if output.shape[1] > layer.n_local_heads: + output[token_start:token_end, layer.n_local_heads :].zero_() + def _forward_decode( self, q: torch.Tensor, @@ -189,9 +797,38 @@ def _forward_decode( # Use unsqueeze to preserve strides (handles padded blocks correctly) swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2) # Reshape KV cache to (num_blocks, block_size, 1, head_bytes) + compressed_k_cache = kv_cache if kv_cache is not None: kv_cache = kv_cache.unsqueeze(-2) + if is_triton_sparse_mla_enabled(q.device): + if swa_only: + self._forward_sparse_mla_swa_decode_triton( + layer=self, + q=q, + swa_k_cache=self.swa_cache_layer.kv_cache, + swa_metadata=swa_metadata, + output=output, + ) + return + if self.compress_ratio in (4, 128): + assert compressed_k_cache is not None + assert attn_metadata is not None + assert topk_indices is not None + assert topk_lens is not None + self._forward_sparse_mla_compressed_decode_triton( + layer=self, + q=q, + compressed_k_cache=compressed_k_cache, + swa_k_cache=self.swa_cache_layer.kv_cache, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_metadata=swa_metadata, + attn_metadata=attn_metadata, + output=output, + ) + return + # One FlashMLASchedMeta per layer type, shared across all same-type # layers within this decode step. The first forward call per type # triggers the in-kernel planner (allocating tile_scheduler_metadata @@ -246,6 +883,7 @@ def _forward_prefill( ) -> None: swa_only = attn_metadata is None + num_prefills = swa_metadata.num_prefills num_prefill_tokens = swa_metadata.num_prefill_tokens num_decodes = swa_metadata.num_decodes num_decode_tokens = swa_metadata.num_decode_tokens @@ -253,8 +891,12 @@ def _forward_prefill( # Use pre-computed prefill metadata. seq_lens = swa_metadata.prefill_seq_lens gather_lens = swa_metadata.prefill_gather_lens + seq_lens_cpu = swa_metadata.prefill_seq_lens_cpu + gather_lens_cpu = swa_metadata.prefill_gather_lens_cpu assert seq_lens is not None assert gather_lens is not None + assert seq_lens_cpu is not None + assert gather_lens_cpu is not None # Derive prefill-local token offsets from the full query_start_loc_cpu. query_start_loc_cpu = swa_metadata.query_start_loc_cpu @@ -268,6 +910,29 @@ def _forward_prefill( assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[num_decode_tokens:] topk_indices = topk_indices[:num_prefill_tokens] + # Rebase the indexer's BATCH-GLOBAL compressed top-k positions to + # per-request-local. combine_topk_swa_indices maps a position p of + # the k-th in-chunk request to gathered slot p + M*k, which is only + # correct when p is request-local; the indexer writes cu_seqlen_ks- + # cumulative (batch-global) positions, so without this rebase non- + # first prefill requests index past their gathered slot and read + # stale workspace (latent C4A multi-request prefill bug). torch.where + # preserves -1 sentinels; no-op at num_prefills==1 (cu_base[0]==0). + assert swa_metadata.token_to_req_indices is not None + _comp_lens = seq_lens // self.compress_ratio + _cu_base = (torch.cumsum(_comp_lens, dim=0) - _comp_lens).to( + torch.int32 + ) + _req_local = ( + swa_metadata.token_to_req_indices[ + num_decode_tokens : num_decode_tokens + num_prefill_tokens + ] + - num_decodes + ).long() + _base = _cu_base[_req_local].unsqueeze(1) + topk_indices = torch.where( + topk_indices >= 0, topk_indices - _base, topk_indices + ) else: # C128A: pre-computed during metadata build. assert attn_metadata is not None @@ -278,17 +943,128 @@ def _forward_prefill( assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[num_decode_tokens:] top_k = 0 + + # Adaptive prefill chunk plan (#45061): pack as many requests as fit the + # workspace-area bound into each chunk, with per-chunk compressed (chunk_N) + # and total (chunk_M) widths. Replaces the fixed PREFILL_CHUNK_SIZE + # chunking with batch-wide M/N. chunk_plan = swa_metadata.get_prefill_chunk_plan( - compress_ratio=self.compress_ratio, + compress_ratio=int(self.compress_ratio), prefill_chunk_size=self.PREFILL_CHUNK_SIZE, ) assert chunk_plan, "prefill chunk plan must be non-empty when num_prefills > 0" + + max_query_chunk_tokens = 0 + for chunk_start, chunk_end, _chunk_n, _chunk_m in chunk_plan: + query_start = ( + query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base + ) + query_end = ( + query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base + ) + max_query_chunk_tokens = max( + max_query_chunk_tokens, int(query_end - query_start) + ) + combined_topk = sparse_prefill_combined_topk_size(top_k, self.window_size) + workspace_manager = current_workspace_manager() - for chunk_start, chunk_end, chunk_N, chunk_M in chunk_plan: + triton_sparse_mla_enabled = is_triton_sparse_mla_enabled(q.device) + indexed_d512_split_prefill = False + indexed_d512_chunked_prefill = False + extra_specs: list[tuple[tuple[int, ...], torch.dtype]] = [] + if triton_sparse_mla_enabled: + query_chunk_size = min( + max_query_chunk_tokens, + triton_sparse_mla_query_chunk_size(), + ) + indexed_d512_split_prefill = _use_indexed_d512_split_prefill( + compress_ratio=int(self.compress_ratio), + head_dim=int(self.head_dim), + num_prefills=int(num_prefills), + combined_topk=int(combined_topk), + max_prefill_seq_len=int(seq_lens_cpu.max().item()), + swa_only=swa_only, + ) + if not indexed_d512_split_prefill: + indexed_d512_chunked_prefill = _use_indexed_d512_chunked_prefill( + compress_ratio=int(self.compress_ratio), + head_dim=int(self.head_dim), + num_prefills=int(num_prefills), + combined_topk=int(combined_topk), + max_prefill_seq_len=int(seq_lens_cpu.max().item()), + swa_only=swa_only, + ) + if indexed_d512_split_prefill: + extra_specs.append( + ( + (query_chunk_size, self.n_local_heads, combined_topk), + torch.float32, + ) + ) + elif indexed_d512_chunked_prefill: + chunked_score_width = min( + combined_topk, + _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK, + ) + extra_specs.extend( + ( + ( + (query_chunk_size, self.n_local_heads, chunked_score_width), + torch.float32, + ), + ((query_chunk_size, self.n_local_heads), torch.float32), + ((query_chunk_size, self.n_local_heads), torch.float32), + ( + (query_chunk_size, self.n_local_heads, q.shape[-1]), + torch.float32, + ), + ) + ) + + # Per-chunk workspace allocation (#45061): the kv buffer width is this + # chunk's compressed+gather width (chunk_m), keeping the area bounded by + # the planner's max_workspace_area instead of the batch-wide worst case. + for chunk_start, chunk_end, chunk_n, chunk_m in chunk_plan: chunk_size = chunk_end - chunk_start - kv = workspace_manager.get_simultaneous( - ((chunk_size, chunk_M, q.shape[-1]), torch.bfloat16), - )[0] + if triton_sparse_mla_enabled: + ( + kv, + combined_indices_buffer, + combined_lens_buffer, + max_score_buffer, + denom_buffer, + output_buffer, + *extra_state_buffers, + ) = workspace_manager.get_simultaneous( + ((chunk_size, chunk_m, q.shape[-1]), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ((query_chunk_size, self.n_local_heads), torch.float32), + ((query_chunk_size, self.n_local_heads), torch.float32), + ( + (query_chunk_size, self.n_local_heads, q.shape[-1]), + torch.float32, + ), + *extra_specs, + ) + prefill_state_buffers = ( + max_score_buffer, + denom_buffer, + output_buffer, + *extra_state_buffers, + ) + else: + ( + kv, + combined_indices_buffer, + combined_lens_buffer, + ) = workspace_manager.get_simultaneous( + ((chunk_size, chunk_m, q.shape[-1]), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ) + prefill_state_buffers = None + if not swa_only: # Gather compressed KV assert attn_metadata is not None @@ -312,7 +1088,7 @@ def _forward_prefill( gather_lens=gather_lens[chunk_start:chunk_end], block_table=swa_block_table[chunk_start:chunk_end], block_size=swa_metadata.block_size, - offset=chunk_N, + offset=chunk_n, ) # Combine the topk indices and SWA indices for gathered KV cache @@ -333,15 +1109,28 @@ def _forward_prefill( self.window_size, self.compress_ratio, top_k, - chunk_M, - chunk_N, - ) - flash_mla_sparse_fwd( - q=q[query_start:query_end], - kv=kv.view(-1, 1, q.shape[-1]), - indices=combined_indices.unsqueeze(1), - sm_scale=self.scale, - attn_sink=self.attn_sink, - topk_length=combined_lens, - out=output[query_start:query_end], + chunk_m, + chunk_n, + combined_indices=combined_indices_buffer, + combined_lens=combined_lens_buffer, ) + if triton_sparse_mla_enabled: + self._forward_sparse_mla_prefill_triton( + self, + q=q[query_start:query_end], + kv=kv[:chunk_size], + combined_indices=combined_indices, + combined_lens=combined_lens, + output=output[query_start:query_end], + state_buffers=prefill_state_buffers, + ) + else: + flash_mla_sparse_fwd( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + sm_scale=self.scale, + attn_sink=self.attn_sink, + topk_length=combined_lens, + out=output[query_start:query_end], + ) diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index aa60ad34ce32..a3227a5d8bb1 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -66,6 +66,7 @@ from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import deep_gemm from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -310,10 +311,6 @@ def finalize_weights(self) -> None: return self._check_runtime_supported() - from vllm.utils.deep_gemm import _import_deep_gemm - - deep_gemm = _import_deep_gemm() - w13_scale = deep_gemm.transform_sf_into_required_layout( self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(), 2 * self.intermediate_size, @@ -346,10 +343,6 @@ def finalize_weights(self) -> None: self.w2_weight_scale = None def get_symm_buffer(self): - from vllm.utils.deep_gemm import _import_deep_gemm - - deep_gemm = _import_deep_gemm() - group = get_ep_group().device_group device = torch.accelerator.current_device_index() key = ( @@ -436,10 +429,6 @@ def forward( ) y = torch.empty_like(hidden_states, dtype=torch.bfloat16) - from vllm.utils.deep_gemm import _import_deep_gemm - - deep_gemm = _import_deep_gemm() - symm_buffer = self.get_symm_buffer() num_tokens = hidden_states.shape[0] @@ -575,6 +564,7 @@ def __init__( prefix=f"{prefix}.shared_experts", ) + self.n_shared_experts = config.n_shared_experts or 0 if self.use_mega_moe: self._init_mega_moe_experts(vllm_config, config, prefix) else: @@ -593,7 +583,6 @@ def _init_mega_moe_experts( eplb_config = vllm_config.parallel_config.eplb_config self.n_redundant_experts = eplb_config.num_redundant_experts self.n_routed_experts = config.n_routed_experts - self.n_shared_experts = config.n_shared_experts or 0 self.n_logical_experts = self.n_routed_experts self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts assert self.n_physical_experts % self.ep_size == 0, ( @@ -667,6 +656,43 @@ def _init_fused_moe_experts( enable_eplb=parallel_config.enable_eplb, num_redundant_experts=eplb_config.num_redundant_experts, ) + self._sync_fused_moe_metadata() + + def _sync_fused_moe_metadata(self) -> None: + experts = self.experts + moe_config = getattr(experts, "moe_config", None) + routed_experts = getattr(experts, "routed_experts", experts) + + def get_optional_attr(obj, name: str): + return None if obj is None else getattr(obj, name, None) + + def first_defined(*values): + return next((value for value in values if value is not None), None) + + self.n_logical_experts = first_defined( + get_optional_attr(experts, "logical_num_experts"), + get_optional_attr(moe_config, "num_logical_experts"), + ) + self.n_physical_experts = first_defined( + get_optional_attr(routed_experts, "global_num_experts"), + get_optional_attr(experts, "global_num_experts"), + get_optional_attr(moe_config, "num_experts"), + ) + self.n_local_physical_experts = first_defined( + get_optional_attr(routed_experts, "local_num_experts"), + get_optional_attr(experts, "local_num_experts"), + get_optional_attr(moe_config, "num_local_experts"), + ) + if ( + self.n_logical_experts is None + or self.n_physical_experts is None + or self.n_local_physical_experts is None + ): + raise AttributeError( + "DeepseekV4MoE FusedMoE metadata is incomplete after construction." + ) + self.n_local_experts = self.n_local_physical_experts + self.n_redundant_experts = self.n_physical_experts - self.n_logical_experts def forward( self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None @@ -738,10 +764,20 @@ def finalize_mega_moe_weights(self) -> None: def _select_dsv4_attn_cls(vllm_config: VllmConfig) -> type[DeepseekV4Attention]: """Pick the CUDA sparse-MLA attention class for the configured backend. - The generic CUDA backend selector does not instantiate DSv4 layers directly, - so map generic sparse-MLA choices to the DSv4-specialized attention class. - Without an explicit backend, SM12 defaults to FlashInfer while the other - CUDA arches keep the FlashMLA path. + An explicit ``--attention-backend FLASHINFER_MLA_SPARSE_DSV4`` selects the + FlashInfer TRTLLM-gen path; the generic ``FLASHMLA_SPARSE*`` choices map to + the DSv4-specialized FlashMLA class. Without an explicit backend the FlashMLA + class is used; on SM12x it runs through the stock Triton sparse-MLA route, so + it serves on released deps without FlashInfer's unmerged SM120 sparse-MLA + fork. + + When ``VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE`` is set and the runtime is + SM12x with FlashInfer's packed sparse-MLA decode kernel available, decode is + routed through the official ``trtllm_batch_decode_sparse_mla_dsv4`` SM120 + kernel (FlashInfer PR3395, released in flashinfer >= 0.6.13) instead of the + FlashMLA decode kernel; everything else (packed ``fp8_ds_mla`` cache, + metadata, prefill) is unchanged. Availability-gated (silent FlashMLA fallback + when the kernel is absent). Default off. """ backend = vllm_config.attention_config.backend device_capability = current_platform.get_device_capability() @@ -764,8 +800,29 @@ def _select_dsv4_attn_cls(vllm_config: VllmConfig) -> type[DeepseekV4Attention]: ): return DeepseekV4FlashMLAAttention - if device_capability is not None and device_capability.major == 12: - return DeepseekV4FlashInferSM120Attention + # Opt-in: route SM12x decode through FlashInfer's official packed sparse-MLA + # decode kernel (PR3395, released in flashinfer >= 0.6.13) when present. + # Availability-gated, so stock installs without that kernel fall through to + # the FlashMLA/Triton-sparse default below instead of raising. We deliberately + # do NOT hard-default SM12 to the FlashInfer sparse-MLA class -- that path + # needs the unmerged FlashInfer SM120 sparse-MLA fork and raises on released + # deps (see #43477). + import vllm.envs as envs + + if envs.VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE: + from vllm.utils.flashinfer import has_flashinfer_trtllm_sparse_mla_dsv4 + + if ( + device_capability is not None + and device_capability.major == 12 + and has_flashinfer_trtllm_sparse_mla_dsv4() + ): + from vllm.models.deepseek_v4.nvidia.flashinfer_sm120_decode import ( + DeepseekV4FlashInferSM120DecodeAttention, + ) + + return DeepseekV4FlashInferSM120DecodeAttention + return DeepseekV4FlashMLAAttention @@ -1396,6 +1453,10 @@ def get_mtp_target_hidden_states(self) -> torch.Tensor | None: forward(); valid after each target step.""" return getattr(self.model, "_mtp_hidden_buffer", None) + def skip_weight_name_before_load(self, name: str) -> bool: + mapped = self.hf_to_vllm_mapper._map_name(name) + return mapped is None or "mtp." in mapped + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_substrs=["mtp."]) loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/models/deepseek_v4/nvidia/mtp.py b/vllm/models/deepseek_v4/nvidia/mtp.py index 64715deae99b..3f9253fe826d 100644 --- a/vllm/models/deepseek_v4/nvidia/mtp.py +++ b/vllm/models/deepseek_v4/nvidia/mtp.py @@ -158,7 +158,12 @@ def forward( inputs_embeds ).unsqueeze(-2) hidden_states, residual, post_mix, res_mix = self.mtp_block( - positions=positions, x=hidden_states, input_ids=None + x=hidden_states, + positions=positions, + input_ids=input_ids, + post_mix=None, + res_mix=None, + residual=None, ) hidden_states = mhc_post_tilelang(hidden_states, residual, post_mix, res_mix) # Return the flat pre-hc_head residual so it can be re-fed as the diff --git a/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py new file mode 100644 index 000000000000..09c0f4aa3f5e --- /dev/null +++ b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""SM12x Triton FP8 einsum kernels for DeepSeek V4.""" + +import torch + +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, +) +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils.deep_gemm import fp8_einsum + + +@triton.jit +def _deepseek_v4_sm12x_fp8_einsum_kernel( + a_ptr, + a_scale_ptr, + b_ptr, + b_scale_ptr, + out_ptr, + num_tokens: tl.constexpr, + num_groups: tl.constexpr, + out_rank: tl.constexpr, + hidden_size: tl.constexpr, + a_stride_token: tl.constexpr, + a_stride_group: tl.constexpr, + a_stride_hidden: tl.constexpr, + a_scale_stride_token: tl.constexpr, + a_scale_stride_group: tl.constexpr, + a_scale_stride_hidden: tl.constexpr, + b_stride_group: tl.constexpr, + b_stride_out: tl.constexpr, + b_stride_hidden: tl.constexpr, + b_scale_stride_group: tl.constexpr, + b_scale_stride_out: tl.constexpr, + b_scale_stride_hidden: tl.constexpr, + out_stride_token: tl.constexpr, + out_stride_group: tl.constexpr, + out_stride_rank: tl.constexpr, + BLOCK_TOKENS: tl.constexpr, + BLOCK_OUT: tl.constexpr, + BLOCK_HIDDEN: tl.constexpr, +) -> None: + token_block = tl.program_id(0) + out_block = tl.program_id(1) + group = tl.program_id(2) + + token_offsets = token_block * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS) + out_offsets = out_block * BLOCK_OUT + tl.arange(0, BLOCK_OUT) + hidden_offsets = tl.arange(0, BLOCK_HIDDEN) + accum = tl.zeros((BLOCK_TOKENS, BLOCK_OUT), dtype=tl.float32) + + for hidden_start in range(0, hidden_size, BLOCK_HIDDEN): + hidden = hidden_start + hidden_offsets + a = tl.load( + a_ptr + + token_offsets[:, None] * a_stride_token + + group * a_stride_group + + hidden[None, :] * a_stride_hidden, + mask=(token_offsets[:, None] < num_tokens) + & (hidden[None, :] < hidden_size), + other=0.0, + ) + b = tl.load( + b_ptr + + group * b_stride_group + + out_offsets[None, :] * b_stride_out + + hidden[:, None] * b_stride_hidden, + mask=(out_offsets[None, :] < out_rank) & (hidden[:, None] < hidden_size), + other=0.0, + ) + raw = tl.dot(a, b, out_dtype=tl.float32) + hidden_scale_block = hidden_start // BLOCK_HIDDEN + a_scale = tl.load( + a_scale_ptr + + token_offsets * a_scale_stride_token + + group * a_scale_stride_group + + hidden_scale_block * a_scale_stride_hidden, + mask=token_offsets < num_tokens, + other=0.0, + ) + b_scale = tl.load( + b_scale_ptr + + group * b_scale_stride_group + + (out_offsets // 128) * b_scale_stride_out + + hidden_scale_block * b_scale_stride_hidden, + mask=out_offsets < out_rank, + other=0.0, + ) + accum += raw * a_scale[:, None] * b_scale[None, :] + + tl.store( + out_ptr + + token_offsets[:, None] * out_stride_token + + group * out_stride_group + + out_offsets[None, :] * out_stride_rank, + accum, + mask=(token_offsets[:, None] < num_tokens) & (out_offsets[None, :] < out_rank), + ) + + +def deepseek_v4_sm12x_fp8_einsum( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, +) -> None: + """Compute ``bhr,hdr->bhd`` with FP32 block scales on SM12x. + + ``a`` is the transposed output of ``fused_inv_rope_fp8_quant`` with shape + ``[tokens, groups, hidden]``. ``b`` is ``wo_a`` reshaped to + ``[groups, out_rank, hidden]``. + """ + num_tokens, num_groups, hidden_size = a.shape + b_groups, out_rank, b_hidden_size = b.shape + assert b_groups == num_groups + assert b_hidden_size == hidden_size + assert out.shape == (num_tokens, num_groups, out_rank) + assert hidden_size % 128 == 0 + assert out_rank % 128 == 0 + assert a.dtype == torch.float8_e4m3fn + assert b.dtype == torch.float8_e4m3fn + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if a_scale.dtype == e8m0_dtype: + a_scale = _upcast_e8m0_to_fp32(a_scale) + if b_scale.dtype == e8m0_dtype: + b_scale = _upcast_e8m0_to_fp32(b_scale) + assert a_scale.dtype == torch.float32 + assert b_scale.dtype == torch.float32 + + if num_tokens == 0: + return + + block_tokens = 16 + block_out = 128 + block_hidden = 128 + grid = ( + triton.cdiv(num_tokens, block_tokens), + triton.cdiv(out_rank, block_out), + num_groups, + ) + _deepseek_v4_sm12x_fp8_einsum_kernel[grid]( + a, + a_scale, + b, + b_scale, + out, + num_tokens, + num_groups, + out_rank, + hidden_size, + a.stride(0), + a.stride(1), + a.stride(2), + a_scale.stride(0), + a_scale.stride(1), + a_scale.stride(2), + b.stride(0), + b.stride(1), + b.stride(2), + b_scale.stride(0), + b_scale.stride(1), + b_scale.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_TOKENS=block_tokens, + BLOCK_OUT=block_out, + BLOCK_HIDDEN=block_hidden, + num_warps=4, + num_stages=3, + ) + + +def deepseek_v4_fp8_einsum_config( + capability_major: int, +) -> tuple[tuple[int, int, int], bool]: + if capability_major == 10: + return (1, 1, 128), True + return (1, 128, 128), False + + +def _use_deepseek_v4_sm12x_triton_fp8_einsum( + equation: str, + recipe: list[int], + b_scale: torch.Tensor, +) -> bool: + capability = current_platform.get_device_capability() + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + return ( + capability is not None + and capability.major == 12 + and equation == "bhr,hdr->bhd" + and tuple(recipe) == (1, 128, 128) + and b_scale.dtype in (torch.float32, e8m0_dtype) + ) + + +def deepseek_v4_fp8_einsum( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, + recipe: list[int], +) -> None: + if equation == "bhr,hdr->bhd" and b.dim() == 2: + num_groups = out.shape[1] + out_rank = out.shape[2] + hidden_size = a.shape[2] + if b.shape[0] % out_rank != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum weight rows must be divisible by " + f"out_rank={out_rank}, got {b.shape[0]}" + ) + b_groups = b.shape[0] // out_rank + group_start = 0 + if b_groups != num_groups: + if b_groups % num_groups != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum weight groups must match the " + "TP-local output groups or be an integer multiple of " + f"them, got weight_groups={b_groups}, " + f"output_groups={num_groups}" + ) + group_partitions = b_groups // num_groups + group_start = ( + get_tensor_model_parallel_rank() % group_partitions + ) * num_groups + b = b.view(b_groups, out_rank, hidden_size) + if group_start != 0 or b_groups != num_groups: + b = b.narrow(0, group_start, num_groups) + + if b_scale.dim() == 2: + scale_mn = recipe[1] + scale_k_pack = 4 if b_scale.dtype == torch.int32 else 1 + scale_k = recipe[2] * scale_k_pack + scale_out_blocks = (out_rank + scale_mn - 1) // scale_mn + scale_hidden_blocks = (hidden_size + scale_k - 1) // scale_k + if b_scale.shape[0] % scale_out_blocks != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum scale rows must be divisible by " + f"scale_out_blocks={scale_out_blocks}, " + f"got {b_scale.shape[0]}" + ) + scale_groups = b_scale.shape[0] // scale_out_blocks + if scale_groups not in (num_groups, b_groups): + raise RuntimeError( + "DeepSeek V4 fp8 einsum scale groups must match the " + "TP-local output groups or weight groups, got " + f"scale_groups={scale_groups}, output_groups={num_groups}, " + f"weight_groups={b_groups}" + ) + b_scale = b_scale.view( + scale_groups, + scale_out_blocks, + scale_hidden_blocks, + ) + if scale_groups == b_groups and scale_groups != num_groups: + b_scale = b_scale.narrow(0, group_start, num_groups) + elif b_scale.dim() == 3 and b_scale.shape[0] == b_groups: + if b_groups != num_groups: + b_scale = b_scale.narrow(0, group_start, num_groups) + + if _use_deepseek_v4_sm12x_triton_fp8_einsum(equation, recipe, b_scale): + deepseek_v4_sm12x_fp8_einsum(a, a_scale, b, b_scale, out) + return + + fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) diff --git a/vllm/models/deepseek_v4/nvidia/ops/o_proj.py b/vllm/models/deepseek_v4/nvidia/ops/o_proj.py index 18e3b10562bd..1b6f3db7de5c 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/o_proj.py +++ b/vllm/models/deepseek_v4/nvidia/ops/o_proj.py @@ -6,23 +6,26 @@ from vllm.models.deepseek_v4.common.ops.fused_inv_rope_fp8_quant import ( fused_inv_rope_fp8_quant, ) +from vllm.models.deepseek_v4.nvidia.ops.fp8_einsum import ( + deepseek_v4_fp8_einsum, + deepseek_v4_fp8_einsum_config, +) from vllm.platforms import current_platform -from vllm.utils.deep_gemm import fp8_einsum def compute_fp8_einsum_recipe() -> tuple[tuple[int, int, int], bool]: """fp8_einsum recipe + scale layout for the current GPU arch. SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128. - SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1. + SM100/SM110: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1. + SM12x: RTX PRO / GB10 does not expose the same TMA/TCGEN05 path, so keep + the legacy FP32 block-scale layout expected by DeepGEMM. Returns ``(einsum_recipe, tma_aligned_scales)`` for ``deep_gemm_fp8_o_proj``. """ cap = current_platform.get_device_capability() assert cap is not None, "DeepseekV4 attention requires a CUDA device" - einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) - tma_aligned_scales = cap.major >= 10 - return einsum_recipe, tma_aligned_scales + return deepseek_v4_fp8_einsum_config(cap.major) def deep_gemm_fp8_o_proj( @@ -60,11 +63,18 @@ def deep_gemm_fp8_o_proj( device=o.device, dtype=torch.bfloat16, ) - fp8_einsum( - "bhr,hdr->bhd", - (o_fp8, o_scale), - (wo_a.weight, wo_a.weight_scale_inv), + # MarlinFP8.process_weights_after_loading renames block-FP8 scales to + # weight_scale_inv. Non-Marlin kernels keep the on-disk weight_scale name. + wo_a_scale = getattr(wo_a, "weight_scale_inv", None) + if wo_a_scale is None: + wo_a_scale = wo_a.weight_scale + deepseek_v4_fp8_einsum( + o_fp8, + o_scale, + wo_a.weight, + wo_a_scale, z, - recipe=einsum_recipe, + "bhr,hdr->bhd", + list(einsum_recipe), ) return wo_b(z.flatten(1)) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py new file mode 100644 index 000000000000..7ed1ecc4c05c --- /dev/null +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py @@ -0,0 +1,721 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""SM12x fallback implementations for DeepGEMM-only interfaces.""" + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +_SM120_MQA_LOGITS_MAX_SCORE_BYTES = 64 * 1024 * 1024 +_SM120_MQA_TRITON_TOPK_MAX_LOGITS_BYTES = 512 * 1024 * 1024 +_SM120_MQA_TRITON_CHUNKED_TOPK_CHUNK_SIZE = 32768 +_SM120_PAGED_MQA_TOPK_CHUNK_SIZE = 8192 + + +def _top_k_per_row_prefill_op(): + try: + from vllm import _custom_ops as _custom_ops # noqa: F401 + + return torch.ops._C.top_k_per_row_prefill + except (AttributeError, ImportError, RuntimeError): + return None + + +def _fp8_mqa_logits_head_chunk_size( + seq_len: int, + seq_len_kv: int, + num_heads: int, +) -> int: + # The SM120 torch path is used on long prefill paths where materializing + # [head_chunk, M, N] scores can otherwise allocate multiple GiB. Keep the + # transient score tensor bounded, while still using larger head chunks for + # short prompts where they are faster. + score_elems_per_head = max(1, seq_len * seq_len_kv) + max_heads = _SM120_MQA_LOGITS_MAX_SCORE_BYTES // (score_elems_per_head * 4) + return max(1, min(8, num_heads, max_heads)) + + +def _fp8_mqa_logits_k_chunk_size( + seq_len: int, + seq_len_kv: int, + head_chunk_size: int, +) -> int: + score_elems_per_key = max(1, seq_len * head_chunk_size) + max_keys = _SM120_MQA_LOGITS_MAX_SCORE_BYTES // (score_elems_per_key * 4) + return max(1, min(seq_len_kv, max_keys)) + + +def _fp8_mqa_logits_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 MQA logits torch path only supports FP8 Q") + + k_values, k_scales = kv + k_f32 = k_values.to(torch.float32) + k_f32.mul_(k_scales.reshape(-1, 1).to(torch.float32)) + k_t = k_f32.transpose(0, 1).contiguous() + + seq_len, num_heads, _ = q_values.shape + seq_len_kv = k_f32.shape[0] + logits = torch.zeros( + (seq_len, seq_len_kv), device=q_values.device, dtype=torch.float32 + ) + head_chunk_size = _fp8_mqa_logits_head_chunk_size(seq_len, seq_len_kv, num_heads) + + for head_start in range(0, num_heads, head_chunk_size): + head_end = min(head_start + head_chunk_size, num_heads) + q_chunk = q_values[:, head_start:head_end, :].to(torch.float32) + q_chunk = q_chunk.transpose(0, 1).contiguous() + head_weights = weights[:, head_start:head_end].transpose(0, 1).unsqueeze(-1) + k_chunk_size = _fp8_mqa_logits_k_chunk_size( + seq_len, seq_len_kv, head_end - head_start + ) + for k_start in range(0, seq_len_kv, k_chunk_size): + k_end = min(k_start + k_chunk_size, seq_len_kv) + scores = torch.matmul(q_chunk, k_t[:, k_start:k_end]) + scores.relu_() + scores.mul_(head_weights) + logits[:, k_start:k_end].add_( + scores[0] if scores.shape[0] == 1 else scores.sum(dim=0) + ) + + if clean_logits: + offsets = torch.arange(seq_len_kv, device=q_values.device) + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + logits = logits.masked_fill(~valid, float("-inf")) + + return logits + + +def _fp8_mqa_logits_topk_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_tokens: int, + out: torch.Tensor | None = None, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 MQA top-k torch path only supports FP8 Q") + + k_values, k_scales = kv + k_f32 = k_values.to(torch.float32) + k_f32.mul_(k_scales.reshape(-1, 1).to(torch.float32)) + k_t = k_f32.transpose(0, 1).contiguous() + + seq_len, num_heads, _ = q_values.shape + seq_len_kv = k_f32.shape[0] + if out is None: + out = torch.empty( + (seq_len, topk_tokens), device=q_values.device, dtype=torch.int32 + ) + else: + assert out.shape == (seq_len, topk_tokens) + assert out.dtype == torch.int32 + out.fill_(-1) + + best_values = torch.full( + (seq_len, topk_tokens), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + head_chunk_size = _fp8_mqa_logits_head_chunk_size(seq_len, seq_len_kv, num_heads) + k_chunk_size = _fp8_mqa_logits_k_chunk_size(seq_len, seq_len_kv, head_chunk_size) + max_chunk_topk = min(topk_tokens, k_chunk_size) + chunk_values_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + chunk_indices_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int64, + ) + chunk_indices_i32 = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + candidate_values = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + candidate_indices = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + next_best_values = torch.empty_like(best_values) + selected = torch.empty( + (seq_len, topk_tokens), + device=q_values.device, + dtype=torch.int64, + ) + + for k_start in range(0, seq_len_kv, k_chunk_size): + k_end = min(k_start + k_chunk_size, seq_len_kv) + chunk_logits = torch.zeros( + (seq_len, k_end - k_start), + device=q_values.device, + dtype=torch.float32, + ) + for head_start in range(0, num_heads, head_chunk_size): + head_end = min(head_start + head_chunk_size, num_heads) + q_chunk = q_values[:, head_start:head_end, :].to(torch.float32) + q_chunk = q_chunk.transpose(0, 1).contiguous() + head_weights = weights[:, head_start:head_end].transpose(0, 1).unsqueeze(-1) + scores = torch.matmul(q_chunk, k_t[:, k_start:k_end]) + scores.relu_() + scores.mul_(head_weights) + chunk_logits.add_(scores[0] if scores.shape[0] == 1 else scores.sum(dim=0)) + + offsets = torch.arange(k_start, k_end, device=q_values.device) + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + chunk_logits.masked_fill_(~valid, float("-inf")) + + chunk_topk = min(topk_tokens, k_end - k_start) + chunk_values = chunk_values_buf[:, :chunk_topk] + chunk_indices = chunk_indices_buf[:, :chunk_topk] + torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices)) + chunk_indices_out = chunk_indices_i32[:, :chunk_topk] + chunk_indices_out.copy_(chunk_indices) + chunk_indices_out.add_(k_start) + + candidate_cols = topk_tokens + chunk_topk + candidate_values_view = candidate_values[:, :candidate_cols] + candidate_indices_view = candidate_indices[:, :candidate_cols] + candidate_values_view[:, :topk_tokens].copy_(best_values) + candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values) + candidate_indices_view[:, :topk_tokens].copy_(out) + candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out) + torch.topk( + candidate_values_view, + topk_tokens, + dim=1, + out=(next_best_values, selected), + ) + torch.gather(candidate_indices_view, 1, selected, out=out) + best_values, next_best_values = next_best_values, best_values + out.masked_fill_(~torch.isfinite(best_values), -1) + + return out + + +def _fp8_mqa_logits_topk_triton( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + out: torch.Tensor, +) -> bool: + q_values, q_scale = q + k_values, _ = kv + if not (q_scale is None and q_values.dim() == 3 and k_values.dim() == 2): + return False + + logits_bytes = q_values.shape[0] * k_values.shape[0] * torch.float32.itemsize + if logits_bytes > _SM120_MQA_TRITON_TOPK_MAX_LOGITS_BYTES: + return False + + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( + fp8_mqa_logits_triton, + ) + + logits = fp8_mqa_logits_triton(q_values, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + topk_tokens = out.shape[1] + select_k = min(topk_tokens, logits.shape[1]) + out.fill_(-1) + if select_k == 0: + return True + + selected = out[:, :select_k] + topk_op = _top_k_per_row_prefill_op() + if topk_op is not None: + # top_k_per_row_prefill writes its output as a contiguous [M, select_k] + # buffer (it is given the logits strides, not the output strides). When + # select_k < out.shape[1] -- i.e. the compressed-KV count is below the + # topk width, which happens for short prompts and the early queries of + # long prompts -- out[:, :select_k] is non-contiguous (row stride = + # out.shape[1]), so writing it as contiguous silently corrupts later + # rows and drops their top-k (all -1). Hand the op a fresh contiguous + # buffer and copy back. The buffer is left uninitialized (new_empty) + # rather than copied (.contiguous()) because the op overwrites every + # element, so copying the placeholder -1s in would be wasted work. + work = ( + selected if selected.is_contiguous() else selected.new_empty(selected.shape) + ) + topk_op( + logits, + cu_seqlen_ks, + cu_seqlen_ke, + work, + logits.shape[0], + logits.stride(0), + logits.stride(1), + select_k, + ) + work.add_(cu_seqlen_ks[:, None]) + valid = (work >= cu_seqlen_ks[:, None]) & (work < cu_seqlen_ke[:, None]) + work.masked_fill_(~valid, -1) + if work is not selected: + selected.copy_(work) + else: + values, indices = torch.topk(logits, select_k, dim=1) + selected.copy_(indices.to(torch.int32)) + selected.masked_fill_(~torch.isfinite(values), -1) + return True + + +def _fp8_mqa_logits_topk_triton_chunked( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + out: torch.Tensor, +) -> bool: + q_values, q_scale = q + k_values, k_scales = kv + if not (q_scale is None and q_values.dim() == 3 and k_values.dim() == 2): + return False + + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( + fp8_mqa_logits_triton, + ) + + seq_len = q_values.shape[0] + seq_len_kv = k_values.shape[0] + topk_tokens = out.shape[1] + out.fill_(-1) + if seq_len == 0 or seq_len_kv == 0 or topk_tokens == 0: + return True + + chunk_size = max(1, _SM120_MQA_TRITON_CHUNKED_TOPK_CHUNK_SIZE) + best_values = torch.full( + (seq_len, topk_tokens), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + max_chunk_topk = min(topk_tokens, chunk_size) + chunk_values_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + chunk_indices_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int64, + ) + chunk_indices_i32 = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + candidate_values = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + candidate_indices = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + next_best_values = torch.empty_like(best_values) + selected = torch.empty( + (seq_len, topk_tokens), + device=q_values.device, + dtype=torch.int64, + ) + + for k_start in range(0, seq_len_kv, chunk_size): + k_end = min(k_start + chunk_size, seq_len_kv) + local_width = k_end - k_start + local_ks = torch.clamp(cu_seqlen_ks - k_start, min=0, max=local_width) + local_ke = torch.clamp(cu_seqlen_ke - k_start, min=0, max=local_width) + chunk_logits = fp8_mqa_logits_triton( + q_values, + (k_values[k_start:k_end], k_scales[k_start:k_end]), + weights, + local_ks, + local_ke, + ) + chunk_topk = min(topk_tokens, local_width) + chunk_values = chunk_values_buf[:, :chunk_topk] + chunk_indices = chunk_indices_buf[:, :chunk_topk] + torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices)) + chunk_indices_out = chunk_indices_i32[:, :chunk_topk] + chunk_indices_out.copy_(chunk_indices) + chunk_indices_out.add_(k_start) + + candidate_cols = topk_tokens + chunk_topk + candidate_values_view = candidate_values[:, :candidate_cols] + candidate_indices_view = candidate_indices[:, :candidate_cols] + candidate_values_view[:, :topk_tokens].copy_(best_values) + candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values) + candidate_indices_view[:, :topk_tokens].copy_(out) + candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out) + torch.topk( + candidate_values_view, + topk_tokens, + dim=1, + out=(next_best_values, selected), + ) + torch.gather(candidate_indices_view, 1, selected, out=out) + best_values, next_best_values = next_best_values, best_values + out.masked_fill_(~torch.isfinite(best_values), -1) + + return True + + +def fp8_fp4_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 MQA top-k indices without materializing full logits.""" + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q[1] is None + ): + return False + if _fp8_mqa_logits_topk_triton( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_indices, + ): + return True + if _fp8_mqa_logits_topk_triton_chunked( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_indices, + ): + return True + _fp8_mqa_logits_topk_torch( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_indices.shape[1], + out=topk_indices, + ) + return True + + +def _fp8_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + q_values, q_scale = q + if clean_logits and q_scale is None and q_values.dim() == 3 and kv[0].dim() == 2: + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( + fp8_mqa_logits_triton, + ) + + return fp8_mqa_logits_triton(q_values, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + return _fp8_mqa_logits_torch( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) + + +def _fp8_paged_mqa_logits_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 paged MQA torch path only supports FP8 Q") + + batch_size, next_n, num_heads, head_dim = q_values.shape + head_dim_with_scale = kv_cache.shape[-1] + assert head_dim_with_scale > head_dim + assert weights.shape == (batch_size * next_n, num_heads) + assert context_lens.shape == (batch_size, next_n) + + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( + _view_packed_fp8_paged_mqa_kv_cache, + ) + + kv_values, kv_scales = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_kv, _, _ = kv_values.shape + logits = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + + q_f32 = q_values.float() + score_bytes = _SM120_MQA_LOGITS_MAX_SCORE_BYTES + max_tokens_per_chunk = max(1, score_bytes // max(1, num_heads * 4)) + token_offsets_cache: dict[int, torch.Tensor] = {} + + for batch_idx in range(batch_size): + for next_idx in range(next_n): + row = batch_idx * next_n + next_idx + context_len = int(context_lens[batch_idx, next_idx].item()) + if context_len <= 0: + continue + + q_row = q_f32[batch_idx, next_idx] + row_weights = weights[row] + for token_start in range(0, context_len, max_tokens_per_chunk): + token_end = min(context_len, token_start + max_tokens_per_chunk) + chunk_len = token_end - token_start + token_offsets = token_offsets_cache.get(chunk_len) + if token_offsets is None or token_offsets.device != q_values.device: + token_offsets = torch.arange( + chunk_len, device=q_values.device, dtype=torch.long + ) + token_offsets_cache[chunk_len] = token_offsets + token_ids = token_start + token_offsets + logical_blocks = token_ids // block_kv + token_in_block = token_ids - logical_blocks * block_kv + physical_blocks = block_tables[batch_idx, logical_blocks] + kv_chunk = kv_values[physical_blocks, token_in_block, 0].float() + scale_chunk = kv_scales[physical_blocks, token_in_block, 0].squeeze(-1) + kv_chunk.mul_(scale_chunk[:, None]) + scores = torch.matmul(q_row, kv_chunk.T) + scores.relu_() + scores.mul_(row_weights[:, None]) + logits[row, token_start:token_end] = scores.sum(dim=0) + + return logits + + +def _fp8_paged_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + q_values, q_scale = q + if ( + q_scale is None + and q_values.dim() == 4 + and kv_cache.dtype == torch.uint8 + and kv_cache.shape[-1] == q_values.shape[-1] + 4 + ): + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( + fp8_paged_mqa_logits_triton, + ) + + return fp8_paged_mqa_logits_triton( + q_values, kv_cache, weights, context_lens, block_tables, max_model_len + ) + logger.warning_once( + "SM12x paged-MQA falling back to the torch reference path " + "(q_scale=%s, q.dim=%s, kv_cache.dtype=%s, kv_cache.shape[-1]=%s, " + "q_values.shape[-1]=%s). This path is intended for correctness checks " + "and is not graph-compatible; expect a large per-step latency.", + "set" if q_scale is not None else "None", + q_values.dim(), + kv_cache.dtype, + kv_cache.shape[-1] if kv_cache.dim() else None, + q_values.shape[-1], + ) + return _fp8_paged_mqa_logits_torch( + q, kv_cache, weights, context_lens, block_tables, max_model_len + ) + + +def fp8_fp4_paged_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 paged MQA top-k indices without full logits.""" + q_values, q_scale = q + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q_scale is None + and q_values.dim() == 4 + and kv_cache.dtype == torch.uint8 + and kv_cache.shape[-1] == q_values.shape[-1] + 4 + ): + return False + + num_rows = q_values.shape[0] * q_values.shape[1] + topk_tokens = topk_indices.shape[1] + assert topk_indices.shape == (num_rows, topk_tokens) + assert topk_indices.dtype == torch.int32 + topk_indices.fill_(-1) + if num_rows == 0 or topk_tokens == 0 or max_model_len == 0: + return True + + best_values = torch.full( + (num_rows, topk_tokens), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + chunk_size = max(1, _SM120_PAGED_MQA_TOPK_CHUNK_SIZE) + max_chunk_topk = min(topk_tokens, chunk_size) + chunk_values_buf = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + chunk_indices_buf = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.int64, + ) + chunk_indices_i32 = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + candidate_values = torch.empty( + (num_rows, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + candidate_indices = torch.empty( + (num_rows, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + next_best_values = torch.empty_like(best_values) + selected = torch.empty( + (num_rows, topk_tokens), + device=q_values.device, + dtype=torch.int64, + ) + + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( + fp8_paged_mqa_logits_triton, + ) + + for token_start in range(0, max_model_len, chunk_size): + token_count = min(chunk_size, max_model_len - token_start) + chunk_logits = fp8_paged_mqa_logits_triton( + q_values, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + token_start=token_start, + token_count=token_count, + ) + chunk_topk = min(topk_tokens, token_count) + chunk_values = chunk_values_buf[:, :chunk_topk] + chunk_indices = chunk_indices_buf[:, :chunk_topk] + torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices)) + chunk_indices_out = chunk_indices_i32[:, :chunk_topk] + chunk_indices_out.copy_(chunk_indices) + chunk_indices_out.add_(token_start) + + candidate_cols = topk_tokens + chunk_topk + candidate_values_view = candidate_values[:, :candidate_cols] + candidate_indices_view = candidate_indices[:, :candidate_cols] + candidate_values_view[:, :topk_tokens].copy_(best_values) + candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values) + candidate_indices_view[:, :topk_tokens].copy_(topk_indices) + candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out) + torch.topk( + candidate_values_view, + topk_tokens, + dim=1, + out=(next_best_values, selected), + ) + torch.gather(candidate_indices_view, 1, selected, out=topk_indices) + best_values, next_best_values = next_best_values, best_values + topk_indices.masked_fill_(~torch.isfinite(best_values), -1) + + return True + + +def _tf32_hc_prenorm_gemm_torch( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + """Portable SM12x HyperConnection prenorm GEMM fallback. + + DeepGEMM's split ABI only requires that downstream consumers recover the + full result by summing over the split dimension. Keep the implementation + simple by writing the full product to split zero and clearing the rest. + """ + del num_split + product = x.float() @ fn.float().T + norm = x.float().square().sum(dim=-1) + + if out.dim() == 3: + out.zero_() + sqrsum.zero_() + out[0].copy_(product) + sqrsum[0].copy_(norm) + else: + out.copy_(product) + sqrsum.copy_(norm) + return out + + +def _tf32_hc_prenorm_gemm_sm12x( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + if out.dim() == 3 and sqrsum.dim() == 2: + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( + tf32_hc_prenorm_gemm_triton, + ) + + tf32_hc_prenorm_gemm_triton(x, fn, out, sqrsum, num_split) + return out + + return _tf32_hc_prenorm_gemm_torch(x, fn, out, sqrsum, num_split) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py new file mode 100644 index 000000000000..ac15de394957 --- /dev/null +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py @@ -0,0 +1,726 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Triton fallback kernels used by the local DeepSeek V4 path.""" + +import torch + +from vllm.triton_utils import tl, triton + + +def _view_packed_fp8_paged_mqa_kv_cache( + kv_cache: torch.Tensor, + head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Return FP8 values and fp32 scales from indexer cache block storage.""" + if kv_cache.dtype != torch.uint8: + raise TypeError(f"Expected uint8 kv_cache, got {kv_cache.dtype}") + if kv_cache.dim() == 3: + num_blocks, block_size, head_dim_with_scale = kv_cache.shape + num_kv_heads = 1 + elif kv_cache.dim() == 4: + num_blocks, block_size, num_kv_heads, head_dim_with_scale = kv_cache.shape + else: + raise ValueError(f"Expected 3D or 4D kv_cache, got {kv_cache.dim()} dimensions") + if num_kv_heads != 1: + raise ValueError(f"Expected one KV head, got {num_kv_heads}") + + scale_bytes = head_dim_with_scale - head_dim + if scale_bytes <= 0 or scale_bytes % torch.float32.itemsize != 0: + raise ValueError( + "Expected kv_cache last dimension to contain FP8 values followed " + f"by fp32 scale bytes; got head_dim={head_dim}, " + f"last_dim={head_dim_with_scale}" + ) + + block_stride = kv_cache.stride(0) + base_storage_offset = kv_cache.storage_offset() + scale_elems = scale_bytes // torch.float32.itemsize + kv_values = torch.as_strided( + kv_cache, + size=(num_blocks, block_size, 1, head_dim), + stride=(block_stride, head_dim, head_dim, 1), + storage_offset=base_storage_offset, + ).view(torch.float8_e4m3fn) + kv_scale = torch.as_strided( + kv_cache, + size=(num_blocks, block_size, 1, scale_bytes), + stride=(block_stride, scale_bytes, scale_bytes, 1), + storage_offset=base_storage_offset + block_size * head_dim, + ).view(torch.float32) + return kv_values, kv_scale[..., :scale_elems] + + +@triton.jit +def _fp8_mqa_logits_kernel( + q_ptr, + k_ptr, + scale_ptr, + weights_ptr, + cu_seqlen_ks_ptr, + cu_seqlen_ke_ptr, + logits_ptr, + num_q: tl.constexpr, + seq_len_kv: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kn: tl.constexpr, + stride_kd: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + valid_m = offs_m < num_q + valid_n = offs_n < seq_len_kv + seq_start = tl.load(cu_seqlen_ks_ptr + offs_m, mask=valid_m, other=0) + seq_end = tl.load(cu_seqlen_ke_ptr + offs_m, mask=valid_m, other=0) + seq_mask = (offs_n[None, :] >= seq_start[:, None]) & ( + offs_n[None, :] < seq_end[:, None] + ) + + logits = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for h in tl.range(0, num_heads): + scores = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + offs_m[:, None] * stride_qm + + h * stride_qh + + d[None, :] * stride_qd, + mask=valid_m[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + k_ptr + offs_n[:, None] * stride_kn + d[None, :] * stride_kd, + mask=valid_n[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.dot(q, tl.trans(k), input_precision="tf32") + scale = tl.load(scale_ptr + offs_n, mask=valid_n, other=0.0) + weighted = tl.maximum(scores * scale[None, :], 0.0) + weight = tl.load( + weights_ptr + offs_m * stride_wm + h * stride_wh, + mask=valid_m, + other=0.0, + ) + logits += weighted * weight[:, None] + + store_mask = valid_m[:, None] & valid_n[None, :] + logits = tl.where(seq_mask & store_mask, logits, float("-inf")) + tl.store( + logits_ptr + offs_m[:, None] * stride_lm + offs_n[None, :] * stride_ln, + logits, + mask=store_mask, + ) + + +def fp8_mqa_logits_triton( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + k_fp8, scale = kv + num_q, num_heads, head_dim = q.shape + seq_len_kv = k_fp8.shape[0] + logits = torch.empty( + (num_q, seq_len_kv), + device=q.device, + dtype=torch.float32, + ) + if num_q == 0 or seq_len_kv == 0: + return logits + + block_m = _fp8_mqa_logits_block_m(num_q, seq_len_kv) + grid = (triton.cdiv(num_q, block_m), triton.cdiv(seq_len_kv, 128)) + _fp8_mqa_logits_kernel[grid]( + q, + k_fp8, + scale, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + logits, + num_q, + seq_len_kv, + num_heads, + head_dim, + q.stride(0), + q.stride(1), + q.stride(2), + k_fp8.stride(0), + k_fp8.stride(1), + weights.stride(0), + weights.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_M=block_m, + BLOCK_N=128, + BLOCK_D=64, + num_warps=4, + ) + return logits + + +def _fp8_mqa_logits_block_m(num_q: int, seq_len_kv: int) -> int: + if seq_len_kv <= 16 * 1024: + return 16 + return 64 + + +@triton.jit +def _fp8_paged_mqa_logits_kernel( + q_ptr, + kv_ptr, + scale_ptr, + weights_ptr, + context_lens_ptr, + block_tables_ptr, + logits_ptr, + token_start, + num_rows: tl.constexpr, + logits_width: tl.constexpr, + next_n: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + block_size: tl.constexpr, + stride_qb: tl.constexpr, + stride_qn: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kvb: tl.constexpr, + stride_kvs: tl.constexpr, + stride_kvd: tl.constexpr, + stride_sb: tl.constexpr, + stride_ss: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_clb: tl.constexpr, + stride_cln: tl.constexpr, + stride_btb: tl.constexpr, + stride_btk: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_local_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = token_start + offs_local_n + offs_d = tl.arange(0, BLOCK_D) + + valid_m = offs_m < num_rows + valid_n = offs_local_n < logits_width + batch = offs_m // next_n + q_pos = offs_m - batch * next_n + context_len = tl.load( + context_lens_ptr + batch * stride_clb + q_pos * stride_cln, + mask=valid_m, + other=0, + ) + context_mask = valid_n[None, :] & (offs_n[None, :] < context_len[:, None]) + + block_rank = offs_n // block_size + block_offset = offs_n - block_rank * block_size + block_idx = tl.load( + block_tables_ptr + + batch[:, None] * stride_btb + + block_rank[None, :] * stride_btk, + mask=valid_m[:, None] & valid_n[None, :], + other=0, + ) + + logits = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + scale = tl.load( + scale_ptr + block_idx.to(tl.int64) * stride_sb + block_offset[None, :] * stride_ss, + mask=context_mask, + other=0.0, + ) + for h in tl.range(0, num_heads): + scores = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + batch[:, None] * stride_qb + + q_pos[:, None] * stride_qn + + h * stride_qh + + d[None, :] * stride_qd, + mask=valid_m[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + kv_ptr + + block_idx[:, :, None].to(tl.int64) * stride_kvb + + block_offset[None, :, None] * stride_kvs + + d[None, None, :] * stride_kvd, + mask=context_mask[:, :, None] & (d[None, None, :] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.sum(q[:, None, :] * k, axis=2) + weighted = tl.maximum(scores * scale, 0.0) + weight = tl.load( + weights_ptr + offs_m * stride_wm + h * stride_wh, + mask=valid_m, + other=0.0, + ) + logits += weighted * weight[:, None] + + store_mask = valid_m[:, None] & valid_n[None, :] + logits = tl.where(context_mask & store_mask, logits, float("-inf")) + tl.store( + logits_ptr + offs_m[:, None] * stride_lm + offs_local_n[None, :] * stride_ln, + logits, + mask=store_mask, + ) + + +@triton.jit +def _fp8_paged_mqa_logits_rowwise_kernel( + q_ptr, + kv_ptr, + scale_ptr, + weights_ptr, + context_lens_ptr, + block_tables_ptr, + logits_ptr, + token_start, + num_rows: tl.constexpr, + logits_width: tl.constexpr, + next_n: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + block_size: tl.constexpr, + stride_qb: tl.constexpr, + stride_qn: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kvb: tl.constexpr, + stride_kvs: tl.constexpr, + stride_kvd: tl.constexpr, + stride_sb: tl.constexpr, + stride_ss: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_clb: tl.constexpr, + stride_cln: tl.constexpr, + stride_btb: tl.constexpr, + stride_btk: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_H: tl.constexpr, +): + """Per-row paged-MQA logits kernel optimised for long ``token_count``. + + Each Triton program handles one logical row (``batch * next_n + q_pos``) + across a ``BLOCK_N``-wide window of token positions. Q is loaded once per + head tile and reused for every K element in the window, which preserves + L2 / register locality and avoids the M-axis padding waste of the + generic 2D-tiled kernel at long contexts (mt-bench c=1 MTP=2 num_rows=3 + with token_count=131072 launches 12k programs of 128 logits each rather + than 8k programs of 64 logits with 25 % M-axis waste). + """ + row = tl.program_id(0) + pid_n = tl.program_id(1) + offs_local_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = token_start + offs_local_n + offs_d = tl.arange(0, BLOCK_D) + + valid_row = row < num_rows + valid_n = offs_local_n < logits_width + batch = row // next_n + q_pos = row - batch * next_n + context_len = tl.load( + context_lens_ptr + batch * stride_clb + q_pos * stride_cln, + mask=valid_row, + other=0, + ) + if token_start + pid_n * BLOCK_N >= context_len: + logits = tl.full((BLOCK_N,), float("-inf"), dtype=tl.float32) + tl.store( + logits_ptr + row * stride_lm + offs_local_n * stride_ln, + logits, + mask=valid_row & valid_n, + ) + return + context_mask = valid_n & (offs_n < context_len) + + block_rank = offs_n // block_size + block_offset = offs_n - block_rank * block_size + block_idx = tl.load( + block_tables_ptr + batch * stride_btb + block_rank * stride_btk, + mask=valid_row & context_mask, + other=0, + ) + + scale = tl.load( + scale_ptr + block_idx.to(tl.int64) * stride_sb + block_offset * stride_ss, + mask=context_mask, + other=0.0, + ) + logits = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for h0 in tl.range(0, num_heads, BLOCK_H): + heads = h0 + tl.arange(0, BLOCK_H) + valid_h = heads < num_heads + scores = tl.zeros((BLOCK_H, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + batch * stride_qb + + q_pos * stride_qn + + heads[:, None] * stride_qh + + d[None, :] * stride_qd, + mask=valid_row & valid_h[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + kv_ptr + + block_idx[None, :].to(tl.int64) * stride_kvb + + block_offset[None, :] * stride_kvs + + d[:, None] * stride_kvd, + mask=context_mask[None, :] & (d[:, None] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.dot(q, k, input_precision="tf32") + + weighted = tl.maximum(scores * scale[None, :], 0.0) + weight = tl.load( + weights_ptr + row * stride_wm + heads * stride_wh, + mask=valid_row & valid_h, + other=0.0, + ) + logits += tl.sum(weighted * weight[:, None], axis=0) + + logits = tl.where(context_mask & valid_row, logits, float("-inf")) + tl.store( + logits_ptr + row * stride_lm + offs_local_n * stride_ln, + logits, + mask=valid_row & valid_n, + ) + + +def fp8_paged_mqa_logits_rowwise_triton( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + token_start: int = 0, + token_count: int | None = None, +) -> torch.Tensor: + """Rowwise paged-MQA logits wrapper. + + Pre-condition: ``head_dim % 64 == 0`` and ``num_heads % 4 == 0`` so the + ``tl.dot`` inside ``_fp8_paged_mqa_logits_rowwise_kernel`` lands on + tensor-core friendly tile shapes. DSv4-Flash (head_dim=128, + num_heads=64) satisfies both and is the only model that exercises this + path today; the generic 2D kernel below remains the fallback for + misaligned shapes. + """ + batch_size, next_n, num_heads, head_dim = q.size() + kv_values, kv_scale = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_size, _, _ = kv_values.size() + num_rows = batch_size * next_n + if token_count is None: + token_count = max_model_len - token_start + assert token_start >= 0 + assert token_count >= 0 + assert token_start + token_count <= max_model_len + logits = torch.empty( + (num_rows, token_count), + device=q.device, + dtype=torch.float32, + ) + if num_rows == 0 or token_count == 0: + return logits + + context_lens_2d = context_lens.reshape(batch_size, -1) + if context_lens_2d.shape[1] == 1 and next_n != 1: + context_lens_2d = context_lens_2d.expand(batch_size, next_n).contiguous() + block_n = 128 + grid = (num_rows, triton.cdiv(token_count, block_n)) + _fp8_paged_mqa_logits_rowwise_kernel[grid]( + q, + kv_values, + kv_scale, + weights, + context_lens_2d, + block_tables, + logits, + token_start, + num_rows, + token_count, + next_n, + num_heads, + head_dim, + block_size, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + kv_values.stride(0), + kv_values.stride(1), + kv_values.stride(3), + kv_scale.stride(0), + kv_scale.stride(1), + weights.stride(0), + weights.stride(1), + context_lens_2d.stride(0), + context_lens_2d.stride(1), + block_tables.stride(0), + block_tables.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_N=block_n, + BLOCK_D=64, + BLOCK_H=8, + num_warps=4, + ) + return logits + + +def fp8_paged_mqa_logits_triton( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + token_start: int = 0, + token_count: int | None = None, +) -> torch.Tensor: + batch_size, next_n, num_heads, head_dim = q.size() + # Aligned head shapes (DSv4-Flash and any future MQA model with + # ``head_dim % 64 == 0`` and ``num_heads % 4 == 0``) get the rowwise + # kernel, which keeps long-context decode (>100K tokens) on a per-row + # grid that re-uses Q across the full token window. The generic 2D + # kernel below still handles misaligned shapes and remains the canonical + # reference for the rowwise variant. + if head_dim % 64 == 0 and num_heads % 4 == 0: + return fp8_paged_mqa_logits_rowwise_triton( + q, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + token_start=token_start, + token_count=token_count, + ) + + kv_values, kv_scale = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_size, _, _ = kv_values.size() + num_rows = batch_size * next_n + if token_count is None: + token_count = max_model_len - token_start + assert token_start >= 0 + assert token_count >= 0 + assert token_start + token_count <= max_model_len + logits = torch.empty( + (num_rows, token_count), + device=q.device, + dtype=torch.float32, + ) + if num_rows == 0 or token_count == 0: + return logits + + context_lens_2d = context_lens.reshape(batch_size, -1) + if context_lens_2d.shape[1] == 1 and next_n != 1: + context_lens_2d = context_lens_2d.expand(batch_size, next_n).contiguous() + # Adaptive BLOCK_M: the kernel masks off positions >= num_rows, so a fixed + # BLOCK_M=4 wastes ~75% of M-axis work in the common single-stream decode + # case (num_rows=1). Pick the smallest power-of-2 tile that still covers + # num_rows so we keep one grid-program for typical decode while still + # benefiting from larger tiles when batch / MTP push num_rows higher. + if num_rows <= 1: + block_m = 1 + elif num_rows <= 2: + block_m = 2 + elif num_rows <= 4: + block_m = 4 + else: + block_m = 8 + grid = (triton.cdiv(num_rows, block_m), triton.cdiv(token_count, 64)) + _fp8_paged_mqa_logits_kernel[grid]( + q, + kv_values, + kv_scale, + weights, + context_lens_2d, + block_tables, + logits, + token_start, + num_rows, + token_count, + next_n, + num_heads, + head_dim, + block_size, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + kv_values.stride(0), + kv_values.stride(1), + kv_values.stride(3), + kv_scale.stride(0), + kv_scale.stride(1), + weights.stride(0), + weights.stride(1), + context_lens_2d.stride(0), + context_lens_2d.stride(1), + block_tables.stride(0), + block_tables.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_M=block_m, + BLOCK_N=64, + BLOCK_D=64, + num_warps=4, + ) + return logits + + +@triton.jit +def _tf32_hc_prenorm_gemm_kernel( + x_ptr, + fn_ptr, + out_ptr, + sqrsum_ptr, + M, + K: tl.constexpr, + N: tl.constexpr, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_fnn: tl.constexpr, + stride_fnk: tl.constexpr, + stride_outs: tl.constexpr, + stride_outm: tl.constexpr, + stride_outn: tl.constexpr, + stride_sqs: tl.constexpr, + stride_sqm: tl.constexpr, + NUM_SPLIT: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + pid_s = tl.program_id(2) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + split_k = tl.cdiv(K, NUM_SPLIT) + split_begin = pid_s * split_k + split_end = tl.minimum(split_begin + split_k, K) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + sq = tl.zeros((BLOCK_M,), dtype=tl.float32) + + for k0 in tl.range(0, split_k, BLOCK_K): + k = split_begin + k0 + offs_k + k_mask = k < split_end + x = tl.load( + x_ptr + offs_m[:, None] * stride_xm + k[None, :] * stride_xk, + mask=(offs_m[:, None] < M) & k_mask[None, :], + other=0.0, + ).to(tl.float32) + fn = tl.load( + fn_ptr + offs_n[None, :] * stride_fnn + k[:, None] * stride_fnk, + mask=(offs_n[None, :] < N) & k_mask[:, None], + other=0.0, + ).to(tl.float32) + + acc += tl.dot(x, fn, input_precision="tf32", out_dtype=tl.float32) + sq += tl.sum(x * x, axis=1) + + tl.store( + out_ptr + + pid_s * stride_outs + + offs_m[:, None] * stride_outm + + offs_n[None, :] * stride_outn, + acc, + mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), + ) + + if pid_n == 0: + tl.store( + sqrsum_ptr + pid_s * stride_sqs + offs_m * stride_sqm, + sq, + mask=offs_m < M, + ) + + +def tf32_hc_prenorm_gemm_triton( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> None: + assert x.dim() == 2 + assert fn.dim() == 2 + assert out.dim() == 3 + assert sqrsum.dim() == 2 + + m, k = x.shape + n = fn.shape[0] + assert fn.shape[1] == k + assert out.shape == (num_split, m, n) + assert sqrsum.shape == (num_split, m) + + if m == 0: + return + + block_m = 16 + block_n = triton.next_power_of_2(n) + block_n = min(max(block_n, 16), 32) + block_k = 64 + grid = (triton.cdiv(m, block_m), triton.cdiv(n, block_n), num_split) + _tf32_hc_prenorm_gemm_kernel[grid]( + x, + fn, + out, + sqrsum, + m, + k, + n, + x.stride(0), + x.stride(1), + fn.stride(0), + fn.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + sqrsum.stride(0), + sqrsum.stride(1), + num_split, + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=4, + ) diff --git a/vllm/models/deepseek_v4/sparse_mla.py b/vllm/models/deepseek_v4/sparse_mla.py index 136a96a45da5..4acfb9692b63 100644 --- a/vllm/models/deepseek_v4/sparse_mla.py +++ b/vllm/models/deepseek_v4/sparse_mla.py @@ -324,11 +324,31 @@ def build_c128a_topk_metadata( num_tokens = positions.shape[0] num_prefill_tokens = num_tokens - num_decode_tokens - global_decode = global_decode_buffer[:num_decode_tokens] + if num_tokens == 0: + global_decode = global_decode_buffer[:num_decode_tokens, :0] + decode_lens = decode_lens_buffer[:num_decode_tokens] + prefill_local = prefill_buffer[:num_prefill_tokens, :0] + return global_decode, decode_lens, prefill_local + + KERNEL_BLOCK_SIZE = 1024 + effective_topk = _c128a_effective_topk_width( + positions=positions, + compress_ratio=compress_ratio, + max_compressed_tokens=max_compressed_tokens, + alignment=_C128A_TOPK_ALIGNMENT, + ) + + global_decode = global_decode_buffer[:num_decode_tokens, :effective_topk] decode_lens = decode_lens_buffer[:num_decode_tokens] - prefill_local = prefill_buffer[:num_prefill_tokens] + prefill_local = prefill_buffer[:num_prefill_tokens, :effective_topk] - if num_tokens == 0: + if num_decode_tokens > 0: + global_decode.fill_(-1) + decode_lens.zero_() + if num_prefill_tokens > 0: + prefill_local.fill_(-1) + + if effective_topk == 0: return global_decode, decode_lens, prefill_local _build_c128a_topk_metadata_kernel[(num_tokens,)]( @@ -339,18 +359,41 @@ def build_c128a_topk_metadata( prefill_buffer.stride(0), positions, compress_ratio, - max_compressed_tokens, + effective_topk, num_decode_tokens, token_to_req_indices, block_table, block_table.stride(0), block_size, slot_mapping, - BLOCK_SIZE=1024, + BLOCK_SIZE=KERNEL_BLOCK_SIZE, ) return global_decode, decode_lens, prefill_local +def _c128a_effective_topk_width( + *, + positions: torch.Tensor, + compress_ratio: int, + max_compressed_tokens: int, + alignment: int, +) -> int: + """Return the aligned C128A top-k width needed by current tokens.""" + if positions.numel() == 0: + return 0 + max_pos = int(positions.max().item()) + max_num_compressed = min( + max((max_pos + 1) // int(compress_ratio), 0), + int(max_compressed_tokens), + ) + if max_num_compressed == 0: + return min(int(max_compressed_tokens), int(alignment)) + return min( + int(max_compressed_tokens), + cdiv(max_num_compressed, int(alignment)) * int(alignment), + ) + + @triton.jit def _build_c128a_topk_metadata_kernel( # Decode outputs @@ -363,7 +406,7 @@ def _build_c128a_topk_metadata_kernel( # Inputs positions_ptr, compress_ratio, - max_compressed_tokens, + effective_topk, num_decode_tokens, token_to_req_indices_ptr, block_table_ptr, @@ -375,7 +418,7 @@ def _build_c128a_topk_metadata_kernel( token_idx = tl.program_id(0) position = tl.load(positions_ptr + token_idx) num_compressed = (position + 1) // compress_ratio - num_compressed = tl.minimum(num_compressed, max_compressed_tokens) + num_compressed = tl.minimum(num_compressed, effective_topk) is_decode = token_idx < num_decode_tokens if is_decode: @@ -383,9 +426,9 @@ def _build_c128a_topk_metadata_kernel( is_valid_token = tl.load(slot_mapping_ptr + token_idx) >= 0 req_idx = tl.load(token_to_req_indices_ptr + token_idx) count = tl.zeros((), dtype=tl.int32) - for i in range(0, max_compressed_tokens, BLOCK_SIZE): + for i in range(0, effective_topk, BLOCK_SIZE): offset = i + tl.arange(0, BLOCK_SIZE) - mask = offset < max_compressed_tokens + mask = offset < effective_topk is_valid = offset < num_compressed block_indices = offset // block_size @@ -410,9 +453,9 @@ def _build_c128a_topk_metadata_kernel( else: # --- Prefill: write local indices --- pfx_idx = token_idx - num_decode_tokens - for i in range(0, max_compressed_tokens, BLOCK_SIZE): + for i in range(0, effective_topk, BLOCK_SIZE): offset = i + tl.arange(0, BLOCK_SIZE) - mask = offset < max_compressed_tokens + mask = offset < effective_topk tl.store( prefill_local_ptr + pfx_idx * prefill_local_stride + offset, tl.where(offset < num_compressed, offset, -1), diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index dabe6058e42f..abd86e518ded 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -661,12 +661,17 @@ def support_static_graph_mode(cls) -> bool: @classmethod def support_deep_gemm(cls) -> bool: - """Currently, only Hopper and Blackwell GPUs are supported.""" - return ( - cls.is_device_capability(90) - or cls.is_device_capability_family(100) - or cls.is_device_capability_family(120) - ) + """Currently, only Hopper and Blackwell SM90/SM100 GPUs are supported. + + SM120 (family-120, consumer Blackwell) is intentionally excluded: the + DeepGEMM SM120 MXFP4 kernels require the still-unmerged DeepGEMM PR #324, + and the released/pinned DeepGEMM ref aborts on SM120 with a scale-factor + layout assertion (`sf.size(-2) == ceil_div(mn, gran_mn)`). DSv4 on SM120 + runs the Marlin / cutlass + sm12x DeepGEMM-fallback path instead, which + serves on stock deps. (#43477 enabled family-120 here; re-enable once + DeepGEMM #324 lands.) + """ + return cls.is_device_capability(90) or cls.is_device_capability_family(100) @classmethod def is_integrated_gpu(cls, device_id: int = 0) -> bool: diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index cbb1fa350f55..b99b2e1a43ec 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -29,8 +29,8 @@ "DeepSeekV3ReasoningParser", ), "deepseek_v4": ( - "deepseek_v3_reasoning_parser", - "DeepSeekV3ReasoningParser", + "deepseek_v4_reasoning_parser", + "DeepSeekV4ReasoningParser", ), "poolside_v1": ( "poolside_v1_reasoning_parser", diff --git a/vllm/reasoning/deepseek_v4_reasoning_parser.py b/vllm/reasoning/deepseek_v4_reasoning_parser.py new file mode 100644 index 000000000000..fe38ca84c03b --- /dev/null +++ b/vllm/reasoning/deepseek_v4_reasoning_parser.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.engine.protocol import DeltaMessage +from vllm.reasoning import ReasoningParser +from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser + +from .identity_reasoning_parser import IdentityReasoningParser + +if TYPE_CHECKING: + from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest + from vllm.entrypoints.openai.responses.protocol import ResponsesRequest + + +# DSML tool-call start markers that the DSv4 model emits when it decides to +# invoke a tool. The reasoning parser treats these as an implicit +# end-of-reasoning when the explicit token is missing, which the +# model occasionally fails to emit at long context. +_DSV4_TOOL_CALL_IMPLICIT_END_MARKERS: tuple[str, ...] = ("<|DSML|tool_calls>",) + + +class DeepSeekV4ThinkingReasoningParser(DeepSeekR1ReasoningParser): + """ + DeepSeek V4 thinking-mode reasoning parser. + + Extends :class:`DeepSeekR1ReasoningParser` with one behavior change: + if the model emits a DSML tool-call start marker without first emitting + ````, treat the marker as an implicit end-of-reasoning so the + tool call is correctly handed off to the tool parser. + + Works around a model behavior observed at long context (~95k–100k + input tokens) where DSv4-Flash sometimes skips ```` before + opening ``<|DSML|tool_calls>``. Without this defensive split the + orchestrator stays in reasoning phase, the tool parser never runs, and + the caller sees a turn with reasoning but no tool call — + indistinguishable on the client side from "model gave up". + + Healthy paths (explicit ````) are unchanged. The marker check + fires only when no explicit start/end token has been seen. + + State (``_implicit_end_seen``) is per-instance and per-stream because + a fresh ``Parser`` (and therefore a fresh reasoning parser) is + constructed for each request. + """ + + implicit_end_markers: tuple[str, ...] = _DSV4_TOOL_CALL_IMPLICIT_END_MARKERS + + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + # Per-stream sticky flag: once the implicit end marker is observed, + # the rest of the stream is content and the orchestrator's + # is_reasoning_end check must return True for every subsequent delta. + self._implicit_end_seen: bool = False + # Text deltas may split the DSML marker itself, for example + # "<|DSML|tool" then "_calls>". Hold such suffixes back so they do + # not leak into reasoning before the next delta disambiguates them. + self._pending_implicit_marker_prefix: str = "" + + def _find_implicit_end_marker(self, text: str) -> tuple[str, int] | None: + """Return ``(marker, index)`` of the earliest implicit end marker in + ``text``, or ``None`` if none of the configured markers are present. + """ + earliest: tuple[str, int] | None = None + for marker in self.implicit_end_markers: + idx = text.find(marker) + if idx < 0: + continue + if earliest is None or idx < earliest[1]: + earliest = (marker, idx) + return earliest + + def _partial_implicit_end_marker_suffix(self, text: str) -> str: + """Return the longest suffix of ``text`` that prefixes a DSML marker.""" + best = "" + for marker in self.implicit_end_markers: + max_len = min(len(marker) - 1, len(text)) + for length in range(max_len, 0, -1): + suffix = text[-length:] + if marker.startswith(suffix): + if length > len(best): + best = suffix + break + return best + + def _extract_pending_implicit_marker_delta( + self, delta_text: str + ) -> DeltaMessage | None: + pending = self._pending_implicit_marker_prefix + combined = pending + delta_text + + marker = self._find_implicit_end_marker(combined) + if marker is not None and marker[1] == 0: + self._pending_implicit_marker_prefix = "" + self._implicit_end_seen = True + return DeltaMessage(content=combined) + + if any(marker.startswith(combined) for marker in self.implicit_end_markers): + self._pending_implicit_marker_prefix = combined + return None + + # The withheld suffix was a false alarm. Emit it as reasoning, while + # still guarding against a new marker prefix at the end of this delta. + partial = self._partial_implicit_end_marker_suffix(combined) + if partial: + self._pending_implicit_marker_prefix = partial + reasoning = combined[: -len(partial)] + else: + self._pending_implicit_marker_prefix = "" + reasoning = combined + return DeltaMessage(reasoning=reasoning) if reasoning else None + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + # Honor the explicit contract first. + if super().is_reasoning_end(input_ids): + return True + # Sticky: once we've observed the marker in streaming, the rest of + # the stream is content. + if self._implicit_end_seen: + return True + # Non-streaming fallback: scan decoded text for the marker. We only + # decode when start/end tokens are absent, which is the failure + # mode we target. + if not input_ids: + return False + if self.start_token_id in input_ids or self.end_token_id in input_ids: + return False + try: + decoded = self.model_tokenizer.decode(list(input_ids)) + except Exception: + return False + return self._find_implicit_end_marker(decoded) is not None + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Iterable[int] + ) -> bool: + if super().is_reasoning_end_streaming(input_ids, delta_ids): + return True + return bool(self._implicit_end_seen) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + ret = super().extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + + # Parent emitted content (explicit in or before this delta). + # Healthy path — nothing to do. + if ret is not None and ret.content is not None: + return ret + + # Defensive check is only meaningful in implicit-reasoning mode + # (no explicit / in the stream). If any explicit + # token is present, defer to parent. + if ( + self.start_token_id in previous_token_ids + or self.start_token_id in delta_token_ids + or self.end_token_id in previous_token_ids + or self.end_token_id in delta_token_ids + ): + return ret + + if self._pending_implicit_marker_prefix: + return self._extract_pending_implicit_marker_delta(delta_text) + + # Sticky: marker observed in an earlier delta — everything here is + # content for the tool parser. + if self._implicit_end_seen: + return DeltaMessage(content=delta_text) + + marker_in_current = self._find_implicit_end_marker(current_text) + if marker_in_current is None: + partial = self._partial_implicit_end_marker_suffix(current_text) + if partial: + partial_start = len(current_text) - len(partial) + if partial_start >= len(previous_text): + self._pending_implicit_marker_prefix = partial + reasoning = delta_text[: -len(partial)] or None + return DeltaMessage(reasoning=reasoning) if reasoning else None + # No marker or holdable marker prefix; parent's classification stands. + return ret + + # First sighting of the implicit end marker. + self._implicit_end_seen = True + self._pending_implicit_marker_prefix = "" + _marker_str, marker_idx_current = marker_in_current + # Position within delta_text where the marker begins. + marker_idx_delta = marker_idx_current - len(previous_text) + if marker_idx_delta < 0: + # Marker straddles into previous_text but wasn't detected there + # (parent path didn't hit). Treat all of delta_text as content. + return DeltaMessage(content=delta_text) + + reasoning_part = delta_text[:marker_idx_delta] or None + content_part = delta_text[marker_idx_delta:] or None + if reasoning_part is None and content_part is None: + return ret + return DeltaMessage(reasoning=reasoning_part, content=content_part) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + # If parent finds , use its behavior. + parent_result = super().extract_content_ids(input_ids) + if parent_result: + return parent_result + # Fall back to text-scan for the implicit marker. + if not input_ids: + return [] + try: + decoded = self.model_tokenizer.decode(list(input_ids)) + except Exception: + return [] + marker = self._find_implicit_end_marker(decoded) + if marker is None: + return [] + # Without per-token offsets we can't slice input_ids exactly at the + # marker boundary. Return everything as content tokens — the + # orchestrator drives tool-call extraction off ``current_text`` for + # DSML grammars, so this conservative answer is acceptable. + return list(input_ids) + + +class DeepSeekV4ReasoningParser(ReasoningParser): + """ + V4 reasoning parser that delegates to either + :class:`DeepSeekV4ThinkingReasoningParser` (the V4-aware extension of + R1) or :class:`IdentityReasoningParser` based on the ``thinking`` / + ``enable_thinking`` chat-template kwargs. + + Replaces the previous arrangement where ``deepseek_v4`` reused + :class:`DeepSeekV3ReasoningParser`, which lacked the implicit + DSML-tool-call end-of-reasoning handling. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + thinking = bool(chat_kwargs.get("thinking", False)) + enable_thinking = bool(chat_kwargs.get("enable_thinking", False)) + thinking = thinking or enable_thinking + + self._parser: ReasoningParser + if thinking: + self._parser = DeepSeekV4ThinkingReasoningParser(tokenizer, *args, **kwargs) + else: + self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs) + + @property + def reasoning_start_str(self) -> str | None: + return self._parser.reasoning_start_str + + @property + def reasoning_end_str(self) -> str | None: + return self._parser.reasoning_end_str + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Iterable[int] + ) -> bool: + return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning( + self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest" + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning(model_output, request) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> "DeltaMessage | None": + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) diff --git a/vllm/tokenizers/deepseek_v4_encoding.py b/vllm/tokenizers/deepseek_v4_encoding.py index 6895771e2f59..74f01ce017e4 100644 --- a/vllm/tokenizers/deepseek_v4_encoding.py +++ b/vllm/tokenizers/deepseek_v4_encoding.py @@ -155,10 +155,15 @@ def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str: p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}' P_dsml_strs = [] - if isinstance(tool_call["arguments"], str): - arguments = json.loads(tool_call["arguments"]) + raw_arguments = tool_call.get("arguments") + if raw_arguments is None or raw_arguments == "": + arguments = {} + elif isinstance(raw_arguments, str): + arguments = json.loads(raw_arguments) + if arguments is None: + arguments = {} else: - arguments = tool_call["arguments"] + arguments = raw_arguments for k, v in arguments.items(): p_dsml_str = p_dsml_template.format( diff --git a/vllm/tool_parsers/structural_tag_registry.py b/vllm/tool_parsers/structural_tag_registry.py index 99c92f8f0a2e..0dfc879518b0 100644 --- a/vllm/tool_parsers/structural_tag_registry.py +++ b/vllm/tool_parsers/structural_tag_registry.py @@ -128,12 +128,26 @@ def get_model_structural_tag( supported = sorted(SUPPORTED_STRUCTURAL_TAG_MODELS) raise ValueError(f"Unknown format type: {model}, supported types: {supported}") - return get_xgrammar_model_structural_tag( + structural_tag = get_xgrammar_model_structural_tag( model=model, tools=dumped_tools, tool_choice=dumped_tool_choice, reasoning=reasoning, ) + if model == "deepseek_v4" and not reasoning: + _allow_deepseek_v4_chat_think_end(structural_tag) + return structural_tag + + +def _allow_deepseek_v4_chat_think_end(structural_tag: StructuralTag) -> None: + """Allow DSv4 chat-mode outputs to emit an extra ````.""" + fmt = structural_tag.format + if getattr(fmt, "type", None) != "triggered_tags": + return + excludes = getattr(fmt, "excludes", None) + if not isinstance(excludes, list): + return + fmt.excludes = [exclude for exclude in excludes if exclude != ""] def _dump_tool_for_xgrammar( diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 0a1644bcc5c6..3fa4eaaea75c 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -149,6 +149,9 @@ def _missing(*_: Any, **__: Any) -> NoReturn: None ) _transform_sf_into_required_layout_impl: Callable[..., Any] | None = None +_transform_weights_for_mega_moe_impl: Callable[..., Any] | None = None +_get_symm_buffer_for_mega_moe_impl: Callable[..., Any] | None = None +_fp8_fp4_mega_moe_impl: Callable[..., Any] | None = None _pack_ue8m0_to_int_impl: Callable[..., Any] | None = None _get_mn_major_tma_aligned_packed_ue8m0_tensor_impl: Callable[..., Any] | None = None _get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor_impl: ( @@ -220,6 +223,8 @@ def _lazy_init() -> None: global _get_mk_alignment_for_contiguous_layout_impl global _get_theoretical_mk_alignment_for_contiguous_layout_impl global _transform_sf_into_required_layout_impl + global _transform_weights_for_mega_moe_impl + global _get_symm_buffer_for_mega_moe_impl, _fp8_fp4_mega_moe_impl global _pack_ue8m0_to_int_impl global _get_mn_major_tma_aligned_packed_ue8m0_tensor_impl global _get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor_impl @@ -237,6 +242,9 @@ def _lazy_init() -> None: or _tf32_hc_prenorm_gemm_impl is not None or _get_mk_alignment_for_contiguous_layout_impl is not None or _transform_sf_into_required_layout_impl is not None + or _transform_weights_for_mega_moe_impl is not None + or _get_symm_buffer_for_mega_moe_impl is not None + or _fp8_fp4_mega_moe_impl is not None or _pack_ue8m0_to_int_impl is not None or _get_mn_major_tma_aligned_packed_ue8m0_tensor_impl is not None or _get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor_impl is not None @@ -286,6 +294,13 @@ def _lazy_init() -> None: _transform_sf_into_required_layout_impl = getattr( _dg, "transform_sf_into_required_layout", None ) + _transform_weights_for_mega_moe_impl = getattr( + _dg, "transform_weights_for_mega_moe", None + ) + _get_symm_buffer_for_mega_moe_impl = getattr( + _dg, "get_symm_buffer_for_mega_moe", None + ) + _fp8_fp4_mega_moe_impl = getattr(_dg, "fp8_fp4_mega_moe", None) _pack_ue8m0_to_int_impl = getattr(_dg, "pack_ue8m0_to_int", None) _get_mn_major_tma_aligned_packed_ue8m0_tensor_impl = getattr( _dg, "get_mn_major_tma_aligned_packed_ue8m0_tensor", None @@ -496,6 +511,69 @@ def transform_sf_into_required_layout(*args, **kwargs): ) +def transform_weights_for_mega_moe(*args, **kwargs): + _lazy_init() + if _transform_weights_for_mega_moe_impl is None: + return _missing(*args, **kwargs) + return _transform_weights_for_mega_moe_impl(*args, **kwargs) + + +def get_symm_buffer_for_mega_moe(*args, **kwargs): + _lazy_init() + if _get_symm_buffer_for_mega_moe_impl is None: + return _missing(*args, **kwargs) + return _get_symm_buffer_for_mega_moe_impl(*args, **kwargs) + + +def fp8_fp4_mega_moe(*args, **kwargs): + _lazy_init() + if _fp8_fp4_mega_moe_impl is None: + return _missing(*args, **kwargs) + return _fp8_fp4_mega_moe_impl(*args, **kwargs) + + +def fp8_fp4_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 MQA top-k indices without materializing full logits.""" + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q[1] is None + ): + return False + from vllm.models.deepseek_v4.nvidia.ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks.fp8_fp4_mqa_topk_indices( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_indices, + ) + + +def _fp8_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + from vllm.models.deepseek_v4.nvidia.ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks._fp8_mqa_logits_sm12x( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) + + def fp8_fp4_mqa_logits( q: tuple[torch.Tensor, torch.Tensor | None], kv: tuple[torch.Tensor, torch.Tensor], @@ -528,6 +606,10 @@ def fp8_fp4_mqa_logits( Returns: Logits tensor of shape [M, N], dtype `torch.float32`. """ + if current_platform.is_device_capability_family(120) and q[1] is None: + return _fp8_mqa_logits_sm12x( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) _lazy_init() if _fp8_fp4_mqa_logits_impl is None: return _missing() @@ -562,6 +644,50 @@ def get_paged_mqa_logits_metadata( return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) +def _fp8_paged_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + from vllm.models.deepseek_v4.nvidia.ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks._fp8_paged_mqa_logits_sm12x( + q, kv_cache, weights, context_lens, block_tables, max_model_len + ) + + +def fp8_fp4_paged_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 paged MQA top-k indices without full logits.""" + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q[1] is None + ): + return False + from vllm.models.deepseek_v4.nvidia.ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks.fp8_fp4_paged_mqa_topk_indices( + q, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + topk_indices, + ) + + def fp8_fp4_paged_mqa_logits( q: tuple[torch.Tensor, torch.Tensor | None], kv_cache: torch.Tensor, @@ -583,9 +709,10 @@ def fp8_fp4_paged_mqa_logits( [B, next_n, H, D] float8_e4m3fn and q_scale is None. FP4 path: q_values is packed uint8 and q_scale is the companion block-scale tensor. - kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, 1, - D+4], dtype `torch.uint8`, with the last 4 bytes per (block, pos) - storing the float dequant scale. + kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, D+4] + or [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. Within + each block, the D-byte FP8 values for every token are stored first, + followed by per-token fp32 scale bytes. weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. context_lens: Tensor of shape [B], dtype int32; effective context length for each batch element. @@ -600,6 +727,10 @@ def fp8_fp4_paged_mqa_logits( Logits tensor of shape [B * next_n, max_model_len], dtype `torch.float32`. """ + if current_platform.is_device_capability_family(120) and q[1] is None: + return _fp8_paged_mqa_logits_sm12x( + q, kv_cache, weights, context_lens, block_tables, max_model_len + ) _lazy_init() if _fp8_fp4_paged_mqa_logits_impl is None: return _missing() @@ -615,6 +746,20 @@ def fp8_fp4_paged_mqa_logits( ) +def _tf32_hc_prenorm_gemm_sm12x( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + from vllm.models.deepseek_v4.nvidia.ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks._tf32_hc_prenorm_gemm_sm12x( + x, fn, out, sqrsum, num_split + ) + + def tf32_hc_prenorm_gemm( x: torch.Tensor, fn: torch.Tensor, @@ -629,6 +774,8 @@ def tf32_hc_prenorm_gemm( See the caller function for shape requirement """ + if current_platform.is_device_capability_family(120): + return _tf32_hc_prenorm_gemm_sm12x(x, fn, out, sqrsum, num_split) _lazy_init() if _tf32_hc_prenorm_gemm_impl is None: return _missing() @@ -728,7 +875,9 @@ def should_use_deepgemm_for_fp8_linear( "m_grouped_fp8_fp4_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", "fp8_fp4_mqa_logits", + "fp8_fp4_mqa_topk_indices", "fp8_fp4_paged_mqa_logits", + "fp8_fp4_paged_mqa_topk_indices", "get_paged_mqa_logits_metadata", "per_block_cast_to_fp8", "is_deep_gemm_e8m0_used", diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index c8e6a8419be7..9e8dbfc5f572 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -213,6 +213,21 @@ def has_flashinfer_moe() -> bool: @functools.cache +def has_flashinfer_trtllm_sparse_mla_dsv4() -> bool: + """Return ``True`` if FlashInfer's official SM120 packed sparse-MLA decode + kernel (``trtllm_batch_decode_sparse_mla_dsv4``, PR3395, merged in + flashinfer >= 0.6.13) is available.""" + if not has_flashinfer(): + return False + try: + from flashinfer.mla import ( # noqa: F401 + trtllm_batch_decode_sparse_mla_dsv4, + ) + except ImportError: + return False + return True + + def has_flashinfer_sparse_mla_sm120() -> bool: """Return ``True`` if FlashInfer sparse MLA decode support is available.""" if not has_flashinfer(): diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index 043798a584b8..18c83af71e32 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -539,9 +539,14 @@ def has_fbgemm_gpu() -> bool: return _has_module("fbgemm_gpu") +@cache def has_cutedsl() -> bool: - """Whether the optional `cutelass` package is available.""" - return _has_module("cutlass") + """Whether the optional `cutlass` package is available and importable.""" + try: + import cutlass # noqa: F401 + except Exception: + return False + return True def has_humming() -> bool: diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 0bc7ca7aa414..150ea1d10555 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from dataclasses import dataclass import torch -import vllm.envs as envs +from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -32,6 +33,26 @@ logger = init_logger(__name__) +def sparse_indexer_max_logits_bytes(is_sm12x: bool | None = None) -> int: + if is_sm12x is None: + is_sm12x = ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + ) + if "VLLM_SPARSE_INDEXER_MAX_LOGITS_MB" in os.environ: + return envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + default_mb = 256 if is_sm12x else 512 + return default_mb * 1024 * 1024 + + +def _uses_deep_gemm_scheduler_metadata() -> bool: + return ( + current_platform.is_cuda() + and has_deep_gemm() + and not current_platform.is_device_capability_family(120) + ) + + @triton.jit def _prepare_uniform_decode_kernel( seq_lens_ptr, @@ -481,12 +502,17 @@ def build( seq_lens = common_attn_metadata.seq_lens slot_mapping = common_attn_metadata.slot_mapping block_table = common_attn_metadata.block_table_tensor + has_prefilling_rows = ( + common_attn_metadata.is_prefilling is not None + and torch.any(common_attn_metadata.is_prefilling).item() + ) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold, require_uniform=not self.use_flattening, + treat_short_extends_as_decodes=not has_prefilling_rows, ) ) @@ -522,7 +548,7 @@ def build( prefill_query_lens_cpu = torch.diff( query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1] ) - max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + max_logits_bytes = sparse_indexer_max_logits_bytes() # Upper bound is exact for prefill rows (the `[num_decodes:]` # slice below). assert common_attn_metadata.seq_lens_cpu_upper_bound is not None @@ -612,8 +638,7 @@ def build( if seq_lens.dim() == 1: seq_lens = seq_lens.unsqueeze(-1) - # DeepGEMM is required for the paged MQA logits on CUDA devices - if current_platform.is_cuda() and has_deep_gemm(): + if _uses_deep_gemm_scheduler_metadata(): self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( seq_lens, self.kv_cache_spec.storage_block_size, diff --git a/vllm/v1/attention/backends/mla/sparse_mla_env.py b/vllm/v1/attention/backends/mla/sparse_mla_env.py new file mode 100644 index 000000000000..92273b55111f --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_env.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Environment controls for the portable Triton sparse MLA path.""" + +import os + +import torch + +import vllm.envs as envs +from vllm.platforms import current_platform + + +def _is_sm12x_device(device: torch.device) -> bool: + if not current_platform.is_cuda(): + return False + index = ( + device.index + if device.index is not None + else torch.accelerator.current_device_index() + ) + capability = current_platform.get_device_capability(device_id=index) + return capability is not None and capability[0] == 12 + + +def triton_sparse_mla_configured() -> bool | None: + return envs.VLLM_TRITON_MLA_SPARSE + + +def is_triton_sparse_mla_enabled_for_platform() -> bool: + configured = triton_sparse_mla_configured() + if configured is not None: + return configured + return current_platform.is_device_capability_family(120) + + +def is_triton_sparse_mla_enabled(device: torch.device) -> bool: + configured = triton_sparse_mla_configured() + if configured is not None: + return configured + return _is_sm12x_device(device) + + +def triton_sparse_mla_topk_chunk_size() -> int: + return envs.VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE + + +def triton_sparse_mla_prefill_topk_chunk_size( + *, + combined_topk_size: int, + compress_ratio: int, + request_count: int, +) -> int: + """Choose the Triton sparse MLA prefill topk chunk size. + + Keep explicit user overrides authoritative. The auto path uses a larger + chunk for SM12x C128A single-request prefill to reduce per-request loop + overhead, but keeps a smaller chunk for the multi-request shape that is + unstable near 128K context. + """ + + configured_topk = triton_sparse_mla_topk_chunk_size() + if os.getenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE") is not None: + return min(combined_topk_size, configured_topk) + if current_platform.is_device_capability_family(120) and compress_ratio == 128: + if request_count > 1 and combined_topk_size > 1024: + configured_topk = min(configured_topk, 256) + elif request_count == 1 and combined_topk_size > 1024: + configured_topk = max(configured_topk, 1024) + return min(combined_topk_size, configured_topk) + + +def triton_sparse_mla_query_chunk_size() -> int: + return envs.VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE + + +def triton_sparse_mla_head_block_size() -> int | None: + value = envs.VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE + if value in (1, 2, 4): + return value + return None + + +def triton_sparse_mla_matmul_decode_enabled() -> bool: + configured = envs.VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE + if configured is not None: + return configured + return current_platform.is_device_capability_family(120) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py new file mode 100644 index 000000000000..5dfa75da3210 --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -0,0 +1,3521 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Portable sparse MLA Triton kernels.""" + +import torch + +from vllm.triton_utils import tl, triton +from vllm.utils.math_utils import next_power_of_2 +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + triton_sparse_mla_head_block_size, +) + + +def sparse_mla_decode_head_block_size(num_decode_tokens: int) -> int: + """Choose the SM12x sparse MLA head grouping for decode kernels. + + Single-token decode is latency sensitive and does best with one head per + program. Once there are enough query tokens, grouping heads lets the kernel + reuse each dequantized KV row across multiple heads. + """ + + configured_head_block_size = triton_sparse_mla_head_block_size() + if configured_head_block_size is not None: + return configured_head_block_size + if num_decode_tokens <= 4: + return 1 + if num_decode_tokens < 16: + return 2 + return 4 + + +@triton.jit +def _merge_two_subsets_with_sink_kernel( + out0_ptr, + lse0_ptr, + out1_ptr, + lse1_ptr, + sink_ptr, + output_ptr, + stride_out0_t: tl.constexpr, + stride_out0_h: tl.constexpr, + stride_out0_d: tl.constexpr, + stride_lse0_t: tl.constexpr, + stride_lse0_h: tl.constexpr, + stride_out1_t: tl.constexpr, + stride_out1_h: tl.constexpr, + stride_out1_d: tl.constexpr, + stride_lse1_t: tl.constexpr, + stride_lse1_h: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offsets < head_dim + + lse0 = tl.load(lse0_ptr + token_idx * stride_lse0_t + head_idx * stride_lse0_h) + lse1 = tl.load(lse1_ptr + token_idx * stride_lse1_t + head_idx * stride_lse1_h) + sink = tl.load(sink_ptr + head_idx) + merge_max = tl.maximum(tl.maximum(lse0, lse1), sink) + + weight0 = tl.exp(lse0 - merge_max) + weight1 = tl.exp(lse1 - merge_max) + weight_sink = tl.exp(sink - merge_max) + denom = weight0 + weight1 + weight_sink + + out0 = tl.load( + out0_ptr + + token_idx * stride_out0_t + + head_idx * stride_out0_h + + offsets * stride_out0_d, + mask=mask, + other=0.0, + ).to(tl.float32) + out1 = tl.load( + out1_ptr + + token_idx * stride_out1_t + + head_idx * stride_out1_h + + offsets * stride_out1_d, + mask=mask, + other=0.0, + ).to(tl.float32) + merged = (out0 * weight0 + out1 * weight1) / denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + merged, + mask=mask, + ) + + +def merge_two_sparse_mla_subsets_with_sink( + subset0_output: torch.Tensor, + subset0_lse: torch.Tensor, + subset1_output: torch.Tensor, + subset1_lse: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert subset0_output.shape == subset1_output.shape + assert subset0_output.shape == output.shape + assert subset0_lse.shape == subset1_lse.shape + assert subset0_lse.shape == subset0_output.shape[:2] + assert attn_sink.shape[0] == subset0_output.shape[1] + assert subset0_output.is_cuda + assert subset1_output.is_cuda + assert output.is_cuda + + num_tokens, num_heads, head_dim = subset0_output.shape + block_d = min(128, next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _merge_two_subsets_with_sink_kernel[grid]( + subset0_output, + subset0_lse, + subset1_output, + subset1_lse, + attn_sink, + output, + subset0_output.stride(0), + subset0_output.stride(1), + subset0_output.stride(2), + subset0_lse.stride(0), + subset0_lse.stride(1), + subset1_output.stride(0), + subset1_output.stride(1), + subset1_output.stride(2), + subset1_lse.stride(0), + subset1_lse.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _merge_single_subset_with_sink_kernel( + subset_output_ptr, + subset_lse_ptr, + sink_ptr, + output_ptr, + stride_subset_t: tl.constexpr, + stride_subset_h: tl.constexpr, + stride_subset_d: tl.constexpr, + stride_lse_t: tl.constexpr, + stride_lse_h: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offsets < head_dim + + subset_lse = tl.load( + subset_lse_ptr + token_idx * stride_lse_t + head_idx * stride_lse_h + ) + sink = tl.load(sink_ptr + head_idx) + merge_max = tl.maximum(subset_lse, sink) + + subset_weight = tl.exp(subset_lse - merge_max) + sink_weight = tl.exp(sink - merge_max) + denom = subset_weight + sink_weight + subset_output = tl.load( + subset_output_ptr + + token_idx * stride_subset_t + + head_idx * stride_subset_h + + offsets * stride_subset_d, + mask=mask, + other=0.0, + ).to(tl.float32) + merged = subset_output * subset_weight / denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + merged, + mask=mask, + ) + + +def merge_sparse_mla_subset_with_sink( + subset_output: torch.Tensor, + subset_lse: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert subset_output.shape == output.shape + assert subset_lse.shape == subset_output.shape[:2] + assert attn_sink.shape[0] == subset_output.shape[1] + assert subset_output.is_cuda + assert subset_lse.is_cuda + assert attn_sink.is_cuda + assert output.is_cuda + + num_tokens, num_heads, head_dim = subset_output.shape + block_d = min(128, next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _merge_single_subset_with_sink_kernel[grid]( + subset_output, + subset_lse, + attn_sink, + output, + subset_output.stride(0), + subset_output.stride(1), + subset_output.stride(2), + subset_lse.stride(0), + subset_lse.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _build_combined_decode_valid_mask_kernel( + output_ptr, + slot_ids_ptr, + topk_lens_ptr, + swa_lens_ptr, + stride_output_t: tl.constexpr, + stride_output_c: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + num_compressed_candidates: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_C: tl.constexpr, +): + token_idx = tl.program_id(0) + offsets = tl.arange(0, BLOCK_C) + candidate_mask = offsets < num_candidates + + topk_lens = tl.load(topk_lens_ptr + token_idx) + swa_lens = tl.load(swa_lens_ptr + token_idx) + is_compressed = offsets < num_compressed_candidates + swa_offsets = offsets - num_compressed_candidates + slot_ids = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + offsets * stride_slot_c, + mask=is_compressed, + other=-1, + ) + valid_compressed = is_compressed & (offsets < topk_lens) & (slot_ids >= 0) + valid_swa = (~is_compressed) & (swa_offsets < swa_lens) + valid = valid_compressed | valid_swa + tl.store( + output_ptr + token_idx * stride_output_t + offsets * stride_output_c, + valid, + mask=candidate_mask, + ) + + +def build_combined_sparse_mla_decode_valid_mask( + output: torch.Tensor, + compressed_slot_ids: torch.Tensor, + topk_lens: torch.Tensor, + swa_lens: torch.Tensor, +) -> None: + """Build `[compressed, SWA]` validity mask for SM12x decode.""" + if compressed_slot_ids.dim() == 3: + assert compressed_slot_ids.shape[1] == 1 + compressed_slot_ids = compressed_slot_ids[:, 0, :] + + assert output.dim() == 2 + assert output.dtype == torch.bool + assert compressed_slot_ids.dim() == 2 + assert output.shape[0] == compressed_slot_ids.shape[0] + assert output.shape[0] == topk_lens.shape[0] + assert output.shape[0] == swa_lens.shape[0] + assert output.shape[1] >= compressed_slot_ids.shape[1] + assert output.is_cuda + assert compressed_slot_ids.is_cuda + assert topk_lens.is_cuda + assert swa_lens.is_cuda + + num_candidates = output.shape[1] + block_c = next_power_of_2(num_candidates) + _build_combined_decode_valid_mask_kernel[(output.shape[0],)]( + output, + compressed_slot_ids, + topk_lens, + swa_lens, + output.stride(0), + output.stride(1), + compressed_slot_ids.stride(0), + compressed_slot_ids.stride(1), + compressed_slot_ids.shape[1], + num_candidates, + BLOCK_C=block_c, + num_warps=4, + ) + + +def matmul_sparse_mla_attention_with_sink( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + num_heads: int | None = None, + score_buffer: torch.Tensor | None = None, + head_block_size: int = 1, + value_block_size: int | None = None, + candidate_block_size: int | None = None, +) -> None: + """Compute sink-aware sparse MLA over materialized BF16 KV. + + This path intentionally dequantizes/gathers KV once, computes scores with + batched matrix multiplication, and finishes the sink-aware value reduction + in Triton. It is useful for the SM12x decode path where the direct Triton + kernel otherwise repeats fp8_ds_mla dequantization once per head group. + """ + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}" + assert valid_tokens.shape == kv.shape[:2] + assert q.shape[0] == kv.shape[0] + assert q.shape[-1] == kv.shape[-1] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert q.is_cuda and kv.is_cuda and valid_tokens.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + + q_active = q[:, :active_heads] + num_tokens = q.shape[0] + num_candidates = kv.shape[1] + if score_buffer is None: + score_buffer = torch.empty( + (num_tokens, active_heads, num_candidates), + dtype=torch.float32, + device=q.device, + ) + assert score_buffer.shape == (num_tokens, active_heads, num_candidates) + assert score_buffer.device == q.device + assert score_buffer.dtype in (torch.float32, torch.bfloat16) + if score_buffer.dtype == torch.float32: + q_score = q_active.float() + kv_score = kv.float() + else: + q_score = q_active.to(score_buffer.dtype) + kv_score = kv.to(score_buffer.dtype) + torch.bmm(q_score, kv_score.transpose(1, 2), out=score_buffer) + score_buffer.mul_(scale) + finish_materialized_sparse_mla_scores_with_sink( + score_buffer, + kv, + valid_tokens, + attn_sink, + output, + num_heads=active_heads, + head_block_size=head_block_size, + value_block_size=value_block_size, + candidate_block_size=candidate_block_size, + ) + + +@triton.jit +def _finish_materialized_scores_with_sink_kernel( + scores_ptr, + kv_ptr, + valid_tokens_ptr, + attn_sink_ptr, + output_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_valid_t: tl.constexpr, + stride_valid_c: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + running_max = tl.load(attn_sink_ptr + head_offsets, mask=head_mask, other=0.0).to( + tl.float32 + ) + running_denom = tl.full((HEAD_BLOCK,), 1.0, tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + for candidate_idx in range(0, num_candidates): + is_valid = tl.load( + valid_tokens_ptr + + token_idx * stride_valid_t + + candidate_idx * stride_valid_c + ) + if is_valid: + score = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_offsets * stride_scores_h + + candidate_idx * stride_scores_c, + mask=head_mask, + other=-float("inf"), + ).to(tl.float32) + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidate_idx * stride_kv_c + + dim_offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + result = running_acc / running_denom[:, None] + tl.store( + output_ptr + + token_idx * stride_out_t + + head_offsets[:, None] * stride_out_h + + dim_offsets[None, :] * stride_out_d, + result, + mask=matrix_mask, + ) + + +@triton.jit +def _finish_materialized_scores_with_sink_candidate_block_kernel( + scores_ptr, + kv_ptr, + valid_tokens_ptr, + attn_sink_ptr, + output_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_valid_t: tl.constexpr, + stride_valid_c: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + head_dim: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + dim_block_idx = tl.program_id(2) + candidate_offsets = tl.arange(0, BLOCK_K) + dim_offsets = dim_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = dim_offsets < head_dim + + max_score = tl.load(attn_sink_ptr + head_idx).to(tl.float32) + for candidate_start in range(0, num_candidates, BLOCK_K): + candidates = candidate_start + candidate_offsets + candidate_mask = candidates < num_candidates + is_valid = tl.load( + valid_tokens_ptr + token_idx * stride_valid_t + candidates * stride_valid_c, + mask=candidate_mask, + other=0, + ).to(tl.int1) + scores = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidates * stride_scores_c, + mask=candidate_mask & is_valid, + other=-float("inf"), + ).to(tl.float32) + max_score = tl.maximum(max_score, tl.max(scores, axis=0)) + + denom = tl.exp(tl.load(attn_sink_ptr + head_idx).to(tl.float32) - max_score) + acc = tl.zeros((BLOCK_D,), tl.float32) + for candidate_start in range(0, num_candidates, BLOCK_K): + candidates = candidate_start + candidate_offsets + candidate_mask = candidates < num_candidates + is_valid = tl.load( + valid_tokens_ptr + token_idx * stride_valid_t + candidates * stride_valid_c, + mask=candidate_mask, + other=0, + ).to(tl.int1) + scores = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidates * stride_scores_c, + mask=candidate_mask & is_valid, + other=-float("inf"), + ).to(tl.float32) + weights = tl.exp(scores - max_score) + denom += tl.sum(weights, axis=0) + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidates[:, None] * stride_kv_c + + dim_offsets[None, :] * stride_kv_d, + mask=(candidate_mask & is_valid)[:, None] & dim_mask[None, :], + other=0.0, + ) + acc += tl.sum(kv.to(tl.float32) * weights[:, None], axis=0) + + tl.store( + output_ptr + + token_idx * stride_out_t + + head_idx * stride_out_h + + dim_offsets * stride_out_d, + acc / denom, + mask=dim_mask, + ) + + +@triton.jit +def _finish_materialized_scores_with_sink_value_block_kernel( + scores_ptr, + kv_ptr, + valid_tokens_ptr, + attn_sink_ptr, + output_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_valid_t: tl.constexpr, + stride_valid_c: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + head_dim: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + dim_block_idx = tl.program_id(2) + dim_offsets = dim_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = dim_offsets < head_dim + + running_max = tl.load(attn_sink_ptr + head_idx).to(tl.float32) + running_denom = tl.full((), 1.0, tl.float32) + running_acc = tl.zeros((BLOCK_D,), tl.float32) + + for candidate_idx in range(0, num_candidates): + is_valid = tl.load( + valid_tokens_ptr + + token_idx * stride_valid_t + + candidate_idx * stride_valid_c + ) + if is_valid: + score = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidate_idx * stride_scores_c + ).to(tl.float32) + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidate_idx * stride_kv_c + + dim_offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + result = running_acc / running_denom + tl.store( + output_ptr + + token_idx * stride_out_t + + head_idx * stride_out_h + + dim_offsets * stride_out_d, + result, + mask=dim_mask, + ) + + +def finish_materialized_sparse_mla_scores_with_sink( + scores: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, + num_heads: int | None = None, + head_block_size: int = 1, + value_block_size: int | None = None, + candidate_block_size: int | None = None, +) -> None: + assert scores.dim() == 3 + assert kv.dim() == 3 + assert valid_tokens.shape == kv.shape[:2] + assert scores.shape[0] == kv.shape[0] + assert scores.shape[2] == kv.shape[1] + assert output.shape[0] == kv.shape[0] + assert output.shape[2] == kv.shape[2] + assert scores.dtype in (torch.float32, torch.bfloat16) + assert head_block_size in (1, 2, 4) + if value_block_size is not None: + assert value_block_size in (64, 128, 256, 512) + if candidate_block_size is not None: + assert candidate_block_size in (16, 32, 64, 128) + assert scores.is_cuda and kv.is_cuda and valid_tokens.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= scores.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + + num_tokens, _, num_candidates = scores.shape + head_dim = kv.shape[2] + + # Clamp BLOCK_D to the smallest power-of-2 >= head_dim (among the allowed + # {64, 128, 256, 512}). Without this, a caller-supplied value_block_size + # larger than head_dim wastes work on masked-off positions — e.g. DSv4 + # head_dim=192 with value_block_size=512 masks off 62.5% of D-axis work + # in every program. Smaller caller values (intentional fine-grained + # splits along D) are respected. + def _smallest_block_d_covering(hd: int) -> int: + for cand in (64, 128, 256, 512): + if cand >= hd: + return cand + return 512 # head_dim > 512: BLOCK_D=512, grid splits along D + + if candidate_block_size is not None: + target_block_d = _smallest_block_d_covering(head_dim) + if value_block_size is None: + block_d = target_block_d + else: + block_d = min(value_block_size, target_block_d) + candidate_grid = (num_tokens, active_heads, triton.cdiv(head_dim, block_d)) + _finish_materialized_scores_with_sink_candidate_block_kernel[candidate_grid]( + scores, + kv, + valid_tokens, + attn_sink, + output, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + valid_tokens.stride(0), + valid_tokens.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + head_dim, + num_candidates, + BLOCK_K=candidate_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + if output.shape[1] > active_heads: + output[:, active_heads:].zero_() + return + + if value_block_size is not None and value_block_size < head_dim: + value_grid = ( + num_tokens, + active_heads, + triton.cdiv(head_dim, value_block_size), + ) + _finish_materialized_scores_with_sink_value_block_kernel[value_grid]( + scores, + kv, + valid_tokens, + attn_sink, + output, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + valid_tokens.stride(0), + valid_tokens.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + head_dim, + num_candidates, + BLOCK_D=value_block_size, + num_warps=4, + ) + if output.shape[1] > active_heads: + output[:, active_heads:].zero_() + return + + block_d = min(1024, next_power_of_2(head_dim)) + head_grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _finish_materialized_scores_with_sink_kernel[head_grid]( + scores, + kv, + valid_tokens, + attn_sink, + output, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + valid_tokens.stride(0), + valid_tokens.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + active_heads, + head_dim, + num_candidates, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + if output.shape[1] > active_heads: + output[:, active_heads:].zero_() + + +@triton.jit +def _accumulate_gathered_attention_chunk_kernel( + q_ptr, + kv_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HAS_SLOT_IDS: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + # Per-token early-loop-exit (see indexed kernel comment). + local_eff = tl.minimum( + num_candidates, + tl.maximum(valid_len - candidate_offset, 0), + ) + + # ``candidate_offset + candidate_idx < valid_len`` is structurally + # guaranteed by the ``local_eff`` cap above, so the only remaining + # gate is ``slot_id >= 0`` when slot ids are present. + for candidate_idx in range(0, local_eff): + if HAS_SLOT_IDS: + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = slot_id >= 0 + else: + is_valid = True + + if is_valid: + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidate_idx * stride_kv_c + + offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_gathered_sparse_mla_attention_chunk( + q: torch.Tensor, + kv: torch.Tensor, + lens: torch.Tensor, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, + slot_ids: torch.Tensor | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}" + assert q.shape[0] == kv.shape[0] + assert q.shape[-1] == kv.shape[-1] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert q.is_cuda and kv.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + if slot_ids is not None: + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + assert slot_ids.dim() == 2 + assert slot_ids.shape == kv.shape[:2] + assert slot_ids.is_cuda + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = kv.shape[1] + block_d = min(1024, next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_gathered_attention_chunk_kernel[grid]( + q, + kv, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + slot_ids.stride(0) if slot_ids is not None else 0, + slot_ids.stride(1) if slot_ids is not None else 0, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HAS_SLOT_IDS=slot_ids is not None, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=4, num_stages=2), + triton.Config({}, num_warps=4, num_stages=3), + triton.Config({}, num_warps=8, num_stages=2), + triton.Config({}, num_warps=8, num_stages=3), + ], + key=["num_candidates"], +) +@triton.jit +def _accumulate_indexed_attention_chunk_kernel( + q_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + # Per-token early-loop-exit: the combine_topk_swa_indices kernel writes + # ``[topk_len_t | swa_len_t | -1 padding]`` and stores + # ``lens[t] = topk_len_t + swa_len_t``. The existing ``is_valid`` guard + # already gates the heavy work past ``valid_len``, but the outer loop + # still iterates the full ``num_candidates`` (= chunk width). Capping + # the loop at ``min(num_candidates, valid_len - candidate_offset)`` + # saves the per-iteration index load + compare overhead on the dead + # tail. CUDA-graph-safe because ``lens_ptr`` is a stable address and + # the loaded value updates per call from the metadata builder. + local_eff = tl.minimum( + num_candidates, + tl.maximum(valid_len - candidate_offset, 0), + ) + + # ``candidate_offset + candidate_idx < valid_len`` is structurally + # guaranteed by the ``local_eff`` cap above; only the per-cell + # sentinel check (``kv_index >= 0``) is still meaningful. + for candidate_idx in range(0, local_eff): + kv_index = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidate_idx * stride_indices_c + ) + is_valid = kv_index >= 0 + + if is_valid: + kv = tl.load( + kv_flat_ptr + + kv_index.to(tl.int64) * stride_kv_t + + offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +_PREFILL_INDEXED_HEAD_BLOCK = 8 + + +@triton.jit +def _accumulate_indexed_attention_chunk_multihead_kernel( + q_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=head_mask[:, None] & dim_mask[None, :], + other=0.0, + ).to(tl.float32) + + state_base = token_idx * stride_state_t + running_max = tl.load( + max_score_ptr + state_base + head_offsets * stride_state_h, + mask=head_mask, + other=float("-inf"), + ) + running_denom = tl.load( + denom_ptr + state_base + head_offsets * stride_state_h, + mask=head_mask, + other=0.0, + ) + acc_base = token_idx * stride_acc_t + running_acc = tl.load( + acc_ptr + + acc_base + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d, + mask=head_mask[:, None] & dim_mask[None, :], + other=0.0, + ).to(tl.float32) + + valid_len = tl.load(lens_ptr + token_idx) + local_eff = tl.minimum( + num_candidates, + tl.maximum(valid_len - candidate_offset, 0), + ) + + for candidate_idx in range(0, local_eff): + kv_index = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidate_idx * stride_indices_c + ) + is_valid = kv_index >= 0 + + if is_valid: + kv = tl.load( + kv_flat_ptr + + kv_index.to(tl.int64) * stride_kv_t + + dim_offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + scores = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, scores) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(scores - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store( + max_score_ptr + state_base + head_offsets * stride_state_h, + running_max, + mask=head_mask, + ) + tl.store( + denom_ptr + state_base + head_offsets * stride_state_h, + running_denom, + mask=head_mask, + ) + tl.store( + acc_ptr + + acc_base + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d, + running_acc, + mask=head_mask[:, None] & dim_mask[None, :], + ) + + +def accumulate_indexed_sparse_mla_attention_chunk( + q: torch.Tensor, + kv_flat: torch.Tensor, + indices: torch.Tensor, + lens: torch.Tensor, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv_flat.dim() == 2 + assert indices.dim() == 2 + assert indices.shape[0] == q.shape[0] + assert kv_flat.shape[-1] == q.shape[-1] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = indices.shape[1] + block_d = min(1024, next_power_of_2(head_dim)) + head_block = _PREFILL_INDEXED_HEAD_BLOCK + + if num_heads >= head_block: + grid = (num_tokens, triton.cdiv(num_heads, head_block)) + _accumulate_indexed_attention_chunk_multihead_kernel[grid]( + q, + kv_flat, + indices, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block, + BLOCK_D=block_d, + num_warps=4, + num_stages=2, + ) + else: + grid = (num_tokens, num_heads) + _accumulate_indexed_attention_chunk_kernel[grid]( + q, + kv_flat, + indices, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + ) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=4, num_stages=2), + triton.Config({}, num_warps=4, num_stages=3), + triton.Config({}, num_warps=8, num_stages=2), + triton.Config({}, num_warps=8, num_stages=3), + ], + key=["num_candidates"], +) +@triton.jit +def _accumulate_fp8ds_global_slots_attention_chunk_kernel( + q_ptr, + k_cache_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + # Per-token early-loop-exit (see indexed kernel comment). + local_eff = tl.minimum( + num_candidates, + tl.maximum(valid_len - candidate_offset, 0), + ) + + fp8_mask = offsets < fp8_dim + rope_mask = (offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(offsets - fp8_dim, 0) + + # ``candidate_offset + candidate_idx < valid_len`` is structurally + # guaranteed by the ``local_eff`` cap above; only the per-cell + # sentinel check (``slot_id >= 0``) is still meaningful. + for candidate_idx in range(0, local_eff): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = slot_id >= 0 + + if is_valid: + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk( + q: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + lens: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = slot_ids.shape[1] + block_d = min(1024, next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_fp8ds_global_slots_attention_chunk_kernel[grid]( + q, + k_cache, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + # num_warps / num_stages supplied by @triton.autotune above. + ) + + +@triton.jit +def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel( + q_ptr, + k_cache_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h + acc_offsets = ( + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d + ) + running_max = tl.load( + max_score_ptr + state_offsets, + mask=head_mask, + other=-float("inf"), + ) + running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) + running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( + tl.float32 + ) + valid_len = tl.load(lens_ptr + token_idx) + # Per-token early-loop-exit: ``lens[t] = topk_len_t + swa_len_t`` (set + # by combine_topk_swa_indices). Iterating past ``valid_len`` only + # incurs the per-iter index-load + compare cost on padding-tail; cap + # the outer loop at ``valid_len - candidate_offset`` to skip the dead + # tail. CUDA-graph-safe because ``lens_ptr`` is a stable address. + local_eff = tl.minimum( + num_candidates, + tl.maximum(valid_len - candidate_offset, 0), + ) + + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + # ``candidate_offset + candidate_idx < valid_len`` is structurally + # guaranteed by the ``local_eff`` cap above; only the per-cell + # sentinel check (``slot_id >= 0``) is still meaningful. + for candidate_idx in range(0, local_eff): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = slot_id >= 0 + + if is_valid: + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) + tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) + tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + + +def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + lens: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, + head_block_size: int = 2, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert head_block_size in (1, 2, 4) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = slot_ids.shape[1] + block_d = min(1024, next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) + _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel[grid]( + q, + k_cache, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _indexed_d512_split_score_kernel( + q_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + scores_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + num_heads: tl.constexpr, + num_candidates: tl.constexpr, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_C: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + candidate_block = tl.program_id(2) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + candidate_offsets = candidate_block * BLOCK_C + tl.arange(0, BLOCK_C) + dim_offsets = tl.arange(0, HEAD_DIM) + head_mask = head_offsets < num_heads + valid_len = tl.load(lens_ptr + token_idx) + if candidate_block * BLOCK_C >= tl.minimum(valid_len, num_candidates): + return + candidate_mask = candidate_offsets < tl.minimum(valid_len, num_candidates) + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=head_mask[:, None], + other=0.0, + ) + kv_indices = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidate_offsets * stride_indices_c, + mask=candidate_mask, + other=-1, + ) + valid_kv = kv_indices >= 0 + kv = tl.load( + kv_flat_ptr + + kv_indices[None, :].to(tl.int64) * stride_kv_t + + dim_offsets[:, None] * stride_kv_d, + mask=valid_kv[None, :], + other=0.0, + ) + scores = tl.dot(q, kv) * scale + tl.store( + scores_ptr + + token_idx * stride_scores_t + + head_offsets[:, None] * stride_scores_h + + candidate_offsets[None, :] * stride_scores_c, + scores, + mask=head_mask[:, None] & candidate_mask[None, :], + ) + + +@triton.jit +def _indexed_d512_split_stats_kernel( + scores_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_C: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + candidate_offsets = tl.arange(0, BLOCK_C) + valid_len = tl.load(lens_ptr + token_idx) + candidate_mask = candidate_offsets < tl.minimum(valid_len, num_candidates) + scores = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidate_offsets * stride_scores_c, + mask=candidate_mask, + other=-float("inf"), + ).to(tl.float32) + running_max = tl.max(scores, axis=0) + safe_max = tl.where(valid_len > 0, running_max, 0.0) + weights = tl.where(candidate_mask, tl.exp(scores - safe_max), 0.0) + running_denom = tl.sum(weights, axis=0) + + tl.store( + max_score_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + running_max, + ) + tl.store( + denom_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + running_denom, + ) + + +@triton.jit +def _indexed_d512_split_value_kernel( + scores_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + max_score_ptr, + acc_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + num_candidates: tl.constexpr, + head_dim: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_C: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + dim_block = tl.program_id(2) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + candidate_offsets = tl.arange(0, BLOCK_C) + dim_offsets = dim_block * BLOCK_D + tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + valid_len = tl.load(lens_ptr + token_idx) + max_score = tl.load( + max_score_ptr + token_idx * stride_state_t + head_offsets * stride_state_h, + mask=head_mask, + other=0.0, + ).to(tl.float32) + safe_max = tl.where(valid_len > 0, max_score, 0.0) + acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + for candidate_start in range(0, num_candidates, BLOCK_C): + if candidate_start < tl.minimum(valid_len, num_candidates): + candidates = candidate_start + candidate_offsets + candidate_mask = candidates < tl.minimum(valid_len, num_candidates) + kv_indices = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidates * stride_indices_c, + mask=candidate_mask, + other=-1, + ) + valid_kv = kv_indices >= 0 + scores = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_offsets[:, None] * stride_scores_h + + candidates[None, :] * stride_scores_c, + mask=head_mask[:, None] & candidate_mask[None, :], + other=-float("inf"), + ).to(tl.float32) + weights = tl.where( + candidate_mask[None, :], + tl.exp(scores - safe_max[:, None]), + 0.0, + ) + values = tl.load( + kv_flat_ptr + + kv_indices[:, None].to(tl.int64) * stride_kv_t + + dim_offsets[None, :] * stride_kv_d, + mask=valid_kv[:, None] & dim_mask[None, :], + other=0.0, + ) + acc += tl.dot(weights.to(tl.bfloat16), values) + + tl.store( + acc_ptr + + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d, + acc, + mask=head_mask[:, None] & dim_mask[None, :], + ) + + +def accumulate_indexed_d512_split_sparse_mla_attention( + q: torch.Tensor, + kv_flat: torch.Tensor, + indices: torch.Tensor, + lens: torch.Tensor, + scale: float, + scores: torch.Tensor, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + head_block_size: int = 32, + candidate_block_size: int = 64, + value_block_size: int = 128, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv_flat.dim() == 2 + assert indices.dim() == 2 + assert indices.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert kv_flat.shape[-1] == q.shape[-1] + assert q.shape[-1] == 512 + assert scores.shape == (q.shape[0], max_score.shape[1], indices.shape[1]) + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert scores.dtype == torch.float32 + assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda + assert scores.is_cuda and max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert head_block_size in (8, 16, 32) + assert candidate_block_size in (32, 64, 128) + assert value_block_size in (32, 64, 128) + assert indices.shape[1] <= 1152 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = indices.shape[1] + score_grid = ( + num_tokens, + triton.cdiv(num_heads, head_block_size), + triton.cdiv(num_candidates, candidate_block_size), + ) + _indexed_d512_split_score_kernel[score_grid]( + q, + kv_flat, + indices, + lens, + scores, + q.stride(0), + q.stride(1), + q.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + scores.stride(0), + scores.stride(1), + scores.stride(2), + num_heads, + num_candidates, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_C=candidate_block_size, + HEAD_DIM=head_dim, + num_warps=8, + num_stages=3, + ) + + stats_grid = (num_tokens, num_heads) + stats_block_c = next_power_of_2(num_candidates) + _indexed_d512_split_stats_kernel[stats_grid]( + scores, + lens, + max_score, + denom, + scores.stride(0), + scores.stride(1), + scores.stride(2), + max_score.stride(0), + max_score.stride(1), + num_candidates, + BLOCK_C=stats_block_c, + num_warps=4, + num_stages=3, + ) + + value_grid = ( + num_tokens, + triton.cdiv(num_heads, head_block_size), + triton.cdiv(head_dim, value_block_size), + ) + _indexed_d512_split_value_kernel[value_grid]( + scores, + kv_flat, + indices, + lens, + max_score, + acc, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + num_candidates, + head_dim, + HEAD_BLOCK=head_block_size, + BLOCK_C=candidate_block_size, + BLOCK_D=value_block_size, + num_warps=4, + num_stages=3, + ) + + +@triton.jit +def _indexed_d512_chunked_merge_acc_kernel( + max_score_ptr, + acc_ptr, + chunk_max_score_ptr, + chunk_acc_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + dim_block = tl.program_id(2) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = dim_block * BLOCK_D + tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + + running_max = tl.load( + max_score_ptr + token_idx * stride_state_t + head_offsets * stride_state_h, + mask=head_mask, + other=-float("inf"), + ).to(tl.float32) + chunk_max = tl.load( + chunk_max_score_ptr + + token_idx * stride_state_t + + head_offsets * stride_state_h, + mask=head_mask, + other=-float("inf"), + ).to(tl.float32) + next_max = tl.maximum(running_max, chunk_max) + running_valid = running_max != -float("inf") + chunk_valid = chunk_max != -float("inf") + running_scale = tl.where(running_valid, tl.exp(running_max - next_max), 0.0) + chunk_scale = tl.where(chunk_valid, tl.exp(chunk_max - next_max), 0.0) + + running_acc = tl.load( + acc_ptr + + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d, + mask=head_mask[:, None] & dim_mask[None, :], + other=0.0, + ).to(tl.float32) + chunk_acc = tl.load( + chunk_acc_ptr + + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d, + mask=head_mask[:, None] & dim_mask[None, :], + other=0.0, + ).to(tl.float32) + merged_acc = running_acc * running_scale[:, None] + chunk_acc * chunk_scale[:, None] + tl.store( + acc_ptr + + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d, + merged_acc, + mask=head_mask[:, None] & dim_mask[None, :], + ) + + +@triton.jit +def _indexed_d512_chunked_merge_state_kernel( + max_score_ptr, + denom_ptr, + chunk_max_score_ptr, + chunk_denom_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + num_heads: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + head_mask = head_idx < num_heads + + running_max = tl.load( + max_score_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + mask=head_mask, + other=-float("inf"), + ).to(tl.float32) + running_denom = tl.load( + denom_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + mask=head_mask, + other=0.0, + ).to(tl.float32) + chunk_max = tl.load( + chunk_max_score_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + mask=head_mask, + other=-float("inf"), + ).to(tl.float32) + chunk_denom = tl.load( + chunk_denom_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + mask=head_mask, + other=0.0, + ).to(tl.float32) + next_max = tl.maximum(running_max, chunk_max) + running_valid = running_max != -float("inf") + chunk_valid = chunk_max != -float("inf") + running_scale = tl.where(running_valid, tl.exp(running_max - next_max), 0.0) + chunk_scale = tl.where(chunk_valid, tl.exp(chunk_max - next_max), 0.0) + next_denom = running_denom * running_scale + chunk_denom * chunk_scale + + tl.store( + max_score_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + next_max, + mask=head_mask, + ) + tl.store( + denom_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + next_denom, + mask=head_mask, + ) + + +def accumulate_indexed_d512_chunked_sparse_mla_attention( + q: torch.Tensor, + kv_flat: torch.Tensor, + indices: torch.Tensor, + lens: torch.Tensor, + scale: float, + scores: torch.Tensor, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + chunk_max_score: torch.Tensor, + chunk_denom: torch.Tensor, + chunk_acc: torch.Tensor, + head_block_size: int = 32, + candidate_block_size: int = 64, + value_block_size: int = 128, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv_flat.dim() == 2 + assert indices.dim() == 2 + assert indices.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert kv_flat.shape[-1] == q.shape[-1] + assert q.shape[-1] == 512 + assert scores.shape[0] == q.shape[0] + assert scores.shape[1] == max_score.shape[1] + assert 0 < scores.shape[2] <= 1152 + assert max_score.shape == denom.shape == chunk_max_score.shape == chunk_denom.shape + assert acc.shape == chunk_acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert chunk_max_score.dtype == torch.float32 + assert chunk_denom.dtype == torch.float32 + assert chunk_acc.dtype == torch.float32 + assert scores.dtype == torch.float32 + assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda + assert scores.is_cuda and max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert chunk_max_score.is_cuda and chunk_denom.is_cuda and chunk_acc.is_cuda + assert head_block_size in (8, 16, 32) + assert candidate_block_size in (32, 64, 128) + assert value_block_size in (32, 64, 128) + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + chunk_size = scores.shape[2] + max_score.fill_(float("-inf")) + denom.zero_() + acc.zero_() + + merge_acc_grid = ( + num_tokens, + triton.cdiv(num_heads, head_block_size), + triton.cdiv(head_dim, value_block_size), + ) + merge_state_grid = (num_tokens, num_heads) + for candidate_start in range(0, indices.shape[1], chunk_size): + candidate_end = min(candidate_start + chunk_size, indices.shape[1]) + chunk_candidates = candidate_end - candidate_start + chunk_lens = torch.clamp( + lens - candidate_start, + min=0, + max=chunk_candidates, + ) + accumulate_indexed_d512_split_sparse_mla_attention( + q=q, + kv_flat=kv_flat, + indices=indices[:, candidate_start:candidate_end], + lens=chunk_lens, + scale=scale, + scores=scores[:, :, :chunk_candidates], + max_score=chunk_max_score, + denom=chunk_denom, + acc=chunk_acc, + head_block_size=head_block_size, + candidate_block_size=candidate_block_size, + value_block_size=value_block_size, + ) + _indexed_d512_chunked_merge_acc_kernel[merge_acc_grid]( + max_score, + acc, + chunk_max_score, + chunk_acc, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + HEAD_BLOCK=head_block_size, + BLOCK_D=value_block_size, + num_warps=4, + num_stages=3, + ) + _indexed_d512_chunked_merge_state_kernel[merge_state_grid]( + max_score, + denom, + chunk_max_score, + chunk_denom, + max_score.stride(0), + max_score.stride(1), + num_heads, + num_warps=4, + num_stages=3, + ) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=4, num_stages=2), + triton.Config({}, num_warps=4, num_stages=3), + triton.Config({}, num_warps=8, num_stages=2), + triton.Config({}, num_warps=8, num_stages=3), + ], + key=["num_candidates"], +) +@triton.jit +def _accumulate_fp8ds_paged_attention_chunk_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = offsets < fp8_dim + rope_mask = (offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(offsets - fp8_dim, 0) + # Per-token early-loop-exit (see indexed kernel comment). + local_eff = tl.minimum( + num_candidates, + tl.maximum(gather_len - candidate_offset, 0), + ) + + # ``gather_idx < gather_len`` is structurally guaranteed by the + # ``local_eff`` cap above; the body is unconditional. + for candidate_idx in range(0, local_eff): + gather_idx = candidate_offset + candidate_idx + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int, + num_candidates: int, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + block_d = min(1024, next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_fp8ds_paged_attention_chunk_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + # num_warps / num_stages supplied by @triton.autotune above. + ) + + +@triton.jit +def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h + acc_offsets = ( + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d + ) + running_max = tl.load( + max_score_ptr + state_offsets, + mask=head_mask, + other=-float("inf"), + ) + running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) + running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( + tl.float32 + ) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + # Per-token early-loop-exit: ``gather_len`` is the per-token count of + # cached entries available for this paged read; the existing + # ``is_valid`` guard skips heavy work past that, but we can also skip + # the per-iter index load + branch by capping the loop. CUDA-graph- + # safe because ``gather_lens_ptr`` is a stable address. + local_eff = tl.minimum( + num_candidates, + tl.maximum(gather_len - candidate_offset, 0), + ) + + # ``gather_idx < gather_len`` is structurally guaranteed by the + # ``local_eff`` cap above; the body is unconditional. + for candidate_idx in range(0, local_eff): + gather_idx = candidate_offset + candidate_idx + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) + tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) + tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + + +def accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int, + num_candidates: int, + head_block_size: int = 2, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert head_block_size in (1, 2, 4) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + block_d = min(1024, next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) + _accumulate_fp8ds_paged_attention_chunk_multihead_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _fp8ds_paged_attention_with_sink_multihead_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + sink_ptr, + output_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + candidate_offset: tl.constexpr, + num_candidates: tl.constexpr, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + running_max = tl.full((HEAD_BLOCK,), -float("inf"), tl.float32) + running_denom = tl.zeros((HEAD_BLOCK,), tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + # Per-token early-loop-exit: ``gather_len`` is the per-token count of + # cached entries available for this paged read; the existing + # ``is_valid`` guard skips heavy work past that, but we can also skip + # the per-iter index load + branch by capping the loop. CUDA-graph- + # safe because ``gather_lens_ptr`` is a stable address. + local_eff = tl.minimum( + num_candidates, + tl.maximum(gather_len - candidate_offset, 0), + ) + + # ``gather_idx < gather_len`` is structurally guaranteed by the + # ``local_eff`` cap above; the body is unconditional. + for candidate_idx in range(0, local_eff): + gather_idx = candidate_offset + candidate_idx + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + sink = tl.load(sink_ptr + head_offsets, mask=head_mask, other=-float("inf")) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = running_denom * subset_scale + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + final = running_acc * subset_scale[:, None] * inv_total[:, None] + + tl.store( + output_ptr + + token_idx * stride_output_t + + head_offsets[:, None] * stride_output_h + + dim_offsets[None, :] * stride_output_d, + final, + mask=matrix_mask, + ) + + +def fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + candidate_offset: int, + num_candidates: int, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + head_block_size: int = 1, + num_heads: int | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert head_block_size in (1, 2, 4) + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + block_d = min(1024, next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _fp8ds_paged_attention_with_sink_multihead_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + attn_sink, + output, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + output.stride(0), + output.stride(1), + output.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + active_heads, + head_dim, + candidate_offset, + num_candidates, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _fp8ds_global_paged_attention_with_sink_multihead_kernel( + q_ptr, + compressed_k_cache_ptr, + slot_ids_ptr, + topk_lens_ptr, + swa_k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + sink_ptr, + output_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_block_table_t, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + compressed_cache_block_size: tl.constexpr, + compressed_block_stride: tl.constexpr, + swa_cache_block_size: tl.constexpr, + swa_block_stride: tl.constexpr, + token_data_size: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_compressed_candidates: tl.constexpr, + num_swa_candidates: tl.constexpr, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + running_max = tl.full((HEAD_BLOCK,), -float("inf"), tl.float32) + running_denom = tl.zeros((HEAD_BLOCK,), tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + topk_len = tl.load(topk_lens_ptr + token_idx) + + for candidate_idx in range(0, num_compressed_candidates): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = (candidate_idx < topk_len) & (slot_id >= 0) + if is_valid: + block_idx = slot_id // compressed_cache_block_size + pos_in_block = slot_id % compressed_cache_block_size + cache_block_ptr = ( + compressed_k_cache_ptr + + block_idx.to(tl.int64) * compressed_block_stride + ) + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + compressed_cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + for candidate_idx in range(0, num_swa_candidates): + is_valid = candidate_idx < gather_len + if is_valid: + pos = start_pos + candidate_idx + block_in_seq = pos // swa_cache_block_size + pos_in_block = pos % swa_cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = ( + swa_k_cache_ptr + physical_block.to(tl.int64) * swa_block_stride + ) + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + swa_cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + sink = tl.load(sink_ptr + head_offsets, mask=head_mask, other=-float("inf")) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = running_denom * subset_scale + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + final = running_acc * subset_scale[:, None] * inv_total[:, None] + + tl.store( + output_ptr + + token_idx * stride_output_t + + head_offsets[:, None] * stride_output_h + + dim_offsets[None, :] * stride_output_d, + final, + mask=matrix_mask, + ) + + +def fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q: torch.Tensor, + compressed_k_cache: torch.Tensor, + slot_ids: torch.Tensor, + topk_lens: torch.Tensor, + compressed_block_size: int, + swa_k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + swa_block_size: int, + num_compressed_candidates: int, + num_swa_candidates: int, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + head_block_size: int = 1, + num_heads: int | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert topk_lens.shape[0] == q.shape[0] + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert head_block_size in (1, 2, 4) + assert compressed_k_cache.dtype == torch.uint8 + assert swa_k_cache.dtype == torch.uint8 + assert q.is_cuda and compressed_k_cache.is_cuda and swa_k_cache.is_cuda + assert slot_ids.is_cuda and topk_lens.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + block_d = min(1024, next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _fp8ds_global_paged_attention_with_sink_multihead_kernel[grid]( + q, + compressed_k_cache, + slot_ids, + topk_lens, + swa_k_cache, + seq_lens, + gather_lens, + block_table, + attn_sink, + output, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + block_table.stride(0), + output.stride(0), + output.stride(1), + output.stride(2), + compressed_block_size, + compressed_k_cache.stride(0), + swa_block_size, + swa_k_cache.stride(0), + token_data_size, + token_fp8_dim, + token_scale_dim, + quant_block_size, + active_heads, + head_dim, + num_compressed_candidates, + num_swa_candidates, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _finish_attention_state_kernel( + max_score_ptr, + denom_ptr, + acc_ptr, + output_ptr, + lse_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + stride_lse_t: tl.constexpr, + stride_lse_h: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + is_valid = running_denom > 0.0 + inv_denom = tl.where(is_valid, 1.0 / running_denom, 0.0) + subset_lse = tl.where( + is_valid, + running_max + tl.log(running_denom), + -float("inf"), + ) + + acc = tl.load( + acc_ptr + + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + subset_output = acc * inv_denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + subset_output, + mask=dim_mask, + ) + if block_d == 0: + tl.store( + lse_ptr + token_idx * stride_lse_t + head_idx * stride_lse_h, + subset_lse, + ) + + +def finish_gathered_sparse_mla_attention( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + output: torch.Tensor, + lse: torch.Tensor, +) -> None: + assert max_score.shape == denom.shape + assert acc.shape[:2] == max_score.shape + assert output.shape == acc.shape + assert lse.shape == max_score.shape + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert output.dtype == torch.float32 + assert lse.dtype == torch.float32 + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert output.is_cuda and lse.is_cuda + + num_tokens, num_heads, head_dim = acc.shape + block_d = min(128, next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_attention_state_kernel[grid]( + max_score, + denom, + acc, + output, + lse, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + lse.stride(0), + lse.stride(1), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _finish_attention_state_with_sink_kernel( + max_score_ptr, + denom_ptr, + acc_ptr, + sink_ptr, + output_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + sink = tl.load(sink_ptr + head_idx) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0) + subset_weight = running_denom * subset_scale + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = subset_weight + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + + acc_values = tl.load( + acc_ptr + + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc_values = tl.where(has_tokens, acc_values, 0.0) + output = acc_values * subset_scale * inv_total + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + output, + mask=dim_mask, + ) + + +@triton.jit +def _finish_two_attention_states_with_sink_kernel( + max0_ptr, + denom0_ptr, + acc0_ptr, + max1_ptr, + denom1_ptr, + acc1_ptr, + sink_ptr, + output_ptr, + stride_state0_t: tl.constexpr, + stride_state0_h: tl.constexpr, + stride_acc0_t: tl.constexpr, + stride_acc0_h: tl.constexpr, + stride_acc0_d: tl.constexpr, + stride_state1_t: tl.constexpr, + stride_state1_h: tl.constexpr, + stride_acc1_t: tl.constexpr, + stride_acc1_h: tl.constexpr, + stride_acc1_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state0_offset = token_idx * stride_state0_t + head_idx * stride_state0_h + state1_offset = token_idx * stride_state1_t + head_idx * stride_state1_h + max0 = tl.load(max0_ptr + state0_offset) + denom0 = tl.load(denom0_ptr + state0_offset) + max1 = tl.load(max1_ptr + state1_offset) + denom1 = tl.load(denom1_ptr + state1_offset) + sink = tl.load(sink_ptr + head_idx) + + has0 = denom0 > 0.0 + has1 = denom1 > 0.0 + has_sink = sink > -float("inf") + valid_max0 = tl.where(has0, max0, -float("inf")) + valid_max1 = tl.where(has1, max1, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(tl.maximum(valid_max0, valid_max1), valid_sink) + has_any = has0 | has1 | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_max0 = tl.where(has0, max0, safe_merge_max) + safe_max1 = tl.where(has1, max1, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + scale0 = tl.where(has0, tl.exp(safe_max0 - safe_merge_max), 0.0) + scale1 = tl.where(has1, tl.exp(safe_max1 - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = denom0 * scale0 + denom1 * scale1 + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + + acc0 = tl.load( + acc0_ptr + + token_idx * stride_acc0_t + + head_idx * stride_acc0_h + + offsets * stride_acc0_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc1 = tl.load( + acc1_ptr + + token_idx * stride_acc1_t + + head_idx * stride_acc1_h + + offsets * stride_acc1_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc0 = tl.where(has0, acc0, 0.0) + acc1 = tl.where(has1, acc1, 0.0) + output = (acc0 * scale0 + acc1 * scale1) * inv_total + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + output, + mask=dim_mask, + ) + + +def finish_two_sparse_mla_attention_states_with_sink( + max_score0: torch.Tensor, + denom0: torch.Tensor, + acc0: torch.Tensor, + max_score1: torch.Tensor, + denom1: torch.Tensor, + acc1: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert max_score0.shape == denom0.shape + assert max_score1.shape == denom1.shape + assert max_score0.shape == max_score1.shape + assert acc0.shape == acc1.shape + assert acc0.shape[:2] == max_score0.shape + assert output.shape[0] == acc0.shape[0] + assert output.shape[1] >= acc0.shape[1] + assert output.shape[2] == acc0.shape[2] + assert attn_sink.shape[0] >= acc0.shape[1] + assert max_score0.dtype == torch.float32 + assert denom0.dtype == torch.float32 + assert acc0.dtype == torch.float32 + assert max_score1.dtype == torch.float32 + assert denom1.dtype == torch.float32 + assert acc1.dtype == torch.float32 + assert max_score0.is_cuda and denom0.is_cuda and acc0.is_cuda + assert max_score1.is_cuda and denom1.is_cuda and acc1.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + num_tokens, num_heads, head_dim = acc0.shape + block_d = min(128, next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_two_attention_states_with_sink_kernel[grid]( + max_score0, + denom0, + acc0, + max_score1, + denom1, + acc1, + attn_sink, + output, + max_score0.stride(0), + max_score0.stride(1), + acc0.stride(0), + acc0.stride(1), + acc0.stride(2), + max_score1.stride(0), + max_score1.stride(1), + acc1.stride(0), + acc1.stride(1), + acc1.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +def finish_sparse_mla_attention_with_sink( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert max_score.shape == denom.shape + assert acc.shape[:2] == max_score.shape + assert output.shape[0] == acc.shape[0] + assert output.shape[1] >= acc.shape[1] + assert output.shape[2] == acc.shape[2] + assert attn_sink.shape[0] >= acc.shape[1] + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + num_tokens, num_heads, head_dim = acc.shape + block_d = min(128, next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_attention_state_with_sink_kernel[grid]( + max_score, + denom, + acc, + attn_sink, + output, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index df23f34378e2..0e54a27ba30a 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -5,6 +5,7 @@ import torch +import vllm.envs as envs from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.platforms import current_platform @@ -17,6 +18,9 @@ CommonAttentionMetadata, MultipleOf, ) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled, +) from vllm.v1.attention.backends.utils import split_decodes_and_prefills from vllm.v1.attention.ops.flashmla import FlashMLASchedMeta, get_mla_metadata from vllm.v1.kv_cache_interface import ( @@ -178,8 +182,10 @@ class DeepseekSparseSWAMetadata: # Pre-computed prefill metadata shared across all DeepseekV4 attention layers. prefill_seq_lens: torch.Tensor | None = None - prefill_seq_lens_cpu: torch.Tensor | None = None prefill_gather_lens: torch.Tensor | None = None + prefill_seq_lens_cpu: torch.Tensor | None = None + prefill_gather_lens_cpu: torch.Tensor | None = None + # Inputs to the adaptive prefill chunk planner (#45061). prefill_query_lens_cpu: torch.Tensor | None = None prefill_window_size: int = 0 prefill_max_model_len: int = 0 @@ -300,27 +306,20 @@ def __init__(self, *args, **kwargs): self.head_size = mla_spec.head_size # Already considered quantization. self.compress_ratio = mla_spec.compress_ratio self.block_size = mla_spec.block_size - self.max_model_len = self.vllm_config.model_config.max_model_len - self.max_num_batched_tokens = ( - self.vllm_config.scheduler_config.max_num_batched_tokens - ) - # Handle MTP: adjust decode_threshold like the indexer does - self.num_speculative_tokens = ( - self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config - else 0 - ) - # With MTP, decode can have query_len up to 1 + num_speculative_tokens. - # Must match the threshold used by the indexer and flashmla_sparse so - # that all backends agree on the decode/prefill split. - self.decode_threshold = ( - self.reorder_batch_threshold + self.num_speculative_tokens - ) + # Handle MTP: classify single-token queries plus speculative tokens as + # decodes, matching the runner-side batch reorder threshold. + self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) + assert self.reorder_batch_threshold is not None + self.decode_threshold = self.reorder_batch_threshold hf_config = self.vllm_config.model_config.hf_config assert hasattr(hf_config, "sliding_window") self.window_size = hf_config.sliding_window + self.max_model_len = self.vllm_config.model_config.max_model_len + self.max_num_batched_tokens = ( + self.vllm_config.scheduler_config.max_num_batched_tokens + ) # Detect which DeepseekV4 layer types this model uses so we only build a # FlashMLA tile-scheduler plan for types that will actually be called. @@ -384,7 +383,6 @@ def build( """ num_reqs = common_attn_metadata.num_reqs seq_lens = common_attn_metadata.seq_lens - seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu block_table = common_attn_metadata.block_table_tensor @@ -407,6 +405,7 @@ def build( is_valid_token = self.is_valid_token[: slot_mapping.shape[0]] is_valid_token.copy_(slot_mapping >= 0) + # Compute SWA window indices/lens for the decode rows once per step. if num_decode_tokens > 0: self.decode_swa_lens[num_decode_tokens:] = 0 _compute_swa_indices_and_lens_kernel[(num_decode_tokens,)]( @@ -425,10 +424,18 @@ def build( TRITON_BLOCK_SIZE=1024, ) - # Prefill SWA indices live in paged coordinates. `token_offset` lets - # the kernel read is_valid_token / token_to_req_indices at absolute - # prefill positions while writing output starting at index 0. - if num_prefill_tokens > 0: + # Prefill SWA indices (paged coords; `token_offset` lets the kernel read at + # absolute prefill positions while writing from index 0) are consumed ONLY by + # the FlashInfer SM120 sparse-MLA fork path. The stock FlashMLA/Triton prefill + # self-computes and never reads them, so gate the launch behind + # VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL (default off). Running it + # unconditionally faulted `_compute_swa_indices_and_lens_kernel` over 32k + # prefill rows (unclamped block_table address arithmetic on masked-off lanes + # -> cudaErrorLaunchFailure under concurrent load). + want_prefill_swa = ( + num_prefill_tokens > 0 and envs.VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL + ) + if want_prefill_swa: prefill_swa_indices = self.prefill_swa_indices[:num_prefill_tokens] prefill_swa_lens = self.prefill_swa_lens[:num_prefill_tokens] _compute_swa_indices_and_lens_kernel[(num_prefill_tokens,)]( @@ -452,9 +459,9 @@ def build( num_decodes, num_prefills, seq_lens, - seq_lens_cpu, query_start_loc, query_start_loc_cpu, + common_attn_metadata.seq_lens_cpu_upper_bound, ) # Per-layer-type tile-scheduler plan holders. Empty FlashMLASchedMeta @@ -475,12 +482,12 @@ def build( decode_swa_lens=self.decode_swa_lens[:num_decode_tokens], prefill_swa_indices=( self.prefill_swa_indices[:num_prefill_tokens] - if num_prefill_tokens > 0 + if want_prefill_swa else None ), prefill_swa_lens=( self.prefill_swa_lens[:num_prefill_tokens] - if num_prefill_tokens > 0 + if want_prefill_swa else None ), block_size=self.block_size, @@ -520,6 +527,8 @@ def build_tile_scheduler( or current_platform.is_device_capability_family(120) ): return out + if is_triton_sparse_mla_enabled(self.device): + return out for layer_type in self._layer_types: # get_mla_metadata() is the official FlashMLA entry point that # returns a fresh empty FlashMLASchedMeta; using it keeps this @@ -533,9 +542,9 @@ def _build_deepseek_v4_metadata( num_decodes: int, num_prefills: int, seq_lens: torch.Tensor, - seq_lens_cpu: torch.Tensor | None, query_start_loc: torch.Tensor, query_start_loc_cpu: torch.Tensor, + seq_lens_cpu_upper_bound: torch.Tensor | None, ) -> dict[str, torch.Tensor | int | None]: """Pre-compute DeepseekV4 prefill metadata during the metadata build phase. @@ -549,7 +558,6 @@ def _build_deepseek_v4_metadata( # --- Prefill query metadata (single Triton kernel + CPU slicing) --- if num_prefills > 0: - assert seq_lens_cpu is not None pfx_gather_lens = torch.empty( num_prefills, dtype=torch.int32, device=seq_lens.device ) @@ -563,13 +571,26 @@ def _build_deepseek_v4_metadata( BLOCK_SIZE=triton.next_power_of_2(num_prefills), ) - result["prefill_seq_lens"] = seq_lens[num_decodes:] - result["prefill_seq_lens_cpu"] = seq_lens_cpu[num_decodes:] - result["prefill_gather_lens"] = pfx_gather_lens - result["prefill_query_lens_cpu"] = ( + assert seq_lens_cpu_upper_bound is not None + seq_lens_cpu = seq_lens_cpu_upper_bound + prefill_seq_lens_cpu = seq_lens_cpu[ + num_decodes : num_decodes + num_prefills + ] + query_lens_cpu = ( query_start_loc_cpu[num_decodes + 1 : num_decodes + num_prefills + 1] - query_start_loc_cpu[num_decodes : num_decodes + num_prefills] - ).to(dtype=torch.int32) + ) + prefix_lens_cpu = prefill_seq_lens_cpu - query_lens_cpu + prefill_gather_lens_cpu = query_lens_cpu + torch.minimum( + prefix_lens_cpu, + torch.full_like(prefix_lens_cpu, self.window_size - 1), + ) + + result["prefill_seq_lens"] = seq_lens[num_decodes:] + result["prefill_gather_lens"] = pfx_gather_lens + result["prefill_seq_lens_cpu"] = prefill_seq_lens_cpu + result["prefill_gather_lens_cpu"] = prefill_gather_lens_cpu + result["prefill_query_lens_cpu"] = query_lens_cpu.to(dtype=torch.int32) result["prefill_window_size"] = self.window_size result["prefill_max_model_len"] = self.max_model_len result["prefill_max_num_batched_tokens"] = self.max_num_batched_tokens @@ -652,8 +673,14 @@ def _compute_swa_indices_and_lens_kernel( pos_offset = start_pos + offset block_indices = pos_offset // block_size + # Clamp masked-off lanes before the address add: SM12x + Triton 3.6 raises + # IMA on out-of-bounds address arithmetic even when the load mask gates the + # read (same hazard the sibling _compute_prefill_metadata_kernel clamps via + # safe_offset). Over a deep prefill row the tail lanes index past the + # request's block_table row -> cudaErrorLaunchFailure without this. + safe_block_indices = tl.where(pos_offset < end_pos, block_indices, 0) block_numbers = tl.load( - block_table_ptr + req_idx * block_table_stride + block_indices, + block_table_ptr + req_idx * block_table_stride + safe_block_indices, mask=pos_offset < end_pos, ) block_offsets = pos_offset % block_size diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 376f65f6697a..96d74d0bb950 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from collections.abc import Sequence +from math import lcm from typing import NamedTuple from vllm import envs @@ -74,6 +75,7 @@ def __init__( pcp_world_size: int, scheduler_block_size: int, hash_block_size: int, + max_num_seqs: int | None = None, metrics_collector: KVCacheMetricsCollector | None = None, ): self.kv_cache_config = kv_cache_config @@ -114,6 +116,7 @@ def __init__( dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, scheduler_block_size=self.scheduler_block_size, + max_num_seqs=max_num_seqs, ) for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups) ) @@ -342,6 +345,21 @@ def remove_skipped_blocks( for manager in self.single_type_managers: manager.remove_skipped_blocks(request_id, total_computed_tokens) + def release_protected_prompt_blocks( + self, + target_free_blocks: int | None = None, + block_ids_to_skip: set[int] | None = None, + ) -> None: + for manager in self.single_type_managers: + if ( + target_free_blocks is not None + and self.block_pool.get_num_free_blocks() >= target_free_blocks + ): + return + manager.release_protected_prompt_blocks( + target_free_blocks, block_ids_to_skip + ) + def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]: """ Get the blocks for the request. @@ -384,6 +402,7 @@ def __init__( pcp_world_size: int, scheduler_block_size: int, hash_block_size: int, + max_num_seqs: int | None = None, metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( @@ -397,6 +416,7 @@ def __init__( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=metrics_collector, ) self.num_single_type_manager = len(self.single_type_managers) @@ -434,6 +454,7 @@ def __init__( pcp_world_size: int, scheduler_block_size: int, hash_block_size: int, + max_num_seqs: int | None = None, metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( @@ -447,6 +468,7 @@ def __init__( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=metrics_collector, ) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec @@ -520,6 +542,7 @@ def __init__( pcp_world_size: int, scheduler_block_size: int, hash_block_size: int, + max_num_seqs: int | None = None, metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( @@ -533,6 +556,7 @@ def __init__( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=metrics_collector, ) # hash_block_size: the block size used to compute block hashes. @@ -590,31 +614,32 @@ def verify_and_split_kv_cache_groups(self) -> None: for gid in group.group_ids: self.single_type_managers[gid].use_eagle = True + # The LCM of the block sizes of all attention types. + # The cache hit length must be a multiple of the LCM of the block sizes + # to make sure the cache hit length is a multiple of the block size of + # each attention type. Requiring this because we don't support partial + # block cache hit yet. + block_sizes = [group.spec.block_size for group in self.attention_groups] + self.lcm_block_size = lcm(*block_sizes) + for manager in self.single_type_managers: + manager.cache_alignment_tokens = self.lcm_block_size + + # Attention-group indices (into ``self.attention_groups``) that + # contain at least one EAGLE/MTP KV cache group. + self.eagle_attn_group_indices: set[int] = { + i for i, group in enumerate(self.attention_groups) if group.use_eagle + } + def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: # Cache hits in this coordinator are always a multiple of # ``scheduler_block_size`` tokens (see ``find_longest_cache_hit``). - # Within an aligned region, SWA groups may only consult a subset of blocks - # per ``scheduler_block_size``-segment so the unused blocks also stay - # out of the prefix-cache hash map. - aligned_num_computed_tokens = ( - num_computed_tokens // self.scheduler_block_size * self.scheduler_block_size - ) + # Managers may still cache complete tail blocks after the last aligned + # boundary; ``find_longest_cache_hit`` keeps returned hits aligned. for manager in self.single_type_managers: - num_tokens_to_cache = aligned_num_computed_tokens - # EAGLE groups match one block past each aligned boundary and drop - # it, so make that lookahead block eligible to be cached. - if manager.use_eagle and aligned_num_computed_tokens > 0: - num_tokens_to_cache = min( - num_computed_tokens, - aligned_num_computed_tokens + manager.block_size, - ) - # The manager already knows the fine hit granularity - # (``scheduler_block_size``); retention is passed separately so it - # can keep both the coarse segment tails and the fine replay - # boundary (which needs the fine value). manager.cache_blocks( request, - num_tokens_to_cache, + num_computed_tokens, + alignment_tokens=self.scheduler_block_size, retention_interval=self.retention_interval, ) @@ -781,6 +806,7 @@ def get_kv_cache_coordinator( pcp_world_size: int, scheduler_block_size: int, hash_block_size: int, + max_num_seqs: int | None = None, metrics_collector: KVCacheMetricsCollector | None = None, ) -> KVCacheCoordinator: if not enable_caching: @@ -794,6 +820,7 @@ def get_kv_cache_coordinator( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=metrics_collector, ) if len(kv_cache_config.kv_cache_groups) == 1: @@ -808,6 +835,7 @@ def get_kv_cache_coordinator( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=metrics_collector, ) return HybridKVCacheCoordinator( @@ -821,5 +849,6 @@ def get_kv_cache_coordinator( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=metrics_collector, ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index b0f6655bf957..ac782954ce56 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -121,6 +121,7 @@ def __init__( enable_kv_cache_events: bool = False, dcp_world_size: int = 1, pcp_world_size: int = 1, + max_num_seqs: int | None = None, metrics_collector: KVCacheMetricsCollector | None = None, watermark: float = 0.0, ) -> None: @@ -151,6 +152,7 @@ def __init__( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=self.metrics_collector, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) @@ -199,12 +201,18 @@ def make_prefix_cache_stats(self) -> PrefixCacheStats | None: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: + def get_computed_blocks( + self, + request: Request, + *, + record_stats: bool = True, + ) -> tuple[KVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. Args: request: The request to get the computed blocks. + record_stats: Whether to record prefix-cache stats for this lookup. Returns: A tuple containing: @@ -231,7 +239,7 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: ) ) - if self.log_stats: + if self.log_stats and record_stats: assert self.prefix_cache_stats is not None self.prefix_cache_stats.record( num_tokens=request.num_tokens, @@ -359,6 +367,9 @@ def allocate_slots( num_local_computed_tokens + num_external_computed_tokens, self.max_model_len, ) + block_ids_to_skip_releasing = self._block_ids_to_skip_releasing( + new_computed_block_list + ) watermark_blocks = 0 # The watermark is applied to waiting/preempted requests only, and only @@ -382,8 +393,11 @@ def allocate_slots( num_tokens_main_model=full_num_tokens, apply_admission_cap=True, ) - required_blocks = num_blocks_to_allocate + watermark_blocks - if required_blocks > self.block_pool.get_num_free_blocks(): + if not self._has_enough_free_blocks( + num_blocks_to_allocate, + block_ids_to_skip_releasing, + watermark_blocks=watermark_blocks, + ): return None num_tokens_main_model = total_computed_tokens + num_new_tokens @@ -411,11 +425,12 @@ def allocate_slots( num_tokens_main_model=num_tokens_main_model, ) - # Keep `reserved_blocks` free for other in-flight sequences, and an - # additional watermark of headroom for waiting/preempted admissions. - available_blocks = self.block_pool.get_num_free_blocks() - reserved_blocks - required_blocks = num_blocks_to_allocate + watermark_blocks - if required_blocks > available_blocks: + if not self._has_enough_free_blocks( + num_blocks_to_allocate, + block_ids_to_skip_releasing, + reserved_blocks=reserved_blocks, + watermark_blocks=watermark_blocks, + ): # Cannot allocate new blocks return None @@ -501,6 +516,32 @@ def evict_blocks(self, block_ids: set[int]) -> None: """ self.block_pool.evict_blocks(block_ids) + @staticmethod + def _block_ids_to_skip_releasing( + blocks: tuple[Sequence[KVCacheBlock], ...], + ) -> set[int]: + return { + block.block_id + for group_blocks in blocks + for block in group_blocks + if not block.is_null + } + + def _has_enough_free_blocks( + self, + num_blocks: int, + block_ids_to_skip_releasing: set[int] | None = None, + reserved_blocks: int = 0, + watermark_blocks: int = 0, + ) -> bool: + required_free_blocks = num_blocks + reserved_blocks + watermark_blocks + if required_free_blocks <= self.block_pool.get_num_free_blocks(): + return True + self.coordinator.release_protected_prompt_blocks( + required_free_blocks, block_ids_to_skip_releasing + ) + return required_free_blocks <= self.block_pool.get_num_free_blocks() + def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalidate prefix caching after the weights are updated, @@ -510,6 +551,7 @@ def reset_prefix_cache(self) -> bool: bool: True if the prefix cache is successfully reset, False otherwise. """ + self.coordinator.release_protected_prompt_blocks() if not self.block_pool.reset_prefix_cache(): return False if self.log_stats: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 90d93a110cc0..737fee6d2998 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -263,6 +263,7 @@ def __init__( pcp_world_size=self.pcp_world_size, scheduler_block_size=self.block_size, hash_block_size=hash_block_size, + max_num_seqs=self.scheduler_config.max_num_seqs, metrics_collector=self.kv_metrics_collector, watermark=self.scheduler_config.watermark, ) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index c98c59017c52..0105ad0f482f 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from abc import ABC, abstractmethod -from collections import defaultdict +from collections import defaultdict, deque from collections.abc import Sequence from vllm.utils.math_utils import cdiv @@ -45,6 +45,8 @@ def __init__( dcp_world_size: int = 1, pcp_world_size: int = 1, max_admission_blocks_per_request: int | None = None, + max_model_len: int | None = None, + max_num_seqs: int | None = None, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -72,6 +74,9 @@ def __init__( self.block_pool = block_pool self.enable_caching = enable_caching self._max_admission_blocks_per_request = max_admission_blocks_per_request + self.max_model_len = max_model_len + self.max_num_seqs = max_num_seqs + self.cache_alignment_tokens = self.block_size self.new_block_ids: list[int] = [] # Mapping from request ID to blocks to track the blocks allocated @@ -87,6 +92,8 @@ def __init__( self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block + self._protected_prompt_block_ids: set[int] = set() + self._protected_prompt_block_queue: deque[int] = deque() # Whether this group's prefix-cache hits drop the EAGLE/MTP lookahead # block. Only consulted by managers whose hit logic is sparse within an @@ -315,10 +322,84 @@ def take_new_block_ids(self) -> list[int]: self.new_block_ids = [] return ids + def _max_protected_prompt_blocks(self) -> int | None: + if self.max_model_len is None: + return None + return 2 * cdiv(max(1, self.max_model_len), self.block_size) + + def _protect_prompt_blocks(self, blocks: Sequence[KVCacheBlock]) -> None: + if not self.enable_caching: + return + + protected: list[KVCacheBlock] = [] + for block in blocks: + if ( + block.is_null + or block.block_hash is None + or block.block_id in self._protected_prompt_block_ids + ): + continue + protected.append(block) + self._protected_prompt_block_ids.add(block.block_id) + self._protected_prompt_block_queue.append(block.block_id) + + if not protected: + return + + # Keep an extra reference for prompt blocks that must survive after + # their request releases its normal runtime reference. Later request + # reuse increments/decrements the runtime reference as usual. + self.block_pool.touch(protected) + self._trim_protected_prompt_blocks() + + def _trim_protected_prompt_blocks(self) -> None: + max_blocks = self._max_protected_prompt_blocks() + if max_blocks is None: + return + + while len(self._protected_prompt_block_ids) > max_blocks: + if not self._release_one_protected_prompt_block(): + return + + def _release_one_protected_prompt_block( + self, block_ids_to_skip: set[int] | None = None + ) -> bool: + attempts = len(self._protected_prompt_block_queue) + while attempts: + block_id = self._protected_prompt_block_queue.popleft() + attempts -= 1 + if block_id not in self._protected_prompt_block_ids: + continue + if block_ids_to_skip is not None and block_id in block_ids_to_skip: + self._protected_prompt_block_queue.append(block_id) + continue + + self._protected_prompt_block_ids.remove(block_id) + block = self.block_pool.blocks[block_id] + if block.ref_cnt > 0: + self.block_pool.free_blocks([block]) + return True + return False + + def release_protected_prompt_blocks( + self, + target_free_blocks: int | None = None, + block_ids_to_skip: set[int] | None = None, + ) -> None: + while self._protected_prompt_block_ids: + if ( + target_free_blocks is not None + and self.block_pool.get_num_free_blocks() >= target_free_blocks + ): + return + if not self._release_one_protected_prompt_block(block_ids_to_skip): + return + def cache_blocks( self, request: Request, num_tokens: int, + alignment_tokens: int | None = None, retention_interval: int | None = None, ) -> None: """ @@ -328,6 +409,8 @@ def cache_blocks( request: The request. num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). + alignment_tokens: The prefix-cache hit alignment in tokens. + ``None`` uses this manager's scheduler block size. retention_interval: Sparse local-checkpoint granularity. ``None`` keeps dense checkpointing; ``0`` keeps only the latest replay boundary; a positive multiple of ``scheduler_block_size`` keeps @@ -339,10 +422,13 @@ def cache_blocks( if num_cached_blocks >= num_full_blocks: return + if alignment_tokens is None: + alignment_tokens = self.scheduler_block_size + block_mask = self.reachable_block_mask( start_block=num_cached_blocks, end_block=num_full_blocks, - alignment_tokens=self.scheduler_block_size, + alignment_tokens=alignment_tokens, kv_cache_spec=self.kv_cache_spec, use_eagle=self.use_eagle, retention_interval=retention_interval, @@ -598,6 +684,92 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: return num_common_blocks +class MLAAttentionManager(FullAttentionManager): + """KV cache manager for compressed / fp8 MLA cache layouts. + + Used by any MLA spec whose hit semantics need prompt-block + protection across decode and unrelated cache churn. ``_should_ + protect_prompt_blocks`` enumerates the triggering conditions. + """ + + def _should_protect_prompt_blocks(self) -> bool: + # Three independent triggers: + # 1. ``model_version == "deepseek_v4"``: DSv4 explicitly opts in. + # 2. ``cache_dtype_str == "fp8_ds_mla"``: fp8 DeepSeek-style + # MLA cache; protection is needed for the same hybrid-align + # reuse pattern. + # 3. ``compress_ratio > 1``: any compressed MLA cache (today + # only DSv4 sets ``compress_ratio > 1``; V3.2 keeps it at 1). + return ( + getattr(self.kv_cache_spec, "model_version", None) == "deepseek_v4" + or getattr(self.kv_cache_spec, "cache_dtype_str", None) == "fp8_ds_mla" + or getattr(self.kv_cache_spec, "compress_ratio", 1) > 1 + ) + + def _max_protected_prompt_blocks(self) -> int | None: + if self.max_num_seqs is None: + return super()._max_protected_prompt_blocks() + if self.max_model_len is None: + return None + + prompt_blocks = cdiv(max(1, self.max_model_len), self.block_size) + target_reqs = max(2, self.max_num_seqs) + target_blocks = target_reqs * prompt_blocks + + # Keep one max-length request worth of blocks available for new work + # before the generic allocation path has to release protected prompts. + pool_blocks = max(0, self.block_pool.num_gpu_blocks - 1) + if pool_blocks <= prompt_blocks: + return pool_blocks + return min(target_blocks, pool_blocks - prompt_blocks) + + def cache_blocks( + self, + request: Request, + num_tokens: int, + alignment_tokens: int | None = None, + retention_interval: int | None = None, + ) -> None: + super().cache_blocks( + request, + num_tokens, + alignment_tokens=alignment_tokens, + retention_interval=retention_interval, + ) + if not self._should_protect_prompt_blocks() or request.num_prompt_tokens <= 1: + return + + max_cache_hit_length = request.num_prompt_tokens - 1 + aligned_cache_hit_length = ( + max_cache_hit_length + // self.cache_alignment_tokens + * self.cache_alignment_tokens + ) + if aligned_cache_hit_length <= 0 or num_tokens < aligned_cache_hit_length: + return + num_hit_blocks = aligned_cache_hit_length // self.block_size + if num_hit_blocks == 0: + return + + self._protect_prompt_blocks( + self.req_to_blocks[request.request_id][:num_hit_blocks] + ) + + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: + blocks = self.req_to_blocks[running_request_id] + num_common_blocks = 0 + expected_ref_cnt = len(self.req_to_blocks) + for block in blocks: + ref_cnt = block.ref_cnt + if block.block_id in self._protected_prompt_block_ids: + ref_cnt -= 1 + if ref_cnt == expected_ref_cnt: + num_common_blocks += 1 + else: + break + return num_common_blocks + + class SlidingWindowManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) @@ -805,6 +977,51 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: return 0 +class SlidingWindowMLAManager(SlidingWindowManager): + """KV cache manager for DeepSeek V4's sliding-window MLA cache. + + During decode, the live sliding window can move past the prompt boundary. + The blocks around the hybrid-aligned prompt boundary are still the suffix + needed for a future prefix-cache hit of the same prompt. + """ + + def cache_blocks( + self, + request: Request, + num_tokens: int, + alignment_tokens: int | None = None, + retention_interval: int | None = None, + ) -> None: + super().cache_blocks( + request, + num_tokens, + alignment_tokens=alignment_tokens, + retention_interval=retention_interval, + ) + if not self.enable_caching or request.num_prompt_tokens <= 1: + return + + max_cache_hit_length = request.num_prompt_tokens - 1 + aligned_cache_hit_length = ( + max_cache_hit_length + // self.cache_alignment_tokens + * self.cache_alignment_tokens + ) + if aligned_cache_hit_length <= 0 or num_tokens < aligned_cache_hit_length: + return + + aligned_num_hit_blocks = aligned_cache_hit_length // self.block_size + last_full_prompt_block = max_cache_hit_length // self.block_size + contiguous_blocks = cdiv(self.sliding_window - 1, self.block_size) + first_protected_block = max(0, aligned_num_hit_blocks - contiguous_blocks) + last_protected_block = max(aligned_num_hit_blocks, last_full_prompt_block) + blocks = self.req_to_blocks[request.request_id] + protected_blocks = blocks[ + first_protected_block : min(last_protected_block, len(blocks)) + ] + self._protect_prompt_blocks(protected_blocks) + + class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) @@ -1215,10 +1432,16 @@ def cache_blocks( self, request: Request, num_tokens: int, + alignment_tokens: int | None = None, retention_interval: int | None = None, ) -> None: num_cached_blocks_before = self.num_cached_block.get(request.request_id, 0) - super().cache_blocks(request, num_tokens, retention_interval=retention_interval) + super().cache_blocks( + request, + num_tokens, + alignment_tokens=alignment_tokens, + retention_interval=retention_interval, + ) num_cached_blocks_after = self.num_cached_block.get(request.request_id, 0) if num_cached_blocks_after > num_cached_blocks_before: for block in self.req_to_blocks[request.request_id][ @@ -1260,6 +1483,7 @@ def cache_blocks( self, request: Request, num_tokens: int, + alignment_tokens: int | None = None, retention_interval: int | None = None, ) -> None: # We do not cache blocks for cross-attention to be shared between @@ -1303,16 +1527,22 @@ def __init__( block_pool: BlockPool, enable_caching: bool, kv_cache_group_id: int, + scheduler_block_size: int, dcp_world_size: int = 1, pcp_world_size: int = 1, + max_model_len: int | None = None, + max_num_seqs: int | None = None, ): super().__init__( kv_cache_spec, block_pool, enable_caching, kv_cache_group_id, + scheduler_block_size, dcp_world_size, pcp_world_size, + max_model_len=max_model_len, + max_num_seqs=max_num_seqs, ) sink_len = kv_cache_spec.sink_len assert sink_len is not None and sink_len > 0 and sink_len % self.block_size == 0 @@ -1324,6 +1554,7 @@ def get_manager_for_kv_cache_spec( kv_cache_spec: KVCacheSpec, max_num_batched_tokens: int, max_model_len: int, + max_num_seqs: int | None = None, **kwargs, ) -> SingleTypeKVCacheManager: """ @@ -1344,6 +1575,8 @@ def get_manager_for_kv_cache_spec( assert manager_class is not None, ( f"No manager registered for KVCacheSpec {type(kv_cache_spec)}" ) + kwargs["max_model_len"] = max_model_len + kwargs["max_num_seqs"] = max_num_seqs # SlidingWindow / ChunkedLocalAttention managers recycle blocks across # chunks; the runtime admission cap must match the recycling-aware bound # the startup pool sizer uses (single source of truth: the spec method). @@ -1373,7 +1606,7 @@ def register_all_kvcache_specs(vllm_config): ) KVCacheSpecRegistry.register( SlidingWindowMLASpec, - SlidingWindowManager, + SlidingWindowMLAManager, uniform_type_base_spec=SlidingWindowMLASpec, ) @@ -1398,7 +1631,7 @@ def register_all_kvcache_specs(vllm_config): uniform_type_base_spec=FullAttentionSpec, ) KVCacheSpecRegistry.register( - MLAAttentionSpec, FullAttentionManager, uniform_type_base_spec=FullAttentionSpec + MLAAttentionSpec, MLAAttentionManager, uniform_type_base_spec=FullAttentionSpec ) # NOTE(Mengqing): HiddenStateCacheSpec won't take part in # grouping, thus the uniform_type_base_spec is just a diff --git a/vllm/v1/spec_decode/llm_base_proposer.py b/vllm/v1/spec_decode/llm_base_proposer.py index 9f46cbd24239..e8e74144b2d9 100644 --- a/vllm/v1/spec_decode/llm_base_proposer.py +++ b/vllm/v1/spec_decode/llm_base_proposer.py @@ -35,6 +35,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.topk_topp_sampler import ( + apply_top_k_top_p, empty_exponential_noise_like, sample_with_exponential_noise, ) @@ -236,7 +237,14 @@ def __init__( self._enable_probabilistic_draft_probs = ( self.speculative_config.rejection_sample_method == "standard" and self.speculative_config.draft_sample_method == "probabilistic" - ) + ) or self.method == "mtp" + # MTP drafts benefit from probabilistic sampling at ``temperature > 0``: + # the draft model is a near-copy of the target, so sampling from + # ``softmax(draft_logits)`` covers the target distribution much better + # than argmax and lifts acceptance by ~9 pp on DSv4-Flash mt-bench + # (measured 58.9% → 67.8%). The ``all_greedy`` fast path inside + # ``_sample_draft_tokens`` still falls back to argmax for ``temp==0`` + # requests so this override only changes behaviour where it helps. self._last_draft_probs: torch.Tensor | None = None self._slot_mapping_buffer = torch.zeros( @@ -408,11 +416,53 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) - def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor: + def _get_effective_spec_step_idx(self, spec_step_idx: int) -> int: + if self.method != "mtp": + return 0 + + config = getattr(self.model, "config", None) + num_mtp_layers = getattr(config, "num_nextn_predict_layers", None) + if not isinstance(num_mtp_layers, int): + inner_model = getattr(self.model, "model", None) + num_mtp_layers = getattr(inner_model, "num_mtp_layers", None) + + if isinstance(num_mtp_layers, int) and num_mtp_layers > 0: + return spec_step_idx % num_mtp_layers + return spec_step_idx + + def _compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if self.method == "mtp": + return self.model.compute_logits( + hidden_states, + spec_step_idx=self._get_effective_spec_step_idx(spec_step_idx), + ) + return self.model.compute_logits(hidden_states) + + def _add_spec_step_idx( + self, + model_kwargs: dict[str, torch.Tensor | None], + spec_step_idx: int, + ) -> dict[str, torch.Tensor | int | None]: + if self.method != "mtp": + return model_kwargs + return { + **model_kwargs, + "spec_step_idx": self._get_effective_spec_step_idx(spec_step_idx), + } + + def _greedy_sample( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor: """Greedy-sample draft tokens from hidden states.""" if self.use_local_argmax_reduction: return self.model.get_top_tokens(hidden_states) - return self.model.compute_logits(hidden_states).argmax(dim=-1) + return self._compute_logits(hidden_states, spec_step_idx).argmax(dim=-1) def _sample_from_logits( self, @@ -431,10 +481,25 @@ def _sample_draft_tokens( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Sample draft token ids (and optionally their probabilities). + + ``spec_step_idx`` is forwarded to ``_compute_logits`` so that + MTP-style drafters can route the logits computation through the + correct draft layer (DSv4 cycles through ``num_nextn_predict_layers`` + layers). For non-MTP drafters this argument is ignored. + + Probabilities are only returned when + :attr:`_enable_probabilistic_draft_probs` is enabled (set by the + speculative config via ``rejection_sample_method="standard"`` and + ``draft_sample_method="probabilistic"``). In every other case we + fall back to argmax, matching the upstream + ``_sample_draft_tokens`` contract. + """ if not self._enable_probabilistic_draft_probs or sampling_metadata.all_greedy: - return self._greedy_sample(hidden_states), None - logits = self.model.compute_logits(hidden_states) + return self._greedy_sample(hidden_states, spec_step_idx), None + logits = self._compute_logits(hidden_states, spec_step_idx) return self._sample_from_logits(logits, sampling_metadata) def take_last_draft_probs(self) -> torch.Tensor | None: @@ -521,7 +586,7 @@ def propose( slot_mapping_size, common_attn_metadata.slot_mapping ), ): - ret_hidden_states = self.model(**model_kwargs) + ret_hidden_states = self.model(**self._add_spec_step_idx(model_kwargs, 0)) if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states @@ -548,8 +613,11 @@ def propose( # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1 or self.parallel_drafting: + # ``spec_step_idx=0`` selects the first MTP layer for DSv4 (and + # is ignored for non-MTP drafters). Probabilities only flow + # through when ``_enable_probabilistic_draft_probs`` is set. draft_token_ids, draft_probs = self._sample_draft_tokens( - sample_hidden_states, sampling_metadata + sample_hidden_states, sampling_metadata, spec_step_idx=0 ) if draft_probs is not None: self._last_draft_probs = draft_probs.view( @@ -570,9 +638,8 @@ def propose( self.positions[:batch_size] = positions draft_token_ids, draft_probs = self._sample_draft_tokens( - sample_hidden_states, sampling_metadata + sample_hidden_states, sampling_metadata, spec_step_idx=0 ) - draft_probs_list = None if draft_probs is None else [draft_probs] if self.allowed_attn_types is not None: for group_md in per_group_attn_metadata: @@ -586,6 +653,7 @@ def propose( # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] + draft_probs_list = [] if draft_probs is None else [draft_probs] cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = ( self._determine_batch_execution_and_padding(batch_size) @@ -664,7 +732,9 @@ def propose( cudagraph_runtime_mode=cudagraph_runtime_mode, slot_mapping=self._get_slot_mapping(input_batch_size), ): - ret_hidden_states = self.model(**model_kwargs) + ret_hidden_states = self.model( + **self._add_spec_step_idx(model_kwargs, token_index + 1) + ) if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = ret_hidden_states @@ -673,16 +743,17 @@ def propose( hidden_states = hidden_states[:batch_size] draft_token_ids, draft_probs = self._sample_draft_tokens( - last_hidden_states[:batch_size], sampling_metadata + last_hidden_states[:batch_size], + sampling_metadata, + spec_step_idx=token_index + 1, ) + draft_token_ids_list.append(draft_token_ids) if draft_probs is not None: - assert draft_probs_list is not None draft_probs_list.append(draft_probs) - draft_token_ids_list.append(draft_token_ids) # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - if draft_probs_list is not None: + if draft_probs_list: self._last_draft_probs = torch.stack(draft_probs_list, dim=1).contiguous() return draft_token_ids @@ -1702,12 +1773,10 @@ def _determine_batch_execution_and_padding( return cudagraph_mode, num_tokens_padded, num_tokens_across_dp -# NOTE(woosuk): Currently, the below code is not used and we always use argmax -# to sample the draft tokens. We will use this after we find a way to manage -# the draft prob tensor. -# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details. -# FIXME(woosuk): The logic here is duplicated with the main sampling code. -# We should refactor this to reuse the same sampling implementation. +# NOTE(woosuk): This duplicates part of the main sampling code because MTP +# needs both the sampled draft token ids and the draft probability tensor for +# rejection sampling. Refactor this once the sampler exposes a reusable helper +# that returns both values without extra packing. def compute_probs_and_sample_next_token( logits: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -1724,22 +1793,38 @@ def compute_probs_and_sample_next_token( # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0) # consistent with sampler.py's _SAMPLING_EPS threshold - temperature = sampling_metadata.temperature + num_tokens = logits.shape[0] + # The triton top-k/top-p sampler (apply_top_k_top_p) asserts float32 logits, + # matching the main sampler (sampler.py). The MTP draft head emits bf16 logits, + # so without this cast MTP draft sampling with any top-k/top-p (i.e. ordinary + # non-greedy chat traffic) trips that assertion and kills the engine. Greedy + # requests return above and never reach here. + logits = logits.to(torch.float32) + temperature = _expand_draft_sampling_tensor( + sampling_metadata.temperature, + num_tokens, + ) + assert temperature is not None # Avoid division by zero if there are greedy requests. if not sampling_metadata.all_random: is_greedy = temperature < _SAMPLING_EPS temperature = torch.where(is_greedy, 1.0, temperature) logits.div_(temperature.view(-1, 1)) + top_k = _expand_draft_sampling_tensor(sampling_metadata.top_k, num_tokens) + top_p = _expand_draft_sampling_tensor(sampling_metadata.top_p, num_tokens) + logits = apply_top_k_top_p(logits, top_k, top_p) probs = logits.softmax(dim=-1, dtype=torch.float32) - # NOTE(woosuk): Currently, we ignore most of the sampling parameters in - # generating the draft tokens. We only use the temperature. While this - # could degrade the acceptance rate, it does not affect the distribution - # of the generated tokens after rejection sampling. - - # TODO(woosuk): Consider seeds. + generators = _expand_draft_sampling_generators( + sampling_metadata.generators, + sampling_metadata.temperature.shape[0], + num_tokens, + ) q = empty_exponential_noise_like(probs, use_fp64_gumbel) - q.exponential_() + if len(generators) != num_tokens: + q.exponential_() + for i, generator in generators.items(): + q[i].exponential_(generator=generator) # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs # will be used later for rejection sampling. next_token_ids = sample_with_exponential_noise(probs.clone(), q) @@ -1747,3 +1832,41 @@ def compute_probs_and_sample_next_token( greedy_token_ids = probs.argmax(dim=-1) next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids) return next_token_ids, probs + + +def _expand_draft_sampling_tensor( + tensor: torch.Tensor | None, + num_tokens: int, +) -> torch.Tensor | None: + if tensor is None or tensor.shape[0] == num_tokens: + return tensor + + batch_size = tensor.shape[0] + if num_tokens % batch_size != 0: + raise ValueError( + "Draft sampling metadata must either match the draft logits row " + "count or evenly divide it." + ) + return tensor.repeat_interleave(num_tokens // batch_size, dim=0) + + +def _expand_draft_sampling_generators( + generators: dict[int, torch.Generator], + batch_size: int, + num_tokens: int, +) -> dict[int, torch.Generator]: + if not generators or batch_size == num_tokens: + return generators + + if num_tokens % batch_size != 0: + raise ValueError( + "Draft sampling generators must either match the draft logits row " + "count or evenly divide it." + ) + + repeat = num_tokens // batch_size + return { + req_idx * repeat + offset: generator + for req_idx, generator in generators.items() + for offset in range(repeat) + } diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 74938a823d9f..43d8408565cd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5677,6 +5677,7 @@ def _dummy_run( skip_eplb: bool = False, is_profile: bool = False, create_mixed_batch: bool = False, + create_single_prefill: bool = False, remove_lora: bool = True, is_graph_capturing: bool = False, num_active_loras: int = 0, @@ -5702,6 +5703,8 @@ def _dummy_run( is_profile: If True, this is a profile run. create_mixed_batch: If True, create a mixed batch with both decode (1 token) and prefill (multiple tokens) requests. + create_single_prefill: If True, create one prefill request with + ``num_tokens`` prompt tokens. remove_lora: If False, dummy LoRAs are not destroyed after the run num_active_loras: Number of distinct active LoRAs to capture for. LoRA is activated when num_active_loras > 0. @@ -5740,7 +5743,13 @@ def _dummy_run( # has num_tokens in total. assert num_tokens <= self.max_num_tokens max_num_reqs = self.scheduler_config.max_num_seqs - if create_mixed_batch: + if create_single_prefill: + assert not uniform_decode + assert not create_mixed_batch + num_reqs = 1 + num_scheduled_tokens_list = [num_tokens] + max_query_len = num_tokens + elif create_mixed_batch: assert not uniform_decode # Create mixed batch: # first half decode tokens, second half one prefill @@ -5754,6 +5763,7 @@ def _dummy_run( max_query_len = num_prefill_tokens elif uniform_decode: assert not create_mixed_batch + assert not create_single_prefill num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len)) num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index f4a76529023c..0140a7afe550 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -231,6 +231,16 @@ def _make_metadata_with_slice( block_table_tensor = attn_metadata.block_table_tensor[request_slice] slot_mapping = attn_metadata.slot_mapping[token_slice] + positions = ( + attn_metadata.positions[token_slice] + if attn_metadata.positions is not None + else None + ) + is_prefilling = ( + attn_metadata.is_prefilling[request_slice] + if attn_metadata.is_prefilling is not None + else None + ) return CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -242,6 +252,8 @@ def _make_metadata_with_slice( max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, + positions=positions, + is_prefilling=is_prefilling, seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound, _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=num_computed_tokens_cpu,