From c7b19daa0af84c385f5b21dce4cc0f16eeb73934 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 6 May 2026 23:54:24 +0800 Subject: [PATCH 001/131] Fix DeepSeek V4 MLA prefix cache reuse Protect hybrid-aligned DeepSeek V4 MLA prompt cache blocks so they survive decode and unrelated cache churn. Release those protected references under admission pressure and before prefix-cache reset so they do not starve the block pool. Add regression coverage for reuse after decode pressure, admission under protected refs, and reset cleanup. Signed-off-by: jasl --- tests/v1/core/test_prefix_caching.py | 208 +++++++++++++++++++ vllm/v1/core/kv_cache_coordinator.py | 34 +++ vllm/v1/core/kv_cache_manager.py | 48 ++++- vllm/v1/core/single_type_kv_cache_manager.py | 183 +++++++++++++++- 4 files changed, 463 insertions(+), 10 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 0871a15d08d3..8cc28f6cd09e 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -42,6 +42,7 @@ KVCacheSpecKind, MambaSpec, MLAAttentionSpec, + SlidingWindowMLASpec, SlidingWindowSpec, ) @@ -3470,6 +3471,213 @@ def test_can_fit_full_sequence_swa_cap_admits_long_prompt(): ) +def test_deepseek_v4_mla_keeps_prompt_blocks_after_decode_pressure(): + hash_block_size = 2 + full_block_size = 8 + swa_block_size = 2 + prompt_tokens = 35 + chunk_tokens = 4 * full_block_size + expected_hit_tokens = (prompt_tokens - 1) // full_block_size * full_block_size + + config = KVCacheConfig( + num_blocks=70, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=full_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_0"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + sliding_window=2 * swa_block_size, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_1"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + sliding_window=2 * swa_block_size, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_compressor_state"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=2 * swa_block_size, + ), + ), + ], + ) + manager = KVCacheManager( + config, + max_model_len=128, + max_num_batched_tokens=chunk_tokens, + enable_caching=True, + hash_block_size=hash_block_size, + ) + + def run_request(request: Request, num_decode_tokens: int) -> int: + computed_blocks, num_computed_tokens = manager.get_computed_blocks(request) + computed_so_far = num_computed_tokens + remaining_prompt_tokens = request.num_prompt_tokens - num_computed_tokens + first_chunk = True + while remaining_prompt_tokens > 0: + num_new_tokens = min(chunk_tokens, remaining_prompt_tokens) + allocated = manager.allocate_slots( + request, + num_new_tokens, + num_computed_tokens if first_chunk else 0, + computed_blocks if first_chunk else None, + ) + assert allocated is not None + computed_so_far += num_new_tokens + request.num_computed_tokens = computed_so_far + remaining_prompt_tokens -= num_new_tokens + first_chunk = False + + for i in range(num_decode_tokens): + request.append_output_token_ids(10_000 + i) + allocated = manager.allocate_slots(request, 1) + assert allocated is not None + computed_so_far += 1 + request.num_computed_tokens = computed_so_far + return num_computed_tokens + + prompt_a = list(range(prompt_tokens)) + req_a = make_request("a", prompt_a, hash_block_size, sha256) + assert run_request(req_a, num_decode_tokens=0) == 0 + manager.free(req_a) + + warm_a = make_request("warm_a", prompt_a, hash_block_size, sha256) + assert run_request(warm_a, num_decode_tokens=8) == expected_hit_tokens + assert manager.get_num_common_prefix_blocks("warm_a")[0] >= ( + expected_hit_tokens // full_block_size + ) + manager.free(warm_a) + + pressure_blocks = manager.block_pool.get_new_blocks( + manager.block_pool.get_num_free_blocks() + ) + manager.block_pool.free_blocks(reversed(pressure_blocks)) + + req_a_again = make_request("a_again", prompt_a, hash_block_size, sha256) + _, num_computed_tokens = manager.get_computed_blocks(req_a_again) + assert num_computed_tokens == expected_hit_tokens + + +def test_deepseek_v4_mla_protected_prompt_blocks_do_not_block_admission(): + block_size = 8 + prompt_tokens = 4 * block_size + 3 + protected_blocks_per_prompt = (prompt_tokens - 1) // block_size + num_prompts = 10 + num_blocks = 80 + manager = KVCacheManager( + KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ) + ], + ), + max_model_len=512, + max_num_batched_tokens=128, + enable_caching=True, + hash_block_size=block_size, + ) + mla_manager = manager.coordinator.single_type_managers[0] + + for i in range(num_prompts): + prompt = list(range(i * 1000, i * 1000 + prompt_tokens)) + req = make_request(f"protected_{i}", prompt, block_size, sha256) + assert manager.allocate_slots(req, prompt_tokens) is not None + req.num_computed_tokens = prompt_tokens + manager.free(req) + + assert len(mla_manager._protected_prompt_block_ids) == ( + num_prompts * protected_blocks_per_prompt + ) + assert manager.block_pool.get_num_free_blocks() < 64 + + long_req = make_request( + "long", + list(range(100_000, 100_000 + 64 * block_size)), + block_size, + sha256, + ) + assert ( + manager.allocate_slots(long_req, block_size, full_sequence_must_fit=True) + is not None + ) + + +def test_reset_prefix_cache_releases_deepseek_v4_mla_protected_blocks(): + block_size = 8 + prompt_tokens = 4 * block_size + 3 + manager = KVCacheManager( + KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ) + ], + ), + max_model_len=512, + max_num_batched_tokens=128, + enable_caching=True, + hash_block_size=block_size, + ) + + req = make_request("protected", list(range(prompt_tokens)), block_size, sha256) + assert manager.allocate_slots(req, prompt_tokens) is not None + req.num_computed_tokens = prompt_tokens + manager.free(req) + + assert manager.coordinator.single_type_managers[0]._protected_prompt_block_ids + assert manager.reset_prefix_cache() + + def test_can_fit_full_sequence_full_attention_still_gates_oversized(): """The cap only loosens the SWA group; a prompt that exceeds the full-attention pool capacity must still be rejected.""" diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 376f65f6697a..452146f9a97a 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from collections.abc import Sequence +from math import lcm from typing import NamedTuple from vllm import envs @@ -342,6 +343,21 @@ def remove_skipped_blocks( for manager in self.single_type_managers: manager.remove_skipped_blocks(request_id, total_computed_tokens) + def release_protected_prompt_blocks( + self, + target_free_blocks: int | None = None, + block_ids_to_skip: set[int] | None = None, + ) -> None: + for manager in self.single_type_managers: + if ( + target_free_blocks is not None + and self.block_pool.get_num_free_blocks() >= target_free_blocks + ): + return + manager.release_protected_prompt_blocks( + target_free_blocks, block_ids_to_skip + ) + def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]: """ Get the blocks for the request. @@ -590,6 +606,24 @@ def verify_and_split_kv_cache_groups(self) -> None: for gid in group.group_ids: self.single_type_managers[gid].use_eagle = True + # The LCM of the block sizes of all attention types. + # The cache hit length must be a multiple of the LCM of the block sizes + # to make sure the cache hit length is a multiple of the block size of + # each attention type. Requiring this because we don't support partial + # block cache hit yet. + block_sizes = [group.spec.block_size for group in self.attention_groups] + self.lcm_block_size = lcm(*block_sizes) + for manager in self.single_type_managers: + manager.cache_alignment_tokens = self.lcm_block_size + + # Attention-group indices (into ``self.attention_groups``) that + # contain at least one EAGLE/MTP KV cache group. + self.eagle_attn_group_indices: set[int] = { + i + for i, group in enumerate(self.attention_groups) + if group.use_eagle + } + def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: # Cache hits in this coordinator are always a multiple of # ``scheduler_block_size`` tokens (see ``find_longest_cache_hit``). diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index b0f6655bf957..c829b845803e 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -359,6 +359,9 @@ def allocate_slots( num_local_computed_tokens + num_external_computed_tokens, self.max_model_len, ) + block_ids_to_skip_releasing = self._block_ids_to_skip_releasing( + new_computed_block_list + ) watermark_blocks = 0 # The watermark is applied to waiting/preempted requests only, and only @@ -382,8 +385,11 @@ def allocate_slots( num_tokens_main_model=full_num_tokens, apply_admission_cap=True, ) - required_blocks = num_blocks_to_allocate + watermark_blocks - if required_blocks > self.block_pool.get_num_free_blocks(): + if not self._has_enough_free_blocks( + num_blocks_to_allocate, + block_ids_to_skip_releasing, + watermark_blocks=watermark_blocks, + ): return None num_tokens_main_model = total_computed_tokens + num_new_tokens @@ -411,11 +417,12 @@ def allocate_slots( num_tokens_main_model=num_tokens_main_model, ) - # Keep `reserved_blocks` free for other in-flight sequences, and an - # additional watermark of headroom for waiting/preempted admissions. - available_blocks = self.block_pool.get_num_free_blocks() - reserved_blocks - required_blocks = num_blocks_to_allocate + watermark_blocks - if required_blocks > available_blocks: + if not self._has_enough_free_blocks( + num_blocks_to_allocate, + block_ids_to_skip_releasing, + reserved_blocks=reserved_blocks, + watermark_blocks=watermark_blocks, + ): # Cannot allocate new blocks return None @@ -501,6 +508,32 @@ def evict_blocks(self, block_ids: set[int]) -> None: """ self.block_pool.evict_blocks(block_ids) + @staticmethod + def _block_ids_to_skip_releasing( + blocks: tuple[Sequence[KVCacheBlock], ...], + ) -> set[int]: + return { + block.block_id + for group_blocks in blocks + for block in group_blocks + if not block.is_null + } + + def _has_enough_free_blocks( + self, + num_blocks: int, + block_ids_to_skip_releasing: set[int] | None = None, + reserved_blocks: int = 0, + watermark_blocks: int = 0, + ) -> bool: + required_free_blocks = num_blocks + reserved_blocks + watermark_blocks + if required_free_blocks <= self.block_pool.get_num_free_blocks(): + return True + self.coordinator.release_protected_prompt_blocks( + required_free_blocks, block_ids_to_skip_releasing + ) + return required_free_blocks <= self.block_pool.get_num_free_blocks() + def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalidate prefix caching after the weights are updated, @@ -510,6 +543,7 @@ def reset_prefix_cache(self) -> bool: bool: True if the prefix cache is successfully reset, False otherwise. """ + self.coordinator.release_protected_prompt_blocks() if not self.block_pool.reset_prefix_cache(): return False if self.log_stats: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index c98c59017c52..b4ecff696de0 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from abc import ABC, abstractmethod -from collections import defaultdict +from collections import defaultdict, deque from collections.abc import Sequence from vllm.utils.math_utils import cdiv @@ -45,6 +45,7 @@ def __init__( dcp_world_size: int = 1, pcp_world_size: int = 1, max_admission_blocks_per_request: int | None = None, + max_model_len: int | None = None, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -72,6 +73,8 @@ def __init__( self.block_pool = block_pool self.enable_caching = enable_caching self._max_admission_blocks_per_request = max_admission_blocks_per_request + self.max_model_len = max_model_len + self.cache_alignment_tokens = self.block_size self.new_block_ids: list[int] = [] # Mapping from request ID to blocks to track the blocks allocated @@ -87,6 +90,8 @@ def __init__( self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block + self._protected_prompt_block_ids: set[int] = set() + self._protected_prompt_block_queue: deque[int] = deque() # Whether this group's prefix-cache hits drop the EAGLE/MTP lookahead # block. Only consulted by managers whose hit logic is sparse within an @@ -315,6 +320,79 @@ def take_new_block_ids(self) -> list[int]: self.new_block_ids = [] return ids + def _max_protected_prompt_blocks(self) -> int | None: + if self.max_model_len is None: + return None + return 2 * cdiv(max(1, self.max_model_len), self.block_size) + + def _protect_prompt_blocks(self, blocks: Sequence[KVCacheBlock]) -> None: + if not self.enable_caching: + return + + protected: list[KVCacheBlock] = [] + for block in blocks: + if ( + block.is_null + or block.block_hash is None + or block.block_id in self._protected_prompt_block_ids + ): + continue + protected.append(block) + self._protected_prompt_block_ids.add(block.block_id) + self._protected_prompt_block_queue.append(block.block_id) + + if not protected: + return + + # Keep an extra reference for prompt blocks that must survive after + # their request releases its normal runtime reference. Later request + # reuse increments/decrements the runtime reference as usual. + self.block_pool.touch(protected) + self._trim_protected_prompt_blocks() + + def _trim_protected_prompt_blocks(self) -> None: + max_blocks = self._max_protected_prompt_blocks() + if max_blocks is None: + return + + while len(self._protected_prompt_block_ids) > max_blocks: + if not self._release_one_protected_prompt_block(): + return + + def _release_one_protected_prompt_block( + self, block_ids_to_skip: set[int] | None = None + ) -> bool: + attempts = len(self._protected_prompt_block_queue) + while attempts: + block_id = self._protected_prompt_block_queue.popleft() + attempts -= 1 + if block_id not in self._protected_prompt_block_ids: + continue + if block_ids_to_skip is not None and block_id in block_ids_to_skip: + self._protected_prompt_block_queue.append(block_id) + continue + + self._protected_prompt_block_ids.remove(block_id) + block = self.block_pool.blocks[block_id] + if block.ref_cnt > 0: + self.block_pool.free_blocks([block]) + return True + return False + + def release_protected_prompt_blocks( + self, + target_free_blocks: int | None = None, + block_ids_to_skip: set[int] | None = None, + ) -> None: + while self._protected_prompt_block_ids: + if ( + target_free_blocks is not None + and self.block_pool.get_num_free_blocks() >= target_free_blocks + ): + return + if not self._release_one_protected_prompt_block(block_ids_to_skip): + return + def cache_blocks( self, request: Request, @@ -598,6 +676,59 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: return num_common_blocks +class MLAAttentionManager(FullAttentionManager): + """KV cache manager for DeepSeek V4 compressed MLA cache.""" + + def _should_protect_prompt_blocks(self) -> bool: + return ( + self.kv_cache_spec.model_version == "deepseek_v4" + or self.kv_cache_spec.cache_dtype_str == "fp8_ds_mla" + or self.kv_cache_spec.compress_ratio > 1 + ) + + def cache_blocks( + self, + request: Request, + num_tokens: int, + alignment_tokens: int | None = None, + ) -> None: + super().cache_blocks(request, num_tokens, alignment_tokens=alignment_tokens) + if ( + not self._should_protect_prompt_blocks() + or num_tokens < request.num_prompt_tokens + or request.num_prompt_tokens <= 1 + ): + return + + max_cache_hit_length = request.num_prompt_tokens - 1 + aligned_cache_hit_length = ( + max_cache_hit_length + // self.cache_alignment_tokens + * self.cache_alignment_tokens + ) + num_hit_blocks = aligned_cache_hit_length // self.block_size + if num_hit_blocks == 0: + return + + self._protect_prompt_blocks( + self.req_to_blocks[request.request_id][:num_hit_blocks] + ) + + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: + blocks = self.req_to_blocks[running_request_id] + num_common_blocks = 0 + expected_ref_cnt = len(self.req_to_blocks) + for block in blocks: + ref_cnt = block.ref_cnt + if block.block_id in self._protected_prompt_block_ids: + ref_cnt -= 1 + if ref_cnt == expected_ref_cnt: + num_common_blocks += 1 + else: + break + return num_common_blocks + + class SlidingWindowManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) @@ -805,6 +936,47 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: return 0 +class SlidingWindowMLAManager(SlidingWindowManager): + """KV cache manager for DeepSeek V4's sliding-window MLA cache. + + During decode, the live sliding window can move past the prompt boundary. + The blocks around the hybrid-aligned prompt boundary are still the suffix + needed for a future prefix-cache hit of the same prompt. + """ + + def cache_blocks( + self, + request: Request, + num_tokens: int, + alignment_tokens: int | None = None, + ) -> None: + super().cache_blocks(request, num_tokens, alignment_tokens=alignment_tokens) + if not self.enable_caching or num_tokens < request.num_prompt_tokens: + return + if request.num_prompt_tokens <= 1: + return + + max_cache_hit_length = request.num_prompt_tokens - 1 + aligned_cache_hit_length = ( + max_cache_hit_length + // self.cache_alignment_tokens + * self.cache_alignment_tokens + ) + if aligned_cache_hit_length <= 0: + return + + aligned_num_hit_blocks = aligned_cache_hit_length // self.block_size + last_full_prompt_block = max_cache_hit_length // self.block_size + contiguous_blocks = cdiv(self.sliding_window - 1, self.block_size) + first_protected_block = max(0, aligned_num_hit_blocks - contiguous_blocks) + last_protected_block = max(aligned_num_hit_blocks, last_full_prompt_block) + blocks = self.req_to_blocks[request.request_id] + protected_blocks = blocks[ + first_protected_block : min(last_protected_block, len(blocks)) + ] + self._protect_prompt_blocks(protected_blocks) + + class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) @@ -1303,16 +1475,20 @@ def __init__( block_pool: BlockPool, enable_caching: bool, kv_cache_group_id: int, + scheduler_block_size: int, dcp_world_size: int = 1, pcp_world_size: int = 1, + max_model_len: int | None = None, ): super().__init__( kv_cache_spec, block_pool, enable_caching, kv_cache_group_id, + scheduler_block_size, dcp_world_size, pcp_world_size, + max_model_len=max_model_len, ) sink_len = kv_cache_spec.sink_len assert sink_len is not None and sink_len > 0 and sink_len % self.block_size == 0 @@ -1344,6 +1520,7 @@ def get_manager_for_kv_cache_spec( assert manager_class is not None, ( f"No manager registered for KVCacheSpec {type(kv_cache_spec)}" ) + kwargs["max_model_len"] = max_model_len # SlidingWindow / ChunkedLocalAttention managers recycle blocks across # chunks; the runtime admission cap must match the recycling-aware bound # the startup pool sizer uses (single source of truth: the spec method). @@ -1373,7 +1550,7 @@ def register_all_kvcache_specs(vllm_config): ) KVCacheSpecRegistry.register( SlidingWindowMLASpec, - SlidingWindowManager, + SlidingWindowMLAManager, uniform_type_base_spec=SlidingWindowMLASpec, ) @@ -1398,7 +1575,7 @@ def register_all_kvcache_specs(vllm_config): uniform_type_base_spec=FullAttentionSpec, ) KVCacheSpecRegistry.register( - MLAAttentionSpec, FullAttentionManager, uniform_type_base_spec=FullAttentionSpec + MLAAttentionSpec, MLAAttentionManager, uniform_type_base_spec=FullAttentionSpec ) # NOTE(Mengqing): HiddenStateCacheSpec won't take part in # grouping, thus the uniform_type_base_spec is just a From 0cc91bf73d86be8da09b4655d6048d234d5f325c Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 6 May 2026 03:40:04 +0800 Subject: [PATCH 002/131] Add Blackwell tuning config aliases Signed-off-by: jasl --- .../test_sm12x_tuned_config_lookup.py | 52 +++++++ ...evice_name=NVIDIA_GB10,dtype=fp8_w8a8.json | 146 +++++++++++++++++ ...-Q_Workstation_Edition,dtype=fp8_w8a8.json | 146 +++++++++++++++++ ...ackwell_Server_Edition,dtype=fp8_w8a8.json | 146 +++++++++++++++++ ...evice_name=NVIDIA_GB10,dtype=fp8_w8a8.json | 146 +++++++++++++++++ ...-Q_Workstation_Edition,dtype=fp8_w8a8.json | 146 +++++++++++++++++ ...ackwell_Server_Edition,dtype=fp8_w8a8.json | 146 +++++++++++++++++ ...evice_name=NVIDIA_GB10,dtype=fp8_w8a8.json | 147 ++++++++++++++++++ ...-Q_Workstation_Edition,dtype=fp8_w8a8.json | 147 ++++++++++++++++++ ...ll_Workstation_Edition,dtype=fp8_w8a8.json | 147 ++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 42 +++++ 43 files changed, 3700 insertions(+) create mode 100644 tests/quantization/test_sm12x_tuned_config_lookup.py create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/tests/quantization/test_sm12x_tuned_config_lookup.py b/tests/quantization/test_sm12x_tuned_config_lookup.py new file mode 100644 index 000000000000..c27d28e45580 --- /dev/null +++ b/tests/quantization/test_sm12x_tuned_config_lookup.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.quantization.utils import fp8_utils +from vllm.platforms import current_platform + +GB10_BLOCK_FP8_SHAPES = ( + (1536, 4096), + (16384, 1024), + (2048, 4096), + (4096, 1024), + (4096, 4096), + (8192, 1024), +) + +GB10_FUSED_MOE_SHAPES = ( + (128, 704, None), + (129, 704, None), + (20, 1536, None), + (256, 384, (128, 128)), + (256, 512, (128, 128)), + (64, 1536, (128, 128)), +) + + +def _get_fused_moe_configs(e, n, block_shape): + if block_shape is None: + return fused_moe.get_moe_configs(e, n, "fp8_w8a8") + block_n, block_k = block_shape + return fused_moe.get_moe_configs(e, n, "fp8_w8a8", block_n, block_k) + + +def test_gb10_tuned_configs_cover_dense_and_fused_moe(monkeypatch): + monkeypatch.setattr(current_platform, "get_device_name", lambda: "NVIDIA GB10") + monkeypatch.setattr(fused_moe.envs, "VLLM_BATCH_INVARIANT", False) + fp8_utils.get_w8a8_block_fp8_configs.cache_clear() + fused_moe.get_moe_configs.cache_clear() + + missing_dense = [ + (n, k) + for n, k in GB10_BLOCK_FP8_SHAPES + if fp8_utils.get_w8a8_block_fp8_configs(n, k, 128, 128) is None + ] + assert not missing_dense + + missing_moe = [ + (e, n, block_shape) + for e, n, block_shape in GB10_FUSED_MOE_SHAPES + if _get_fused_moe_configs(e, n, block_shape) is None + ] + assert not missing_moe diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..75dfc52cb46d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..75dfc52cb46d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..75dfc52cb46d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..75dfc52cb46d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..75dfc52cb46d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..75dfc52cb46d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..8b78f87e7f73 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_GB10,dtype=fp8_w8a8.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.0", + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..8b78f87e7f73 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.0", + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..8b78f87e7f73 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.0", + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..bcec61632e3e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..bcec61632e3e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..bcec61632e3e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..705ca33d594b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..705ca33d594b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..705ca33d594b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..9c2ebaddd83f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..9c2ebaddd83f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..9c2ebaddd83f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..50dc6d9575e7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..34d0f8699583 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..8ad3a0197412 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..8ad3a0197412 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..8ad3a0197412 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..cd7e6e91e663 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..8ad3a0197412 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..8ad3a0197412 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..217028c5412e --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..da4016483cc6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,42 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} From b7c81236a19fbef2c93d46b258da6879b1e6ffd6 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 6 May 2026 17:49:06 +0800 Subject: [PATCH 003/131] Add portable sparse MLA Triton kernels Signed-off-by: jasl --- .../attention/backends/mla/sparse_mla_env.py | 62 ++ .../backends/mla/sparse_mla_kernels.py | 838 ++++++++++++++++++ 2 files changed, 900 insertions(+) create mode 100644 vllm/v1/attention/backends/mla/sparse_mla_env.py create mode 100644 vllm/v1/attention/backends/mla/sparse_mla_kernels.py diff --git a/vllm/v1/attention/backends/mla/sparse_mla_env.py b/vllm/v1/attention/backends/mla/sparse_mla_env.py new file mode 100644 index 000000000000..ff8c3a5c4b56 --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_env.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Platform controls for the portable Triton sparse MLA path.""" + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE = 512 +_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE = 256 + +logger = init_logger(__name__) + + +def _is_sm12x_device(device: torch.device) -> bool: + if not torch.cuda.is_available(): + return False + index = device.index if device.index is not None else torch.cuda.current_device() + return torch.cuda.get_device_capability(index)[0] == 12 + + +def is_triton_sparse_mla_enabled_for_platform() -> bool: + return current_platform.is_device_capability_family(120) + + +def is_triton_sparse_mla_enabled(device: torch.device) -> bool: + return _is_sm12x_device(device) + + +def disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config) -> None: + if not is_triton_sparse_mla_enabled_for_platform(): + return + + from vllm.config.compilation import CompilationMode, CUDAGraphMode + + compilation_config = vllm_config.compilation_config + if ( + compilation_config.mode == CompilationMode.NONE + and compilation_config.cudagraph_mode == CUDAGraphMode.NONE + ): + return + + logger.warning_once( + "Disabling vLLM compile and CUDA graphs for the DeepSeek V4 Triton " + "sparse MLA path because the current Triton sparse MLA path is not " + "compile/graph-safe yet." + ) + compilation_config.mode = CompilationMode.NONE + compilation_config.compile_sizes = [] + compilation_config.compile_ranges_endpoints = [] + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + compilation_config.cudagraph_capture_sizes = [] + compilation_config.max_cudagraph_capture_size = 0 + + +def triton_sparse_mla_topk_chunk_size() -> int: + return _TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE + + +def triton_sparse_mla_query_chunk_size() -> int: + return _TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py new file mode 100644 index 000000000000..36c41e4b3bad --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -0,0 +1,838 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Portable sparse MLA Triton kernels.""" + +import torch + +from vllm.triton_utils import tl, triton + + +def sparse_mla_decode_head_block_size(num_decode_tokens: int) -> int: + """Choose the SM12x sparse MLA head grouping for decode kernels. + + Single-token decode is latency sensitive and does best with one head per + program. Once there are enough query tokens, grouping heads lets the kernel + reuse each dequantized KV row across multiple heads. + """ + + if num_decode_tokens <= 4: + return 1 + if num_decode_tokens < 16: + return 2 + return 4 + + +@triton.jit +def _accumulate_indexed_attention_chunk_kernel( + q_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + + for candidate_idx in range(0, num_candidates): + kv_index = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidate_idx * stride_indices_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (kv_index >= 0) + + if is_valid: + kv = tl.load( + kv_flat_ptr + + kv_index.to(tl.int64) * stride_kv_t + + offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_indexed_sparse_mla_attention_chunk( + q: torch.Tensor, + kv_flat: torch.Tensor, + indices: torch.Tensor, + lens: torch.Tensor, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv_flat.dim() == 2 + assert indices.dim() == 2 + assert indices.shape[0] == q.shape[0] + assert kv_flat.shape[-1] == q.shape[-1] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = indices.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_indexed_attention_chunk_kernel[grid]( + q, + kv_flat, + indices, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel( + q_ptr, + k_cache_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h + acc_offsets = ( + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d + ) + running_max = tl.load( + max_score_ptr + state_offsets, + mask=head_mask, + other=-float("inf"), + ) + running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) + running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( + tl.float32 + ) + valid_len = tl.load(lens_ptr + token_idx) + + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0) + + if is_valid: + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) + tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) + tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + + +def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + lens: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, + head_block_size: int = 2, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert head_block_size in (1, 2, 4) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = slot_ids.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) + _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel[grid]( + q, + k_cache, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h + acc_offsets = ( + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d + ) + running_max = tl.load( + max_score_ptr + state_offsets, + mask=head_mask, + other=-float("inf"), + ) + running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) + running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( + tl.float32 + ) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len + + if is_valid: + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) + tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) + tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + + +def accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int, + num_candidates: int, + head_block_size: int = 2, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert head_block_size in (1, 2, 4) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) + _accumulate_fp8ds_paged_attention_chunk_multihead_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _finish_attention_state_with_sink_kernel( + max_score_ptr, + denom_ptr, + acc_ptr, + sink_ptr, + output_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + sink = tl.load(sink_ptr + head_idx) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0) + subset_weight = running_denom * subset_scale + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = subset_weight + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + + acc_values = tl.load( + acc_ptr + + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc_values = tl.where(has_tokens, acc_values, 0.0) + output = acc_values * subset_scale * inv_total + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + output, + mask=dim_mask, + ) + + +@triton.jit +def _finish_two_attention_states_with_sink_kernel( + max0_ptr, + denom0_ptr, + acc0_ptr, + max1_ptr, + denom1_ptr, + acc1_ptr, + sink_ptr, + output_ptr, + stride_state0_t: tl.constexpr, + stride_state0_h: tl.constexpr, + stride_acc0_t: tl.constexpr, + stride_acc0_h: tl.constexpr, + stride_acc0_d: tl.constexpr, + stride_state1_t: tl.constexpr, + stride_state1_h: tl.constexpr, + stride_acc1_t: tl.constexpr, + stride_acc1_h: tl.constexpr, + stride_acc1_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state0_offset = token_idx * stride_state0_t + head_idx * stride_state0_h + state1_offset = token_idx * stride_state1_t + head_idx * stride_state1_h + max0 = tl.load(max0_ptr + state0_offset) + denom0 = tl.load(denom0_ptr + state0_offset) + max1 = tl.load(max1_ptr + state1_offset) + denom1 = tl.load(denom1_ptr + state1_offset) + sink = tl.load(sink_ptr + head_idx) + + has0 = denom0 > 0.0 + has1 = denom1 > 0.0 + has_sink = sink > -float("inf") + valid_max0 = tl.where(has0, max0, -float("inf")) + valid_max1 = tl.where(has1, max1, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(tl.maximum(valid_max0, valid_max1), valid_sink) + has_any = has0 | has1 | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_max0 = tl.where(has0, max0, safe_merge_max) + safe_max1 = tl.where(has1, max1, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + scale0 = tl.where(has0, tl.exp(safe_max0 - safe_merge_max), 0.0) + scale1 = tl.where(has1, tl.exp(safe_max1 - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = denom0 * scale0 + denom1 * scale1 + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + + acc0 = tl.load( + acc0_ptr + + token_idx * stride_acc0_t + + head_idx * stride_acc0_h + + offsets * stride_acc0_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc1 = tl.load( + acc1_ptr + + token_idx * stride_acc1_t + + head_idx * stride_acc1_h + + offsets * stride_acc1_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc0 = tl.where(has0, acc0, 0.0) + acc1 = tl.where(has1, acc1, 0.0) + output = (acc0 * scale0 + acc1 * scale1) * inv_total + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + output, + mask=dim_mask, + ) + + +def finish_two_sparse_mla_attention_states_with_sink( + max_score0: torch.Tensor, + denom0: torch.Tensor, + acc0: torch.Tensor, + max_score1: torch.Tensor, + denom1: torch.Tensor, + acc1: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert max_score0.shape == denom0.shape + assert max_score1.shape == denom1.shape + assert max_score0.shape == max_score1.shape + assert acc0.shape == acc1.shape + assert acc0.shape[:2] == max_score0.shape + assert output.shape[0] == acc0.shape[0] + assert output.shape[1] >= acc0.shape[1] + assert output.shape[2] == acc0.shape[2] + assert attn_sink.shape[0] >= acc0.shape[1] + assert max_score0.dtype == torch.float32 + assert denom0.dtype == torch.float32 + assert acc0.dtype == torch.float32 + assert max_score1.dtype == torch.float32 + assert denom1.dtype == torch.float32 + assert acc1.dtype == torch.float32 + assert max_score0.is_cuda and denom0.is_cuda and acc0.is_cuda + assert max_score1.is_cuda and denom1.is_cuda and acc1.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + num_tokens, num_heads, head_dim = acc0.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_two_attention_states_with_sink_kernel[grid]( + max_score0, + denom0, + acc0, + max_score1, + denom1, + acc1, + attn_sink, + output, + max_score0.stride(0), + max_score0.stride(1), + acc0.stride(0), + acc0.stride(1), + acc0.stride(2), + max_score1.stride(0), + max_score1.stride(1), + acc1.stride(0), + acc1.stride(1), + acc1.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +def finish_sparse_mla_attention_with_sink( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert max_score.shape == denom.shape + assert acc.shape[:2] == max_score.shape + assert output.shape[0] == acc.shape[0] + assert output.shape[1] >= acc.shape[1] + assert output.shape[2] == acc.shape[2] + assert attn_sink.shape[0] >= acc.shape[1] + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + num_tokens, num_heads, head_dim = acc.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_attention_state_with_sink_kernel[grid]( + max_score, + denom, + acc, + attn_sink, + output, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) From d3c37395207dfaca57422d6e8666ce6a2bb2dc76 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 6 May 2026 17:49:12 +0800 Subject: [PATCH 004/131] Add DeepSeek V4 SM12x fallback ops Signed-off-by: jasl --- .../models/deepseek_v4/common/ops/__init__.py | 2 + .../deepseek_v4/common/ops/cache_utils.py | 68 ++- .../deepseek_v4/nvidia/ops/fp8_einsum.py | 297 ++++++++++ .../nvidia/ops/sm12x_deep_gemm_fallbacks.py | 522 ++++++++++++++++++ .../deepseek_v4/nvidia/ops/sm12x_mqa.py | 481 ++++++++++++++++ 5 files changed, 1353 insertions(+), 17 deletions(-) create mode 100644 vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py create mode 100644 vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py create mode 100644 vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py diff --git a/vllm/models/deepseek_v4/common/ops/__init__.py b/vllm/models/deepseek_v4/common/ops/__init__.py index ff6ee22996d6..f4ff348091dd 100644 --- a/vllm/models/deepseek_v4/common/ops/__init__.py +++ b/vllm/models/deepseek_v4/common/ops/__init__.py @@ -7,6 +7,7 @@ compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, quantize_and_insert_k_cache, + sparse_prefill_combined_topk_size, ) from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant from .fused_inv_rope_fp8_quant import fused_inv_rope_fp8_quant @@ -27,4 +28,5 @@ "mtp_shared_head_rmsnorm", "quantize_and_insert_k_cache", "save_partial_states", + "sparse_prefill_combined_topk_size", ] diff --git a/vllm/models/deepseek_v4/common/ops/cache_utils.py b/vllm/models/deepseek_v4/common/ops/cache_utils.py index ffaec528aa86..5bdad86de066 100644 --- a/vllm/models/deepseek_v4/common/ops/cache_utils.py +++ b/vllm/models/deepseek_v4/common/ops/cache_utils.py @@ -429,6 +429,8 @@ def compute_global_topk_indices_and_lens( block_table: torch.Tensor, block_size: int, is_valid_token: torch.Tensor, + global_topk_indices: torch.Tensor | None = None, + topk_lens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Map local topk indices to global KV cache slots and count valid entries. @@ -438,8 +440,20 @@ def compute_global_topk_indices_and_lens( 3. Masking padding tokens to length 0 """ num_tokens = topk_indices.shape[0] - global_topk_indices = torch.empty_like(topk_indices) - topk_lens = torch.empty(num_tokens, dtype=torch.int32, device=topk_indices.device) + if global_topk_indices is None: + global_topk_indices = torch.empty_like(topk_indices) + else: + assert global_topk_indices.shape == topk_indices.shape + assert global_topk_indices.dtype == topk_indices.dtype + assert global_topk_indices.device == topk_indices.device + if topk_lens is None: + topk_lens = torch.empty( + num_tokens, dtype=torch.int32, device=topk_indices.device + ) + else: + assert topk_lens.shape == (num_tokens,) + assert topk_lens.dtype == torch.int32 + assert topk_lens.device == topk_indices.device _compute_global_topk_indices_and_lens_kernel[(num_tokens,)]( global_topk_indices, global_topk_indices.stride(0), @@ -486,7 +500,7 @@ def _compute_global_topk_indices_and_lens_kernel( mask=mask, other=-1, ) - is_valid = local_idx >= 0 + is_valid = (local_idx >= 0) & is_valid_token block_indices = local_idx // block_size block_numbers = tl.load( @@ -516,6 +530,14 @@ def _compute_global_topk_indices_and_lens_kernel( _SPARSE_PREFILL_TOPK_ALIGNMENT = 128 +def sparse_prefill_combined_topk_size(topk: int, window_size: int) -> int: + return ( + (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1) + // _SPARSE_PREFILL_TOPK_ALIGNMENT + * _SPARSE_PREFILL_TOPK_ALIGNMENT + ) + + def combine_topk_swa_indices( topk_indices: torch.Tensor, query_start_loc: torch.Tensor, @@ -526,23 +548,35 @@ def combine_topk_swa_indices( topk: int, M: int, N: int, + combined_indices: torch.Tensor | None = None, + combined_lens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: num_tokens = topk_indices.shape[0] num_reqs = seq_lens.shape[0] - combined_topk = ( - (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1) - // _SPARSE_PREFILL_TOPK_ALIGNMENT - * _SPARSE_PREFILL_TOPK_ALIGNMENT - ) - combined_indices = torch.full( - (num_tokens, combined_topk), - fill_value=-1, - dtype=torch.int32, - device=topk_indices.device, - ) - combined_lens = torch.empty( - num_tokens, dtype=torch.int32, device=topk_indices.device - ) + combined_topk = sparse_prefill_combined_topk_size(topk, window_size) + if combined_indices is None: + combined_indices = torch.full( + (num_tokens, combined_topk), + fill_value=-1, + dtype=torch.int32, + device=topk_indices.device, + ) + else: + assert combined_indices.shape[0] >= num_tokens + assert combined_indices.shape[1] >= combined_topk + assert combined_indices.dtype == torch.int32 + assert combined_indices.device == topk_indices.device + combined_indices = combined_indices[:num_tokens, :combined_topk] + combined_indices.fill_(-1) + if combined_lens is None: + combined_lens = torch.empty( + num_tokens, dtype=torch.int32, device=topk_indices.device + ) + else: + assert combined_lens.shape[0] >= num_tokens + assert combined_lens.dtype == torch.int32 + assert combined_lens.device == topk_indices.device + combined_lens = combined_lens[:num_tokens] NUM_WORKERS = 128 _combine_topk_swa_indices_kernel[(num_reqs, NUM_WORKERS)]( diff --git a/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py new file mode 100644 index 000000000000..68cf39061ed8 --- /dev/null +++ b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py @@ -0,0 +1,297 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""SM12x Triton FP8 einsum kernels for DeepSeek V4.""" + +import torch + +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.custom_op import direct_register_custom_op +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils.deep_gemm import fp8_einsum + + +def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: + exp_bits = scale.view(torch.uint8).to(torch.int32) + fp32_bits = exp_bits << 23 + return fp32_bits.view(torch.float32) + + +@triton.jit +def _deepseek_v4_sm12x_fp8_einsum_kernel( + a_ptr, + a_scale_ptr, + b_ptr, + b_scale_ptr, + out_ptr, + num_tokens: tl.constexpr, + num_groups: tl.constexpr, + out_rank: tl.constexpr, + hidden_size: tl.constexpr, + a_stride_token: tl.constexpr, + a_stride_group: tl.constexpr, + a_stride_hidden: tl.constexpr, + a_scale_stride_token: tl.constexpr, + a_scale_stride_group: tl.constexpr, + a_scale_stride_hidden: tl.constexpr, + b_stride_group: tl.constexpr, + b_stride_out: tl.constexpr, + b_stride_hidden: tl.constexpr, + b_scale_stride_group: tl.constexpr, + b_scale_stride_out: tl.constexpr, + b_scale_stride_hidden: tl.constexpr, + out_stride_token: tl.constexpr, + out_stride_group: tl.constexpr, + out_stride_rank: tl.constexpr, + BLOCK_TOKENS: tl.constexpr, + BLOCK_OUT: tl.constexpr, + BLOCK_HIDDEN: tl.constexpr, +) -> None: + token_block = tl.program_id(0) + out_block = tl.program_id(1) + group = tl.program_id(2) + + token_offsets = token_block * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS) + out_offsets = out_block * BLOCK_OUT + tl.arange(0, BLOCK_OUT) + hidden_offsets = tl.arange(0, BLOCK_HIDDEN) + accum = tl.zeros((BLOCK_TOKENS, BLOCK_OUT), dtype=tl.float32) + + for hidden_start in range(0, hidden_size, BLOCK_HIDDEN): + hidden = hidden_start + hidden_offsets + a = tl.load( + a_ptr + + token_offsets[:, None] * a_stride_token + + group * a_stride_group + + hidden[None, :] * a_stride_hidden, + mask=(token_offsets[:, None] < num_tokens) + & (hidden[None, :] < hidden_size), + other=0.0, + ) + b = tl.load( + b_ptr + + group * b_stride_group + + out_offsets[None, :] * b_stride_out + + hidden[:, None] * b_stride_hidden, + mask=(out_offsets[None, :] < out_rank) & (hidden[:, None] < hidden_size), + other=0.0, + ) + raw = tl.dot(a, b, out_dtype=tl.float32) + hidden_scale_block = hidden_start // BLOCK_HIDDEN + a_scale = tl.load( + a_scale_ptr + + token_offsets * a_scale_stride_token + + group * a_scale_stride_group + + hidden_scale_block * a_scale_stride_hidden, + mask=token_offsets < num_tokens, + other=0.0, + ) + b_scale = tl.load( + b_scale_ptr + + group * b_scale_stride_group + + (out_offsets // 128) * b_scale_stride_out + + hidden_scale_block * b_scale_stride_hidden, + mask=out_offsets < out_rank, + other=0.0, + ) + accum += raw * a_scale[:, None] * b_scale[None, :] + + tl.store( + out_ptr + + token_offsets[:, None] * out_stride_token + + group * out_stride_group + + out_offsets[None, :] * out_stride_rank, + accum, + mask=(token_offsets[:, None] < num_tokens) & (out_offsets[None, :] < out_rank), + ) + + +def deepseek_v4_sm12x_fp8_einsum( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, +) -> None: + """Compute ``bhr,hdr->bhd`` with FP32 block scales on SM12x. + + ``a`` is the transposed output of ``fused_inv_rope_fp8_quant`` with shape + ``[tokens, groups, hidden]``. ``b`` is ``wo_a`` reshaped to + ``[groups, out_rank, hidden]``. + """ + num_tokens, num_groups, hidden_size = a.shape + b_groups, out_rank, b_hidden_size = b.shape + assert b_groups == num_groups + assert b_hidden_size == hidden_size + assert out.shape == (num_tokens, num_groups, out_rank) + assert hidden_size % 128 == 0 + assert out_rank % 128 == 0 + assert a.dtype == torch.float8_e4m3fn + assert b.dtype == torch.float8_e4m3fn + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if a_scale.dtype == e8m0_dtype: + a_scale = _upcast_e8m0_to_fp32(a_scale) + if b_scale.dtype == e8m0_dtype: + b_scale = _upcast_e8m0_to_fp32(b_scale) + assert a_scale.dtype == torch.float32 + assert b_scale.dtype == torch.float32 + + if num_tokens == 0: + return + + block_tokens = 16 + block_out = 128 + block_hidden = 128 + grid = ( + triton.cdiv(num_tokens, block_tokens), + triton.cdiv(out_rank, block_out), + num_groups, + ) + _deepseek_v4_sm12x_fp8_einsum_kernel[grid]( + a, + a_scale, + b, + b_scale, + out, + num_tokens, + num_groups, + out_rank, + hidden_size, + a.stride(0), + a.stride(1), + a.stride(2), + a_scale.stride(0), + a_scale.stride(1), + a_scale.stride(2), + b.stride(0), + b.stride(1), + b.stride(2), + b_scale.stride(0), + b_scale.stride(1), + b_scale.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_TOKENS=block_tokens, + BLOCK_OUT=block_out, + BLOCK_HIDDEN=block_hidden, + num_warps=4, + num_stages=3, + ) + + +def deepseek_v4_fp8_einsum_config( + capability_major: int, +) -> tuple[tuple[int, int, int], bool]: + if capability_major == 10: + return (1, 1, 128), True + return (1, 128, 128), False + + +def _use_deepseek_v4_sm12x_triton_fp8_einsum( + equation: str, + recipe: list[int], + b_scale: torch.Tensor, +) -> bool: + capability = current_platform.get_device_capability() + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + return ( + capability is not None + and capability.major == 12 + and equation == "bhr,hdr->bhd" + and tuple(recipe) == (1, 128, 128) + and b_scale.dtype in (torch.float32, e8m0_dtype) + ) + + +def deepseek_v4_fp8_einsum( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, + recipe: list[int], +) -> None: + if equation == "bhr,hdr->bhd" and b.dim() == 2: + num_groups = out.shape[1] + out_rank = out.shape[2] + hidden_size = a.shape[2] + if b.shape[0] % out_rank != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum weight rows must be divisible by " + f"out_rank={out_rank}, got {b.shape[0]}" + ) + b_groups = b.shape[0] // out_rank + group_start = 0 + if b_groups != num_groups: + if b_groups % num_groups != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum weight groups must match the " + "TP-local output groups or be an integer multiple of " + f"them, got weight_groups={b_groups}, " + f"output_groups={num_groups}" + ) + group_partitions = b_groups // num_groups + group_start = ( + get_tensor_model_parallel_rank() % group_partitions + ) * num_groups + b = b.view(b_groups, out_rank, hidden_size) + if group_start != 0 or b_groups != num_groups: + b = b.narrow(0, group_start, num_groups) + + if b_scale.dim() == 2: + scale_mn = recipe[1] + scale_k_pack = 4 if b_scale.dtype == torch.int32 else 1 + scale_k = recipe[2] * scale_k_pack + scale_out_blocks = (out_rank + scale_mn - 1) // scale_mn + scale_hidden_blocks = (hidden_size + scale_k - 1) // scale_k + if b_scale.shape[0] % scale_out_blocks != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum scale rows must be divisible by " + f"scale_out_blocks={scale_out_blocks}, " + f"got {b_scale.shape[0]}" + ) + scale_groups = b_scale.shape[0] // scale_out_blocks + if scale_groups not in (num_groups, b_groups): + raise RuntimeError( + "DeepSeek V4 fp8 einsum scale groups must match the " + "TP-local output groups or weight groups, got " + f"scale_groups={scale_groups}, output_groups={num_groups}, " + f"weight_groups={b_groups}" + ) + b_scale = b_scale.view( + scale_groups, + scale_out_blocks, + scale_hidden_blocks, + ) + if scale_groups == b_groups and scale_groups != num_groups: + b_scale = b_scale.narrow(0, group_start, num_groups) + elif b_scale.dim() == 3 and b_scale.shape[0] == b_groups: + if b_groups != num_groups: + b_scale = b_scale.narrow(0, group_start, num_groups) + + if _use_deepseek_v4_sm12x_triton_fp8_einsum(equation, recipe, b_scale): + deepseek_v4_sm12x_fp8_einsum(a, a_scale, b, b_scale, out) + return + + fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) + + +def deepseek_v4_fp8_einsum_fake( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, + recipe: list[int], +) -> None: + return None + + +direct_register_custom_op( + op_name="deepseek_v4_fp8_einsum", + op_func=deepseek_v4_fp8_einsum, + mutates_args=["out"], + fake_impl=deepseek_v4_fp8_einsum_fake, +) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py new file mode 100644 index 000000000000..8fa9b2697c91 --- /dev/null +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py @@ -0,0 +1,522 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""SM12x fallback implementations for DeepGEMM-only interfaces.""" + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +_SM120_MQA_LOGITS_MAX_SCORE_BYTES = 64 * 1024 * 1024 +_SM120_PAGED_MQA_TOPK_CHUNK_SIZE = 8192 + + +def _fp8_mqa_logits_head_chunk_size( + seq_len: int, + seq_len_kv: int, + num_heads: int, +) -> int: + # The SM120 torch path is used on long prefill paths where materializing + # [head_chunk, M, N] scores can otherwise allocate multiple GiB. Keep the + # transient score tensor bounded, while still using larger head chunks for + # short prompts where they are faster. + score_elems_per_head = max(1, seq_len * seq_len_kv) + max_heads = _SM120_MQA_LOGITS_MAX_SCORE_BYTES // (score_elems_per_head * 4) + return max(1, min(8, num_heads, max_heads)) + + +def _fp8_mqa_logits_k_chunk_size( + seq_len: int, + seq_len_kv: int, + head_chunk_size: int, +) -> int: + score_elems_per_key = max(1, seq_len * head_chunk_size) + max_keys = _SM120_MQA_LOGITS_MAX_SCORE_BYTES // (score_elems_per_key * 4) + return max(1, min(seq_len_kv, max_keys)) + + +def _fp8_mqa_logits_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 MQA logits torch path only supports FP8 Q") + + k_values, k_scales = kv + k_f32 = k_values.to(torch.float32) + k_f32.mul_(k_scales.reshape(-1, 1).to(torch.float32)) + k_t = k_f32.transpose(0, 1).contiguous() + + seq_len, num_heads, _ = q_values.shape + seq_len_kv = k_f32.shape[0] + logits = torch.zeros( + (seq_len, seq_len_kv), device=q_values.device, dtype=torch.float32 + ) + head_chunk_size = _fp8_mqa_logits_head_chunk_size(seq_len, seq_len_kv, num_heads) + + for head_start in range(0, num_heads, head_chunk_size): + head_end = min(head_start + head_chunk_size, num_heads) + q_chunk = q_values[:, head_start:head_end, :].to(torch.float32) + q_chunk = q_chunk.transpose(0, 1).contiguous() + head_weights = weights[:, head_start:head_end].transpose(0, 1).unsqueeze(-1) + k_chunk_size = _fp8_mqa_logits_k_chunk_size( + seq_len, seq_len_kv, head_end - head_start + ) + for k_start in range(0, seq_len_kv, k_chunk_size): + k_end = min(k_start + k_chunk_size, seq_len_kv) + scores = torch.matmul(q_chunk, k_t[:, k_start:k_end]) + scores.relu_() + scores.mul_(head_weights) + logits[:, k_start:k_end].add_( + scores[0] if scores.shape[0] == 1 else scores.sum(dim=0) + ) + + if clean_logits: + offsets = torch.arange(seq_len_kv, device=q_values.device) + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + logits = logits.masked_fill(~valid, float("-inf")) + + return logits + + +def _fp8_mqa_logits_topk_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_tokens: int, + out: torch.Tensor | None = None, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 MQA top-k torch path only supports FP8 Q") + + k_values, k_scales = kv + k_f32 = k_values.to(torch.float32) + k_f32.mul_(k_scales.reshape(-1, 1).to(torch.float32)) + k_t = k_f32.transpose(0, 1).contiguous() + + seq_len, num_heads, _ = q_values.shape + seq_len_kv = k_f32.shape[0] + if out is None: + out = torch.empty( + (seq_len, topk_tokens), device=q_values.device, dtype=torch.int32 + ) + else: + assert out.shape == (seq_len, topk_tokens) + assert out.dtype == torch.int32 + out.fill_(-1) + + best_values = torch.full( + (seq_len, topk_tokens), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + head_chunk_size = _fp8_mqa_logits_head_chunk_size(seq_len, seq_len_kv, num_heads) + k_chunk_size = _fp8_mqa_logits_k_chunk_size(seq_len, seq_len_kv, head_chunk_size) + max_chunk_topk = min(topk_tokens, k_chunk_size) + chunk_values_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + chunk_indices_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int64, + ) + chunk_indices_i32 = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + candidate_values = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + candidate_indices = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + next_best_values = torch.empty_like(best_values) + selected = torch.empty( + (seq_len, topk_tokens), + device=q_values.device, + dtype=torch.int64, + ) + + for k_start in range(0, seq_len_kv, k_chunk_size): + k_end = min(k_start + k_chunk_size, seq_len_kv) + chunk_logits = torch.zeros( + (seq_len, k_end - k_start), + device=q_values.device, + dtype=torch.float32, + ) + for head_start in range(0, num_heads, head_chunk_size): + head_end = min(head_start + head_chunk_size, num_heads) + q_chunk = q_values[:, head_start:head_end, :].to(torch.float32) + q_chunk = q_chunk.transpose(0, 1).contiguous() + head_weights = weights[:, head_start:head_end].transpose(0, 1).unsqueeze(-1) + scores = torch.matmul(q_chunk, k_t[:, k_start:k_end]) + scores.relu_() + scores.mul_(head_weights) + chunk_logits.add_(scores[0] if scores.shape[0] == 1 else scores.sum(dim=0)) + + offsets = torch.arange(k_start, k_end, device=q_values.device) + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + chunk_logits.masked_fill_(~valid, float("-inf")) + + chunk_topk = min(topk_tokens, k_end - k_start) + chunk_values = chunk_values_buf[:, :chunk_topk] + chunk_indices = chunk_indices_buf[:, :chunk_topk] + torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices)) + chunk_indices_out = chunk_indices_i32[:, :chunk_topk] + chunk_indices_out.copy_(chunk_indices) + chunk_indices_out.add_(k_start) + + candidate_cols = topk_tokens + chunk_topk + candidate_values_view = candidate_values[:, :candidate_cols] + candidate_indices_view = candidate_indices[:, :candidate_cols] + candidate_values_view[:, :topk_tokens].copy_(best_values) + candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values) + candidate_indices_view[:, :topk_tokens].copy_(out) + candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out) + torch.topk( + candidate_values_view, + topk_tokens, + dim=1, + out=(next_best_values, selected), + ) + torch.gather(candidate_indices_view, 1, selected, out=out) + best_values, next_best_values = next_best_values, best_values + out.masked_fill_(~torch.isfinite(best_values), -1) + + return out + + +def fp8_fp4_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 MQA top-k indices without materializing full logits.""" + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q[1] is None + ): + return False + _fp8_mqa_logits_topk_torch( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_indices.shape[1], + out=topk_indices, + ) + return True + + +def _fp8_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + q_values, q_scale = q + if clean_logits and q_scale is None and q_values.dim() == 3 and kv[0].dim() == 2: + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( + fp8_mqa_logits_triton, + ) + + return fp8_mqa_logits_triton(q_values, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + return _fp8_mqa_logits_torch( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) + + +def _fp8_paged_mqa_logits_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 paged MQA torch path only supports FP8 Q") + + batch_size, next_n, num_heads, head_dim = q_values.shape + head_dim_with_scale = kv_cache.shape[-1] + assert head_dim_with_scale > head_dim + assert weights.shape == (batch_size * next_n, num_heads) + assert context_lens.shape == (batch_size, next_n) + + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( + _view_packed_fp8_paged_mqa_kv_cache, + ) + + kv_values, kv_scales = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_kv, _, _ = kv_values.shape + logits = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + + q_f32 = q_values.float() + score_bytes = _SM120_MQA_LOGITS_MAX_SCORE_BYTES + max_tokens_per_chunk = max(1, score_bytes // max(1, num_heads * 4)) + token_offsets_cache: dict[int, torch.Tensor] = {} + + for batch_idx in range(batch_size): + for next_idx in range(next_n): + row = batch_idx * next_n + next_idx + context_len = int(context_lens[batch_idx, next_idx].item()) + if context_len <= 0: + continue + + q_row = q_f32[batch_idx, next_idx] + row_weights = weights[row] + for token_start in range(0, context_len, max_tokens_per_chunk): + token_end = min(context_len, token_start + max_tokens_per_chunk) + chunk_len = token_end - token_start + token_offsets = token_offsets_cache.get(chunk_len) + if token_offsets is None or token_offsets.device != q_values.device: + token_offsets = torch.arange( + chunk_len, device=q_values.device, dtype=torch.long + ) + token_offsets_cache[chunk_len] = token_offsets + token_ids = token_start + token_offsets + logical_blocks = token_ids // block_kv + token_in_block = token_ids - logical_blocks * block_kv + physical_blocks = block_tables[batch_idx, logical_blocks] + kv_chunk = kv_values[physical_blocks, token_in_block, 0].float() + scale_chunk = kv_scales[physical_blocks, token_in_block, 0].squeeze(-1) + kv_chunk.mul_(scale_chunk[:, None]) + scores = torch.matmul(q_row, kv_chunk.T) + scores.relu_() + scores.mul_(row_weights[:, None]) + logits[row, token_start:token_end] = scores.sum(dim=0) + + return logits + + +def _fp8_paged_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + q_values, q_scale = q + if ( + q_scale is None + and q_values.dim() == 4 + and kv_cache.dtype == torch.uint8 + and kv_cache.shape[-1] == q_values.shape[-1] + 4 + ): + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( + fp8_paged_mqa_logits_triton, + ) + + return fp8_paged_mqa_logits_triton( + q_values, kv_cache, weights, context_lens, block_tables, max_model_len + ) + logger.warning_once( + "SM12x paged-MQA falling back to the torch reference path " + "(q_scale=%s, q.dim=%s, kv_cache.dtype=%s, kv_cache.shape[-1]=%s, " + "q_values.shape[-1]=%s). This path is intended for correctness checks " + "and is not graph-compatible; expect a large per-step latency.", + "set" if q_scale is not None else "None", + q_values.dim(), + kv_cache.dtype, + kv_cache.shape[-1] if kv_cache.dim() else None, + q_values.shape[-1], + ) + return _fp8_paged_mqa_logits_torch( + q, kv_cache, weights, context_lens, block_tables, max_model_len + ) + + +def fp8_fp4_paged_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 paged MQA top-k indices without full logits.""" + q_values, q_scale = q + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q_scale is None + and q_values.dim() == 4 + and kv_cache.dtype == torch.uint8 + and kv_cache.shape[-1] == q_values.shape[-1] + 4 + ): + return False + + num_rows = q_values.shape[0] * q_values.shape[1] + topk_tokens = topk_indices.shape[1] + assert topk_indices.shape == (num_rows, topk_tokens) + assert topk_indices.dtype == torch.int32 + topk_indices.fill_(-1) + if num_rows == 0 or topk_tokens == 0 or max_model_len == 0: + return True + + best_values = torch.full( + (num_rows, topk_tokens), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + chunk_size = max(1, _SM120_PAGED_MQA_TOPK_CHUNK_SIZE) + max_chunk_topk = min(topk_tokens, chunk_size) + chunk_values_buf = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + chunk_indices_buf = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.int64, + ) + chunk_indices_i32 = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + candidate_values = torch.empty( + (num_rows, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + candidate_indices = torch.empty( + (num_rows, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + next_best_values = torch.empty_like(best_values) + selected = torch.empty( + (num_rows, topk_tokens), + device=q_values.device, + dtype=torch.int64, + ) + + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( + fp8_paged_mqa_logits_triton, + ) + + for token_start in range(0, max_model_len, chunk_size): + token_count = min(chunk_size, max_model_len - token_start) + chunk_logits = fp8_paged_mqa_logits_triton( + q_values, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + token_start=token_start, + token_count=token_count, + ) + chunk_topk = min(topk_tokens, token_count) + chunk_values = chunk_values_buf[:, :chunk_topk] + chunk_indices = chunk_indices_buf[:, :chunk_topk] + torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices)) + chunk_indices_out = chunk_indices_i32[:, :chunk_topk] + chunk_indices_out.copy_(chunk_indices) + chunk_indices_out.add_(token_start) + + candidate_cols = topk_tokens + chunk_topk + candidate_values_view = candidate_values[:, :candidate_cols] + candidate_indices_view = candidate_indices[:, :candidate_cols] + candidate_values_view[:, :topk_tokens].copy_(best_values) + candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values) + candidate_indices_view[:, :topk_tokens].copy_(topk_indices) + candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out) + torch.topk( + candidate_values_view, + topk_tokens, + dim=1, + out=(next_best_values, selected), + ) + torch.gather(candidate_indices_view, 1, selected, out=topk_indices) + best_values, next_best_values = next_best_values, best_values + topk_indices.masked_fill_(~torch.isfinite(best_values), -1) + + return True + + +def _tf32_hc_prenorm_gemm_torch( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + """Portable SM12x HyperConnection prenorm GEMM fallback. + + DeepGEMM's split ABI only requires that downstream consumers recover the + full result by summing over the split dimension. Keep the implementation + simple by writing the full product to split zero and clearing the rest. + """ + del num_split + product = x.float() @ fn.float().T + norm = x.float().square().sum(dim=-1) + + if out.dim() == 3: + out.zero_() + sqrsum.zero_() + out[0].copy_(product) + sqrsum[0].copy_(norm) + else: + out.copy_(product) + sqrsum.copy_(norm) + return out + + +def _tf32_hc_prenorm_gemm_sm12x( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + if out.dim() == 3 and sqrsum.dim() == 2: + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( + tf32_hc_prenorm_gemm_triton, + ) + + tf32_hc_prenorm_gemm_triton(x, fn, out, sqrsum, num_split) + return out + + return _tf32_hc_prenorm_gemm_torch(x, fn, out, sqrsum, num_split) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py new file mode 100644 index 000000000000..85dab1f6d9a1 --- /dev/null +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py @@ -0,0 +1,481 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Triton fallback kernels used by the local DeepSeek V4 path.""" + +import torch + +from vllm.triton_utils import tl, triton + + +def _view_packed_fp8_paged_mqa_kv_cache( + kv_cache: torch.Tensor, + head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Return FP8 values and fp32 scales from indexer cache block storage.""" + if kv_cache.dtype != torch.uint8: + raise TypeError(f"Expected uint8 kv_cache, got {kv_cache.dtype}") + if kv_cache.dim() == 3: + num_blocks, block_size, head_dim_with_scale = kv_cache.shape + num_kv_heads = 1 + elif kv_cache.dim() == 4: + num_blocks, block_size, num_kv_heads, head_dim_with_scale = kv_cache.shape + else: + raise ValueError( + f"Expected 3D or 4D kv_cache, got {kv_cache.dim()} dimensions" + ) + if num_kv_heads != 1: + raise ValueError(f"Expected one KV head, got {num_kv_heads}") + + scale_bytes = head_dim_with_scale - head_dim + if scale_bytes <= 0 or scale_bytes % torch.float32.itemsize != 0: + raise ValueError( + "Expected kv_cache last dimension to contain FP8 values followed " + f"by fp32 scale bytes; got head_dim={head_dim}, " + f"last_dim={head_dim_with_scale}" + ) + + block_stride = kv_cache.stride(0) + base_storage_offset = kv_cache.storage_offset() + scale_elems = scale_bytes // torch.float32.itemsize + kv_values = torch.as_strided( + kv_cache, + size=(num_blocks, block_size, 1, head_dim), + stride=(block_stride, head_dim, head_dim, 1), + storage_offset=base_storage_offset, + ).view(torch.float8_e4m3fn) + kv_scale = torch.as_strided( + kv_cache, + size=(num_blocks, block_size, 1, scale_bytes), + stride=(block_stride, scale_bytes, scale_bytes, 1), + storage_offset=base_storage_offset + block_size * head_dim, + ).view(torch.float32) + return kv_values, kv_scale[..., :scale_elems] + + +@triton.jit +def _fp8_mqa_logits_kernel( + q_ptr, + k_ptr, + scale_ptr, + weights_ptr, + cu_seqlen_ks_ptr, + cu_seqlen_ke_ptr, + logits_ptr, + num_q: tl.constexpr, + seq_len_kv: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kn: tl.constexpr, + stride_kd: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + valid_m = offs_m < num_q + valid_n = offs_n < seq_len_kv + seq_start = tl.load(cu_seqlen_ks_ptr + offs_m, mask=valid_m, other=0) + seq_end = tl.load(cu_seqlen_ke_ptr + offs_m, mask=valid_m, other=0) + seq_mask = (offs_n[None, :] >= seq_start[:, None]) & ( + offs_n[None, :] < seq_end[:, None] + ) + + logits = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for h in tl.range(0, num_heads): + scores = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + offs_m[:, None] * stride_qm + + h * stride_qh + + d[None, :] * stride_qd, + mask=valid_m[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + k_ptr + offs_n[:, None] * stride_kn + d[None, :] * stride_kd, + mask=valid_n[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.dot(q, tl.trans(k), input_precision="tf32") + scale = tl.load(scale_ptr + offs_n, mask=valid_n, other=0.0) + weighted = tl.maximum(scores * scale[None, :], 0.0) + weight = tl.load( + weights_ptr + offs_m * stride_wm + h * stride_wh, + mask=valid_m, + other=0.0, + ) + logits += weighted * weight[:, None] + + store_mask = valid_m[:, None] & valid_n[None, :] + logits = tl.where(seq_mask & store_mask, logits, float("-inf")) + tl.store( + logits_ptr + offs_m[:, None] * stride_lm + offs_n[None, :] * stride_ln, + logits, + mask=store_mask, + ) + + +def fp8_mqa_logits_triton( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + k_fp8, scale = kv + num_q, num_heads, head_dim = q.shape + seq_len_kv = k_fp8.shape[0] + logits = torch.empty( + (num_q, seq_len_kv), + device=q.device, + dtype=torch.float32, + ) + if num_q == 0 or seq_len_kv == 0: + return logits + + grid = (triton.cdiv(num_q, 8), triton.cdiv(seq_len_kv, 64)) + _fp8_mqa_logits_kernel[grid]( + q, + k_fp8, + scale, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + logits, + num_q, + seq_len_kv, + num_heads, + head_dim, + q.stride(0), + q.stride(1), + q.stride(2), + k_fp8.stride(0), + k_fp8.stride(1), + weights.stride(0), + weights.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_M=8, + BLOCK_N=64, + BLOCK_D=64, + num_warps=4, + ) + return logits + + +@triton.jit +def _fp8_paged_mqa_logits_kernel( + q_ptr, + kv_ptr, + scale_ptr, + weights_ptr, + context_lens_ptr, + block_tables_ptr, + logits_ptr, + token_start, + num_rows: tl.constexpr, + logits_width: tl.constexpr, + next_n: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + block_size: tl.constexpr, + stride_qb: tl.constexpr, + stride_qn: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kvb: tl.constexpr, + stride_kvs: tl.constexpr, + stride_kvd: tl.constexpr, + stride_sb: tl.constexpr, + stride_ss: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_clb: tl.constexpr, + stride_cln: tl.constexpr, + stride_btb: tl.constexpr, + stride_btk: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_local_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = token_start + offs_local_n + offs_d = tl.arange(0, BLOCK_D) + + valid_m = offs_m < num_rows + valid_n = offs_local_n < logits_width + batch = offs_m // next_n + q_pos = offs_m - batch * next_n + context_len = tl.load( + context_lens_ptr + batch * stride_clb + q_pos * stride_cln, + mask=valid_m, + other=0, + ) + context_mask = valid_n[None, :] & (offs_n[None, :] < context_len[:, None]) + + block_rank = offs_n // block_size + block_offset = offs_n - block_rank * block_size + block_idx = tl.load( + block_tables_ptr + + batch[:, None] * stride_btb + + block_rank[None, :] * stride_btk, + mask=valid_m[:, None] & valid_n[None, :], + other=0, + ) + + logits = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + scale = tl.load( + scale_ptr + block_idx * stride_sb + block_offset[None, :] * stride_ss, + mask=context_mask, + other=0.0, + ) + for h in tl.range(0, num_heads): + scores = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + batch[:, None] * stride_qb + + q_pos[:, None] * stride_qn + + h * stride_qh + + d[None, :] * stride_qd, + mask=valid_m[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + kv_ptr + + block_idx[:, :, None] * stride_kvb + + block_offset[None, :, None] * stride_kvs + + d[None, None, :] * stride_kvd, + mask=context_mask[:, :, None] & (d[None, None, :] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.sum(q[:, None, :] * k, axis=2) + weighted = tl.maximum(scores * scale, 0.0) + weight = tl.load( + weights_ptr + offs_m * stride_wm + h * stride_wh, + mask=valid_m, + other=0.0, + ) + logits += weighted * weight[:, None] + + store_mask = valid_m[:, None] & valid_n[None, :] + logits = tl.where(context_mask & store_mask, logits, float("-inf")) + tl.store( + logits_ptr + offs_m[:, None] * stride_lm + offs_local_n[None, :] * stride_ln, + logits, + mask=store_mask, + ) + + +def fp8_paged_mqa_logits_triton( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + token_start: int = 0, + token_count: int | None = None, +) -> torch.Tensor: + batch_size, next_n, num_heads, head_dim = q.size() + kv_values, kv_scale = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_size, _, _ = kv_values.size() + num_rows = batch_size * next_n + if token_count is None: + token_count = max_model_len - token_start + assert token_start >= 0 + assert token_count >= 0 + assert token_start + token_count <= max_model_len + logits = torch.empty( + (num_rows, token_count), + device=q.device, + dtype=torch.float32, + ) + if num_rows == 0 or token_count == 0: + return logits + + context_lens_2d = context_lens.reshape(batch_size, -1) + if context_lens_2d.shape[1] == 1 and next_n != 1: + context_lens_2d = context_lens_2d.expand(batch_size, next_n).contiguous() + grid = (triton.cdiv(num_rows, 4), triton.cdiv(token_count, 64)) + _fp8_paged_mqa_logits_kernel[grid]( + q, + kv_values, + kv_scale, + weights, + context_lens_2d, + block_tables, + logits, + token_start, + num_rows, + token_count, + next_n, + num_heads, + head_dim, + block_size, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + kv_values.stride(0), + kv_values.stride(1), + kv_values.stride(3), + kv_scale.stride(0), + kv_scale.stride(1), + weights.stride(0), + weights.stride(1), + context_lens_2d.stride(0), + context_lens_2d.stride(1), + block_tables.stride(0), + block_tables.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_M=4, + BLOCK_N=64, + BLOCK_D=64, + num_warps=4, + ) + return logits + + +@triton.jit +def _tf32_hc_prenorm_gemm_kernel( + x_ptr, + fn_ptr, + out_ptr, + sqrsum_ptr, + M: tl.constexpr, + K: tl.constexpr, + N: tl.constexpr, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_fnn: tl.constexpr, + stride_fnk: tl.constexpr, + stride_outs: tl.constexpr, + stride_outm: tl.constexpr, + stride_outn: tl.constexpr, + stride_sqs: tl.constexpr, + stride_sqm: tl.constexpr, + NUM_SPLIT: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + pid_s = tl.program_id(2) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + split_k = tl.cdiv(K, NUM_SPLIT) + split_begin = pid_s * split_k + split_end = tl.minimum(split_begin + split_k, K) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + sq = tl.zeros((BLOCK_M,), dtype=tl.float32) + + for k0 in tl.range(0, split_k, BLOCK_K): + k = split_begin + k0 + offs_k + k_mask = k < split_end + x = tl.load( + x_ptr + offs_m[:, None] * stride_xm + k[None, :] * stride_xk, + mask=(offs_m[:, None] < M) & k_mask[None, :], + other=0.0, + ).to(tl.float32) + fn = tl.load( + fn_ptr + offs_n[None, :] * stride_fnn + k[:, None] * stride_fnk, + mask=(offs_n[None, :] < N) & k_mask[:, None], + other=0.0, + ).to(tl.float32) + + acc += tl.dot(x, fn, input_precision="tf32", out_dtype=tl.float32) + sq += tl.sum(x * x, axis=1) + + tl.store( + out_ptr + + pid_s * stride_outs + + offs_m[:, None] * stride_outm + + offs_n[None, :] * stride_outn, + acc, + mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), + ) + + if pid_n == 0: + tl.store( + sqrsum_ptr + pid_s * stride_sqs + offs_m * stride_sqm, + sq, + mask=offs_m < M, + ) + + +def tf32_hc_prenorm_gemm_triton( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> None: + assert x.dim() == 2 + assert fn.dim() == 2 + assert out.dim() == 3 + assert sqrsum.dim() == 2 + + m, k = x.shape + n = fn.shape[0] + assert fn.shape[1] == k + assert out.shape == (num_split, m, n) + assert sqrsum.shape == (num_split, m) + + if m == 0: + return + + block_m = 16 + block_n = triton.next_power_of_2(n) + block_n = min(max(block_n, 16), 32) + block_k = 64 + grid = (triton.cdiv(m, block_m), triton.cdiv(n, block_n), num_split) + _tf32_hc_prenorm_gemm_kernel[grid]( + x, + fn, + out, + sqrsum, + m, + k, + n, + x.stride(0), + x.stride(1), + fn.stride(0), + fn.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + sqrsum.stride(0), + sqrsum.stride(1), + num_split, + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=4, + ) From f8f18ce66b5378e00124436e5b7b183b8269a602 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 6 May 2026 17:49:21 +0800 Subject: [PATCH 005/131] Route SM12x DeepGEMM fallbacks Signed-off-by: jasl --- .../test_sm120_deepgemm_fallbacks.py | 184 ++++++++++++++++ .../kernels/linear/scaled_mm/cutlass.py | 45 ++++ .../layers/fused_moe/routed_experts.py | 33 ++- .../layers/quantization/utils/fp8_utils.py | 49 ++++- .../layers/sparse_attn_indexer.py | 197 +++++++++++++----- vllm/utils/deep_gemm.py | 119 ++++++++++- 6 files changed, 554 insertions(+), 73 deletions(-) create mode 100644 tests/v1/attention/test_sm120_deepgemm_fallbacks.py diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py new file mode 100644 index 000000000000..fa8a31b3676b --- /dev/null +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm.utils.deep_gemm as deep_gemm_utils +from vllm.model_executor.layers.sparse_attn_indexer import ( + _decode_logits_width, + _decode_topk_logits_width, +) +from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv +from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + + +def _make_indexer_kv_cache( + kv_fp8: torch.Tensor, + kv_scale: torch.Tensor, +) -> torch.Tensor: + num_blocks, block_size, num_kv_heads, head_dim = kv_fp8.shape + assert num_kv_heads == 1 + fused_kv = torch.empty( + num_blocks, + block_size, + 1, + head_dim + torch.float32.itemsize, + device=kv_fp8.device, + dtype=torch.uint8, + ) + block_stride = fused_kv.stride(0) + kv_values = torch.as_strided( + fused_kv, + size=kv_fp8.shape, + stride=(block_stride, head_dim, head_dim, 1), + ) + kv_scales = torch.as_strided( + fused_kv, + size=(num_blocks, block_size, 1, torch.float32.itemsize), + stride=(block_stride, torch.float32.itemsize, torch.float32.itemsize, 1), + storage_offset=block_size * head_dim, + ) + kv_values.copy_(kv_fp8.view(torch.uint8)) + kv_scales.copy_(kv_scale.contiguous().view(torch.uint8)) + return fused_kv + + +def _reference_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_fp8: torch.Tensor, + kv_scale: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + batch_size, next_n, _, _ = q_fp8.shape + _, block_size, _, _ = kv_fp8.shape + logits = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device=q_fp8.device, + dtype=torch.float32, + ) + q = q_fp8.float() + kv = kv_fp8.float() * kv_scale.float() + for batch_idx in range(batch_size): + for next_idx in range(next_n): + row = batch_idx * next_n + next_idx + context_len = min( + int(context_lens[batch_idx, next_idx].item()), + max_model_len, + ) + for token_idx in range(context_len): + block_idx = block_tables[batch_idx, token_idx // block_size] + block_offset = token_idx % block_size + k = kv[block_idx, block_offset, 0] + scores = (q[batch_idx, next_idx] * k).sum(dim=-1).relu() + logits[row, token_idx] = (scores * weights[row]).sum() + return logits + + +def test_decode_logits_width_uses_active_context_bound(): + assert _decode_logits_width(262144, 1024) == 1024 + assert _decode_logits_width(4096, 8192) == 4096 + assert _decode_logits_width(4096, 0) == 4096 + assert _decode_logits_width(0, 1024) == 0 + + +def test_decode_topk_logits_width_keeps_topk_kernel_width(): + assert _decode_topk_logits_width(262144, 1024, 512) == 1024 + assert _decode_topk_logits_width(262144, 128, 512) == 512 + assert _decode_topk_logits_width(300, 128, 512) == 300 + assert _decode_topk_logits_width(0, 128, 512) == 0 + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_paged_mqa_direct_topk_matches_truncated_decode_width( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(7) + batch_size, next_n, num_heads, head_dim = 2, 2, 8, 32 + block_size, max_model_len, num_blocks = 4, 64, 16 + active_max_len = 13 + topk_tokens = 6 + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None) + monkeypatch.setattr( + sm12x_deep_gemm_fallbacks, + "_SM120_PAGED_MQA_TOPK_CHUNK_SIZE", + 7, + ) + + q = torch.randn( + batch_size, + next_n, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + q_fp8 = q.to(torch.float8_e4m3fn).contiguous() + kv = torch.randn( + num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn) + fused_kv = _make_indexer_kv_cache(kv_fp8, kv_scale) + + weights = torch.randn( + batch_size * next_n, num_heads, device="cuda", dtype=torch.float32 + ) + context_lens = torch.tensor( + [[7, active_max_len], [9, 12]], device="cuda", dtype=torch.int32 + ) + block_tables = ( + torch.arange( + batch_size * cdiv(max_model_len, block_size), + device="cuda", + dtype=torch.int32, + ).reshape(batch_size, -1) + % num_blocks + ) + + full_width_topk = torch.empty( + batch_size * next_n, topk_tokens, device="cuda", dtype=torch.int32 + ) + truncated_width_topk = torch.empty_like(full_width_topk) + + assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices( + (q_fp8, None), + fused_kv, + weights, + context_lens, + block_tables, + max_model_len, + full_width_topk, + ) + assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices( + (q_fp8, None), + fused_kv, + weights, + context_lens, + block_tables, + active_max_len, + truncated_width_topk, + ) + + reference_logits = _reference_paged_mqa_logits( + q_fp8, + kv_fp8, + kv_scale, + weights, + context_lens, + block_tables, + active_max_len, + ) + expected_topk = torch.topk(reference_logits, topk_tokens, dim=1).indices.to( + torch.int32 + ) + + torch.testing.assert_close(truncated_width_topk, full_width_topk, rtol=0, atol=0) + torch.testing.assert_close(truncated_width_topk, expected_topk, rtol=0, atol=0) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py index 9f69ab0c7377..0b679792267c 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py @@ -9,6 +9,9 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, @@ -30,6 +33,20 @@ ) +def _is_sm12x_compute_capability(compute_capability) -> bool: + if compute_capability is None: + return current_platform.is_device_capability_family(120) + + if isinstance(compute_capability, tuple): + return compute_capability[0] == 12 + + to_int = getattr(compute_capability, "to_int", None) + if callable(to_int): + return to_int() // 10 == 12 + + return int(compute_capability) // 10 == 12 + + class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def is_supported( @@ -286,6 +303,9 @@ def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None: @classmethod def is_supported(cls, compute_capability=None): + if _is_sm12x_compute_capability(compute_capability): + return False, "CUTLASS block-scaled FP8 GEMM is not supported on SM12x." + if not CUTLASS_BLOCK_FP8_SUPPORTED: return ( False, @@ -309,6 +329,31 @@ def can_implement(cls, config: FP8ScaledMMLinearLayerConfig): ) return True, None + def process_weights_after_loading(self, layer: torch.nn.Module): + super().process_weights_after_loading(layer) + params = self._get_layer_params(layer) + weight_scale = ( + params.weight_scale + if params.weight_scale_inv is None + else params.weight_scale_inv + ) + scale_attr_name = ( + params.WEIGHT_SCALE + if params.weight_scale_inv is None + else params.WEIGHT_SCALE_INV + ) + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if ( + e8m0_dtype is not None + and weight_scale is not None + and weight_scale.dtype == e8m0_dtype + ): + replace_parameter( + layer, + scale_attr_name, + _upcast_e8m0_to_fp32(weight_scale), + ) + def apply_block_scaled_mm( self, A: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/routed_experts.py b/vllm/model_executor/layers/fused_moe/routed_experts.py index 669d1d376902..0a4a1a37d958 100644 --- a/vllm/model_executor/layers/fused_moe/routed_experts.py +++ b/vllm/model_executor/layers/fused_moe/routed_experts.py @@ -275,6 +275,19 @@ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: # Weight Loading Methods # + @staticmethod + def _normalize_loaded_weight_for_copy( + expert_data: torch.Tensor, loaded_weight: torch.Tensor + ) -> torch.Tensor: + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if ( + e8m0_dtype is not None + and expert_data.dtype == torch.uint8 + and loaded_weight.dtype == e8m0_dtype + ): + return loaded_weight.view(torch.uint8) + return loaded_weight + def _load_per_tensor_weight_scale( self, shard_id: str, @@ -288,10 +301,12 @@ def _load_per_tensor_weight_scale( # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight + target = param_data[expert_id][idx] + target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight)) # If we are in the row parallel case (down_proj) elif shard_id == "w2": - param_data[expert_id] = loaded_weight + target = param_data[expert_id] + target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight)) def _load_combined_w13_weight_scale( self, @@ -308,7 +323,7 @@ def _load_combined_w13_weight_scale( loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size ) - param.copy_(loaded_weight) + param.copy_(self._normalize_loaded_weight_for_copy(param, loaded_weight)) def _load_model_weight_or_group_weight_scale( self, @@ -366,7 +381,9 @@ def _load_per_channel_weight_scale( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) elif shard_id in ("w1", "w3"): self._load_w13( shard_id=shard_id, @@ -482,7 +499,9 @@ def _load_w13( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) def _load_w2( self, @@ -517,7 +536,9 @@ def _load_w2( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) def _load_single_value( self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index be1167332ed8..da91124e44ae 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -889,6 +889,35 @@ def get_w8a8_block_fp8_configs( return None +def _get_default_w8a8_block_fp8_config( + M: int, + block_n: int, + block_k: int, +) -> dict[str, Any]: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_n and + # BLOCK_SIZE_K must be divisible by block_k. + # M-aware tuning for low-M decode: BLOCK_SIZE_M=64 wastes most of the + # M-dim for single-request decode and short MTP-style draft batches. SM12x + # keeps benefiting from the low-M tile through M=32 on DeepSeek V4 shapes. + capability = current_platform.get_device_capability() + capability_major = getattr(capability, "major", None) + if capability_major is None and capability is not None: + capability_major = capability[0] + low_m_limit = 32 if capability_major == 12 else 8 + if low_m_limit >= M: + block_m, num_stages = 16, (2 if current_platform.is_rocm() else 3) + else: + block_m, num_stages = 64, 2 + return { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": num_stages, + } + + def w8a8_triton_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, @@ -933,6 +962,12 @@ def w8a8_triton_block_scaled_mm( N, K = B.shape assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if e8m0_dtype is not None: + if As.dtype == e8m0_dtype: + As = _upcast_e8m0_to_fp32(As) + if Bs.dtype == e8m0_dtype: + Bs = _upcast_e8m0_to_fp32(Bs) C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) @@ -942,17 +977,7 @@ def w8a8_triton_block_scaled_mm( # Get the optimal config if there is one config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: - # Default config - # Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0] - # BLOCK_SIZE_K must be divisible by block_size[1] - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": block_size[0], - "BLOCK_SIZE_K": block_size[1], - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 2, - } + config = _get_default_w8a8_block_fp8_config(M, block_size[0], block_size[1]) def grid(META): return ( @@ -1289,6 +1314,8 @@ def create_fp8_scale_parameter( if dtype == torch.float32: scale[:] = torch.finfo(torch.float32).min + elif dtype == getattr(torch, "float8_e8m0fnu", None): + scale[:] = 0 set_weight_attrs(scale, {"scale_type": "weight_scale"}) return scale diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 45c5d5f78198..000241b46b08 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -4,7 +4,6 @@ import torch -import vllm.envs as envs from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.compilation.breakable_cudagraph import eager_break_during_capture @@ -14,7 +13,9 @@ from vllm.platforms import current_platform from vllm.utils.deep_gemm import ( fp8_fp4_mqa_logits, + fp8_fp4_mqa_topk_indices, fp8_fp4_paged_mqa_logits, + fp8_fp4_paged_mqa_topk_indices, has_deep_gemm, ) from vllm.utils.torch_utils import ( @@ -25,6 +26,7 @@ ) from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerMetadata, + sparse_indexer_max_logits_bytes, ) from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.worker.workspace import current_workspace_manager @@ -32,11 +34,58 @@ logger = init_logger(__name__) RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024 +SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH = 4096 +SM120_SHORT_ROW_TOPK_MAX_WIDTH = 12288 # MXFP4 layout: 2 values packed per byte, ue8m0 (1-byte) scale per block of 32. MXFP4_BLOCK_SIZE = 32 +def _should_use_sm120_short_row_topk_decode( + topk_tokens: int, + logits_width: int, + is_cuda_sm120: bool, +) -> bool: + if not is_cuda_sm120 or topk_tokens != 512: + return False + if logits_width <= SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH: + return True + return logits_width < SM120_SHORT_ROW_TOPK_MAX_WIDTH + + +def _use_sm120_short_row_topk_decode( + logits: torch.Tensor, + topk_tokens: int, +) -> bool: + return _should_use_sm120_short_row_topk_decode( + topk_tokens, + logits.shape[1], + current_platform.is_cuda() + and current_platform.is_device_capability_family(120), + ) + + +def _decode_logits_width(max_model_len: int, max_seq_len: int) -> int: + if max_model_len <= 0: + return 0 + if max_seq_len <= 0: + return max_model_len + return min(max_model_len, max_seq_len) + + +def _decode_topk_logits_width( + max_model_len: int, max_seq_len: int, topk_tokens: int +) -> int: + logits_width = _decode_logits_width(max_model_len, max_seq_len) + return min(max_model_len, max(logits_width, topk_tokens)) + + +def _sparse_indexer_requires_deep_gemm() -> bool: + return current_platform.is_cuda() and not ( + current_platform.is_device_capability_family(120) + ) + + def _gather_workspace_shapes( total_seq_lens: int, head_dim: int, @@ -116,7 +165,7 @@ def sparse_attn_indexer( # Dummy allocation to simulate for peak logits tensor memory during inference. # FP8 elements so elements == bytes - max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + max_logits_elems = sparse_indexer_max_logits_bytes() _ = torch.empty( max_logits_elems, dtype=torch.uint8, device=hidden_states.device ) @@ -218,6 +267,19 @@ def sparse_attn_indexer( q_slice_cast = q_slice k_quant_cast = k_quant k_scale_cast = k_scale.view(torch.float32).squeeze(-1) + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + if not current_platform.is_xpu() and fp8_fp4_mqa_topk_indices( + (q_slice_cast, q_scale_slice), + (k_quant_cast, k_scale_cast), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + ): + continue + if current_platform.is_xpu(): if q_scale_slice is not None: raise RuntimeError("XPU fp8_mqa_logits does not support FP4 Q") @@ -240,10 +302,6 @@ def sparse_attn_indexer( ) num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[ - chunk.token_start : chunk.token_end, :topk_tokens - ] - ops.top_k_per_row_prefill( logits, chunk.cu_seqlen_ks, @@ -305,60 +363,93 @@ def sparse_attn_indexer( if use_fp4_cache else padded_q_quant_decode_tokens ) - if current_platform.is_xpu(): - if padded_q_scale is not None: - raise RuntimeError("XPU fp8_paged_mqa_logits does not support FP4 Q") - seq_lens_xpu = ( - seq_lens[:, -1].contiguous() if seq_lens.ndim == 2 else seq_lens - ) - logits = torch.ops.vllm.xpu_fp8_paged_mqa_logits( - padded_q_quant_cast, - kv_cache, - weights[:num_padded_tokens], - seq_lens_xpu, - decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len, - ) - else: - logits = fp8_fp4_paged_mqa_logits( + topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] + logits_width = _decode_topk_logits_width( + max_model_len, attn_metadata_narrowed.max_seq_len, topk_tokens + ) + logits_bytes = num_padded_tokens * logits_width * torch.float32.itemsize + used_direct_topk = False + if ( + not current_platform.is_xpu() + and logits_bytes > sparse_indexer_max_logits_bytes() + ): + used_direct_topk = fp8_fp4_paged_mqa_topk_indices( (padded_q_quant_cast, padded_q_scale), kv_cache, weights[:num_padded_tokens], seq_lens, decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len=max_model_len, - clean_logits=False, - ) - num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - - if current_platform.is_cuda() and topk_tokens in (512, 1024, 2048): - workspace_manager = current_workspace_manager() - (topk_workspace,) = workspace_manager.get_simultaneous( - ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), - ) - torch.ops._C.persistent_topk( - logits, - seq_lens, - topk_indices, - topk_workspace, - topk_tokens, - attn_metadata_narrowed.max_seq_len, - ) - else: - ops.top_k_per_row_decode( - logits, - next_n, - seq_lens, + logits_width, topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, ) + if not used_direct_topk: + if current_platform.is_xpu(): + if padded_q_scale is not None: + raise RuntimeError( + "XPU fp8_paged_mqa_logits does not support FP4 Q" + ) + seq_lens_xpu = ( + seq_lens[:, -1].contiguous() if seq_lens.ndim == 2 else seq_lens + ) + logits = torch.ops.vllm.xpu_fp8_paged_mqa_logits( + padded_q_quant_cast, + kv_cache, + weights[:num_padded_tokens], + seq_lens_xpu, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len, + ) + else: + logits = fp8_fp4_paged_mqa_logits( + (padded_q_quant_cast, padded_q_scale), + kv_cache, + weights[:num_padded_tokens], + seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=logits_width, + clean_logits=False, + ) + num_rows = logits.shape[0] + + if _use_sm120_short_row_topk_decode(logits, topk_tokens): + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + elif current_platform.is_cuda() and topk_tokens in (512, 1024, 2048): + workspace_manager = current_workspace_manager() + (topk_workspace,) = workspace_manager.get_simultaneous( + ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), + ) + torch.ops._C.persistent_topk( + logits, + seq_lens, + topk_indices, + topk_workspace, + topk_tokens, + logits_width, + ) + else: + ops.top_k_per_row_decode( + logits, + next_n, + seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens @@ -440,7 +531,7 @@ def __init__( self.topk_indices_buffer = topk_indices_buffer self.skip_k_cache_insert = skip_k_cache_insert self.use_fp4_cache = use_fp4_cache - if current_platform.is_cuda() and not has_deep_gemm(): + if _sparse_indexer_requires_deep_gemm() and not has_deep_gemm(): raise RuntimeError( "Sparse Attention Indexer CUDA op requires DeepGEMM support in " "the current vLLM environment." diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 1ddc93ff5e74..6b6fac63e261 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -359,6 +359,48 @@ def transform_sf_into_required_layout(*args, **kwargs): ) +def fp8_fp4_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 MQA top-k indices without materializing full logits.""" + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q[1] is None + ): + return False + from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks.fp8_fp4_mqa_topk_indices( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_indices, + ) + + +def _fp8_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks._fp8_mqa_logits_sm12x( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) + + def fp8_fp4_mqa_logits( q: tuple[torch.Tensor, torch.Tensor | None], kv: tuple[torch.Tensor, torch.Tensor], @@ -391,6 +433,10 @@ def fp8_fp4_mqa_logits( Returns: Logits tensor of shape [M, N], dtype `torch.float32`. """ + if current_platform.is_device_capability_family(120) and q[1] is None: + return _fp8_mqa_logits_sm12x( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) _lazy_init() if _fp8_fp4_mqa_logits_impl is None: return _missing() @@ -425,6 +471,50 @@ def get_paged_mqa_logits_metadata( return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) +def _fp8_paged_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks._fp8_paged_mqa_logits_sm12x( + q, kv_cache, weights, context_lens, block_tables, max_model_len + ) + + +def fp8_fp4_paged_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 paged MQA top-k indices without full logits.""" + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q[1] is None + ): + return False + from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks.fp8_fp4_paged_mqa_topk_indices( + q, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + topk_indices, + ) + + def fp8_fp4_paged_mqa_logits( q: tuple[torch.Tensor, torch.Tensor | None], kv_cache: torch.Tensor, @@ -446,9 +536,10 @@ def fp8_fp4_paged_mqa_logits( [B, next_n, H, D] float8_e4m3fn and q_scale is None. FP4 path: q_values is packed uint8 and q_scale is the companion block-scale tensor. - kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, 1, - D+4], dtype `torch.uint8`, with the last 4 bytes per (block, pos) - storing the float dequant scale. + kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, D+4] + or [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. Within + each block, the D-byte FP8 values for every token are stored first, + followed by per-token fp32 scale bytes. weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. context_lens: Tensor of shape [B], dtype int32; effective context length for each batch element. @@ -463,6 +554,10 @@ def fp8_fp4_paged_mqa_logits( Logits tensor of shape [B * next_n, max_model_len], dtype `torch.float32`. """ + if current_platform.is_device_capability_family(120) and q[1] is None: + return _fp8_paged_mqa_logits_sm12x( + q, kv_cache, weights, context_lens, block_tables, max_model_len + ) _lazy_init() if _fp8_fp4_paged_mqa_logits_impl is None: return _missing() @@ -478,6 +573,20 @@ def fp8_fp4_paged_mqa_logits( ) +def _tf32_hc_prenorm_gemm_sm12x( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks._tf32_hc_prenorm_gemm_sm12x( + x, fn, out, sqrsum, num_split + ) + + def tf32_hc_prenorm_gemm( x: torch.Tensor, fn: torch.Tensor, @@ -492,6 +601,8 @@ def tf32_hc_prenorm_gemm( See the caller function for shape requirement """ + if current_platform.is_device_capability_family(120): + return _tf32_hc_prenorm_gemm_sm12x(x, fn, out, sqrsum, num_split) _lazy_init() if _tf32_hc_prenorm_gemm_impl is None: return _missing() @@ -591,7 +702,9 @@ def should_use_deepgemm_for_fp8_linear( "m_grouped_fp8_fp4_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", "fp8_fp4_mqa_logits", + "fp8_fp4_mqa_topk_indices", "fp8_fp4_paged_mqa_logits", + "fp8_fp4_paged_mqa_topk_indices", "get_paged_mqa_logits_metadata", "per_block_cast_to_fp8", "is_deep_gemm_e8m0_used", From cbb6995ee997bd8bd61bf9e53c2df00c24e245bf Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 6 May 2026 17:49:25 +0800 Subject: [PATCH 006/131] Wire SM12x sparse MLA into DeepSeek V4 Signed-off-by: jasl --- vllm/config/compilation.py | 1 + .../attention/backends/mla/flashmla_sparse.py | 17 +++ vllm/v1/attention/backends/mla/indexer.py | 28 +++- vllm/v1/attention/backends/mla/sparse_swa.py | 144 ++++++------------ 4 files changed, 88 insertions(+), 102 deletions(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index bc38ec6a8a8a..c641fdddf405 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -757,6 +757,7 @@ class CompilationConfig: "vllm::sparse_attn_indexer", "vllm::rocm_aiter_sparse_attn_indexer", "vllm::deepseek_v4_attention", + "vllm::deepseek_v4_fp8_einsum", ] def compute_hash(self) -> str: diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 6d8dfe13128d..2d6231051686 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -27,6 +27,10 @@ MultipleOf, SparseMLAAttentionImpl, ) +from vllm.v1.attention.backends.mla.compressor_utils import get_compressed_slot_mapping +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled_for_platform, +) from vllm.v1.attention.backends.mla.sparse_utils import ( triton_convert_req_index_to_global_index, ) @@ -231,6 +235,19 @@ def get_prefill_workspace_size(max_model_len: int): class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + @classmethod + def get_cudagraph_support( + cls, + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + if ( + getattr(kv_cache_spec, "model_version", None) == "deepseek_v4" + and is_triton_sparse_mla_enabled_for_platform() + ): + return AttentionCGSupport.NEVER + return cls._cudagraph_support + def __init__( self, kv_cache_spec: AttentionSpec, diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 0bc7ca7aa414..b2719cbbe7dc 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from dataclasses import dataclass import torch -import vllm.envs as envs +from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -32,6 +33,26 @@ logger = init_logger(__name__) +def sparse_indexer_max_logits_bytes(is_sm12x: bool | None = None) -> int: + if is_sm12x is None: + is_sm12x = ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + ) + if "VLLM_SPARSE_INDEXER_MAX_LOGITS_MB" in os.environ: + return envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + default_mb = 256 if is_sm12x else 512 + return default_mb * 1024 * 1024 + + +def _uses_deep_gemm_scheduler_metadata() -> bool: + return ( + current_platform.is_cuda() + and has_deep_gemm() + and not current_platform.is_device_capability_family(120) + ) + + @triton.jit def _prepare_uniform_decode_kernel( seq_lens_ptr, @@ -522,7 +543,7 @@ def build( prefill_query_lens_cpu = torch.diff( query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1] ) - max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + max_logits_bytes = sparse_indexer_max_logits_bytes() # Upper bound is exact for prefill rows (the `[num_decodes:]` # slice below). assert common_attn_metadata.seq_lens_cpu_upper_bound is not None @@ -612,8 +633,7 @@ def build( if seq_lens.dim() == 1: seq_lens = seq_lens.unsqueeze(-1) - # DeepGEMM is required for the paged MQA logits on CUDA devices - if current_platform.is_cuda() and has_deep_gemm(): + if _uses_deep_gemm_scheduler_metadata(): self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( seq_lens, self.kv_cache_spec.storage_block_size, diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index a3fd39bed79a..fb6e0f9dbcc9 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -9,7 +9,6 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils.math_utils import cdiv from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -17,9 +16,14 @@ CommonAttentionMetadata, MultipleOf, ) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled, + is_triton_sparse_mla_enabled_for_platform, +) from vllm.v1.attention.backends.utils import split_decodes_and_prefills from vllm.v1.attention.ops.flashmla import FlashMLASchedMeta, get_mla_metadata from vllm.v1.kv_cache_interface import ( + AttentionSpec, KVCacheSpec, MLAAttentionSpec, SlidingWindowMLASpec, @@ -173,12 +177,9 @@ class DeepseekSparseSWAMetadata: # Pre-computed prefill metadata shared across all DeepseekV4 attention layers. prefill_seq_lens: torch.Tensor | None = None - prefill_seq_lens_cpu: torch.Tensor | None = None prefill_gather_lens: torch.Tensor | None = None - prefill_query_lens_cpu: torch.Tensor | None = None - prefill_window_size: int = 0 - prefill_max_model_len: int = 0 - prefill_max_num_batched_tokens: int = 0 + prefill_seq_lens_cpu: torch.Tensor | None = None + prefill_gather_lens_cpu: torch.Tensor | None = None # Per-layer-type FlashMLA tile-scheduler metadata. One FlashMLASchedMeta # per present DeepseekV4 layer type, shared across all ~60 layers of that type @@ -197,79 +198,6 @@ class DeepseekSparseSWAMetadata: default_factory=dict ) - def get_prefill_chunk_plan( - self, compress_ratio: int, prefill_chunk_size: int - ) -> list[tuple[int, int, int, int]]: - if self.num_prefills == 0: - return [] - - assert self.prefill_seq_lens_cpu is not None - assert self.prefill_query_lens_cpu is not None - - # query_len <= max_num_batched_tokens and - # gather_len = query_len + min(prefix_len, window_size - 1), so the - # worst-case gathered width is bounded by - # max_num_batched_tokens + window_size - 1. The compressed prefix pool - # is bounded by ceil(max_model_len / compress_ratio). - max_workspace_area = prefill_chunk_size * ( - ( - 0 - if compress_ratio <= 1 - else cdiv(self.prefill_max_model_len, compress_ratio) - ) - + self.prefill_window_size - + self.prefill_max_num_batched_tokens - ) - prefix_lens_cpu = self.prefill_seq_lens_cpu - self.prefill_query_lens_cpu - gather_lens_cpu = self.prefill_query_lens_cpu + torch.clamp( - prefix_lens_cpu, min=0, max=self.prefill_window_size - 1 - ) - compressed_lens_cpu = ( - torch.zeros_like(self.prefill_seq_lens_cpu) - if compress_ratio <= 1 - else torch.div( - self.prefill_seq_lens_cpu, - compress_ratio, - rounding_mode="floor", - ) - ) - - chunk_plan: list[tuple[int, int, int, int]] = [] - chunk_start = 0 - while chunk_start < self.num_prefills: - chunk_max_compressed = int(compressed_lens_cpu[chunk_start].item()) - chunk_max_gather = int(gather_lens_cpu[chunk_start].item()) - chunk_end = chunk_start + 1 - - while chunk_end < self.num_prefills: - candidate_max_compressed = max( - chunk_max_compressed, - int(compressed_lens_cpu[chunk_end].item()), - ) - candidate_max_gather = max( - chunk_max_gather, - int(gather_lens_cpu[chunk_end].item()), - ) - candidate_width = candidate_max_compressed + candidate_max_gather - candidate_area = (chunk_end - chunk_start + 1) * candidate_width - if candidate_area > max_workspace_area: - break - chunk_max_compressed = candidate_max_compressed - chunk_max_gather = candidate_max_gather - chunk_end += 1 - - chunk_plan.append( - ( - chunk_start, - chunk_end, - chunk_max_compressed, - chunk_max_compressed + chunk_max_gather, - ) - ) - chunk_start = chunk_end - - return chunk_plan - class DeepseekSparseSWAMetadataBuilder(AttentionMetadataBuilder): """Builds metadata for DeepseekV4 SWA cache. @@ -288,6 +216,19 @@ class DeepseekSparseSWAMetadataBuilder(AttentionMetadataBuilder): reorder_batch_threshold: int = 1 _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + @classmethod + def get_cudagraph_support( + cls, + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + if ( + getattr(kv_cache_spec, "model_version", None) == "deepseek_v4" + and is_triton_sparse_mla_enabled_for_platform() + ): + return AttentionCGSupport.NEVER + return cls._cudagraph_support + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec) @@ -295,10 +236,6 @@ def __init__(self, *args, **kwargs): self.head_size = mla_spec.head_size # Already considered quantization. self.compress_ratio = mla_spec.compress_ratio self.block_size = mla_spec.block_size - self.max_model_len = self.vllm_config.model_config.max_model_len - self.max_num_batched_tokens = ( - self.vllm_config.scheduler_config.max_num_batched_tokens - ) # Handle MTP: adjust decode_threshold like the indexer does self.num_speculative_tokens = ( @@ -365,7 +302,6 @@ def build( """ num_reqs = common_attn_metadata.num_reqs seq_lens = common_attn_metadata.seq_lens - seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu block_table = common_attn_metadata.block_table_tensor @@ -410,9 +346,9 @@ def build( num_decodes, num_prefills, seq_lens, - seq_lens_cpu, query_start_loc, query_start_loc_cpu, + common_attn_metadata.seq_lens_cpu_upper_bound, ) # Per-layer-type tile-scheduler plan holders. Empty FlashMLASchedMeta @@ -439,7 +375,7 @@ def build( tile_sched_swaonly=tile_sched[_LAYER_TYPE_SWAONLY], tile_sched_c4a=tile_sched[_LAYER_TYPE_C4A], tile_sched_c128a=tile_sched[_LAYER_TYPE_C128A], - **deepseek_v4_fields, # type: ignore[arg-type] + **deepseek_v4_fields, ) def build_tile_scheduler( @@ -467,6 +403,8 @@ def build_tile_scheduler( or current_platform.is_xpu() ): return out + if is_triton_sparse_mla_enabled(self.device): + return out for layer_type in self._layer_types: # get_mla_metadata() is the official FlashMLA entry point that # returns a fresh empty FlashMLASchedMeta; using it keeps this @@ -480,10 +418,10 @@ def _build_deepseek_v4_metadata( num_decodes: int, num_prefills: int, seq_lens: torch.Tensor, - seq_lens_cpu: torch.Tensor | None, query_start_loc: torch.Tensor, query_start_loc_cpu: torch.Tensor, - ) -> dict[str, torch.Tensor | int | None]: + seq_lens_cpu_upper_bound: torch.Tensor | None, + ) -> dict[str, torch.Tensor | None]: """Pre-compute DeepseekV4 prefill metadata during the metadata build phase. Returns a dict of keyword arguments to pass to the @@ -492,11 +430,10 @@ def _build_deepseek_v4_metadata( Note: C128A topk indices are computed by the FlashMLASparse builder (which owns the C128A block_table), not here. """ - result: dict[str, torch.Tensor | int | None] = {} + result: dict[str, torch.Tensor | None] = {} # --- Prefill query metadata (single Triton kernel + CPU slicing) --- if num_prefills > 0: - assert seq_lens_cpu is not None pfx_gather_lens = torch.empty( num_prefills, dtype=torch.int32, device=seq_lens.device ) @@ -510,16 +447,27 @@ def _build_deepseek_v4_metadata( BLOCK_SIZE=triton.next_power_of_2(num_prefills), ) + assert seq_lens_cpu_upper_bound is not None + seq_lens_cpu = seq_lens_cpu_upper_bound + prefill_seq_lens_cpu = seq_lens_cpu[ + num_decodes : num_decodes + num_prefills + ] + query_lens_cpu = ( + query_start_loc_cpu[ + num_decodes + 1 : num_decodes + num_prefills + 1 + ] + - query_start_loc_cpu[num_decodes : num_decodes + num_prefills] + ) + prefix_lens_cpu = prefill_seq_lens_cpu - query_lens_cpu + prefill_gather_lens_cpu = query_lens_cpu + torch.minimum( + prefix_lens_cpu, + torch.full_like(prefix_lens_cpu, self.window_size - 1), + ) + result["prefill_seq_lens"] = seq_lens[num_decodes:] - result["prefill_seq_lens_cpu"] = seq_lens_cpu[num_decodes:] result["prefill_gather_lens"] = pfx_gather_lens - result["prefill_query_lens_cpu"] = ( - query_start_loc_cpu[num_decodes + 1 : num_decodes + num_prefills + 1] - - query_start_loc_cpu[num_decodes : num_decodes + num_prefills] - ).to(dtype=torch.int32) - result["prefill_window_size"] = self.window_size - result["prefill_max_model_len"] = self.max_model_len - result["prefill_max_num_batched_tokens"] = self.max_num_batched_tokens + result["prefill_seq_lens_cpu"] = prefill_seq_lens_cpu + result["prefill_gather_lens_cpu"] = prefill_gather_lens_cpu return result From b32ed2056f11859da2072338cea158ba426a9b2a Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 6 May 2026 19:51:42 +0800 Subject: [PATCH 007/131] Reduce DeepSeek V4 load overhead on GB10 Fix the SM12x fp8 einsum custom-op registration import, skip unused DeepSeek V4 MTP checkpoint tensors before safetensors materialization, and release MXFP4 setup temporaries after kernel setup. Signed-off-by: jasl --- .../model_loader/test_ep_weight_filter.py | 14 ++++++ .../layers/quantization/mxfp4.py | 6 +++ .../model_loader/default_loader.py | 14 ++++-- .../model_loader/weight_utils.py | 43 +++++++++++++++++-- vllm/models/deepseek_v4/nvidia/model.py | 6 +++ .../deepseek_v4/nvidia/ops/fp8_einsum.py | 2 +- 6 files changed, 76 insertions(+), 9 deletions(-) diff --git a/tests/model_executor/model_loader/test_ep_weight_filter.py b/tests/model_executor/model_loader/test_ep_weight_filter.py index 2ac38192a4b0..d032cd569019 100644 --- a/tests/model_executor/model_loader/test_ep_weight_filter.py +++ b/tests/model_executor/model_loader/test_ep_weight_filter.py @@ -319,6 +319,20 @@ def test_ep2_rank0_gets_half_experts(self, synthetic_moe_files): assert "model.layers.0.input_layernorm.weight" in loaded assert "model.layers.0.mlp.shared_experts.gate_proj.weight" in loaded + def test_weight_name_filter_skips_dense_weights(self, synthetic_moe_files): + files, _ = synthetic_moe_files + loaded = dict( + safetensors_weights_iterator( + files, + False, + weight_name_filter=lambda name: "self_attn.q_proj" in name, + ) + ) + + assert "model.layers.0.self_attn.q_proj.weight" not in loaded + assert "model.embed_tokens.weight" in loaded + assert "model.layers.0.mlp.shared_experts.gate_proj.weight" in loaded + def test_ep2_rank1_gets_other_half(self, synthetic_moe_files): files, expected = synthetic_moe_files local_ids = compute_local_expert_ids(8, ep_size=2, ep_rank=1) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 1b2a8a74bdcb..92c4b23f54bc 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -389,6 +389,9 @@ def process_weights_after_loading(self, layer: RoutedExperts) -> None: return self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias) + del w13, w2, w13_scale, w2_scale, w13_bias, w2_bias + if torch.cuda.is_available(): + torch.cuda.empty_cache() def get_fused_moe_quant_config( self, layer: RoutedExperts @@ -733,6 +736,9 @@ def process_weights_after_loading(self, layer): return self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias) + del w13, w2, w13_scale, w2_scale, w13_bias, w2_bias + if torch.cuda.is_available(): + torch.cuda.empty_cache() def get_fused_moe_quant_config( self, diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 3ea76f4d9b3a..bf8bf657a3c3 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -4,7 +4,7 @@ import glob import os import time -from collections.abc import Generator, Iterable +from collections.abc import Callable, Generator, Iterable from typing import cast import torch @@ -242,7 +242,9 @@ def _prepare_weights( return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( - self, source: "Source" + self, + source: "Source", + weight_name_filter: Callable[[str], bool] | None = None, ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" extra_config = self.load_config.model_loader_extra_config @@ -289,6 +291,7 @@ def _get_weights_iterator( self.load_config.use_tqdm_on_load, self.load_config.safetensors_load_strategy, local_expert_ids=self.local_expert_ids, + weight_name_filter=weight_name_filter, safetensors_prefetch_num_threads=( self.load_config.safetensors_prefetch_num_threads ), @@ -323,6 +326,9 @@ def get_all_weights( model_config: ModelConfig, model: nn.Module, ) -> Generator[tuple[str, torch.Tensor], None, None]: + weight_name_filter = getattr(model, "skip_weight_name_before_load", None) + if not callable(weight_name_filter): + weight_name_filter = None primary_weights = DefaultModelLoader.Source( model_config.model, model_config.revision, @@ -330,14 +336,14 @@ def get_all_weights( fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None), ) - yield from self._get_weights_iterator(primary_weights) + yield from self._get_weights_iterator(primary_weights, weight_name_filter) secondary_weights = cast( Iterable[DefaultModelLoader.Source], getattr(model, "secondary_weights", ()), ) for source in secondary_weights: - yield from self._get_weights_iterator(source) + yield from self._get_weights_iterator(source, weight_name_filter) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights( diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 47c6c02be6ab..7037778b7f97 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -817,11 +817,22 @@ def _run_prefetch() -> None: threading.Thread(target=_run_prefetch, daemon=True).start() +def _should_skip_safetensors_weight( + weight_name: str, + local_expert_ids: set[int] | None, + weight_name_filter: Callable[[str], bool] | None, +) -> bool: + if should_skip_weight(weight_name, local_expert_ids): + return True + return weight_name_filter is not None and weight_name_filter(weight_name) + + def safetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, safetensors_load_strategy: str | None = None, local_expert_ids: set[int] | None = None, + weight_name_filter: Callable[[str], bool] | None = None, *, safetensors_prefetch_num_threads: int = DEFAULT_SAFETENSORS_PREFETCH_NUM_THREADS, safetensors_prefetch_block_size: int = DEFAULT_SAFETENSORS_PREFETCH_BLOCK_SIZE, @@ -831,6 +842,9 @@ def safetensors_weights_iterator( When *local_expert_ids* is provided, expert weights not belonging to this rank are skipped **before** reading from disk, which drastically reduces storage I/O for MoE models under EP. + + When *weight_name_filter* is provided, names for which the callback returns + ``True`` are also skipped before tensor materialization. """ loading_desc = "Loading safetensors checkpoint shards" if safetensors_load_strategy == "eager": @@ -913,7 +927,9 @@ def safetensors_weights_iterator( with open(st_file, "rb") as f: state_dict = load(f.read()) for name, param in state_dict.items(): - if not should_skip_weight(name, local_expert_ids): + if not _should_skip_safetensors_weight( + name, local_expert_ids, weight_name_filter + ): yield name, param elif safetensors_load_strategy == "torchao": # we can't load flattened torchao tensor subclasses directly into the model @@ -930,7 +946,9 @@ def safetensors_weights_iterator( with safe_open(st_file, framework="pt") as f: state_dict = {} for name in f.keys(): # noqa: SIM118 - if should_skip_weight(name, local_expert_ids): + if _should_skip_safetensors_weight( + name, local_expert_ids, weight_name_filter + ): continue state_dict[name] = f.get_tensor(name) @@ -948,7 +966,9 @@ def safetensors_weights_iterator( else: with safe_open(st_file, framework="pt") as f: for name in f.keys(): # noqa: SIM118 - if should_skip_weight(name, local_expert_ids): + if _should_skip_safetensors_weight( + name, local_expert_ids, weight_name_filter + ): continue param = f.get_tensor(name) yield name, param @@ -1024,6 +1044,8 @@ def runai_safetensors_weights_iterator( def fastsafetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, + local_expert_ids: set[int] | None = None, + weight_name_filter: Callable[[str], bool] | None = None, ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files using fastsafetensor library. @@ -1072,6 +1094,12 @@ def _make_loader(nogds: bool) -> "ParallelLoader": pl = _make_loader(nogds) for name, tensor in pl.iterate_weights(): yielded = True + if _should_skip_safetensors_weight( + name, + local_expert_ids, + weight_name_filter, + ): + continue yield name, tensor except RuntimeError as e: if nogds or yielded or "gds" not in str(e): @@ -1084,7 +1112,14 @@ def _make_loader(nogds: bool) -> "ParallelLoader": if pl is not None: pl.close() pl = _make_loader(nogds=True) - yield from pl.iterate_weights() + for name, tensor in pl.iterate_weights(): + if _should_skip_safetensors_weight( + name, + local_expert_ids, + weight_name_filter, + ): + continue + yield name, tensor finally: if pl is not None: pl.close() diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 868fc3f5fdb2..ea09509bd4c2 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -1374,6 +1374,12 @@ def get_mtp_target_hidden_states(self) -> torch.Tensor | None: forward(); valid after each target step.""" return getattr(self.model, "_mtp_hidden_buffer", None) + def skip_weight_name_before_load(self, name: str) -> bool: + mapped_names = self.hf_to_vllm_mapper.apply_list([name]) + if not mapped_names: + return True + return all("mtp." in mapped_name for mapped_name in mapped_names) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_substrs=["mtp."]) loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py index 68cf39061ed8..a9f52e767b20 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py +++ b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py @@ -5,10 +5,10 @@ import torch from vllm.distributed import get_tensor_model_parallel_rank -from vllm.model_executor.custom_op import direct_register_custom_op from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import fp8_einsum +from vllm.utils.torch_utils import direct_register_custom_op def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: From e46bd0df02e3b411baf00db986b9f707e35490a0 Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 7 May 2026 00:26:22 +0800 Subject: [PATCH 008/131] Apply weight filter to fast safetensors loading Forward model skip_weight_name_before_load filters into the fastsafetensors iterator and skip filtered keys before materializing tensors. This keeps DeepSeek V4 non-MTP loads from reading MTP-only weights when users select --load-format fastsafetensors. Keep the regression coverage at behavior level by checking the DefaultModelLoader path and pruning private implementation-field assertions from the adjacent DeepSeek V4 prefix-cache tests. Co-authored-by: OpenAI Codex Signed-off-by: jasl --- .../test_weight_utils.py | 74 +++++++++++++++++++ tests/v1/core/test_prefix_caching.py | 12 +-- .../model_loader/default_loader.py | 2 + 3 files changed, 79 insertions(+), 9 deletions(-) diff --git a/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py b/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py index da974131f65a..cb4c4d492fd2 100644 --- a/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py +++ b/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py @@ -8,6 +8,9 @@ import pytest import torch +import vllm.model_executor.model_loader.weight_utils as weight_utils +from vllm.config.load import LoadConfig +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import ( download_weights_from_hf, fastsafetensors_weights_iterator, @@ -16,6 +19,77 @@ from vllm.platforms import current_platform +def test_default_loader_filters_fastsafetensors_before_materializing(monkeypatch): + class FakeProcessGroup: + def size(self): + return 1 + + class FakeFileBuffer: + def __init__(self): + self.key_to_rank_lidx = { + "model.layers.0.self_attn.q_proj.weight": (0, 0), + "model.layers.0.mlp.experts.0.gate_proj.weight": (0, 1), + "model.layers.0.mlp.experts.1.gate_proj.weight": (0, 2), + "model.mtp.0.weight": (0, 3), + } + self.loaded_keys: list[str] = [] + self.closed = False + + def get_tensor(self, key: str): + self.loaded_keys.append(key) + return torch.tensor([len(self.loaded_keys)]) + + def close(self): + self.closed = True + + class FakeLoader: + def __init__(self, file_buffer): + self.file_buffer = file_buffer + self.closed = False + + def copy_files_to_device(self): + return self.file_buffer + + def close(self): + self.closed = True + + file_buffer = FakeFileBuffer() + loader = FakeLoader(file_buffer) + + model_loader = DefaultModelLoader(LoadConfig(load_format="fastsafetensors")) + model_loader.local_expert_ids = {0} + monkeypatch.setattr( + model_loader, + "_prepare_weights", + lambda *_args: ("/weights", ["model.safetensors"], True), + ) + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: False) + monkeypatch.setattr(weight_utils, "SingleGroup", FakeProcessGroup) + monkeypatch.setattr( + weight_utils, + "_init_fastsafetensors_loader", + lambda *_args, **_kwargs: loader, + ) + + loaded = dict( + model_loader._get_weights_iterator( + DefaultModelLoader.Source("model", revision=None), + weight_name_filter=lambda name: "model.mtp." in name, + ) + ) + + assert set(loaded) == { + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.mlp.experts.0.gate_proj.weight", + } + assert file_buffer.loaded_keys == [ + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.mlp.experts.0.gate_proj.weight", + ] + assert file_buffer.closed + assert loader.closed + + @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="fastsafetensors requires NVIDIA/AMD GPUs", diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 8cc28f6cd09e..d2f231d034b0 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -3471,7 +3471,7 @@ def test_can_fit_full_sequence_swa_cap_admits_long_prompt(): ) -def test_deepseek_v4_mla_keeps_prompt_blocks_after_decode_pressure(): +def test_deepseek_v4_mla_prompt_cache_survives_decode_pressure(): hash_block_size = 2 full_block_size = 8 swa_block_size = 2 @@ -3587,10 +3587,9 @@ def run_request(request: Request, num_decode_tokens: int) -> int: assert num_computed_tokens == expected_hit_tokens -def test_deepseek_v4_mla_protected_prompt_blocks_do_not_block_admission(): +def test_deepseek_v4_mla_cached_prompts_do_not_block_admission(): block_size = 8 prompt_tokens = 4 * block_size + 3 - protected_blocks_per_prompt = (prompt_tokens - 1) // block_size num_prompts = 10 num_blocks = 80 manager = KVCacheManager( @@ -3616,7 +3615,6 @@ def test_deepseek_v4_mla_protected_prompt_blocks_do_not_block_admission(): enable_caching=True, hash_block_size=block_size, ) - mla_manager = manager.coordinator.single_type_managers[0] for i in range(num_prompts): prompt = list(range(i * 1000, i * 1000 + prompt_tokens)) @@ -3625,9 +3623,6 @@ def test_deepseek_v4_mla_protected_prompt_blocks_do_not_block_admission(): req.num_computed_tokens = prompt_tokens manager.free(req) - assert len(mla_manager._protected_prompt_block_ids) == ( - num_prompts * protected_blocks_per_prompt - ) assert manager.block_pool.get_num_free_blocks() < 64 long_req = make_request( @@ -3642,7 +3637,7 @@ def test_deepseek_v4_mla_protected_prompt_blocks_do_not_block_admission(): ) -def test_reset_prefix_cache_releases_deepseek_v4_mla_protected_blocks(): +def test_reset_prefix_cache_after_deepseek_v4_mla_prompt_cache(): block_size = 8 prompt_tokens = 4 * block_size + 3 manager = KVCacheManager( @@ -3674,7 +3669,6 @@ def test_reset_prefix_cache_releases_deepseek_v4_mla_protected_blocks(): req.num_computed_tokens = prompt_tokens manager.free(req) - assert manager.coordinator.single_type_managers[0]._protected_prompt_block_ids assert manager.reset_prefix_cache() diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index bf8bf657a3c3..25682c1085dc 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -270,6 +270,8 @@ def _get_weights_iterator( weights_iterator = fastsafetensors_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, + local_expert_ids=self.local_expert_ids, + weight_name_filter=weight_name_filter, ) elif self.load_config.load_format == "instanttensor": weights_iterator = instanttensor_weights_iterator( From f809e1431fbec8c741908e3bce177b741b0d23da Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 6 May 2026 01:45:55 +0800 Subject: [PATCH 009/131] Warm DeepSeek V4 startup kernels Import the production-preview warmups for DeepSeek V4 request preparation, sparse MLA attention, and mHC TileLang kernels while leaving the old warmup test fixture out of the preview branch. Cherry-picked-from: 0dca30b461d2cce278caacf80ad8b1919a1bb341 Cherry-picked-from: 5959aadef1c63831713c0333567680a55994469d Cherry-picked-from: 7cf6f1d96f2dd1c324cd74a5f724c9795041a086 Signed-off-by: jasl --- vllm/envs.py | 22 ++ .../warmup/deepseek_v4_mhc_warmup.py | 251 ++++++++++++++++++ vllm/model_executor/warmup/kernel_warmup.py | 208 +++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 12 +- 4 files changed, 492 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/warmup/deepseek_v4_mhc_warmup.py diff --git a/vllm/envs.py b/vllm/envs.py index a94e084ab628..79dfecd947c0 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -179,6 +179,9 @@ VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True + VLLM_ENABLE_DEEPSEEK_V4_MHC_WARMUP: bool = True + VLLM_DEEPSEEK_V4_MHC_WARMUP_TOKEN_SIZES: list[int] | None = None + VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: bool = True VLLM_DEEP_GEMM_WARMUP: Literal[ "skip", "full", @@ -302,6 +305,13 @@ def maybe_convert_int(value: str | None) -> int | None: return int(value) +def maybe_convert_int_list(value: str | None) -> list[int] | None: + if value is None: + return None + values = [int(item.strip()) for item in value.split(",") if item.strip()] + return values or None + + def maybe_convert_bool(value: str | None) -> bool | None: if value is None: return None @@ -1444,6 +1454,18 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool( int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1")) ), + # DeepSeek V4 mHC / hc_head TileLang kernels JIT on first use. Enable + # startup warmup by default to avoid first-request latency spikes; set to + # 0 to keep the old lazy-JIT behavior. + "VLLM_ENABLE_DEEPSEEK_V4_MHC_WARMUP": lambda: bool( + int(os.getenv("VLLM_ENABLE_DEEPSEEK_V4_MHC_WARMUP", "1")) + ), + "VLLM_DEEPSEEK_V4_MHC_WARMUP_TOKEN_SIZES": lambda: maybe_convert_int_list( + os.getenv("VLLM_DEEPSEEK_V4_MHC_WARMUP_TOKEN_SIZES") + ), + "VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP": lambda: bool( + int(os.getenv("VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP", "1")) + ), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine diff --git a/vllm/model_executor/warmup/deepseek_v4_mhc_warmup.py b/vllm/model_executor/warmup/deepseek_v4_mhc_warmup.py new file mode 100644 index 000000000000..9189db0c4739 --- /dev/null +++ b/vllm/model_executor/warmup/deepseek_v4_mhc_warmup.py @@ -0,0 +1,251 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Warm up DeepSeek V4 mHC TileLang kernels before serving requests.""" + +import time +from collections.abc import Iterable + +import torch + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.tracing import instrument +from vllm.utils.math_utils import cdiv + +logger = init_logger(__name__) + +_AUTO_WARMUP_MAX_TOKENS = 16_384 +_DEFAULT_TOKEN_SIZE_CANDIDATES = ( + 1, + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16_384, +) + + +def _compute_mhc_pre_num_split( + *, + num_tokens: int, + hidden_size: int, + hc_mult: int, + num_sms: int, +) -> int: + block_k = 64 + block_m = 64 + k = hc_mult * hidden_size + grid_size = cdiv(num_tokens, block_m) + split_k = num_sms // grid_size + num_block_k = cdiv(k, block_k) + split_k = min(split_k, num_block_k // 4) + return max(split_k, 1) + + +def _normalize_token_sizes( + token_sizes: Iterable[int], + *, + max_tokens: int, +) -> list[int]: + return sorted({size for size in token_sizes if 1 <= size <= max_tokens}) + + +def _select_mhc_warmup_token_sizes( + *, + max_tokens: int, + hidden_size: int, + hc_mult: int, + num_sms: int, + requested_token_sizes: list[int] | None, + cudagraph_capture_sizes: list[int], +) -> list[int]: + if max_tokens <= 0: + return [] + + if requested_token_sizes is None: + max_auto_tokens = min(max_tokens, _AUTO_WARMUP_MAX_TOKENS) + candidates = list(_DEFAULT_TOKEN_SIZE_CANDIDATES) + candidates.extend(cudagraph_capture_sizes) + candidates.append(max_auto_tokens) + candidates = _normalize_token_sizes(candidates, max_tokens=max_auto_tokens) + else: + candidates = _normalize_token_sizes( + requested_token_sizes, + max_tokens=max_tokens, + ) + + return candidates + + +def _find_first_mhc_layer(model: torch.nn.Module) -> torch.nn.Module | None: + for module in model.modules(): + if module.__class__.__name__ != "DeepseekV4DecoderLayer": + continue + if all( + hasattr(module, attr) + for attr in ( + "hc_pre", + "hc_post", + "hc_attn_fn", + "hc_attn_scale", + "hc_attn_base", + "hc_ffn_fn", + "hc_ffn_scale", + "hc_ffn_base", + ) + ): + return module + return None + + +def _find_deepseek_v4_model(model: torch.nn.Module) -> torch.nn.Module | None: + for module in model.modules(): + if module.__class__.__name__ != "DeepseekV4Model": + continue + if all( + hasattr(module, attr) + for attr in ("hc_head_fn", "hc_head_scale", "hc_head_base") + ): + return module + return None + + +def _get_cuda_num_sms(device: torch.device) -> int: + index = device.index + if index is None: + index = torch.accelerator.current_device_index() + return torch.cuda.get_device_properties(index).multi_processor_count + + +def _warmup_layer_mhc( + layer: torch.nn.Module, + token_sizes: list[int], +) -> None: + max_tokens = max(token_sizes) + hidden_size = int(layer.hidden_size) + hc_mult = int(layer.hc_mult) + device = layer.hc_attn_fn.device + residual = torch.zeros( + max_tokens, + hc_mult, + hidden_size, + dtype=torch.bfloat16, + device=device, + ) + + for size in token_sizes: + residual_slice = residual[:size] + for fn, scale, base in ( + (layer.hc_attn_fn, layer.hc_attn_scale, layer.hc_attn_base), + (layer.hc_ffn_fn, layer.hc_ffn_scale, layer.hc_ffn_base), + ): + layer_input, post_mix, comb_mix = layer.hc_pre( + residual_slice, + fn, + scale, + base, + ) + layer.hc_post(layer_input, residual_slice, post_mix, comb_mix) + + +def _warmup_hc_head( + model: torch.nn.Module, + token_sizes: list[int], +) -> None: + if not hasattr(model, "_mtp_hidden_buffer"): + return + + # Upstream (a8887c208 "[DSV4] aiter mhc support (ROCm)") refactored + # ``hc_head`` from a free function exported by ``deepseek_v4`` into the + # ``HCHeadOp`` :class:`CustomOp` instance attached to the model as + # ``hc_head_op``. We call through that instance so the warmup exercises + # the same dispatched implementation as the inference path. + hc_head_op = getattr(model, "hc_head_op", None) + if hc_head_op is None: + return + + max_tokens = max(token_sizes) + hidden_size = int(model.config.hidden_size) + hc_mult = int(model.hc_mult) + device = model.hc_head_fn.device + hidden_states = torch.zeros( + max_tokens, + hc_mult, + hidden_size, + dtype=torch.bfloat16, + device=device, + ) + + for size in token_sizes: + hc_head_op( + hidden_states[:size], + model.hc_head_fn, + model.hc_head_scale, + model.hc_head_base, + model.rms_norm_eps, + model.hc_eps, + ) + + +@instrument(span_name="DeepSeek V4 mHC warmup") +def deepseek_v4_mhc_warmup( + model: torch.nn.Module, + *, + max_tokens: int, + cudagraph_capture_sizes: list[int] | None = None, +) -> None: + if not envs.VLLM_ENABLE_DEEPSEEK_V4_MHC_WARMUP: + return + + # Cheap model-type gate before walking ``model.modules()``. The class + # walk below is O(num_layers) and shows up in startup time on very + # large checkpoints; bail out for any model that is not DeepSeek V4. + config = getattr(model, "config", None) + model_type = getattr(config, "model_type", None) if config is not None else None + if model_type is not None and model_type != "deepseek_v4": + return + + layer = _find_first_mhc_layer(model) + if layer is None: + return + + device = layer.hc_attn_fn.device + if device.type != "cuda": + return + + deepseek_model = _find_deepseek_v4_model(model) + num_sms = _get_cuda_num_sms(device) + token_sizes = _select_mhc_warmup_token_sizes( + max_tokens=max_tokens, + hidden_size=int(layer.hidden_size), + hc_mult=int(layer.hc_mult), + num_sms=num_sms, + requested_token_sizes=envs.VLLM_DEEPSEEK_V4_MHC_WARMUP_TOKEN_SIZES, + cudagraph_capture_sizes=cudagraph_capture_sizes or [], + ) + if not token_sizes: + return + + started = time.perf_counter() + logger.info( + "Warming up DeepSeek V4 mHC TileLang kernels for token sizes: %s", + token_sizes, + ) + with torch.inference_mode(): + _warmup_layer_mhc(layer, token_sizes) + if deepseek_model is not None: + _warmup_hc_head(deepseek_model, token_sizes) + torch.accelerator.synchronize() + logger.info( + "DeepSeek V4 mHC TileLang warmup finished in %.2f seconds.", + time.perf_counter() - started, + ) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 61d2376abb8c..3698582cfa0b 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -8,17 +8,24 @@ import hashlib from pathlib import Path +from types import SimpleNamespace from typing import TYPE_CHECKING +import numpy as np import torch import vllm.envs as envs from vllm.compilation.caching import aot_compile_hash_factors from vllm.logger import init_logger from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup +from vllm.model_executor.warmup.deepseek_v4_mhc_warmup import ( + deepseek_v4_mhc_warmup, +) from vllm.platforms import current_platform from vllm.utils.deep_gemm import is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer +from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput +from vllm.v1.structured_output.utils import apply_grammar_bitmask if TYPE_CHECKING: from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -26,6 +33,197 @@ logger = init_logger(__name__) +_DEEPSEEK_V4_SPARSE_MLA_BACKENDS = frozenset( + { + "V4_FLASHMLA_SPARSE", + "DEEPSEEK_SPARSE_SWA", + } +) +_DEEPSEEK_V4_SPARSE_MLA_MIXED_WARMUP_TOKENS = 16 +_DEEPSEEK_V4_SPARSE_MLA_PREFILL_WARMUP_TOKENS = 1024 +_DEEPSEEK_V4_SLOT_MAPPING_WARMUP_TOKENS = tuple(range(1, 17)) + ( + 32, + 64, + 128, + 256, + 512, +) + + +def _attention_backend_name(backend: object) -> str | None: + get_name = getattr(backend, "get_name", None) + if get_name is None: + return None + try: + return get_name() + except NotImplementedError: + return None + + +def _has_deepseek_v4_sparse_mla_backend(runner: "GPUModelRunner") -> bool: + for groups in getattr(runner, "attn_groups", []) or (): + for group in groups: + name = _attention_backend_name(getattr(group, "backend", None)) + if name in _DEEPSEEK_V4_SPARSE_MLA_BACKENDS: + return True + return False + + +def _clamp_warmup_tokens(num_tokens: int, max_tokens: int) -> int: + return max(0, min(num_tokens, max_tokens)) + + +def _deepseek_v4_slot_mapping_warmup(runner: "GPUModelRunner") -> None: + max_tokens = getattr(runner, "max_num_tokens", 1) + block_table = runner.input_batch.block_table + + # Snapshot the runner buffers we mutate so warmup never leaks state into + # the first real request. + saved_query_start_loc_np: np.ndarray | None = None + saved_query_start_loc_gpu: torch.Tensor | None = None + if hasattr(runner, "query_start_loc"): + saved_query_start_loc_np = runner.query_start_loc.np[:2].copy() + saved_query_start_loc_gpu = runner.query_start_loc.gpu[:2].clone() + + try: + for requested_tokens in _DEEPSEEK_V4_SLOT_MAPPING_WARMUP_TOKENS: + num_tokens = _clamp_warmup_tokens(requested_tokens, max_tokens) + if num_tokens <= 0: + continue + + positions_source = torch.arange( + num_tokens, dtype=torch.int64, device=runner.device + ) + if hasattr(runner, "query_start_loc"): + runner.query_start_loc.np[0] = 0 + runner.query_start_loc.np[1] = num_tokens + runner.query_start_loc.copy_to_gpu(2) + query_start_loc = runner.query_start_loc.gpu[:2] + else: + query_start_loc = torch.tensor( + [0, num_tokens], dtype=torch.int32, device=runner.device + ) + + if hasattr(runner, "positions"): + saved_positions: torch.Tensor | None = ( + runner.positions[:num_tokens].clone() + ) + runner.positions[:num_tokens].copy_(positions_source) + positions = runner.positions[:num_tokens] + else: + saved_positions = None + positions = positions_source + + try: + block_table.commit_block_table(1) + block_table.compute_slot_mapping(1, query_start_loc, positions) + finally: + if saved_positions is not None: + runner.positions[:num_tokens].copy_(saved_positions) + finally: + if saved_query_start_loc_np is not None: + runner.query_start_loc.np[:2] = saved_query_start_loc_np + assert saved_query_start_loc_gpu is not None + runner.query_start_loc.gpu[:2].copy_(saved_query_start_loc_gpu) + + +def _deepseek_v4_structured_output_bitmask_warmup( + runner: "GPUModelRunner", +) -> None: + vocab_size = runner.model_config.get_vocab_size() + if vocab_size <= 0: + return + + dtypes = [torch.float32] + model_dtype = getattr(runner.model_config, "dtype", None) + if isinstance(model_dtype, torch.dtype) and model_dtype not in dtypes: + dtypes.append(model_dtype) + + bitmask_width = (vocab_size + 31) // 32 + req_id = "_deepseek_v4_warmup_" + grammar_bitmask = np.full((1, bitmask_width), fill_value=-1, dtype=np.int32) + grammar_output = GrammarOutput( + structured_output_request_ids=[req_id], grammar_bitmask=grammar_bitmask + ) + + for dtype in dtypes: + for req_ids in ([req_id], [req_id, "_deepseek_v4_warmup_unmasked_"]): + logits = torch.zeros( + (len(req_ids), vocab_size), dtype=dtype, device=runner.device + ) + input_batch = SimpleNamespace(req_ids=req_ids) + apply_grammar_bitmask( + SchedulerOutput.make_empty(), grammar_output, input_batch, logits + ) + + +@torch.inference_mode() +def _deepseek_v4_request_prep_warmup(worker: "Worker") -> None: + if not envs.VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: + return + + runner = worker.model_runner + if runner.is_pooling_model or not _has_deepseek_v4_sparse_mla_backend(runner): + return + if not current_platform.is_cuda_alike(): + return + + logger.info("Warming up DeepSeek V4 request preparation kernels.") + _deepseek_v4_slot_mapping_warmup(runner) + + if getattr(runner, "is_last_pp_rank", True): + try: + _deepseek_v4_structured_output_bitmask_warmup(runner) + except ImportError: + logger.debug( + "Skipping DeepSeek V4 structured output bitmask warmup because " + "xgrammar is unavailable." + ) + + torch.accelerator.synchronize() + + +def _deepseek_v4_sparse_mla_attention_warmup(worker: "Worker") -> None: + if not envs.VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: + return + + runner = worker.model_runner + if runner.is_pooling_model or not _has_deepseek_v4_sparse_mla_backend(runner): + return + + max_tokens = worker.scheduler_config.max_num_batched_tokens + mixed_tokens = _clamp_warmup_tokens( + _DEEPSEEK_V4_SPARSE_MLA_MIXED_WARMUP_TOKENS, max_tokens + ) + prefill_tokens = _clamp_warmup_tokens( + _DEEPSEEK_V4_SPARSE_MLA_PREFILL_WARMUP_TOKENS, max_tokens + ) + if mixed_tokens <= 0 and prefill_tokens <= 0: + return + + logger.info( + "Warming up DeepSeek V4 sparse MLA attention " + "for mixed tokens=%s and prefill tokens=%s.", + mixed_tokens, + prefill_tokens, + ) + if mixed_tokens > 0: + runner._dummy_run( + num_tokens=mixed_tokens, + skip_eplb=True, + is_profile=True, + force_attention=True, + create_mixed_batch=True, + ) + if prefill_tokens > 0: + runner._dummy_run( + num_tokens=prefill_tokens, + skip_eplb=True, + is_profile=True, + force_attention=True, + create_single_prefill=True, + ) + def _flashinfer_autotune_cache_hash(runner: "GPUModelRunner") -> str: factors = aot_compile_hash_factors(runner.vllm_config) @@ -69,6 +267,16 @@ def kernel_warmup(worker: "Worker"): deep_gemm_warmup(model, max_tokens) minimax_m3_msa_warmup(worker) + deepseek_v4_mhc_warmup( + worker.get_model(), + max_tokens=worker.scheduler_config.max_num_batched_tokens, + cudagraph_capture_sizes=( + worker.vllm_config.compilation_config.cudagraph_capture_sizes or [] + ), + ) + + _deepseek_v4_sparse_mla_attention_warmup(worker) + _deepseek_v4_request_prep_warmup(worker) enable_flashinfer_autotune = ( worker.vllm_config.kernel_config.enable_flashinfer_autotune diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3221dc46c63c..8d575d971ddb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5666,6 +5666,7 @@ def _dummy_run( skip_eplb: bool = False, is_profile: bool = False, create_mixed_batch: bool = False, + create_single_prefill: bool = False, remove_lora: bool = True, is_graph_capturing: bool = False, num_active_loras: int = 0, @@ -5691,6 +5692,8 @@ def _dummy_run( is_profile: If True, this is a profile run. create_mixed_batch: If True, create a mixed batch with both decode (1 token) and prefill (multiple tokens) requests. + create_single_prefill: If True, create one prefill request with + ``num_tokens`` prompt tokens. remove_lora: If False, dummy LoRAs are not destroyed after the run num_active_loras: Number of distinct active LoRAs to capture for. LoRA is activated when num_active_loras > 0. @@ -5729,7 +5732,13 @@ def _dummy_run( # has num_tokens in total. assert num_tokens <= self.max_num_tokens max_num_reqs = self.scheduler_config.max_num_seqs - if create_mixed_batch: + if create_single_prefill: + assert not uniform_decode + assert not create_mixed_batch + num_reqs = 1 + num_scheduled_tokens_list = [num_tokens] + max_query_len = num_tokens + elif create_mixed_batch: assert not uniform_decode # Create mixed batch: # first half decode tokens, second half one prefill @@ -5743,6 +5752,7 @@ def _dummy_run( max_query_len = num_prefill_tokens elif uniform_decode: assert not create_mixed_batch + assert not create_single_prefill num_reqs = min(max_num_reqs, cdiv(num_tokens, max_query_len)) num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: From e47bdd1a4debbb5a8640a7e4e779acc45e10e102 Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 7 May 2026 03:09:12 +0800 Subject: [PATCH 010/131] Add SM12x sparse MLA direct decode kernels Signed-off-by: jasl --- vllm/envs.py | 31 + .../models/deepseek_v4/common/ops/__init__.py | 4 + .../deepseek_v4/common/ops/cache_utils.py | 152 + .../attention/backends/mla/sparse_mla_env.py | 72 +- .../backends/mla/sparse_mla_kernels.py | 2501 +++++++++++++++-- 5 files changed, 2558 insertions(+), 202 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 79dfecd947c0..721fa31410c6 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -182,6 +182,12 @@ VLLM_ENABLE_DEEPSEEK_V4_MHC_WARMUP: bool = True VLLM_DEEPSEEK_V4_MHC_WARMUP_TOKEN_SIZES: list[int] | None = None VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: bool = True + VLLM_TRITON_MLA_SPARSE: bool | None = None + VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE: int = 512 + VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE: int = 256 + VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE: int | None = None + VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE: bool | None = None + VLLM_TRITON_MLA_SPARSE_SPLITKV_DECODE: bool = False VLLM_DEEP_GEMM_WARMUP: Literal[ "skip", "full", @@ -1466,6 +1472,31 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP": lambda: bool( int(os.getenv("VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP", "1")) ), + # Experimental sparse MLA fallback controls. + # ``VLLM_TRITON_MLA_SPARSE`` unset means auto-select where FlashMLA sparse + # is unavailable; set 0/1 to force-disable/force-enable the fallback. + "VLLM_TRITON_MLA_SPARSE": lambda: ( + None + if os.getenv("VLLM_TRITON_MLA_SPARSE") is None + else bool(int(os.getenv("VLLM_TRITON_MLA_SPARSE", "0"))) + ), + "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE": lambda: int( + os.getenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", "512") + ), + "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE": lambda: int( + os.getenv("VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE", "256") + ), + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE": lambda: maybe_convert_int( + os.getenv("VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE") + ), + "VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE": lambda: ( + None + if os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE") is None + else bool(int(os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE", "0"))) + ), + "VLLM_TRITON_MLA_SPARSE_SPLITKV_DECODE": lambda: bool( + int(os.getenv("VLLM_TRITON_MLA_SPARSE_SPLITKV_DECODE", "0")) + ), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine diff --git a/vllm/models/deepseek_v4/common/ops/__init__.py b/vllm/models/deepseek_v4/common/ops/__init__.py index f4ff348091dd..9c15ce832cc5 100644 --- a/vllm/models/deepseek_v4/common/ops/__init__.py +++ b/vllm/models/deepseek_v4/common/ops/__init__.py @@ -6,6 +6,8 @@ combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, + dequantize_combined_sparse_mla_decode_kv, + dequantize_global_slots_k_cache, quantize_and_insert_k_cache, sparse_prefill_combined_topk_size, ) @@ -21,6 +23,8 @@ "combine_topk_swa_indices", "compute_global_topk_indices_and_lens", "dequantize_and_gather_k_cache", + "dequantize_combined_sparse_mla_decode_kv", + "dequantize_global_slots_k_cache", "fused_indexer_q_rope_quant", "fused_inv_rope_fp8_quant", "fused_mtp_input_rmsnorm", diff --git a/vllm/models/deepseek_v4/common/ops/cache_utils.py b/vllm/models/deepseek_v4/common/ops/cache_utils.py index 5bdad86de066..32c6850b10ca 100644 --- a/vllm/models/deepseek_v4/common/ops/cache_utils.py +++ b/vllm/models/deepseek_v4/common/ops/cache_utils.py @@ -423,6 +423,158 @@ def dequantize_and_gather_k_cache( ) +@triton.jit +def _dequantize_global_slots_k_kernel( + out_ptr, + out_stride_token, + out_stride_slot, + k_cache_ptr, + slot_ids_ptr, + slot_ids_stride_token, + slot_ids_stride_slot, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + bf16_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + output_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + topk_idx = tl.program_id(1) + + slot_id = tl.load( + slot_ids_ptr + + token_idx * slot_ids_stride_token + + topk_idx * slot_ids_stride_slot + ) + offsets = tl.arange(0, BLOCK_D) + output_row = out_ptr + token_idx * out_stride_token + topk_idx * out_stride_slot + + if slot_id < 0: + tl.store( + output_row + offsets, + tl.zeros((BLOCK_D,), dtype=tl.float32).to(tl.bfloat16), + mask=offsets < output_dim, + ) + return + + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + cache_block_size * token_data_size + pos_in_block * scale_dim + ) + + fp8_offsets = tl.arange(0, 512) + fp8_mask = fp8_offsets < fp8_dim + x_uint8 = tl.load(token_data_ptr + fp8_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + + scale_offsets = fp8_offsets // quant_block + encoded_scale = tl.load(token_scale_ptr + scale_offsets, mask=fp8_mask, other=127) + scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * scale + tl.store(output_row + fp8_offsets, x_dequant.to(tl.bfloat16), mask=fp8_mask) + + bf16_offsets = tl.arange(0, 64) + bf16_cache_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + bf16_vals = tl.load(bf16_cache_ptr + bf16_offsets, mask=bf16_offsets < bf16_dim) + tl.store( + output_row + fp8_dim + bf16_offsets, + bf16_vals, + mask=bf16_offsets < bf16_dim, + ) + + +def dequantize_global_slots_k_cache( + out: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + block_size: int, +) -> None: + """Dequantize fp8_ds_mla cache rows addressed by physical global slot ids.""" + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0, :] + assert slot_ids.dim() == 2, ( + f"slot_ids must be [num_tokens, topk], got {slot_ids.shape}" + ) + assert out.shape[:2] == slot_ids.shape + assert out.shape[-1] == 512 + assert out.dtype == torch.bfloat16 + assert k_cache.dtype == torch.uint8 + + TOKEN_FP8_DIM = 448 + TOKEN_BF16_DIM = 64 + TOKEN_SCALE_DIM = 8 + QUANT_BLOCK_SIZE = 64 + TOKEN_DATA_SIZE = TOKEN_FP8_DIM + TOKEN_BF16_DIM * 2 + + grid = slot_ids.shape + _dequantize_global_slots_k_kernel[grid]( + out, + out.stride(0), + out.stride(1), + k_cache, + slot_ids, + slot_ids.stride(0), + slot_ids.stride(1), + cache_block_size=block_size, + token_data_size=TOKEN_DATA_SIZE, + block_stride=k_cache.stride(0), + fp8_dim=TOKEN_FP8_DIM, + bf16_dim=TOKEN_BF16_DIM, + scale_dim=TOKEN_SCALE_DIM, + quant_block=QUANT_BLOCK_SIZE, + output_dim=512, + BLOCK_D=triton.next_power_of_2(512), + ) + + +def dequantize_combined_sparse_mla_decode_kv( + combined_kv: torch.Tensor, + compressed_k_cache: torch.Tensor, + compressed_slot_ids: torch.Tensor, + compressed_block_size: int, + swa_k_cache: torch.Tensor, + seq_lens: torch.Tensor, + swa_lens: torch.Tensor, + block_table: torch.Tensor, + swa_block_size: int, +) -> None: + """Fill `[compressed, SWA]` decode candidates into one output buffer.""" + assert combined_kv.dim() == 3 + compressed_topk = compressed_slot_ids.shape[-1] + assert combined_kv.shape[0] == compressed_slot_ids.shape[0] + assert combined_kv.shape[-1] == 512 + assert combined_kv.dtype == torch.bfloat16 + assert combined_kv.shape[1] >= compressed_topk + + dequantize_global_slots_k_cache( + combined_kv[:, :compressed_topk], + compressed_k_cache, + compressed_slot_ids, + compressed_block_size, + ) + swa_out = combined_kv[:, compressed_topk:] + if swa_out.shape[1] == 0: + return + dequantize_and_gather_k_cache( + swa_out, + swa_k_cache, + seq_lens=seq_lens, + gather_lens=swa_lens, + block_table=block_table, + block_size=swa_block_size, + offset=0, + ) + + def compute_global_topk_indices_and_lens( topk_indices: torch.Tensor, token_to_req_indices: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/sparse_mla_env.py b/vllm/v1/attention/backends/mla/sparse_mla_env.py index ff8c3a5c4b56..4567eca9dd83 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_env.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_env.py @@ -1,62 +1,64 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Platform controls for the portable Triton sparse MLA path.""" +"""Environment controls for the portable Triton sparse MLA path.""" import torch -from vllm.logger import init_logger +import vllm.envs as envs from vllm.platforms import current_platform -_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE = 512 -_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE = 256 - -logger = init_logger(__name__) - def _is_sm12x_device(device: torch.device) -> bool: - if not torch.cuda.is_available(): + if not current_platform.is_cuda(): return False - index = device.index if device.index is not None else torch.cuda.current_device() - return torch.cuda.get_device_capability(index)[0] == 12 + index = ( + device.index + if device.index is not None + else torch.accelerator.current_device_index() + ) + capability = current_platform.get_device_capability(device_id=index) + return capability is not None and capability[0] == 12 + + +def triton_sparse_mla_configured() -> bool | None: + return envs.VLLM_TRITON_MLA_SPARSE def is_triton_sparse_mla_enabled_for_platform() -> bool: + configured = triton_sparse_mla_configured() + if configured is not None: + return configured return current_platform.is_device_capability_family(120) def is_triton_sparse_mla_enabled(device: torch.device) -> bool: + configured = triton_sparse_mla_configured() + if configured is not None: + return configured return _is_sm12x_device(device) -def disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config) -> None: - if not is_triton_sparse_mla_enabled_for_platform(): - return +def triton_sparse_mla_topk_chunk_size() -> int: + return envs.VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE + - from vllm.config.compilation import CompilationMode, CUDAGraphMode +def triton_sparse_mla_query_chunk_size() -> int: + return envs.VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE - compilation_config = vllm_config.compilation_config - if ( - compilation_config.mode == CompilationMode.NONE - and compilation_config.cudagraph_mode == CUDAGraphMode.NONE - ): - return - logger.warning_once( - "Disabling vLLM compile and CUDA graphs for the DeepSeek V4 Triton " - "sparse MLA path because the current Triton sparse MLA path is not " - "compile/graph-safe yet." - ) - compilation_config.mode = CompilationMode.NONE - compilation_config.compile_sizes = [] - compilation_config.compile_ranges_endpoints = [] - compilation_config.cudagraph_mode = CUDAGraphMode.NONE - compilation_config.cudagraph_capture_sizes = [] - compilation_config.max_cudagraph_capture_size = 0 +def triton_sparse_mla_head_block_size() -> int | None: + value = envs.VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE + if value in (1, 2, 4): + return value + return None -def triton_sparse_mla_topk_chunk_size() -> int: - return _TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE +def triton_sparse_mla_matmul_decode_enabled() -> bool: + configured = envs.VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE + if configured is not None: + return configured + return current_platform.is_device_capability_family(120) -def triton_sparse_mla_query_chunk_size() -> int: - return _TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE +def triton_sparse_mla_splitkv_decode_enabled() -> bool: + return envs.VLLM_TRITON_MLA_SPARSE_SPLITKV_DECODE diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py index 36c41e4b3bad..867eaf298149 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -2,47 +2,1780 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Portable sparse MLA Triton kernels.""" +import math + import torch -from vllm.triton_utils import tl, triton +from vllm.triton_utils import LOG2E, LOGE2, tl, triton +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + triton_sparse_mla_head_block_size, +) + +_SPLITKV_HEAD_BLOCK = 16 +_SPLITKV_MERGE_HEAD_BLOCK = 1 +_SPLITKV_BLOCK_N = 32 +_SPLITKV_MERGE_BLOCK_D = 128 +_SPLITKV_MIN_CANDIDATES_PER_SPLIT = 128 +_SPLITKV_MEDIUM_BATCH_MIN_TOKENS = 16 +_SPLITKV_MEDIUM_BATCH_CANDIDATES_PER_SPLIT = 512 +_SPLITKV_MEDIUM_BATCH_MAX_SPLITS = 8 +_SPLITKV_MAX_OCCUPANCY = 4 + + +def sparse_mla_decode_head_block_size(num_decode_tokens: int) -> int: + """Choose the SM12x sparse MLA head grouping for decode kernels. + + Single-token decode is latency sensitive and does best with one head per + program. Once there are enough query tokens, grouping heads lets the kernel + reuse each dequantized KV row across multiple heads. + """ + + configured_head_block_size = triton_sparse_mla_head_block_size() + if configured_head_block_size is not None: + return configured_head_block_size + if num_decode_tokens <= 4: + return 1 + if num_decode_tokens < 16: + return 2 + return 4 + + +def _next_power_of_2(value: int) -> int: + return 1 << max(0, value - 1).bit_length() + + +def choose_sparse_mla_splitkv_splits( + num_tokens: int, + num_heads: int, + num_candidates: int, + sm_count: int, + head_block_size: int = _SPLITKV_HEAD_BLOCK, +) -> int: + if ( + num_tokens <= 0 + or num_heads <= 0 + or num_candidates <= 0 + or sm_count <= 0 + or head_block_size <= 0 + ): + return 1 + + num_head_groups = math.ceil(num_heads / min(head_block_size, num_heads)) + baseline = num_tokens * num_head_groups + if baseline == 0: + return 1 + + ideal = _next_power_of_2( + max(1, num_candidates // _SPLITKV_MIN_CANDIDATES_PER_SPLIT) + ) + max_splits = max(1, (sm_count * _SPLITKV_MAX_OCCUPANCY) // baseline) + max_splits = 1 << (max_splits.bit_length() - 1) + num_splits = min(ideal, max_splits) + if ( + num_tokens >= _SPLITKV_MEDIUM_BATCH_MIN_TOKENS + and baseline <= sm_count * _SPLITKV_MAX_OCCUPANCY + ): + medium_batch_splits = _next_power_of_2( + max(1, num_candidates // _SPLITKV_MEDIUM_BATCH_CANDIDATES_PER_SPLIT) + ) + medium_batch_splits = min( + ideal, medium_batch_splits, _SPLITKV_MEDIUM_BATCH_MAX_SPLITS + ) + num_splits = max(num_splits, medium_batch_splits) + while num_splits > 1 and num_candidates % num_splits != 0: + num_splits //= 2 + return max(1, num_splits) + + +@triton.jit +def _splitkv_sparse_mla_stage1_kernel( + q_ptr, + kv_ptr, + valid_ptr, + mid_ptr, + stride_qt: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kvt: tl.constexpr, + stride_kvc: tl.constexpr, + stride_kvd: tl.constexpr, + stride_vt: tl.constexpr, + stride_vc: tl.constexpr, + stride_mt: tl.constexpr, + stride_mh: tl.constexpr, + stride_ms: tl.constexpr, + num_heads: tl.constexpr, + num_candidates: tl.constexpr, + scale: tl.constexpr, + num_splits: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + LOGE2_VALUE: tl.constexpr, +): + token_id = tl.program_id(0) + head_group = tl.program_id(1) + split_id = tl.program_id(2) + + offs_h = head_group * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + mask_h = offs_h < num_heads + offs_d = tl.arange(0, BLOCK_D) + + q = tl.load( + q_ptr + + token_id * stride_qt + + offs_h[:, None] * stride_qh + + offs_d[None, :] * stride_qd, + mask=mask_h[:, None], + other=0.0, + ) + + split_size: tl.constexpr = tl.cdiv(num_candidates, num_splits) + split_start = split_id * split_size + split_end = tl.minimum(split_start + split_size, num_candidates) + + neg_large = -1.0e30 + e_max = tl.full((HEAD_BLOCK,), neg_large, dtype=tl.float32) + e_sum = tl.zeros((HEAD_BLOCK,), dtype=tl.float32) + acc = tl.zeros((HEAD_BLOCK, BLOCK_D), dtype=tl.float32) + + for cand_start in range(split_start, split_end, BLOCK_N): + offs_c = cand_start + tl.arange(0, BLOCK_N) + mask_c = offs_c < split_end + valid = tl.load( + valid_ptr + token_id * stride_vt + offs_c * stride_vc, + mask=mask_c, + other=0, + ) + mask_kv = mask_c & valid + k = tl.load( + kv_ptr + + token_id * stride_kvt + + offs_c[:, None] * stride_kvc + + offs_d[None, :] * stride_kvd, + mask=mask_kv[:, None], + other=0.0, + ) + qk = tl.dot(q, tl.trans(k.to(q.dtype))) * scale + qk = tl.where(mask_h[:, None] & mask_kv[None, :], qk, neg_large) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp2(e_max - n_e_max) + p = tl.exp2(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(k.dtype), k) + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + e_sum_safe = tl.where(e_sum > 0, e_sum, 1.0) + mid_base = ( + mid_ptr + + token_id * stride_mt + + offs_h[:, None] * stride_mh + + split_id * stride_ms + ) + tl.store( + mid_base + offs_d[None, :], + acc / e_sum_safe[:, None], + mask=mask_h[:, None], + ) + tl.store( + mid_ptr + + token_id * stride_mt + + offs_h * stride_mh + + split_id * stride_ms + + BLOCK_D, + (e_max + tl.log2(e_sum)) * LOGE2_VALUE, + mask=mask_h, + ) + + +@triton.jit +def _splitkv_sparse_mla_merge_kernel( + mid_ptr, + sink_ptr, + output_ptr, + stride_mt: tl.constexpr, + stride_mh: tl.constexpr, + stride_ms: tl.constexpr, + stride_out_t: tl.constexpr, + stride_oh: tl.constexpr, + stride_od: tl.constexpr, + num_heads: tl.constexpr, + num_splits: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_D_TILE: tl.constexpr, +): + token_id = tl.program_id(0) + head_group = tl.program_id(1) + d_tile = tl.program_id(2) + + offs_h = head_group * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + mask_h = offs_h < num_heads + offs_d = d_tile * BLOCK_D_TILE + tl.arange(0, BLOCK_D_TILE) + mask_d = offs_d < BLOCK_D + + e_max = tl.full((HEAD_BLOCK,), -float("inf"), dtype=tl.float32) + e_sum = tl.zeros((HEAD_BLOCK,), dtype=tl.float32) + acc = tl.zeros((HEAD_BLOCK, BLOCK_D_TILE), dtype=tl.float32) + mid_base = mid_ptr + token_id * stride_mt + offs_h[:, None] * stride_mh + mid_lse = mid_ptr + token_id * stride_mt + offs_h * stride_mh + BLOCK_D + + for split_id in range(num_splits): + part = tl.load( + mid_base + split_id * stride_ms + offs_d[None, :], + mask=mask_h[:, None] & mask_d[None, :], + other=0.0, + ) + lse = tl.load( + mid_lse + split_id * stride_ms, + mask=mask_h, + other=-float("inf"), + ) + n_e_max = tl.maximum(lse, e_max) + old_scale = tl.exp(e_max - n_e_max) + part_scale = tl.exp(lse - n_e_max) + acc = acc * old_scale[:, None] + part * part_scale[:, None] + e_sum = e_sum * old_scale + part_scale + e_max = n_e_max + + sink = tl.load(sink_ptr + offs_h, mask=mask_h, other=-float("inf")) + n_e_max = tl.maximum(sink, e_max) + value_scale = tl.exp(e_max - n_e_max) + sink_scale = tl.exp(sink - n_e_max) + denom = e_sum * value_scale + sink_scale + denom = tl.where(denom > 0, denom, 1.0) + merged = acc * value_scale[:, None] / denom[:, None] + + tl.store( + output_ptr + + token_id * stride_out_t + + offs_h[:, None] * stride_oh + + offs_d[None, :] * stride_od, + merged, + mask=mask_h[:, None] & mask_d[None, :], + ) + + +def splitkv_sparse_mla_attention_with_sink( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + mid: torch.Tensor, + num_splits: int, + num_heads: int | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}" + assert valid_tokens.shape == kv.shape[:2] + assert q.shape[0] == kv.shape[0] + assert q.shape[-1] == kv.shape[-1] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert q.is_cuda and kv.is_cuda and valid_tokens.is_cuda + assert attn_sink.is_cuda and output.is_cuda and mid.is_cuda + + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + + num_tokens, _, head_dim = q.shape + num_candidates = kv.shape[1] + assert mid.shape == (num_tokens, active_heads, num_splits, head_dim + 1) + num_head_groups = triton.cdiv(active_heads, _SPLITKV_HEAD_BLOCK) + _splitkv_sparse_mla_stage1_kernel[(num_tokens, num_head_groups, num_splits)]( + q, + kv, + valid_tokens, + mid, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + valid_tokens.stride(0), + valid_tokens.stride(1), + mid.stride(0), + mid.stride(1), + mid.stride(2), + active_heads, + num_candidates, + scale * LOG2E, + num_splits, + HEAD_BLOCK=_SPLITKV_HEAD_BLOCK, + BLOCK_N=_SPLITKV_BLOCK_N, + BLOCK_D=head_dim, + LOGE2_VALUE=LOGE2, + num_warps=4, + ) + _splitkv_sparse_mla_merge_kernel[ + (num_tokens, active_heads, triton.cdiv(head_dim, _SPLITKV_MERGE_BLOCK_D)) + ]( + mid, + attn_sink, + output, + mid.stride(0), + mid.stride(1), + mid.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + active_heads, + num_splits, + HEAD_BLOCK=_SPLITKV_MERGE_HEAD_BLOCK, + BLOCK_D=head_dim, + BLOCK_D_TILE=_SPLITKV_MERGE_BLOCK_D, + num_warps=2, + ) + + +@triton.jit +def _merge_two_subsets_with_sink_kernel( + out0_ptr, + lse0_ptr, + out1_ptr, + lse1_ptr, + sink_ptr, + output_ptr, + stride_out0_t: tl.constexpr, + stride_out0_h: tl.constexpr, + stride_out0_d: tl.constexpr, + stride_lse0_t: tl.constexpr, + stride_lse0_h: tl.constexpr, + stride_out1_t: tl.constexpr, + stride_out1_h: tl.constexpr, + stride_out1_d: tl.constexpr, + stride_lse1_t: tl.constexpr, + stride_lse1_h: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offsets < head_dim + + lse0 = tl.load(lse0_ptr + token_idx * stride_lse0_t + head_idx * stride_lse0_h) + lse1 = tl.load(lse1_ptr + token_idx * stride_lse1_t + head_idx * stride_lse1_h) + sink = tl.load(sink_ptr + head_idx) + merge_max = tl.maximum(tl.maximum(lse0, lse1), sink) + + weight0 = tl.exp(lse0 - merge_max) + weight1 = tl.exp(lse1 - merge_max) + weight_sink = tl.exp(sink - merge_max) + denom = weight0 + weight1 + weight_sink + + out0 = tl.load( + out0_ptr + + token_idx * stride_out0_t + + head_idx * stride_out0_h + + offsets * stride_out0_d, + mask=mask, + other=0.0, + ).to(tl.float32) + out1 = tl.load( + out1_ptr + + token_idx * stride_out1_t + + head_idx * stride_out1_h + + offsets * stride_out1_d, + mask=mask, + other=0.0, + ).to(tl.float32) + merged = (out0 * weight0 + out1 * weight1) / denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + merged, + mask=mask, + ) + + +def merge_two_sparse_mla_subsets_with_sink( + subset0_output: torch.Tensor, + subset0_lse: torch.Tensor, + subset1_output: torch.Tensor, + subset1_lse: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert subset0_output.shape == subset1_output.shape + assert subset0_output.shape == output.shape + assert subset0_lse.shape == subset1_lse.shape + assert subset0_lse.shape == subset0_output.shape[:2] + assert attn_sink.shape[0] == subset0_output.shape[1] + assert subset0_output.is_cuda + assert subset1_output.is_cuda + assert output.is_cuda + + num_tokens, num_heads, head_dim = subset0_output.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _merge_two_subsets_with_sink_kernel[grid]( + subset0_output, + subset0_lse, + subset1_output, + subset1_lse, + attn_sink, + output, + subset0_output.stride(0), + subset0_output.stride(1), + subset0_output.stride(2), + subset0_lse.stride(0), + subset0_lse.stride(1), + subset1_output.stride(0), + subset1_output.stride(1), + subset1_output.stride(2), + subset1_lse.stride(0), + subset1_lse.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _merge_single_subset_with_sink_kernel( + subset_output_ptr, + subset_lse_ptr, + sink_ptr, + output_ptr, + stride_subset_t: tl.constexpr, + stride_subset_h: tl.constexpr, + stride_subset_d: tl.constexpr, + stride_lse_t: tl.constexpr, + stride_lse_h: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offsets < head_dim + + subset_lse = tl.load( + subset_lse_ptr + token_idx * stride_lse_t + head_idx * stride_lse_h + ) + sink = tl.load(sink_ptr + head_idx) + merge_max = tl.maximum(subset_lse, sink) + + subset_weight = tl.exp(subset_lse - merge_max) + sink_weight = tl.exp(sink - merge_max) + denom = subset_weight + sink_weight + subset_output = tl.load( + subset_output_ptr + + token_idx * stride_subset_t + + head_idx * stride_subset_h + + offsets * stride_subset_d, + mask=mask, + other=0.0, + ).to(tl.float32) + merged = subset_output * subset_weight / denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + merged, + mask=mask, + ) + + +def merge_sparse_mla_subset_with_sink( + subset_output: torch.Tensor, + subset_lse: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert subset_output.shape == output.shape + assert subset_lse.shape == subset_output.shape[:2] + assert attn_sink.shape[0] == subset_output.shape[1] + assert subset_output.is_cuda + assert subset_lse.is_cuda + assert attn_sink.is_cuda + assert output.is_cuda + + num_tokens, num_heads, head_dim = subset_output.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _merge_single_subset_with_sink_kernel[grid]( + subset_output, + subset_lse, + attn_sink, + output, + subset_output.stride(0), + subset_output.stride(1), + subset_output.stride(2), + subset_lse.stride(0), + subset_lse.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _build_combined_decode_valid_mask_kernel( + output_ptr, + slot_ids_ptr, + topk_lens_ptr, + swa_lens_ptr, + stride_output_t: tl.constexpr, + stride_output_c: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + num_compressed_candidates: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_C: tl.constexpr, +): + token_idx = tl.program_id(0) + offsets = tl.arange(0, BLOCK_C) + candidate_mask = offsets < num_candidates + + topk_lens = tl.load(topk_lens_ptr + token_idx) + swa_lens = tl.load(swa_lens_ptr + token_idx) + is_compressed = offsets < num_compressed_candidates + swa_offsets = offsets - num_compressed_candidates + slot_ids = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + offsets * stride_slot_c, + mask=is_compressed, + other=-1, + ) + valid_compressed = is_compressed & (offsets < topk_lens) & (slot_ids >= 0) + valid_swa = (~is_compressed) & (swa_offsets < swa_lens) + valid = valid_compressed | valid_swa + tl.store( + output_ptr + token_idx * stride_output_t + offsets * stride_output_c, + valid, + mask=candidate_mask, + ) + + +def build_combined_sparse_mla_decode_valid_mask( + output: torch.Tensor, + compressed_slot_ids: torch.Tensor, + topk_lens: torch.Tensor, + swa_lens: torch.Tensor, +) -> None: + """Build `[compressed, SWA]` validity mask for SM12x decode.""" + if compressed_slot_ids.dim() == 3: + assert compressed_slot_ids.shape[1] == 1 + compressed_slot_ids = compressed_slot_ids[:, 0, :] + + assert output.dim() == 2 + assert output.dtype == torch.bool + assert compressed_slot_ids.dim() == 2 + assert output.shape[0] == compressed_slot_ids.shape[0] + assert output.shape[0] == topk_lens.shape[0] + assert output.shape[0] == swa_lens.shape[0] + assert output.shape[1] >= compressed_slot_ids.shape[1] + assert output.is_cuda + assert compressed_slot_ids.is_cuda + assert topk_lens.is_cuda + assert swa_lens.is_cuda + + num_candidates = output.shape[1] + block_c = triton.next_power_of_2(num_candidates) + _build_combined_decode_valid_mask_kernel[(output.shape[0],)]( + output, + compressed_slot_ids, + topk_lens, + swa_lens, + output.stride(0), + output.stride(1), + compressed_slot_ids.stride(0), + compressed_slot_ids.stride(1), + compressed_slot_ids.shape[1], + num_candidates, + BLOCK_C=block_c, + num_warps=4, + ) + + +def matmul_sparse_mla_attention_with_sink( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + num_heads: int | None = None, + score_buffer: torch.Tensor | None = None, + head_block_size: int = 1, + value_block_size: int | None = None, + candidate_block_size: int | None = None, +) -> None: + """Compute sink-aware sparse MLA over materialized BF16 KV. + + This path intentionally dequantizes/gathers KV once, computes scores with + batched matrix multiplication, and finishes the sink-aware value reduction + in Triton. It is useful for the SM12x decode path where the direct Triton + kernel otherwise repeats fp8_ds_mla dequantization once per head group. + """ + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}" + assert valid_tokens.shape == kv.shape[:2] + assert q.shape[0] == kv.shape[0] + assert q.shape[-1] == kv.shape[-1] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert q.is_cuda and kv.is_cuda and valid_tokens.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + + q_active = q[:, :active_heads] + num_tokens = q.shape[0] + num_candidates = kv.shape[1] + if score_buffer is None: + score_buffer = torch.empty( + (num_tokens, active_heads, num_candidates), + dtype=torch.float32, + device=q.device, + ) + assert score_buffer.shape == (num_tokens, active_heads, num_candidates) + assert score_buffer.device == q.device + assert score_buffer.dtype in (torch.float32, torch.bfloat16) + if score_buffer.dtype == torch.float32: + q_score = q_active.float() + kv_score = kv.float() + else: + q_score = q_active.to(score_buffer.dtype) + kv_score = kv.to(score_buffer.dtype) + torch.bmm(q_score, kv_score.transpose(1, 2), out=score_buffer) + score_buffer.mul_(scale) + finish_materialized_sparse_mla_scores_with_sink( + score_buffer, + kv, + valid_tokens, + attn_sink, + output, + num_heads=active_heads, + head_block_size=head_block_size, + value_block_size=value_block_size, + candidate_block_size=candidate_block_size, + ) + + +@triton.jit +def _finish_materialized_scores_with_sink_kernel( + scores_ptr, + kv_ptr, + valid_tokens_ptr, + attn_sink_ptr, + output_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_valid_t: tl.constexpr, + stride_valid_c: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + running_max = tl.load(attn_sink_ptr + head_offsets, mask=head_mask, other=0.0).to( + tl.float32 + ) + running_denom = tl.full((HEAD_BLOCK,), 1.0, tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + for candidate_idx in range(0, num_candidates): + is_valid = tl.load( + valid_tokens_ptr + + token_idx * stride_valid_t + + candidate_idx * stride_valid_c + ) + if is_valid: + score = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_offsets * stride_scores_h + + candidate_idx * stride_scores_c, + mask=head_mask, + other=-float("inf"), + ).to(tl.float32) + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidate_idx * stride_kv_c + + dim_offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + result = running_acc / running_denom[:, None] + tl.store( + output_ptr + + token_idx * stride_out_t + + head_offsets[:, None] * stride_out_h + + dim_offsets[None, :] * stride_out_d, + result, + mask=matrix_mask, + ) + + +@triton.jit +def _finish_materialized_scores_with_sink_candidate_block_kernel( + scores_ptr, + kv_ptr, + valid_tokens_ptr, + attn_sink_ptr, + output_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_valid_t: tl.constexpr, + stride_valid_c: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + head_dim: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + dim_block_idx = tl.program_id(2) + candidate_offsets = tl.arange(0, BLOCK_K) + dim_offsets = dim_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = dim_offsets < head_dim + + max_score = tl.load(attn_sink_ptr + head_idx).to(tl.float32) + for candidate_start in range(0, num_candidates, BLOCK_K): + candidates = candidate_start + candidate_offsets + candidate_mask = candidates < num_candidates + is_valid = tl.load( + valid_tokens_ptr + token_idx * stride_valid_t + candidates * stride_valid_c, + mask=candidate_mask, + other=0, + ).to(tl.int1) + scores = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidates * stride_scores_c, + mask=candidate_mask & is_valid, + other=-float("inf"), + ).to(tl.float32) + max_score = tl.maximum(max_score, tl.max(scores, axis=0)) + + denom = tl.exp(tl.load(attn_sink_ptr + head_idx).to(tl.float32) - max_score) + acc = tl.zeros((BLOCK_D,), tl.float32) + for candidate_start in range(0, num_candidates, BLOCK_K): + candidates = candidate_start + candidate_offsets + candidate_mask = candidates < num_candidates + is_valid = tl.load( + valid_tokens_ptr + token_idx * stride_valid_t + candidates * stride_valid_c, + mask=candidate_mask, + other=0, + ).to(tl.int1) + scores = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidates * stride_scores_c, + mask=candidate_mask & is_valid, + other=-float("inf"), + ).to(tl.float32) + weights = tl.exp(scores - max_score) + denom += tl.sum(weights, axis=0) + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidates[:, None] * stride_kv_c + + dim_offsets[None, :] * stride_kv_d, + mask=(candidate_mask & is_valid)[:, None] & dim_mask[None, :], + other=0.0, + ) + acc += tl.sum(kv.to(tl.float32) * weights[:, None], axis=0) + + tl.store( + output_ptr + + token_idx * stride_out_t + + head_idx * stride_out_h + + dim_offsets * stride_out_d, + acc / denom, + mask=dim_mask, + ) + + +@triton.jit +def _finish_materialized_scores_with_sink_value_block_kernel( + scores_ptr, + kv_ptr, + valid_tokens_ptr, + attn_sink_ptr, + output_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_valid_t: tl.constexpr, + stride_valid_c: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + head_dim: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + dim_block_idx = tl.program_id(2) + dim_offsets = dim_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = dim_offsets < head_dim + + running_max = tl.load(attn_sink_ptr + head_idx).to(tl.float32) + running_denom = tl.full((), 1.0, tl.float32) + running_acc = tl.zeros((BLOCK_D,), tl.float32) + + for candidate_idx in range(0, num_candidates): + is_valid = tl.load( + valid_tokens_ptr + + token_idx * stride_valid_t + + candidate_idx * stride_valid_c + ) + if is_valid: + score = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidate_idx * stride_scores_c + ).to(tl.float32) + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidate_idx * stride_kv_c + + dim_offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + result = running_acc / running_denom + tl.store( + output_ptr + + token_idx * stride_out_t + + head_idx * stride_out_h + + dim_offsets * stride_out_d, + result, + mask=dim_mask, + ) + + +def finish_materialized_sparse_mla_scores_with_sink( + scores: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, + num_heads: int | None = None, + head_block_size: int = 1, + value_block_size: int | None = None, + candidate_block_size: int | None = None, +) -> None: + assert scores.dim() == 3 + assert kv.dim() == 3 + assert valid_tokens.shape == kv.shape[:2] + assert scores.shape[0] == kv.shape[0] + assert scores.shape[2] == kv.shape[1] + assert output.shape[0] == kv.shape[0] + assert output.shape[2] == kv.shape[2] + assert scores.dtype in (torch.float32, torch.bfloat16) + assert head_block_size in (1, 2, 4) + if value_block_size is not None: + assert value_block_size in (64, 128, 256, 512) + if candidate_block_size is not None: + assert candidate_block_size in (16, 32, 64, 128) + assert scores.is_cuda and kv.is_cuda and valid_tokens.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= scores.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + + num_tokens, _, num_candidates = scores.shape + head_dim = kv.shape[2] + if candidate_block_size is not None: + block_d = value_block_size if value_block_size is not None else 128 + candidate_grid = (num_tokens, active_heads, triton.cdiv(head_dim, block_d)) + _finish_materialized_scores_with_sink_candidate_block_kernel[candidate_grid]( + scores, + kv, + valid_tokens, + attn_sink, + output, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + valid_tokens.stride(0), + valid_tokens.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + head_dim, + num_candidates, + BLOCK_K=candidate_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + if output.shape[1] > active_heads: + output[:, active_heads:].zero_() + return + + if value_block_size is not None and value_block_size < head_dim: + value_grid = ( + num_tokens, + active_heads, + triton.cdiv(head_dim, value_block_size), + ) + _finish_materialized_scores_with_sink_value_block_kernel[value_grid]( + scores, + kv, + valid_tokens, + attn_sink, + output, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + valid_tokens.stride(0), + valid_tokens.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + head_dim, + num_candidates, + BLOCK_D=value_block_size, + num_warps=4, + ) + if output.shape[1] > active_heads: + output[:, active_heads:].zero_() + return + + block_d = min(1024, triton.next_power_of_2(head_dim)) + head_grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _finish_materialized_scores_with_sink_kernel[head_grid]( + scores, + kv, + valid_tokens, + attn_sink, + output, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + valid_tokens.stride(0), + valid_tokens.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + active_heads, + head_dim, + num_candidates, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + if output.shape[1] > active_heads: + output[:, active_heads:].zero_() + + +@triton.jit +def _accumulate_gathered_attention_chunk_kernel( + q_ptr, + kv_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HAS_SLOT_IDS: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + + for candidate_idx in range(0, num_candidates): + is_valid = (candidate_offset + candidate_idx) < valid_len + if HAS_SLOT_IDS: + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = is_valid & (slot_id >= 0) + + if is_valid: + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidate_idx * stride_kv_c + + offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_gathered_sparse_mla_attention_chunk( + q: torch.Tensor, + kv: torch.Tensor, + lens: torch.Tensor, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, + slot_ids: torch.Tensor | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}" + assert q.shape[0] == kv.shape[0] + assert q.shape[-1] == kv.shape[-1] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert q.is_cuda and kv.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + if slot_ids is not None: + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + assert slot_ids.dim() == 2 + assert slot_ids.shape == kv.shape[:2] + assert slot_ids.is_cuda + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = kv.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_gathered_attention_chunk_kernel[grid]( + q, + kv, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + slot_ids.stride(0) if slot_ids is not None else 0, + slot_ids.stride(1) if slot_ids is not None else 0, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HAS_SLOT_IDS=slot_ids is not None, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_indexed_attention_chunk_kernel( + q_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + + for candidate_idx in range(0, num_candidates): + kv_index = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidate_idx * stride_indices_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (kv_index >= 0) + + if is_valid: + kv = tl.load( + kv_flat_ptr + + kv_index.to(tl.int64) * stride_kv_t + + offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_indexed_sparse_mla_attention_chunk( + q: torch.Tensor, + kv_flat: torch.Tensor, + indices: torch.Tensor, + lens: torch.Tensor, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv_flat.dim() == 2 + assert indices.dim() == 2 + assert indices.shape[0] == q.shape[0] + assert kv_flat.shape[-1] == q.shape[-1] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = indices.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_indexed_attention_chunk_kernel[grid]( + q, + kv_flat, + indices, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_global_slots_attention_chunk_kernel( + q_ptr, + k_cache_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + + fp8_mask = offsets < fp8_dim + rope_mask = (offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0) + + if is_valid: + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk( + q: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + lens: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = slot_ids.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_fp8ds_global_slots_attention_chunk_kernel[grid]( + q, + k_cache, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel( + q_ptr, + k_cache_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h + acc_offsets = ( + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d + ) + running_max = tl.load( + max_score_ptr + state_offsets, + mask=head_mask, + other=-float("inf"), + ) + running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) + running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( + tl.float32 + ) + valid_len = tl.load(lens_ptr + token_idx) + + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + for candidate_idx in range(0, num_candidates): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0) -def sparse_mla_decode_head_block_size(num_decode_tokens: int) -> int: - """Choose the SM12x sparse MLA head grouping for decode kernels. + if is_valid: + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) - Single-token decode is latency sensitive and does best with one head per - program. Once there are enough query tokens, grouping heads lets the kernel - reuse each dequantized KV row across multiple heads. - """ + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale - if num_decode_tokens <= 4: - return 1 - if num_decode_tokens < 16: - return 2 - return 4 + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) + tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) + tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + + +def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + lens: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, + head_block_size: int = 2, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert head_block_size in (1, 2, 4) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = slot_ids.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) + _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel[grid]( + q, + k_cache, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) @triton.jit -def _accumulate_indexed_attention_chunk_kernel( +def _accumulate_fp8ds_paged_attention_chunk_kernel( q_ptr, - kv_flat_ptr, - indices_ptr, - lens_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, max_score_ptr, denom_ptr, acc_ptr, stride_q_t: tl.constexpr, stride_q_h: tl.constexpr, stride_q_d: tl.constexpr, - stride_kv_t, - stride_kv_d: tl.constexpr, - stride_indices_t: tl.constexpr, - stride_indices_c: tl.constexpr, + stride_block_table_t, stride_state_t: tl.constexpr, stride_state_h: tl.constexpr, stride_acc_t: tl.constexpr, stride_acc_h: tl.constexpr, stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, num_heads: tl.constexpr, head_dim: tl.constexpr, num_candidates, @@ -68,24 +1801,52 @@ def _accumulate_indexed_attention_chunk_kernel( running_max = tl.load(max_score_ptr + state_offset) running_denom = tl.load(denom_ptr + state_offset) running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) - valid_len = tl.load(lens_ptr + token_idx) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = offsets < fp8_dim + rope_mask = (offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(offsets - fp8_dim, 0) for candidate_idx in range(0, num_candidates): - kv_index = tl.load( - indices_ptr - + token_idx * stride_indices_t - + candidate_idx * stride_indices_c - ) - is_valid = ((candidate_offset + candidate_idx) < valid_len) & (kv_index >= 0) + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len if is_valid: - kv = tl.load( - kv_flat_ptr - + kv_index.to(tl.int64) * stride_kv_t - + offsets * stride_kv_d, - mask=dim_mask, - other=0.0, - ).to(tl.float32) + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + score = tl.sum(q * kv, axis=0) * scale next_max = tl.maximum(running_max, score) previous_weight = tl.exp(running_max - next_max) @@ -99,27 +1860,29 @@ def _accumulate_indexed_attention_chunk_kernel( tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) -def accumulate_indexed_sparse_mla_attention_chunk( +def accumulate_fp8ds_paged_sparse_mla_attention_chunk( q: torch.Tensor, - kv_flat: torch.Tensor, - indices: torch.Tensor, - lens: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, scale: float, max_score: torch.Tensor, denom: torch.Tensor, acc: torch.Tensor, - candidate_offset: int = 0, + candidate_offset: int, + num_candidates: int, ) -> None: if q.dim() == 4: assert q.shape[1] == 1 q = q[:, 0] assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" - assert kv_flat.dim() == 2 - assert indices.dim() == 2 - assert indices.shape[0] == q.shape[0] - assert kv_flat.shape[-1] == q.shape[-1] - assert lens.shape[0] == q.shape[0] + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] assert max_score.shape[0] == q.shape[0] assert max_score.shape[1] <= q.shape[1] assert denom.shape == max_score.shape @@ -127,34 +1890,45 @@ def accumulate_indexed_sparse_mla_attention_chunk( assert max_score.dtype == torch.float32 assert denom.dtype == torch.float32 assert acc.dtype == torch.float32 - assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + num_tokens, _, head_dim = q.shape num_heads = max_score.shape[1] - num_candidates = indices.shape[1] block_d = min(1024, triton.next_power_of_2(head_dim)) grid = (num_tokens, num_heads) - _accumulate_indexed_attention_chunk_kernel[grid]( + _accumulate_fp8ds_paged_attention_chunk_kernel[grid]( q, - kv_flat, - indices, - lens, + k_cache, + seq_lens, + gather_lens, + block_table, max_score, denom, acc, q.stride(0), q.stride(1), q.stride(2), - kv_flat.stride(0), - kv_flat.stride(1), - indices.stride(0), - indices.stride(1), + block_table.stride(0), max_score.stride(0), max_score.stride(1), acc.stride(0), acc.stride(1), acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, num_heads, head_dim, num_candidates, @@ -166,19 +1940,19 @@ def accumulate_indexed_sparse_mla_attention_chunk( @triton.jit -def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel( +def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel( q_ptr, k_cache_ptr, - slot_ids_ptr, - lens_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, max_score_ptr, denom_ptr, acc_ptr, stride_q_t: tl.constexpr, stride_q_h: tl.constexpr, stride_q_d: tl.constexpr, - stride_slot_t: tl.constexpr, - stride_slot_c: tl.constexpr, + stride_block_table_t, stride_state_t: tl.constexpr, stride_state_h: tl.constexpr, stride_acc_t: tl.constexpr, @@ -230,22 +2004,26 @@ def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel( running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( tl.float32 ) - valid_len = tl.load(lens_ptr + token_idx) + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len fp8_mask = dim_offsets < fp8_dim rope_mask = (dim_offsets >= fp8_dim) & dim_mask rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) for candidate_idx in range(0, num_candidates): - slot_id = tl.load( - slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c - ) - is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0) + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len if is_valid: - block_idx = slot_id // cache_block_size - pos_in_block = slot_id % cache_block_size - cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride token_data_ptr = cache_block_ptr + pos_in_block * token_data_size token_scale_ptr = ( cache_block_ptr @@ -288,31 +2066,30 @@ def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel( tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) -def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( +def accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( q: torch.Tensor, k_cache: torch.Tensor, - slot_ids: torch.Tensor, - lens: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, block_size: int, scale: float, max_score: torch.Tensor, denom: torch.Tensor, acc: torch.Tensor, - candidate_offset: int = 0, + candidate_offset: int, + num_candidates: int, head_block_size: int = 2, ) -> None: if q.dim() == 4: assert q.shape[1] == 1 q = q[:, 0] - if slot_ids.dim() == 3: - assert slot_ids.shape[1] == 1 - slot_ids = slot_ids[:, 0] assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" assert q.shape[-1] == 512 - assert slot_ids.dim() == 2 - assert slot_ids.shape[0] == q.shape[0] - assert lens.shape[0] == q.shape[0] + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] assert max_score.shape[0] == q.shape[0] assert max_score.shape[1] <= q.shape[1] assert denom.shape == max_score.shape @@ -322,7 +2099,8 @@ def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( assert denom.dtype == torch.float32 assert acc.dtype == torch.float32 assert k_cache.dtype == torch.uint8 - assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda assert max_score.is_cuda and denom.is_cuda and acc.is_cuda token_fp8_dim = 448 @@ -333,22 +2111,21 @@ def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( num_tokens, _, head_dim = q.shape num_heads = max_score.shape[1] - num_candidates = slot_ids.shape[1] block_d = min(1024, triton.next_power_of_2(head_dim)) grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) - _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel[grid]( + _accumulate_fp8ds_paged_attention_chunk_multihead_kernel[grid]( q, k_cache, - slot_ids, - lens, + seq_lens, + gather_lens, + block_table, max_score, denom, acc, q.stride(0), q.stride(1), q.stride(2), - slot_ids.stride(0), - slot_ids.stride(1), + block_table.stride(0), max_score.stride(0), max_score.stride(1), acc.stride(0), @@ -360,10 +2137,218 @@ def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( token_fp8_dim, token_scale_dim, quant_block_size, - num_heads, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _fp8ds_paged_attention_with_sink_multihead_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + sink_ptr, + output_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + candidate_offset: tl.constexpr, + num_candidates: tl.constexpr, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + running_max = tl.full((HEAD_BLOCK,), -float("inf"), tl.float32) + running_denom = tl.zeros((HEAD_BLOCK,), tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len + if is_valid: + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + sink = tl.load(sink_ptr + head_offsets, mask=head_mask, other=-float("inf")) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = running_denom * subset_scale + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + final = running_acc * subset_scale[:, None] * inv_total[:, None] + + tl.store( + output_ptr + + token_idx * stride_output_t + + head_offsets[:, None] * stride_output_h + + dim_offsets[None, :] * stride_output_d, + final, + mask=matrix_mask, + ) + + +def fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + candidate_offset: int, + num_candidates: int, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + head_block_size: int = 1, + num_heads: int | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert head_block_size in (1, 2, 4) + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _fp8ds_paged_attention_with_sink_multihead_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + attn_sink, + output, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + output.stride(0), + output.stride(1), + output.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + active_heads, head_dim, - num_candidates, candidate_offset, + num_candidates, scale, HEAD_BLOCK=head_block_size, BLOCK_D=block_d, @@ -372,34 +2357,38 @@ def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( @triton.jit -def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel( +def _fp8ds_global_paged_attention_with_sink_multihead_kernel( q_ptr, - k_cache_ptr, + compressed_k_cache_ptr, + slot_ids_ptr, + topk_lens_ptr, + swa_k_cache_ptr, seq_lens_ptr, gather_lens_ptr, block_table_ptr, - max_score_ptr, - denom_ptr, - acc_ptr, + sink_ptr, + output_ptr, stride_q_t: tl.constexpr, stride_q_h: tl.constexpr, stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, stride_block_table_t, - stride_state_t: tl.constexpr, - stride_state_h: tl.constexpr, - stride_acc_t: tl.constexpr, - stride_acc_h: tl.constexpr, - stride_acc_d: tl.constexpr, - cache_block_size: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + compressed_cache_block_size: tl.constexpr, + compressed_block_stride: tl.constexpr, + swa_cache_block_size: tl.constexpr, + swa_block_stride: tl.constexpr, token_data_size: tl.constexpr, - block_stride: tl.constexpr, fp8_dim: tl.constexpr, scale_dim: tl.constexpr, quant_block: tl.constexpr, num_heads: tl.constexpr, head_dim: tl.constexpr, - num_candidates, - candidate_offset, + num_compressed_candidates: tl.constexpr, + num_swa_candidates: tl.constexpr, scale: tl.constexpr, HEAD_BLOCK: tl.constexpr, BLOCK_D: tl.constexpr, @@ -420,46 +2409,82 @@ def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel( mask=matrix_mask, other=0.0, ).to(tl.float32) + running_max = tl.full((HEAD_BLOCK,), -float("inf"), tl.float32) + running_denom = tl.zeros((HEAD_BLOCK,), tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) - state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h - acc_offsets = ( - token_idx * stride_acc_t - + head_offsets[:, None] * stride_acc_h - + dim_offsets[None, :] * stride_acc_d - ) - running_max = tl.load( - max_score_ptr + state_offsets, - mask=head_mask, - other=-float("inf"), - ) - running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) - running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( - tl.float32 - ) - - seq_len = tl.load(seq_lens_ptr + token_idx) - gather_len = tl.load(gather_lens_ptr + token_idx) - start_pos = seq_len - gather_len fp8_mask = dim_offsets < fp8_dim rope_mask = (dim_offsets >= fp8_dim) & dim_mask rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + topk_len = tl.load(topk_lens_ptr + token_idx) - for candidate_idx in range(0, num_candidates): - gather_idx = candidate_offset + candidate_idx - is_valid = gather_idx < gather_len + for candidate_idx in range(0, num_compressed_candidates): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = (candidate_idx < topk_len) & (slot_id >= 0) + if is_valid: + block_idx = slot_id // compressed_cache_block_size + pos_in_block = slot_id % compressed_cache_block_size + cache_block_ptr = ( + compressed_k_cache_ptr + + block_idx.to(tl.int64) * compressed_block_stride + ) + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + compressed_cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + for candidate_idx in range(0, num_swa_candidates): + is_valid = candidate_idx < gather_len if is_valid: - pos = start_pos + gather_idx - block_in_seq = pos // cache_block_size - pos_in_block = pos % cache_block_size + pos = start_pos + candidate_idx + block_in_seq = pos // swa_cache_block_size + pos_in_block = pos % swa_cache_block_size physical_block = tl.load( block_table_ptr + token_idx * stride_block_table_t + block_in_seq ) - cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + cache_block_ptr = ( + swa_k_cache_ptr + physical_block.to(tl.int64) * swa_block_stride + ) token_data_ptr = cache_block_ptr + pos_in_block * token_data_size token_scale_ptr = ( cache_block_ptr - + cache_block_size * token_data_size + + swa_cache_block_size * token_data_size + pos_in_block * scale_dim ) @@ -474,7 +2499,6 @@ def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel( ) dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) x_dequant = x_float * dequant_scale - rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( tl.float32 @@ -493,47 +2517,75 @@ def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel( running_denom = running_denom * previous_weight + candidate_weight running_max = next_max - tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) - tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) - tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + sink = tl.load(sink_ptr + head_offsets, mask=head_mask, other=-float("inf")) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = running_denom * subset_scale + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + final = running_acc * subset_scale[:, None] * inv_total[:, None] + + tl.store( + output_ptr + + token_idx * stride_output_t + + head_offsets[:, None] * stride_output_h + + dim_offsets[None, :] * stride_output_d, + final, + mask=matrix_mask, + ) -def accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( +def fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( q: torch.Tensor, - k_cache: torch.Tensor, + compressed_k_cache: torch.Tensor, + slot_ids: torch.Tensor, + topk_lens: torch.Tensor, + compressed_block_size: int, + swa_k_cache: torch.Tensor, seq_lens: torch.Tensor, gather_lens: torch.Tensor, block_table: torch.Tensor, - block_size: int, + swa_block_size: int, + num_compressed_candidates: int, + num_swa_candidates: int, scale: float, - max_score: torch.Tensor, - denom: torch.Tensor, - acc: torch.Tensor, - candidate_offset: int, - num_candidates: int, - head_block_size: int = 2, + attn_sink: torch.Tensor, + output: torch.Tensor, + head_block_size: int = 1, + num_heads: int | None = None, ) -> None: if q.dim() == 4: assert q.shape[1] == 1 q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert topk_lens.shape[0] == q.shape[0] assert seq_lens.shape[0] == q.shape[0] assert gather_lens.shape[0] == q.shape[0] assert block_table.shape[0] == q.shape[0] - assert max_score.shape[0] == q.shape[0] - assert max_score.shape[1] <= q.shape[1] - assert denom.shape == max_score.shape - assert acc.shape == (*max_score.shape, q.shape[-1]) + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] assert head_block_size in (1, 2, 4) - assert max_score.dtype == torch.float32 - assert denom.dtype == torch.float32 - assert acc.dtype == torch.float32 - assert k_cache.dtype == torch.uint8 - assert q.is_cuda and k_cache.is_cuda + assert compressed_k_cache.dtype == torch.uint8 + assert swa_k_cache.dtype == torch.uint8 + assert q.is_cuda and compressed_k_cache.is_cuda and swa_k_cache.is_cuda + assert slot_ids.is_cuda and topk_lens.is_cuda assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda - assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert attn_sink.is_cuda and output.is_cuda token_fp8_dim = 448 token_bf16_dim = 64 @@ -542,37 +2594,44 @@ def accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( token_data_size = token_fp8_dim + token_bf16_dim * 2 num_tokens, _, head_dim = q.shape - num_heads = max_score.shape[1] + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] block_d = min(1024, triton.next_power_of_2(head_dim)) - grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) - _accumulate_fp8ds_paged_attention_chunk_multihead_kernel[grid]( + grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _fp8ds_global_paged_attention_with_sink_multihead_kernel[grid]( q, - k_cache, + compressed_k_cache, + slot_ids, + topk_lens, + swa_k_cache, seq_lens, gather_lens, block_table, - max_score, - denom, - acc, + attn_sink, + output, q.stride(0), q.stride(1), q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), block_table.stride(0), - max_score.stride(0), - max_score.stride(1), - acc.stride(0), - acc.stride(1), - acc.stride(2), - block_size, + output.stride(0), + output.stride(1), + output.stride(2), + compressed_block_size, + compressed_k_cache.stride(0), + swa_block_size, + swa_k_cache.stride(0), token_data_size, - k_cache.stride(0), token_fp8_dim, token_scale_dim, quant_block_size, - num_heads, + active_heads, head_dim, - num_candidates, - candidate_offset, + num_compressed_candidates, + num_swa_candidates, scale, HEAD_BLOCK=head_block_size, BLOCK_D=block_d, @@ -580,6 +2639,114 @@ def accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( ) +@triton.jit +def _finish_attention_state_kernel( + max_score_ptr, + denom_ptr, + acc_ptr, + output_ptr, + lse_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + stride_lse_t: tl.constexpr, + stride_lse_h: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + is_valid = running_denom > 0.0 + inv_denom = tl.where(is_valid, 1.0 / running_denom, 0.0) + subset_lse = tl.where( + is_valid, + running_max + tl.log(running_denom), + -float("inf"), + ) + + acc = tl.load( + acc_ptr + + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + subset_output = acc * inv_denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + subset_output, + mask=dim_mask, + ) + if block_d == 0: + tl.store( + lse_ptr + token_idx * stride_lse_t + head_idx * stride_lse_h, + subset_lse, + ) + + +def finish_gathered_sparse_mla_attention( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + output: torch.Tensor, + lse: torch.Tensor, +) -> None: + assert max_score.shape == denom.shape + assert acc.shape[:2] == max_score.shape + assert output.shape == acc.shape + assert lse.shape == max_score.shape + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert output.dtype == torch.float32 + assert lse.dtype == torch.float32 + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert output.is_cuda and lse.is_cuda + + num_tokens, num_heads, head_dim = acc.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_attention_state_kernel[grid]( + max_score, + denom, + acc, + output, + lse, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + lse.stride(0), + lse.stride(1), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + @triton.jit def _finish_attention_state_with_sink_kernel( max_score_ptr, From 14be6ce353c53fe097f6936992394ec0ff2ff50a Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 5 May 2026 17:05:19 +0800 Subject: [PATCH 011/131] Stabilize DeepSeek V4 MTP scheduling Signed-off-by: jasl --- tests/v1/spec_decode/test_mtp.py | 301 ++++++++++++++++++++++- vllm/models/deepseek_v4/nvidia/mtp.py | 7 +- vllm/v1/spec_decode/llm_base_proposer.py | 172 ++++++++++--- 3 files changed, 450 insertions(+), 30 deletions(-) diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index e334371f6d8a..53b91739c979 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -25,13 +25,49 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform from vllm.v1.attention.backends.registry import AttentionBackendEnum +from vllm.v1.sample.logits_processor import LogitsProcessors +from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.llm_base_proposer import compute_probs_and_sample_next_token mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base" DEVICE_TYPE = current_platform.device_type -def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: +def _create_sampling_metadata( + all_greedy: bool, + batch_size: int, + top_k: torch.Tensor | None = None, + top_p: torch.Tensor | None = None, +) -> SamplingMetadata: + temperature = None + if not all_greedy: + temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE_TYPE) + return SamplingMetadata( + temperature=temperature, + all_greedy=all_greedy, + all_random=not all_greedy, + top_p=top_p, + top_k=top_k, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=torch.tensor([], device=DEVICE_TYPE), + presence_penalties=torch.tensor([], device=DEVICE_TYPE), + repetition_penalties=torch.tensor([], device=DEVICE_TYPE), + output_token_ids=[], + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=LogitsProcessors(), + spec_token_ids=[], + ) + + +def _create_mtp_proposer( + num_speculative_tokens: int, + parallel_drafting: bool = False, +) -> EagleProposer: """Create an MTP proposer with unified model configuration.""" model_config = ModelConfig( model=mimo_7b_dir, runner="generate", max_model_len=100, trust_remote_code=True @@ -43,7 +79,10 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: model=mimo_7b_dir, method="mtp", num_speculative_tokens=num_speculative_tokens, + parallel_drafting=parallel_drafting, ) + if parallel_drafting: + speculative_config.draft_model_config.hf_config.ptd_token_id = 0 vllm_config = VllmConfig( model_config=model_config, @@ -219,3 +258,263 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset): assert model_mock.called # Verify output shape assert result.shape == (batch_size, num_speculative_tokens) + + +def test_mtp_propose_random_sampling_records_draft_probs(): + device = torch.device(DEVICE_TYPE) + batch_size = 2 + seq_lens = [3, 2] + total_tokens = sum(seq_lens) + vocab_size = 4 + + proposer = _create_mtp_proposer(num_speculative_tokens=1) + # Mirror upstream's f51f6844f gating: probabilities are only collected + # when the speculative config explicitly opts into the probabilistic + # draft-model rejection path. We force the flag here so the test + # exercises ``_sample_draft_tokens``'s probabilistic branch. + proposer._enable_probabilistic_draft_probs = True + hidden_size = proposer.hidden_size + + model_mock = mock.MagicMock() + model_mock.return_value = torch.zeros(total_tokens, hidden_size, device=device) + logits = torch.tensor([[0.0, 1.0, 2.0, 3.0], [3.0, 2.0, 1.0, 0.0]], device=device) + model_mock.compute_logits.return_value = logits.clone() + proposer.model = model_mock + proposer._draft_attn_layer_names = {"layer.0"} + + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) + common_attn_metadata = create_common_attn_metadata( + batch_spec, block_size=16, device=device + ) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=list(proposer._draft_attn_layer_names), + vllm_config=proposer.vllm_config, + device=device, + ) + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder + mock_attn_group.layer_names = list(proposer._draft_attn_layer_names) + mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec + proposer.draft_attn_groups = [mock_attn_group] + + result = proposer.propose( + target_token_ids=torch.randint(0, vocab_size, (total_tokens,), device=device), + target_positions=torch.arange(total_tokens, device=device), + target_hidden_states=torch.randn(total_tokens, hidden_size, device=device), + next_token_ids=torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ), + token_indices_to_sample=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=_create_sampling_metadata( + all_greedy=False, batch_size=batch_size + ), + ) + + assert result.shape == (batch_size, 1) + assert proposer._last_draft_probs is not None + assert proposer._last_draft_probs.shape == (batch_size, 1, vocab_size) + expected_probs = torch.softmax(logits, dim=-1).view(batch_size, 1, vocab_size) + assert torch.allclose(proposer._last_draft_probs, expected_probs) + # ``take_last_draft_probs`` is the upstream-side accessor that + # ``GPUModelRunner`` uses to plumb probs into the rejection sampler. + assert torch.equal( + proposer.take_last_draft_probs(), proposer._last_draft_probs + ) + + +def test_mtp_sequential_drafting_passes_spec_step_indices(): + device = torch.device(DEVICE_TYPE) + batch_size = 2 + seq_lens = [3, 2] + total_tokens = sum(seq_lens) + vocab_size = 4 + num_spec_tokens = 2 + + proposer = _create_mtp_proposer(num_speculative_tokens=num_spec_tokens) + proposer.block_size = 16 + hidden_size = proposer.hidden_size + + model_mock = mock.MagicMock() + model_mock.side_effect = [ + torch.zeros(total_tokens, hidden_size, device=device), + torch.zeros(batch_size, hidden_size, device=device), + ] + + def logits_for_token(token_id: int): + logits = torch.full((batch_size, vocab_size), -100.0, device=device) + logits[:, token_id] = 100.0 + return logits + + model_mock.compute_logits.side_effect = [ + logits_for_token(1), + logits_for_token(2), + ] + proposer.model = model_mock + proposer._draft_attn_layer_names = {"layer.0"} + + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) + common_attn_metadata = create_common_attn_metadata( + batch_spec, block_size=16, device=device + ) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=list(proposer._draft_attn_layer_names), + vllm_config=proposer.vllm_config, + device=device, + ) + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder + mock_attn_group.layer_names = list(proposer._draft_attn_layer_names) + mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec + proposer.draft_attn_groups = [mock_attn_group] + + result = proposer.propose( + target_token_ids=torch.randint(0, vocab_size, (total_tokens,), device=device), + target_positions=torch.arange(total_tokens, device=device), + target_hidden_states=torch.randn(total_tokens, hidden_size, device=device), + next_token_ids=torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ), + token_indices_to_sample=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=_create_sampling_metadata( + all_greedy=True, batch_size=batch_size + ), + ) + + assert torch.equal( + result, + torch.tensor([[1, 2], [1, 2]], device=device), + ) + assert [ + call.kwargs.get("spec_step_idx", 0) + for call in model_mock.compute_logits.call_args_list + ] == [0, 1] + assert [ + call.kwargs.get("spec_step_idx", 0) + for call in model_mock.call_args_list + ] == [0, 1] + + +def test_mtp_draft_sampling_applies_top_k_to_draft_probs(): + logits = torch.tensor([[0.0, 1.0, 2.0, 3.0]], device=DEVICE_TYPE) + top_k = torch.tensor([2], dtype=torch.int32, device=DEVICE_TYPE) + + _token_ids, draft_probs = compute_probs_and_sample_next_token( + logits, + _create_sampling_metadata(all_greedy=False, batch_size=1, top_k=top_k), + ) + + expected_logits = torch.tensor( + [[-float("inf"), -float("inf"), 2.0, 3.0]], device=DEVICE_TYPE + ) + expected_probs = torch.softmax(expected_logits, dim=-1, dtype=torch.float32) + assert torch.allclose(draft_probs, expected_probs) + + +def test_mtp_parallel_drafting_random_sampling_records_draft_probs(): + device = torch.device(DEVICE_TYPE) + batch_size = 2 + num_spec_tokens = 2 + seq_lens = [2, 2] + total_tokens = sum(seq_lens) + vocab_size = 4 + + proposer = _create_mtp_proposer( + num_speculative_tokens=num_spec_tokens, + parallel_drafting=True, + ) + # Mirror upstream's f51f6844f gating: see the matching comment in + # ``test_mtp_propose_random_sampling_records_draft_probs``. + proposer._enable_probabilistic_draft_probs = True + proposer.block_size = 16 + hidden_size = proposer.hidden_size + + model_mock = mock.MagicMock() + model_mock.return_value = torch.zeros( + total_tokens + batch_size, + hidden_size, + dtype=proposer.dtype, + device=device, + ) + logits = torch.tensor( + [ + [0.0, 1.0, 2.0, 3.0], + [3.0, 2.0, 1.0, 0.0], + [0.0, 0.5, 1.0, 1.5], + [1.5, 1.0, 0.5, 0.0], + ], + device=device, + ) + model_mock.compute_logits.return_value = logits.clone() + proposer.model = model_mock + proposer._draft_attn_layer_names = {"layer.0"} + + batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens) + common_attn_metadata = create_common_attn_metadata( + batch_spec, block_size=16, device=device + ) + attn_metadata_builder_cls, _ = try_get_attention_backend( + AttentionBackendEnum.FLASH_ATTN + ) + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=list(proposer._draft_attn_layer_names), + vllm_config=proposer.vllm_config, + device=device, + ) + mock_attn_group = mock.MagicMock() + mock_attn_group.get_metadata_builder.return_value = attn_metadata_builder + mock_attn_group.layer_names = list(proposer._draft_attn_layer_names) + mock_attn_group.kv_cache_spec = attn_metadata_builder.kv_cache_spec + proposer.draft_attn_groups = [mock_attn_group] + + result = proposer.propose( + target_token_ids=torch.randint(0, vocab_size, (total_tokens,), device=device), + target_positions=torch.arange(total_tokens, device=device), + target_hidden_states=torch.randn( + total_tokens, + hidden_size, + dtype=proposer.dtype, + device=device, + ), + next_token_ids=torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ), + token_indices_to_sample=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=_create_sampling_metadata( + all_greedy=False, batch_size=batch_size + ), + ) + + assert result.shape == (batch_size, num_spec_tokens) + assert proposer._last_draft_probs is not None + assert proposer._last_draft_probs.shape == ( + batch_size, num_spec_tokens, vocab_size + ) + assert torch.allclose( + proposer._last_draft_probs, + torch.softmax(logits, dim=-1).view(batch_size, num_spec_tokens, vocab_size), + ) + assert torch.equal( + proposer.take_last_draft_probs(), proposer._last_draft_probs + ) + + +# Tests for ``_get_draft_probs_for_rejection`` and the +# positional ``runner_stub._draft_probs`` packing path were removed when +# our MTP scheduling commit dropped that branch in favor of upstream's +# req-id-indexed ``_get_spec_decode_draft_probs`` (added by +# vllm-project/vllm#40269 / f51f6844f). The remaining +# ``_get_draft_probs_for_rejection`` tests that follow have been deleted +# along with the function; equivalent coverage for the new code lives +# under ``tests/v1/worker/test_gpu_model_runner.py``. diff --git a/vllm/models/deepseek_v4/nvidia/mtp.py b/vllm/models/deepseek_v4/nvidia/mtp.py index 64715deae99b..3f9253fe826d 100644 --- a/vllm/models/deepseek_v4/nvidia/mtp.py +++ b/vllm/models/deepseek_v4/nvidia/mtp.py @@ -158,7 +158,12 @@ def forward( inputs_embeds ).unsqueeze(-2) hidden_states, residual, post_mix, res_mix = self.mtp_block( - positions=positions, x=hidden_states, input_ids=None + x=hidden_states, + positions=positions, + input_ids=input_ids, + post_mix=None, + res_mix=None, + residual=None, ) hidden_states = mhc_post_tilelang(hidden_states, residual, post_mix, res_mix) # Return the flat pre-hc_head residual so it can be re-fed as the diff --git a/vllm/v1/spec_decode/llm_base_proposer.py b/vllm/v1/spec_decode/llm_base_proposer.py index b7c01d3ec1ce..b76c1f98160b 100644 --- a/vllm/v1/spec_decode/llm_base_proposer.py +++ b/vllm/v1/spec_decode/llm_base_proposer.py @@ -34,6 +34,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, UniformTypeKVCacheSpecs from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.topk_topp_sampler import ( + apply_top_k_top_p, empty_exponential_noise_like, sample_with_exponential_noise, ) @@ -235,7 +236,14 @@ def __init__( self._enable_probabilistic_draft_probs = ( self.speculative_config.rejection_sample_method == "standard" and self.speculative_config.draft_sample_method == "probabilistic" - ) + ) or self.method == "mtp" + # MTP drafts benefit from probabilistic sampling at ``temperature > 0``: + # the draft model is a near-copy of the target, so sampling from + # ``softmax(draft_logits)`` covers the target distribution much better + # than argmax and lifts acceptance by ~9 pp on DSv4-Flash mt-bench + # (measured 58.9% → 67.8%). The ``all_greedy`` fast path inside + # ``_sample_draft_tokens`` still falls back to argmax for ``temp==0`` + # requests so this override only changes behaviour where it helps. self._last_draft_probs: torch.Tensor | None = None self._slot_mapping_buffer = torch.zeros( @@ -409,11 +417,53 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) - def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor: + def _get_effective_spec_step_idx(self, spec_step_idx: int) -> int: + if self.method != "mtp": + return 0 + + config = getattr(self.model, "config", None) + num_mtp_layers = getattr(config, "num_nextn_predict_layers", None) + if not isinstance(num_mtp_layers, int): + inner_model = getattr(self.model, "model", None) + num_mtp_layers = getattr(inner_model, "num_mtp_layers", None) + + if isinstance(num_mtp_layers, int) and num_mtp_layers > 0: + return spec_step_idx % num_mtp_layers + return spec_step_idx + + def _compute_logits( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if self.method == "mtp": + return self.model.compute_logits( + hidden_states, + spec_step_idx=self._get_effective_spec_step_idx(spec_step_idx), + ) + return self.model.compute_logits(hidden_states) + + def _add_spec_step_idx( + self, + model_kwargs: dict[str, torch.Tensor | None], + spec_step_idx: int, + ) -> dict[str, torch.Tensor | int | None]: + if self.method != "mtp": + return model_kwargs + return { + **model_kwargs, + "spec_step_idx": self._get_effective_spec_step_idx(spec_step_idx), + } + + def _greedy_sample( + self, + hidden_states: torch.Tensor, + spec_step_idx: int = 0, + ) -> torch.Tensor: """Greedy-sample draft tokens from hidden states.""" if self.use_local_argmax_reduction: return self.model.get_top_tokens(hidden_states) - return self.model.compute_logits(hidden_states).argmax(dim=-1) + return self._compute_logits(hidden_states, spec_step_idx).argmax(dim=-1) def _sample_from_logits( self, @@ -432,10 +482,25 @@ def _sample_draft_tokens( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Sample draft token ids (and optionally their probabilities). + + ``spec_step_idx`` is forwarded to ``_compute_logits`` so that + MTP-style drafters can route the logits computation through the + correct draft layer (DSv4 cycles through ``num_nextn_predict_layers`` + layers). For non-MTP drafters this argument is ignored. + + Probabilities are only returned when + :attr:`_enable_probabilistic_draft_probs` is enabled (set by the + speculative config via ``rejection_sample_method="standard"`` and + ``draft_sample_method="probabilistic"``). In every other case we + fall back to argmax, matching the upstream + ``_sample_draft_tokens`` contract. + """ if not self._enable_probabilistic_draft_probs or sampling_metadata.all_greedy: - return self._greedy_sample(hidden_states), None - logits = self.model.compute_logits(hidden_states) + return self._greedy_sample(hidden_states, spec_step_idx), None + logits = self._compute_logits(hidden_states, spec_step_idx) return self._sample_from_logits(logits, sampling_metadata) def take_last_draft_probs(self) -> torch.Tensor | None: @@ -521,7 +586,7 @@ def propose( slot_mapping_size, common_attn_metadata.slot_mapping ), ): - ret_hidden_states = self.model(**model_kwargs) + ret_hidden_states = self.model(**self._add_spec_step_idx(model_kwargs, 0)) if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states @@ -548,8 +613,11 @@ def propose( # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1 or self.parallel_drafting: + # ``spec_step_idx=0`` selects the first MTP layer for DSv4 (and + # is ignored for non-MTP drafters). Probabilities only flow + # through when ``_enable_probabilistic_draft_probs`` is set. draft_token_ids, draft_probs = self._sample_draft_tokens( - sample_hidden_states, sampling_metadata + sample_hidden_states, sampling_metadata, spec_step_idx=0 ) if draft_probs is not None: self._last_draft_probs = draft_probs.view( @@ -570,9 +638,8 @@ def propose( self.positions[:batch_size] = positions draft_token_ids, draft_probs = self._sample_draft_tokens( - sample_hidden_states, sampling_metadata + sample_hidden_states, sampling_metadata, spec_step_idx=0 ) - draft_probs_list = None if draft_probs is None else [draft_probs] if self.allowed_attn_types is not None: for group_md in per_group_attn_metadata: @@ -586,6 +653,7 @@ def propose( # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] + draft_probs_list = [] if draft_probs is None else [draft_probs] cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = ( self._determine_batch_execution_and_padding(batch_size) @@ -664,7 +732,9 @@ def propose( cudagraph_runtime_mode=cudagraph_runtime_mode, slot_mapping=self._get_slot_mapping(input_batch_size), ): - ret_hidden_states = self.model(**model_kwargs) + ret_hidden_states = self.model( + **self._add_spec_step_idx(model_kwargs, token_index + 1) + ) if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = ret_hidden_states @@ -673,16 +743,17 @@ def propose( hidden_states = hidden_states[:batch_size] draft_token_ids, draft_probs = self._sample_draft_tokens( - last_hidden_states[:batch_size], sampling_metadata + last_hidden_states[:batch_size], + sampling_metadata, + spec_step_idx=token_index + 1, ) + draft_token_ids_list.append(draft_token_ids) if draft_probs is not None: - assert draft_probs_list is not None draft_probs_list.append(draft_probs) - draft_token_ids_list.append(draft_token_ids) # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - if draft_probs_list is not None: + if draft_probs_list: self._last_draft_probs = torch.stack(draft_probs_list, dim=1).contiguous() return draft_token_ids @@ -1702,12 +1773,10 @@ def _determine_batch_execution_and_padding( return cudagraph_mode, num_tokens_padded, num_tokens_across_dp -# NOTE(woosuk): Currently, the below code is not used and we always use argmax -# to sample the draft tokens. We will use this after we find a way to manage -# the draft prob tensor. -# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details. -# FIXME(woosuk): The logic here is duplicated with the main sampling code. -# We should refactor this to reuse the same sampling implementation. +# NOTE(woosuk): This duplicates part of the main sampling code because MTP +# needs both the sampled draft token ids and the draft probability tensor for +# rejection sampling. Refactor this once the sampler exposes a reusable helper +# that returns both values without extra packing. def compute_probs_and_sample_next_token( logits: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -1724,22 +1793,31 @@ def compute_probs_and_sample_next_token( # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0) # consistent with sampler.py's _SAMPLING_EPS threshold - temperature = sampling_metadata.temperature + num_tokens = logits.shape[0] + temperature = _expand_draft_sampling_tensor( + sampling_metadata.temperature, + num_tokens, + ) # Avoid division by zero if there are greedy requests. if not sampling_metadata.all_random: is_greedy = temperature < _SAMPLING_EPS temperature = torch.where(is_greedy, 1.0, temperature) logits.div_(temperature.view(-1, 1)) + top_k = _expand_draft_sampling_tensor(sampling_metadata.top_k, num_tokens) + top_p = _expand_draft_sampling_tensor(sampling_metadata.top_p, num_tokens) + logits = apply_top_k_top_p(logits, top_k, top_p) probs = logits.softmax(dim=-1, dtype=torch.float32) - # NOTE(woosuk): Currently, we ignore most of the sampling parameters in - # generating the draft tokens. We only use the temperature. While this - # could degrade the acceptance rate, it does not affect the distribution - # of the generated tokens after rejection sampling. - - # TODO(woosuk): Consider seeds. + generators = _expand_draft_sampling_generators( + sampling_metadata.generators, + sampling_metadata.temperature.shape[0], + num_tokens, + ) q = empty_exponential_noise_like(probs, use_fp64_gumbel) - q.exponential_() + if len(generators) != num_tokens: + q.exponential_() + for i, generator in generators.items(): + q[i].exponential_(generator=generator) # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs # will be used later for rejection sampling. next_token_ids = sample_with_exponential_noise(probs.clone(), q) @@ -1747,3 +1825,41 @@ def compute_probs_and_sample_next_token( greedy_token_ids = probs.argmax(dim=-1) next_token_ids = torch.where(is_greedy, greedy_token_ids, next_token_ids) return next_token_ids, probs + + +def _expand_draft_sampling_tensor( + tensor: torch.Tensor | None, + num_tokens: int, +) -> torch.Tensor | None: + if tensor is None or tensor.shape[0] == num_tokens: + return tensor + + batch_size = tensor.shape[0] + if num_tokens % batch_size != 0: + raise ValueError( + "Draft sampling metadata must either match the draft logits row " + "count or evenly divide it." + ) + return tensor.repeat_interleave(num_tokens // batch_size, dim=0) + + +def _expand_draft_sampling_generators( + generators: dict[int, torch.Generator], + batch_size: int, + num_tokens: int, +) -> dict[int, torch.Generator]: + if not generators or batch_size == num_tokens: + return generators + + if num_tokens % batch_size != 0: + raise ValueError( + "Draft sampling generators must either match the draft logits row " + "count or evenly divide it." + ) + + repeat = num_tokens // batch_size + return { + req_idx * repeat + offset: generator + for req_idx, generator in generators.items() + for offset in range(repeat) + } From 28714ae7a3039dd413567490a29cbfc789643208 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 8 May 2026 17:36:00 +0800 Subject: [PATCH 012/131] Warm DeepSeek V4 MTP spec-decode kernels Signed-off-by: jasl --- vllm/model_executor/warmup/kernel_warmup.py | 198 +++++++++++++++++++- 1 file changed, 196 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 3698582cfa0b..d89426707a04 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -41,6 +41,7 @@ ) _DEEPSEEK_V4_SPARSE_MLA_MIXED_WARMUP_TOKENS = 16 _DEEPSEEK_V4_SPARSE_MLA_PREFILL_WARMUP_TOKENS = 1024 +_DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS = (1, 2) _DEEPSEEK_V4_SLOT_MAPPING_WARMUP_TOKENS = tuple(range(1, 17)) + ( 32, 64, @@ -73,6 +74,38 @@ def _clamp_warmup_tokens(num_tokens: int, max_tokens: int) -> int: return max(0, min(num_tokens, max_tokens)) +def _is_deepseek_v4_mtp_spec_decode(runner: "GPUModelRunner") -> bool: + spec_config = getattr(runner, "speculative_config", None) + return ( + getattr(spec_config, "method", None) == "mtp" + and getattr(runner, "num_spec_tokens", 0) > 0 + ) + + +def _deepseek_v4_mtp_uniform_decode_warmup_requests( + runner: "GPUModelRunner", + max_tokens: int, + max_reqs: int, +) -> tuple[int, ...]: + if not _is_deepseek_v4_mtp_spec_decode(runner): + return () + + query_len = getattr( + runner, + "uniform_decode_query_len", + 1 + getattr(runner, "num_spec_tokens", 0), + ) + if query_len <= 0: + return () + + max_warmup_reqs = min(max_reqs, max_tokens // query_len) + return tuple( + reqs + for reqs in _DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS + if reqs <= max_warmup_reqs + ) + + def _deepseek_v4_slot_mapping_warmup(runner: "GPUModelRunner") -> None: max_tokens = getattr(runner, "max_num_tokens", 1) block_table = runner.input_batch.block_table @@ -183,6 +216,131 @@ def _deepseek_v4_request_prep_warmup(worker: "Worker") -> None: torch.accelerator.synchronize() +def _run_deepseek_v4_mtp_spec_decode_warmup_kernels( + *, + device: torch.device, + num_reqs: int, + num_spec_tokens: int, + vocab_size: int, + block_size: int, + max_model_len: int, +) -> None: + from vllm.v1.sample.logits_processor import LogitsProcessors + from vllm.v1.sample.metadata import SamplingMetadata + from vllm.v1.sample.rejection_sampler import rejection_sample + from vllm.v1.spec_decode.utils import ( + eagle_prepare_inputs_padded_kernel, + eagle_prepare_next_token_padded_kernel, + eagle_step_update_slot_mapping_and_metadata, + next_power_of_2, + ) + + num_sampled_tokens = num_spec_tokens + 1 + sampled_token_ids = torch.arange( + num_reqs * num_sampled_tokens, dtype=torch.int32, device=device + ).reshape(num_reqs, num_sampled_tokens) + sampled_token_ids.remainder_(vocab_size) + discard_request_mask = torch.zeros(num_reqs, dtype=torch.bool, device=device) + backup_next_token_ids = torch.zeros(num_reqs, dtype=torch.int32, device=device) + next_token_ids = torch.empty(num_reqs, dtype=torch.int32, device=device) + valid_sampled_tokens_count = torch.empty(num_reqs, dtype=torch.int32, device=device) + eagle_prepare_next_token_padded_kernel[(num_reqs,)]( + sampled_token_ids, + discard_request_mask, + backup_next_token_ids, + next_token_ids, + valid_sampled_tokens_count, + vocab_size, + num_sampled_tokens, + num_reqs, + sampled_token_ids.stride(0), + BLOCK_SIZE_TOKENS=next_power_of_2(num_sampled_tokens), + ) + + cu_num_draft_tokens = torch.arange( + num_spec_tokens, + num_reqs * num_spec_tokens + 1, + num_spec_tokens, + dtype=torch.int32, + device=device, + ) + query_start_loc = torch.arange( + 0, + (num_reqs + 1) * num_sampled_tokens, + num_sampled_tokens, + dtype=torch.int32, + device=device, + ) + token_indices_to_sample = torch.empty(num_reqs, dtype=torch.int32, device=device) + num_rejected_tokens = torch.empty(num_reqs, dtype=torch.int32, device=device) + eagle_prepare_inputs_padded_kernel[(num_reqs,)]( + cu_num_draft_tokens, + valid_sampled_tokens_count, + query_start_loc, + token_indices_to_sample, + num_rejected_tokens, + num_reqs, + ) + + positions = torch.arange(num_reqs, dtype=torch.int64, device=device) + block_table_tensor = torch.zeros((num_reqs, 1), dtype=torch.int32, device=device) + seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device) + out_clamped_positions = torch.empty_like(positions) + out_slot_mapping = torch.empty(num_reqs, dtype=torch.int64, device=device) + eagle_step_update_slot_mapping_and_metadata( + positions, + block_table_tensor, + seq_lens, + block_size, + max_model_len, + out_clamped_positions, + out_slot_mapping, + input_batch_size=num_reqs, + ) + + total_draft_tokens = num_reqs * num_spec_tokens + draft_token_ids = torch.arange(total_draft_tokens, dtype=torch.int32, device=device) + draft_token_ids.remainder_(vocab_size) + draft_probs = torch.rand( + total_draft_tokens, vocab_size, dtype=torch.float32, device=device + ) + draft_probs = draft_probs / draft_probs.sum(dim=-1, keepdim=True) + target_logits = torch.randn( + total_draft_tokens, vocab_size, dtype=torch.float32, device=device + ) + bonus_token_ids = torch.zeros((num_reqs, 1), dtype=torch.int32, device=device) + sampling_metadata = SamplingMetadata( + temperature=torch.full((num_reqs,), 0.7, dtype=torch.float32, device=device), + all_greedy=False, + all_random=True, + top_p=None, + top_k=None, + generators={}, + max_num_logprobs=None, + no_penalties=True, + prompt_token_ids=None, + frequency_penalties=torch.empty(0, device=device), + presence_penalties=torch.empty(0, device=device), + repetition_penalties=torch.empty(0, device=device), + output_token_ids=[[] for _ in range(num_reqs)], + allowed_token_ids_mask=None, + bad_words_token_ids={}, + logitsprocs=LogitsProcessors(), + logprob_token_ids=None, + spec_token_ids=[[] for _ in range(num_reqs)], + ) + rejection_sample( + draft_token_ids=draft_token_ids, + num_draft_tokens=[num_spec_tokens] * num_reqs, + max_spec_len=num_spec_tokens, + cu_num_draft_tokens=cu_num_draft_tokens, + draft_probs=draft_probs, + target_logits=target_logits, + bonus_token_ids=bonus_token_ids, + sampling_metadata=sampling_metadata, + ) + + def _deepseek_v4_sparse_mla_attention_warmup(worker: "Worker") -> None: if not envs.VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: return @@ -198,14 +356,21 @@ def _deepseek_v4_sparse_mla_attention_warmup(worker: "Worker") -> None: prefill_tokens = _clamp_warmup_tokens( _DEEPSEEK_V4_SPARSE_MLA_PREFILL_WARMUP_TOKENS, max_tokens ) - if mixed_tokens <= 0 and prefill_tokens <= 0: + uniform_decode_reqs = _deepseek_v4_mtp_uniform_decode_warmup_requests( + runner, + max_tokens=max_tokens, + max_reqs=worker.scheduler_config.max_num_seqs, + ) + if mixed_tokens <= 0 and prefill_tokens <= 0 and not uniform_decode_reqs: return logger.info( "Warming up DeepSeek V4 sparse MLA attention " - "for mixed tokens=%s and prefill tokens=%s.", + "for mixed tokens=%s, prefill tokens=%s, and MTP uniform decode " + "requests=%s.", mixed_tokens, prefill_tokens, + list(uniform_decode_reqs), ) if mixed_tokens > 0: runner._dummy_run( @@ -223,6 +388,35 @@ def _deepseek_v4_sparse_mla_attention_warmup(worker: "Worker") -> None: force_attention=True, create_single_prefill=True, ) + query_len = getattr(runner, "uniform_decode_query_len", 0) + for num_reqs in uniform_decode_reqs: + runner._dummy_run( + num_tokens=num_reqs * query_len, + skip_eplb=True, + is_profile=True, + force_attention=True, + uniform_decode=True, + ) + + if uniform_decode_reqs and current_platform.is_cuda_alike(): + vocab_size = runner.model_config.get_vocab_size() + block_size = getattr(runner.cache_config, "block_size", None) or 16 + logger.info( + "Warming up DeepSeek V4 MTP spec-decode kernels for request " + "counts=%s and %d draft tokens.", + list(uniform_decode_reqs), + runner.num_spec_tokens, + ) + for num_reqs in uniform_decode_reqs: + _run_deepseek_v4_mtp_spec_decode_warmup_kernels( + device=runner.device, + num_reqs=num_reqs, + num_spec_tokens=runner.num_spec_tokens, + vocab_size=vocab_size, + block_size=block_size, + max_model_len=runner.max_model_len, + ) + torch.accelerator.synchronize() def _flashinfer_autotune_cache_hash(runner: "GPUModelRunner") -> str: From 8102228f00c4c1685511279f9df2fb0ea20f02fe Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 12 May 2026 03:50:30 +0800 Subject: [PATCH 013/131] Tune dense FP8 block-scaled GEMM configs for SM12x DSv4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the placeholder configs added in commit 7b0f8b975 ("Add Blackwell tuning config aliases") with real autotuning results from benchmark_w8a8_ block_fp8.py on the actual hardware. Coverage: - M-keys extended from [1, 4, 8, 16, 32] to [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] — adds short prefill (M=64..128) and long prefill (M=256, 512) anchors that decode dispatch was previously rounding down to "M=32" placeholder. - 6 (N, K) shapes × 4 device variants (RTX PRO 6000 Workstation/Server/ Max-Q Edition + GB10) = 24 JSON files. - Hardware-specific: Workstation Edition tuned on physical RTX PRO 6000 Blackwell Workstation Edition; Server Edition and Max-Q Workstation Edition share the SM120 architecture and identical 24G/96G memory configs, only TGP differs, so they reuse the Workstation Edition tunings. GB10 (SM121) tuned separately on physical hardware. Search space: - Base: vllm's get_configs_compute_bound() — 1280 (BLOCK_M, BLOCK_N, BLOCK_K, GROUP_SIZE_M, num_warps, num_stages) combinations. - Per-M filter: BLOCK_SIZE_M >= max(16, M/8) (cap 64) for M>=64 — drops configs guaranteed to be catastrophic at large M (cdiv(M, BLOCK_M) > 8 iterations sentence the kernel to many M-loops on cold cache). - num_iters: 10 for M<=32, 7 for M=64..128, 5 for M>=256. Why the placeholders mattered: - Placeholder had BLOCK_M=16 for every M (since all 5 keys were copies of the same config). At M=256 the kernel did cdiv(256, 16) = 16 iterations along M; at M=512, 32 iterations. - Observed behavior: long-prefill at M=256 took 7+ minutes per request, M=512 didn't return within 40 minutes. Tuned configs pick BLOCK_M=64.. 128 for these M values (2-4 M-iterations), unblocking long prefill. Tuning wall clock: - Workstation Edition: 57.7 min on RTX PRO 6000 Blackwell Workstation Edition (single GPU). - GB10: 66.2 min on NVIDIA GB10 (single GPU). - Shape 1 + Shape 3 (cold compiles for K=4096 and K=1024) dominated; the other 4 shapes each took <2 min via Triton JIT cache reuse (M/N/K are runtime args, so (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages) cache hits across (N, K) once the K-divisibility class is compiled). Same hardware verifies: tests/quantization/test_sm12x_tuned_config_lookup.py still passes (asserts shape coverage, not contents). Co-Authored-By: Claude Opus 4.7 Signed-off-by: jasl --- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 64 +++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 70 +++++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 70 +++++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 70 +++++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 56 ++++++++++++--- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 54 ++++++++++++-- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 54 ++++++++++++-- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 54 ++++++++++++-- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 64 +++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 68 ++++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 68 ++++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 68 ++++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 50 +++++++++++-- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 66 +++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 66 +++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 66 +++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 58 ++++++++++++--- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 66 +++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 66 +++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 66 +++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 68 ++++++++++++++---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 60 +++++++++++++--- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 60 +++++++++++++--- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 60 +++++++++++++--- 24 files changed, 1236 insertions(+), 276 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json index 50dc6d9575e7..306fdae8639e 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,41 +1,81 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "num_warps": 4, + "num_stages": 4 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2 + "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 2 + "num_stages": 3 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 3 } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..387b572731b6 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "num_warps": 4, + "num_stages": 4 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..387b572731b6 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "num_warps": 4, + "num_stages": 4 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..387b572731b6 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "num_warps": 4, + "num_stages": 4 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json index 34d0f8699583..1cb973e5c383 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,5 +1,13 @@ { "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, @@ -11,32 +19,64 @@ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 2 + "num_stages": 3 }, "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2 + "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 2 + "num_stages": 3 }, "32": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2 + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index 8ad3a0197412..ac91f525e96b 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -3,11 +3,11 @@ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 3 }, - "4": { + "2": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, @@ -15,28 +15,68 @@ "num_warps": 8, "num_stages": 3 }, - "8": { + "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 }, "16": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index 8ad3a0197412..ac91f525e96b 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -3,11 +3,11 @@ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 3 }, - "4": { + "2": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, @@ -15,28 +15,68 @@ "num_warps": 8, "num_stages": 3 }, - "8": { + "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 }, "16": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index 8ad3a0197412..ac91f525e96b 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -3,11 +3,11 @@ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 3 }, - "4": { + "2": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, @@ -15,28 +15,68 @@ "num_warps": 8, "num_stages": 3 }, - "8": { + "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 }, "16": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json index cd7e6e91e663..2d655a2debac 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 3 + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 }, "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 2 + "num_stages": 4 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2 + "num_stages": 5 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..ac1053b588c5 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 3 }, - "4": { + "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, + "num_warps": 4, "num_stages": 3 }, - "32": { + "64": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..ac1053b588c5 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 3 }, - "4": { + "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, + "num_warps": 4, "num_stages": 3 }, - "32": { + "64": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..ac1053b588c5 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 3 }, - "4": { + "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, + "num_warps": 4, "num_stages": 3 }, - "32": { + "64": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json index 8ad3a0197412..2026a21038b9 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -3,23 +3,31 @@ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 }, "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 }, "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 3 }, @@ -36,7 +44,39 @@ "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..be96d80c51f1 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 3 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "GROUP_SIZE_M": 32, + "num_warps": 4, "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 3 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..be96d80c51f1 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 3 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "GROUP_SIZE_M": 32, + "num_warps": 4, "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 3 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..be96d80c51f1 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 3 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "GROUP_SIZE_M": 32, + "num_warps": 4, "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 3 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json index 8ad3a0197412..1e84c847ff37 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -3,9 +3,17 @@ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 }, "4": { "BLOCK_SIZE_M": 16, @@ -13,30 +21,62 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, + "num_warps": 4, "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, + "num_warps": 4, "num_stages": 3 }, "32": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..5163bc4f3da1 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 }, - "4": { + "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "GROUP_SIZE_M": 64, + "num_warps": 4, "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..5163bc4f3da1 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 }, - "4": { + "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "GROUP_SIZE_M": 64, + "num_warps": 4, "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..5163bc4f3da1 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, - "num_stages": 3 + "num_stages": 4 }, - "4": { + "2": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 3 }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, + "GROUP_SIZE_M": 64, + "num_warps": 4, "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json index 217028c5412e..36a2e6621f2e 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -2,41 +2,81 @@ "1": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 }, - "4": { + "2": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 2 + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 }, - "8": { + "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 2 + "num_stages": 3 }, - "16": { + "8": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 2 + "num_stages": 3 }, "32": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 2 + "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..5a6f1de61395 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 3 }, - "4": { + "2": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 3 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "16": { "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 3 }, - "32": { - "BLOCK_SIZE_M": 16, + "128": { + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..5a6f1de61395 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 3 }, - "4": { + "2": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 3 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "16": { "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 3 }, - "32": { - "BLOCK_SIZE_M": 16, + "128": { + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 3 } } diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json index da4016483cc6..5a6f1de61395 100644 --- a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,42 +1,82 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 3 }, - "4": { + "2": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 3 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 }, "16": { "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 3 }, - "32": { - "BLOCK_SIZE_M": 16, + "128": { + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, - "num_warps": 8, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, "num_stages": 3 } } From 919365f0913689afecd79f4e7da5bde29b5ebbca Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 12 May 2026 07:04:42 +0800 Subject: [PATCH 014/131] T1-D: adaptive BLOCK_M for _fp8_paged_mqa_logits_kernel (SM12x) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The single-stream decode profile showed `_fp8_paged_mqa_logits_kernel` at 12.61% of decode kernel time, 84.87 µs/call — #3 single hotspot on SM120 TP=2 after T1-A. Investigation: the launch used hardcoded `BLOCK_M=4` regardless of `num_rows = batch_size * next_n`. For the common no-MTP single-stream decode case, num_rows=1, which means 75% of the M-axis work (3 of 4 rows) is masked off and discarded — pure waste of compute and memory bandwidth. Fix: pick the smallest power-of-2 tile that still covers num_rows. - num_rows == 1 (no-MTP decode, batch=1): BLOCK_M=1 - num_rows == 2: BLOCK_M=2 - num_rows in [3, 4] (MTP=2 batch=1, or batch=4 prefill chunks): BLOCK_M=4 - num_rows > 4: BLOCK_M=8 Cost: each unique block_m value compiles a separate Triton specialization, so cudagraph capture exercises four variants instead of one. Triton JIT cache amortises this — first warmup adds a few seconds, subsequent loads cache-hit. Expected impact: - Single-stream decode (num_rows=1): 84.87 µs/call → ~25-30 µs/call (eliminate 3 of 4 wasted rows). At 42 calls/tok that's ~2.3 ms/tok TPOT improvement, ~8% throughput uplift on no-MTP single-stream. - MTP=2 (num_rows=3 typical): BLOCK_M=4 unchanged (1 row masked, same as before). No regression. - Prefill (num_rows >= 4): BLOCK_M=4 or 8 picked — covers full work. Risk: low. Kernel logic unchanged; only the launch tile size adapts. Co-Authored-By: Claude Opus 4.7 Signed-off-by: jasl --- vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py index 85dab1f6d9a1..b7afa0de66a9 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py @@ -317,7 +317,20 @@ def fp8_paged_mqa_logits_triton( context_lens_2d = context_lens.reshape(batch_size, -1) if context_lens_2d.shape[1] == 1 and next_n != 1: context_lens_2d = context_lens_2d.expand(batch_size, next_n).contiguous() - grid = (triton.cdiv(num_rows, 4), triton.cdiv(token_count, 64)) + # Adaptive BLOCK_M: the kernel masks off positions >= num_rows, so a fixed + # BLOCK_M=4 wastes ~75% of M-axis work in the common single-stream decode + # case (num_rows=1). Pick the smallest power-of-2 tile that still covers + # num_rows so we keep one grid-program for typical decode while still + # benefiting from larger tiles when batch / MTP push num_rows higher. + if num_rows <= 1: + block_m = 1 + elif num_rows <= 2: + block_m = 2 + elif num_rows <= 4: + block_m = 4 + else: + block_m = 8 + grid = (triton.cdiv(num_rows, block_m), triton.cdiv(token_count, 64)) _fp8_paged_mqa_logits_kernel[grid]( q, kv_values, @@ -350,7 +363,7 @@ def fp8_paged_mqa_logits_triton( block_tables.stride(1), logits.stride(0), logits.stride(1), - BLOCK_M=4, + BLOCK_M=block_m, BLOCK_N=64, BLOCK_D=64, num_warps=4, From 326b4807b0920d0aedee73b51379e661f066c52b Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 12 May 2026 07:22:47 +0800 Subject: [PATCH 015/131] T2-A: clamp BLOCK_D in sparse MLA finish kernel to head_dim MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The candidate_block path in finish_materialized_sparse_mla_scores_with_sink took the caller-supplied value_block_size as BLOCK_D directly. DSv4 calls it with value_block_size=512 (matmul_sparse_mla_attention_with_sink default for use_dot_finish=True) but the actual head_dim from combined_kv is qk_nope+qk_rope = 128+64 = 192. With BLOCK_D=512, the kernel masks off positions 192..511 per program — 62.5% of D-axis work discarded. Fix: clamp block_d to the smallest power-of-2 >= head_dim from the allowed set {64, 128, 256, 512}. For DSv4 head_dim=192 this picks BLOCK_D=256 (25% mask waste instead of 62.5%). Caller-supplied value_block_size smaller than the target (intentional fine-grained D-axis splits) is still respected. Expected impact on SM12x decode profile: _finish_materialized_scores_with _sink_candidate_block_kernel time per call drops from 17.92 µs to roughly half (less work per program, same grid size). At 82 calls/token that's ~0.7 ms/tok TPOT savings → ~2-3% throughput uplift on top of T1-D. Risk: low. Kernel logic unchanged; only the per-launch BLOCK_D adapts to the actual head_dim, falling back to 512 for head_dim > 512. Co-Authored-By: Claude Opus 4.7 Signed-off-by: jasl --- .../backends/mla/sparse_mla_kernels.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py index 867eaf298149..1f4460fa5d70 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -971,8 +971,24 @@ def finish_materialized_sparse_mla_scores_with_sink( num_tokens, _, num_candidates = scores.shape head_dim = kv.shape[2] + # Clamp BLOCK_D to the smallest power-of-2 >= head_dim (among the allowed + # {64, 128, 256, 512}). Without this, a caller-supplied value_block_size + # larger than head_dim wastes work on masked-off positions — e.g. DSv4 + # head_dim=192 with value_block_size=512 masks off 62.5% of D-axis work + # in every program. Smaller caller values (intentional fine-grained + # splits along D) are respected. + def _smallest_block_d_covering(hd: int) -> int: + for cand in (64, 128, 256, 512): + if cand >= hd: + return cand + return 512 # head_dim > 512: BLOCK_D=512, grid splits along D + if candidate_block_size is not None: - block_d = value_block_size if value_block_size is not None else 128 + target_block_d = _smallest_block_d_covering(head_dim) + if value_block_size is None: + block_d = target_block_d + else: + block_d = min(value_block_size, target_block_d) candidate_grid = (num_tokens, active_heads, triton.cdiv(head_dim, block_d)) _finish_materialized_scores_with_sink_candidate_block_kernel[candidate_grid]( scores, From 5530c3758b03c8fd1e3104d8b626f315c8dc4d79 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 12 May 2026 11:36:07 +0800 Subject: [PATCH 016/131] Extend DeepSeek V4 prefill warmup to max single-chunk size MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The prefill warmup constant was 1024 tokens. With `max_num_batched_tokens = 8192` (the canonical SM12x serve setting), the first real request that prefills more than 1024 tokens in a single chunk has to JIT-compile the dense FP8 W8A8 block-scaled GEMM at the larger M, plus the sparse-MLA prefill kernel against a longer KV slab. T1-A's autotuned config space makes the cold-compile cost bigger, not smaller, so any user who issues an 8K-context first request after a fresh serve currently waits on Triton compilation that the warmup hook is supposed to absorb. Lift the constant to 8192. The call site already clamps via `_clamp_warmup_tokens(requested, scheduler_config.max_num_batched_tokens)` so schedulers running with a smaller batched-token cap naturally warm at their cap, and configurations that lift the cap above 8192 keep this floor (the cost of warming beyond 8192 grows fast enough that we want a deliberate decision rather than implicit scaling). Measured on 2x RTX PRO 6000 Workstation Edition (SM120, TP=2 EP, max_num_batched_tokens=8192) with a cold random ISL=8192 OSL=512 num-prompts=4 c=1 bench against a freshly-restarted serve: Before: TTFT mean ≈ 17 s (cold first request dominates the average) After: TTFT mean 3,172 ms, TTFT p99 3,176 ms (mean ≈ p99 — the cold-start variance disappears) Throughput: 61.16 tok/s vs 54.20 tok/s on the 020e0c89a baseline for the same shape (+13 %) Startup time: 71 s -> 80 s (+9 s one-time) Signed-off-by: jasl --- vllm/model_executor/warmup/kernel_warmup.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index d89426707a04..18f05144de74 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -40,7 +40,12 @@ } ) _DEEPSEEK_V4_SPARSE_MLA_MIXED_WARMUP_TOKENS = 16 -_DEEPSEEK_V4_SPARSE_MLA_PREFILL_WARMUP_TOKENS = 1024 +# Cap warmup at the largest single-chunk prefill the scheduler will ever +# issue (max_num_batched_tokens). 8192 covers the canonical SM12x serve +# (max_num_batched_tokens=8192); larger scheduler caps clamp to this +# value via _clamp_warmup_tokens at the call site, smaller caps clamp +# down naturally. +_DEEPSEEK_V4_SPARSE_MLA_PREFILL_WARMUP_TOKENS = 8192 _DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS = (1, 2) _DEEPSEEK_V4_SLOT_MAPPING_WARMUP_TOKENS = tuple(range(1, 17)) + ( 32, From c7653d8da2556be459dd0ca07595a93181de0db5 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 12 May 2026 12:56:21 +0800 Subject: [PATCH 017/131] Extend DeepSeek V4 warmup coverage to multi-request shapes Three follow-on fixes on top of `5c8975591`: 1. Drop the hardcoded `_DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS = (1, 2)` ceiling and append `scheduler_config.max_num_seqs` so MTP uniform-decode warmup also covers the largest in-flight batch the server will ever issue. On a Spark MTP=2 cluster with max_num_seqs=4 this lifts the random ISL=8,192 OSL=512 c=4 cold throughput from 23.67 tok/s to 42.82 tok/s (+81 %) by warming the `_fp8_paged_mqa_logits_kernel` adaptive `BLOCK_M=8` path that the (1, 2) tier missed. 2. Add a chunked-prefill warmup `_dummy_run` that sets `profile_seq_lens = prefill_tokens * 2` so the indexer / sparse-MLA builders see "this is the second chunk of a longer sequence", not only the freshly-arriving single chunk. 3. Add a multi-request prefill warmup `_dummy_run` (no `create_*` flags) so the runner splits the batched-token budget across `max_num_seqs` requests and exercises the multi-prefill indexer path that single-request prefill warmup skips. Cost: ~+35 s startup on Spark (init engine: 61.33 s -> 96.81 s) for a one-time JIT pass over the larger shape coverage. Limitation: vLLM's `jit_monitor` shows nine kernels still JIT during the first c=1 cold bench, including `eagle_prepare_next_token_padded_kernel` and `_w8a8_triton_block_scaled_mm` at alt shapes. These kernels are already invoked from `_run_deepseek_v4_mtp_spec_decode_warmup_kernels`, but the synthesized warmup tensors hit a different Triton specialization key (notably pointer 16-byte alignment) than the sampler / spec-decode buffers used in real inference. Closing this gap requires routing warmup through the actual scheduler / sampler pipeline rather than a dummy_run helper, which is a larger upstream change. The harness (`scripts/prewarm_serve.sh`, also auto-invoked by `scripts/dgx_spark_start_mp_serve.sh`) issues real-pipeline prewarm requests after `/health=200` to absorb the remaining cold-start cost on the deployment side. Signed-off-by: jasl --- vllm/model_executor/warmup/kernel_warmup.py | 34 ++++++++++++++++++--- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 18f05144de74..1fbe45f1c50e 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -46,6 +46,11 @@ # value via _clamp_warmup_tokens at the call site, smaller caps clamp # down naturally. _DEEPSEEK_V4_SPARSE_MLA_PREFILL_WARMUP_TOKENS = 8192 +# Steady-state MTP decode shapes to warm. We always include 1 and 2; the +# scheduler's `max_num_seqs` is appended dynamically at the call site so +# kernels selected per-shape (e.g. `_fp8_paged_mqa_logits_kernel`'s +# adaptive BLOCK_M) are covered for the largest in-flight batch the +# server will ever issue. _DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS = (1, 2) _DEEPSEEK_V4_SLOT_MAPPING_WARMUP_TOKENS = tuple(range(1, 17)) + ( 32, @@ -104,11 +109,8 @@ def _deepseek_v4_mtp_uniform_decode_warmup_requests( return () max_warmup_reqs = min(max_reqs, max_tokens // query_len) - return tuple( - reqs - for reqs in _DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS - if reqs <= max_warmup_reqs - ) + candidates = sorted(set(_DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS) | {max_reqs}) + return tuple(reqs for reqs in candidates if reqs <= max_warmup_reqs) def _deepseek_v4_slot_mapping_warmup(runner: "GPUModelRunner") -> None: @@ -393,6 +395,28 @@ def _deepseek_v4_sparse_mla_attention_warmup(worker: "Worker") -> None: force_attention=True, create_single_prefill=True, ) + # Simulate the second-and-later chunk of a chunked prefill so + # `_build_prefill_chunk_metadata_kernel` and the alt-shape + # `_w8a8_triton_block_scaled_mm` configs that fire when the + # indexer sees prior context get JIT-compiled here, not on the + # first user request that exceeds `max_num_batched_tokens`. + runner._dummy_run( + num_tokens=prefill_tokens, + skip_eplb=True, + is_profile=True, + force_attention=True, + create_single_prefill=True, + profile_seq_lens=prefill_tokens * 2, + ) + # Multi-request prefill (max_num_seqs prefills sharing the + # batched-token budget) covers the other dense-GEMM shape bucket + # plus the multi-prefill indexer path. + runner._dummy_run( + num_tokens=prefill_tokens, + skip_eplb=True, + is_profile=True, + force_attention=True, + ) query_len = getattr(runner, "uniform_decode_query_len", 0) for num_reqs in uniform_decode_reqs: runner._dummy_run( From a6c292ab586a4d51b11f01850dd8359719a28a0a Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 13 May 2026 11:39:59 +0800 Subject: [PATCH 018/131] Restore rowwise paged-MQA logits kernel for SM12x long context MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The earlier `ds4-sm120-full` PoC branch shipped two FP8 paged-MQA logits kernels — a generic 2D-tiled one and a per-row variant (`_fp8_paged_mqa_logits_rowwise_kernel`) tuned for long decode contexts. During the file split that produced `vllm/v1/attention/ops/deepseek_v4_ops/sm12x_mqa.py`, only the 2D-tiled kernel was carried over; the rowwise variant and its dispatcher gate were dropped. Users running ctx > 100K with MTP=2 on RTX PRO 6000 (Max-Q) report ~20% throughput regression vs the PoC branch on the "Red-Black Tree, max_tokens=2048, thinking on" 5-run probe (~85 tok/s here vs ~108 tok/s on `da4f1c711`). Single-stream short contexts are unaffected because the 2D-tile work scales with `token_count` cdiv 64 and the rowwise win comes from Q-reuse across the full 4K-128K window per program — exactly the regime the bug report hits. This commit restores the rowwise kernel verbatim from `da4f1c711` (its routing predicate is aligned with `4c9ee613d`, dropping the `next_n == 1` constraint so MTP=2 also hits the rowwise path), and re-introduces the dispatch in `fp8_paged_mqa_logits_triton`: if head_dim % 64 == 0 and num_heads % 4 == 0: return fp8_paged_mqa_logits_rowwise_triton(...) DSv4-Flash (head_dim=128, num_heads=64) always satisfies both predicates so all real serves take the rowwise path; the 2D-tiled kernel remains as the fallback for misaligned shapes and is still the canonical reference the rowwise kernel was validated against in the original PoC tests. The recently-added T1-D adaptive `BLOCK_M` (commit `959a04df5`) is preserved in the 2D-tiled path. On DSv4-Flash it becomes dead code in practice, but kept for portability and to keep the diff isolated from the long-context regression fix. Signed-off-by: jasl --- .../deepseek_v4/nvidia/ops/sm12x_mqa.py | 227 ++++++++++++++++++ 1 file changed, 227 insertions(+) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py index b7afa0de66a9..8593989e8647 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py @@ -287,6 +287,215 @@ def _fp8_paged_mqa_logits_kernel( ) +@triton.jit +def _fp8_paged_mqa_logits_rowwise_kernel( + q_ptr, + kv_ptr, + scale_ptr, + weights_ptr, + context_lens_ptr, + block_tables_ptr, + logits_ptr, + token_start, + num_rows: tl.constexpr, + logits_width: tl.constexpr, + next_n: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + block_size: tl.constexpr, + stride_qb: tl.constexpr, + stride_qn: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kvb: tl.constexpr, + stride_kvs: tl.constexpr, + stride_kvd: tl.constexpr, + stride_sb: tl.constexpr, + stride_ss: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_clb: tl.constexpr, + stride_cln: tl.constexpr, + stride_btb: tl.constexpr, + stride_btk: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_H: tl.constexpr, +): + """Per-row paged-MQA logits kernel optimised for long ``token_count``. + + Each Triton program handles one logical row (``batch * next_n + q_pos``) + across a ``BLOCK_N``-wide window of token positions. Q is loaded once per + head tile and reused for every K element in the window, which preserves + L2 / register locality and avoids the M-axis padding waste of the + generic 2D-tiled kernel at long contexts (mt-bench c=1 MTP=2 num_rows=3 + with token_count=131072 launches 12k programs of 128 logits each rather + than 8k programs of 64 logits with 25 % M-axis waste). + """ + row = tl.program_id(0) + pid_n = tl.program_id(1) + offs_local_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = token_start + offs_local_n + offs_d = tl.arange(0, BLOCK_D) + + valid_row = row < num_rows + valid_n = offs_local_n < logits_width + batch = row // next_n + q_pos = row - batch * next_n + context_len = tl.load( + context_lens_ptr + batch * stride_clb + q_pos * stride_cln, + mask=valid_row, + other=0, + ) + if token_start + pid_n * BLOCK_N >= context_len: + logits = tl.full((BLOCK_N,), float("-inf"), dtype=tl.float32) + tl.store( + logits_ptr + row * stride_lm + offs_local_n * stride_ln, + logits, + mask=valid_row & valid_n, + ) + return + context_mask = valid_n & (offs_n < context_len) + + block_rank = offs_n // block_size + block_offset = offs_n - block_rank * block_size + block_idx = tl.load( + block_tables_ptr + batch * stride_btb + block_rank * stride_btk, + mask=valid_row & context_mask, + other=0, + ) + + scale = tl.load( + scale_ptr + block_idx * stride_sb + block_offset * stride_ss, + mask=context_mask, + other=0.0, + ) + logits = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for h0 in tl.range(0, num_heads, BLOCK_H): + heads = h0 + tl.arange(0, BLOCK_H) + valid_h = heads < num_heads + scores = tl.zeros((BLOCK_H, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + batch * stride_qb + + q_pos * stride_qn + + heads[:, None] * stride_qh + + d[None, :] * stride_qd, + mask=valid_row & valid_h[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + kv_ptr + + block_idx[None, :] * stride_kvb + + block_offset[None, :] * stride_kvs + + d[:, None] * stride_kvd, + mask=context_mask[None, :] & (d[:, None] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.dot(q, k, input_precision="tf32") + + weighted = tl.maximum(scores * scale[None, :], 0.0) + weight = tl.load( + weights_ptr + row * stride_wm + heads * stride_wh, + mask=valid_row & valid_h, + other=0.0, + ) + logits += tl.sum(weighted * weight[:, None], axis=0) + + logits = tl.where(context_mask & valid_row, logits, float("-inf")) + tl.store( + logits_ptr + row * stride_lm + offs_local_n * stride_ln, + logits, + mask=valid_row & valid_n, + ) + + +def fp8_paged_mqa_logits_rowwise_triton( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + token_start: int = 0, + token_count: int | None = None, +) -> torch.Tensor: + """Rowwise paged-MQA logits wrapper. + + Pre-condition: ``head_dim % 64 == 0`` and ``num_heads % 4 == 0`` so the + ``tl.dot`` inside ``_fp8_paged_mqa_logits_rowwise_kernel`` lands on + tensor-core friendly tile shapes. DSv4-Flash (head_dim=128, + num_heads=64) satisfies both and is the only model that exercises this + path today; the generic 2D kernel below remains the fallback for + misaligned shapes. + """ + batch_size, next_n, num_heads, head_dim = q.size() + kv_values, kv_scale = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_size, _, _ = kv_values.size() + num_rows = batch_size * next_n + if token_count is None: + token_count = max_model_len - token_start + assert token_start >= 0 + assert token_count >= 0 + assert token_start + token_count <= max_model_len + logits = torch.empty( + (num_rows, token_count), + device=q.device, + dtype=torch.float32, + ) + if num_rows == 0 or token_count == 0: + return logits + + context_lens_2d = context_lens.reshape(batch_size, -1) + if context_lens_2d.shape[1] == 1 and next_n != 1: + context_lens_2d = context_lens_2d.expand(batch_size, next_n).contiguous() + block_n = 128 + grid = (num_rows, triton.cdiv(token_count, block_n)) + _fp8_paged_mqa_logits_rowwise_kernel[grid]( + q, + kv_values, + kv_scale, + weights, + context_lens_2d, + block_tables, + logits, + token_start, + num_rows, + token_count, + next_n, + num_heads, + head_dim, + block_size, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + kv_values.stride(0), + kv_values.stride(1), + kv_values.stride(3), + kv_scale.stride(0), + kv_scale.stride(1), + weights.stride(0), + weights.stride(1), + context_lens_2d.stride(0), + context_lens_2d.stride(1), + block_tables.stride(0), + block_tables.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_N=block_n, + BLOCK_D=64, + BLOCK_H=8, + num_warps=4, + ) + return logits + + def fp8_paged_mqa_logits_triton( q: torch.Tensor, kv_cache: torch.Tensor, @@ -298,6 +507,24 @@ def fp8_paged_mqa_logits_triton( token_count: int | None = None, ) -> torch.Tensor: batch_size, next_n, num_heads, head_dim = q.size() + # Aligned head shapes (DSv4-Flash and any future MQA model with + # ``head_dim % 64 == 0`` and ``num_heads % 4 == 0``) get the rowwise + # kernel, which keeps long-context decode (>100K tokens) on a per-row + # grid that re-uses Q across the full token window. The generic 2D + # kernel below still handles misaligned shapes and remains the canonical + # reference for the rowwise variant. + if head_dim % 64 == 0 and num_heads % 4 == 0: + return fp8_paged_mqa_logits_rowwise_triton( + q, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + token_start=token_start, + token_count=token_count, + ) + kv_values, kv_scale = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) _, block_size, _, _ = kv_values.size() num_rows = batch_size * next_n From a042a41b449981531382387f3110b69fea292f93 Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 14 May 2026 06:24:58 +0800 Subject: [PATCH 019/131] reasoning: defensive implicit for DeepSeek V4 tool-call streaming MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a DSv4-specific reasoning parser (`DeepSeekV4ThinkingReasoningParser`, `DeepSeekV4ReasoningParser`) that treats the DSML tool-call start marker `<|DSML|tool_calls>` as an implicit end-of-reasoning when the explicit `` token is absent. Why --- DSv4-Flash at long context (~95k-100k input tokens) occasionally fails to emit `` before opening a tool call. The existing `DeepSeekR1ReasoningParser` keeps the parser stuck in reasoning mode in that case: the tool-call start marker (and everything after) is classified as reasoning, the orchestrator never advances to the tool parser, and the caller sees a turn with reasoning but no tool call. opencode's agent loop interprets that as "no tool to dispatch" and exits — visually indistinguishable from "the model gave up". Reproduced 18% of the time at 95-100k input tokens with `tool_choice` auto and 25 tools in scope. Full repro bundle (Python script + SSE trace + opencode forensics) lives in the harness repo. What ---- - New module `vllm/reasoning/deepseek_v4_reasoning_parser.py` providing `DeepSeekV4ThinkingReasoningParser` (extends `DeepSeekR1ReasoningParser` with one defensive split), plus the dispatcher pair `DeepSeekV4ReasoningParser` and `DeepSeekV4ReasoningWithThinkingParser` matching the V3 shape. - The dispatcher mirrors `DeepSeekV3ReasoningParser`: thinking-mode uses the V4 extension, non-thinking uses `IdentityReasoningParser`. - Sticky `_implicit_end_seen` flag on the parser instance ensures `is_reasoning_end[_streaming]` returns True for every delta after the marker first appears, so the orchestrator state machine transitions correctly even when the marker spans a token boundary. - `vllm/reasoning/__init__.py` re-points the `deepseek_v4` registration from `DeepSeekV3ReasoningParser` to the new `DeepSeekV4ReasoningParser`. `deepseek_v3` is unchanged. What does NOT change -------------------- - Healthy streams (explicit ``) take the same code path as before: the V4 parser defers to `super()` and the defensive split only fires when no explicit start/end token has been seen. - The DSv32 tool parser is untouched. - V3 reasoning parser and registration are untouched. Tests ----- - `tests/reasoning/test_deepseekv4_reasoning_parser.py` covers the registration, dispatcher selection, healthy paths (parent behaviour), implicit-end-marker in isolated delta, implicit-end split within delta, sticky behaviour after first marker, suppression when `` is explicitly present, `is_reasoning_end` for explicit and implicit cases, and the parent's single-token guard. - `tests/reasoning/test_deepseekv3_reasoning_parser.py` updated: the `deepseek_v4` alias now resolves to `DeepSeekV4ReasoningParser`, while `deepseek_v3` still resolves to `DeepSeekV3ReasoningParser`. The fix is intentionally narrow: it addresses one well-defined failure mode (tool call without closing reasoning). The "runaway reasoning to length limit with no tool call" and "premature reasoning stop with no tool call" subtypes seen at long context are model-behaviour issues, not parser bugs, and are left for a separate follow-up. Signed-off-by: jasl --- .../test_deepseekv3_reasoning_parser.py | 15 +- .../test_deepseekv4_reasoning_parser.py | 316 ++++++++++++++++++ vllm/reasoning/__init__.py | 4 +- .../reasoning/deepseek_v4_reasoning_parser.py | 255 ++++++++++++++ 4 files changed, 586 insertions(+), 4 deletions(-) create mode 100644 tests/reasoning/test_deepseekv4_reasoning_parser.py create mode 100644 vllm/reasoning/deepseek_v4_reasoning_parser.py diff --git a/tests/reasoning/test_deepseekv3_reasoning_parser.py b/tests/reasoning/test_deepseekv3_reasoning_parser.py index f5b37194f927..49b373dbe332 100644 --- a/tests/reasoning/test_deepseekv3_reasoning_parser.py +++ b/tests/reasoning/test_deepseekv3_reasoning_parser.py @@ -34,10 +34,21 @@ def test_parser_selection(tokenizer, thinking, expected_parser_type): assert isinstance(parser._parser, expected_parser_type) -def test_deepseek_v4_reasoning_parser_alias(): +def test_deepseek_v4_reasoning_parser_registration(): + """``deepseek_v4`` now resolves to its own parser (a defensive + extension of ``DeepSeekV3ReasoningParser``) rather than reusing the + V3 parser directly. See ``tests/reasoning/test_deepseekv4_reasoning_parser.py``. + """ + from vllm.reasoning.deepseek_v4_reasoning_parser import ( + DeepSeekV4ReasoningParser, + ) + parser_cls = ReasoningParserManager.get_reasoning_parser("deepseek_v4") - assert parser_cls is DeepSeekV3ReasoningParser + assert parser_cls is DeepSeekV4ReasoningParser + # The V3 alias must remain pointed at V3. + v3_cls = ReasoningParserManager.get_reasoning_parser("deepseek_v3") + assert v3_cls is DeepSeekV3ReasoningParser def test_identity_reasoning_parser_basic(tokenizer): diff --git a/tests/reasoning/test_deepseekv4_reasoning_parser.py b/tests/reasoning/test_deepseekv4_reasoning_parser.py new file mode 100644 index 000000000000..800f1285a352 --- /dev/null +++ b/tests/reasoning/test_deepseekv4_reasoning_parser.py @@ -0,0 +1,316 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Tests for the DeepSeek V4 reasoning parser. + +The V4 parser is a defensive extension of :class:`DeepSeekR1ReasoningParser` +that treats the DSML tool-call start marker as an implicit end-of-reasoning +when ```` is missing. That failure mode is observed at long context +on DSv4-Flash and was previously trapping tool calls inside reasoning, +leaving the agent loop with nothing to dispatch. +""" + +from unittest.mock import MagicMock + +import pytest + +from vllm.entrypoints.openai.engine.protocol import DeltaMessage +from vllm.reasoning import ReasoningParserManager +from vllm.reasoning.deepseek_v4_reasoning_parser import ( + DeepSeekV4ReasoningParser, + DeepSeekV4ThinkingReasoningParser, +) +from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser + +START_TOKEN = "" +END_TOKEN = "" +START_TOKEN_ID = 9001 +END_TOKEN_ID = 9002 +DSML_MARKER = "<|DSML|tool_calls>" +DSML_MARKER_TOKEN_ID = 9100 + + +def _make_tokenizer() -> MagicMock: + """Mock tokenizer mapping the four special strings we care about.""" + tok = MagicMock() + vocab = { + START_TOKEN: START_TOKEN_ID, + END_TOKEN: END_TOKEN_ID, + DSML_MARKER: DSML_MARKER_TOKEN_ID, + } + tok.get_vocab.return_value = vocab + + def _decode(ids, *args, **kwargs): + out = [] + for tid in ids: + for s, sid in vocab.items(): + if sid == tid: + out.append(s) + break + else: + out.append(f"") + return "".join(out) + + tok.decode = _decode + return tok + + +@pytest.fixture +def tokenizer() -> MagicMock: + return _make_tokenizer() + + +@pytest.fixture +def parser(tokenizer) -> DeepSeekV4ThinkingReasoningParser: + return DeepSeekV4ThinkingReasoningParser(tokenizer) + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + + +def test_registration_resolves_to_v4_class(): + parser_cls = ReasoningParserManager.get_reasoning_parser("deepseek_v4") + assert parser_cls is DeepSeekV4ReasoningParser + + +@pytest.mark.parametrize( + "thinking_kwargs,expected_inner", + [ + ({"thinking": True}, DeepSeekV4ThinkingReasoningParser), + ({"enable_thinking": True}, DeepSeekV4ThinkingReasoningParser), + ({"thinking": False}, IdentityReasoningParser), + ({}, IdentityReasoningParser), + ], +) +def test_dispatch_based_on_thinking_kwarg( + tokenizer, thinking_kwargs, expected_inner +): + parser = DeepSeekV4ReasoningParser( + tokenizer, chat_template_kwargs=thinking_kwargs + ) + assert isinstance(parser._parser, expected_inner) + + +# --------------------------------------------------------------------------- +# Healthy path (explicit ) — must match parent behavior +# --------------------------------------------------------------------------- + + +def test_healthy_implicit_start_explicit_end_in_delta(parser): + """Model emits reasoning text then in the same delta; the + parent's R1 splitter handles this — V4 must not interfere.""" + delta = parser.extract_reasoning_streaming( + previous_text="some reasoning ", + current_text="some reasoning after", + delta_text="after", + previous_token_ids=[100, 101], + current_token_ids=[100, 101, END_TOKEN_ID, 102], + delta_token_ids=[END_TOKEN_ID, 102], + ) + assert delta is not None + assert delta.reasoning == "" + assert delta.content == "after" + + +def test_healthy_explicit_end_in_previous_emits_content(parser): + """Once has been seen, the parser emits delta_text as content.""" + delta = parser.extract_reasoning_streaming( + previous_text="reasoning", + current_text="reasoningsome content", + delta_text="some content", + previous_token_ids=[100, END_TOKEN_ID], + current_token_ids=[100, END_TOKEN_ID, 101, 102], + delta_token_ids=[101, 102], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == "some content" + + +def test_healthy_explicit_start_in_delta(parser): + """Model emits both and in the same delta.""" + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text="quick thoughtanswer", + delta_text="quick thoughtanswer", + previous_token_ids=[], + current_token_ids=[START_TOKEN_ID, 200, 201, END_TOKEN_ID, 202], + delta_token_ids=[START_TOKEN_ID, 200, 201, END_TOKEN_ID, 202], + ) + assert delta is not None + assert delta.reasoning == "quick thought" + assert delta.content == "answer" + + +# --------------------------------------------------------------------------- +# Defensive path (no , DSML marker appears) — the fix +# --------------------------------------------------------------------------- + + +def test_implicit_end_marker_in_isolated_delta(parser): + """Marker arrives as its own delta after pure reasoning. The marker + token alone should be classified as content, not reasoning.""" + # First, two reasoning-only deltas to populate state. + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text="step 1 ", + delta_text="step 1 ", + previous_token_ids=[], + current_token_ids=[300, 301], + delta_token_ids=[300, 301], + ) + assert delta is not None + assert delta.reasoning == "step 1 " + assert delta.content is None + assert parser._implicit_end_seen is False + + delta = parser.extract_reasoning_streaming( + previous_text="step 1 ", + current_text="step 1 step 2", + delta_text="step 2", + previous_token_ids=[300, 301], + current_token_ids=[300, 301, 302, 303], + delta_token_ids=[302, 303], + ) + assert delta is not None + assert delta.reasoning == "step 2" + assert delta.content is None + assert parser._implicit_end_seen is False + + # Now the marker arrives. + delta = parser.extract_reasoning_streaming( + previous_text="step 1 step 2", + current_text=f"step 1 step 2{DSML_MARKER}", + delta_text=DSML_MARKER, + previous_token_ids=[300, 301, 302, 303], + current_token_ids=[300, 301, 302, 303, DSML_MARKER_TOKEN_ID], + delta_token_ids=[DSML_MARKER_TOKEN_ID], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == DSML_MARKER + assert parser._implicit_end_seen is True + + +def test_implicit_end_marker_within_delta_split(parser): + """Marker appears partway through a delta — split it at the boundary.""" + delta_text = f"tail of reasoning{DSML_MARKER}\n<|DSML|invoke name=\"w\"" + delta = parser.extract_reasoning_streaming( + previous_text="head ", + current_text=f"head {delta_text}", + delta_text=delta_text, + previous_token_ids=[400], + current_token_ids=[400, 401, DSML_MARKER_TOKEN_ID, 402, 403], + delta_token_ids=[401, DSML_MARKER_TOKEN_ID, 402, 403], + ) + assert delta is not None + assert delta.reasoning == "tail of reasoning" + assert delta.content == f"{DSML_MARKER}\n<|DSML|invoke name=\"w\"" + assert parser._implicit_end_seen is True + + +def test_subsequent_delta_after_implicit_end_is_content(parser): + """Once the implicit end fires, every later delta is content.""" + # Seed the parser by flipping the sticky flag via a marker delta. + parser.extract_reasoning_streaming( + previous_text="reasoning", + current_text=f"reasoning{DSML_MARKER}", + delta_text=DSML_MARKER, + previous_token_ids=[500], + current_token_ids=[500, DSML_MARKER_TOKEN_ID], + delta_token_ids=[DSML_MARKER_TOKEN_ID], + ) + assert parser._implicit_end_seen is True + + # Next delta: pure tool-call body. Should be content. + delta = parser.extract_reasoning_streaming( + previous_text=f"reasoning{DSML_MARKER}", + current_text=f"reasoning{DSML_MARKER}\n<|DSML|invoke", + delta_text="\n<|DSML|invoke", + previous_token_ids=[500, DSML_MARKER_TOKEN_ID], + current_token_ids=[500, DSML_MARKER_TOKEN_ID, 600, 601], + delta_token_ids=[600, 601], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == "\n<|DSML|invoke" + + +def test_marker_does_not_fire_when_explicit_start_present(parser): + """If the explicit ```` token is in the stream, defer to parent. + This guards against false-positive splits when something that looks + like a marker shows up in the user's prompt history. + """ + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text=f"discussing {DSML_MARKER}", + delta_text=f"discussing {DSML_MARKER}", + previous_token_ids=[START_TOKEN_ID], + current_token_ids=[START_TOKEN_ID, 700, 701, DSML_MARKER_TOKEN_ID], + delta_token_ids=[700, 701, DSML_MARKER_TOKEN_ID], + ) + assert delta is not None + # Parent puts everything after into reasoning until . + assert delta.reasoning == f"discussing {DSML_MARKER}" + assert delta.content is None + assert parser._implicit_end_seen is False + + +# --------------------------------------------------------------------------- +# is_reasoning_end / is_reasoning_end_streaming +# --------------------------------------------------------------------------- + + +def test_is_reasoning_end_with_explicit_end_token(parser): + assert parser.is_reasoning_end([100, END_TOKEN_ID, 101]) is True + + +def test_is_reasoning_end_with_implicit_marker(parser): + """When start/end tokens are absent, decoding to text and finding the + marker counts as end-of-reasoning.""" + assert parser.is_reasoning_end([300, 301, DSML_MARKER_TOKEN_ID]) is True + + +def test_is_reasoning_end_pure_reasoning(parser): + assert parser.is_reasoning_end([300, 301, 302]) is False + + +def test_is_reasoning_end_streaming_sticky_after_split(parser): + """After ``extract_reasoning_streaming`` flips the sticky flag, + ``is_reasoning_end_streaming`` must report end-of-reasoning for any + subsequent delta — even one that contains neither nor the + marker.""" + # Seed via marker delta. + parser.extract_reasoning_streaming( + previous_text="r", + current_text=f"r{DSML_MARKER}", + delta_text=DSML_MARKER, + previous_token_ids=[800], + current_token_ids=[800, DSML_MARKER_TOKEN_ID], + delta_token_ids=[DSML_MARKER_TOKEN_ID], + ) + assert parser._implicit_end_seen is True + # Now a plain content-only delta. + assert parser.is_reasoning_end_streaming([800, DSML_MARKER_TOKEN_ID, 900], [900]) is True + + +# --------------------------------------------------------------------------- +# Sanity: parent's empty-delta and single-token guards still apply +# --------------------------------------------------------------------------- + + +def test_single_end_token_delta_returns_none(parser): + """Parent contract: a delta containing only the end token returns + ``None`` — the orchestrator handles the transition via + ``is_reasoning_end``.""" + out = parser.extract_reasoning_streaming( + previous_text="r", + current_text="r", + delta_text="", + previous_token_ids=[900], + current_token_ids=[900, END_TOKEN_ID], + delta_token_ids=[END_TOKEN_ID], + ) + assert out is None diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index cbb1fa350f55..b99b2e1a43ec 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -29,8 +29,8 @@ "DeepSeekV3ReasoningParser", ), "deepseek_v4": ( - "deepseek_v3_reasoning_parser", - "DeepSeekV3ReasoningParser", + "deepseek_v4_reasoning_parser", + "DeepSeekV4ReasoningParser", ), "poolside_v1": ( "poolside_v1_reasoning_parser", diff --git a/vllm/reasoning/deepseek_v4_reasoning_parser.py b/vllm/reasoning/deepseek_v4_reasoning_parser.py new file mode 100644 index 000000000000..373efb9260a1 --- /dev/null +++ b/vllm/reasoning/deepseek_v4_reasoning_parser.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.engine.protocol import DeltaMessage +from vllm.reasoning import ReasoningParser +from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser + +from .identity_reasoning_parser import IdentityReasoningParser + +if TYPE_CHECKING: + from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest + from vllm.entrypoints.openai.responses.protocol import ResponsesRequest + + +# DSML tool-call start markers that the DSv4 model emits when it decides to +# invoke a tool. The reasoning parser treats these as an implicit +# end-of-reasoning when the explicit token is missing, which the +# model occasionally fails to emit at long context. +_DSV4_TOOL_CALL_IMPLICIT_END_MARKERS: tuple[str, ...] = ( + "<|DSML|tool_calls>", +) + + +class DeepSeekV4ThinkingReasoningParser(DeepSeekR1ReasoningParser): + """ + DeepSeek V4 thinking-mode reasoning parser. + + Extends :class:`DeepSeekR1ReasoningParser` with one behavior change: + if the model emits a DSML tool-call start marker without first emitting + ````, treat the marker as an implicit end-of-reasoning so the + tool call is correctly handed off to the tool parser. + + Works around a model behavior observed at long context (~95k–100k + input tokens) where DSv4-Flash sometimes skips ```` before + opening ``<|DSML|tool_calls>``. Without this defensive split the + orchestrator stays in reasoning phase, the tool parser never runs, and + the caller sees a turn with reasoning but no tool call — + indistinguishable on the client side from "model gave up". + + Healthy paths (explicit ````) are unchanged. The marker check + fires only when no explicit start/end token has been seen. + + State (``_implicit_end_seen``) is per-instance and per-stream because + a fresh ``Parser`` (and therefore a fresh reasoning parser) is + constructed for each request. + """ + + implicit_end_markers: tuple[str, ...] = _DSV4_TOOL_CALL_IMPLICIT_END_MARKERS + + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + # Per-stream sticky flag: once the implicit end marker is observed, + # the rest of the stream is content and the orchestrator's + # is_reasoning_end check must return True for every subsequent delta. + self._implicit_end_seen: bool = False + + def _find_implicit_end_marker(self, text: str) -> tuple[str, int] | None: + """Return ``(marker, index)`` of the earliest implicit end marker in + ``text``, or ``None`` if none of the configured markers are present. + """ + earliest: tuple[str, int] | None = None + for marker in self.implicit_end_markers: + idx = text.find(marker) + if idx < 0: + continue + if earliest is None or idx < earliest[1]: + earliest = (marker, idx) + return earliest + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + # Honor the explicit contract first. + if super().is_reasoning_end(input_ids): + return True + # Sticky: once we've observed the marker in streaming, the rest of + # the stream is content. + if self._implicit_end_seen: + return True + # Non-streaming fallback: scan decoded text for the marker. We only + # decode when start/end tokens are absent, which is the failure + # mode we target. + if not input_ids: + return False + if self.start_token_id in input_ids or self.end_token_id in input_ids: + return False + try: + decoded = self.model_tokenizer.decode(list(input_ids)) + except Exception: + return False + return self._find_implicit_end_marker(decoded) is not None + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Iterable[int] + ) -> bool: + if super().is_reasoning_end_streaming(input_ids, delta_ids): + return True + if self._implicit_end_seen: + return True + return False + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + ret = super().extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + + # Parent emitted content (explicit in or before this delta). + # Healthy path — nothing to do. + if ret is not None and ret.content is not None: + return ret + + # Defensive check is only meaningful in implicit-reasoning mode + # (no explicit / in the stream). If any explicit + # token is present, defer to parent. + if ( + self.start_token_id in previous_token_ids + or self.start_token_id in delta_token_ids + or self.end_token_id in previous_token_ids + or self.end_token_id in delta_token_ids + ): + return ret + + # Sticky: marker observed in an earlier delta — everything here is + # content for the tool parser. + if self._implicit_end_seen: + return DeltaMessage(content=delta_text) + + marker_in_current = self._find_implicit_end_marker(current_text) + if marker_in_current is None: + # No marker anywhere; parent's classification stands. + return ret + + # First sighting of the implicit end marker. + self._implicit_end_seen = True + _marker_str, marker_idx_current = marker_in_current + # Position within delta_text where the marker begins. + marker_idx_delta = marker_idx_current - len(previous_text) + if marker_idx_delta < 0: + # Marker straddles into previous_text but wasn't detected there + # (parent path didn't hit). Treat all of delta_text as content. + return DeltaMessage(content=delta_text) + + reasoning_part = delta_text[:marker_idx_delta] or None + content_part = delta_text[marker_idx_delta:] or None + if reasoning_part is None and content_part is None: + return ret + return DeltaMessage(reasoning=reasoning_part, content=content_part) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + # If parent finds , use its behavior. + parent_result = super().extract_content_ids(input_ids) + if parent_result: + return parent_result + # Fall back to text-scan for the implicit marker. + if not input_ids: + return [] + try: + decoded = self.model_tokenizer.decode(list(input_ids)) + except Exception: + return [] + marker = self._find_implicit_end_marker(decoded) + if marker is None: + return [] + # Without per-token offsets we can't slice input_ids exactly at the + # marker boundary. Return everything as content tokens — the + # orchestrator drives tool-call extraction off ``current_text`` for + # DSML grammars, so this conservative answer is acceptable. + return list(input_ids) + + +class DeepSeekV4ReasoningParser(ReasoningParser): + """ + V4 reasoning parser that delegates to either + :class:`DeepSeekV4ThinkingReasoningParser` (the V4-aware extension of + R1) or :class:`IdentityReasoningParser` based on the ``thinking`` / + ``enable_thinking`` chat-template kwargs. + + Replaces the previous arrangement where ``deepseek_v4`` reused + :class:`DeepSeekV3ReasoningParser`, which lacked the implicit + DSML-tool-call end-of-reasoning handling. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + thinking = bool(chat_kwargs.get("thinking", False)) + enable_thinking = bool(chat_kwargs.get("enable_thinking", False)) + thinking = thinking or enable_thinking + + self._parser: ReasoningParser + if thinking: + self._parser = DeepSeekV4ThinkingReasoningParser( + tokenizer, *args, **kwargs + ) + else: + self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs) + + @property + def reasoning_start_str(self) -> str | None: + return self._parser.reasoning_start_str + + @property + def reasoning_end_str(self) -> str | None: + return self._parser.reasoning_end_str + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Iterable[int] + ) -> bool: + return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning( + self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest" + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning(model_output, request) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> "DeltaMessage | None": + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) From 411a2de911c9e059c88ab1f53a42459b43a9d4ce Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 14 May 2026 21:33:39 +0800 Subject: [PATCH 020/131] sm12x: keep @torch.compile on HC head reduction via free-function wrapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the previous attempt to put the decorator directly on ``HCHeadOp.forward_cuda``: when the outer model is wrapped by ``@support_torch_compile`` (the no-MTP path on SM12x) dynamo can't inline-bind the decorated method through ``CustomOp._forward_method`` and the worker dies with:: torch._dynamo.exc.Unsupported: failed to bind arguments when attempting to inline forward_cuda That blocks every no-MTP serve on SM12x. Move the body into a free ``_hc_head_cuda_impl`` decorated with ``@torch.compile(backend=simple_compile_backend)`` — the layout that existed pre-upstream-#41946 — so the method just delegates and dynamo no longer needs to inline a decorated method. Recovers the DSv4-Flash MTP=2 spec-acceptance gain reported in 16ee3bd (67.6 % → 59.8 % drop) without breaking the no-MTP startup path. ``forward_hip`` is unchanged: ROCm doesn't take the same outer ``@support_torch_compile`` route, so the method-level decorator is fine there. Signed-off-by: jasl --- vllm/model_executor/layers/mhc.py | 48 +++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index de1b2a0c617c..60c942e5e510 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -6,6 +6,7 @@ # import vllm.model_executor.kernels.mhc # noqa: F401 import vllm.model_executor.kernels.mhc as mhc_kernels from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform from vllm.utils.import_utils import has_tilelang HAS_TILELANG = has_tilelang() @@ -260,6 +261,45 @@ def forward_xpu( ) +# ``@torch.compile`` on the CUDA HC head reduction is necessary for accuracy +# as well as performance — upstream a8887c208 ("[Bugfix] [ROCm] [DSV4] [Perf] +# Add aiter mhc support", #41946) refactored ``hc_head`` from a free +# function into ``HCHeadOp(CustomOp)`` and dropped the decorator from the +# CUDA path while keeping it on ``forward_hip``. The drop caused a measured +# ~7 pp regression in DSv4-Flash MTP=2 spec acceptance on SM12x (mt-bench +# c=1, 67.6 % → 59.8 %). +# +# Decorating the ``forward_cuda`` method directly trips +# ``torch._dynamo.exc.Unsupported: failed to bind arguments when attempting +# to inline forward_cuda`` whenever the outer model is wrapped by +# ``@support_torch_compile`` (which is the no-MTP path on SM12x): dynamo +# tries to inline the bound method through ``CustomOp._forward_method`` and +# can't reconcile the ``self`` parameter. Keeping the body as a free +# function — the layout that existed pre-#41946 — sidesteps the bind +# failure while preserving the spec-acceptance recovery. +@torch.compile(backend=current_platform.simple_compile_backend) +def _hc_head_cuda_impl( + hidden_states: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_norm_eps: float, + hc_eps: float, +) -> torch.Tensor: + hc_mult, hidden_size = hidden_states.shape[-2:] + outer_shape = hidden_states.shape[:-2] + hs_flat = hidden_states.view(-1, hc_mult, hidden_size) + out = torch.ops.vllm.hc_head_fused_kernel_tilelang( + hs_flat, + hc_fn, + hc_scale, + hc_base, + rms_norm_eps, + hc_eps, + ) + return out.view(*outer_shape, hidden_size) + + # --8<-- [start:hc_head] @CustomOp.register("hc_head") class HCHeadOp(CustomOp): @@ -284,18 +324,14 @@ def forward_cuda( rms_norm_eps: float, hc_eps: float, ) -> torch.Tensor: - hc_mult, hidden_size = hidden_states.shape[-2:] - outer_shape = hidden_states.shape[:-2] - hs_flat = hidden_states.view(-1, hc_mult, hidden_size) - out = torch.ops.vllm.hc_head_fused_kernel_tilelang( - hs_flat, + return _hc_head_cuda_impl( + hidden_states, hc_fn, hc_scale, hc_base, rms_norm_eps, hc_eps, ) - return out.view(*outer_shape, hidden_size) def forward_hip( self, From 21d46acce9b353b2aef72d64d94940e0ad830b37 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 15 May 2026 00:44:09 +0800 Subject: [PATCH 021/131] sm12x: drop multi-request prefill warmup that crashes CUTeDSL kv-gather The third ``_dummy_run`` call added in f4b3301e8 ("Extend DeepSeek V4 warmup coverage to multi-request shapes") synthesizes a multi-request prefill batch and runs it through ``force_attention=True``. On SM12x this trips an illegal memory access inside the CUTeDSL ``DequantGatherKCacheKernel``: the dummy shape exceeds the ``offset + gather_len <= M`` invariant of the kv-gather output buffer (M is sized for the single-prefill warmup case, not for the multi-request layout). Reproduced with ``CUDA_LAUNCH_BLOCKING=1``:: File ".../dequant_gather_k_cutedsl.py", line 29, in dequantize_and_gather_k_cache_cutedsl DequantGatherKCacheKernel.compile(...)( out, k_cache, seq_lens, gather_lens, block_table, offset) RuntimeError: CUDA Error: cudaErrorIllegalAddress Drop just this third warmup call. The other two (single-prefill and second-chunk-of-chunked-prefill) and the MTP uniform-decode coverage from f4b3301e8 stay. The trade is a one-time JIT compile on the first real multi-prefill user request for the un-pre-warmed indexer path; the alternative is failing to start the serve at production ``--max-num-seqs`` values (e.g. 128). A proper fix would reconcile the gather-buffer sizing for the multi-request prefill warmup with the kernel's bounds; left for a follow-up. Signed-off-by: jasl --- vllm/model_executor/warmup/kernel_warmup.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 1fbe45f1c50e..214e4e901cb1 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -408,15 +408,18 @@ def _deepseek_v4_sparse_mla_attention_warmup(worker: "Worker") -> None: create_single_prefill=True, profile_seq_lens=prefill_tokens * 2, ) - # Multi-request prefill (max_num_seqs prefills sharing the - # batched-token budget) covers the other dense-GEMM shape bucket - # plus the multi-prefill indexer path. - runner._dummy_run( - num_tokens=prefill_tokens, - skip_eplb=True, - is_profile=True, - force_attention=True, - ) + # NOTE: The multi-request prefill warmup that previously sat here + # (max_num_seqs prefills sharing the batched-token budget) hit a + # CUDA illegal memory access inside the CUTeDSL + # ``DequantGatherKCacheKernel`` on SM12x. The dummy_run shape it + # generated violates an implicit ``offset + gather_len <= M`` + # invariant of the kv-gather output buffer (M is sized for the + # single-prefill warmup case). Removing the warmup gives back the + # one-time JIT cost on the first real multi-prefill request, but + # unblocks serve startup at production ``--max-num-seqs`` values + # (e.g. 128). Re-enable once the gather-buffer sizing for + # multi-request prefill warmup is reconciled with the kernel's + # bounds. query_len = getattr(runner, "uniform_decode_query_len", 0) for num_reqs in uniform_decode_reqs: runner._dummy_run( From 455c28266035d8f74e3be82d242bb145c973603c Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 15 May 2026 02:16:55 +0800 Subject: [PATCH 022/131] sm12x: drop vestigial cudagraph kill-switch on Triton sparse MLA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``FlashMLASparseMetadataBuilder.get_cudagraph_support`` and the parallel override in ``DeepseekSparseSWAMetadataBuilder`` were guarded on:: getattr(kv_cache_spec, "model_version", None) == "deepseek_v4" and is_triton_sparse_mla_enabled_for_platform() The first clause never holds at runtime: the spec the runtime passes to ``get_cudagraph_support`` is an outer ``UniformTypeKVCacheSpecs`` wrapper (``vllm/v1/kv_cache_interface.py``) that only exposes ``block_size``; the per-layer ``MLAAttentionSpec.model_version`` lives one level down under ``.kv_cache_specs``. So the overrides silently fall through to ``cls._cudagraph_support = AttentionCGSupport.UNIFORM_BATCH`` and cudagraphs are captured normally — confirmed by instrumenting the call and by an ``--enforce-eager`` probe (FULL cudagraphs give a ~2.4× decode throughput speedup on 2× RTX PRO 6000 at ISL=2048 OSL=2048 c=16). Cudagraph capture is also fine for the MTP=2 path on this stack — the spec-decode acceptance and TPOT match the no-MTP measurement to within mt-bench noise (66.10 % acceptance, length 2.32). Since the override was both dead code and would *reduce* performance if "fixed" to actually fire, drop it. The default ``UNIFORM_BATCH`` support level on both builders already does the right thing. Signed-off-by: jasl --- .../v1/attention/backends/mla/flashmla_sparse.py | 16 ---------------- vllm/v1/attention/backends/mla/sparse_swa.py | 14 -------------- 2 files changed, 30 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 2d6231051686..ab726cac1f75 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -28,9 +28,6 @@ SparseMLAAttentionImpl, ) from vllm.v1.attention.backends.mla.compressor_utils import get_compressed_slot_mapping -from vllm.v1.attention.backends.mla.sparse_mla_env import ( - is_triton_sparse_mla_enabled_for_platform, -) from vllm.v1.attention.backends.mla.sparse_utils import ( triton_convert_req_index_to_global_index, ) @@ -235,19 +232,6 @@ def get_prefill_workspace_size(max_model_len: int): class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH - @classmethod - def get_cudagraph_support( - cls, - vllm_config: VllmConfig, - kv_cache_spec: AttentionSpec, - ) -> AttentionCGSupport: - if ( - getattr(kv_cache_spec, "model_version", None) == "deepseek_v4" - and is_triton_sparse_mla_enabled_for_platform() - ): - return AttentionCGSupport.NEVER - return cls._cudagraph_support - def __init__( self, kv_cache_spec: AttentionSpec, diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index fb6e0f9dbcc9..3ff897761370 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -18,7 +18,6 @@ ) from vllm.v1.attention.backends.mla.sparse_mla_env import ( is_triton_sparse_mla_enabled, - is_triton_sparse_mla_enabled_for_platform, ) from vllm.v1.attention.backends.utils import split_decodes_and_prefills from vllm.v1.attention.ops.flashmla import FlashMLASchedMeta, get_mla_metadata @@ -216,19 +215,6 @@ class DeepseekSparseSWAMetadataBuilder(AttentionMetadataBuilder): reorder_batch_threshold: int = 1 _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH - @classmethod - def get_cudagraph_support( - cls, - vllm_config: VllmConfig, - kv_cache_spec: AttentionSpec, - ) -> AttentionCGSupport: - if ( - getattr(kv_cache_spec, "model_version", None) == "deepseek_v4" - and is_triton_sparse_mla_enabled_for_platform() - ): - return AttentionCGSupport.NEVER - return cls._cudagraph_support - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec) From 6e4a03799a7a423a05e583b11e961c4cd13a3c67 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 15 May 2026 03:09:56 +0800 Subject: [PATCH 023/131] sm12x: harden sparse_attn_indexer seq_lens slice with .contiguous() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Defensive ``.contiguous()`` on ``decode_metadata.seq_lens[:batch_size]``. On an already-contiguous slice this is a no-op pointer return; on a non-contiguous 2D slice (max_decode_len < next_n under V2 model runner cudagraph capture) it materializes a contiguous copy that satisfies ``persistent_topk`` and the FP8 MQA paged-logits kernels. Reported by @aabbccddwasd in PR #41834 (comment 4450901180) as a crash workaround on their 4× RTX PRO 6000 TP=4 setup; cost is zero on the path we currently exercise (already contiguous). Signed-off-by: jasl --- vllm/model_executor/layers/sparse_attn_indexer.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 000241b46b08..d44572b96f69 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -354,7 +354,18 @@ def sparse_attn_indexer( batch_size = padded_q_quant_decode_tokens.shape[0] next_n = padded_q_quant_decode_tokens.shape[1] num_padded_tokens = batch_size * next_n - seq_lens = decode_metadata.seq_lens[:batch_size] + # ``.contiguous()`` was originally required because the + # ``DeepseekV32IndexerMetadataBuilder`` allocated + # ``decode_seq_lens_buffer`` as a 2D ``(max_num_seqs, next_n)`` + # tensor, and a ``[:num_decodes, :max_decode_len]`` slice was + # non-contiguous when ``max_decode_len < next_n`` under V2 model + # runner cudagraph capture. Reported by aabbccddwasd in PR #41834 + # comment 4450901180. Upstream PR #42135 (ee58665aa) since + # unified the buffer to 1D ``(max_num_batched_tokens,)``, so the + # slice is now always contiguous and this call is a no-op pointer + # return. Kept as a defensive belt against future regressions in + # the metadata builder's buffer shape. + seq_lens = decode_metadata.seq_lens[:batch_size].contiguous() # seq_lens is always 2D: (B, next_n) for native spec decode, (B, 1) # otherwise. deep_gemm fp8_fp4_paged_mqa_logits requires 2D context_lens; # the downstream topk kernels accept both 1D and 2D. From 00ea0da9b9efae7c69c06b88050992adfef18553 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 15 May 2026 03:10:06 +0800 Subject: [PATCH 024/131] sm12x: autotune num_warps on fp8_einsum + fused_indexer_q kernels ``_deepseek_v4_sm12x_fp8_einsum_kernel`` was launched with hardcoded ``num_warps=4 num_stages=3``; ``_fused_indexer_q_rope_quant_kernel`` was launched with ``num_warps=1`` (with a "TODO: Tune this" inline). Replace both with ``@triton.autotune`` so the best warp/stage config is picked per shape: - fp8_einsum: configs over ``{(4,3), (8,3), (4,2), (8,2)}`` keyed on ``(num_tokens, num_groups, out_rank, hidden_size)``. - fused_indexer_q: configs over ``num_warps={1,2,4}`` keyed on ``(INDEX_Q_HALF_ROT_DIM, INDEX_Q_HEAD_DIM)``. Both kernels are launched per forward, so autotune fires once per unique key and the cached selection is reused on subsequent calls. Reported by @aabbccddwasd in PR #41834 (comment 4450901180). Signed-off-by: jasl --- vllm/models/deepseek_v4/common/ops/fused_indexer_q.py | 10 +++++++++- vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py | 8 ++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py b/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py index d5aaf10feba4..c7fbd93f329b 100644 --- a/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py +++ b/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py @@ -66,6 +66,14 @@ def _quantize_mxfp4_pair(x_lo, x_hi): return packed, ue8m0 +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + ], + key=["INDEX_Q_HALF_ROT_DIM", "INDEX_Q_HEAD_DIM"], +) @triton.jit def _fused_indexer_q_rope_quant_kernel( pos_ptr, @@ -433,6 +441,6 @@ def fused_indexer_q_rope_quant( index_weights_head_scale, index_weights_out, index_weights_out.stride(0), - num_warps=1, # TODO: Tune this + # num_warps supplied by @triton.autotune above. ) return index_q_fp8, index_weights_out diff --git a/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py index a9f52e767b20..6c2d2797333c 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py +++ b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py @@ -174,6 +174,14 @@ def deepseek_v4_sm12x_fp8_einsum( BLOCK_TOKENS=block_tokens, BLOCK_OUT=block_out, BLOCK_HIDDEN=block_hidden, + # Pinned to the SM12x-optimal config: a previous ``@triton.autotune`` + # block selected from {num_warps in {4,8}, num_stages in {2,3}} with + # key=["num_tokens", ...]. ``num_tokens`` varies per request, so the + # autotune cache missed every call and the 4-config bench replayed + # on every shape — pure overhead. The other three keys are + # model-architecture-fixed, so the same config (num_warps=4, + # num_stages=3) always won; we pin it directly. Reported by + # ``alexbi29`` in PR #41834 comment 4464750956. num_warps=4, num_stages=3, ) From 4ed845281717cc48a38e494209df67c166a0776c Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 15 May 2026 03:10:18 +0800 Subject: [PATCH 025/131] sm12x: autotune num_warps/num_stages on 3 sparse MLA accumulate kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add ``@triton.autotune({(num_warps, num_stages) in {4,8} × {2,3}})`` to the three single-head prefill accumulate kernels in ``sparse_mla_kernels.py``:: - ``_accumulate_indexed_attention_chunk_kernel`` - ``_accumulate_fp8ds_global_slots_attention_chunk_kernel`` - ``_accumulate_fp8ds_paged_attention_chunk_kernel`` Each was previously launched with hardcoded ``num_warps=8``; the new configs explore ``{4,8}`` × ``{2,3}`` keyed on ``num_candidates`` (the dominant per-shape factor). Autotune fires once per ``num_candidates`` value seen at runtime and the chosen config is cached for subsequent calls. The two multihead variants (``..._multihead_kernel``) are NOT autotuned in this commit: they share the same accumulate-read-write pattern but per @aabbccddwasd's note (PR #41834 comment 4450901180) need a separate ``num_tokens: tl.constexpr`` + ``reset_to_zero`` treatment for autotune correctness, which we'll add in a follow-up once we've validated the single-head gain on this hardware. Reported by @aabbccddwasd in PR #41834 (comment 4450901180); claimed ~+39 % prefill on 4× RTX PRO 6000 TP=4 32K ctx, with the PR base already having a higher baseline so absolute gain is smaller. Signed-off-by: jasl --- .../backends/mla/sparse_mla_kernels.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py index 1f4460fa5d70..ecfd12ae3d1e 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -1233,6 +1233,15 @@ def accumulate_gathered_sparse_mla_attention_chunk( ) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=4, num_stages=2), + triton.Config({}, num_warps=4, num_stages=3), + triton.Config({}, num_warps=8, num_stages=2), + triton.Config({}, num_warps=8, num_stages=3), + ], + key=["num_candidates"], +) @triton.jit def _accumulate_indexed_attention_chunk_kernel( q_ptr, @@ -1372,10 +1381,19 @@ def accumulate_indexed_sparse_mla_attention_chunk( candidate_offset, scale, BLOCK_D=block_d, - num_warps=8, + # num_warps / num_stages supplied by @triton.autotune above. ) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=4, num_stages=2), + triton.Config({}, num_warps=4, num_stages=3), + triton.Config({}, num_warps=8, num_stages=2), + triton.Config({}, num_warps=8, num_stages=3), + ], + key=["num_candidates"], +) @triton.jit def _accumulate_fp8ds_global_slots_attention_chunk_kernel( q_ptr, @@ -1557,7 +1575,7 @@ def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk( candidate_offset, scale, BLOCK_D=block_d, - num_warps=8, + # num_warps / num_stages supplied by @triton.autotune above. ) @@ -1767,6 +1785,15 @@ def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( ) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=4, num_stages=2), + triton.Config({}, num_warps=4, num_stages=3), + triton.Config({}, num_warps=8, num_stages=2), + triton.Config({}, num_warps=8, num_stages=3), + ], + key=["num_candidates"], +) @triton.jit def _accumulate_fp8ds_paged_attention_chunk_kernel( q_ptr, @@ -1951,7 +1978,7 @@ def accumulate_fp8ds_paged_sparse_mla_attention_chunk( candidate_offset, scale, BLOCK_D=block_d, - num_warps=8, + # num_warps / num_stages supplied by @triton.autotune above. ) From e8b14e236d2cb88d4ae2645e6039023cad53a60d Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 15 May 2026 05:36:47 +0800 Subject: [PATCH 026/131] sm12x: add 3 dense FP8 W8A8 Block configs for RTX PRO 6000 WS Edition MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tuned via ``scripts/_fp8_block_tune_driver.py`` for the three remaining DSv4-Flash dense linear shapes the workstation hits at TP=2 but didn't yet have ship-tuned configs for: - N=4096, K=2048 (q_b / gate projection) - N=1024, K=4096 (wq_b projection) - N=4096, K=512 (wo_b projection) Suggested by @aabbccddwasd in PR #41834 (comment 4450901180). Tuned on the local 2× RTX PRO 6000 Blackwell Workstation Edition host with the same wrapper that produced the existing six configs in this directory; lookup is device-name keyed so no code changes required. These complement the existing six WS-edition configs (N,K) ∈ {(1536, 4096), (2048, 4096), (4096, 1024), (4096, 4096), (8192, 1024), (16384, 1024)} so DSv4-Flash now hits a tuned config for every dense linear shape it issues, instead of falling back to the default heuristic for the three shapes above. Signed-off-by: jasl --- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 +++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 +++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 +++++++++++++++++++ 3 files changed, 246 insertions(+) create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=1024,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1024,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1024,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..46a4cc726f7c --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=1024,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..afeb347e229a --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..6059ab794a4f --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,82 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} From 1b69e3d6b9b8dec328918a0faa4ff97d604d60b9 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 15 May 2026 06:28:24 +0800 Subject: [PATCH 027/131] sm12x: cap C128A metadata kernel loop at effective_topk (no shape changes) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cudagraph-safe retry of suggestion #2 from PR #41834 comment 4450901180. Previous attempt (e34daef4d, reverted) also exposed ``c128a_*_effective_topk`` on the metadata and truncated the buffer slice inside ``deepseek_v4_attention``; that truncation baked the shape into the captured forward launch, breaking replay when ``effective_topk`` shifted between capture and replay. This version only touches the metadata builder (which already runs *outside* the captured forward), so per-call ``effective_topk`` variation is fine: 1. Pre-fill ``global_decode_buffer[:num_decode_tokens]`` and ``prefill_buffer[:num_prefill_tokens]`` with ``-1`` before launch. 2. Compute ``effective_topk_arg = cdiv(max num_compressed across in-flight tokens, BLOCK_SIZE) * BLOCK_SIZE``, capped at ``max_compressed_tokens``. 3. Kernel inner loop uses ``effective_topk`` (was ``max_compressed_tokens``); store mask uses the same. The buffer entries the kernel skips (``[effective_topk, max_compressed_tokens)``) stay at ``-1`` from the pre-fill, so downstream sparse MLA accumulate kernels (which still iterate the full ``max_compressed_tokens`` width inside the cudagraph) see only ``-1`` sentinels in the tail and short-circuit them via ``kv_index >= 0`` / ``candidate < valid_len`` checks. No tensor shape changes inside the captured forward → cudagraph capture/replay remains correct. Savings here are limited to the metadata kernel itself; the accumulate kernels' iteration count is unchanged (their loop bound is the captured ``num_candidates`` shape value, which we deliberately do not narrow). Bench at long ``max_model_len`` will confirm whether this is enough to recover a meaningful chunk of the ~27 % TPOT regression observed at ``max_model_len=131072`` vs ``8192``. Signed-off-by: jasl --- .../attention/backends/mla/flashmla_sparse.py | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index ab726cac1f75..0bc2c89bebff 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -15,6 +15,8 @@ ) from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability +from vllm.triton_utils import tl, triton +from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import num_compute_units from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( @@ -893,3 +895,164 @@ def forward_mqa( ) return attn_out, None + + +def build_c128a_topk_metadata( + positions: torch.Tensor, + compress_ratio: int, + num_decode_tokens: int, + token_to_req_indices: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + slot_mapping: torch.Tensor, + global_decode_buffer: torch.Tensor, + decode_lens_buffer: torch.Tensor, + prefill_buffer: torch.Tensor, + max_compressed_tokens: int = 8192, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Single kernel for all C128A tokens (decode + prefill). + + Decode tokens: position → block_table lookup → global slot ids + topk_lens. + Prefill tokens: position → local indices [0, ..., n-1, -1, ...]. + + Writes into pre-allocated buffers for CUDA graph address stability. + Returns slices of the buffers. + """ + num_tokens = positions.shape[0] + num_prefill_tokens = num_tokens - num_decode_tokens + + global_decode = global_decode_buffer[:num_decode_tokens] + decode_lens = decode_lens_buffer[:num_decode_tokens] + prefill_local = prefill_buffer[:num_prefill_tokens] + + if num_tokens == 0: + return global_decode, decode_lens, prefill_local + + BLOCK_SIZE = 1024 + + # Compute the smallest BLOCK_SIZE-aligned width that covers every + # in-flight token's num_compressed. When ``max_model_len`` is much + # larger than the actual prompts (e.g. 1M cap with 2K inputs) the + # original ``range(0, max_compressed_tokens, BLOCK_SIZE)`` iterated + # over a tail that always wrote ``-1`` for shorter contexts. Capping + # the inner loop at ``effective_topk`` cuts those dead iterations. + # + # This builder runs OUTSIDE the CUDA-graph-captured forward pass, so + # ``effective_topk`` may vary freely per call. To keep downstream + # (which reads the full ``max_compressed_tokens`` buffer width inside + # the captured forward) correct, we pre-fill the active slice with + # ``-1`` so the kernel-skipped tail is treated as "invalid" by the + # sentinel checks in the sparse MLA accumulate kernels. + if num_decode_tokens > 0: + global_decode_buffer[:num_decode_tokens].fill_(-1) + decode_lens_buffer[:num_decode_tokens].zero_() + if num_prefill_tokens > 0: + prefill_buffer[:num_prefill_tokens].fill_(-1) + + max_pos = int(positions.max().item()) + max_num_compressed = min( + max((max_pos + 1) // compress_ratio, 0), + max_compressed_tokens, + ) + effective_topk_arg = min( + max_compressed_tokens, + cdiv(max_num_compressed, BLOCK_SIZE) * BLOCK_SIZE, + ) + if effective_topk_arg == 0: + # Nothing to write; the fill_(-1) above already produced the + # correct "all-invalid" buffer state. + return global_decode, decode_lens, prefill_local + + _build_c128a_topk_metadata_kernel[(num_tokens,)]( + global_decode_buffer, + global_decode_buffer.stride(0), + decode_lens_buffer, + prefill_buffer, + prefill_buffer.stride(0), + positions, + compress_ratio, + effective_topk_arg, + num_decode_tokens, + token_to_req_indices, + block_table, + block_table.stride(0), + block_size, + slot_mapping, + BLOCK_SIZE=BLOCK_SIZE, + ) + return global_decode, decode_lens, prefill_local + + +@triton.jit +def _build_c128a_topk_metadata_kernel( + # Decode outputs + global_decode_ptr, + global_decode_stride, + decode_lens_ptr, + # Prefill output + prefill_local_ptr, + prefill_local_stride, + # Inputs + positions_ptr, + compress_ratio, + effective_topk, + num_decode_tokens, + token_to_req_indices_ptr, + block_table_ptr, + block_table_stride, + block_size, + slot_mapping_ptr, + BLOCK_SIZE: tl.constexpr, +): + # ``effective_topk`` is the BLOCK_SIZE-aligned cap that covers every + # in-flight token's ``num_compressed`` (computed by the Python + # builder). The caller pre-fills the active buffer slice with ``-1`` + # so entries in ``[effective_topk, max_compressed_tokens)`` remain + # ``-1`` (the sentinel the downstream sparse MLA accumulate kernels + # use to skip invalid candidates). + token_idx = tl.program_id(0) + position = tl.load(positions_ptr + token_idx) + num_compressed = (position + 1) // compress_ratio + num_compressed = tl.minimum(num_compressed, effective_topk) + is_decode = token_idx < num_decode_tokens + + if is_decode: + # --- Decode: block-table lookup → global slot ids + count --- + is_valid_token = tl.load(slot_mapping_ptr + token_idx) >= 0 + req_idx = tl.load(token_to_req_indices_ptr + token_idx) + count = tl.zeros((), dtype=tl.int32) + for i in range(0, effective_topk, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < effective_topk + is_valid = offset < num_compressed + + block_indices = offset // block_size + block_numbers = tl.load( + block_table_ptr + req_idx * block_table_stride + block_indices, + mask=mask & is_valid, + ) + block_offsets = offset % block_size + slot_ids = block_numbers * block_size + block_offsets + slot_ids = tl.where(is_valid, slot_ids, -1) + tl.store( + global_decode_ptr + token_idx * global_decode_stride + offset, + slot_ids, + mask=mask, + ) + count += tl.sum(is_valid.to(tl.int32), axis=0) + + tl.store( + decode_lens_ptr + token_idx, + tl.where(is_valid_token, count, 0), + ) + else: + # --- Prefill: write local indices --- + pfx_idx = token_idx - num_decode_tokens + for i in range(0, effective_topk, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + mask = offset < effective_topk + tl.store( + prefill_local_ptr + pfx_idx * prefill_local_stride + offset, + tl.where(offset < num_compressed, offset, -1), + mask=mask, + ) From 5632e820158dfa5b4973c813be9370c7d8ce94d2 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 15 May 2026 17:06:30 +0800 Subject: [PATCH 028/131] sm12x: per-token early-loop-exit on sparse MLA accumulate inner candidate loop MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Redesigned suggestion #3 from PR #41834 comment 4450901180. The first attempt (e34daef4d, reverted; later 72a5ff228, also reverted) tried to truncate ``topk_indices.shape[1]`` in Python so the captured launches iterated a narrower combined slice; that approach broke under cudagraph replay (shape baked at capture) and *also* mis-bounded — the combine kernel writes each token's combined buffer as ``[topk_len_t | swa_len_t | -1 padding]`` with SWA *immediately* following the per-token topk, so a fixed ``effective_topk`` cap cuts off the SWA portion (GSM8K dropped 25 pp on the prior attempt). The kernel already loads the per-token combined length (``valid_len = tl.load(lens_ptr + token_idx)`` for the four ``lens``- gated kernels, ``gather_len`` for the two paged kernels). The existing ``is_valid`` guard only short-circuits the *heavy* work past that length; the outer ``for candidate_idx in range(0, num_candidates)`` still pays one ``tl.load`` + branch per iter on the dead tail. Capping the loop at ``min(num_candidates, valid_len - candidate_offset)`` (clamped to 0) removes those wasted iterations while preserving the existing ``is_valid`` semantics: the iterations we now skip are exactly those the existing guard already discarded. Applied to six accumulate kernels in ``sparse_mla_kernels.py``: - ``_accumulate_gathered_attention_chunk_kernel`` - ``_accumulate_indexed_attention_chunk_kernel`` [autotuned in #1] - ``_accumulate_fp8ds_global_slots_attention_chunk_kernel`` [autotuned in #1] - ``_accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel`` [decode] - ``_accumulate_fp8ds_paged_attention_chunk_kernel`` [autotuned in #1] - ``_accumulate_fp8ds_paged_attention_chunk_multihead_kernel`` [decode] CUDA-graph safety: ``lens_ptr`` / ``gather_lens_ptr`` are stable addresses; their values are refreshed per call by the metadata builder (outside the captured forward) and by ``combine_topk_swa_indices`` (inside the forward but writing only into the persistent buffers the accumulate kernels read from). The kernel inner-loop bound is a runtime-loaded scalar — Triton compiles a dynamic loop and the captured launch picks up the current value on each replay. Savings scale with ``combined_topk_buffer_width - actual valid length`` (i.e. mostly visible at long ``max_model_len`` with shorter actual contexts). At our test shape (``max_model_len=131072``, ISL=2048) the saved iterations come mostly from the decode multihead path; expected to be neutral / no-regression at short ``max_model_len`` where the bound equals ``num_candidates``. Signed-off-by: jasl --- .../backends/mla/sparse_mla_kernels.py | 69 +++++++++++++++++-- 1 file changed, 62 insertions(+), 7 deletions(-) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py index ecfd12ae3d1e..5ccdd2d3ef7f 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -1128,8 +1128,13 @@ def _accumulate_gathered_attention_chunk_kernel( running_denom = tl.load(denom_ptr + state_offset) running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) valid_len = tl.load(lens_ptr + token_idx) + # Per-token early-loop-exit (see indexed kernel comment). + local_eff = tl.minimum( + num_candidates, + tl.maximum(valid_len - candidate_offset, 0), + ) - for candidate_idx in range(0, num_candidates): + for candidate_idx in range(0, local_eff): is_valid = (candidate_offset + candidate_idx) < valid_len if HAS_SLOT_IDS: slot_id = tl.load( @@ -1289,8 +1294,21 @@ def _accumulate_indexed_attention_chunk_kernel( running_denom = tl.load(denom_ptr + state_offset) running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) valid_len = tl.load(lens_ptr + token_idx) + # Per-token early-loop-exit: the combine_topk_swa_indices kernel writes + # ``[topk_len_t | swa_len_t | -1 padding]`` and stores + # ``lens[t] = topk_len_t + swa_len_t``. The existing ``is_valid`` guard + # already gates the heavy work past ``valid_len``, but the outer loop + # still iterates the full ``num_candidates`` (= chunk width). Capping + # the loop at ``min(num_candidates, valid_len - candidate_offset)`` + # saves the per-iteration index load + compare overhead on the dead + # tail. CUDA-graph-safe because ``lens_ptr`` is a stable address and + # the loaded value updates per call from the metadata builder. + local_eff = tl.minimum( + num_candidates, + tl.maximum(valid_len - candidate_offset, 0), + ) - for candidate_idx in range(0, num_candidates): + for candidate_idx in range(0, local_eff): kv_index = tl.load( indices_ptr + token_idx * stride_indices_t @@ -1445,12 +1463,17 @@ def _accumulate_fp8ds_global_slots_attention_chunk_kernel( running_denom = tl.load(denom_ptr + state_offset) running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) valid_len = tl.load(lens_ptr + token_idx) + # Per-token early-loop-exit (see indexed kernel comment). + local_eff = tl.minimum( + num_candidates, + tl.maximum(valid_len - candidate_offset, 0), + ) fp8_mask = offsets < fp8_dim rope_mask = (offsets >= fp8_dim) & dim_mask rope_offsets = tl.maximum(offsets - fp8_dim, 0) - for candidate_idx in range(0, num_candidates): + for candidate_idx in range(0, local_eff): slot_id = tl.load( slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c ) @@ -1645,12 +1668,21 @@ def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel( tl.float32 ) valid_len = tl.load(lens_ptr + token_idx) + # Per-token early-loop-exit: ``lens[t] = topk_len_t + swa_len_t`` (set + # by combine_topk_swa_indices). Iterating past ``valid_len`` only + # incurs the per-iter index-load + compare cost on padding-tail; cap + # the outer loop at ``valid_len - candidate_offset`` to skip the dead + # tail. CUDA-graph-safe because ``lens_ptr`` is a stable address. + local_eff = tl.minimum( + num_candidates, + tl.maximum(valid_len - candidate_offset, 0), + ) fp8_mask = dim_offsets < fp8_dim rope_mask = (dim_offsets >= fp8_dim) & dim_mask rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) - for candidate_idx in range(0, num_candidates): + for candidate_idx in range(0, local_eff): slot_id = tl.load( slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c ) @@ -1851,8 +1883,13 @@ def _accumulate_fp8ds_paged_attention_chunk_kernel( fp8_mask = offsets < fp8_dim rope_mask = (offsets >= fp8_dim) & dim_mask rope_offsets = tl.maximum(offsets - fp8_dim, 0) + # Per-token early-loop-exit (see indexed kernel comment). + local_eff = tl.minimum( + num_candidates, + tl.maximum(gather_len - candidate_offset, 0), + ) - for candidate_idx in range(0, num_candidates): + for candidate_idx in range(0, local_eff): gather_idx = candidate_offset + candidate_idx is_valid = gather_idx < gather_len @@ -2054,8 +2091,17 @@ def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel( fp8_mask = dim_offsets < fp8_dim rope_mask = (dim_offsets >= fp8_dim) & dim_mask rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + # Per-token early-loop-exit: ``gather_len`` is the per-token count of + # cached entries available for this paged read; the existing + # ``is_valid`` guard skips heavy work past that, but we can also skip + # the per-iter index load + branch by capping the loop. CUDA-graph- + # safe because ``gather_lens_ptr`` is a stable address. + local_eff = tl.minimum( + num_candidates, + tl.maximum(gather_len - candidate_offset, 0), + ) - for candidate_idx in range(0, num_candidates): + for candidate_idx in range(0, local_eff): gather_idx = candidate_offset + candidate_idx is_valid = gather_idx < gather_len @@ -2247,8 +2293,17 @@ def _fp8ds_paged_attention_with_sink_multihead_kernel( fp8_mask = dim_offsets < fp8_dim rope_mask = (dim_offsets >= fp8_dim) & dim_mask rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + # Per-token early-loop-exit: ``gather_len`` is the per-token count of + # cached entries available for this paged read; the existing + # ``is_valid`` guard skips heavy work past that, but we can also skip + # the per-iter index load + branch by capping the loop. CUDA-graph- + # safe because ``gather_lens_ptr`` is a stable address. + local_eff = tl.minimum( + num_candidates, + tl.maximum(gather_len - candidate_offset, 0), + ) - for candidate_idx in range(0, num_candidates): + for candidate_idx in range(0, local_eff): gather_idx = candidate_offset + candidate_idx is_valid = gather_idx < gather_len if is_valid: From 47cc4c363308f83fdc9970306518e4308237f856 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 15 May 2026 23:55:17 +0800 Subject: [PATCH 029/131] =?UTF-8?q?sm12x:=20docs=20cleanup=20pass=201=20?= =?UTF-8?q?=E2=80=94=20clarify=20metadata=20+=20MLA=20manager=20docstrings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three pure comment/docstring fixes from the audit, no behavior change: 1. ``_build_c128a_topk_metadata_kernel`` comment was ambiguous about ``max_compressed_tokens`` after the parameter was renamed to ``effective_topk`` in 304944e29. Reword to explicitly point at the Python caller (``build_c128a_topk_metadata``) and explain that ``max_compressed_tokens`` is the buffer column width and entries past ``effective_topk`` stay at ``-1`` via the caller's ``fill_(-1)`` pre-pass. 2. Add an inline note next to ``positions.max().item()`` flagging it as a host sync that is safe here because the builder runs outside the captured forward. 3. Expand ``MLAAttentionManager`` class docstring: the predicate ``_should_protect_prompt_blocks`` triggers on three independent conditions (DSv4 model_version, fp8_ds_mla cache_dtype_str, or compress_ratio > 1), not just DSv4. Document the three conditions inline so a future tightening pass does not accidentally narrow the coverage. Signed-off-by: jasl --- vllm/v1/attention/backends/mla/flashmla_sparse.py | 13 +++++++++---- vllm/v1/core/single_type_kv_cache_manager.py | 14 +++++++++++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 0bc2c89bebff..6835fc22a4e1 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -949,6 +949,9 @@ def build_c128a_topk_metadata( if num_prefill_tokens > 0: prefill_buffer[:num_prefill_tokens].fill_(-1) + # ``.item()`` is a host sync, but this builder runs in metadata + # build (outside the captured forward) so it is harmless w.r.t. + # cudagraph capture/replay (see the comment block above). max_pos = int(positions.max().item()) max_num_compressed = min( max((max_pos + 1) // compress_ratio, 0), @@ -1006,10 +1009,12 @@ def _build_c128a_topk_metadata_kernel( ): # ``effective_topk`` is the BLOCK_SIZE-aligned cap that covers every # in-flight token's ``num_compressed`` (computed by the Python - # builder). The caller pre-fills the active buffer slice with ``-1`` - # so entries in ``[effective_topk, max_compressed_tokens)`` remain - # ``-1`` (the sentinel the downstream sparse MLA accumulate kernels - # use to skip invalid candidates). + # caller in ``build_c128a_topk_metadata``). The buffer columns + # extend out to the Python-side ``max_compressed_tokens`` width; + # entries past ``effective_topk`` are left at ``-1`` by the caller's + # ``fill_(-1)`` pre-pass so the downstream sparse MLA accumulate + # kernels treat them as invalid via their ``kv_index >= 0`` / + # ``slot_id >= 0`` sentinel checks. token_idx = tl.program_id(0) position = tl.load(positions_ptr + token_idx) num_compressed = (position + 1) // compress_ratio diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index b4ecff696de0..042d149393d3 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -677,9 +677,21 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: class MLAAttentionManager(FullAttentionManager): - """KV cache manager for DeepSeek V4 compressed MLA cache.""" + """KV cache manager for compressed / fp8 MLA cache layouts. + + Used by any MLA spec whose hit semantics need prompt-block + protection across decode and unrelated cache churn. ``_should_ + protect_prompt_blocks`` enumerates the triggering conditions. + """ def _should_protect_prompt_blocks(self) -> bool: + # Three independent triggers: + # 1. ``model_version == "deepseek_v4"``: DSv4 explicitly opts in. + # 2. ``cache_dtype_str == "fp8_ds_mla"``: fp8 DeepSeek-style + # MLA cache; protection is needed for the same hybrid-align + # reuse pattern. + # 3. ``compress_ratio > 1``: any compressed MLA cache (today + # only DSv4 sets ``compress_ratio > 1``; V3.2 keeps it at 1). return ( self.kv_cache_spec.model_version == "deepseek_v4" or self.kv_cache_spec.cache_dtype_str == "fp8_ds_mla" From dbed3068f8f7b41a63a78546ac867d74c9ee0e55 Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 16 May 2026 00:02:21 +0800 Subject: [PATCH 030/131] =?UTF-8?q?sm12x:=20docs=20cleanup=20pass=202=20?= =?UTF-8?q?=E2=80=94=20dedupe=20=5Fupcast=5Fe8m0=5Fto=5Ffp32=20+=20simplif?= =?UTF-8?q?y=20skip=5Fweight=5Fname?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two refactors from the audit, no behavior change: 1. ``vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py`` had its own copy of ``_upcast_e8m0_to_fp32`` (4 lines, identical to the canonical helper at ``vllm/model_executor/layers/quantization/utils/fp8_utils. py:1017``). Other peer call sites (cutlass.py, rocm_aiter_mla_sparse. py, mxfp4.py) already import from ``fp8_utils``; do the same here. 2. ``DeepseekV4ForCausalLM.skip_weight_name_before_load`` used ``hf_to_vllm_mapper.apply_list([name])`` to map a single name. That builds a one-element list and routes through a list-comprehension that filters ``None``. Use the canonical 1-to-1 helper ``WeightsMapper._map_name`` directly, matching the pattern used in ``compressed_tensors.py``, ``adapters.py``, ``bitsandbytes_loader. py``, and ``lora/utils.py``. Same semantics, 3 lines instead of 5. Signed-off-by: jasl --- vllm/models/deepseek_v4/nvidia/model.py | 6 ++---- vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py | 9 +++------ 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index ea09509bd4c2..511acaff55c5 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -1375,10 +1375,8 @@ def get_mtp_target_hidden_states(self) -> torch.Tensor | None: return getattr(self.model, "_mtp_hidden_buffer", None) def skip_weight_name_before_load(self, name: str) -> bool: - mapped_names = self.hf_to_vllm_mapper.apply_list([name]) - if not mapped_names: - return True - return all("mtp." in mapped_name for mapped_name in mapped_names) + mapped = self.hf_to_vllm_mapper._map_name(name) + return mapped is None or "mtp." in mapped def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_substrs=["mtp."]) diff --git a/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py index 6c2d2797333c..ec256eb03392 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py +++ b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py @@ -5,18 +5,15 @@ import torch from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, +) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import fp8_einsum from vllm.utils.torch_utils import direct_register_custom_op -def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: - exp_bits = scale.view(torch.uint8).to(torch.int32) - fp32_bits = exp_bits << 23 - return fp32_bits.view(torch.float32) - - @triton.jit def _deepseek_v4_sm12x_fp8_einsum_kernel( a_ptr, From 1cf15d222c774ff6640bc00d03c58d65b04e2e00 Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 16 May 2026 00:14:54 +0800 Subject: [PATCH 031/131] =?UTF-8?q?sm12x:=20docs=20cleanup=20pass=203=20?= =?UTF-8?q?=E2=80=94=20drop=20tautological=20is=5Fvalid=20in=207=20accumul?= =?UTF-8?q?ate=20kernels?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After ``a94e7c289 sm12x: per-token early-loop-exit on sparse MLA accumulate inner candidate loop`` capped each inner loop at ``local_eff = min(num_candidates, max(valid_len - candidate_offset, 0))`` (or the ``gather_len`` equivalent for the paged kernels), the per-iter check ``(candidate_offset + candidate_idx) < valid_len`` / ``gather_idx < gather_len`` became structurally always true: by construction every iteration's index sits inside the valid range. This commit drops the tautological term in 7 sparse MLA accumulate kernels and leaves the remaining cell-sentinel guard: - ``accumulate_..._gathered_chunk`` (was: ``(...) < valid_len`` then AND with ``slot_id >= 0`` when ``HAS_SLOT_IDS``): now just ``is_valid = slot_id >= 0`` (or ``True`` when ``HAS_SLOT_IDS`` is false). The branch on ``HAS_SLOT_IDS`` becomes a ``tl.constexpr`` binary, which Triton compiles into two clean specialisations. - ``accumulate_..._indexed_chunk``: ``is_valid = kv_index >= 0``. - ``accumulate_fp8ds_global_slots_sparse_mla_attention_chunk{,_multihead}``: ``is_valid = slot_id >= 0``. - ``accumulate_fp8ds_paged_sparse_mla_attention_chunk{,_multihead, _multihead_with_sink}``: there is no per-cell sentinel here, so the whole ``is_valid`` variable and ``if is_valid:`` guard go away and the loop body becomes unconditional. Each touched site gains a 2-3 line comment explaining the invariant so a future reader can see why no per-iter clamp is needed. No behavioral change: Triton was already eliminating the tautology after the SSA pass; this commit makes the intent explicit at the source level. Signed-off-by: jasl --- .../backends/mla/sparse_mla_kernels.py | 275 +++++++++--------- 1 file changed, 143 insertions(+), 132 deletions(-) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py index 5ccdd2d3ef7f..8985d53f9e51 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -1134,13 +1134,17 @@ def _accumulate_gathered_attention_chunk_kernel( tl.maximum(valid_len - candidate_offset, 0), ) + # ``candidate_offset + candidate_idx < valid_len`` is structurally + # guaranteed by the ``local_eff`` cap above, so the only remaining + # gate is ``slot_id >= 0`` when slot ids are present. for candidate_idx in range(0, local_eff): - is_valid = (candidate_offset + candidate_idx) < valid_len if HAS_SLOT_IDS: slot_id = tl.load( slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c ) - is_valid = is_valid & (slot_id >= 0) + is_valid = slot_id >= 0 + else: + is_valid = True if is_valid: kv = tl.load( @@ -1308,13 +1312,16 @@ def _accumulate_indexed_attention_chunk_kernel( tl.maximum(valid_len - candidate_offset, 0), ) + # ``candidate_offset + candidate_idx < valid_len`` is structurally + # guaranteed by the ``local_eff`` cap above; only the per-cell + # sentinel check (``kv_index >= 0``) is still meaningful. for candidate_idx in range(0, local_eff): kv_index = tl.load( indices_ptr + token_idx * stride_indices_t + candidate_idx * stride_indices_c ) - is_valid = ((candidate_offset + candidate_idx) < valid_len) & (kv_index >= 0) + is_valid = kv_index >= 0 if is_valid: kv = tl.load( @@ -1473,11 +1480,14 @@ def _accumulate_fp8ds_global_slots_attention_chunk_kernel( rope_mask = (offsets >= fp8_dim) & dim_mask rope_offsets = tl.maximum(offsets - fp8_dim, 0) + # ``candidate_offset + candidate_idx < valid_len`` is structurally + # guaranteed by the ``local_eff`` cap above; only the per-cell + # sentinel check (``slot_id >= 0``) is still meaningful. for candidate_idx in range(0, local_eff): slot_id = tl.load( slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c ) - is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0) + is_valid = slot_id >= 0 if is_valid: block_idx = slot_id // cache_block_size @@ -1682,11 +1692,14 @@ def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel( rope_mask = (dim_offsets >= fp8_dim) & dim_mask rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + # ``candidate_offset + candidate_idx < valid_len`` is structurally + # guaranteed by the ``local_eff`` cap above; only the per-cell + # sentinel check (``slot_id >= 0``) is still meaningful. for candidate_idx in range(0, local_eff): slot_id = tl.load( slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c ) - is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0) + is_valid = slot_id >= 0 if is_valid: block_idx = slot_id // cache_block_size @@ -1889,51 +1902,50 @@ def _accumulate_fp8ds_paged_attention_chunk_kernel( tl.maximum(gather_len - candidate_offset, 0), ) + # ``gather_idx < gather_len`` is structurally guaranteed by the + # ``local_eff`` cap above; the body is unconditional. for candidate_idx in range(0, local_eff): gather_idx = candidate_offset + candidate_idx - is_valid = gather_idx < gather_len - - if is_valid: - pos = start_pos + gather_idx - block_in_seq = pos // cache_block_size - pos_in_block = pos % cache_block_size - physical_block = tl.load( - block_table_ptr + token_idx * stride_block_table_t + block_in_seq - ) - cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride - token_data_ptr = cache_block_ptr + pos_in_block * token_data_size - token_scale_ptr = ( - cache_block_ptr - + cache_block_size * token_data_size - + pos_in_block * scale_dim - ) + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) - x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0) - x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) - x_float = x_fp8.to(tl.float32) - scale_offsets = offsets // quant_block - encoded_scale = tl.load( - token_scale_ptr + scale_offsets, - mask=fp8_mask, - other=127, - ) - dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) - x_dequant = x_float * dequant_scale + x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale - rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) - rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( - tl.float32 - ) - kv = tl.where(fp8_mask, x_dequant, rope) - kv = tl.where(dim_mask, kv, 0.0) + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) - score = tl.sum(q * kv, axis=0) * scale - next_max = tl.maximum(running_max, score) - previous_weight = tl.exp(running_max - next_max) - candidate_weight = tl.exp(score - next_max) - running_acc = running_acc * previous_weight + kv * candidate_weight - running_denom = running_denom * previous_weight + candidate_weight - running_max = next_max + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max tl.store(max_score_ptr + state_offset, running_max) tl.store(denom_ptr + state_offset, running_denom) @@ -2101,54 +2113,53 @@ def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel( tl.maximum(gather_len - candidate_offset, 0), ) + # ``gather_idx < gather_len`` is structurally guaranteed by the + # ``local_eff`` cap above; the body is unconditional. for candidate_idx in range(0, local_eff): gather_idx = candidate_offset + candidate_idx - is_valid = gather_idx < gather_len - - if is_valid: - pos = start_pos + gather_idx - block_in_seq = pos // cache_block_size - pos_in_block = pos % cache_block_size - physical_block = tl.load( - block_table_ptr + token_idx * stride_block_table_t + block_in_seq - ) - cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride - token_data_ptr = cache_block_ptr + pos_in_block * token_data_size - token_scale_ptr = ( - cache_block_ptr - + cache_block_size * token_data_size - + pos_in_block * scale_dim - ) - - x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) - x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) - x_float = x_fp8.to(tl.float32) - scale_offsets = dim_offsets // quant_block - encoded_scale = tl.load( - token_scale_ptr + scale_offsets, - mask=fp8_mask, - other=127, - ) - dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) - x_dequant = x_float * dequant_scale + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) - rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) - rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( - tl.float32 - ) - kv = tl.where(fp8_mask, x_dequant, rope) - kv = tl.where(dim_mask, kv, 0.0) + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale - score = tl.sum(q * kv[None, :], axis=1) * scale - next_max = tl.maximum(running_max, score) - previous_weight = tl.exp(running_max - next_max) - candidate_weight = tl.exp(score - next_max) - running_acc = ( - running_acc * previous_weight[:, None] - + kv[None, :] * candidate_weight[:, None] - ) - running_denom = running_denom * previous_weight + candidate_weight - running_max = next_max + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) @@ -2303,53 +2314,53 @@ def _fp8ds_paged_attention_with_sink_multihead_kernel( tl.maximum(gather_len - candidate_offset, 0), ) + # ``gather_idx < gather_len`` is structurally guaranteed by the + # ``local_eff`` cap above; the body is unconditional. for candidate_idx in range(0, local_eff): gather_idx = candidate_offset + candidate_idx - is_valid = gather_idx < gather_len - if is_valid: - pos = start_pos + gather_idx - block_in_seq = pos // cache_block_size - pos_in_block = pos % cache_block_size - physical_block = tl.load( - block_table_ptr + token_idx * stride_block_table_t + block_in_seq - ) - cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride - token_data_ptr = cache_block_ptr + pos_in_block * token_data_size - token_scale_ptr = ( - cache_block_ptr - + cache_block_size * token_data_size - + pos_in_block * scale_dim - ) - - x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) - x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) - x_float = x_fp8.to(tl.float32) - scale_offsets = dim_offsets // quant_block - encoded_scale = tl.load( - token_scale_ptr + scale_offsets, - mask=fp8_mask, - other=127, - ) - dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) - x_dequant = x_float * dequant_scale + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) - rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) - rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( - tl.float32 - ) - kv = tl.where(fp8_mask, x_dequant, rope) - kv = tl.where(dim_mask, kv, 0.0) + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale - score = tl.sum(q * kv[None, :], axis=1) * scale - next_max = tl.maximum(running_max, score) - previous_weight = tl.exp(running_max - next_max) - candidate_weight = tl.exp(score - next_max) - running_acc = ( - running_acc * previous_weight[:, None] - + kv[None, :] * candidate_weight[:, None] - ) - running_denom = running_denom * previous_weight + candidate_weight - running_max = next_max + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max sink = tl.load(sink_ptr + head_offsets, mask=head_mask, other=-float("inf")) has_tokens = running_denom > 0.0 From 922e64b0dafc844daa240790a3a6029cab87a47b Mon Sep 17 00:00:00 2001 From: alexbi29 Date: Sat, 16 May 2026 06:58:46 +0000 Subject: [PATCH 032/131] sm12x: multi-head prefill accumulate kernel + drop fp8 einsum autotune MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two prefill performance fixes for SM12x DeepSeek V4: 1. Add _accumulate_indexed_attention_chunk_multihead_kernel (HEAD_BLOCK=8) that loads KV once per candidate and reuses across 8 heads, reducing L2 traffic in the prefill accumulate phase. Same pattern as the existing decode _finish_materialized_scores_with_sink_kernel. Prefill throughput on 2× RTX PRO 6000 WS, TP=2, MTP=2: - 1K tokens: +49% (2,746 → 4,100 tok/s) - 4.5K tokens: +37% (3,122 → 4,271 tok/s) - 18K tokens: +36% (2,474 → 3,360 tok/s) - 64K tokens: +28% (1,679 → 2,146 tok/s) Tuned config: HEAD_BLOCK=8, num_warps=4, num_stages=2. Benchmarked against HEAD_BLOCK=4 and num_warps=8 variants — HEAD_BLOCK=8 with num_warps=4 wins at all sizes. 2. Drop @triton.autotune from _deepseek_v4_sm12x_fp8_einsum_kernel and pin num_warps=4, num_stages=3. The autotune key included num_tokens which varies per request, causing ~200 unique keys with zero cache hits — re-benchmarking 4 configs at ~1s each on every request. Co-Authored-By: Claude Opus 4.6 (cherry picked from commit 9c2e7ca85c91b1bfbcf0d04ba5156d4233460592) Signed-off-by: jasl --- .../deepseek_v4/nvidia/ops/fp8_einsum.py | 8 - .../backends/mla/sparse_mla_kernels.py | 215 +++++++++++++++--- 2 files changed, 186 insertions(+), 37 deletions(-) diff --git a/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py index ec256eb03392..31353e62748f 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py +++ b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py @@ -171,14 +171,6 @@ def deepseek_v4_sm12x_fp8_einsum( BLOCK_TOKENS=block_tokens, BLOCK_OUT=block_out, BLOCK_HIDDEN=block_hidden, - # Pinned to the SM12x-optimal config: a previous ``@triton.autotune`` - # block selected from {num_warps in {4,8}, num_stages in {2,3}} with - # key=["num_tokens", ...]. ``num_tokens`` varies per request, so the - # autotune cache missed every call and the 4-config bench replayed - # on every shape — pure overhead. The other three keys are - # model-architecture-fixed, so the same config (num_warps=4, - # num_stages=3) always won; we pin it directly. Reported by - # ``alexbi29`` in PR #41834 comment 4464750956. num_warps=4, num_stages=3, ) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py index 8985d53f9e51..4afa5c0aed4b 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -1344,6 +1344,128 @@ def _accumulate_indexed_attention_chunk_kernel( tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) +_PREFILL_INDEXED_HEAD_BLOCK = 8 + + +@triton.jit +def _accumulate_indexed_attention_chunk_multihead_kernel( + q_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=head_mask[:, None] & dim_mask[None, :], + other=0.0, + ).to(tl.float32) + + state_base = token_idx * stride_state_t + running_max = tl.load( + max_score_ptr + state_base + head_offsets * stride_state_h, + mask=head_mask, + other=float("-inf"), + ) + running_denom = tl.load( + denom_ptr + state_base + head_offsets * stride_state_h, + mask=head_mask, + other=0.0, + ) + acc_base = token_idx * stride_acc_t + running_acc = tl.load( + acc_ptr + + acc_base + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d, + mask=head_mask[:, None] & dim_mask[None, :], + other=0.0, + ).to(tl.float32) + + valid_len = tl.load(lens_ptr + token_idx) + local_eff = tl.minimum( + num_candidates, + tl.maximum(valid_len - candidate_offset, 0), + ) + + for candidate_idx in range(0, local_eff): + kv_index = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidate_idx * stride_indices_c + ) + is_valid = kv_index >= 0 + + if is_valid: + kv = tl.load( + kv_flat_ptr + + kv_index.to(tl.int64) * stride_kv_t + + dim_offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + scores = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, scores) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(scores - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store( + max_score_ptr + state_base + head_offsets * stride_state_h, + running_max, + mask=head_mask, + ) + tl.store( + denom_ptr + state_base + head_offsets * stride_state_h, + running_denom, + mask=head_mask, + ) + tl.store( + acc_ptr + + acc_base + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d, + running_acc, + mask=head_mask[:, None] & dim_mask[None, :], + ) + + def accumulate_indexed_sparse_mla_attention_chunk( q: torch.Tensor, kv_flat: torch.Tensor, @@ -1379,35 +1501,70 @@ def accumulate_indexed_sparse_mla_attention_chunk( num_heads = max_score.shape[1] num_candidates = indices.shape[1] block_d = min(1024, triton.next_power_of_2(head_dim)) - grid = (num_tokens, num_heads) - _accumulate_indexed_attention_chunk_kernel[grid]( - q, - kv_flat, - indices, - lens, - max_score, - denom, - acc, - q.stride(0), - q.stride(1), - q.stride(2), - kv_flat.stride(0), - kv_flat.stride(1), - indices.stride(0), - indices.stride(1), - max_score.stride(0), - max_score.stride(1), - acc.stride(0), - acc.stride(1), - acc.stride(2), - num_heads, - head_dim, - num_candidates, - candidate_offset, - scale, - BLOCK_D=block_d, - # num_warps / num_stages supplied by @triton.autotune above. - ) + head_block = _PREFILL_INDEXED_HEAD_BLOCK + + if num_heads >= head_block: + grid = (num_tokens, triton.cdiv(num_heads, head_block)) + _accumulate_indexed_attention_chunk_multihead_kernel[grid]( + q, + kv_flat, + indices, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block, + BLOCK_D=block_d, + num_warps=4, + num_stages=2, + ) + else: + grid = (num_tokens, num_heads) + _accumulate_indexed_attention_chunk_kernel[grid]( + q, + kv_flat, + indices, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + ) + @triton.autotune( From 9c37c22a2fe3b1f67ecec0638b9cdd41d00d0ed9 Mon Sep 17 00:00:00 2001 From: jasl Date: Mon, 18 May 2026 07:05:17 +0800 Subject: [PATCH 033/131] =?UTF-8?q?sm12x:=20add=20fused-MoE=20FP8=20W8A8?= =?UTF-8?q?=20Block=20configs=20for=20RTX=20PRO=206000=20(4=20shapes=20?= =?UTF-8?q?=C3=97=203=20variants)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tuned with Triton 3.6.0 at vllm@c92696943 on NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition, 10 M-buckets per shape (1, 2, 4, 8, 16, 32, 64, 128, 256, 512), 640-config search space filtered to BLOCK_SIZE_M >= M/8 for M >= 64. Four typical SM12x DSv4-Flash deployment shapes — none had prior tuning in the vLLM tree at this revision: E=128, N=2048 (TP=2 + EP, production shape, 2x RTX PRO 6000) E=64, N=2048 (TP=4 + EP, 4x RTX PRO 6000) — tune ~1h47m E=32, N=2048 (TP=8 + EP, 8x RTX PRO 6000) — tune ~1h00m E=256, N=1024 (TP=2 no-EP fallback, 2x RTX PRO 6000) — tune ~2h14m Aliased identically to the Max-Q Workstation Edition and Server Edition variants since they share the same silicon (GB202) as the Workstation Edition and only differ in power/form-factor envelope; copying yields the same Triton autotune optima. Signed-off-by: jasl --- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 +++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 +++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 +++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 +++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 +++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 +++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 +++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 +++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 +++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 +++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 +++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 +++++++++++++++++++ 12 files changed, 996 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..06431d4d355e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..06431d4d355e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..06431d4d355e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..4b98ba105c14 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..4b98ba105c14 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..4b98ba105c14 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..952b908297b7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..952b908297b7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..952b908297b7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..36d0361b926f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..36d0361b926f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..36d0361b926f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,83 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} From ff98c4b5836f79ac5b10be7687ff9662e4e7850e Mon Sep 17 00:00:00 2001 From: jasl Date: Mon, 18 May 2026 20:57:22 +0800 Subject: [PATCH 034/131] sm120: use Triton MQA logits for direct topk fallback Route the direct FP8 MQA top-k fallback through the existing Triton logits kernel when the materialized logits fit within a bounded workspace, then keep the previous PyTorch path as the fallback for larger or unsupported shapes. This removes the 127K prefill PyTorch/CUTLASS logits hotspot on RTX PRO 6000 while preserving the short-context path. Local SM120 validation showed 127K C=1 cold TTFT mean improving from 60.83s to 37.65s with no short-context regression. Inspired-by: vllm-project/vllm#41834 comment 4476480477 Signed-off-by: jasl --- .../test_sm120_deepgemm_fallbacks.py | 66 ++++++++++++++++++- .../nvidia/ops/sm12x_deep_gemm_fallbacks.py | 45 +++++++++++++ 2 files changed, 110 insertions(+), 1 deletion(-) diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py index fa8a31b3676b..720b4d171434 100644 --- a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -11,7 +11,10 @@ ) from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv -from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks +from vllm.v1.attention.ops.deepseek_v4_ops import ( + sm12x_deep_gemm_fallbacks, + sm12x_mqa, +) def _make_indexer_kv_cache( @@ -94,6 +97,67 @@ def test_decode_topk_logits_width_keeps_topk_kernel_width(): assert _decode_topk_logits_width(0, 128, 512) == 0 +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_mqa_direct_topk_uses_triton_logits_when_logits_fit( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(11) + num_q, num_heads, head_dim = 8, 4, 64 + seq_len_kv, topk_tokens = 17, 5 + monkeypatch.setattr( + sm12x_deep_gemm_fallbacks, + "_SM120_MQA_TRITON_TOPK_MAX_LOGITS_BYTES", + num_q * seq_len_kv * torch.float32.itemsize, + ) + + q = torch.randn(num_q, num_heads, head_dim, device="cuda", dtype=torch.bfloat16) + q_fp8 = q.to(torch.float8_e4m3fn).contiguous() + kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16) + kv_scale = kv.abs().float().amax(dim=-1).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()[:, None]).to(torch.float8_e4m3fn) + weights = torch.randn(num_q, num_heads, device="cuda", dtype=torch.float32) + cu_seqlen_ks = torch.arange(num_q, device="cuda", dtype=torch.int32) % 3 + cu_seqlen_ke = torch.full( + (num_q,), seq_len_kv, device="cuda", dtype=torch.int32 + ) + out = torch.empty(num_q, topk_tokens, device="cuda", dtype=torch.int32) + + original_triton = sm12x_mqa.fp8_mqa_logits_triton + calls = 0 + + def wrapped_triton(*args, **kwargs): + nonlocal calls + calls += 1 + return original_triton(*args, **kwargs) + + monkeypatch.setattr(sm12x_mqa, "fp8_mqa_logits_triton", wrapped_triton) + + assert deep_gemm_utils.fp8_fp4_mqa_topk_indices( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + out, + ) + assert calls == 1 + + reference_logits = sm12x_deep_gemm_fallbacks._fp8_mqa_logits_torch( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + expected = torch.topk(reference_logits, topk_tokens, dim=1).indices.to( + torch.int32 + ) + torch.testing.assert_close(out, expected, rtol=0, atol=0) + + @pytest.mark.skipif( not current_platform.is_device_capability_family(120), reason="SM120 only" ) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py index 8fa9b2697c91..1c42b7674aa9 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py @@ -10,6 +10,7 @@ logger = init_logger(__name__) _SM120_MQA_LOGITS_MAX_SCORE_BYTES = 64 * 1024 * 1024 +_SM120_MQA_TRITON_TOPK_MAX_LOGITS_BYTES = 512 * 1024 * 1024 _SM120_PAGED_MQA_TOPK_CHUNK_SIZE = 8192 @@ -209,6 +210,41 @@ def _fp8_mqa_logits_topk_torch( return out +def _fp8_mqa_logits_topk_triton( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + out: torch.Tensor, +) -> bool: + q_values, q_scale = q + k_values, _ = kv + if not (q_scale is None and q_values.dim() == 3 and k_values.dim() == 2): + return False + + logits_bytes = q_values.shape[0] * k_values.shape[0] * torch.float32.itemsize + if logits_bytes > _SM120_MQA_TRITON_TOPK_MAX_LOGITS_BYTES: + return False + + from vllm.v1.attention.ops.deepseek_v4_ops.sm12x_mqa import ( + fp8_mqa_logits_triton, + ) + + logits = fp8_mqa_logits_triton(q_values, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + topk_tokens = out.shape[1] + select_k = min(topk_tokens, logits.shape[1]) + out.fill_(-1) + if select_k == 0: + return True + + values, indices = torch.topk(logits, select_k, dim=1) + selected = out[:, :select_k] + selected.copy_(indices.to(torch.int32)) + selected.masked_fill_(~torch.isfinite(values), -1) + return True + + def fp8_fp4_mqa_topk_indices( q: tuple[torch.Tensor, torch.Tensor | None], kv: tuple[torch.Tensor, torch.Tensor], @@ -224,6 +260,15 @@ def fp8_fp4_mqa_topk_indices( and q[1] is None ): return False + if _fp8_mqa_logits_topk_triton( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_indices, + ): + return True _fp8_mqa_logits_topk_torch( q, kv, From ef37fc536d9e0b35ac3feeb0d72c50dc31c2f750 Mon Sep 17 00:00:00 2001 From: jasl Date: Mon, 18 May 2026 21:24:18 +0800 Subject: [PATCH 035/131] sm120: use custom row topk for MQA fallback indices Use the existing top_k_per_row_prefill CUDA op after the SM120 Triton MQA logits fallback materializes logits. The op writes int32 indices directly and respects row bounds, avoiding torch.topk plus int64-to-int32 copy in the long-context prefill path. On RTX PRO 6000 TP=2, 127K C=1 cold TTFT mean improved from 37.65s to 36.87s with no short-context regression across C=1/2/4. Signed-off-by: jasl --- .../test_sm120_deepgemm_fallbacks.py | 50 +++++++++++++++++-- .../nvidia/ops/sm12x_deep_gemm_fallbacks.py | 33 ++++++++++-- 2 files changed, 76 insertions(+), 7 deletions(-) diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py index 720b4d171434..e617097fe6dd 100644 --- a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -83,6 +83,35 @@ def _reference_paged_mqa_logits( return logits +def _assert_topk_indices_match_values( + logits: torch.Tensor, + actual: torch.Tensor, + topk_tokens: int, +) -> None: + for row_idx in range(logits.shape[0]): + valid_count = int(torch.isfinite(logits[row_idx]).sum().item()) + row_topk = min(topk_tokens, valid_count) + expected = torch.topk(logits[row_idx], row_topk, dim=0).indices.to(torch.int32) + actual_row = actual[row_idx, :row_topk] + assert torch.all(actual_row >= 0) + assert torch.all(torch.isfinite(logits[row_idx, actual_row.long()])) + actual_values = torch.sort(logits[row_idx, actual_row.long()]).values + expected_values = torch.sort(logits[row_idx, expected.long()]).values + torch.testing.assert_close( + actual_values, + expected_values, + rtol=0, + atol=0, + ) + if row_topk < topk_tokens: + torch.testing.assert_close( + actual[row_idx, row_topk:], + torch.full_like(actual[row_idx, row_topk:], -1), + rtol=0, + atol=0, + ) + + def test_decode_logits_width_uses_active_context_bound(): assert _decode_logits_width(262144, 1024) == 1024 assert _decode_logits_width(4096, 8192) == 4096 @@ -133,6 +162,21 @@ def wrapped_triton(*args, **kwargs): return original_triton(*args, **kwargs) monkeypatch.setattr(sm12x_mqa, "fp8_mqa_logits_triton", wrapped_triton) + original_topk_op = sm12x_deep_gemm_fallbacks._top_k_per_row_prefill_op() + if original_topk_op is None: + pytest.skip("top_k_per_row_prefill op is unavailable") + topk_calls = 0 + + def wrapped_topk_op(*args, **kwargs): + nonlocal topk_calls + topk_calls += 1 + return original_topk_op(*args, **kwargs) + + monkeypatch.setattr( + sm12x_deep_gemm_fallbacks, + "_top_k_per_row_prefill_op", + lambda: wrapped_topk_op, + ) assert deep_gemm_utils.fp8_fp4_mqa_topk_indices( (q_fp8, None), @@ -143,6 +187,7 @@ def wrapped_triton(*args, **kwargs): out, ) assert calls == 1 + assert topk_calls == 1 reference_logits = sm12x_deep_gemm_fallbacks._fp8_mqa_logits_torch( (q_fp8, None), @@ -152,10 +197,7 @@ def wrapped_triton(*args, **kwargs): cu_seqlen_ke, clean_logits=True, ) - expected = torch.topk(reference_logits, topk_tokens, dim=1).indices.to( - torch.int32 - ) - torch.testing.assert_close(out, expected, rtol=0, atol=0) + _assert_topk_indices_match_values(reference_logits, out, topk_tokens) @pytest.mark.skipif( diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py index 1c42b7674aa9..1e9d456760fc 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py @@ -14,6 +14,15 @@ _SM120_PAGED_MQA_TOPK_CHUNK_SIZE = 8192 +def _top_k_per_row_prefill_op(): + try: + from vllm import _custom_ops as _custom_ops # noqa: F401 + + return torch.ops._C.top_k_per_row_prefill + except (AttributeError, ImportError, RuntimeError): + return None + + def _fp8_mqa_logits_head_chunk_size( seq_len: int, seq_len_kv: int, @@ -238,10 +247,28 @@ def _fp8_mqa_logits_topk_triton( if select_k == 0: return True - values, indices = torch.topk(logits, select_k, dim=1) selected = out[:, :select_k] - selected.copy_(indices.to(torch.int32)) - selected.masked_fill_(~torch.isfinite(values), -1) + topk_op = _top_k_per_row_prefill_op() + if topk_op is not None: + topk_op( + logits, + cu_seqlen_ks, + cu_seqlen_ke, + selected, + logits.shape[0], + logits.stride(0), + logits.stride(1), + select_k, + ) + selected.add_(cu_seqlen_ks[:, None]) + valid = (selected >= cu_seqlen_ks[:, None]) & ( + selected < cu_seqlen_ke[:, None] + ) + selected.masked_fill_(~valid, -1) + else: + values, indices = torch.topk(logits, select_k, dim=1) + selected.copy_(indices.to(torch.int32)) + selected.masked_fill_(~torch.isfinite(values), -1) return True From 2ebb26f80907cbb7a414db8030fef46d495f2c08 Mon Sep 17 00:00:00 2001 From: jasl Date: Mon, 18 May 2026 23:21:43 +0800 Subject: [PATCH 036/131] sm120: widen FP8 MQA logits tile Signed-off-by: jasl --- vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py index 8593989e8647..70c8b6d6b521 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py @@ -147,7 +147,7 @@ def fp8_mqa_logits_triton( if num_q == 0 or seq_len_kv == 0: return logits - grid = (triton.cdiv(num_q, 8), triton.cdiv(seq_len_kv, 64)) + grid = (triton.cdiv(num_q, 8), triton.cdiv(seq_len_kv, 128)) _fp8_mqa_logits_kernel[grid]( q, k_fp8, @@ -170,7 +170,7 @@ def fp8_mqa_logits_triton( logits.stride(0), logits.stride(1), BLOCK_M=8, - BLOCK_N=64, + BLOCK_N=128, BLOCK_D=64, num_warps=4, ) From e072e2054d7e75eb136a79ef2d5cf5ca7d92a69e Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 19 May 2026 00:14:06 +0800 Subject: [PATCH 037/131] sm120: increase FP8 MQA logits row tile Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl --- vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py index 70c8b6d6b521..23e42cc07d25 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py @@ -147,7 +147,7 @@ def fp8_mqa_logits_triton( if num_q == 0 or seq_len_kv == 0: return logits - grid = (triton.cdiv(num_q, 8), triton.cdiv(seq_len_kv, 128)) + grid = (triton.cdiv(num_q, 16), triton.cdiv(seq_len_kv, 128)) _fp8_mqa_logits_kernel[grid]( q, k_fp8, @@ -169,7 +169,7 @@ def fp8_mqa_logits_triton( weights.stride(1), logits.stride(0), logits.stride(1), - BLOCK_M=8, + BLOCK_M=16, BLOCK_N=128, BLOCK_D=64, num_warps=4, From 1028229eae0fb4cc4a95e7df85486320aad52946 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 19 May 2026 09:23:35 +0800 Subject: [PATCH 038/131] Fix DeepSeek V4 MTP sparse SWA reordering Signed-off-by: jasl --- .../attention/test_deepseek_v4_sparse_swa.py | 42 +++++++++++++++++++ vllm/v1/attention/backends/mla/sparse_swa.py | 17 +++----- 2 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 tests/v1/attention/test_deepseek_v4_sparse_swa.py diff --git a/tests/v1/attention/test_deepseek_v4_sparse_swa.py b/tests/v1/attention/test_deepseek_v4_sparse_swa.py new file mode 100644 index 000000000000..158c5af8b1cd --- /dev/null +++ b/tests/v1/attention/test_deepseek_v4_sparse_swa.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from tests.v1.attention.utils import create_vllm_config +from vllm.config import SpeculativeConfig +from vllm.v1.attention.backends.mla.sparse_swa import ( + DeepseekSparseSWAMetadataBuilder, +) +from vllm.v1.kv_cache_interface import MLAAttentionSpec + + +def test_sparse_swa_reorder_threshold_matches_mtp_decode_threshold(): + vllm_config = create_vllm_config( + block_size=256, + hf_config_override={ + "sliding_window": 128, + "compress_ratios": [1, 4, 128], + }, + ) + vllm_config.speculative_config = SpeculativeConfig( + method="ngram", + num_speculative_tokens=2, + ) + kv_cache_spec = MLAAttentionSpec( + block_size=256, + num_kv_heads=1, + head_size=512, + dtype=torch.bfloat16, + compress_ratio=4, + ) + + builder = DeepseekSparseSWAMetadataBuilder( + kv_cache_spec=kv_cache_spec, + layer_names=["dummy"], + vllm_config=vllm_config, + device=torch.device("cpu"), + ) + + assert builder.decode_threshold == 3 + assert builder.reorder_batch_threshold == builder.decode_threshold diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index 3ff897761370..3788f59b1a20 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -223,18 +223,11 @@ def __init__(self, *args, **kwargs): self.compress_ratio = mla_spec.compress_ratio self.block_size = mla_spec.block_size - # Handle MTP: adjust decode_threshold like the indexer does - self.num_speculative_tokens = ( - self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config - else 0 - ) - # With MTP, decode can have query_len up to 1 + num_speculative_tokens. - # Must match the threshold used by the indexer and flashmla_sparse so - # that all backends agree on the decode/prefill split. - self.decode_threshold = ( - self.reorder_batch_threshold + self.num_speculative_tokens - ) + # Handle MTP: classify single-token queries plus speculative tokens as + # decodes, matching the runner-side batch reorder threshold. + self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) + assert self.reorder_batch_threshold is not None + self.decode_threshold = self.reorder_batch_threshold hf_config = self.vllm_config.model_config.hf_config assert hasattr(hf_config, "sliding_window") From c429c77dba62141f9698e292882b65d5e47c7642 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 19 May 2026 17:06:14 +0800 Subject: [PATCH 039/131] sm12x: update DeepSeek V4 fallback imports Signed-off-by: jasl --- tests/v1/attention/test_sm120_deepgemm_fallbacks.py | 2 +- .../nvidia/ops/sm12x_deep_gemm_fallbacks.py | 2 +- vllm/utils/deep_gemm.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py index e617097fe6dd..7bc8cc4df55e 100644 --- a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -11,7 +11,7 @@ ) from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv -from vllm.v1.attention.ops.deepseek_v4_ops import ( +from vllm.models.deepseek_v4.nvidia.ops import ( sm12x_deep_gemm_fallbacks, sm12x_mqa, ) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py index 1e9d456760fc..3293659a6895 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py @@ -236,7 +236,7 @@ def _fp8_mqa_logits_topk_triton( if logits_bytes > _SM120_MQA_TRITON_TOPK_MAX_LOGITS_BYTES: return False - from vllm.v1.attention.ops.deepseek_v4_ops.sm12x_mqa import ( + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( fp8_mqa_logits_triton, ) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 6b6fac63e261..a1508aef1c40 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -374,7 +374,7 @@ def fp8_fp4_mqa_topk_indices( and q[1] is None ): return False - from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + from vllm.models.deepseek_v4.nvidia.ops import sm12x_deep_gemm_fallbacks return sm12x_deep_gemm_fallbacks.fp8_fp4_mqa_topk_indices( q, @@ -394,7 +394,7 @@ def _fp8_mqa_logits_sm12x( cu_seqlen_ke: torch.Tensor, clean_logits: bool, ) -> torch.Tensor: - from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + from vllm.models.deepseek_v4.nvidia.ops import sm12x_deep_gemm_fallbacks return sm12x_deep_gemm_fallbacks._fp8_mqa_logits_sm12x( q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits @@ -479,7 +479,7 @@ def _fp8_paged_mqa_logits_sm12x( block_tables: torch.Tensor, max_model_len: int, ) -> torch.Tensor: - from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + from vllm.models.deepseek_v4.nvidia.ops import sm12x_deep_gemm_fallbacks return sm12x_deep_gemm_fallbacks._fp8_paged_mqa_logits_sm12x( q, kv_cache, weights, context_lens, block_tables, max_model_len @@ -502,7 +502,7 @@ def fp8_fp4_paged_mqa_topk_indices( and q[1] is None ): return False - from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + from vllm.models.deepseek_v4.nvidia.ops import sm12x_deep_gemm_fallbacks return sm12x_deep_gemm_fallbacks.fp8_fp4_paged_mqa_topk_indices( q, @@ -580,7 +580,7 @@ def _tf32_hc_prenorm_gemm_sm12x( sqrsum: torch.Tensor, num_split: int, ) -> torch.Tensor: - from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + from vllm.models.deepseek_v4.nvidia.ops import sm12x_deep_gemm_fallbacks return sm12x_deep_gemm_fallbacks._tf32_hc_prenorm_gemm_sm12x( x, fn, out, sqrsum, num_split From 2fb99cceb9024af749e7faa3c8298ce389200dc3 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 19 May 2026 17:18:17 +0800 Subject: [PATCH 040/131] tests: update DeepSeek V4 MegaMoE refactor assumptions Signed-off-by: jasl --- tests/models/test_deepseek_v4_mega_moe.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/models/test_deepseek_v4_mega_moe.py b/tests/models/test_deepseek_v4_mega_moe.py index 3daae242d459..2e9b76f5c030 100644 --- a/tests/models/test_deepseek_v4_mega_moe.py +++ b/tests/models/test_deepseek_v4_mega_moe.py @@ -6,6 +6,7 @@ import pytest import torch +from vllm.config import CompilationConfig from vllm.models.deepseek_v4.nvidia.model import ( DeepseekV4MegaMoEExperts, make_deepseek_v4_expert_params_mapping, @@ -46,7 +47,8 @@ def test_deepseek_v4_mega_moe_ue8m0_uint8_to_float(): def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership(): vllm_config = SimpleNamespace( - scheduler_config=SimpleNamespace(max_num_batched_tokens=4) + compilation_config=CompilationConfig(), + scheduler_config=SimpleNamespace(max_num_batched_tokens=4), ) experts = DeepseekV4MegaMoEExperts( vllm_config, @@ -111,7 +113,7 @@ def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership(): reason="DeepSeek V4 MegaMoE fused input staging requires CUDA.", ) def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact(): - from vllm.third_party.deep_gemm.utils import per_token_cast_to_fp8 + deep_gemm_utils = pytest.importorskip("deep_gemm.utils") device = torch.device("cuda") num_tokens = 7 @@ -150,7 +152,7 @@ def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact(): generator=generator, ) - ref_x, ref_x_sf = per_token_cast_to_fp8( + ref_x, ref_x_sf = deep_gemm_utils.per_token_cast_to_fp8( hidden_states, use_ue8m0=True, gran_k=32, From f30a7af7445fa413ba13a754576242c3deec7e4a Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 19 May 2026 19:54:45 +0800 Subject: [PATCH 041/131] Fix DeepSeek V4 MLA prompt cache protection Signed-off-by: jasl --- vllm/v1/core/single_type_kv_cache_manager.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 042d149393d3..2ccffaed18d8 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -705,11 +705,7 @@ def cache_blocks( alignment_tokens: int | None = None, ) -> None: super().cache_blocks(request, num_tokens, alignment_tokens=alignment_tokens) - if ( - not self._should_protect_prompt_blocks() - or num_tokens < request.num_prompt_tokens - or request.num_prompt_tokens <= 1 - ): + if not self._should_protect_prompt_blocks() or request.num_prompt_tokens <= 1: return max_cache_hit_length = request.num_prompt_tokens - 1 @@ -718,6 +714,8 @@ def cache_blocks( // self.cache_alignment_tokens * self.cache_alignment_tokens ) + if aligned_cache_hit_length <= 0 or num_tokens < aligned_cache_hit_length: + return num_hit_blocks = aligned_cache_hit_length // self.block_size if num_hit_blocks == 0: return @@ -963,9 +961,7 @@ def cache_blocks( alignment_tokens: int | None = None, ) -> None: super().cache_blocks(request, num_tokens, alignment_tokens=alignment_tokens) - if not self.enable_caching or num_tokens < request.num_prompt_tokens: - return - if request.num_prompt_tokens <= 1: + if not self.enable_caching or request.num_prompt_tokens <= 1: return max_cache_hit_length = request.num_prompt_tokens - 1 @@ -974,7 +970,7 @@ def cache_blocks( // self.cache_alignment_tokens * self.cache_alignment_tokens ) - if aligned_cache_hit_length <= 0: + if aligned_cache_hit_length <= 0 or num_tokens < aligned_cache_hit_length: return aligned_num_hit_blocks = aligned_cache_hit_length // self.block_size From 0101ca23ef3697090a7408dc4743d206203dfb72 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 19 May 2026 19:54:50 +0800 Subject: [PATCH 042/131] Clean up DeepSeek V4 upstream rebase leftovers Signed-off-by: jasl --- tests/models/test_deepseek_v4_mega_moe.py | 64 +++++++++++++++++++ .../attention/test_deepseek_v4_sparse_swa.py | 2 +- tests/v1/spec_decode/test_mtp.py | 20 +----- .../deepseek_v4/common/ops/fused_indexer_q.py | 10 +-- vllm/models/deepseek_v4/nvidia/model.py | 13 +--- vllm/utils/deep_gemm.py | 36 +++++++++++ 6 files changed, 105 insertions(+), 40 deletions(-) diff --git a/tests/models/test_deepseek_v4_mega_moe.py b/tests/models/test_deepseek_v4_mega_moe.py index 2e9b76f5c030..ca0ec781e269 100644 --- a/tests/models/test_deepseek_v4_mega_moe.py +++ b/tests/models/test_deepseek_v4_mega_moe.py @@ -13,6 +13,7 @@ ) from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs from vllm.platforms import current_platform +from vllm.utils import deep_gemm as deep_gemm_utils pytestmark = pytest.mark.skipif( not current_platform.is_cuda(), @@ -108,6 +109,69 @@ def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership(): assert torch.count_nonzero(experts.w13_weight[1]) == 0 +def test_deepseek_v4_mega_moe_finalize_uses_deep_gemm_wrapper(monkeypatch): + vllm_config = SimpleNamespace( + compilation_config=CompilationConfig(), + scheduler_config=SimpleNamespace(max_num_batched_tokens=4), + ) + experts = DeepseekV4MegaMoEExperts( + vllm_config, + num_experts=4, + num_local_experts=2, + experts_start_idx=0, + top_k=2, + hidden_size=128, + intermediate_size=128, + ) + + transformed = (object(), object()) + calls = [] + + def fake_runtime_check(self): + calls.append("runtime_check") + + def fake_transform_sf_into_required_layout( + scale, rows, cols, block_shape, num_local_experts + ): + calls.append((rows, cols, block_shape, num_local_experts)) + return scale + + def fake_transform_weights_for_mega_moe(w13_weight, w2_weight): + calls.append((w13_weight[0].shape, w2_weight[0].shape)) + return transformed + + monkeypatch.setattr( + DeepseekV4MegaMoEExperts, + "_check_runtime_supported", + fake_runtime_check, + ) + monkeypatch.setattr( + deep_gemm_utils, + "transform_sf_into_required_layout", + fake_transform_sf_into_required_layout, + ) + monkeypatch.setattr( + deep_gemm_utils, + "transform_weights_for_mega_moe", + fake_transform_weights_for_mega_moe, + ) + + experts.finalize_weights() + + assert experts._transformed_l1_weights is transformed[0] + assert experts._transformed_l2_weights is transformed[1] + assert experts.w13_weight is None + assert experts.w13_weight_scale is None + assert experts.w2_weight is None + assert experts.w2_weight_scale is None + assert calls == [ + "runtime_check", + (256, 128, (1, 32), 2), + (128, 128, (1, 32), 2), + (torch.Size([2, 256, 64]), torch.Size([2, 128, 64])), + ] + + @pytest.mark.skipif( not torch.cuda.is_available(), reason="DeepSeek V4 MegaMoE fused input staging requires CUDA.", diff --git a/tests/v1/attention/test_deepseek_v4_sparse_swa.py b/tests/v1/attention/test_deepseek_v4_sparse_swa.py index 158c5af8b1cd..d2f7c1d5041f 100644 --- a/tests/v1/attention/test_deepseek_v4_sparse_swa.py +++ b/tests/v1/attention/test_deepseek_v4_sparse_swa.py @@ -11,7 +11,7 @@ from vllm.v1.kv_cache_interface import MLAAttentionSpec -def test_sparse_swa_reorder_threshold_matches_mtp_decode_threshold(): +def test_sparse_swa_reorder_threshold_matches_spec_decode_threshold(): vllm_config = create_vllm_config( block_size=256, hf_config_override={ diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index 53b91739c979..b4056fb67567 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -268,11 +268,7 @@ def test_mtp_propose_random_sampling_records_draft_probs(): vocab_size = 4 proposer = _create_mtp_proposer(num_speculative_tokens=1) - # Mirror upstream's f51f6844f gating: probabilities are only collected - # when the speculative config explicitly opts into the probabilistic - # draft-model rejection path. We force the flag here so the test - # exercises ``_sample_draft_tokens``'s probabilistic branch. - proposer._enable_probabilistic_draft_probs = True + assert proposer._enable_probabilistic_draft_probs hidden_size = proposer.hidden_size model_mock = mock.MagicMock() @@ -432,9 +428,7 @@ def test_mtp_parallel_drafting_random_sampling_records_draft_probs(): num_speculative_tokens=num_spec_tokens, parallel_drafting=True, ) - # Mirror upstream's f51f6844f gating: see the matching comment in - # ``test_mtp_propose_random_sampling_records_draft_probs``. - proposer._enable_probabilistic_draft_probs = True + assert proposer._enable_probabilistic_draft_probs proposer.block_size = 16 hidden_size = proposer.hidden_size @@ -508,13 +502,3 @@ def test_mtp_parallel_drafting_random_sampling_records_draft_probs(): assert torch.equal( proposer.take_last_draft_probs(), proposer._last_draft_probs ) - - -# Tests for ``_get_draft_probs_for_rejection`` and the -# positional ``runner_stub._draft_probs`` packing path were removed when -# our MTP scheduling commit dropped that branch in favor of upstream's -# req-id-indexed ``_get_spec_decode_draft_probs`` (added by -# vllm-project/vllm#40269 / f51f6844f). The remaining -# ``_get_draft_probs_for_rejection`` tests that follow have been deleted -# along with the function; equivalent coverage for the new code lives -# under ``tests/v1/worker/test_gpu_model_runner.py``. diff --git a/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py b/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py index c7fbd93f329b..d5aaf10feba4 100644 --- a/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py +++ b/vllm/models/deepseek_v4/common/ops/fused_indexer_q.py @@ -66,14 +66,6 @@ def _quantize_mxfp4_pair(x_lo, x_hi): return packed, ue8m0 -@triton.autotune( - configs=[ - triton.Config({}, num_warps=1), - triton.Config({}, num_warps=2), - triton.Config({}, num_warps=4), - ], - key=["INDEX_Q_HALF_ROT_DIM", "INDEX_Q_HEAD_DIM"], -) @triton.jit def _fused_indexer_q_rope_quant_kernel( pos_ptr, @@ -441,6 +433,6 @@ def fused_indexer_q_rope_quant( index_weights_head_scale, index_weights_out, index_weights_out.stride(0), - # num_warps supplied by @triton.autotune above. + num_warps=1, # TODO: Tune this ) return index_q_fp8, index_weights_out diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 511acaff55c5..8cf5e636b0c0 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -64,6 +64,7 @@ from vllm.models.deepseek_v4.nvidia.flashmla import DeepseekV4FlashMLAAttention from vllm.models.deepseek_v4.nvidia.ops.prepare_megamoe import prepare_megamoe_inputs from vllm.sequence import IntermediateTensors +from vllm.utils import deep_gemm from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -308,10 +309,6 @@ def finalize_weights(self) -> None: return self._check_runtime_supported() - from vllm.utils.deep_gemm import _import_deep_gemm - - deep_gemm = _import_deep_gemm() - w13_scale = deep_gemm.transform_sf_into_required_layout( self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(), 2 * self.intermediate_size, @@ -344,10 +341,6 @@ def finalize_weights(self) -> None: self.w2_weight_scale = None def get_symm_buffer(self): - from vllm.utils.deep_gemm import _import_deep_gemm - - deep_gemm = _import_deep_gemm() - group = get_ep_group().device_group device = torch.accelerator.current_device_index() key = ( @@ -434,10 +427,6 @@ def forward( ) y = torch.empty_like(hidden_states, dtype=torch.bfloat16) - from vllm.utils.deep_gemm import _import_deep_gemm - - deep_gemm = _import_deep_gemm() - symm_buffer = self.get_symm_buffer() num_tokens = hidden_states.shape[0] diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index a1508aef1c40..9c532bf2549b 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -139,6 +139,9 @@ def _missing(*_: Any, **__: Any) -> NoReturn: _get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None _get_mk_alignment_for_contiguous_layout_impl: Callable[..., Any] | None = None _transform_sf_into_required_layout_impl: Callable[..., Any] | None = None +_transform_weights_for_mega_moe_impl: Callable[..., Any] | None = None +_get_symm_buffer_for_mega_moe_impl: Callable[..., Any] | None = None +_fp8_fp4_mega_moe_impl: Callable[..., Any] | None = None @functools.cache @@ -204,6 +207,8 @@ def _lazy_init() -> None: global _get_mn_major_tma_aligned_tensor_impl global _get_mk_alignment_for_contiguous_layout_impl global _transform_sf_into_required_layout_impl + global _transform_weights_for_mega_moe_impl + global _get_symm_buffer_for_mega_moe_impl, _fp8_fp4_mega_moe_impl # fast path if ( _cublaslt_gemm_nt_impl is not None @@ -218,6 +223,9 @@ def _lazy_init() -> None: or _tf32_hc_prenorm_gemm_impl is not None or _get_mk_alignment_for_contiguous_layout_impl is not None or _transform_sf_into_required_layout_impl is not None + or _transform_weights_for_mega_moe_impl is not None + or _get_symm_buffer_for_mega_moe_impl is not None + or _fp8_fp4_mega_moe_impl is not None ): return @@ -261,6 +269,13 @@ def _lazy_init() -> None: _transform_sf_into_required_layout_impl = getattr( _dg, "transform_sf_into_required_layout", None ) + _transform_weights_for_mega_moe_impl = getattr( + _dg, "transform_weights_for_mega_moe", None + ) + _get_symm_buffer_for_mega_moe_impl = getattr( + _dg, "get_symm_buffer_for_mega_moe", None + ) + _fp8_fp4_mega_moe_impl = getattr(_dg, "fp8_fp4_mega_moe", None) DeepGemmQuantScaleFMT.init_oracle_cache() @@ -359,6 +374,27 @@ def transform_sf_into_required_layout(*args, **kwargs): ) +def transform_weights_for_mega_moe(*args, **kwargs): + _lazy_init() + if _transform_weights_for_mega_moe_impl is None: + return _missing(*args, **kwargs) + return _transform_weights_for_mega_moe_impl(*args, **kwargs) + + +def get_symm_buffer_for_mega_moe(*args, **kwargs): + _lazy_init() + if _get_symm_buffer_for_mega_moe_impl is None: + return _missing(*args, **kwargs) + return _get_symm_buffer_for_mega_moe_impl(*args, **kwargs) + + +def fp8_fp4_mega_moe(*args, **kwargs): + _lazy_init() + if _fp8_fp4_mega_moe_impl is None: + return _missing(*args, **kwargs) + return _fp8_fp4_mega_moe_impl(*args, **kwargs) + + def fp8_fp4_mqa_topk_indices( q: tuple[torch.Tensor, torch.Tensor | None], kv: tuple[torch.Tensor, torch.Tensor], From ca2dd9866b210f6ce7512964727bc7f33b750c78 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 19 May 2026 20:21:04 +0800 Subject: [PATCH 043/131] Fix CUTeDSL availability probe Detect CUTeDSL by importing cutlass instead of only checking package metadata. This lets SM12x deployments fall back to Triton when a broken or incompatible nvidia-cutlass-dsl install leaves the cutlass package visible but not importable. Reported-by: danielwu1987 Signed-off-by: jasl --- tests/utils_/test_import_utils.py | 33 +++++++++++++++++++++++++++++++ vllm/utils/import_utils.py | 9 +++++++-- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/utils_/test_import_utils.py b/tests/utils_/test_import_utils.py index 464f209f0f25..dc1e55bcd0e4 100644 --- a/tests/utils_/test_import_utils.py +++ b/tests/utils_/test_import_utils.py @@ -1,12 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import builtins +import importlib.util from unittest.mock import MagicMock, patch import pytest +from vllm.utils import import_utils from vllm.utils.import_utils import PlaceholderModule, _has_module +def _clear_import_utils_caches(): + import_utils._has_module.cache_clear() + if hasattr(import_utils.has_cutedsl, "cache_clear"): + import_utils.has_cutedsl.cache_clear() + + def _raises_module_not_found(): return pytest.raises(ModuleNotFoundError, match="No module named") @@ -101,3 +110,27 @@ def test_result_is_cached(self): result = _has_module("json") # should hit cache mock_spec.assert_not_called() assert result is True + + +def test_has_cutedsl_requires_importable_cutlass(monkeypatch: pytest.MonkeyPatch): + real_find_spec = importlib.util.find_spec + real_import = builtins.__import__ + + def fake_find_spec(name, *args, **kwargs): + if name == "cutlass": + return object() + return real_find_spec(name, *args, **kwargs) + + def fake_import(name, *args, **kwargs): + if name == "cutlass": + raise ImportError("broken CUTLASS DSL") + return real_import(name, *args, **kwargs) + + _clear_import_utils_caches() + monkeypatch.setattr(import_utils.importlib.util, "find_spec", fake_find_spec) + monkeypatch.setattr(builtins, "__import__", fake_import) + + try: + assert import_utils.has_cutedsl() is False + finally: + _clear_import_utils_caches() diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index 043798a584b8..18c83af71e32 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -539,9 +539,14 @@ def has_fbgemm_gpu() -> bool: return _has_module("fbgemm_gpu") +@cache def has_cutedsl() -> bool: - """Whether the optional `cutelass` package is available.""" - return _has_module("cutlass") + """Whether the optional `cutlass` package is available and importable.""" + try: + import cutlass # noqa: F401 + except Exception: + return False + return True def has_humming() -> bool: From 95fc9073bafe8766836d96b63c4abd84bc72a590 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 20 May 2026 02:33:09 +0800 Subject: [PATCH 044/131] Fix DeepSeek V4 MTP small-batch graph hangs Keep FULL_AND_PIECEWISE enabled for DeepSeek V4 MTP and avoid replaying small speculative decode batches against padded virtual requests by preserving exact spec-decode capture sizes for request counts 1..32. Signed-off-by: jasl --- tests/compile/test_config.py | 32 ++++++++++++++++++ .../test_deepseek_v4_kernel_warmup.py | 22 +++++++++++++ tests/v1/cudagraph/test_cudagraph_dispatch.py | 33 +++++++++++++++++++ vllm/config/compilation.py | 10 ++++++ vllm/model_executor/warmup/kernel_warmup.py | 23 ++++++++----- 5 files changed, 112 insertions(+), 8 deletions(-) create mode 100644 tests/model_executor/test_deepseek_v4_kernel_warmup.py diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index d822b68c5036..be9318b32263 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -468,6 +468,38 @@ def test_cudagraph_sizes_post_init( ) +def test_spec_decode_cudagraph_sizes_keep_small_full_decode_batches_exact(): + config = CompilationConfig( + cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE, + cudagraph_capture_sizes=[ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 40, + 48, + 56, + 64, + 72, + 80, + 88, + 96, + ], + max_cudagraph_capture_size=96, + ) + + config.adjust_cudagraph_sizes_for_spec_decode( + uniform_decode_query_len=3, + tensor_parallel_size=1, + ) + + for num_reqs in range(1, 33): + assert 3 * num_reqs in config.cudagraph_capture_sizes + + @pytest.mark.skipif( not current_platform.support_static_graph_mode(), reason="Skip if not cudagraph mode supported", diff --git a/tests/model_executor/test_deepseek_v4_kernel_warmup.py b/tests/model_executor/test_deepseek_v4_kernel_warmup.py new file mode 100644 index 000000000000..e73db23ebb7b --- /dev/null +++ b/tests/model_executor/test_deepseek_v4_kernel_warmup.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +from vllm.model_executor.warmup.kernel_warmup import ( + _deepseek_v4_mtp_uniform_decode_warmup_requests, +) + + +def test_deepseek_v4_mtp_uniform_decode_warmup_caps_large_max_num_seqs(): + runner = SimpleNamespace( + speculative_config=SimpleNamespace(method="mtp"), + num_spec_tokens=2, + uniform_decode_query_len=3, + ) + + assert _deepseek_v4_mtp_uniform_decode_warmup_requests( + runner, + max_tokens=4096, + max_reqs=1024, + ) == (1, 2, 4, 8, 16, 24, 32) diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index c10835821f58..94dcfe1f5d59 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -267,6 +267,39 @@ def test_get_capture_descs_empty_when_not_initialized(self): assert dispatcher.get_capture_descs() == [] + def test_deepseek_v4_mtp_spec_decode_keeps_full_and_piecewise_graphs(self): + comp_config = CompilationConfig( + cudagraph_mode="FULL_AND_PIECEWISE", + mode=CompilationMode.VLLM_COMPILE, + cudagraph_capture_sizes=[3, 6, 12, 18], + ) + config = _create_vllm_config(comp_config, max_num_seqs=8) + config.speculative_config = MagicMock() + config.speculative_config.method = "mtp" + config.speculative_config.num_speculative_tokens = 2 + config.model_config = MagicMock() + config.model_config.hf_config = MagicMock() + config.model_config.hf_config.model_type = "deepseek_v4" + config.model_config.hf_config.architectures = ["DeepseekV4ForCausalLM"] + + dispatcher = CudagraphDispatcher(config) + dispatcher.initialize_cudagraph_keys( + cudagraph_mode=comp_config.cudagraph_mode, + uniform_decode_query_len=3, + ) + + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) > 0 + assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) > 0 + + rt_mode, key = dispatcher.dispatch( + num_tokens=12, + uniform_decode=True, + has_lora=False, + ) + + assert rt_mode == CUDAGraphMode.FULL + assert key == BatchDescriptor(num_tokens=12, num_reqs=4, uniform=True) + @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") class TestCUDAGraphWrapper: diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c641fdddf405..fe6416de3c0b 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -1490,6 +1490,16 @@ def adjust_cudagraph_sizes_for_spec_decode( if round_up(size, multiple_of) <= self.max_cudagraph_capture_size ) ) + # Spec decode uniform decode graphs are shaped by request count + # (`num_reqs * (1 + num_speculative_tokens)`). Keep the small + # interactive request counts exact so FULL decode graphs do not replay + # against padded virtual requests. + small_decode_sizes = { + multiple_of * num_reqs + for num_reqs in range(1, 33) + if multiple_of * num_reqs <= self.max_cudagraph_capture_size + } + rounded_sizes = sorted(set(rounded_sizes) | small_decode_sizes) if len(rounded_sizes) == 0 and multiple_of <= self.max_cudagraph_capture_size: # if one valid but would be round_down use that diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 214e4e901cb1..96b9bdbbd471 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -46,12 +46,12 @@ # value via _clamp_warmup_tokens at the call site, smaller caps clamp # down naturally. _DEEPSEEK_V4_SPARSE_MLA_PREFILL_WARMUP_TOKENS = 8192 -# Steady-state MTP decode shapes to warm. We always include 1 and 2; the -# scheduler's `max_num_seqs` is appended dynamically at the call site so -# kernels selected per-shape (e.g. `_fp8_paged_mqa_logits_kernel`'s -# adaptive BLOCK_M) are covered for the largest in-flight batch the -# server will ever issue. -_DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS = (1, 2) +# Steady-state MTP decode shapes to warm. Keep this bounded to the edge +# deployment range we expect to optimize; warming the scheduler's raw +# max_num_seqs (often 1024) can consume multiple GiB of temporary workspace +# on long-context SM12x serves before the first request. +_DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS = (1, 2, 4, 8, 16, 24, 32) +_DEEPSEEK_V4_MTP_UNIFORM_DECODE_MAX_WARMUP_REQUESTS = 32 _DEEPSEEK_V4_SLOT_MAPPING_WARMUP_TOKENS = tuple(range(1, 17)) + ( 32, 64, @@ -108,8 +108,15 @@ def _deepseek_v4_mtp_uniform_decode_warmup_requests( if query_len <= 0: return () - max_warmup_reqs = min(max_reqs, max_tokens // query_len) - candidates = sorted(set(_DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS) | {max_reqs}) + max_warmup_reqs = min( + max_reqs, + max_tokens // query_len, + _DEEPSEEK_V4_MTP_UNIFORM_DECODE_MAX_WARMUP_REQUESTS, + ) + candidates = sorted( + set(_DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS) + | {max_warmup_reqs} + ) return tuple(reqs for reqs in candidates if reqs <= max_warmup_reqs) From 16c1667ef13f6a99a303c3bc995719e9a9e505bf Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 20 May 2026 03:54:42 +0800 Subject: [PATCH 045/131] Remove ineffective DeepSeek V4 mHC warmup The model-specific TileLang warmup did not reduce startup time, first-request JIT warnings, 127K C=1 TTFT, or short C=4 correctness in the SM120 ablation. Drop the warmup hook and env knobs instead of keeping a dead A/B path. Signed-off-by: jasl --- vllm/envs.py | 11 - .../warmup/deepseek_v4_mhc_warmup.py | 251 ------------------ vllm/model_executor/warmup/kernel_warmup.py | 10 - 3 files changed, 272 deletions(-) delete mode 100644 vllm/model_executor/warmup/deepseek_v4_mhc_warmup.py diff --git a/vllm/envs.py b/vllm/envs.py index 721fa31410c6..8086b24ef238 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -179,8 +179,6 @@ VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True - VLLM_ENABLE_DEEPSEEK_V4_MHC_WARMUP: bool = True - VLLM_DEEPSEEK_V4_MHC_WARMUP_TOKEN_SIZES: list[int] | None = None VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: bool = True VLLM_TRITON_MLA_SPARSE: bool | None = None VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE: int = 512 @@ -1460,15 +1458,6 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool( int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1")) ), - # DeepSeek V4 mHC / hc_head TileLang kernels JIT on first use. Enable - # startup warmup by default to avoid first-request latency spikes; set to - # 0 to keep the old lazy-JIT behavior. - "VLLM_ENABLE_DEEPSEEK_V4_MHC_WARMUP": lambda: bool( - int(os.getenv("VLLM_ENABLE_DEEPSEEK_V4_MHC_WARMUP", "1")) - ), - "VLLM_DEEPSEEK_V4_MHC_WARMUP_TOKEN_SIZES": lambda: maybe_convert_int_list( - os.getenv("VLLM_DEEPSEEK_V4_MHC_WARMUP_TOKEN_SIZES") - ), "VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP": lambda: bool( int(os.getenv("VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP", "1")) ), diff --git a/vllm/model_executor/warmup/deepseek_v4_mhc_warmup.py b/vllm/model_executor/warmup/deepseek_v4_mhc_warmup.py deleted file mode 100644 index 9189db0c4739..000000000000 --- a/vllm/model_executor/warmup/deepseek_v4_mhc_warmup.py +++ /dev/null @@ -1,251 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Warm up DeepSeek V4 mHC TileLang kernels before serving requests.""" - -import time -from collections.abc import Iterable - -import torch - -import vllm.envs as envs -from vllm.logger import init_logger -from vllm.tracing import instrument -from vllm.utils.math_utils import cdiv - -logger = init_logger(__name__) - -_AUTO_WARMUP_MAX_TOKENS = 16_384 -_DEFAULT_TOKEN_SIZE_CANDIDATES = ( - 1, - 2, - 4, - 8, - 16, - 32, - 64, - 128, - 256, - 512, - 1024, - 2048, - 4096, - 8192, - 16_384, -) - - -def _compute_mhc_pre_num_split( - *, - num_tokens: int, - hidden_size: int, - hc_mult: int, - num_sms: int, -) -> int: - block_k = 64 - block_m = 64 - k = hc_mult * hidden_size - grid_size = cdiv(num_tokens, block_m) - split_k = num_sms // grid_size - num_block_k = cdiv(k, block_k) - split_k = min(split_k, num_block_k // 4) - return max(split_k, 1) - - -def _normalize_token_sizes( - token_sizes: Iterable[int], - *, - max_tokens: int, -) -> list[int]: - return sorted({size for size in token_sizes if 1 <= size <= max_tokens}) - - -def _select_mhc_warmup_token_sizes( - *, - max_tokens: int, - hidden_size: int, - hc_mult: int, - num_sms: int, - requested_token_sizes: list[int] | None, - cudagraph_capture_sizes: list[int], -) -> list[int]: - if max_tokens <= 0: - return [] - - if requested_token_sizes is None: - max_auto_tokens = min(max_tokens, _AUTO_WARMUP_MAX_TOKENS) - candidates = list(_DEFAULT_TOKEN_SIZE_CANDIDATES) - candidates.extend(cudagraph_capture_sizes) - candidates.append(max_auto_tokens) - candidates = _normalize_token_sizes(candidates, max_tokens=max_auto_tokens) - else: - candidates = _normalize_token_sizes( - requested_token_sizes, - max_tokens=max_tokens, - ) - - return candidates - - -def _find_first_mhc_layer(model: torch.nn.Module) -> torch.nn.Module | None: - for module in model.modules(): - if module.__class__.__name__ != "DeepseekV4DecoderLayer": - continue - if all( - hasattr(module, attr) - for attr in ( - "hc_pre", - "hc_post", - "hc_attn_fn", - "hc_attn_scale", - "hc_attn_base", - "hc_ffn_fn", - "hc_ffn_scale", - "hc_ffn_base", - ) - ): - return module - return None - - -def _find_deepseek_v4_model(model: torch.nn.Module) -> torch.nn.Module | None: - for module in model.modules(): - if module.__class__.__name__ != "DeepseekV4Model": - continue - if all( - hasattr(module, attr) - for attr in ("hc_head_fn", "hc_head_scale", "hc_head_base") - ): - return module - return None - - -def _get_cuda_num_sms(device: torch.device) -> int: - index = device.index - if index is None: - index = torch.accelerator.current_device_index() - return torch.cuda.get_device_properties(index).multi_processor_count - - -def _warmup_layer_mhc( - layer: torch.nn.Module, - token_sizes: list[int], -) -> None: - max_tokens = max(token_sizes) - hidden_size = int(layer.hidden_size) - hc_mult = int(layer.hc_mult) - device = layer.hc_attn_fn.device - residual = torch.zeros( - max_tokens, - hc_mult, - hidden_size, - dtype=torch.bfloat16, - device=device, - ) - - for size in token_sizes: - residual_slice = residual[:size] - for fn, scale, base in ( - (layer.hc_attn_fn, layer.hc_attn_scale, layer.hc_attn_base), - (layer.hc_ffn_fn, layer.hc_ffn_scale, layer.hc_ffn_base), - ): - layer_input, post_mix, comb_mix = layer.hc_pre( - residual_slice, - fn, - scale, - base, - ) - layer.hc_post(layer_input, residual_slice, post_mix, comb_mix) - - -def _warmup_hc_head( - model: torch.nn.Module, - token_sizes: list[int], -) -> None: - if not hasattr(model, "_mtp_hidden_buffer"): - return - - # Upstream (a8887c208 "[DSV4] aiter mhc support (ROCm)") refactored - # ``hc_head`` from a free function exported by ``deepseek_v4`` into the - # ``HCHeadOp`` :class:`CustomOp` instance attached to the model as - # ``hc_head_op``. We call through that instance so the warmup exercises - # the same dispatched implementation as the inference path. - hc_head_op = getattr(model, "hc_head_op", None) - if hc_head_op is None: - return - - max_tokens = max(token_sizes) - hidden_size = int(model.config.hidden_size) - hc_mult = int(model.hc_mult) - device = model.hc_head_fn.device - hidden_states = torch.zeros( - max_tokens, - hc_mult, - hidden_size, - dtype=torch.bfloat16, - device=device, - ) - - for size in token_sizes: - hc_head_op( - hidden_states[:size], - model.hc_head_fn, - model.hc_head_scale, - model.hc_head_base, - model.rms_norm_eps, - model.hc_eps, - ) - - -@instrument(span_name="DeepSeek V4 mHC warmup") -def deepseek_v4_mhc_warmup( - model: torch.nn.Module, - *, - max_tokens: int, - cudagraph_capture_sizes: list[int] | None = None, -) -> None: - if not envs.VLLM_ENABLE_DEEPSEEK_V4_MHC_WARMUP: - return - - # Cheap model-type gate before walking ``model.modules()``. The class - # walk below is O(num_layers) and shows up in startup time on very - # large checkpoints; bail out for any model that is not DeepSeek V4. - config = getattr(model, "config", None) - model_type = getattr(config, "model_type", None) if config is not None else None - if model_type is not None and model_type != "deepseek_v4": - return - - layer = _find_first_mhc_layer(model) - if layer is None: - return - - device = layer.hc_attn_fn.device - if device.type != "cuda": - return - - deepseek_model = _find_deepseek_v4_model(model) - num_sms = _get_cuda_num_sms(device) - token_sizes = _select_mhc_warmup_token_sizes( - max_tokens=max_tokens, - hidden_size=int(layer.hidden_size), - hc_mult=int(layer.hc_mult), - num_sms=num_sms, - requested_token_sizes=envs.VLLM_DEEPSEEK_V4_MHC_WARMUP_TOKEN_SIZES, - cudagraph_capture_sizes=cudagraph_capture_sizes or [], - ) - if not token_sizes: - return - - started = time.perf_counter() - logger.info( - "Warming up DeepSeek V4 mHC TileLang kernels for token sizes: %s", - token_sizes, - ) - with torch.inference_mode(): - _warmup_layer_mhc(layer, token_sizes) - if deepseek_model is not None: - _warmup_hc_head(deepseek_model, token_sizes) - torch.accelerator.synchronize() - logger.info( - "DeepSeek V4 mHC TileLang warmup finished in %.2f seconds.", - time.perf_counter() - started, - ) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 96b9bdbbd471..83473dd6090b 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -18,9 +18,6 @@ from vllm.compilation.caching import aot_compile_hash_factors from vllm.logger import init_logger from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup -from vllm.model_executor.warmup.deepseek_v4_mhc_warmup import ( - deepseek_v4_mhc_warmup, -) from vllm.platforms import current_platform from vllm.utils.deep_gemm import is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer @@ -500,13 +497,6 @@ def kernel_warmup(worker: "Worker"): deep_gemm_warmup(model, max_tokens) minimax_m3_msa_warmup(worker) - deepseek_v4_mhc_warmup( - worker.get_model(), - max_tokens=worker.scheduler_config.max_num_batched_tokens, - cudagraph_capture_sizes=( - worker.vllm_config.compilation_config.cudagraph_capture_sizes or [] - ), - ) _deepseek_v4_sparse_mla_attention_warmup(worker) _deepseek_v4_request_prep_warmup(worker) From 97b7143fbc8a834cae63f920e6efec7cc7531a56 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 20 May 2026 04:30:00 +0800 Subject: [PATCH 046/131] Tune SM120 FP8 MQA logits row tile Widen the direct SM120 FP8 MQA logits Triton row tile from 16 to 64 while keeping BLOCK_N=128, BLOCK_D=64, and four warps. SM120 microbench improved about 22-23% across 128/256/512/1024 query rows. Same-host C=1 repeat gates improved 59K mean TTFT from 11.413s to 11.097s and 124K from 29.868s to 28.042s. Short MT-Bench C=1/2/4 and GSM8K limit-200 temperature=0 passed. Artifacts: codex_mqa_blockm_followup_20260520040105, codex_blockm64_c1_repeat/20260520041210, codex_blockm16_c1_repeat_baseline/20260520041627, codex_blockm64_short_gsm8k_gate/20260520042117. Signed-off-by: jasl --- vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py index 23e42cc07d25..d589d89390e9 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py @@ -147,7 +147,7 @@ def fp8_mqa_logits_triton( if num_q == 0 or seq_len_kv == 0: return logits - grid = (triton.cdiv(num_q, 16), triton.cdiv(seq_len_kv, 128)) + grid = (triton.cdiv(num_q, 64), triton.cdiv(seq_len_kv, 128)) _fp8_mqa_logits_kernel[grid]( q, k_fp8, @@ -169,7 +169,7 @@ def fp8_mqa_logits_triton( weights.stride(1), logits.stride(0), logits.stride(1), - BLOCK_M=16, + BLOCK_M=64, BLOCK_N=128, BLOCK_D=64, num_warps=4, From a81282b740530072dbdb82e14edc800f09333589 Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 21 May 2026 05:12:59 +0800 Subject: [PATCH 047/131] Clean up SM120 rebase leftovers Signed-off-by: jasl --- vllm/envs.py | 7 ------- vllm/model_executor/layers/quantization/mxfp4.py | 6 ------ 2 files changed, 13 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 8086b24ef238..f3f1fa998a1f 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -309,13 +309,6 @@ def maybe_convert_int(value: str | None) -> int | None: return int(value) -def maybe_convert_int_list(value: str | None) -> list[int] | None: - if value is None: - return None - values = [int(item.strip()) for item in value.split(",") if item.strip()] - return values or None - - def maybe_convert_bool(value: str | None) -> bool | None: if value is None: return None diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 92c4b23f54bc..1b2a8a74bdcb 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -389,9 +389,6 @@ def process_weights_after_loading(self, layer: RoutedExperts) -> None: return self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias) - del w13, w2, w13_scale, w2_scale, w13_bias, w2_bias - if torch.cuda.is_available(): - torch.cuda.empty_cache() def get_fused_moe_quant_config( self, layer: RoutedExperts @@ -736,9 +733,6 @@ def process_weights_after_loading(self, layer): return self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias) - del w13, w2, w13_scale, w2_scale, w13_bias, w2_bias - if torch.cuda.is_available(): - torch.cuda.empty_cache() def get_fused_moe_quant_config( self, From 52a0c1949a0104f68efac60f50ac4e39e23711af Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 21 May 2026 06:00:21 +0800 Subject: [PATCH 048/131] Remove unused SM120 splitKV decode experiment The splitKV sparse MLA decode path stayed behind a default-off flag after benchmarking showed ambiguous value for the current SM120 latency target. Keep the measured matmul decode path active and preserve the experiment on backup branch codex/sm120-splitkv-decode-experiment-backup-20260521054846 for future reference. Signed-off-by: jasl --- vllm/envs.py | 4 - vllm/model_executor/warmup/kernel_warmup.py | 15 +- .../attention/backends/mla/sparse_mla_env.py | 4 - .../backends/mla/sparse_mla_kernels.py | 313 +----------------- 4 files changed, 4 insertions(+), 332 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index f3f1fa998a1f..128289ba1e0b 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -185,7 +185,6 @@ VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE: int = 256 VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE: int | None = None VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE: bool | None = None - VLLM_TRITON_MLA_SPARSE_SPLITKV_DECODE: bool = False VLLM_DEEP_GEMM_WARMUP: Literal[ "skip", "full", @@ -1476,9 +1475,6 @@ def _resolve_rust_frontend_path() -> str | None: if os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE") is None else bool(int(os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE", "0"))) ), - "VLLM_TRITON_MLA_SPARSE_SPLITKV_DECODE": lambda: bool( - int(os.getenv("VLLM_TRITON_MLA_SPARSE_SPLITKV_DECODE", "0")) - ), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 83473dd6090b..819b2050a009 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -412,18 +412,9 @@ def _deepseek_v4_sparse_mla_attention_warmup(worker: "Worker") -> None: create_single_prefill=True, profile_seq_lens=prefill_tokens * 2, ) - # NOTE: The multi-request prefill warmup that previously sat here - # (max_num_seqs prefills sharing the batched-token budget) hit a - # CUDA illegal memory access inside the CUTeDSL - # ``DequantGatherKCacheKernel`` on SM12x. The dummy_run shape it - # generated violates an implicit ``offset + gather_len <= M`` - # invariant of the kv-gather output buffer (M is sized for the - # single-prefill warmup case). Removing the warmup gives back the - # one-time JIT cost on the first real multi-prefill request, but - # unblocks serve startup at production ``--max-num-seqs`` values - # (e.g. 128). Re-enable once the gather-buffer sizing for - # multi-request prefill warmup is reconciled with the kernel's - # bounds. + # Do not synthesize multi-request prefill here: that dummy shape + # overflows the CUTeDSL KV-gather workspace on SM12x. Revisit only + # with a real buffer-sizing fix for that warmup path. query_len = getattr(runner, "uniform_decode_query_len", 0) for num_reqs in uniform_decode_reqs: runner._dummy_run( diff --git a/vllm/v1/attention/backends/mla/sparse_mla_env.py b/vllm/v1/attention/backends/mla/sparse_mla_env.py index 4567eca9dd83..433d69413bcd 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_env.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_env.py @@ -58,7 +58,3 @@ def triton_sparse_mla_matmul_decode_enabled() -> bool: if configured is not None: return configured return current_platform.is_device_capability_family(120) - - -def triton_sparse_mla_splitkv_decode_enabled() -> bool: - return envs.VLLM_TRITON_MLA_SPARSE_SPLITKV_DECODE diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py index 4afa5c0aed4b..c941d1df433d 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -2,25 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Portable sparse MLA Triton kernels.""" -import math - import torch -from vllm.triton_utils import LOG2E, LOGE2, tl, triton +from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.mla.sparse_mla_env import ( triton_sparse_mla_head_block_size, ) -_SPLITKV_HEAD_BLOCK = 16 -_SPLITKV_MERGE_HEAD_BLOCK = 1 -_SPLITKV_BLOCK_N = 32 -_SPLITKV_MERGE_BLOCK_D = 128 -_SPLITKV_MIN_CANDIDATES_PER_SPLIT = 128 -_SPLITKV_MEDIUM_BATCH_MIN_TOKENS = 16 -_SPLITKV_MEDIUM_BATCH_CANDIDATES_PER_SPLIT = 512 -_SPLITKV_MEDIUM_BATCH_MAX_SPLITS = 8 -_SPLITKV_MAX_OCCUPANCY = 4 - def sparse_mla_decode_head_block_size(num_decode_tokens: int) -> int: """Choose the SM12x sparse MLA head grouping for decode kernels. @@ -40,305 +28,6 @@ def sparse_mla_decode_head_block_size(num_decode_tokens: int) -> int: return 4 -def _next_power_of_2(value: int) -> int: - return 1 << max(0, value - 1).bit_length() - - -def choose_sparse_mla_splitkv_splits( - num_tokens: int, - num_heads: int, - num_candidates: int, - sm_count: int, - head_block_size: int = _SPLITKV_HEAD_BLOCK, -) -> int: - if ( - num_tokens <= 0 - or num_heads <= 0 - or num_candidates <= 0 - or sm_count <= 0 - or head_block_size <= 0 - ): - return 1 - - num_head_groups = math.ceil(num_heads / min(head_block_size, num_heads)) - baseline = num_tokens * num_head_groups - if baseline == 0: - return 1 - - ideal = _next_power_of_2( - max(1, num_candidates // _SPLITKV_MIN_CANDIDATES_PER_SPLIT) - ) - max_splits = max(1, (sm_count * _SPLITKV_MAX_OCCUPANCY) // baseline) - max_splits = 1 << (max_splits.bit_length() - 1) - num_splits = min(ideal, max_splits) - if ( - num_tokens >= _SPLITKV_MEDIUM_BATCH_MIN_TOKENS - and baseline <= sm_count * _SPLITKV_MAX_OCCUPANCY - ): - medium_batch_splits = _next_power_of_2( - max(1, num_candidates // _SPLITKV_MEDIUM_BATCH_CANDIDATES_PER_SPLIT) - ) - medium_batch_splits = min( - ideal, medium_batch_splits, _SPLITKV_MEDIUM_BATCH_MAX_SPLITS - ) - num_splits = max(num_splits, medium_batch_splits) - while num_splits > 1 and num_candidates % num_splits != 0: - num_splits //= 2 - return max(1, num_splits) - - -@triton.jit -def _splitkv_sparse_mla_stage1_kernel( - q_ptr, - kv_ptr, - valid_ptr, - mid_ptr, - stride_qt: tl.constexpr, - stride_qh: tl.constexpr, - stride_qd: tl.constexpr, - stride_kvt: tl.constexpr, - stride_kvc: tl.constexpr, - stride_kvd: tl.constexpr, - stride_vt: tl.constexpr, - stride_vc: tl.constexpr, - stride_mt: tl.constexpr, - stride_mh: tl.constexpr, - stride_ms: tl.constexpr, - num_heads: tl.constexpr, - num_candidates: tl.constexpr, - scale: tl.constexpr, - num_splits: tl.constexpr, - HEAD_BLOCK: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D: tl.constexpr, - LOGE2_VALUE: tl.constexpr, -): - token_id = tl.program_id(0) - head_group = tl.program_id(1) - split_id = tl.program_id(2) - - offs_h = head_group * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) - mask_h = offs_h < num_heads - offs_d = tl.arange(0, BLOCK_D) - - q = tl.load( - q_ptr - + token_id * stride_qt - + offs_h[:, None] * stride_qh - + offs_d[None, :] * stride_qd, - mask=mask_h[:, None], - other=0.0, - ) - - split_size: tl.constexpr = tl.cdiv(num_candidates, num_splits) - split_start = split_id * split_size - split_end = tl.minimum(split_start + split_size, num_candidates) - - neg_large = -1.0e30 - e_max = tl.full((HEAD_BLOCK,), neg_large, dtype=tl.float32) - e_sum = tl.zeros((HEAD_BLOCK,), dtype=tl.float32) - acc = tl.zeros((HEAD_BLOCK, BLOCK_D), dtype=tl.float32) - - for cand_start in range(split_start, split_end, BLOCK_N): - offs_c = cand_start + tl.arange(0, BLOCK_N) - mask_c = offs_c < split_end - valid = tl.load( - valid_ptr + token_id * stride_vt + offs_c * stride_vc, - mask=mask_c, - other=0, - ) - mask_kv = mask_c & valid - k = tl.load( - kv_ptr - + token_id * stride_kvt - + offs_c[:, None] * stride_kvc - + offs_d[None, :] * stride_kvd, - mask=mask_kv[:, None], - other=0.0, - ) - qk = tl.dot(q, tl.trans(k.to(q.dtype))) * scale - qk = tl.where(mask_h[:, None] & mask_kv[None, :], qk, neg_large) - - n_e_max = tl.maximum(tl.max(qk, 1), e_max) - re_scale = tl.exp2(e_max - n_e_max) - p = tl.exp2(qk - n_e_max[:, None]) - acc *= re_scale[:, None] - acc += tl.dot(p.to(k.dtype), k) - e_sum = e_sum * re_scale + tl.sum(p, 1) - e_max = n_e_max - - e_sum_safe = tl.where(e_sum > 0, e_sum, 1.0) - mid_base = ( - mid_ptr - + token_id * stride_mt - + offs_h[:, None] * stride_mh - + split_id * stride_ms - ) - tl.store( - mid_base + offs_d[None, :], - acc / e_sum_safe[:, None], - mask=mask_h[:, None], - ) - tl.store( - mid_ptr - + token_id * stride_mt - + offs_h * stride_mh - + split_id * stride_ms - + BLOCK_D, - (e_max + tl.log2(e_sum)) * LOGE2_VALUE, - mask=mask_h, - ) - - -@triton.jit -def _splitkv_sparse_mla_merge_kernel( - mid_ptr, - sink_ptr, - output_ptr, - stride_mt: tl.constexpr, - stride_mh: tl.constexpr, - stride_ms: tl.constexpr, - stride_out_t: tl.constexpr, - stride_oh: tl.constexpr, - stride_od: tl.constexpr, - num_heads: tl.constexpr, - num_splits: tl.constexpr, - HEAD_BLOCK: tl.constexpr, - BLOCK_D: tl.constexpr, - BLOCK_D_TILE: tl.constexpr, -): - token_id = tl.program_id(0) - head_group = tl.program_id(1) - d_tile = tl.program_id(2) - - offs_h = head_group * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) - mask_h = offs_h < num_heads - offs_d = d_tile * BLOCK_D_TILE + tl.arange(0, BLOCK_D_TILE) - mask_d = offs_d < BLOCK_D - - e_max = tl.full((HEAD_BLOCK,), -float("inf"), dtype=tl.float32) - e_sum = tl.zeros((HEAD_BLOCK,), dtype=tl.float32) - acc = tl.zeros((HEAD_BLOCK, BLOCK_D_TILE), dtype=tl.float32) - mid_base = mid_ptr + token_id * stride_mt + offs_h[:, None] * stride_mh - mid_lse = mid_ptr + token_id * stride_mt + offs_h * stride_mh + BLOCK_D - - for split_id in range(num_splits): - part = tl.load( - mid_base + split_id * stride_ms + offs_d[None, :], - mask=mask_h[:, None] & mask_d[None, :], - other=0.0, - ) - lse = tl.load( - mid_lse + split_id * stride_ms, - mask=mask_h, - other=-float("inf"), - ) - n_e_max = tl.maximum(lse, e_max) - old_scale = tl.exp(e_max - n_e_max) - part_scale = tl.exp(lse - n_e_max) - acc = acc * old_scale[:, None] + part * part_scale[:, None] - e_sum = e_sum * old_scale + part_scale - e_max = n_e_max - - sink = tl.load(sink_ptr + offs_h, mask=mask_h, other=-float("inf")) - n_e_max = tl.maximum(sink, e_max) - value_scale = tl.exp(e_max - n_e_max) - sink_scale = tl.exp(sink - n_e_max) - denom = e_sum * value_scale + sink_scale - denom = tl.where(denom > 0, denom, 1.0) - merged = acc * value_scale[:, None] / denom[:, None] - - tl.store( - output_ptr - + token_id * stride_out_t - + offs_h[:, None] * stride_oh - + offs_d[None, :] * stride_od, - merged, - mask=mask_h[:, None] & mask_d[None, :], - ) - - -def splitkv_sparse_mla_attention_with_sink( - q: torch.Tensor, - kv: torch.Tensor, - valid_tokens: torch.Tensor, - scale: float, - attn_sink: torch.Tensor, - output: torch.Tensor, - mid: torch.Tensor, - num_splits: int, - num_heads: int | None = None, -) -> None: - if q.dim() == 4: - assert q.shape[1] == 1 - q = q[:, 0] - - assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" - assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}" - assert valid_tokens.shape == kv.shape[:2] - assert q.shape[0] == kv.shape[0] - assert q.shape[-1] == kv.shape[-1] - assert output.shape[0] == q.shape[0] - assert output.shape[2] == q.shape[-1] - assert q.is_cuda and kv.is_cuda and valid_tokens.is_cuda - assert attn_sink.is_cuda and output.is_cuda and mid.is_cuda - - active_heads = num_heads if num_heads is not None else output.shape[1] - assert active_heads <= q.shape[1] - assert active_heads <= output.shape[1] - assert active_heads <= attn_sink.shape[0] - - num_tokens, _, head_dim = q.shape - num_candidates = kv.shape[1] - assert mid.shape == (num_tokens, active_heads, num_splits, head_dim + 1) - num_head_groups = triton.cdiv(active_heads, _SPLITKV_HEAD_BLOCK) - _splitkv_sparse_mla_stage1_kernel[(num_tokens, num_head_groups, num_splits)]( - q, - kv, - valid_tokens, - mid, - q.stride(0), - q.stride(1), - q.stride(2), - kv.stride(0), - kv.stride(1), - kv.stride(2), - valid_tokens.stride(0), - valid_tokens.stride(1), - mid.stride(0), - mid.stride(1), - mid.stride(2), - active_heads, - num_candidates, - scale * LOG2E, - num_splits, - HEAD_BLOCK=_SPLITKV_HEAD_BLOCK, - BLOCK_N=_SPLITKV_BLOCK_N, - BLOCK_D=head_dim, - LOGE2_VALUE=LOGE2, - num_warps=4, - ) - _splitkv_sparse_mla_merge_kernel[ - (num_tokens, active_heads, triton.cdiv(head_dim, _SPLITKV_MERGE_BLOCK_D)) - ]( - mid, - attn_sink, - output, - mid.stride(0), - mid.stride(1), - mid.stride(2), - output.stride(0), - output.stride(1), - output.stride(2), - active_heads, - num_splits, - HEAD_BLOCK=_SPLITKV_MERGE_HEAD_BLOCK, - BLOCK_D=head_dim, - BLOCK_D_TILE=_SPLITKV_MERGE_BLOCK_D, - num_warps=2, - ) - - @triton.jit def _merge_two_subsets_with_sink_kernel( out0_ptr, From a8bdc004cc187e4dd200e230e5cb42026282a7c2 Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 21 May 2026 09:03:38 +0800 Subject: [PATCH 049/131] Limit long prefill chunks behind active decode Signed-off-by: jasl --- tests/v1/core/test_scheduler.py | 74 +++++++++++++++++++++++++++++++++ vllm/v1/core/sched/scheduler.py | 32 ++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index b2825c34df86..7368d3bea11d 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -866,6 +866,80 @@ def test_schedule_order(enable_chunked_prefill: bool): assert len(scheduler_output1.scheduled_new_reqs) == 1 +def test_mixed_decode_prefill_does_not_cap_short_prefill(): + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_model_len=512, + max_num_seqs=2, + enable_chunked_prefill=True, + ) + decode_req = create_requests(num_requests=1, num_tokens=100, req_ids=["decode"])[0] + short_prefill_req = create_requests( + num_requests=1, + num_tokens=40, + req_ids=["short_prefill"], + )[0] + + scheduler.add_request(decode_req) + prefill_output = scheduler.schedule() + assert prefill_output.num_scheduled_tokens[decode_req.request_id] == 100 + + scheduler.update_from_output( + prefill_output, + ModelRunnerOutput( + req_ids=[decode_req.request_id], + req_id_to_index={decode_req.request_id: 0}, + sampled_token_ids=[[0]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + scheduler.add_request(short_prefill_req) + mixed_output = scheduler.schedule() + + assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 + assert mixed_output.num_scheduled_tokens[short_prefill_req.request_id] == 40 + + +def test_mixed_decode_prefill_caps_long_prefill_chunk(): + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_model_len=512, + max_num_seqs=2, + enable_chunked_prefill=True, + ) + decode_req = create_requests(num_requests=1, num_tokens=100, req_ids=["decode"])[0] + long_prefill_req = create_requests( + num_requests=1, + num_tokens=300, + req_ids=["long_prefill"], + )[0] + + scheduler.add_request(decode_req) + prefill_output = scheduler.schedule() + assert prefill_output.num_scheduled_tokens[decode_req.request_id] == 100 + + scheduler.update_from_output( + prefill_output, + ModelRunnerOutput( + req_ids=[decode_req.request_id], + req_id_to_index={decode_req.request_id: 0}, + sampled_token_ids=[[0]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + scheduler.add_request(long_prefill_req) + mixed_output = scheduler.schedule() + + assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 + assert mixed_output.num_scheduled_tokens[long_prefill_req.request_id] == 75 + + def test_preempt_during_execution(): # NOTE(woosuk): The actual number of available blocks is 10 instead of 11 # because block 0 is reserved as the null block. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 25ccf79bc3a7..6110fda46f67 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -384,6 +384,32 @@ def _mamba_block_aligned_split( num_new_tokens = num_new_tokens // block_size * block_size return num_new_tokens + def _has_scheduled_decode(self, requests: list[Request]) -> bool: + return any( + request.num_computed_tokens >= request.num_prompt_tokens + for request in requests + ) + + def _limit_mixed_decode_prefill_chunk( + self, + request: Request, + num_new_tokens: int, + scheduled_running_reqs: list[Request], + ) -> int: + if ( + not self.scheduler_config.enable_chunked_prefill + or request.num_computed_tokens >= request.num_prompt_tokens + or not self._has_scheduled_decode(scheduled_running_reqs) + ): + return num_new_tokens + + remaining_prefill = request.num_prompt_tokens - request.num_computed_tokens + if remaining_prefill <= self.max_num_scheduled_tokens: + return num_new_tokens + + mixed_prefill_budget = max(1, (self.max_num_scheduled_tokens * 3) // 4) + return min(num_new_tokens, mixed_prefill_budget) + def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: self.current_step += 1 # NOTE(woosuk) on the scheduling algorithm: @@ -467,6 +493,9 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) + num_new_tokens = self._limit_mixed_decode_prefill_chunk( + request, num_new_tokens, scheduled_running_reqs + ) # Make sure the input position does not exceed the max model len. # This is necessary when using spec decoding. @@ -808,6 +837,9 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: break num_new_tokens = min(num_new_tokens, token_budget) + num_new_tokens = self._limit_mixed_decode_prefill_chunk( + request, num_new_tokens, scheduled_running_reqs + ) assert num_new_tokens > 0 # Schedule encoder inputs. From 129e12921597c47229ebd34fa51a985a5ca6e03a Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 21 May 2026 17:47:38 +0800 Subject: [PATCH 050/131] Tighten mixed prefill cap for very long prompts Signed-off-by: jasl --- tests/v1/core/test_scheduler.py | 37 +++++++++++++++++++++++++++++++++ vllm/v1/core/sched/scheduler.py | 11 +++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 7368d3bea11d..0ea790dd310a 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -940,6 +940,43 @@ def test_mixed_decode_prefill_caps_long_prefill_chunk(): assert mixed_output.num_scheduled_tokens[long_prefill_req.request_id] == 75 +def test_mixed_decode_prefill_caps_very_long_prefill_more_tightly(): + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_model_len=4096, + max_num_seqs=2, + enable_chunked_prefill=True, + ) + decode_req = create_requests(num_requests=1, num_tokens=100, req_ids=["decode"])[0] + very_long_prefill_req = create_requests( + num_requests=1, + num_tokens=2000, + req_ids=["very_long_prefill"], + )[0] + + scheduler.add_request(decode_req) + prefill_output = scheduler.schedule() + assert prefill_output.num_scheduled_tokens[decode_req.request_id] == 100 + + scheduler.update_from_output( + prefill_output, + ModelRunnerOutput( + req_ids=[decode_req.request_id], + req_id_to_index={decode_req.request_id: 0}, + sampled_token_ids=[[0]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + scheduler.add_request(very_long_prefill_req) + mixed_output = scheduler.schedule() + + assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 + assert mixed_output.num_scheduled_tokens[very_long_prefill_req.request_id] == 50 + + def test_preempt_during_execution(): # NOTE(woosuk): The actual number of available blocks is 10 instead of 11 # because block 0 is reserved as the null block. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 6110fda46f67..e23ddea23a09 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -407,7 +407,16 @@ def _limit_mixed_decode_prefill_chunk( if remaining_prefill <= self.max_num_scheduled_tokens: return num_new_tokens - mixed_prefill_budget = max(1, (self.max_num_scheduled_tokens * 3) // 4) + # Very long prefills span many scheduling steps; a smaller chunk keeps + # already-active decoders from seeing long inter-token gaps. + very_long_prefill_steps = 16 + very_long_prefill_threshold = ( + self.max_num_scheduled_tokens * very_long_prefill_steps + ) + if remaining_prefill > very_long_prefill_threshold: + mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 2) + else: + mixed_prefill_budget = max(1, (self.max_num_scheduled_tokens * 3) // 4) return min(num_new_tokens, mixed_prefill_budget) def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: From a962bf1f2b444b78afa99f0add4e3c57b2313c4b Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 21 May 2026 19:15:48 +0800 Subject: [PATCH 051/131] Improve SM120 mixed prefill scheduling Signed-off-by: jasl --- .../test_sm120_deepgemm_fallbacks.py | 11 ++- tests/v1/core/test_scheduler.py | 80 ++++++++++++++++++- .../deepseek_v4/nvidia/ops/sm12x_mqa.py | 11 ++- vllm/v1/core/sched/scheduler.py | 23 ++++-- 4 files changed, 115 insertions(+), 10 deletions(-) diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py index 7bc8cc4df55e..a0e93d4943d3 100644 --- a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -9,12 +9,12 @@ _decode_logits_width, _decode_topk_logits_width, ) -from vllm.platforms import current_platform -from vllm.utils.math_utils import cdiv from vllm.models.deepseek_v4.nvidia.ops import ( sm12x_deep_gemm_fallbacks, sm12x_mqa, ) +from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv def _make_indexer_kv_cache( @@ -126,6 +126,13 @@ def test_decode_topk_logits_width_keeps_topk_kernel_width(): assert _decode_topk_logits_width(0, 128, 512) == 0 +def test_sm120_direct_mqa_logits_block_m_prefers_short_prefill_tile(): + assert sm12x_mqa._fp8_mqa_logits_block_m(1024, 1024) == 16 + assert sm12x_mqa._fp8_mqa_logits_block_m(4096, 4096) == 16 + assert sm12x_mqa._fp8_mqa_logits_block_m(16384, 16384) == 16 + assert sm12x_mqa._fp8_mqa_logits_block_m(65536, 65536) == 64 + + @pytest.mark.skipif( not current_platform.is_device_capability_family(120), reason="SM120 only" ) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 0ea790dd310a..542c88041c63 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -937,7 +937,85 @@ def test_mixed_decode_prefill_caps_long_prefill_chunk(): mixed_output = scheduler.schedule() assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 + assert mixed_output.num_scheduled_tokens[long_prefill_req.request_id] == 25 + + +def test_running_long_prefill_leaves_budget_for_waiting_short_prefill(): + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_model_len=512, + max_num_seqs=2, + enable_chunked_prefill=True, + ) + long_prefill_req = create_requests( + num_requests=1, + num_tokens=300, + req_ids=["long_prefill"], + )[0] + short_prefill_req = create_requests( + num_requests=1, + num_tokens=20, + req_ids=["short_prefill"], + )[0] + + scheduler.add_request(long_prefill_req) + first_chunk = scheduler.schedule() + assert first_chunk.num_scheduled_tokens[long_prefill_req.request_id] == 100 + + scheduler.update_from_output( + first_chunk, + ModelRunnerOutput( + req_ids=[long_prefill_req.request_id], + req_id_to_index={long_prefill_req.request_id: 0}, + sampled_token_ids=[[]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + scheduler.add_request(short_prefill_req) + mixed_output = scheduler.schedule() + assert mixed_output.num_scheduled_tokens[long_prefill_req.request_id] == 75 + assert mixed_output.num_scheduled_tokens[short_prefill_req.request_id] == 20 + + +def test_mixed_decode_prefill_caps_mid_long_prefill_more_tightly(): + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_model_len=1024, + max_num_seqs=2, + enable_chunked_prefill=True, + ) + decode_req = create_requests(num_requests=1, num_tokens=100, req_ids=["decode"])[0] + mid_long_prefill_req = create_requests( + num_requests=1, + num_tokens=500, + req_ids=["mid_long_prefill"], + )[0] + + scheduler.add_request(decode_req) + prefill_output = scheduler.schedule() + assert prefill_output.num_scheduled_tokens[decode_req.request_id] == 100 + + scheduler.update_from_output( + prefill_output, + ModelRunnerOutput( + req_ids=[decode_req.request_id], + req_id_to_index={decode_req.request_id: 0}, + sampled_token_ids=[[0]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + scheduler.add_request(mid_long_prefill_req) + mixed_output = scheduler.schedule() + + assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 + assert mixed_output.num_scheduled_tokens[mid_long_prefill_req.request_id] == 12 def test_mixed_decode_prefill_caps_very_long_prefill_more_tightly(): @@ -974,7 +1052,7 @@ def test_mixed_decode_prefill_caps_very_long_prefill_more_tightly(): mixed_output = scheduler.schedule() assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 - assert mixed_output.num_scheduled_tokens[very_long_prefill_req.request_id] == 50 + assert mixed_output.num_scheduled_tokens[very_long_prefill_req.request_id] == 12 def test_preempt_during_execution(): diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py index d589d89390e9..d308fa6a6268 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py @@ -147,7 +147,8 @@ def fp8_mqa_logits_triton( if num_q == 0 or seq_len_kv == 0: return logits - grid = (triton.cdiv(num_q, 64), triton.cdiv(seq_len_kv, 128)) + block_m = _fp8_mqa_logits_block_m(num_q, seq_len_kv) + grid = (triton.cdiv(num_q, block_m), triton.cdiv(seq_len_kv, 128)) _fp8_mqa_logits_kernel[grid]( q, k_fp8, @@ -169,7 +170,7 @@ def fp8_mqa_logits_triton( weights.stride(1), logits.stride(0), logits.stride(1), - BLOCK_M=64, + BLOCK_M=block_m, BLOCK_N=128, BLOCK_D=64, num_warps=4, @@ -177,6 +178,12 @@ def fp8_mqa_logits_triton( return logits +def _fp8_mqa_logits_block_m(num_q: int, seq_len_kv: int) -> int: + if seq_len_kv <= 16 * 1024: + return 16 + return 64 + + @triton.jit def _fp8_paged_mqa_logits_kernel( q_ptr, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index e23ddea23a09..e7b4f11447f9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -395,25 +395,35 @@ def _limit_mixed_decode_prefill_chunk( request: Request, num_new_tokens: int, scheduled_running_reqs: list[Request], + has_waiting_requests: bool = False, ) -> int: if ( not self.scheduler_config.enable_chunked_prefill or request.num_computed_tokens >= request.num_prompt_tokens - or not self._has_scheduled_decode(scheduled_running_reqs) ): return num_new_tokens + has_scheduled_decode = self._has_scheduled_decode(scheduled_running_reqs) + if not has_scheduled_decode and not has_waiting_requests: + return num_new_tokens + remaining_prefill = request.num_prompt_tokens - request.num_computed_tokens if remaining_prefill <= self.max_num_scheduled_tokens: return num_new_tokens # Very long prefills span many scheduling steps; a smaller chunk keeps - # already-active decoders from seeing long inter-token gaps. - very_long_prefill_steps = 16 + # already-active decoders from seeing long inter-token gaps and leaves + # room for short requests that arrive behind an active long prefill. + very_long_prefill_steps = 4 very_long_prefill_threshold = ( self.max_num_scheduled_tokens * very_long_prefill_steps ) - if remaining_prefill > very_long_prefill_threshold: + if has_scheduled_decode: + if remaining_prefill > very_long_prefill_threshold: + mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 8) + else: + mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 4) + elif remaining_prefill > very_long_prefill_threshold: mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 2) else: mixed_prefill_budget = max(1, (self.max_num_scheduled_tokens * 3) // 4) @@ -503,7 +513,10 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = self._limit_mixed_decode_prefill_chunk( - request, num_new_tokens, scheduled_running_reqs + request, + num_new_tokens, + scheduled_running_reqs, + bool(self.waiting or self.skipped_waiting), ) # Make sure the input position does not exceed the max model len. From b02f6834a1cf7b9df35d15ef865742aa19453b43 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 22 May 2026 18:50:21 +0800 Subject: [PATCH 052/131] Clean up DeepSeek V4 reasoning parser lint Signed-off-by: jasl --- tests/reasoning/test_deepseekv4_reasoning_parser.py | 6 ++++-- vllm/reasoning/deepseek_v4_reasoning_parser.py | 4 +--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/reasoning/test_deepseekv4_reasoning_parser.py b/tests/reasoning/test_deepseekv4_reasoning_parser.py index 800f1285a352..f72fe27bbfa6 100644 --- a/tests/reasoning/test_deepseekv4_reasoning_parser.py +++ b/tests/reasoning/test_deepseekv4_reasoning_parser.py @@ -14,7 +14,6 @@ import pytest -from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.reasoning import ReasoningParserManager from vllm.reasoning.deepseek_v4_reasoning_parser import ( DeepSeekV4ReasoningParser, @@ -293,7 +292,10 @@ def test_is_reasoning_end_streaming_sticky_after_split(parser): ) assert parser._implicit_end_seen is True # Now a plain content-only delta. - assert parser.is_reasoning_end_streaming([800, DSML_MARKER_TOKEN_ID, 900], [900]) is True + assert ( + parser.is_reasoning_end_streaming([800, DSML_MARKER_TOKEN_ID, 900], [900]) + is True + ) # --------------------------------------------------------------------------- diff --git a/vllm/reasoning/deepseek_v4_reasoning_parser.py b/vllm/reasoning/deepseek_v4_reasoning_parser.py index 373efb9260a1..185af64a37e2 100644 --- a/vllm/reasoning/deepseek_v4_reasoning_parser.py +++ b/vllm/reasoning/deepseek_v4_reasoning_parser.py @@ -98,9 +98,7 @@ def is_reasoning_end_streaming( ) -> bool: if super().is_reasoning_end_streaming(input_ids, delta_ids): return True - if self._implicit_end_seen: - return True - return False + return bool(self._implicit_end_seen) def extract_reasoning_streaming( self, From 54cdf345c2abdc19a674b18168620e8f39a2f4cd Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 23 May 2026 17:18:05 +0800 Subject: [PATCH 053/131] Add DeepSeek V4 prefix cache pressure regression Signed-off-by: jasl --- tests/v1/core/test_prefix_caching.py | 61 ++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index d2f231d034b0..d412fd81e502 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -3637,6 +3637,67 @@ def test_deepseek_v4_mla_cached_prompts_do_not_block_admission(): ) +@pytest.mark.skip_global_cleanup +def test_deepseek_v4_mla_prefix_hit_under_pressure_does_not_overallocate(): + block_size = 8 + prompt_tokens = 5 * block_size + manager = KVCacheManager( + KVCacheConfig( + num_blocks=12, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ) + ], + ), + max_model_len=512, + max_num_batched_tokens=128, + enable_caching=True, + hash_block_size=block_size, + ) + + warm_req = make_request("warm", list(range(prompt_tokens)), block_size, sha256) + assert manager.allocate_slots(warm_req, prompt_tokens) is not None + warm_req.num_computed_tokens = prompt_tokens + manager.free(warm_req) + + pressure_blocks = manager.block_pool.get_new_blocks( + manager.block_pool.get_num_free_blocks() - 2 + ) + + reuse_req = make_request( + "reuse", + list(range(prompt_tokens)) + list(range(10_000, 10_000 + 2 * block_size)), + block_size, + sha256, + ) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(reuse_req) + + try: + assert ( + manager.allocate_slots( + reuse_req, + 2 * block_size, + num_new_computed_tokens=num_computed_tokens, + new_computed_blocks=computed_blocks, + ) + is None + ) + req_blocks = manager.coordinator.single_type_managers[0].req_to_blocks + assert "reuse" not in req_blocks + finally: + manager.block_pool.free_blocks(reversed(pressure_blocks)) + + def test_reset_prefix_cache_after_deepseek_v4_mla_prompt_cache(): block_size = 8 prompt_tokens = 4 * block_size + 3 From 75ba377b87f631afddd5a6625df62351b6636c8b Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 23 May 2026 21:21:55 +0800 Subject: [PATCH 054/131] Keep hybrid prefix cache tail blocks Signed-off-by: jasl --- tests/v1/core/test_prefix_caching.py | 33 ++++++++++++++++------------ vllm/v1/core/kv_cache_coordinator.py | 23 ++++--------------- 2 files changed, 23 insertions(+), 33 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index d412fd81e502..63390268bab6 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -2926,12 +2926,14 @@ def test_hybrid_cache_blocks_swa_tail_window_only(): ) -def test_hybrid_cache_blocks_clamped_to_lcm(): - """HybridKVCacheCoordinator.cache_blocks() clamps to scheduler_block_size. - Chunks past the last lcm-aligned boundary can never participate in a - cache hit (find_longest_cache_hit always returns lcm-aligned hits), so - caching them only pollutes the prefix-cache hash map and keeps blocks - on the LRU list that could otherwise return to the free pool.""" +def test_hybrid_cache_blocks_keeps_tail_blocks_but_hits_stay_lcm_aligned(): + """HybridKVCacheCoordinator.cache_blocks() keeps complete tail blocks. + + find_longest_cache_hit still returns lcm-aligned hits, so caching the + complete blocks past the last lcm boundary must not produce partial hybrid + cache hits. Keeping those blocks lets later turns complete a reusable + lcm-aligned segment instead of permanently dropping the tail. + """ block_size = 16 # Full attn block_size=32, SWA block_size=16 -> lcm=32. kv_cache_config = KVCacheConfig( @@ -2966,8 +2968,9 @@ def test_hybrid_cache_blocks_clamped_to_lcm(): hash_block_size=block_size, ) - # 7 hash-blocks of 16 tokens (112 tokens). With lcm=32 the clamp truncates - # to 96 tokens — SWA caches 6 hashes, full-attn caches 3. + # 7 hash-blocks of 16 tokens (112 tokens). The trailing SWA block is a + # complete block and should be cached, but a later hit is still capped by + # the full-attention group at the last 32-token lcm boundary. token_ids = [i for i in range(7) for _ in range(block_size)] req = make_request("0", token_ids, block_size, sha256) computed_blocks, _ = manager.get_computed_blocks(req) @@ -2981,16 +2984,18 @@ def test_hybrid_cache_blocks_clamped_to_lcm(): assert len(req.block_hashes) == 7 pool = manager.block_pool - # SWA group_id=1: hashes 0..5 cached (6 blocks * 16 tokens = 96), hash 6 - # spans tokens [96, 112) past the lcm boundary and must NOT be cached. - for i in range(6): + # SWA group_id=1: all complete SWA blocks should be cached. + for i in range(7): assert ( pool.get_cached_block(req.block_hashes[i], kv_cache_group_ids=[1]) is not None ), f"SWA hash {i} should be cached" - assert pool.get_cached_block(req.block_hashes[6], kv_cache_group_ids=[1]) is None, ( - "SWA hash 6 spans tokens past the lcm boundary; should not be cached" - ) + + warm = make_request("1", token_ids, block_size, sha256) + computed_blocks, num_computed_tokens = manager.get_computed_blocks(warm) + assert len(computed_blocks.blocks[0]) == 3 + assert len(computed_blocks.blocks[1]) == 6 + assert num_computed_tokens == 6 * block_size def test_hybrid_local_kv_retention_interval_aligns_in_manager(monkeypatch): diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 452146f9a97a..b5b109d776c4 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -627,28 +627,13 @@ def verify_and_split_kv_cache_groups(self) -> None: def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: # Cache hits in this coordinator are always a multiple of # ``scheduler_block_size`` tokens (see ``find_longest_cache_hit``). - # Within an aligned region, SWA groups may only consult a subset of blocks - # per ``scheduler_block_size``-segment so the unused blocks also stay - # out of the prefix-cache hash map. - aligned_num_computed_tokens = ( - num_computed_tokens // self.scheduler_block_size * self.scheduler_block_size - ) + # Managers may still cache complete tail blocks after the last aligned + # boundary; ``find_longest_cache_hit`` keeps returned hits aligned. for manager in self.single_type_managers: - num_tokens_to_cache = aligned_num_computed_tokens - # EAGLE groups match one block past each aligned boundary and drop - # it, so make that lookahead block eligible to be cached. - if manager.use_eagle and aligned_num_computed_tokens > 0: - num_tokens_to_cache = min( - num_computed_tokens, - aligned_num_computed_tokens + manager.block_size, - ) - # The manager already knows the fine hit granularity - # (``scheduler_block_size``); retention is passed separately so it - # can keep both the coarse segment tails and the fine replay - # boundary (which needs the fine value). manager.cache_blocks( request, - num_tokens_to_cache, + num_computed_tokens, + alignment_tokens=self.scheduler_block_size, retention_interval=self.retention_interval, ) From ea529148293f60818a04bf7107f5e0a68c2f50a5 Mon Sep 17 00:00:00 2001 From: jasl Date: Sun, 24 May 2026 20:25:12 +0800 Subject: [PATCH 055/131] Stabilize SM12x sparse MLA long prefill Signed-off-by: jasl --- tests/v1/attention/test_sparse_mla_env.py | 75 ++++ vllm/models/deepseek_v4/nvidia/flashmla.py | 348 ++++++++++++++++++ .../attention/backends/mla/sparse_mla_env.py | 28 ++ 3 files changed, 451 insertions(+) create mode 100644 tests/v1/attention/test_sparse_mla_env.py diff --git a/tests/v1/attention/test_sparse_mla_env.py b/tests/v1/attention/test_sparse_mla_env.py new file mode 100644 index 000000000000..74e48c60c61e --- /dev/null +++ b/tests/v1/attention/test_sparse_mla_env.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for sparse MLA environment helpers.""" + +from vllm.v1.attention.backends.mla import sparse_mla_env + + +def test_prefill_topk_uses_sm12x_multi_request_guard(monkeypatch): + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", raising=False) + monkeypatch.setattr( + sparse_mla_env.current_platform, + "is_device_capability_family", + lambda family: family == 120, + ) + + assert ( + sparse_mla_env.triton_sparse_mla_prefill_topk_chunk_size( + combined_topk_size=1152, + compress_ratio=128, + request_count=2, + ) + == 256 + ) + + +def test_prefill_topk_keeps_default_for_lower_risk_shapes(monkeypatch): + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", raising=False) + monkeypatch.setattr( + sparse_mla_env.current_platform, + "is_device_capability_family", + lambda family: family == 120, + ) + + assert ( + sparse_mla_env.triton_sparse_mla_prefill_topk_chunk_size( + combined_topk_size=1152, + compress_ratio=128, + request_count=1, + ) + == 512 + ) + assert ( + sparse_mla_env.triton_sparse_mla_prefill_topk_chunk_size( + combined_topk_size=640, + compress_ratio=4, + request_count=2, + ) + == 512 + ) + assert ( + sparse_mla_env.triton_sparse_mla_prefill_topk_chunk_size( + combined_topk_size=128, + compress_ratio=1, + request_count=2, + ) + == 128 + ) + + +def test_prefill_topk_honors_explicit_env_override(monkeypatch): + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", "512") + monkeypatch.setattr( + sparse_mla_env.current_platform, + "is_device_capability_family", + lambda family: family == 120, + ) + + assert ( + sparse_mla_env.triton_sparse_mla_prefill_topk_chunk_size( + combined_topk_size=1152, + compress_ratio=128, + request_count=2, + ) + == 512 + ) diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index 9fa4e1c11b94..3e8bd7e6c503 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -20,6 +20,25 @@ DeepseekV4FlashMLABackend, DeepseekV4FlashMLAMetadata, ) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled, + is_triton_sparse_mla_enabled_for_platform, + triton_sparse_mla_matmul_decode_enabled, + triton_sparse_mla_prefill_topk_chunk_size, + triton_sparse_mla_query_chunk_size, + triton_sparse_mla_topk_chunk_size, +) +from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead, + accumulate_indexed_sparse_mla_attention_chunk, + build_combined_sparse_mla_decode_valid_mask, + finish_sparse_mla_attention_with_sink, + finish_two_sparse_mla_attention_states_with_sink, + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead, + fp8ds_paged_sparse_mla_attention_with_sink_multihead, + matmul_sparse_mla_attention_with_sink, + sparse_mla_decode_head_block_size, +) from vllm.v1.attention.ops.flashmla import ( flash_mla_sparse_fwd, flash_mla_with_kvcache, @@ -142,6 +161,335 @@ def forward_mqa( output=output[:num_decode_tokens], ) + @classmethod + def _forward_sparse_mla_swa_decode_triton( + cls, + layer: "DeepseekV4FlashMLAAttention", + q: torch.Tensor, + swa_k_cache: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + output: torch.Tensor, + ) -> None: + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + mtp_decode = num_decode_tokens != num_decodes + + swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens] + swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens] + max_swa_len = swa_metadata.decode_swa_indices.shape[-1] + head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens) + if not mtp_decode: + fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q=q, + k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + block_size=swa_metadata.block_size, + candidate_offset=0, + num_candidates=max_swa_len, + scale=layer.scale, + attn_sink=layer.attn_sink, + output=output, + head_block_size=head_block_size, + num_heads=layer.n_local_heads, + ) + if output.shape[1] > layer.n_local_heads: + output[:, layer.n_local_heads :].zero_() + return + + ( + swa_max_score, + swa_denom, + swa_acc, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, layer.n_local_heads), torch.float32), + ((num_decode_tokens, layer.n_local_heads), torch.float32), + ((num_decode_tokens, layer.n_local_heads, q.shape[-1]), torch.float32), + ) + swa_max_score.fill_(float("-inf")) + swa_denom.zero_() + swa_acc.zero_() + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + slot_ids=swa_indices, + lens=swa_lens, + block_size=swa_metadata.block_size, + scale=layer.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + finish_sparse_mla_attention_with_sink( + swa_max_score, + swa_denom, + swa_acc, + layer.attn_sink, + output=output, + ) + if output.shape[1] > layer.n_local_heads: + output[:, layer.n_local_heads :].zero_() + + @classmethod + def _forward_sparse_mla_compressed_decode_triton( + cls, + layer: "DeepseekV4FlashMLAAttention", + q: torch.Tensor, + compressed_k_cache: torch.Tensor, + swa_k_cache: torch.Tensor, + topk_indices: torch.Tensor, + topk_lens: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata, + output: torch.Tensor, + ) -> None: + if layer.compress_ratio not in (4, 128): + raise NotImplementedError( + "Triton sparse MLA compressed decode currently supports " + f"compress_ratio=4 or 128, got {layer.compress_ratio}" + ) + + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + mtp_decode = num_decode_tokens != num_decodes + + max_swa_len = swa_metadata.decode_swa_indices.shape[-1] + compressed_block_size = attn_metadata.block_size // layer.compress_ratio + compressed_topk = topk_indices.shape[-1] + topk_chunk_size = min( + compressed_topk, + triton_sparse_mla_topk_chunk_size(), + ) + compressed_slot_ids = topk_indices[:, 0, :] + swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens] + swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens] + head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens) + if ( + compressed_topk <= topk_chunk_size + and triton_sparse_mla_matmul_decode_enabled() + ): + total_candidates = compressed_topk + max_swa_len + ( + combined_kv, + valid_tokens, + score_buffer, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, total_candidates, q.shape[-1]), torch.bfloat16), + ((num_decode_tokens, total_candidates), torch.bool), + ( + (num_decode_tokens, layer.n_local_heads, total_candidates), + torch.bfloat16, + ), + ) + if mtp_decode: + dequantize_global_slots_k_cache( + combined_kv[:, :compressed_topk], + compressed_k_cache, + compressed_slot_ids, + compressed_block_size, + ) + dequantize_global_slots_k_cache( + combined_kv[:, compressed_topk:], + swa_k_cache, + swa_indices, + swa_metadata.block_size, + ) + else: + dequantize_combined_sparse_mla_decode_kv( + combined_kv, + compressed_k_cache, + compressed_slot_ids, + compressed_block_size, + swa_k_cache, + swa_metadata.seq_lens[:num_decodes], + swa_lens, + swa_metadata.block_table[:num_decodes], + swa_metadata.block_size, + ) + + build_combined_sparse_mla_decode_valid_mask( + valid_tokens, + compressed_slot_ids, + topk_lens, + swa_lens, + ) + use_dot_finish = num_decode_tokens <= 16 + matmul_sparse_mla_attention_with_sink( + q=q, + kv=combined_kv, + valid_tokens=valid_tokens, + scale=layer.scale, + attn_sink=layer.attn_sink, + output=output, + num_heads=layer.n_local_heads, + score_buffer=score_buffer, + value_block_size=512 if use_dot_finish else 256, + candidate_block_size=128 if use_dot_finish else None, + ) + return + + if not mtp_decode and compressed_topk <= topk_chunk_size: + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q=q, + compressed_k_cache=compressed_k_cache, + slot_ids=compressed_slot_ids, + topk_lens=topk_lens, + compressed_block_size=compressed_block_size, + swa_k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + swa_block_size=swa_metadata.block_size, + num_compressed_candidates=compressed_topk, + num_swa_candidates=max_swa_len, + scale=layer.scale, + attn_sink=layer.attn_sink, + output=output, + head_block_size=head_block_size, + num_heads=layer.n_local_heads, + ) + if output.shape[1] > layer.n_local_heads: + output[:, layer.n_local_heads :].zero_() + return + + ( + comp_max_score, + comp_denom, + comp_acc, + swa_max_score, + swa_denom, + swa_acc, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, layer.n_local_heads), torch.float32), + ((num_decode_tokens, layer.n_local_heads), torch.float32), + ((num_decode_tokens, layer.n_local_heads, q.shape[-1]), torch.float32), + ((num_decode_tokens, layer.n_local_heads), torch.float32), + ((num_decode_tokens, layer.n_local_heads), torch.float32), + ((num_decode_tokens, layer.n_local_heads, q.shape[-1]), torch.float32), + ) + comp_max_score.fill_(float("-inf")) + comp_denom.zero_() + comp_acc.zero_() + swa_max_score.fill_(float("-inf")) + swa_denom.zero_() + swa_acc.zero_() + + for chunk_start in range(0, compressed_topk, topk_chunk_size): + chunk_end = min(chunk_start + topk_chunk_size, compressed_topk) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=compressed_k_cache, + slot_ids=compressed_slot_ids[:, chunk_start:chunk_end], + lens=topk_lens, + block_size=compressed_block_size, + candidate_offset=chunk_start, + scale=layer.scale, + max_score=comp_max_score, + denom=comp_denom, + acc=comp_acc, + head_block_size=head_block_size, + ) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + slot_ids=swa_indices, + lens=swa_lens, + block_size=swa_metadata.block_size, + scale=layer.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + finish_two_sparse_mla_attention_states_with_sink( + comp_max_score, + comp_denom, + comp_acc, + swa_max_score, + swa_denom, + swa_acc, + layer.attn_sink, + output=output, + ) + if output.shape[1] > layer.n_local_heads: + output[:, layer.n_local_heads :].zero_() + + @classmethod + def _forward_sparse_mla_prefill_triton( + cls, + layer: "DeepseekV4FlashMLAAttention", + q: torch.Tensor, + kv: torch.Tensor, + combined_indices: torch.Tensor, + combined_lens: torch.Tensor, + output: torch.Tensor, + state_buffers: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> None: + kv_flat = kv.reshape(-1, q.shape[-1]) + topk_chunk_size = triton_sparse_mla_prefill_topk_chunk_size( + combined_topk_size=combined_indices.shape[-1], + compress_ratio=int(layer.compress_ratio), + request_count=kv.shape[0], + ) + query_chunk_size = min( + q.shape[0], + triton_sparse_mla_query_chunk_size(), + ) + if state_buffers is None: + ( + max_score_buffer, + denom_buffer, + output_buffer, + ) = current_workspace_manager().get_simultaneous( + ((query_chunk_size, layer.n_local_heads), torch.float32), + ((query_chunk_size, layer.n_local_heads), torch.float32), + ((query_chunk_size, layer.n_local_heads, q.shape[-1]), torch.float32), + ) + else: + max_score_buffer, denom_buffer, output_buffer = state_buffers + + for token_start in range(0, q.shape[0], query_chunk_size): + token_end = min(token_start + query_chunk_size, q.shape[0]) + q_chunk = q[token_start:token_end] + indices_chunk_full = combined_indices[token_start:token_end] + lens_chunk = combined_lens[token_start:token_end] + num_tokens = token_end - token_start + max_score = max_score_buffer[:num_tokens] + denom = denom_buffer[:num_tokens] + subset_acc = output_buffer[:num_tokens] + max_score.fill_(float("-inf")) + denom.zero_() + subset_acc.zero_() + + for index_start in range(0, combined_indices.shape[-1], topk_chunk_size): + index_end = min( + index_start + topk_chunk_size, + combined_indices.shape[-1], + ) + accumulate_indexed_sparse_mla_attention_chunk( + q=q_chunk, + kv_flat=kv_flat, + indices=indices_chunk_full[:, index_start:index_end], + lens=lens_chunk, + candidate_offset=index_start, + scale=layer.scale, + max_score=max_score, + denom=denom, + acc=subset_acc, + ) + + finish_sparse_mla_attention_with_sink( + max_score, + denom, + subset_acc, + layer.attn_sink, + output=output[token_start:token_end], + ) + if output.shape[1] > layer.n_local_heads: + output[token_start:token_end, layer.n_local_heads :].zero_() + def _forward_decode( self, q: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/sparse_mla_env.py b/vllm/v1/attention/backends/mla/sparse_mla_env.py index 433d69413bcd..e2e739f12dff 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_env.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_env.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Environment controls for the portable Triton sparse MLA path.""" +import os + import torch import vllm.envs as envs @@ -42,6 +44,32 @@ def triton_sparse_mla_topk_chunk_size() -> int: return envs.VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE +def triton_sparse_mla_prefill_topk_chunk_size( + *, + combined_topk_size: int, + compress_ratio: int, + request_count: int, +) -> int: + """Choose the Triton sparse MLA prefill topk chunk size. + + Keep explicit user overrides authoritative. The auto path keeps the + previous 512-token chunk except for the SM12x C128A multi-request shape + that is unstable near 128K context. + """ + + configured_topk = triton_sparse_mla_topk_chunk_size() + if os.getenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE") is not None: + return min(combined_topk_size, configured_topk) + if ( + current_platform.is_device_capability_family(120) + and compress_ratio == 128 + and request_count > 1 + and combined_topk_size > 1024 + ): + configured_topk = min(configured_topk, 256) + return min(combined_topk_size, configured_topk) + + def triton_sparse_mla_query_chunk_size() -> int: return envs.VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE From 44c90d2ea39ddf2992936da586015d351a3a0079 Mon Sep 17 00:00:00 2001 From: jasl Date: Mon, 25 May 2026 03:26:19 +0800 Subject: [PATCH 056/131] Tune SM12x sparse MLA single prefill topk Signed-off-by: jasl --- tests/v1/attention/test_sparse_mla_env.py | 14 ++++++++++++-- .../attention/backends/mla/sparse_mla_env.py | 19 +++++++++---------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/v1/attention/test_sparse_mla_env.py b/tests/v1/attention/test_sparse_mla_env.py index 74e48c60c61e..b50d5f1b9dd8 100644 --- a/tests/v1/attention/test_sparse_mla_env.py +++ b/tests/v1/attention/test_sparse_mla_env.py @@ -23,7 +23,7 @@ def test_prefill_topk_uses_sm12x_multi_request_guard(monkeypatch): ) -def test_prefill_topk_keeps_default_for_lower_risk_shapes(monkeypatch): +def test_prefill_topk_relaxes_sm12x_single_request_c128a(monkeypatch): monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", raising=False) monkeypatch.setattr( sparse_mla_env.current_platform, @@ -37,8 +37,18 @@ def test_prefill_topk_keeps_default_for_lower_risk_shapes(monkeypatch): compress_ratio=128, request_count=1, ) - == 512 + == 1024 + ) + + +def test_prefill_topk_keeps_default_for_other_lower_risk_shapes(monkeypatch): + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", raising=False) + monkeypatch.setattr( + sparse_mla_env.current_platform, + "is_device_capability_family", + lambda family: family == 120, ) + assert ( sparse_mla_env.triton_sparse_mla_prefill_topk_chunk_size( combined_topk_size=640, diff --git a/vllm/v1/attention/backends/mla/sparse_mla_env.py b/vllm/v1/attention/backends/mla/sparse_mla_env.py index e2e739f12dff..92273b55111f 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_env.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_env.py @@ -52,21 +52,20 @@ def triton_sparse_mla_prefill_topk_chunk_size( ) -> int: """Choose the Triton sparse MLA prefill topk chunk size. - Keep explicit user overrides authoritative. The auto path keeps the - previous 512-token chunk except for the SM12x C128A multi-request shape - that is unstable near 128K context. + Keep explicit user overrides authoritative. The auto path uses a larger + chunk for SM12x C128A single-request prefill to reduce per-request loop + overhead, but keeps a smaller chunk for the multi-request shape that is + unstable near 128K context. """ configured_topk = triton_sparse_mla_topk_chunk_size() if os.getenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE") is not None: return min(combined_topk_size, configured_topk) - if ( - current_platform.is_device_capability_family(120) - and compress_ratio == 128 - and request_count > 1 - and combined_topk_size > 1024 - ): - configured_topk = min(configured_topk, 256) + if current_platform.is_device_capability_family(120) and compress_ratio == 128: + if request_count > 1 and combined_topk_size > 1024: + configured_topk = min(configured_topk, 256) + elif request_count == 1 and combined_topk_size > 1024: + configured_topk = max(configured_topk, 1024) return min(combined_topk_size, configured_topk) From 1059c811720b2f931feae846cf9aad42f22217a4 Mon Sep 17 00:00:00 2001 From: jasl Date: Mon, 25 May 2026 08:32:53 +0800 Subject: [PATCH 057/131] Protect active decode from very long prefill Signed-off-by: jasl --- tests/v1/core/test_scheduler.py | 4 ++-- vllm/v1/core/sched/scheduler.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 542c88041c63..ed57546115b9 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1015,7 +1015,7 @@ def test_mixed_decode_prefill_caps_mid_long_prefill_more_tightly(): mixed_output = scheduler.schedule() assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 - assert mixed_output.num_scheduled_tokens[mid_long_prefill_req.request_id] == 12 + assert mixed_output.num_scheduled_tokens[mid_long_prefill_req.request_id] == 6 def test_mixed_decode_prefill_caps_very_long_prefill_more_tightly(): @@ -1052,7 +1052,7 @@ def test_mixed_decode_prefill_caps_very_long_prefill_more_tightly(): mixed_output = scheduler.schedule() assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 - assert mixed_output.num_scheduled_tokens[very_long_prefill_req.request_id] == 12 + assert mixed_output.num_scheduled_tokens[very_long_prefill_req.request_id] == 6 def test_preempt_during_execution(): diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index e7b4f11447f9..01216a898599 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -420,7 +420,7 @@ def _limit_mixed_decode_prefill_chunk( ) if has_scheduled_decode: if remaining_prefill > very_long_prefill_threshold: - mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 8) + mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 16) else: mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 4) elif remaining_prefill > very_long_prefill_threshold: From 650de0698a2e5554812ecdb968b72f21dc162e80 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 27 May 2026 17:56:01 +0800 Subject: [PATCH 058/131] Clean sparse SWA imports after rebase Signed-off-by: jasl --- vllm/v1/attention/backends/mla/sparse_swa.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index 3788f59b1a20..4455cba25e90 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -22,7 +22,6 @@ from vllm.v1.attention.backends.utils import split_decodes_and_prefills from vllm.v1.attention.ops.flashmla import FlashMLASchedMeta, get_mla_metadata from vllm.v1.kv_cache_interface import ( - AttentionSpec, KVCacheSpec, MLAAttentionSpec, SlidingWindowMLASpec, From f169084f4bf034fe1ba51c18e68e7a3f7a21d7e9 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 27 May 2026 18:50:13 +0800 Subject: [PATCH 059/131] Guard SM120 FP4 sparse indexer dependency Signed-off-by: jasl --- .../test_sm120_deepgemm_fallbacks.py | 32 +++++++++++++++++++ .../layers/sparse_attn_indexer.py | 16 +++++++--- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py index a0e93d4943d3..08c0a9b752e3 100644 --- a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -8,6 +8,7 @@ from vllm.model_executor.layers.sparse_attn_indexer import ( _decode_logits_width, _decode_topk_logits_width, + _sparse_indexer_requires_deep_gemm, ) from vllm.models.deepseek_v4.nvidia.ops import ( sm12x_deep_gemm_fallbacks, @@ -126,6 +127,37 @@ def test_decode_topk_logits_width_keeps_topk_kernel_width(): assert _decode_topk_logits_width(0, 128, 512) == 0 +def test_sparse_indexer_requires_deep_gemm_for_sm120_fp4_cache(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_cuda", + lambda: True, + ) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda family: family == 120, + ) + + assert not _sparse_indexer_requires_deep_gemm(use_fp4_cache=False) + assert _sparse_indexer_requires_deep_gemm(use_fp4_cache=True) + + +def test_sparse_indexer_requires_deep_gemm_for_other_cuda_arches(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_cuda", + lambda: True, + ) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda family: False, + ) + + assert _sparse_indexer_requires_deep_gemm(use_fp4_cache=False) + + def test_sm120_direct_mqa_logits_block_m_prefers_short_prefill_tile(): assert sm12x_mqa._fp8_mqa_logits_block_m(1024, 1024) == 16 assert sm12x_mqa._fp8_mqa_logits_block_m(4096, 4096) == 16 diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index d44572b96f69..a8372ed7d43a 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -80,10 +80,16 @@ def _decode_topk_logits_width( return min(max_model_len, max(logits_width, topk_tokens)) -def _sparse_indexer_requires_deep_gemm() -> bool: - return current_platform.is_cuda() and not ( - current_platform.is_device_capability_family(120) - ) +def _sparse_indexer_requires_deep_gemm(use_fp4_cache: bool = False) -> bool: + if not current_platform.is_cuda(): + return False + if current_platform.is_device_capability_family(120): + # The SM120 fallback path covers FP8-Q sparse indexer calls. FP4-Q + # indexer calls still route through DeepGEMM's fp8_fp4 kernels, so + # fail during construction instead of letting the first forward hit + # the generic DeepGEMM missing-dependency error. + return use_fp4_cache + return True def _gather_workspace_shapes( @@ -542,7 +548,7 @@ def __init__( self.topk_indices_buffer = topk_indices_buffer self.skip_k_cache_insert = skip_k_cache_insert self.use_fp4_cache = use_fp4_cache - if _sparse_indexer_requires_deep_gemm() and not has_deep_gemm(): + if _sparse_indexer_requires_deep_gemm(use_fp4_cache) and not has_deep_gemm(): raise RuntimeError( "Sparse Attention Indexer CUDA op requires DeepGEMM support in " "the current vLLM environment." From d727d81eed11d8c6518272480cd1aa0d9c133ce8 Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 28 May 2026 00:39:18 +0800 Subject: [PATCH 060/131] Absorb SM120 external Marlin fixes Refuse block-FP8 layers in the Marlin FP8 kernel selector so DSv4 block-FP8 compressor layers fall through to a block-FP8-capable path even when Marlin is forced for W4A16/NVFP4 MoE layers. Accept both the Marlin-renamed wo_a weight_scale_inv attribute and the non-Marlin weight_scale attribute in DeepSeek V4 MLA setup. Also absorb the Marlin MoE CUDA graph hardening from vllm-project/vllm#43730: keep the device shared-memory maximum for cudaFuncSetAttribute, launch with the per-config shared-memory size, and allocate c_tmp using a graph-stable upper bound. Refs: vllm-project/vllm#43722 Refs: vllm-project/vllm#43723 Refs: vllm-project/vllm#43730 Co-authored-by: pasta-paul Co-authored-by: haosdent Signed-off-by: jasl --- .../moe/marlin_moe_wna16/ops.cu | 10 +- tests/kernels/moe/test_moe.py | 113 ++++++++++++++++++ .../test_fp8_marlin_kernel_selection.py | 59 +++++++++ .../kernels/linear/scaled_mm/marlin.py | 16 +++ vllm/models/deepseek_v4/nvidia/ops/o_proj.py | 7 +- 5 files changed, 199 insertions(+), 6 deletions(-) create mode 100644 tests/model_executor/test_fp8_marlin_kernel_selection.py diff --git a/csrc/libtorch_stable/moe/marlin_moe_wna16/ops.cu b/csrc/libtorch_stable/moe/marlin_moe_wna16/ops.cu index 177eefa2c6f0..8fd63140e6d8 100644 --- a/csrc/libtorch_stable/moe/marlin_moe_wna16/ops.cu +++ b/csrc/libtorch_stable/moe/marlin_moe_wna16/ops.cu @@ -437,6 +437,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); STD_TORCH_CHECK(max_shared_mem > 0); + int device_max_shared_mem = max_shared_mem; int major_capability, minor_capability; cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, @@ -527,10 +528,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, } cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - max_shared_mem); + device_max_shared_mem); // avoid ">>>" being formatted to "> > >" // clang-format off - kernel<<>>( + kernel<<>>( A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, num_groups, prob_m, @@ -708,9 +709,8 @@ torch::stable::Tensor moe_wna16_marlin_gemm( torch::stable::Tensor c_tmp; if (use_fp32_reduce && !use_atomic_add) { // max num of threadblocks is sms * 4 - long max_c_tmp_size = min( - (long)size_n * sorted_token_ids.size(0), - (long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n); + long max_c_tmp_size = + (long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n; if (moe_block_size == 8) max_c_tmp_size *= 2; c_tmp = torch::stable::new_empty(a, {max_c_tmp_size}, kFloat); } else { diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 45cd17b3b115..750a99b20bba 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -1009,6 +1009,119 @@ def test_fused_marlin_moe( torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +def test_fused_marlin_moe_cuda_graph(): + """Test that GPTQ Marlin MoE works correctly with CUDA graphs. + + Regression test for https://github.com/vllm-project/vllm/issues/36811 + The bug was caused by: + 1. cudaFuncSetAttribute(MaxDynamicSharedMemorySize) being overwritten + by later CUDA graph captures with different blocks_per_sm values. + 2. Batch-dependent c_tmp buffer sizing via sorted_token_ids.size(0). + """ + torch.cuda.manual_seed(42) + + # Use 64 experts like Qwen3.5-35B-A3B to trigger the bug + e, topk = 64, 8 + n, k = 1024, 1024 + group_size = 128 + quant_type = scalar_types.uint4b8 + dtype = torch.half + + # Batch sizes matching vLLM's CUDA graph capture sizes + batch_sizes = [1, 2, 4, 8, 16, 24, 32] + + # Create weights (shared across all batch sizes) + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + + w1_data = MarlinMoEWeightData.make( + w=w1, + quant_type=quant_type, + group_size=group_size, + act_order=False, + ) + w2_data = MarlinMoEWeightData.make( + w=w2, + quant_type=quant_type, + group_size=group_size, + act_order=False, + ) + + # Capture a CUDA graph for each batch size + graphs = {} + static_inputs = {} + static_outputs = {} + + for m in batch_sizes: + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + + # Static input tensors for graph replay + a_static = a.clone() + topk_weights_static = topk_weights.clone() + topk_ids_static = topk_ids.clone() + + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.stream(stream), torch.cuda.graph(graph): + out = fused_marlin_moe( + a_static, + w1_data.qweight, + w2_data.qweight, + None, + None, + w1_data.scales, + w2_data.scales, + topk_weights_static, + topk_ids_static, + global_num_experts=e, + quant_type_id=quant_type.id, + is_k_full=True, + ) + + graphs[m] = graph + static_inputs[m] = (a_static, topk_weights_static, topk_ids_static) + static_outputs[m] = out + + torch.accelerator.synchronize() + + # Replay each graph and compare against eager mode. + # The bug manifested when replaying small batch size graphs (e.g., M=1) + # after larger batch sizes had overwritten cudaFuncSetAttribute. + for m in batch_sizes: + a_static, topk_weights_static, topk_ids_static = static_inputs[m] + + # Compute eager reference with the same inputs + eager_out = fused_marlin_moe( + a_static, + w1_data.qweight, + w2_data.qweight, + None, + None, + w1_data.scales, + w2_data.scales, + topk_weights_static, + topk_ids_static, + global_num_experts=e, + quant_type_id=quant_type.id, + is_k_full=True, + ) + + # Replay the captured graph + static_outputs[m].zero_() + graphs[m].replay() + torch.accelerator.synchronize() + + torch.testing.assert_close( + static_outputs[m], + eager_out, + atol=1e-2, + rtol=1e-2, + ) + + @pytest.mark.flaky(reruns=2) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @pytest.mark.usefixtures("default_vllm_config") diff --git a/tests/model_executor/test_fp8_marlin_kernel_selection.py b/tests/model_executor/test_fp8_marlin_kernel_selection.py new file mode 100644 index 000000000000..0e3cdbddf5cc --- /dev/null +++ b/tests/model_executor/test_fp8_marlin_kernel_selection.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.model_executor.kernels.linear.scaled_mm import ( + FP8ScaledMMLinearLayerConfig, + MarlinFP8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + QuantKey, + ScaleDesc, +) + + +def _fp8_linear_config( + activation_group_shape: GroupShape, +) -> FP8ScaledMMLinearLayerConfig: + return FP8ScaledMMLinearLayerConfig( + weight_quant_key=QuantKey( + dtype=torch.float8_e4m3fn, + scale=ScaleDesc( + dtype=torch.float32, + static=True, + group_shape=GroupShape.PER_CHANNEL, + ), + ), + activation_quant_key=QuantKey( + dtype=torch.float8_e4m3fn, + scale=ScaleDesc( + dtype=torch.float32, + static=False, + group_shape=activation_group_shape, + ), + ), + weight_shape=(128, 128), + input_dtype=torch.bfloat16, + out_dtype=torch.bfloat16, + ) + + +def test_marlin_fp8_refuses_block_fp8_activation_scales(): + can_implement, reason = MarlinFP8ScaledMMLinearKernel.can_implement( + _fp8_linear_config(GroupShape(1, 128)) + ) + + assert not can_implement + assert reason is not None + assert "block-FP8" in reason + + +def test_marlin_fp8_keeps_non_block_fp8_layers_available(): + can_implement, reason = MarlinFP8ScaledMMLinearKernel.can_implement( + _fp8_linear_config(GroupShape.PER_TOKEN) + ) + + assert can_implement + assert reason is None diff --git a/vllm/model_executor/kernels/linear/scaled_mm/marlin.py b/vllm/model_executor/kernels/linear/scaled_mm/marlin.py index 66a03b4d205b..59427e0d3897 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/marlin.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/marlin.py @@ -57,6 +57,22 @@ def is_supported( @classmethod def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + # Marlin's per-channel scale layout cannot serve block-FP8 layers + # (weight_scale_inv shape [N_blocks_n, N_blocks_k]). + # csrc/quantization/marlin/marlin.cu fires + # ``TORCH_CHECK(b_scales.size(1) == size_n, ...)`` for any layer + # where the block-quant scales arrive in their raw 2-D layout. + # Even when VLLM_TEST_FORCE_FP8_MARLIN=1 is set (which downstream + # operators do for NVFP4-MoE on SM 12.0 / RTX PRO 6000 Blackwell), + # block-FP8 attention compressor layers like DSv4 fused_wqa_wkv + # must fall through to the next supported kernel in + # ``_POSSIBLE_FP8_BLOCK_KERNELS[CUDA]`` (typically the Triton + # block-FP8 path). + if c.activation_quant_key.scale.group_shape.is_per_group(): + return False, ( + "MarlinFP8 cannot serve block-FP8 layers; falling " + "through to the next kernel in the priority list." + ) return True, None def __init__( diff --git a/vllm/models/deepseek_v4/nvidia/ops/o_proj.py b/vllm/models/deepseek_v4/nvidia/ops/o_proj.py index 18e3b10562bd..54c255fc3ada 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/o_proj.py +++ b/vllm/models/deepseek_v4/nvidia/ops/o_proj.py @@ -60,10 +60,15 @@ def deep_gemm_fp8_o_proj( device=o.device, dtype=torch.bfloat16, ) + # MarlinFP8.process_weights_after_loading renames block-FP8 scales to + # weight_scale_inv. Non-Marlin kernels keep the on-disk weight_scale name. + wo_a_scale = getattr(wo_a, "weight_scale_inv", None) + if wo_a_scale is None: + wo_a_scale = wo_a.weight_scale fp8_einsum( "bhr,hdr->bhd", (o_fp8, o_scale), - (wo_a.weight, wo_a.weight_scale_inv), + (wo_a.weight, wo_a_scale), z, recipe=einsum_recipe, ) From 93400dd386a94824542b1daffbf839bc53bb87e4 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 29 May 2026 04:25:44 +0800 Subject: [PATCH 061/131] sm120: keep optimized MHC prenorm path without DeepGEMM Signed-off-by: jasl --- tests/kernels/test_mhc_kernels.py | 15 ++++++++++++++ vllm/model_executor/kernels/mhc/tilelang.py | 23 ++++++++++++--------- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/tests/kernels/test_mhc_kernels.py b/tests/kernels/test_mhc_kernels.py index 0e0e3769f497..0f1892329487 100644 --- a/tests/kernels/test_mhc_kernels.py +++ b/tests/kernels/test_mhc_kernels.py @@ -7,6 +7,7 @@ from vllm.model_executor.kernels.mhc.tilelang import ( _tilelang_hc_prenorm_gemm, _torch_hc_prenorm_gemm, + _use_tf32_hc_prenorm_gemm, ) from vllm.platforms import current_platform from vllm.utils.import_utils import has_tilelang @@ -96,6 +97,20 @@ def hc_head_ref( return torch.sum(pre_mix.unsqueeze(-1) * residual.float(), dim=-2).bfloat16() +def test_sm120_uses_tf32_hc_prenorm_gemm_without_deepgemm(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda family: family == 120, + ) + monkeypatch.setattr( + "vllm.utils.deep_gemm.is_deep_gemm_supported", + lambda: False, + ) + + assert _use_tf32_hc_prenorm_gemm() + + @pytest.mark.skipif( not (current_platform.is_cuda_alike() and has_tilelang()), reason="CUDA or ROCm and tilelang required", diff --git a/vllm/model_executor/kernels/mhc/tilelang.py b/vllm/model_executor/kernels/mhc/tilelang.py index e0007141d53d..0876dc88602d 100644 --- a/vllm/model_executor/kernels/mhc/tilelang.py +++ b/vllm/model_executor/kernels/mhc/tilelang.py @@ -2,9 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op +def _use_tf32_hc_prenorm_gemm() -> bool: + from vllm.utils.deep_gemm import is_deep_gemm_supported + + return current_platform.is_device_capability_family(120) or is_deep_gemm_supported() + + def _torch_hc_prenorm_gemm( x: torch.Tensor, fn: torch.Tensor, @@ -162,10 +169,8 @@ def mhc_pre_tilelang( residual_flat = residual.view(-1, hc_mult, hidden_size) num_tokens = residual_flat.shape[0] - from vllm.utils.deep_gemm import is_deep_gemm_supported - - use_deep_gemm = is_deep_gemm_supported() - if use_deep_gemm: + use_tf32_hc_prenorm_gemm = _use_tf32_hc_prenorm_gemm() + if use_tf32_hc_prenorm_gemm: # these numbers are from deepgemm kernel impl block_k = 64 block_m = 64 @@ -191,7 +196,7 @@ def mhc_pre_tilelang( ) residual_2d = residual_flat.view(num_tokens, hc_mult * hidden_size) - if use_deep_gemm: + if use_tf32_hc_prenorm_gemm: tf32_hc_prenorm_gemm( residual_2d, fn, @@ -405,16 +410,14 @@ def mhc_fused_post_pre_tilelang( post_layer_mix_flat = post_layer_mix.view(num_tokens, hc_mult) comb_res_mix_flat = comb_res_mix.view(num_tokens, hc_mult, hc_mult) - from vllm.utils.deep_gemm import is_deep_gemm_supported - - use_deep_gemm = is_deep_gemm_supported() + use_tf32_hc_prenorm_gemm = _use_tf32_hc_prenorm_gemm() use_small_fma = num_tokens <= 16 if use_small_fma: # TODO(gnovack): investigate autotuning these heuristics tile_n = 2 if num_tokens < 8 else 3 n_splits = 8 if (num_tokens < 8 and hidden_size <= 4096) else 4 else: - if use_deep_gemm: + if use_tf32_hc_prenorm_gemm: # these number are from deepgemm kernel impl block_k = 64 block_m = 64 @@ -485,7 +488,7 @@ def mhc_fused_post_pre_tilelang( ) residual_cur_2d = residual_cur.view(num_tokens, hc_mult * hidden_size) - if use_deep_gemm: + if use_tf32_hc_prenorm_gemm: from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm tf32_hc_prenorm_gemm( From fb9c4d3e5399848a563177cda3ccd2011e348386 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 29 May 2026 07:26:21 +0800 Subject: [PATCH 062/131] sm12x: prune fallback tests and tuned config duplicates Signed-off-by: jasl --- .../test_deepseek_v4_kernel_warmup.py | 22 --- .../test_sm12x_tuned_config_lookup.py | 41 +---- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 ---------- ...-Q_Workstation_Edition,dtype=fp8_w8a8.json | 146 ----------------- ...ackwell_Server_Edition,dtype=fp8_w8a8.json | 146 ----------------- ...-Q_Workstation_Edition,dtype=fp8_w8a8.json | 146 ----------------- ...ackwell_Server_Edition,dtype=fp8_w8a8.json | 146 ----------------- ...-Q_Workstation_Edition,dtype=fp8_w8a8.json | 147 ------------------ ...ackwell_Server_Edition,dtype=fp8_w8a8.json | 147 ------------------ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ------------------ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ------------------ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ------------------ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ------------------ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ------------------ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ------------------ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 83 ---------- .../layers/fused_moe/fused_moe.py | 63 ++++++-- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 ---------- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 82 ---------- .../layers/quantization/utils/fp8_utils.py | 50 ++++-- .../deepseek_v4/nvidia/ops/fp8_einsum.py | 21 --- 37 files changed, 94 insertions(+), 3511 deletions(-) delete mode 100644 tests/model_executor/test_deepseek_v4_kernel_warmup.py delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json delete mode 100644 vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/tests/model_executor/test_deepseek_v4_kernel_warmup.py b/tests/model_executor/test_deepseek_v4_kernel_warmup.py deleted file mode 100644 index e73db23ebb7b..000000000000 --- a/tests/model_executor/test_deepseek_v4_kernel_warmup.py +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from types import SimpleNamespace - -from vllm.model_executor.warmup.kernel_warmup import ( - _deepseek_v4_mtp_uniform_decode_warmup_requests, -) - - -def test_deepseek_v4_mtp_uniform_decode_warmup_caps_large_max_num_seqs(): - runner = SimpleNamespace( - speculative_config=SimpleNamespace(method="mtp"), - num_spec_tokens=2, - uniform_decode_query_len=3, - ) - - assert _deepseek_v4_mtp_uniform_decode_warmup_requests( - runner, - max_tokens=4096, - max_reqs=1024, - ) == (1, 2, 4, 8, 16, 24, 32) diff --git a/tests/quantization/test_sm12x_tuned_config_lookup.py b/tests/quantization/test_sm12x_tuned_config_lookup.py index c27d28e45580..aa85120fdde4 100644 --- a/tests/quantization/test_sm12x_tuned_config_lookup.py +++ b/tests/quantization/test_sm12x_tuned_config_lookup.py @@ -5,24 +5,6 @@ from vllm.model_executor.layers.quantization.utils import fp8_utils from vllm.platforms import current_platform -GB10_BLOCK_FP8_SHAPES = ( - (1536, 4096), - (16384, 1024), - (2048, 4096), - (4096, 1024), - (4096, 4096), - (8192, 1024), -) - -GB10_FUSED_MOE_SHAPES = ( - (128, 704, None), - (129, 704, None), - (20, 1536, None), - (256, 384, (128, 128)), - (256, 512, (128, 128)), - (64, 1536, (128, 128)), -) - def _get_fused_moe_configs(e, n, block_shape): if block_shape is None: @@ -31,22 +13,15 @@ def _get_fused_moe_configs(e, n, block_shape): return fused_moe.get_moe_configs(e, n, "fp8_w8a8", block_n, block_k) -def test_gb10_tuned_configs_cover_dense_and_fused_moe(monkeypatch): - monkeypatch.setattr(current_platform, "get_device_name", lambda: "NVIDIA GB10") +def test_rtx_pro_6000_variants_reuse_workstation_tuned_configs(monkeypatch): + monkeypatch.setattr( + current_platform, + "get_device_name", + lambda: "NVIDIA RTX PRO 6000 Blackwell Server Edition", + ) monkeypatch.setattr(fused_moe.envs, "VLLM_BATCH_INVARIANT", False) fp8_utils.get_w8a8_block_fp8_configs.cache_clear() fused_moe.get_moe_configs.cache_clear() - missing_dense = [ - (n, k) - for n, k in GB10_BLOCK_FP8_SHAPES - if fp8_utils.get_w8a8_block_fp8_configs(n, k, 128, 128) is None - ] - assert not missing_dense - - missing_moe = [ - (e, n, block_shape) - for e, n, block_shape in GB10_FUSED_MOE_SHAPES - if _get_fused_moe_configs(e, n, block_shape) is None - ] - assert not missing_moe + assert fp8_utils.get_w8a8_block_fp8_configs(1536, 4096, 128, 128) is not None + assert _get_fused_moe_configs(256, 384, (128, 128)) is not None diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 06431d4d355e..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,83 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 2 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 5 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 5 - }, - "16": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 5 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 06431d4d355e..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=128,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,83 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 2 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 5 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 5 - }, - "16": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 5 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json deleted file mode 100644 index 75dfc52cb46d..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +++ /dev/null @@ -1,146 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 5 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "96": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2 - }, - "128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 - }, - "1024": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 - }, - "1536": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 - }, - "2048": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "3072": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2 - }, - "4096": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json deleted file mode 100644 index 75dfc52cb46d..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json +++ /dev/null @@ -1,146 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 5 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "96": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2 - }, - "128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 - }, - "1024": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 - }, - "1536": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 - }, - "2048": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "3072": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2 - }, - "4096": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json deleted file mode 100644 index 75dfc52cb46d..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +++ /dev/null @@ -1,146 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 5 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "96": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2 - }, - "128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 - }, - "1024": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 - }, - "1536": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 - }, - "2048": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "3072": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2 - }, - "4096": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json deleted file mode 100644 index 75dfc52cb46d..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json +++ /dev/null @@ -1,146 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 5 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "96": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2 - }, - "128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 - }, - "1024": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 - }, - "1536": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 - }, - "2048": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "3072": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2 - }, - "4096": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json deleted file mode 100644 index 8b78f87e7f73..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "triton_version": "3.5.0", - "1": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 2 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "4": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 2 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 5 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "48": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "96": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "1024": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "1536": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "2048": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "3072": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "4096": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json deleted file mode 100644 index 8b78f87e7f73..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "triton_version": "3.5.0", - "1": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 2 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "4": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 2 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 5 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "48": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "96": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "1024": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "1536": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "2048": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "3072": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "4096": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 4b98ba105c14..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,83 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 2 - }, - "8": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 5 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 4 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 4b98ba105c14..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,83 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 2 - }, - "8": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 5 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 4 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index bcec61632e3e..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 5 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "48": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "96": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 2 - }, - "512": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "1024": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "1536": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "2048": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "3072": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "4096": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index bcec61632e3e..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 5 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "48": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "96": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 2 - }, - "512": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "1024": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "1536": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "2048": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "3072": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "4096": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 705ca33d594b..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 2 - }, - "8": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 2 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "96": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 2 - }, - "128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 2 - }, - "256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 - }, - "512": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 2 - }, - "1024": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "1536": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "2048": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "3072": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "4096": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 705ca33d594b..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 2 - }, - "8": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 2 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "96": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 2 - }, - "128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 2 - }, - "256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 - }, - "512": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 2 - }, - "1024": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "1536": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "2048": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "3072": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "4096": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 952b908297b7..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,83 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2 - }, - "8": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 5 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 2 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "512": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 952b908297b7..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,83 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2 - }, - "8": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 2 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 5 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 2 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "512": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 9c2ebaddd83f..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 2 - }, - "4": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 5 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "24": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "96": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "1024": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "1536": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "2048": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "3072": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "4096": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 9c2ebaddd83f..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,147 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 2 - }, - "4": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 5 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "24": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 - }, - "96": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "1024": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "1536": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "2048": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "3072": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "4096": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 36d0361b926f..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,83 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "4": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 5 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 5 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 5 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 5 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - } -} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 36d0361b926f..000000000000 --- a/vllm/model_executor/layers/fused_moe/configs/E=64,N=2048,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,83 +0,0 @@ -{ - "triton_version": "3.6.0", - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "4": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 5 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 5 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 5 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 5 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - } -} diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 49957c8f5e36..5b11243edeed 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -35,6 +35,23 @@ logger = init_logger(__name__) +_SM12X_TUNED_CONFIG_DEVICE_ALIASES = { + "NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition": ( + "NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition", + ), + "NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition": ( + "NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition", + ), +} + + +def _tuned_config_device_names() -> tuple[str, ...]: + device_name = current_platform.get_device_name().replace(" ", "_") + if "H200" in device_name.split("_"): + return ("NVIDIA_H200",) + return (device_name, *_SM12X_TUNED_CONFIG_DEVICE_ALIASES.get(device_name, ())) + + @triton.jit def write_zeros_to_output( c_ptr, @@ -999,10 +1016,30 @@ def zero_experts_compute_triton( def get_config_file_name( E: int, N: int, dtype: str | None, block_shape: list[int] | None = None ) -> str: - device_name = current_platform.get_device_name().replace(" ", "_") - # Set device_name to H200 if a device from the H200 family is detected - if "H200" in device_name.split("_"): - device_name = "NVIDIA_H200" + device_name = _tuned_config_device_names()[0] + dtype_selector = "" if not dtype else f",dtype={dtype}" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ).replace(" ", "") + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 + + +def _get_config_file_names( + E: int, N: int, dtype: str | None, block_shape: list[int] | None = None +) -> tuple[str, ...]: + return tuple( + _get_config_file_name_for_device(E, N, dtype, device_name, block_shape) + for device_name in _tuned_config_device_names() + ) + + +def _get_config_file_name_for_device( + E: int, + N: int, + dtype: str | None, + device_name: str, + block_shape: list[int] | None = None, +) -> str: dtype_selector = "" if not dtype else f",dtype={dtype}" block_shape_selector = ( "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" @@ -1035,22 +1072,26 @@ def get_moe_configs( # First look up if an optimized configuration is available in the configs # directory block_shape = [block_n, block_k] if block_n and block_k else None - json_file_name = get_config_file_name(E, N, dtype, block_shape) + json_file_names = _get_config_file_names(E, N, dtype, block_shape) config_file_paths = [] # note that we prioritize user defined config user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER if user_defined_config_folder is not None: - user_defined_config_file_path = os.path.join( - user_defined_config_folder, json_file_name + config_file_paths.extend( + os.path.join(user_defined_config_folder, json_file_name) + for json_file_name in json_file_names ) - config_file_paths.append(user_defined_config_file_path) - default_config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + config_file_paths.extend( + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "configs", + json_file_name, + ) + for json_file_name in json_file_names ) - config_file_paths.append(default_config_file_path) for config_file_path in config_file_paths: if os.path.exists(config_file_path): diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 387b572731b6..000000000000 --- a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 5 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 387b572731b6..000000000000 --- a/vllm/model_executor/layers/quantization/utils/configs/N=1536,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 5 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index ac91f525e96b..000000000000 --- a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "32": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "64": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index ac91f525e96b..000000000000 --- a/vllm/model_executor/layers/quantization/utils/configs/N=16384,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "32": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "64": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index ac1053b588c5..000000000000 --- a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 5 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 5 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index ac1053b588c5..000000000000 --- a/vllm/model_executor/layers/quantization/utils/configs/N=2048,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 5 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 5 - }, - "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index be96d80c51f1..000000000000 --- a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index be96d80c51f1..000000000000 --- a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 3 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 5163bc4f3da1..000000000000 --- a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 5 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 5163bc4f3da1..000000000000 --- a/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=4096,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 4 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 8, - "num_stages": 3 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 4 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 4 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 4 - }, - "64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 5 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 5a6f1de61395..000000000000 --- a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "32": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json deleted file mode 100644 index 5a6f1de61395..000000000000 --- a/vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1024,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8,block_shape=[128,128].json +++ /dev/null @@ -1,82 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 5 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "32": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 8, - "num_stages": 4 - }, - "64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 3 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 - }, - "256": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, - "num_stages": 3 - }, - "512": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 - } -} diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index da91124e44ae..9dcc21ff460c 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -38,6 +38,21 @@ logger = init_logger(__name__) +_SM12X_TUNED_CONFIG_DEVICE_ALIASES = { + "NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition": ( + "NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition", + ), + "NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition": ( + "NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition", + ), +} + + +def _tuned_config_device_names() -> tuple[str, ...]: + device_name = current_platform.get_device_name().replace(" ", "_") + return (device_name, *_SM12X_TUNED_CONFIG_DEVICE_ALIASES.get(device_name, ())) + + def is_fp8(x: torch.dtype | torch.Tensor) -> bool: if isinstance(x, torch.Tensor): x = x.dtype @@ -864,27 +879,30 @@ def get_w8a8_block_fp8_configs( # First look up if an optimized configuration is available in the configs # directory - device_name = current_platform.get_device_name().replace(" ", "_") - json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501 - - config_file_path = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name - ) - if os.path.exists(config_file_path): - with open(config_file_path) as f: - logger.info( - "Using configuration from %s for W8A8 Block FP8 kernel.", - config_file_path, - ) - # If a configuration has been found, return it - return {int(key): val for key, val in json.load(f).items()} + config_file_paths = [ + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "configs", + f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json", # noqa: E501 + ) + for device_name in _tuned_config_device_names() + ] + for config_file_path in config_file_paths: + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info( + "Using configuration from %s for W8A8 Block FP8 kernel.", + config_file_path, + ) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration logger.warning( "Using default W8A8 Block FP8 kernel config. Performance might " - "be sub-optimal! Config file not found at %s", - config_file_path, + "be sub-optimal! Config files not found at %s", + config_file_paths, ) return None diff --git a/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py index 31353e62748f..09c0f4aa3f5e 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py +++ b/vllm/models/deepseek_v4/nvidia/ops/fp8_einsum.py @@ -11,7 +11,6 @@ from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import fp8_einsum -from vllm.utils.torch_utils import direct_register_custom_op @triton.jit @@ -272,23 +271,3 @@ def deepseek_v4_fp8_einsum( return fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) - - -def deepseek_v4_fp8_einsum_fake( - a: torch.Tensor, - a_scale: torch.Tensor, - b: torch.Tensor, - b_scale: torch.Tensor, - out: torch.Tensor, - equation: str, - recipe: list[int], -) -> None: - return None - - -direct_register_custom_op( - op_name="deepseek_v4_fp8_einsum", - op_func=deepseek_v4_fp8_einsum, - mutates_args=["out"], - fake_impl=deepseek_v4_fp8_einsum_fake, -) From d6e4dcd2c4bfb919d32f05db27b4e30eafd3202c Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 29 May 2026 09:41:05 +0800 Subject: [PATCH 063/131] sm12x: clear MXFP4 loading cache after setup Signed-off-by: jasl --- vllm/model_executor/layers/quantization/mxfp4.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 1b2a8a74bdcb..a205dd612a21 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -37,6 +37,14 @@ logger = init_logger(__name__) +def _clear_mxfp4_moe_weight_loading_cache() -> None: + # WORKAROUND: SM12x/GB10 can hit driver instability while loading MXFP4 + # MoE weights under tight VRAM pressure. Release PyTorch's staging cache + # after the backend kernel has taken ownership of the converted weights. + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + class Mxfp4Config(QuantizationConfig): """Canonical base config for MXFP4 quantization. @@ -389,6 +397,8 @@ def process_weights_after_loading(self, layer: RoutedExperts) -> None: return self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias) + del w13, w2, w13_scale, w2_scale, w13_bias, w2_bias + _clear_mxfp4_moe_weight_loading_cache() def get_fused_moe_quant_config( self, layer: RoutedExperts @@ -733,6 +743,8 @@ def process_weights_after_loading(self, layer): return self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias) + del w13, w2, w13_scale, w2_scale, w13_bias, w2_bias + _clear_mxfp4_moe_weight_loading_cache() def get_fused_moe_quant_config( self, From aaef91a44c2275f71d7b9779ed922431ce673667 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 29 May 2026 23:10:47 +0800 Subject: [PATCH 064/131] sm12x: drop obsolete MHC CustomOp wrapper Signed-off-by: jasl --- vllm/model_executor/layers/mhc.py | 48 ++++--------------------------- 1 file changed, 6 insertions(+), 42 deletions(-) diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index 60c942e5e510..de1b2a0c617c 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -6,7 +6,6 @@ # import vllm.model_executor.kernels.mhc # noqa: F401 import vllm.model_executor.kernels.mhc as mhc_kernels from vllm.model_executor.custom_op import CustomOp -from vllm.platforms import current_platform from vllm.utils.import_utils import has_tilelang HAS_TILELANG = has_tilelang() @@ -261,45 +260,6 @@ def forward_xpu( ) -# ``@torch.compile`` on the CUDA HC head reduction is necessary for accuracy -# as well as performance — upstream a8887c208 ("[Bugfix] [ROCm] [DSV4] [Perf] -# Add aiter mhc support", #41946) refactored ``hc_head`` from a free -# function into ``HCHeadOp(CustomOp)`` and dropped the decorator from the -# CUDA path while keeping it on ``forward_hip``. The drop caused a measured -# ~7 pp regression in DSv4-Flash MTP=2 spec acceptance on SM12x (mt-bench -# c=1, 67.6 % → 59.8 %). -# -# Decorating the ``forward_cuda`` method directly trips -# ``torch._dynamo.exc.Unsupported: failed to bind arguments when attempting -# to inline forward_cuda`` whenever the outer model is wrapped by -# ``@support_torch_compile`` (which is the no-MTP path on SM12x): dynamo -# tries to inline the bound method through ``CustomOp._forward_method`` and -# can't reconcile the ``self`` parameter. Keeping the body as a free -# function — the layout that existed pre-#41946 — sidesteps the bind -# failure while preserving the spec-acceptance recovery. -@torch.compile(backend=current_platform.simple_compile_backend) -def _hc_head_cuda_impl( - hidden_states: torch.Tensor, - hc_fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - rms_norm_eps: float, - hc_eps: float, -) -> torch.Tensor: - hc_mult, hidden_size = hidden_states.shape[-2:] - outer_shape = hidden_states.shape[:-2] - hs_flat = hidden_states.view(-1, hc_mult, hidden_size) - out = torch.ops.vllm.hc_head_fused_kernel_tilelang( - hs_flat, - hc_fn, - hc_scale, - hc_base, - rms_norm_eps, - hc_eps, - ) - return out.view(*outer_shape, hidden_size) - - # --8<-- [start:hc_head] @CustomOp.register("hc_head") class HCHeadOp(CustomOp): @@ -324,14 +284,18 @@ def forward_cuda( rms_norm_eps: float, hc_eps: float, ) -> torch.Tensor: - return _hc_head_cuda_impl( - hidden_states, + hc_mult, hidden_size = hidden_states.shape[-2:] + outer_shape = hidden_states.shape[:-2] + hs_flat = hidden_states.view(-1, hc_mult, hidden_size) + out = torch.ops.vllm.hc_head_fused_kernel_tilelang( + hs_flat, hc_fn, hc_scale, hc_base, rms_norm_eps, hc_eps, ) + return out.view(*outer_shape, hidden_size) def forward_hip( self, From ad26f8fc4cb3984dc7c18d551b3a47089a5ddf8c Mon Sep 17 00:00:00 2001 From: jasl Date: Sun, 31 May 2026 20:21:49 +0800 Subject: [PATCH 065/131] Protect running prefills from long prefill starvation Signed-off-by: jasl --- tests/v1/core/test_scheduler.py | 59 +++++++++++++++++++++++++++++++++ vllm/v1/core/sched/scheduler.py | 7 +++- 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index ed57546115b9..c613fea7a583 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -981,6 +981,65 @@ def test_running_long_prefill_leaves_budget_for_waiting_short_prefill(): assert mixed_output.num_scheduled_tokens[short_prefill_req.request_id] == 20 +def test_running_long_prefill_leaves_budget_for_running_short_prefill(): + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_model_len=512, + max_num_seqs=2, + enable_chunked_prefill=True, + ) + long_prefill_req = create_requests( + num_requests=1, + num_tokens=300, + req_ids=["long_prefill"], + )[0] + short_prefill_req = create_requests( + num_requests=1, + num_tokens=80, + req_ids=["short_prefill"], + )[0] + + scheduler.add_request(long_prefill_req) + first_chunk = scheduler.schedule() + assert first_chunk.num_scheduled_tokens[long_prefill_req.request_id] == 100 + + scheduler.update_from_output( + first_chunk, + ModelRunnerOutput( + req_ids=[long_prefill_req.request_id], + req_id_to_index={long_prefill_req.request_id: 0}, + sampled_token_ids=[[]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + scheduler.add_request(short_prefill_req) + first_mixed = scheduler.schedule() + assert first_mixed.num_scheduled_tokens[long_prefill_req.request_id] == 75 + assert first_mixed.num_scheduled_tokens[short_prefill_req.request_id] == 25 + + scheduler.update_from_output( + first_mixed, + ModelRunnerOutput( + req_ids=[long_prefill_req.request_id, short_prefill_req.request_id], + req_id_to_index={ + long_prefill_req.request_id: 0, + short_prefill_req.request_id: 1, + }, + sampled_token_ids=[[], []], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + second_mixed = scheduler.schedule() + assert second_mixed.num_scheduled_tokens[long_prefill_req.request_id] == 75 + assert second_mixed.num_scheduled_tokens[short_prefill_req.request_id] == 25 + + def test_mixed_decode_prefill_caps_mid_long_prefill_more_tightly(): scheduler = create_scheduler( max_num_batched_tokens=100, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 01216a898599..4505a0e87795 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -512,11 +512,16 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) + has_unscheduled_running_prefill = any( + later_request.num_computed_tokens < later_request.num_prompt_tokens + for later_request in self.running[req_index + 1 :] + ) num_new_tokens = self._limit_mixed_decode_prefill_chunk( request, num_new_tokens, scheduled_running_reqs, - bool(self.waiting or self.skipped_waiting), + bool(self.waiting or self.skipped_waiting) + or has_unscheduled_running_prefill, ) # Make sure the input position does not exceed the max model len. From 6bce5c93e32117ff2e5bc7f4d2bbef2ad5bb40e3 Mon Sep 17 00:00:00 2001 From: jasl Date: Mon, 1 Jun 2026 00:05:03 +0800 Subject: [PATCH 066/131] Add chunked SM120 direct MQA top-k fallback Signed-off-by: jasl --- .../test_sm120_deepgemm_fallbacks.py | 64 ++++++++++ .../nvidia/ops/sm12x_deep_gemm_fallbacks.py | 114 ++++++++++++++++++ 2 files changed, 178 insertions(+) diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py index 08c0a9b752e3..7211de486a0a 100644 --- a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -239,6 +239,70 @@ def wrapped_topk_op(*args, **kwargs): _assert_topk_indices_match_values(reference_logits, out, topk_tokens) +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_mqa_direct_topk_uses_triton_chunks_when_logits_do_not_fit( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(13) + num_q, num_heads, head_dim = 9, 4, 64 + seq_len_kv, topk_tokens = 23, 6 + logits_bytes = num_q * seq_len_kv * torch.float32.itemsize + monkeypatch.setattr( + sm12x_deep_gemm_fallbacks, + "_SM120_MQA_TRITON_TOPK_MAX_LOGITS_BYTES", + logits_bytes - 1, + ) + monkeypatch.setattr( + sm12x_deep_gemm_fallbacks, + "_SM120_MQA_TRITON_CHUNKED_TOPK_CHUNK_SIZE", + 7, + ) + + q = torch.randn(num_q, num_heads, head_dim, device="cuda", dtype=torch.bfloat16) + q_fp8 = q.to(torch.float8_e4m3fn).contiguous() + kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16) + kv_scale = kv.abs().float().amax(dim=-1).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()[:, None]).to(torch.float8_e4m3fn) + weights = torch.randn(num_q, num_heads, device="cuda", dtype=torch.float32) + cu_seqlen_ks = torch.arange(num_q, device="cuda", dtype=torch.int32) % 4 + cu_seqlen_ke = torch.full( + (num_q,), seq_len_kv, device="cuda", dtype=torch.int32 + ) + out = torch.empty(num_q, topk_tokens, device="cuda", dtype=torch.int32) + + original_triton = sm12x_mqa.fp8_mqa_logits_triton + calls = 0 + + def wrapped_triton(*args, **kwargs): + nonlocal calls + calls += 1 + return original_triton(*args, **kwargs) + + monkeypatch.setattr(sm12x_mqa, "fp8_mqa_logits_triton", wrapped_triton) + + assert deep_gemm_utils.fp8_fp4_mqa_topk_indices( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + out, + ) + assert calls == cdiv(seq_len_kv, 7) + + reference_logits = sm12x_deep_gemm_fallbacks._fp8_mqa_logits_torch( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + _assert_topk_indices_match_values(reference_logits, out, topk_tokens) + + @pytest.mark.skipif( not current_platform.is_device_capability_family(120), reason="SM120 only" ) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py index 3293659a6895..64351957917d 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py @@ -11,6 +11,7 @@ _SM120_MQA_LOGITS_MAX_SCORE_BYTES = 64 * 1024 * 1024 _SM120_MQA_TRITON_TOPK_MAX_LOGITS_BYTES = 512 * 1024 * 1024 +_SM120_MQA_TRITON_CHUNKED_TOPK_CHUNK_SIZE = 32768 _SM120_PAGED_MQA_TOPK_CHUNK_SIZE = 8192 @@ -272,6 +273,110 @@ def _fp8_mqa_logits_topk_triton( return True +def _fp8_mqa_logits_topk_triton_chunked( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + out: torch.Tensor, +) -> bool: + q_values, q_scale = q + k_values, k_scales = kv + if not (q_scale is None and q_values.dim() == 3 and k_values.dim() == 2): + return False + + from vllm.models.deepseek_v4.nvidia.ops.sm12x_mqa import ( + fp8_mqa_logits_triton, + ) + + seq_len = q_values.shape[0] + seq_len_kv = k_values.shape[0] + topk_tokens = out.shape[1] + out.fill_(-1) + if seq_len == 0 or seq_len_kv == 0 or topk_tokens == 0: + return True + + chunk_size = max(1, _SM120_MQA_TRITON_CHUNKED_TOPK_CHUNK_SIZE) + best_values = torch.full( + (seq_len, topk_tokens), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + max_chunk_topk = min(topk_tokens, chunk_size) + chunk_values_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + chunk_indices_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int64, + ) + chunk_indices_i32 = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + candidate_values = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + candidate_indices = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + next_best_values = torch.empty_like(best_values) + selected = torch.empty( + (seq_len, topk_tokens), + device=q_values.device, + dtype=torch.int64, + ) + + for k_start in range(0, seq_len_kv, chunk_size): + k_end = min(k_start + chunk_size, seq_len_kv) + local_width = k_end - k_start + local_ks = torch.clamp(cu_seqlen_ks - k_start, min=0, max=local_width) + local_ke = torch.clamp(cu_seqlen_ke - k_start, min=0, max=local_width) + chunk_logits = fp8_mqa_logits_triton( + q_values, + (k_values[k_start:k_end], k_scales[k_start:k_end]), + weights, + local_ks, + local_ke, + ) + chunk_topk = min(topk_tokens, local_width) + chunk_values = chunk_values_buf[:, :chunk_topk] + chunk_indices = chunk_indices_buf[:, :chunk_topk] + torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices)) + chunk_indices_out = chunk_indices_i32[:, :chunk_topk] + chunk_indices_out.copy_(chunk_indices) + chunk_indices_out.add_(k_start) + + candidate_cols = topk_tokens + chunk_topk + candidate_values_view = candidate_values[:, :candidate_cols] + candidate_indices_view = candidate_indices[:, :candidate_cols] + candidate_values_view[:, :topk_tokens].copy_(best_values) + candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values) + candidate_indices_view[:, :topk_tokens].copy_(out) + candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out) + torch.topk( + candidate_values_view, + topk_tokens, + dim=1, + out=(next_best_values, selected), + ) + torch.gather(candidate_indices_view, 1, selected, out=out) + best_values, next_best_values = next_best_values, best_values + out.masked_fill_(~torch.isfinite(best_values), -1) + + return True + + def fp8_fp4_mqa_topk_indices( q: tuple[torch.Tensor, torch.Tensor | None], kv: tuple[torch.Tensor, torch.Tensor], @@ -296,6 +401,15 @@ def fp8_fp4_mqa_topk_indices( topk_indices, ): return True + if _fp8_mqa_logits_topk_triton_chunked( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_indices, + ): + return True _fp8_mqa_logits_topk_torch( q, kv, From db3a71f53c0546f5b6016837c580051a27b527f6 Mon Sep 17 00:00:00 2001 From: jasl Date: Mon, 1 Jun 2026 12:31:53 +0800 Subject: [PATCH 067/131] Protect later running decodes from long prefill starvation Signed-off-by: jasl --- tests/v1/core/test_scheduler.py | 68 +++++++++++++++++++++++++++++++++ vllm/v1/core/sched/scheduler.py | 16 ++++++-- 2 files changed, 80 insertions(+), 4 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index c613fea7a583..942f5af3d1f2 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1040,6 +1040,74 @@ def test_running_long_prefill_leaves_budget_for_running_short_prefill(): assert second_mixed.num_scheduled_tokens[short_prefill_req.request_id] == 25 +def test_running_long_prefill_leaves_budget_for_later_running_decode(): + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_model_len=2048, + max_num_seqs=2, + enable_chunked_prefill=True, + ) + long_prefill_req = create_requests( + num_requests=1, + num_tokens=1000, + req_ids=["long_prefill"], + )[0] + short_req = create_requests( + num_requests=1, + num_tokens=80, + req_ids=["short_then_decode"], + )[0] + + scheduler.add_request(long_prefill_req) + first_chunk = scheduler.schedule() + assert first_chunk.num_scheduled_tokens[long_prefill_req.request_id] == 100 + scheduler.update_from_output( + first_chunk, + ModelRunnerOutput( + req_ids=[long_prefill_req.request_id], + req_id_to_index={long_prefill_req.request_id: 0}, + sampled_token_ids=[[]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + scheduler.add_request(short_req) + while short_req.num_computed_tokens < short_req.num_prompt_tokens: + mixed_output = scheduler.schedule() + assert short_req.request_id in mixed_output.num_scheduled_tokens + sampled_token_ids = [] + req_id_to_index = {} + for index, req_id in enumerate(mixed_output.num_scheduled_tokens): + req_id_to_index[req_id] = index + if req_id == short_req.request_id and ( + short_req.num_computed_tokens + + mixed_output.num_scheduled_tokens[req_id] + ) >= short_req.num_prompt_tokens: + sampled_token_ids.append([0]) + else: + sampled_token_ids.append([]) + scheduler.update_from_output( + mixed_output, + ModelRunnerOutput( + req_ids=list(mixed_output.num_scheduled_tokens), + req_id_to_index=req_id_to_index, + sampled_token_ids=sampled_token_ids, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + assert long_prefill_req.num_computed_tokens < long_prefill_req.num_prompt_tokens + assert short_req.num_computed_tokens >= short_req.num_prompt_tokens + + decode_mixed = scheduler.schedule() + assert decode_mixed.num_scheduled_tokens[long_prefill_req.request_id] == 6 + assert decode_mixed.num_scheduled_tokens[short_req.request_id] == 1 + + def test_mixed_decode_prefill_caps_mid_long_prefill_more_tightly(): scheduler = create_scheduler( max_num_batched_tokens=100, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4505a0e87795..81fe3da33afc 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -7,6 +7,7 @@ from dataclasses import replace from typing import Any +from vllm import envs from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.config import VllmConfig from vllm.distributed.ec_transfer.ec_connector.base import ( @@ -396,6 +397,7 @@ def _limit_mixed_decode_prefill_chunk( num_new_tokens: int, scheduled_running_reqs: list[Request], has_waiting_requests: bool = False, + has_pending_decode: bool = False, ) -> int: if ( not self.scheduler_config.enable_chunked_prefill @@ -403,8 +405,10 @@ def _limit_mixed_decode_prefill_chunk( ): return num_new_tokens - has_scheduled_decode = self._has_scheduled_decode(scheduled_running_reqs) - if not has_scheduled_decode and not has_waiting_requests: + has_decode_pressure = ( + self._has_scheduled_decode(scheduled_running_reqs) or has_pending_decode + ) + if not has_decode_pressure and not has_waiting_requests: return num_new_tokens remaining_prefill = request.num_prompt_tokens - request.num_computed_tokens @@ -418,7 +422,7 @@ def _limit_mixed_decode_prefill_chunk( very_long_prefill_threshold = ( self.max_num_scheduled_tokens * very_long_prefill_steps ) - if has_scheduled_decode: + if has_decode_pressure: if remaining_prefill > very_long_prefill_threshold: mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 16) else: @@ -453,7 +457,6 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: if self._pause_state == PauseState.PAUSED_ALL: # Do not schedule any requests when paused. token_budget = 0 - # Encoder-related. scheduled_encoder_inputs: dict[str, list[int]] = {} encoder_compute_budget = self.max_num_encoder_input_tokens @@ -516,12 +519,17 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: later_request.num_computed_tokens < later_request.num_prompt_tokens for later_request in self.running[req_index + 1 :] ) + has_pending_running_decode = any( + later_request.num_computed_tokens >= later_request.num_prompt_tokens + for later_request in self.running[req_index + 1 :] + ) num_new_tokens = self._limit_mixed_decode_prefill_chunk( request, num_new_tokens, scheduled_running_reqs, bool(self.waiting or self.skipped_waiting) or has_unscheduled_running_prefill, + has_pending_running_decode, ) # Make sure the input position does not exceed the max model len. From 6dac492601446b9cab361180bba1f29cbd15dbec Mon Sep 17 00:00:00 2001 From: jasl Date: Mon, 1 Jun 2026 19:42:08 +0800 Subject: [PATCH 068/131] Protect very-long prefill fairness Signed-off-by: jasl --- tests/v1/core/test_scheduler.py | 87 +++++++++++++++++++++++++++++++++ vllm/v1/core/sched/scheduler.py | 48 ++++++++++++++++-- 2 files changed, 130 insertions(+), 5 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 942f5af3d1f2..94c29f38a527 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1040,6 +1040,93 @@ def test_running_long_prefill_leaves_budget_for_running_short_prefill(): assert second_mixed.num_scheduled_tokens[short_prefill_req.request_id] == 25 +def test_running_very_long_prefill_defers_waiting_very_long_prefill(): + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_model_len=2048, + max_num_seqs=3, + enable_chunked_prefill=True, + ) + first_long_req = create_requests( + num_requests=1, + num_tokens=600, + req_ids=["first_long"], + )[0] + second_long_req = create_requests( + num_requests=1, + num_tokens=600, + req_ids=["second_long"], + )[0] + short_req = create_requests( + num_requests=1, + num_tokens=20, + req_ids=["short"], + )[0] + + scheduler.add_request(first_long_req) + first_chunk = scheduler.schedule() + assert first_chunk.num_scheduled_tokens[first_long_req.request_id] == 100 + scheduler.update_from_output( + first_chunk, + ModelRunnerOutput( + req_ids=[first_long_req.request_id], + req_id_to_index={first_long_req.request_id: 0}, + sampled_token_ids=[[]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + scheduler.add_request(second_long_req) + scheduler.add_request(short_req) + mixed_output = scheduler.schedule() + + assert mixed_output.num_scheduled_tokens[first_long_req.request_id] == 50 + assert second_long_req.request_id not in mixed_output.num_scheduled_tokens + assert mixed_output.num_scheduled_tokens[short_req.request_id] == 20 + + +def test_running_very_long_prefill_ignores_deferred_long_waiting_pressure(): + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_model_len=2048, + max_num_seqs=2, + enable_chunked_prefill=True, + ) + first_long_req = create_requests( + num_requests=1, + num_tokens=600, + req_ids=["first_long"], + )[0] + second_long_req = create_requests( + num_requests=1, + num_tokens=600, + req_ids=["second_long"], + )[0] + + scheduler.add_request(first_long_req) + first_chunk = scheduler.schedule() + assert first_chunk.num_scheduled_tokens[first_long_req.request_id] == 100 + scheduler.update_from_output( + first_chunk, + ModelRunnerOutput( + req_ids=[first_long_req.request_id], + req_id_to_index={first_long_req.request_id: 0}, + sampled_token_ids=[[]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + scheduler.add_request(second_long_req) + mixed_output = scheduler.schedule() + + assert mixed_output.num_scheduled_tokens[first_long_req.request_id] == 100 + assert second_long_req.request_id not in mixed_output.num_scheduled_tokens + + def test_running_long_prefill_leaves_budget_for_later_running_decode(): scheduler = create_scheduler( max_num_batched_tokens=100, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 81fe3da33afc..3c3dc380844c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -391,6 +391,39 @@ def _has_scheduled_decode(self, requests: list[Request]) -> bool: for request in requests ) + def _very_long_prefill_threshold(self) -> int: + return self.max_num_scheduled_tokens * 4 + + def _is_very_long_prefill( + self, + request: Request, + num_computed_tokens: int | None = None, + ) -> bool: + if not self.scheduler_config.enable_chunked_prefill: + return False + if num_computed_tokens is None: + num_computed_tokens = request.num_computed_tokens + return ( + request.num_prompt_tokens > self._very_long_prefill_threshold() + and num_computed_tokens < request.num_prompt_tokens + ) + + def _has_active_very_long_prefill(self) -> bool: + return any(self._is_very_long_prefill(request) for request in self.running) + + def _has_waiting_requests_for_running_prefill(self, request: Request) -> bool: + if not (self.waiting or self.skipped_waiting): + return False + if not self._is_very_long_prefill(request): + return True + return any( + not self._is_very_long_prefill(waiting_request) + for waiting_request in self.waiting + ) or any( + not self._is_very_long_prefill(waiting_request) + for waiting_request in self.skipped_waiting + ) + def _limit_mixed_decode_prefill_chunk( self, request: Request, @@ -418,10 +451,7 @@ def _limit_mixed_decode_prefill_chunk( # Very long prefills span many scheduling steps; a smaller chunk keeps # already-active decoders from seeing long inter-token gaps and leaves # room for short requests that arrive behind an active long prefill. - very_long_prefill_steps = 4 - very_long_prefill_threshold = ( - self.max_num_scheduled_tokens * very_long_prefill_steps - ) + very_long_prefill_threshold = self._very_long_prefill_threshold() if has_decode_pressure: if remaining_prefill > very_long_prefill_threshold: mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 16) @@ -527,7 +557,7 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: request, num_new_tokens, scheduled_running_reqs, - bool(self.waiting or self.skipped_waiting) + self._has_waiting_requests_for_running_prefill(request) or has_unscheduled_running_prefill, has_pending_running_decode, ) @@ -838,6 +868,14 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens + if ( + self._is_very_long_prefill(request, num_computed_tokens) + and self._has_active_very_long_prefill() + ): + request_queue.pop_request() + step_skipped_waiting.prepend_request(request) + continue + encoder_inputs_to_schedule = None external_load_encoder_input = [] new_encoder_compute_budget = encoder_compute_budget From 1381809ab0fcd166eb4d0c4fcd30b6de05f5b1e9 Mon Sep 17 00:00:00 2001 From: jasl Date: Mon, 1 Jun 2026 21:49:08 +0800 Subject: [PATCH 069/131] sm12x: avoid MHC prenorm GEMM JIT per token count Signed-off-by: jasl --- tests/v1/attention/test_sm120_deepgemm_fallbacks.py | 7 +++++++ vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py index 7211de486a0a..642bc7ef7fd5 100644 --- a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -165,6 +165,13 @@ def test_sm120_direct_mqa_logits_block_m_prefers_short_prefill_tile(): assert sm12x_mqa._fp8_mqa_logits_block_m(65536, 65536) == 64 +def test_tf32_hc_prenorm_gemm_does_not_specialize_prefill_token_count(): + kernel = sm12x_mqa._tf32_hc_prenorm_gemm_kernel + constexpr_names = {kernel.arg_names[index] for index in kernel.constexprs} + + assert "M" not in constexpr_names + + @pytest.mark.skipif( not current_platform.is_device_capability_family(120), reason="SM120 only" ) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py index d308fa6a6268..a8d0c8ce58ce 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py @@ -611,7 +611,7 @@ def _tf32_hc_prenorm_gemm_kernel( fn_ptr, out_ptr, sqrsum_ptr, - M: tl.constexpr, + M, K: tl.constexpr, N: tl.constexpr, stride_xm: tl.constexpr, From 56897a265143b47c8b3a75e4248e32bc55a031b2 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 3 Jun 2026 04:08:34 +0800 Subject: [PATCH 070/131] test: adapt DS4 prefix cache tests to scheduler block size Signed-off-by: jasl --- tests/v1/core/test_prefix_caching.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 63390268bab6..4a17a4f247e9 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -3538,6 +3538,7 @@ def test_deepseek_v4_mla_prompt_cache_survives_decode_pressure(): manager = KVCacheManager( config, max_model_len=128, + scheduler_block_size=full_block_size, max_num_batched_tokens=chunk_tokens, enable_caching=True, hash_block_size=hash_block_size, @@ -3616,6 +3617,7 @@ def test_deepseek_v4_mla_cached_prompts_do_not_block_admission(): ], ), max_model_len=512, + scheduler_block_size=block_size, max_num_batched_tokens=128, enable_caching=True, hash_block_size=block_size, @@ -3665,6 +3667,7 @@ def test_deepseek_v4_mla_prefix_hit_under_pressure_does_not_overallocate(): ], ), max_model_len=512, + scheduler_block_size=block_size, max_num_batched_tokens=128, enable_caching=True, hash_block_size=block_size, @@ -3725,6 +3728,7 @@ def test_reset_prefix_cache_after_deepseek_v4_mla_prompt_cache(): ], ), max_model_len=512, + scheduler_block_size=block_size, max_num_batched_tokens=128, enable_caching=True, hash_block_size=block_size, From 6a576fa64359845165c12c3320d0f0bb8985b3c0 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 3 Jun 2026 14:49:29 +0800 Subject: [PATCH 071/131] fix: export DeepSeek V4 FusedMoE metadata Signed-off-by: jasl --- .../test_deepseek_v4_moe_metadata.py | 90 +++++++++++++++++++ vllm/models/deepseek_v4/nvidia/model.py | 10 ++- 2 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 tests/model_executor/test_deepseek_v4_moe_metadata.py diff --git a/tests/model_executor/test_deepseek_v4_moe_metadata.py b/tests/model_executor/test_deepseek_v4_moe_metadata.py new file mode 100644 index 000000000000..fbfdba6e19db --- /dev/null +++ b/tests/model_executor/test_deepseek_v4_moe_metadata.py @@ -0,0 +1,90 @@ +from types import SimpleNamespace + +from vllm.models.deepseek_v4.nvidia import model as deepseek_v4_model +from vllm.models.deepseek_v4.nvidia.model import ( + DeepseekV4MixtureOfExperts, + DeepseekV4MoE, +) + + +def test_deepseek_v4_fused_moe_metadata_is_available_to_mixture(): + moe = object.__new__(DeepseekV4MoE) + moe.n_routed_experts = 256 + moe.n_shared_experts = 1 + moe.experts = SimpleNamespace( + logical_num_experts=256, + global_num_experts=256, + local_num_experts=128, + ) + moe._sync_fused_moe_metadata() + + mixture = object.__new__(DeepseekV4MixtureOfExperts) + mixture.extract_moe_parameters(moe) + + assert mixture.num_logical_experts == 256 + assert mixture.num_physical_experts == 256 + assert mixture.num_local_physical_experts == 128 + assert mixture.num_routed_experts == 256 + assert mixture.num_shared_experts == 1 + assert mixture.num_redundant_experts == 0 + + +def test_deepseek_v4_fused_moe_init_exports_moe_metadata(monkeypatch): + class FakeGate: + def __init__(self, *args, **kwargs): + self.e_score_correction_bias = None + self.tid2eid = None + + class FakeFusedMoE: + logical_num_experts = 256 + global_num_experts = 256 + local_num_experts = 128 + + def __init__(self, **kwargs): + self.kwargs = kwargs + + class FakeMLP: + def __init__(self, *args, **kwargs): + pass + + monkeypatch.setattr(deepseek_v4_model, "GateLinear", FakeGate) + monkeypatch.setattr(deepseek_v4_model, "DeepseekV4MLP", FakeMLP) + monkeypatch.setattr(deepseek_v4_model, "FusedMoE", FakeFusedMoE) + monkeypatch.setattr( + deepseek_v4_model, + "get_tensor_model_parallel_world_size", + lambda: 2, + ) + monkeypatch.setattr( + deepseek_v4_model, + "get_tensor_model_parallel_rank", + lambda: 1, + ) + + config = SimpleNamespace( + n_routed_experts=256, + n_shared_experts=1, + num_experts_per_tok=8, + hidden_size=7168, + moe_intermediate_size=2048, + swiglu_limit=7.0, + hidden_act="silu", + norm_topk_prob=True, + num_hash_layers=0, + vocab_size=128000, + ) + vllm_config = SimpleNamespace( + model_config=SimpleNamespace(hf_config=config), + quant_config=None, + kernel_config=SimpleNamespace(moe_backend="auto"), + parallel_config=SimpleNamespace(enable_expert_parallel=True), + ) + + moe = DeepseekV4MoE(vllm_config, prefix="model.layers.3.mlp") + + assert moe.n_logical_experts == 256 + assert moe.n_physical_experts == 256 + assert moe.n_local_physical_experts == 128 + assert moe.n_local_experts == 128 + assert moe.n_shared_experts == 1 + assert moe.n_redundant_experts == 0 diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 8cf5e636b0c0..c5a1a3ae8816 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -562,6 +562,7 @@ def __init__( prefix=f"{prefix}.shared_experts", ) + self.n_shared_experts = config.n_shared_experts or 0 if self.use_mega_moe: self._init_mega_moe_experts(vllm_config, config, prefix) else: @@ -580,7 +581,6 @@ def _init_mega_moe_experts( eplb_config = vllm_config.parallel_config.eplb_config self.n_redundant_experts = eplb_config.num_redundant_experts self.n_routed_experts = config.n_routed_experts - self.n_shared_experts = config.n_shared_experts or 0 self.n_logical_experts = self.n_routed_experts self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts assert self.n_physical_experts % self.ep_size == 0, ( @@ -654,6 +654,14 @@ def _init_fused_moe_experts( enable_eplb=parallel_config.enable_eplb, num_redundant_experts=eplb_config.num_redundant_experts, ) + self._sync_fused_moe_metadata() + + def _sync_fused_moe_metadata(self) -> None: + self.n_logical_experts = self.experts.logical_num_experts + self.n_physical_experts = self.experts.global_num_experts + self.n_local_physical_experts = self.experts.local_num_experts + self.n_local_experts = self.n_local_physical_experts + self.n_redundant_experts = self.n_physical_experts - self.n_logical_experts def forward( self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None From 52c549e5736a364037b07baf164db6dcad0146fa Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 4 Jun 2026 00:09:07 +0800 Subject: [PATCH 072/131] sched: defer very long prefill under decode pressure Signed-off-by: jasl --- tests/v1/core/test_scheduler.py | 12 ++++++------ vllm/v1/core/sched/scheduler.py | 5 +++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 94c29f38a527..be4423daaf5e 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1127,7 +1127,7 @@ def test_running_very_long_prefill_ignores_deferred_long_waiting_pressure(): assert second_long_req.request_id not in mixed_output.num_scheduled_tokens -def test_running_long_prefill_leaves_budget_for_later_running_decode(): +def test_running_very_long_prefill_defers_to_later_running_decode(): scheduler = create_scheduler( max_num_batched_tokens=100, max_model_len=2048, @@ -1191,8 +1191,8 @@ def test_running_long_prefill_leaves_budget_for_later_running_decode(): assert short_req.num_computed_tokens >= short_req.num_prompt_tokens decode_mixed = scheduler.schedule() - assert decode_mixed.num_scheduled_tokens[long_prefill_req.request_id] == 6 assert decode_mixed.num_scheduled_tokens[short_req.request_id] == 1 + assert long_prefill_req.request_id not in decode_mixed.num_scheduled_tokens def test_mixed_decode_prefill_caps_mid_long_prefill_more_tightly(): @@ -1205,7 +1205,7 @@ def test_mixed_decode_prefill_caps_mid_long_prefill_more_tightly(): decode_req = create_requests(num_requests=1, num_tokens=100, req_ids=["decode"])[0] mid_long_prefill_req = create_requests( num_requests=1, - num_tokens=500, + num_tokens=300, req_ids=["mid_long_prefill"], )[0] @@ -1229,10 +1229,10 @@ def test_mixed_decode_prefill_caps_mid_long_prefill_more_tightly(): mixed_output = scheduler.schedule() assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 - assert mixed_output.num_scheduled_tokens[mid_long_prefill_req.request_id] == 6 + assert mixed_output.num_scheduled_tokens[mid_long_prefill_req.request_id] == 25 -def test_mixed_decode_prefill_caps_very_long_prefill_more_tightly(): +def test_mixed_decode_prefill_defers_very_long_prefill(): scheduler = create_scheduler( max_num_batched_tokens=100, max_model_len=4096, @@ -1266,7 +1266,7 @@ def test_mixed_decode_prefill_caps_very_long_prefill_more_tightly(): mixed_output = scheduler.schedule() assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 - assert mixed_output.num_scheduled_tokens[very_long_prefill_req.request_id] == 6 + assert very_long_prefill_req.request_id not in mixed_output.num_scheduled_tokens def test_preempt_during_execution(): diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 3c3dc380844c..5cb58cff9bff 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -454,7 +454,7 @@ def _limit_mixed_decode_prefill_chunk( very_long_prefill_threshold = self._very_long_prefill_threshold() if has_decode_pressure: if remaining_prefill > very_long_prefill_threshold: - mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 16) + return 0 else: mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 4) elif remaining_prefill > very_long_prefill_threshold: @@ -913,7 +913,8 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: num_new_tokens = self._limit_mixed_decode_prefill_chunk( request, num_new_tokens, scheduled_running_reqs ) - assert num_new_tokens > 0 + if num_new_tokens == 0: + break # Schedule encoder inputs. if request.has_encoder_inputs: From 6e44ffac76eb2127428f1b1d972e05787d2bacfb Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 3 Jun 2026 03:58:45 +0800 Subject: [PATCH 073/131] sm12x: add sparse MLA prefill D512 split prototype Signed-off-by: jasl --- .../attention/test_sparse_mla_indexed_d512.py | 182 +++++ vllm/envs.py | 4 + vllm/model_executor/warmup/kernel_warmup.py | 10 +- vllm/models/deepseek_v4/nvidia/flashmla.py | 336 ++++++-- .../backends/mla/sparse_mla_kernels.py | 722 +++++++++++++++++- 5 files changed, 1189 insertions(+), 65 deletions(-) create mode 100644 tests/v1/attention/test_sparse_mla_indexed_d512.py diff --git a/tests/v1/attention/test_sparse_mla_indexed_d512.py b/tests/v1/attention/test_sparse_mla_indexed_d512.py new file mode 100644 index 000000000000..531c4146f3ee --- /dev/null +++ b/tests/v1/attention/test_sparse_mla_indexed_d512.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_indexed_d512_split_sparse_mla_attention, + accumulate_indexed_sparse_mla_attention_chunk, +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_indexed_d512_split_sparse_mla_matches_indexed_accumulate(): + torch.cuda.set_device(0) + device = torch.device("cuda:0") + torch.manual_seed(17) + num_tokens = 64 + num_heads = 8 + head_dim = 512 + num_candidates = 640 + kv_tokens = 4096 + scale = head_dim**-0.5 + + q = torch.randn( + num_tokens, + num_heads, + head_dim, + device=device, + dtype=torch.bfloat16, + ) + kv = torch.randn(kv_tokens, head_dim, device=device, dtype=torch.bfloat16) + indices = torch.randint( + 0, + kv_tokens, + (num_tokens, num_candidates), + device=device, + dtype=torch.int32, + ) + lens = torch.randint( + num_candidates // 2, + num_candidates + 1, + (num_tokens,), + device=device, + dtype=torch.int32, + ) + + current_max = torch.full( + (num_tokens, num_heads), + -float("inf"), + device=device, + dtype=torch.float32, + ) + current_denom = torch.zeros_like(current_max) + current_acc = torch.zeros( + num_tokens, num_heads, head_dim, device=device, dtype=torch.float32 + ) + split_max = torch.empty_like(current_max) + split_denom = torch.empty_like(current_denom) + split_acc = torch.empty_like(current_acc) + split_scores = torch.empty( + num_tokens, + num_heads, + num_candidates, + device=device, + dtype=torch.float32, + ) + + accumulate_indexed_sparse_mla_attention_chunk( + q=q, + kv_flat=kv, + indices=indices, + lens=lens, + scale=scale, + max_score=current_max, + denom=current_denom, + acc=current_acc, + ) + accumulate_indexed_d512_split_sparse_mla_attention( + q=q, + kv_flat=kv, + indices=indices, + lens=lens, + scale=scale, + max_score=split_max, + denom=split_denom, + acc=split_acc, + scores=split_scores, + ) + torch.cuda.synchronize() + + current = current_acc / current_denom[:, :, None] + split = split_acc / split_denom[:, :, None] + torch.testing.assert_close(split_max, current_max, atol=2e-5, rtol=2e-5) + torch.testing.assert_close(split_denom, current_denom, atol=2e-3, rtol=2e-3) + torch.testing.assert_close(split, current, atol=2e-3, rtol=2e-3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_indexed_d512_split_sparse_mla_matches_c128_combined_width(): + torch.cuda.set_device(0) + device = torch.device("cuda:0") + torch.manual_seed(23) + num_tokens = 64 + num_heads = 8 + head_dim = 512 + num_candidates = 1152 + kv_tokens = 4096 + scale = head_dim**-0.5 + + q = torch.randn( + num_tokens, + num_heads, + head_dim, + device=device, + dtype=torch.bfloat16, + ) + kv = torch.randn(kv_tokens, head_dim, device=device, dtype=torch.bfloat16) + indices = torch.randint( + 0, + kv_tokens, + (num_tokens, num_candidates), + device=device, + dtype=torch.int32, + ) + lens = torch.randint( + 128, + 1097, + (num_tokens,), + device=device, + dtype=torch.int32, + ) + + current_max = torch.full( + (num_tokens, num_heads), + -float("inf"), + device=device, + dtype=torch.float32, + ) + current_denom = torch.zeros_like(current_max) + current_acc = torch.zeros( + num_tokens, num_heads, head_dim, device=device, dtype=torch.float32 + ) + split_max = torch.empty_like(current_max) + split_denom = torch.empty_like(current_denom) + split_acc = torch.empty_like(current_acc) + split_scores = torch.empty( + num_tokens, + num_heads, + num_candidates, + device=device, + dtype=torch.float32, + ) + + accumulate_indexed_sparse_mla_attention_chunk( + q=q, + kv_flat=kv, + indices=indices, + lens=lens, + scale=scale, + max_score=current_max, + denom=current_denom, + acc=current_acc, + ) + accumulate_indexed_d512_split_sparse_mla_attention( + q=q, + kv_flat=kv, + indices=indices, + lens=lens, + scale=scale, + max_score=split_max, + denom=split_denom, + acc=split_acc, + scores=split_scores, + ) + torch.cuda.synchronize() + + current = current_acc / current_denom[:, :, None] + split = split_acc / split_denom[:, :, None] + torch.testing.assert_close(split_max, current_max, atol=2e-5, rtol=2e-5) + torch.testing.assert_close(split_denom, current_denom, atol=2e-3, rtol=2e-3) + torch.testing.assert_close(split, current, atol=2e-3, rtol=2e-3) diff --git a/vllm/envs.py b/vllm/envs.py index 128289ba1e0b..90d37f7b5541 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -180,6 +180,7 @@ VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: bool = True + VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL: bool = False VLLM_TRITON_MLA_SPARSE: bool | None = None VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE: int = 512 VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE: int = 256 @@ -1453,6 +1454,9 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP": lambda: bool( int(os.getenv("VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP", "1")) ), + "VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL": lambda: bool( + int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL", "0")) + ), # Experimental sparse MLA fallback controls. # ``VLLM_TRITON_MLA_SPARSE`` unset means auto-select where FlashMLA sparse # is unavailable; set 0/1 to force-disable/force-enable the fallback. diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 819b2050a009..d0a6c390912f 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -489,7 +489,15 @@ def kernel_warmup(worker: "Worker"): minimax_m3_msa_warmup(worker) - _deepseek_v4_sparse_mla_attention_warmup(worker) + if envs.VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH: + from vllm.models.deepseek_v4.nvidia.flashmla import ( + _disable_sparse_mla_prefill_stats, + ) + + with _disable_sparse_mla_prefill_stats(): + _deepseek_v4_sparse_mla_attention_warmup(worker) + else: + _deepseek_v4_sparse_mla_attention_warmup(worker) _deepseek_v4_request_prep_warmup(worker) enable_flashinfer_autotune = ( diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index 3e8bd7e6c503..36cd8f525776 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -5,12 +5,14 @@ import torch +import vllm.envs as envs from vllm.forward_context import get_forward_context from vllm.models.deepseek_v4.attention import DeepseekV4Attention from vllm.models.deepseek_v4.common.ops import ( combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, + sparse_prefill_combined_topk_size, ) from vllm.models.deepseek_v4.nvidia.ops.o_proj import ( compute_fp8_einsum_recipe, @@ -30,6 +32,7 @@ ) from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead, + accumulate_indexed_d512_split_sparse_mla_attention, accumulate_indexed_sparse_mla_attention_chunk, build_combined_sparse_mla_decode_valid_mask, finish_sparse_mla_attention_with_sink, @@ -49,6 +52,45 @@ from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata +_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS = 8192 +_INDEXED_D512_SPLIT_PREFILL_MAX_TOPK = 1152 + + +def _use_indexed_d512_split_prefill( + *, + compress_ratio: int, + head_dim: int, + num_prefills: int, + combined_topk: int, + max_prefill_seq_len: int, + swa_only: bool, +) -> bool: + return ( + envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL + and not swa_only + and compress_ratio in (4, 128) + and head_dim == 512 + and num_prefills == 1 + and 512 < combined_topk <= _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK + and max_prefill_seq_len >= _INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS + ) + + +def _sparse_mla_prefill_gather_len_upper_bound( + *, + max_model_len: int, + max_num_batched_tokens: int, + window_size: int, +) -> tuple[int, int]: + max_query_chunk_tokens = max(1, min(max_model_len, max_num_batched_tokens)) + max_prefix_len = max(max_model_len - max_query_chunk_tokens, 0) + max_gather_len = max_query_chunk_tokens + min( + max_prefix_len, + max(window_size - 1, 0), + ) + return max_query_chunk_tokens, max_gather_len + + class DeepseekV4FlashMLAAttention(DeepseekV4Attention): """FlashMLA sparse MLA attention layer for DeepSeek V4 (CUDA).""" @@ -84,6 +126,94 @@ def get_padded_num_q_heads(cls, num_heads: int) -> int: ) return 64 if num_heads <= 64 else 128 + @classmethod + def _prefill_workspace_topk_bound( + cls, + layer: "DeepseekV4FlashMLAAttention", + ) -> int: + if layer.compress_ratio <= 1: + return 0 + if ( + layer.topk_indices_buffer is not None + and layer.topk_indices_buffer.ndim > 0 + and layer.topk_indices_buffer.shape[-1] > 0 + ): + return int(layer.topk_indices_buffer.shape[-1]) + indexer_topk = getattr(layer.indexer, "topk_tokens", None) + if indexer_topk is not None: + return int(indexer_topk) + return 2048 + + @classmethod + def _prefill_workspace_reservation_specs( + cls, + layer: "DeepseekV4FlashMLAAttention", + ) -> tuple[tuple[tuple[int, ...], torch.dtype], ...]: + max_model_len = max(1, int(layer.max_model_len)) + max_num_batched_tokens = max(1, int(layer.max_num_batched_tokens)) + window_size = max(1, int(layer.window_size)) + compress_ratio = max(1, int(layer.compress_ratio)) + head_dim = int(layer.head_dim) + num_heads = int(layer.n_local_heads) + + max_query_chunk_tokens, max_gather_len = ( + _sparse_mla_prefill_gather_len_upper_bound( + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + window_size=window_size, + ) + ) + if compress_ratio <= 1: + m_bound = max_gather_len + else: + compressed_region_size = max_model_len // compress_ratio + m_bound = compressed_region_size + max_gather_len + + combined_topk = sparse_prefill_combined_topk_size( + cls._prefill_workspace_topk_bound(layer), + window_size, + ) + specs: list[tuple[tuple[int, ...], torch.dtype]] = [ + ((layer.PREFILL_CHUNK_SIZE, m_bound, head_dim), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ] + if is_triton_sparse_mla_enabled_for_platform(): + query_chunk_size = min( + max_query_chunk_tokens, + triton_sparse_mla_query_chunk_size(), + ) + specs.extend( + [ + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads, head_dim), torch.float32), + ] + ) + if _use_indexed_d512_split_prefill( + compress_ratio=compress_ratio, + head_dim=head_dim, + num_prefills=1, + combined_topk=combined_topk, + max_prefill_seq_len=max_model_len, + swa_only=False, + ): + specs.append(((query_chunk_size, num_heads, combined_topk), torch.float32)) + return tuple(specs) + + @classmethod + def _reserve_prefill_workspace( + cls, + layer: "DeepseekV4FlashMLAAttention", + ) -> None: + try: + workspace_manager = current_workspace_manager() + except AssertionError: + return + workspace_manager.get_simultaneous( + *cls._prefill_workspace_reservation_specs(layer) + ) + def forward_mqa( self, q: torch.Tensor, @@ -103,20 +233,9 @@ def forward_mqa( attn_metadata = forward_context.attn_metadata if attn_metadata is None: - # Warmup dummy run: no real metadata. Reserve the same bf16 - # gather workspace _forward_prefill would; the dequantize / topk - # / sparse_fwd kernels are skipped this step. - swa_only = self.compress_ratio <= 1 - N = ( - 0 - if swa_only - else (self.max_model_len + self.compress_ratio - 1) - // self.compress_ratio - ) - M = N + self.window_size + self.max_num_batched_tokens - current_workspace_manager().get_simultaneous( - ((self.PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - ) + # Warmup dummy run: no real metadata. Reserve the same graph-stable + # workspace shapes _forward_prefill can use, but skip real kernels. + self._reserve_prefill_workspace(self) output.zero_() return @@ -425,7 +544,7 @@ def _forward_sparse_mla_prefill_triton( combined_indices: torch.Tensor, combined_lens: torch.Tensor, output: torch.Tensor, - state_buffers: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + state_buffers: tuple[torch.Tensor, ...] | None = None, ) -> None: kv_flat = kv.reshape(-1, q.shape[-1]) topk_chunk_size = triton_sparse_mla_prefill_topk_chunk_size( @@ -448,7 +567,20 @@ def _forward_sparse_mla_prefill_triton( ((query_chunk_size, layer.n_local_heads, q.shape[-1]), torch.float32), ) else: - max_score_buffer, denom_buffer, output_buffer = state_buffers + max_score_buffer, denom_buffer, output_buffer = state_buffers[:3] + indexed_d512_scores = None + if ( + state_buffers is not None + and envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL + and layer.compress_ratio in (4, 128) + and q.shape[-1] == 512 + and kv.shape[0] == 1 + and 512 + < combined_indices.shape[-1] + <= _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK + and len(state_buffers) == 4 + ): + indexed_d512_scores = state_buffers[3] for token_start in range(0, q.shape[0], query_chunk_size): token_end = min(token_start + query_chunk_size, q.shape[0]) @@ -459,26 +591,43 @@ def _forward_sparse_mla_prefill_triton( max_score = max_score_buffer[:num_tokens] denom = denom_buffer[:num_tokens] subset_acc = output_buffer[:num_tokens] - max_score.fill_(float("-inf")) - denom.zero_() - subset_acc.zero_() - - for index_start in range(0, combined_indices.shape[-1], topk_chunk_size): - index_end = min( - index_start + topk_chunk_size, - combined_indices.shape[-1], - ) - accumulate_indexed_sparse_mla_attention_chunk( + if indexed_d512_scores is not None: + accumulate_indexed_d512_split_sparse_mla_attention( q=q_chunk, kv_flat=kv_flat, - indices=indices_chunk_full[:, index_start:index_end], + indices=indices_chunk_full, lens=lens_chunk, - candidate_offset=index_start, scale=layer.scale, max_score=max_score, denom=denom, acc=subset_acc, + scores=indexed_d512_scores[ + :num_tokens, :, : combined_indices.shape[-1] + ], ) + else: + max_score.fill_(float("-inf")) + denom.zero_() + subset_acc.zero_() + + for index_start in range( + 0, combined_indices.shape[-1], topk_chunk_size + ): + index_end = min( + index_start + topk_chunk_size, + combined_indices.shape[-1], + ) + accumulate_indexed_sparse_mla_attention_chunk( + q=q_chunk, + kv_flat=kv_flat, + indices=indices_chunk_full[:, index_start:index_end], + lens=lens_chunk, + candidate_offset=index_start, + scale=layer.scale, + max_score=max_score, + denom=denom, + acc=subset_acc, + ) finish_sparse_mla_attention_with_sink( max_score, @@ -594,6 +743,7 @@ def _forward_prefill( ) -> None: swa_only = attn_metadata is None + num_prefills = swa_metadata.num_prefills num_prefill_tokens = swa_metadata.num_prefill_tokens num_decodes = swa_metadata.num_decodes num_decode_tokens = swa_metadata.num_decode_tokens @@ -601,8 +751,12 @@ def _forward_prefill( # Use pre-computed prefill metadata. seq_lens = swa_metadata.prefill_seq_lens gather_lens = swa_metadata.prefill_gather_lens + seq_lens_cpu = swa_metadata.prefill_seq_lens_cpu + gather_lens_cpu = swa_metadata.prefill_gather_lens_cpu assert seq_lens is not None assert gather_lens is not None + assert seq_lens_cpu is not None + assert gather_lens_cpu is not None # Derive prefill-local token offsets from the full query_start_loc_cpu. query_start_loc_cpu = swa_metadata.query_start_loc_cpu @@ -621,22 +775,93 @@ def _forward_prefill( assert attn_metadata is not None topk_indices = attn_metadata.c128a_prefill_topk_indices top_k = topk_indices.shape[-1] + # Compressed region must fit the full compressed pool (seq_len // + # compress_ratio), not just top_k. top_k bounds how many indices + # the indexer selects, not the pool size it indexes into. + N = int((seq_lens_cpu // self.compress_ratio).max().item()) else: # NOTE(woosuk): topk_indices will not be used for SWA-only layers. assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[num_decode_tokens:] top_k = 0 - chunk_plan = swa_metadata.get_prefill_chunk_plan( - compress_ratio=self.compress_ratio, - prefill_chunk_size=self.PREFILL_CHUNK_SIZE, - ) - assert chunk_plan, "prefill chunk plan must be non-empty when num_prefills > 0" + N = 0 + + M = N + int(gather_lens_cpu.max().item()) + chunk_size_const = self.PREFILL_CHUNK_SIZE + num_chunks = (num_prefills + chunk_size_const - 1) // chunk_size_const + max_query_chunk_tokens = 0 + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size_const + chunk_end = min(chunk_start + chunk_size_const, num_prefills) + query_start = ( + query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base + ) + query_end = ( + query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base + ) + max_query_chunk_tokens = max( + max_query_chunk_tokens, int(query_end - query_start) + ) + combined_topk = sparse_prefill_combined_topk_size(top_k, self.window_size) + workspace_manager = current_workspace_manager() - for chunk_start, chunk_end, chunk_N, chunk_M in chunk_plan: + triton_sparse_mla_enabled = is_triton_sparse_mla_enabled(q.device) + if triton_sparse_mla_enabled: + query_chunk_size = min(q.shape[0], triton_sparse_mla_query_chunk_size()) + indexed_d512_split_prefill = _use_indexed_d512_split_prefill( + compress_ratio=int(self.compress_ratio), + head_dim=int(self.head_dim), + num_prefills=int(num_prefills), + combined_topk=int(combined_topk), + max_prefill_seq_len=int(seq_lens_cpu.max().item()), + swa_only=swa_only, + ) + extra_specs: list[tuple[tuple[int, ...], torch.dtype]] = [] + if indexed_d512_split_prefill: + extra_specs.append( + ( + (query_chunk_size, self.n_local_heads, combined_topk), + torch.float32, + ) + ) + ( + kv, + combined_indices_buffer, + combined_lens_buffer, + max_score_buffer, + denom_buffer, + output_buffer, + *extra_state_buffers, + ) = workspace_manager.get_simultaneous( + ((chunk_size_const, M, q.shape[-1]), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ((query_chunk_size, self.n_local_heads), torch.float32), + ((query_chunk_size, self.n_local_heads), torch.float32), + ((query_chunk_size, self.n_local_heads, q.shape[-1]), torch.float32), + *extra_specs, + ) + prefill_state_buffers = ( + max_score_buffer, + denom_buffer, + output_buffer, + *extra_state_buffers, + ) + else: + ( + kv, + combined_indices_buffer, + combined_lens_buffer, + ) = workspace_manager.get_simultaneous( + ((chunk_size_const, M, q.shape[-1]), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ) + prefill_state_buffers = None + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size_const + chunk_end = min(chunk_start + chunk_size_const, num_prefills) chunk_size = chunk_end - chunk_start - kv = workspace_manager.get_simultaneous( - ((chunk_size, chunk_M, q.shape[-1]), torch.bfloat16), - )[0] if not swa_only: # Gather compressed KV assert attn_metadata is not None @@ -660,7 +885,7 @@ def _forward_prefill( gather_lens=gather_lens[chunk_start:chunk_end], block_table=swa_block_table[chunk_start:chunk_end], block_size=swa_metadata.block_size, - offset=chunk_N, + offset=N, ) # Combine the topk indices and SWA indices for gathered KV cache @@ -681,15 +906,28 @@ def _forward_prefill( self.window_size, self.compress_ratio, top_k, - chunk_M, - chunk_N, - ) - flash_mla_sparse_fwd( - q=q[query_start:query_end], - kv=kv.view(-1, 1, q.shape[-1]), - indices=combined_indices.unsqueeze(1), - sm_scale=self.scale, - attn_sink=self.attn_sink, - topk_length=combined_lens, - out=output[query_start:query_end], + M, + N, + combined_indices=combined_indices_buffer, + combined_lens=combined_lens_buffer, ) + if triton_sparse_mla_enabled: + self._forward_sparse_mla_prefill_triton( + self, + q=q[query_start:query_end], + kv=kv[:chunk_size], + combined_indices=combined_indices, + combined_lens=combined_lens, + output=output[query_start:query_end], + state_buffers=prefill_state_buffers, + ) + else: + flash_mla_sparse_fwd( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + sm_scale=self.scale, + attn_sink=self.attn_sink, + topk_length=combined_lens, + out=output[query_start:query_end], + ) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py index c941d1df433d..9dc18fed8b6c 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -5,6 +5,7 @@ import torch from vllm.triton_utils import tl, triton +from vllm.utils.math_utils import next_power_of_2 from vllm.v1.attention.backends.mla.sparse_mla_env import ( triton_sparse_mla_head_block_size, ) @@ -115,7 +116,7 @@ def merge_two_sparse_mla_subsets_with_sink( assert output.is_cuda num_tokens, num_heads, head_dim = subset0_output.shape - block_d = min(128, triton.next_power_of_2(head_dim)) + block_d = min(128, next_power_of_2(head_dim)) grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) _merge_two_subsets_with_sink_kernel[grid]( subset0_output, @@ -212,7 +213,7 @@ def merge_sparse_mla_subset_with_sink( assert output.is_cuda num_tokens, num_heads, head_dim = subset_output.shape - block_d = min(128, triton.next_power_of_2(head_dim)) + block_d = min(128, next_power_of_2(head_dim)) grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) _merge_single_subset_with_sink_kernel[grid]( subset_output, @@ -295,7 +296,7 @@ def build_combined_sparse_mla_decode_valid_mask( assert swa_lens.is_cuda num_candidates = output.shape[1] - block_c = triton.next_power_of_2(num_candidates) + block_c = next_power_of_2(num_candidates) _build_combined_decode_valid_mask_kernel[(output.shape[0],)]( output, compressed_slot_ids, @@ -738,7 +739,7 @@ def _smallest_block_d_covering(hd: int) -> int: output[:, active_heads:].zero_() return - block_d = min(1024, triton.next_power_of_2(head_dim)) + block_d = min(1024, next_power_of_2(head_dim)) head_grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) _finish_materialized_scores_with_sink_kernel[head_grid]( scores, @@ -897,7 +898,7 @@ def accumulate_gathered_sparse_mla_attention_chunk( num_tokens, _, head_dim = q.shape num_heads = max_score.shape[1] num_candidates = kv.shape[1] - block_d = min(1024, triton.next_power_of_2(head_dim)) + block_d = min(1024, next_power_of_2(head_dim)) grid = (num_tokens, num_heads) _accumulate_gathered_attention_chunk_kernel[grid]( q, @@ -1155,6 +1156,115 @@ def _accumulate_indexed_attention_chunk_multihead_kernel( ) +@triton.jit +def _accumulate_indexed_attention_partial_states_multihead_kernel( + q_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_state_p: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_p: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + PART_SIZE: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + part_idx = tl.program_id(2) + part_start = part_idx * PART_SIZE + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=head_mask[:, None] & dim_mask[None, :], + other=0.0, + ).to(tl.float32) + + running_max = tl.full((HEAD_BLOCK,), -float("inf"), dtype=tl.float32) + running_denom = tl.zeros((HEAD_BLOCK,), dtype=tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), dtype=tl.float32) + + valid_len = tl.load(lens_ptr + token_idx) + local_remaining = tl.maximum(num_candidates - part_start, 0) + local_valid = tl.maximum(valid_len - candidate_offset - part_start, 0) + local_eff = tl.minimum(PART_SIZE, tl.minimum(local_remaining, local_valid)) + + for local_idx in range(0, local_eff): + candidate_idx = part_start + local_idx + kv_index = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidate_idx * stride_indices_c + ) + is_valid = kv_index >= 0 + + if is_valid: + kv = tl.load( + kv_flat_ptr + + kv_index.to(tl.int64) * stride_kv_t + + dim_offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + scores = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, scores) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(scores - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + state_base = part_idx * stride_state_p + token_idx * stride_state_t + acc_base = part_idx * stride_acc_p + token_idx * stride_acc_t + tl.store( + max_score_ptr + state_base + head_offsets * stride_state_h, + running_max, + mask=head_mask, + ) + tl.store( + denom_ptr + state_base + head_offsets * stride_state_h, + running_denom, + mask=head_mask, + ) + tl.store( + acc_ptr + + acc_base + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d, + running_acc, + mask=head_mask[:, None] & dim_mask[None, :], + ) + + def accumulate_indexed_sparse_mla_attention_chunk( q: torch.Tensor, kv_flat: torch.Tensor, @@ -1189,7 +1299,7 @@ def accumulate_indexed_sparse_mla_attention_chunk( num_tokens, _, head_dim = q.shape num_heads = max_score.shape[1] num_candidates = indices.shape[1] - block_d = min(1024, triton.next_power_of_2(head_dim)) + block_d = min(1024, next_power_of_2(head_dim)) head_block = _PREFILL_INDEXED_HEAD_BLOCK if num_heads >= head_block: @@ -1422,7 +1532,7 @@ def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk( num_tokens, _, head_dim = q.shape num_heads = max_score.shape[1] num_candidates = slot_ids.shape[1] - block_d = min(1024, triton.next_power_of_2(head_dim)) + block_d = min(1024, next_power_of_2(head_dim)) grid = (num_tokens, num_heads) _accumulate_fp8ds_global_slots_attention_chunk_kernel[grid]( q, @@ -1639,7 +1749,7 @@ def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( num_tokens, _, head_dim = q.shape num_heads = max_score.shape[1] num_candidates = slot_ids.shape[1] - block_d = min(1024, triton.next_power_of_2(head_dim)) + block_d = min(1024, next_power_of_2(head_dim)) grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel[grid]( q, @@ -1676,6 +1786,332 @@ def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( ) + +@triton.jit +def _indexed_d512_split_score_kernel( + q_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + scores_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + num_heads: tl.constexpr, + num_candidates: tl.constexpr, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_C: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + candidate_block = tl.program_id(2) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + candidate_offsets = candidate_block * BLOCK_C + tl.arange(0, BLOCK_C) + dim_offsets = tl.arange(0, HEAD_DIM) + head_mask = head_offsets < num_heads + valid_len = tl.load(lens_ptr + token_idx) + candidate_mask = candidate_offsets < tl.minimum(valid_len, num_candidates) + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=head_mask[:, None], + other=0.0, + ) + kv_indices = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidate_offsets * stride_indices_c, + mask=candidate_mask, + other=-1, + ) + valid_kv = kv_indices >= 0 + kv = tl.load( + kv_flat_ptr + + kv_indices[None, :].to(tl.int64) * stride_kv_t + + dim_offsets[:, None] * stride_kv_d, + mask=valid_kv[None, :], + other=0.0, + ) + scores = tl.dot(q, kv) * scale + tl.store( + scores_ptr + + token_idx * stride_scores_t + + head_offsets[:, None] * stride_scores_h + + candidate_offsets[None, :] * stride_scores_c, + scores, + mask=head_mask[:, None] & candidate_mask[None, :], + ) + + +@triton.jit +def _indexed_d512_split_stats_kernel( + scores_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_C: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + candidate_offsets = tl.arange(0, BLOCK_C) + valid_len = tl.load(lens_ptr + token_idx) + candidate_mask = candidate_offsets < tl.minimum(valid_len, num_candidates) + scores = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidate_offsets * stride_scores_c, + mask=candidate_mask, + other=-float("inf"), + ).to(tl.float32) + running_max = tl.max(scores, axis=0) + safe_max = tl.where(valid_len > 0, running_max, 0.0) + weights = tl.where(candidate_mask, tl.exp(scores - safe_max), 0.0) + running_denom = tl.sum(weights, axis=0) + + tl.store( + max_score_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + running_max, + ) + tl.store( + denom_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + running_denom, + ) + + +@triton.jit +def _indexed_d512_split_value_kernel( + scores_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + max_score_ptr, + acc_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + num_candidates: tl.constexpr, + head_dim: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_C: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + dim_block = tl.program_id(2) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + candidate_offsets = tl.arange(0, BLOCK_C) + dim_offsets = dim_block * BLOCK_D + tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + valid_len = tl.load(lens_ptr + token_idx) + max_score = tl.load( + max_score_ptr + + token_idx * stride_state_t + + head_offsets * stride_state_h, + mask=head_mask, + other=0.0, + ).to(tl.float32) + safe_max = tl.where(valid_len > 0, max_score, 0.0) + acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + for candidate_start in range(0, num_candidates, BLOCK_C): + candidates = candidate_start + candidate_offsets + candidate_mask = candidates < tl.minimum(valid_len, num_candidates) + kv_indices = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidates * stride_indices_c, + mask=candidate_mask, + other=-1, + ) + valid_kv = kv_indices >= 0 + scores = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_offsets[:, None] * stride_scores_h + + candidates[None, :] * stride_scores_c, + mask=head_mask[:, None] & candidate_mask[None, :], + other=-float("inf"), + ).to(tl.float32) + weights = tl.where( + candidate_mask[None, :], + tl.exp(scores - safe_max[:, None]), + 0.0, + ) + values = tl.load( + kv_flat_ptr + + kv_indices[:, None].to(tl.int64) * stride_kv_t + + dim_offsets[None, :] * stride_kv_d, + mask=valid_kv[:, None] & dim_mask[None, :], + other=0.0, + ) + acc += tl.dot(weights.to(tl.bfloat16), values) + + tl.store( + acc_ptr + + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d, + acc, + mask=head_mask[:, None] & dim_mask[None, :], + ) + + +def accumulate_indexed_d512_split_sparse_mla_attention( + q: torch.Tensor, + kv_flat: torch.Tensor, + indices: torch.Tensor, + lens: torch.Tensor, + scale: float, + scores: torch.Tensor, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + head_block_size: int = 16, + candidate_block_size: int = 64, + value_block_size: int = 64, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv_flat.dim() == 2 + assert indices.dim() == 2 + assert indices.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert kv_flat.shape[-1] == q.shape[-1] + assert q.shape[-1] == 512 + assert scores.shape == (q.shape[0], max_score.shape[1], indices.shape[1]) + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert scores.dtype == torch.float32 + assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda + assert scores.is_cuda and max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert head_block_size in (8, 16, 32) + assert candidate_block_size in (32, 64, 128) + assert value_block_size in (32, 64, 128) + assert indices.shape[1] <= 1152 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = indices.shape[1] + score_grid = ( + num_tokens, + triton.cdiv(num_heads, head_block_size), + triton.cdiv(num_candidates, candidate_block_size), + ) + _indexed_d512_split_score_kernel[score_grid]( + q, + kv_flat, + indices, + lens, + scores, + q.stride(0), + q.stride(1), + q.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + scores.stride(0), + scores.stride(1), + scores.stride(2), + num_heads, + num_candidates, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_C=candidate_block_size, + HEAD_DIM=head_dim, + num_warps=8, + num_stages=3, + ) + + stats_grid = (num_tokens, num_heads) + stats_block_c = next_power_of_2(num_candidates) + _indexed_d512_split_stats_kernel[stats_grid]( + scores, + lens, + max_score, + denom, + scores.stride(0), + scores.stride(1), + scores.stride(2), + max_score.stride(0), + max_score.stride(1), + num_candidates, + BLOCK_C=stats_block_c, + num_warps=4, + num_stages=3, + ) + + value_grid = ( + num_tokens, + triton.cdiv(num_heads, head_block_size), + triton.cdiv(head_dim, value_block_size), + ) + _indexed_d512_split_value_kernel[value_grid]( + scores, + kv_flat, + indices, + lens, + max_score, + acc, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + num_candidates, + head_dim, + HEAD_BLOCK=head_block_size, + BLOCK_C=candidate_block_size, + BLOCK_D=value_block_size, + num_warps=4, + num_stages=3, + ) + + @triton.autotune( configs=[ triton.Config({}, num_warps=4, num_stages=2), @@ -1841,7 +2277,7 @@ def accumulate_fp8ds_paged_sparse_mla_attention_chunk( num_tokens, _, head_dim = q.shape num_heads = max_score.shape[1] - block_d = min(1024, triton.next_power_of_2(head_dim)) + block_d = min(1024, next_power_of_2(head_dim)) grid = (num_tokens, num_heads) _accumulate_fp8ds_paged_attention_chunk_kernel[grid]( q, @@ -2057,7 +2493,7 @@ def accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( num_tokens, _, head_dim = q.shape num_heads = max_score.shape[1] - block_d = min(1024, triton.next_power_of_2(head_dim)) + block_d = min(1024, next_power_of_2(head_dim)) grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) _accumulate_fp8ds_paged_attention_chunk_multihead_kernel[grid]( q, @@ -2277,7 +2713,7 @@ def fp8ds_paged_sparse_mla_attention_with_sink_multihead( assert active_heads <= q.shape[1] assert active_heads <= output.shape[1] assert active_heads <= attn_sink.shape[0] - block_d = min(1024, triton.next_power_of_2(head_dim)) + block_d = min(1024, next_power_of_2(head_dim)) grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) _fp8ds_paged_attention_with_sink_multihead_kernel[grid]( q, @@ -2553,7 +2989,7 @@ def fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( assert active_heads <= q.shape[1] assert active_heads <= output.shape[1] assert active_heads <= attn_sink.shape[0] - block_d = min(1024, triton.next_power_of_2(head_dim)) + block_d = min(1024, next_power_of_2(head_dim)) grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) _fp8ds_global_paged_attention_with_sink_multihead_kernel[grid]( q, @@ -2677,7 +3113,7 @@ def finish_gathered_sparse_mla_attention( assert output.is_cuda and lse.is_cuda num_tokens, num_heads, head_dim = acc.shape - block_d = min(128, triton.next_power_of_2(head_dim)) + block_d = min(128, next_power_of_2(head_dim)) grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) _finish_attention_state_kernel[grid]( max_score, @@ -2886,7 +3322,7 @@ def finish_two_sparse_mla_attention_states_with_sink( assert attn_sink.is_cuda and output.is_cuda num_tokens, num_heads, head_dim = acc0.shape - block_d = min(128, triton.next_power_of_2(head_dim)) + block_d = min(128, next_power_of_2(head_dim)) grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) _finish_two_attention_states_with_sink_kernel[grid]( max_score0, @@ -2917,6 +3353,262 @@ def finish_two_sparse_mla_attention_states_with_sink( ) +def accumulate_indexed_sparse_mla_attention_partial_states( + q: torch.Tensor, + kv_flat: torch.Tensor, + indices: torch.Tensor, + lens: torch.Tensor, + scale: float, + part_size: int, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv_flat.dim() == 2 + assert indices.dim() == 2 + assert indices.shape[0] == q.shape[0] + assert kv_flat.shape[-1] == q.shape[-1] + assert lens.shape[0] == q.shape[0] + assert max_score.dim() == 3 + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.shape[1] == q.shape[0] + assert max_score.shape[2] <= q.shape[1] + assert part_size > 0 + assert max_score.shape[0] * part_size >= indices.shape[1] + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + num_tokens, _, head_dim = q.shape + num_parts, _, num_heads = max_score.shape + num_candidates = indices.shape[1] + block_d = min(1024, next_power_of_2(head_dim)) + head_block = _PREFILL_INDEXED_HEAD_BLOCK + + if num_heads >= head_block: + grid = (num_tokens, triton.cdiv(num_heads, head_block), num_parts) + _accumulate_indexed_attention_partial_states_multihead_kernel[grid]( + q, + kv_flat, + indices, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + max_score.stride(0), + max_score.stride(1), + max_score.stride(2), + acc.stride(0), + acc.stride(1), + acc.stride(2), + acc.stride(3), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + PART_SIZE=part_size, + HEAD_BLOCK=head_block, + BLOCK_D=block_d, + num_warps=4, + num_stages=2, + ) + else: + for part_idx in range(num_parts): + part_start = part_idx * part_size + part_end = min(part_start + part_size, num_candidates) + max_score[part_idx].fill_(float("-inf")) + denom[part_idx].zero_() + acc[part_idx].zero_() + if part_start >= part_end: + continue + accumulate_indexed_sparse_mla_attention_chunk( + q=q, + kv_flat=kv_flat, + indices=indices[:, part_start:part_end], + lens=lens, + candidate_offset=candidate_offset + part_start, + scale=scale, + max_score=max_score[part_idx], + denom=denom[part_idx], + acc=acc[part_idx], + ) + + +@triton.jit +def _merge_attention_states_kernel( + max0_ptr, + denom0_ptr, + acc0_ptr, + max1_ptr, + denom1_ptr, + acc1_ptr, + out_max_ptr, + out_denom_ptr, + out_acc_ptr, + stride_state0_t: tl.constexpr, + stride_state0_h: tl.constexpr, + stride_acc0_t: tl.constexpr, + stride_acc0_h: tl.constexpr, + stride_acc0_d: tl.constexpr, + stride_state1_t: tl.constexpr, + stride_state1_h: tl.constexpr, + stride_acc1_t: tl.constexpr, + stride_acc1_h: tl.constexpr, + stride_acc1_d: tl.constexpr, + stride_out_state_t: tl.constexpr, + stride_out_state_h: tl.constexpr, + stride_out_acc_t: tl.constexpr, + stride_out_acc_h: tl.constexpr, + stride_out_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state0_offset = token_idx * stride_state0_t + head_idx * stride_state0_h + state1_offset = token_idx * stride_state1_t + head_idx * stride_state1_h + out_state_offset = token_idx * stride_out_state_t + head_idx * stride_out_state_h + max0 = tl.load(max0_ptr + state0_offset) + denom0 = tl.load(denom0_ptr + state0_offset) + max1 = tl.load(max1_ptr + state1_offset) + denom1 = tl.load(denom1_ptr + state1_offset) + + has0 = denom0 > 0.0 + has1 = denom1 > 0.0 + valid_max0 = tl.where(has0, max0, -float("inf")) + valid_max1 = tl.where(has1, max1, -float("inf")) + merge_max = tl.maximum(valid_max0, valid_max1) + has_any = has0 | has1 + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_max0 = tl.where(has0, max0, safe_merge_max) + safe_max1 = tl.where(has1, max1, safe_merge_max) + scale0 = tl.where(has0, tl.exp(safe_max0 - safe_merge_max), 0.0) + scale1 = tl.where(has1, tl.exp(safe_max1 - safe_merge_max), 0.0) + merged_denom = denom0 * scale0 + denom1 * scale1 + merged_max = tl.where(has_any, merge_max, -float("inf")) + + acc0 = tl.load( + acc0_ptr + + token_idx * stride_acc0_t + + head_idx * stride_acc0_h + + offsets * stride_acc0_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc1 = tl.load( + acc1_ptr + + token_idx * stride_acc1_t + + head_idx * stride_acc1_h + + offsets * stride_acc1_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc0 = tl.where(has0, acc0, 0.0) + acc1 = tl.where(has1, acc1, 0.0) + merged_acc = acc0 * scale0 + acc1 * scale1 + + tl.store(out_max_ptr + out_state_offset, merged_max) + tl.store(out_denom_ptr + out_state_offset, merged_denom) + tl.store( + out_acc_ptr + + token_idx * stride_out_acc_t + + head_idx * stride_out_acc_h + + offsets * stride_out_acc_d, + merged_acc, + mask=dim_mask, + ) + + +def merge_sparse_mla_attention_states( + max_score0: torch.Tensor, + denom0: torch.Tensor, + acc0: torch.Tensor, + max_score1: torch.Tensor, + denom1: torch.Tensor, + acc1: torch.Tensor, + output_max_score: torch.Tensor, + output_denom: torch.Tensor, + output_acc: torch.Tensor, +) -> None: + assert max_score0.shape == denom0.shape + assert max_score1.shape == denom1.shape + assert max_score0.shape == max_score1.shape + assert output_max_score.shape == max_score0.shape + assert output_denom.shape == denom0.shape + assert acc0.shape == acc1.shape + assert output_acc.shape == acc0.shape + assert acc0.shape[:2] == max_score0.shape + assert max_score0.dtype == torch.float32 + assert denom0.dtype == torch.float32 + assert acc0.dtype == torch.float32 + assert max_score1.dtype == torch.float32 + assert denom1.dtype == torch.float32 + assert acc1.dtype == torch.float32 + assert output_max_score.dtype == torch.float32 + assert output_denom.dtype == torch.float32 + assert output_acc.dtype == torch.float32 + assert max_score0.is_cuda and denom0.is_cuda and acc0.is_cuda + assert max_score1.is_cuda and denom1.is_cuda and acc1.is_cuda + assert output_max_score.is_cuda and output_denom.is_cuda and output_acc.is_cuda + + num_tokens, num_heads, head_dim = acc0.shape + block_d = min(128, next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _merge_attention_states_kernel[grid]( + max_score0, + denom0, + acc0, + max_score1, + denom1, + acc1, + output_max_score, + output_denom, + output_acc, + max_score0.stride(0), + max_score0.stride(1), + acc0.stride(0), + acc0.stride(1), + acc0.stride(2), + max_score1.stride(0), + max_score1.stride(1), + acc1.stride(0), + acc1.stride(1), + acc1.stride(2), + output_max_score.stride(0), + output_max_score.stride(1), + output_acc.stride(0), + output_acc.stride(1), + output_acc.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + def finish_sparse_mla_attention_with_sink( max_score: torch.Tensor, denom: torch.Tensor, @@ -2937,7 +3629,7 @@ def finish_sparse_mla_attention_with_sink( assert attn_sink.is_cuda and output.is_cuda num_tokens, num_heads, head_dim = acc.shape - block_d = min(128, triton.next_power_of_2(head_dim)) + block_d = min(128, next_power_of_2(head_dim)) grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) _finish_attention_state_with_sink_kernel[grid]( max_score, From 42aade9c46e6affa36970d18882cdbab38274619 Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 4 Jun 2026 02:23:46 +0800 Subject: [PATCH 074/131] sm12x: warm high-concurrency MTP decode workspace Signed-off-by: jasl --- .../test_deepseek_v4_kernel_warmup.py | 37 +++++++++++++++++++ vllm/model_executor/warmup/kernel_warmup.py | 10 ++--- 2 files changed, 42 insertions(+), 5 deletions(-) create mode 100644 tests/model_executor/test_deepseek_v4_kernel_warmup.py diff --git a/tests/model_executor/test_deepseek_v4_kernel_warmup.py b/tests/model_executor/test_deepseek_v4_kernel_warmup.py new file mode 100644 index 000000000000..6747107d59d2 --- /dev/null +++ b/tests/model_executor/test_deepseek_v4_kernel_warmup.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +from vllm.model_executor.warmup import kernel_warmup + + +def _mtp_runner(query_len: int = 3): + return SimpleNamespace( + speculative_config=SimpleNamespace(method="mtp"), + num_spec_tokens=query_len - 1, + uniform_decode_query_len=query_len, + ) + + +def test_deepseek_v4_mtp_uniform_decode_warmup_covers_c256(): + requests = kernel_warmup._deepseek_v4_mtp_uniform_decode_warmup_requests( + _mtp_runner(), + max_tokens=4096, + max_reqs=256, + ) + + assert requests == (1, 2, 4, 8, 16, 24, 32, 256) + + +def test_deepseek_v4_mtp_uniform_decode_warmup_still_respects_limits(): + assert kernel_warmup._deepseek_v4_mtp_uniform_decode_warmup_requests( + _mtp_runner(), + max_tokens=4096, + max_reqs=24, + ) == (1, 2, 4, 8, 16, 24) + assert kernel_warmup._deepseek_v4_mtp_uniform_decode_warmup_requests( + _mtp_runner(), + max_tokens=96, + max_reqs=256, + ) == (1, 2, 4, 8, 16, 24, 32) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index d0a6c390912f..b23e13668048 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -43,12 +43,12 @@ # value via _clamp_warmup_tokens at the call site, smaller caps clamp # down naturally. _DEEPSEEK_V4_SPARSE_MLA_PREFILL_WARMUP_TOKENS = 8192 -# Steady-state MTP decode shapes to warm. Keep this bounded to the edge -# deployment range we expect to optimize; warming the scheduler's raw -# max_num_seqs (often 1024) can consume multiple GiB of temporary workspace -# on long-context SM12x serves before the first request. +# Steady-state MTP decode shapes to warm. Keep this bounded to high-concurrency +# SM12x gates while still avoiding the scheduler's raw max_num_seqs (often 1024), +# which can consume multiple GiB of temporary workspace on long-context serves +# before the first request. _DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS = (1, 2, 4, 8, 16, 24, 32) -_DEEPSEEK_V4_MTP_UNIFORM_DECODE_MAX_WARMUP_REQUESTS = 32 +_DEEPSEEK_V4_MTP_UNIFORM_DECODE_MAX_WARMUP_REQUESTS = 256 _DEEPSEEK_V4_SLOT_MAPPING_WARMUP_TOKENS = tuple(range(1, 17)) + ( 32, 64, From 1a3e89cdb021f9f127288e8ace019020025cb70a Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 4 Jun 2026 10:57:44 +0800 Subject: [PATCH 075/131] sm12x: enable indexed D512 sparse MLA prefill by default Co-authored-by: OpenAI Codex Signed-off-by: jasl --- vllm/envs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 90d37f7b5541..390f6e877574 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -180,7 +180,7 @@ VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: bool = True - VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL: bool = False + VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL: bool = True VLLM_TRITON_MLA_SPARSE: bool | None = None VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE: int = 512 VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE: int = 256 @@ -1455,7 +1455,7 @@ def _resolve_rust_frontend_path() -> str | None: int(os.getenv("VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP", "1")) ), "VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL": lambda: bool( - int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL", "0")) + int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL", "1")) ), # Experimental sparse MLA fallback controls. # ``VLLM_TRITON_MLA_SPARSE`` unset means auto-select where FlashMLA sparse From e4b8347164661ef123ce1551c5b8ca70f6deb888 Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 4 Jun 2026 18:23:19 +0800 Subject: [PATCH 076/131] sm12x: retune D512 sparse MLA split tiles Signed-off-by: jasl --- vllm/v1/attention/backends/mla/sparse_mla_kernels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py index 9dc18fed8b6c..8fe3ca0e9b4e 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -1996,9 +1996,9 @@ def accumulate_indexed_d512_split_sparse_mla_attention( max_score: torch.Tensor, denom: torch.Tensor, acc: torch.Tensor, - head_block_size: int = 16, + head_block_size: int = 32, candidate_block_size: int = 64, - value_block_size: int = 64, + value_block_size: int = 128, ) -> None: if q.dim() == 4: assert q.shape[1] == 1 From 087c9617189aa67bc4e4cb75bc39eed54e2a5287 Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 4 Jun 2026 18:33:37 +0800 Subject: [PATCH 077/131] fix: align prefix cache manager signatures after rebase Signed-off-by: jasl --- vllm/v1/core/single_type_kv_cache_manager.py | 33 +++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 2ccffaed18d8..87e73c5e4c4f 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -397,6 +397,7 @@ def cache_blocks( self, request: Request, num_tokens: int, + alignment_tokens: int | None = None, retention_interval: int | None = None, ) -> None: """ @@ -406,6 +407,8 @@ def cache_blocks( request: The request. num_tokens: The total number of tokens that need to be cached (including tokens that are already cached). + alignment_tokens: The prefix-cache hit alignment in tokens. + ``None`` uses this manager's scheduler block size. retention_interval: Sparse local-checkpoint granularity. ``None`` keeps dense checkpointing; ``0`` keeps only the latest replay boundary; a positive multiple of ``scheduler_block_size`` keeps @@ -417,10 +420,13 @@ def cache_blocks( if num_cached_blocks >= num_full_blocks: return + if alignment_tokens is None: + alignment_tokens = self.scheduler_block_size + block_mask = self.reachable_block_mask( start_block=num_cached_blocks, end_block=num_full_blocks, - alignment_tokens=self.scheduler_block_size, + alignment_tokens=alignment_tokens, kv_cache_spec=self.kv_cache_spec, use_eagle=self.use_eagle, retention_interval=retention_interval, @@ -703,8 +709,14 @@ def cache_blocks( request: Request, num_tokens: int, alignment_tokens: int | None = None, + retention_interval: int | None = None, ) -> None: - super().cache_blocks(request, num_tokens, alignment_tokens=alignment_tokens) + super().cache_blocks( + request, + num_tokens, + alignment_tokens=alignment_tokens, + retention_interval=retention_interval, + ) if not self._should_protect_prompt_blocks() or request.num_prompt_tokens <= 1: return @@ -959,8 +971,14 @@ def cache_blocks( request: Request, num_tokens: int, alignment_tokens: int | None = None, + retention_interval: int | None = None, ) -> None: - super().cache_blocks(request, num_tokens, alignment_tokens=alignment_tokens) + super().cache_blocks( + request, + num_tokens, + alignment_tokens=alignment_tokens, + retention_interval=retention_interval, + ) if not self.enable_caching or request.num_prompt_tokens <= 1: return @@ -1395,10 +1413,16 @@ def cache_blocks( self, request: Request, num_tokens: int, + alignment_tokens: int | None = None, retention_interval: int | None = None, ) -> None: num_cached_blocks_before = self.num_cached_block.get(request.request_id, 0) - super().cache_blocks(request, num_tokens, retention_interval=retention_interval) + super().cache_blocks( + request, + num_tokens, + alignment_tokens=alignment_tokens, + retention_interval=retention_interval, + ) num_cached_blocks_after = self.num_cached_block.get(request.request_id, 0) if num_cached_blocks_after > num_cached_blocks_before: for block in self.req_to_blocks[request.request_id][ @@ -1440,6 +1464,7 @@ def cache_blocks( self, request: Request, num_tokens: int, + alignment_tokens: int | None = None, retention_interval: int | None = None, ) -> None: # We do not cache blocks for cross-attention to be shared between From d3373c5e8cb869d3744733e21befa58c2143aa4a Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 4 Jun 2026 21:29:52 +0800 Subject: [PATCH 078/131] sm12x: skip empty D512 sparse MLA tail blocks Signed-off-by: jasl --- .../backends/mla/sparse_mla_kernels.py | 65 ++++++++++--------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py index 8fe3ca0e9b4e..6fc9bf924800 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -1819,6 +1819,8 @@ def _indexed_d512_split_score_kernel( dim_offsets = tl.arange(0, HEAD_DIM) head_mask = head_offsets < num_heads valid_len = tl.load(lens_ptr + token_idx) + if candidate_block * BLOCK_C >= tl.minimum(valid_len, num_candidates): + return candidate_mask = candidate_offsets < tl.minimum(valid_len, num_candidates) q = tl.load( @@ -1944,37 +1946,38 @@ def _indexed_d512_split_value_kernel( acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) for candidate_start in range(0, num_candidates, BLOCK_C): - candidates = candidate_start + candidate_offsets - candidate_mask = candidates < tl.minimum(valid_len, num_candidates) - kv_indices = tl.load( - indices_ptr - + token_idx * stride_indices_t - + candidates * stride_indices_c, - mask=candidate_mask, - other=-1, - ) - valid_kv = kv_indices >= 0 - scores = tl.load( - scores_ptr - + token_idx * stride_scores_t - + head_offsets[:, None] * stride_scores_h - + candidates[None, :] * stride_scores_c, - mask=head_mask[:, None] & candidate_mask[None, :], - other=-float("inf"), - ).to(tl.float32) - weights = tl.where( - candidate_mask[None, :], - tl.exp(scores - safe_max[:, None]), - 0.0, - ) - values = tl.load( - kv_flat_ptr - + kv_indices[:, None].to(tl.int64) * stride_kv_t - + dim_offsets[None, :] * stride_kv_d, - mask=valid_kv[:, None] & dim_mask[None, :], - other=0.0, - ) - acc += tl.dot(weights.to(tl.bfloat16), values) + if candidate_start < tl.minimum(valid_len, num_candidates): + candidates = candidate_start + candidate_offsets + candidate_mask = candidates < tl.minimum(valid_len, num_candidates) + kv_indices = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidates * stride_indices_c, + mask=candidate_mask, + other=-1, + ) + valid_kv = kv_indices >= 0 + scores = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_offsets[:, None] * stride_scores_h + + candidates[None, :] * stride_scores_c, + mask=head_mask[:, None] & candidate_mask[None, :], + other=-float("inf"), + ).to(tl.float32) + weights = tl.where( + candidate_mask[None, :], + tl.exp(scores - safe_max[:, None]), + 0.0, + ) + values = tl.load( + kv_flat_ptr + + kv_indices[:, None].to(tl.int64) * stride_kv_t + + dim_offsets[None, :] * stride_kv_d, + mask=valid_kv[:, None] & dim_mask[None, :], + other=0.0, + ) + acc += tl.dot(weights.to(tl.bfloat16), values) tl.store( acc_ptr From dad43ed9adca4ad583b07fb43b81a093b241910d Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 5 Jun 2026 15:22:45 +0800 Subject: [PATCH 079/131] sm12x: clean sparse MLA rebase leftovers Signed-off-by: jasl --- vllm/models/deepseek_v4/nvidia/flashmla.py | 6 +- .../backends/mla/sparse_mla_kernels.py | 365 ------------------ 2 files changed, 5 insertions(+), 366 deletions(-) diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index 36cd8f525776..29b1bcbc5c11 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -12,6 +12,8 @@ combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, + dequantize_combined_sparse_mla_decode_kv, + dequantize_global_slots_k_cache, sparse_prefill_combined_topk_size, ) from vllm.models.deepseek_v4.nvidia.ops.o_proj import ( @@ -198,7 +200,9 @@ def _prefill_workspace_reservation_specs( max_prefill_seq_len=max_model_len, swa_only=False, ): - specs.append(((query_chunk_size, num_heads, combined_topk), torch.float32)) + specs.append( + ((query_chunk_size, num_heads, combined_topk), torch.float32) + ) return tuple(specs) @classmethod diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py index 6fc9bf924800..01f01170218c 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -1156,115 +1156,6 @@ def _accumulate_indexed_attention_chunk_multihead_kernel( ) -@triton.jit -def _accumulate_indexed_attention_partial_states_multihead_kernel( - q_ptr, - kv_flat_ptr, - indices_ptr, - lens_ptr, - max_score_ptr, - denom_ptr, - acc_ptr, - stride_q_t: tl.constexpr, - stride_q_h: tl.constexpr, - stride_q_d: tl.constexpr, - stride_kv_t, - stride_kv_d: tl.constexpr, - stride_indices_t: tl.constexpr, - stride_indices_c: tl.constexpr, - stride_state_p: tl.constexpr, - stride_state_t: tl.constexpr, - stride_state_h: tl.constexpr, - stride_acc_p: tl.constexpr, - stride_acc_t: tl.constexpr, - stride_acc_h: tl.constexpr, - stride_acc_d: tl.constexpr, - num_heads: tl.constexpr, - head_dim: tl.constexpr, - num_candidates, - candidate_offset, - scale: tl.constexpr, - PART_SIZE: tl.constexpr, - HEAD_BLOCK: tl.constexpr, - BLOCK_D: tl.constexpr, -): - token_idx = tl.program_id(0) - head_block_idx = tl.program_id(1) - part_idx = tl.program_id(2) - part_start = part_idx * PART_SIZE - head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) - dim_offsets = tl.arange(0, BLOCK_D) - head_mask = head_offsets < num_heads - dim_mask = dim_offsets < head_dim - - q = tl.load( - q_ptr - + token_idx * stride_q_t - + head_offsets[:, None] * stride_q_h - + dim_offsets[None, :] * stride_q_d, - mask=head_mask[:, None] & dim_mask[None, :], - other=0.0, - ).to(tl.float32) - - running_max = tl.full((HEAD_BLOCK,), -float("inf"), dtype=tl.float32) - running_denom = tl.zeros((HEAD_BLOCK,), dtype=tl.float32) - running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), dtype=tl.float32) - - valid_len = tl.load(lens_ptr + token_idx) - local_remaining = tl.maximum(num_candidates - part_start, 0) - local_valid = tl.maximum(valid_len - candidate_offset - part_start, 0) - local_eff = tl.minimum(PART_SIZE, tl.minimum(local_remaining, local_valid)) - - for local_idx in range(0, local_eff): - candidate_idx = part_start + local_idx - kv_index = tl.load( - indices_ptr - + token_idx * stride_indices_t - + candidate_idx * stride_indices_c - ) - is_valid = kv_index >= 0 - - if is_valid: - kv = tl.load( - kv_flat_ptr - + kv_index.to(tl.int64) * stride_kv_t - + dim_offsets * stride_kv_d, - mask=dim_mask, - other=0.0, - ).to(tl.float32) - scores = tl.sum(q * kv[None, :], axis=1) * scale - next_max = tl.maximum(running_max, scores) - previous_weight = tl.exp(running_max - next_max) - candidate_weight = tl.exp(scores - next_max) - running_acc = ( - running_acc * previous_weight[:, None] - + kv[None, :] * candidate_weight[:, None] - ) - running_denom = running_denom * previous_weight + candidate_weight - running_max = next_max - - state_base = part_idx * stride_state_p + token_idx * stride_state_t - acc_base = part_idx * stride_acc_p + token_idx * stride_acc_t - tl.store( - max_score_ptr + state_base + head_offsets * stride_state_h, - running_max, - mask=head_mask, - ) - tl.store( - denom_ptr + state_base + head_offsets * stride_state_h, - running_denom, - mask=head_mask, - ) - tl.store( - acc_ptr - + acc_base - + head_offsets[:, None] * stride_acc_h - + dim_offsets[None, :] * stride_acc_d, - running_acc, - mask=head_mask[:, None] & dim_mask[None, :], - ) - - def accumulate_indexed_sparse_mla_attention_chunk( q: torch.Tensor, kv_flat: torch.Tensor, @@ -3356,262 +3247,6 @@ def finish_two_sparse_mla_attention_states_with_sink( ) -def accumulate_indexed_sparse_mla_attention_partial_states( - q: torch.Tensor, - kv_flat: torch.Tensor, - indices: torch.Tensor, - lens: torch.Tensor, - scale: float, - part_size: int, - max_score: torch.Tensor, - denom: torch.Tensor, - acc: torch.Tensor, - candidate_offset: int = 0, -) -> None: - if q.dim() == 4: - assert q.shape[1] == 1 - q = q[:, 0] - - assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" - assert kv_flat.dim() == 2 - assert indices.dim() == 2 - assert indices.shape[0] == q.shape[0] - assert kv_flat.shape[-1] == q.shape[-1] - assert lens.shape[0] == q.shape[0] - assert max_score.dim() == 3 - assert denom.shape == max_score.shape - assert acc.shape == (*max_score.shape, q.shape[-1]) - assert max_score.shape[1] == q.shape[0] - assert max_score.shape[2] <= q.shape[1] - assert part_size > 0 - assert max_score.shape[0] * part_size >= indices.shape[1] - assert max_score.dtype == torch.float32 - assert denom.dtype == torch.float32 - assert acc.dtype == torch.float32 - assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda - assert max_score.is_cuda and denom.is_cuda and acc.is_cuda - - num_tokens, _, head_dim = q.shape - num_parts, _, num_heads = max_score.shape - num_candidates = indices.shape[1] - block_d = min(1024, next_power_of_2(head_dim)) - head_block = _PREFILL_INDEXED_HEAD_BLOCK - - if num_heads >= head_block: - grid = (num_tokens, triton.cdiv(num_heads, head_block), num_parts) - _accumulate_indexed_attention_partial_states_multihead_kernel[grid]( - q, - kv_flat, - indices, - lens, - max_score, - denom, - acc, - q.stride(0), - q.stride(1), - q.stride(2), - kv_flat.stride(0), - kv_flat.stride(1), - indices.stride(0), - indices.stride(1), - max_score.stride(0), - max_score.stride(1), - max_score.stride(2), - acc.stride(0), - acc.stride(1), - acc.stride(2), - acc.stride(3), - num_heads, - head_dim, - num_candidates, - candidate_offset, - scale, - PART_SIZE=part_size, - HEAD_BLOCK=head_block, - BLOCK_D=block_d, - num_warps=4, - num_stages=2, - ) - else: - for part_idx in range(num_parts): - part_start = part_idx * part_size - part_end = min(part_start + part_size, num_candidates) - max_score[part_idx].fill_(float("-inf")) - denom[part_idx].zero_() - acc[part_idx].zero_() - if part_start >= part_end: - continue - accumulate_indexed_sparse_mla_attention_chunk( - q=q, - kv_flat=kv_flat, - indices=indices[:, part_start:part_end], - lens=lens, - candidate_offset=candidate_offset + part_start, - scale=scale, - max_score=max_score[part_idx], - denom=denom[part_idx], - acc=acc[part_idx], - ) - - -@triton.jit -def _merge_attention_states_kernel( - max0_ptr, - denom0_ptr, - acc0_ptr, - max1_ptr, - denom1_ptr, - acc1_ptr, - out_max_ptr, - out_denom_ptr, - out_acc_ptr, - stride_state0_t: tl.constexpr, - stride_state0_h: tl.constexpr, - stride_acc0_t: tl.constexpr, - stride_acc0_h: tl.constexpr, - stride_acc0_d: tl.constexpr, - stride_state1_t: tl.constexpr, - stride_state1_h: tl.constexpr, - stride_acc1_t: tl.constexpr, - stride_acc1_h: tl.constexpr, - stride_acc1_d: tl.constexpr, - stride_out_state_t: tl.constexpr, - stride_out_state_h: tl.constexpr, - stride_out_acc_t: tl.constexpr, - stride_out_acc_h: tl.constexpr, - stride_out_acc_d: tl.constexpr, - num_heads: tl.constexpr, - head_dim: tl.constexpr, - BLOCK_D: tl.constexpr, -): - token_head = tl.program_id(0) - block_d = tl.program_id(1) - token_idx = token_head // num_heads - head_idx = token_head - token_idx * num_heads - offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) - dim_mask = offsets < head_dim - - state0_offset = token_idx * stride_state0_t + head_idx * stride_state0_h - state1_offset = token_idx * stride_state1_t + head_idx * stride_state1_h - out_state_offset = token_idx * stride_out_state_t + head_idx * stride_out_state_h - max0 = tl.load(max0_ptr + state0_offset) - denom0 = tl.load(denom0_ptr + state0_offset) - max1 = tl.load(max1_ptr + state1_offset) - denom1 = tl.load(denom1_ptr + state1_offset) - - has0 = denom0 > 0.0 - has1 = denom1 > 0.0 - valid_max0 = tl.where(has0, max0, -float("inf")) - valid_max1 = tl.where(has1, max1, -float("inf")) - merge_max = tl.maximum(valid_max0, valid_max1) - has_any = has0 | has1 - safe_merge_max = tl.where(has_any, merge_max, 0.0) - safe_max0 = tl.where(has0, max0, safe_merge_max) - safe_max1 = tl.where(has1, max1, safe_merge_max) - scale0 = tl.where(has0, tl.exp(safe_max0 - safe_merge_max), 0.0) - scale1 = tl.where(has1, tl.exp(safe_max1 - safe_merge_max), 0.0) - merged_denom = denom0 * scale0 + denom1 * scale1 - merged_max = tl.where(has_any, merge_max, -float("inf")) - - acc0 = tl.load( - acc0_ptr - + token_idx * stride_acc0_t - + head_idx * stride_acc0_h - + offsets * stride_acc0_d, - mask=dim_mask, - other=0.0, - ).to(tl.float32) - acc1 = tl.load( - acc1_ptr - + token_idx * stride_acc1_t - + head_idx * stride_acc1_h - + offsets * stride_acc1_d, - mask=dim_mask, - other=0.0, - ).to(tl.float32) - acc0 = tl.where(has0, acc0, 0.0) - acc1 = tl.where(has1, acc1, 0.0) - merged_acc = acc0 * scale0 + acc1 * scale1 - - tl.store(out_max_ptr + out_state_offset, merged_max) - tl.store(out_denom_ptr + out_state_offset, merged_denom) - tl.store( - out_acc_ptr - + token_idx * stride_out_acc_t - + head_idx * stride_out_acc_h - + offsets * stride_out_acc_d, - merged_acc, - mask=dim_mask, - ) - - -def merge_sparse_mla_attention_states( - max_score0: torch.Tensor, - denom0: torch.Tensor, - acc0: torch.Tensor, - max_score1: torch.Tensor, - denom1: torch.Tensor, - acc1: torch.Tensor, - output_max_score: torch.Tensor, - output_denom: torch.Tensor, - output_acc: torch.Tensor, -) -> None: - assert max_score0.shape == denom0.shape - assert max_score1.shape == denom1.shape - assert max_score0.shape == max_score1.shape - assert output_max_score.shape == max_score0.shape - assert output_denom.shape == denom0.shape - assert acc0.shape == acc1.shape - assert output_acc.shape == acc0.shape - assert acc0.shape[:2] == max_score0.shape - assert max_score0.dtype == torch.float32 - assert denom0.dtype == torch.float32 - assert acc0.dtype == torch.float32 - assert max_score1.dtype == torch.float32 - assert denom1.dtype == torch.float32 - assert acc1.dtype == torch.float32 - assert output_max_score.dtype == torch.float32 - assert output_denom.dtype == torch.float32 - assert output_acc.dtype == torch.float32 - assert max_score0.is_cuda and denom0.is_cuda and acc0.is_cuda - assert max_score1.is_cuda and denom1.is_cuda and acc1.is_cuda - assert output_max_score.is_cuda and output_denom.is_cuda and output_acc.is_cuda - - num_tokens, num_heads, head_dim = acc0.shape - block_d = min(128, next_power_of_2(head_dim)) - grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) - _merge_attention_states_kernel[grid]( - max_score0, - denom0, - acc0, - max_score1, - denom1, - acc1, - output_max_score, - output_denom, - output_acc, - max_score0.stride(0), - max_score0.stride(1), - acc0.stride(0), - acc0.stride(1), - acc0.stride(2), - max_score1.stride(0), - max_score1.stride(1), - acc1.stride(0), - acc1.stride(1), - acc1.stride(2), - output_max_score.stride(0), - output_max_score.stride(1), - output_acc.stride(0), - output_acc.stride(1), - output_acc.stride(2), - num_heads, - head_dim, - BLOCK_D=block_d, - num_warps=4, - ) - - def finish_sparse_mla_attention_with_sink( max_score: torch.Tensor, denom: torch.Tensor, From 05e4a1d0bdb207662fb333b4b766e9eb05905582 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 5 Jun 2026 15:50:22 +0800 Subject: [PATCH 080/131] sm12x: restore DeepSeek V4 O-proj FP8 einsum layout Signed-off-by: jasl --- .../model_executor/test_deepseek_v4_o_proj.py | 32 +++++++++++++++++++ vllm/models/deepseek_v4/nvidia/ops/o_proj.py | 25 +++++++++------ 2 files changed, 47 insertions(+), 10 deletions(-) create mode 100644 tests/model_executor/test_deepseek_v4_o_proj.py diff --git a/tests/model_executor/test_deepseek_v4_o_proj.py b/tests/model_executor/test_deepseek_v4_o_proj.py new file mode 100644 index 000000000000..09b32e43b370 --- /dev/null +++ b/tests/model_executor/test_deepseek_v4_o_proj.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.models.deepseek_v4.nvidia.ops.o_proj import compute_fp8_einsum_recipe +from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability + + +@pytest.mark.parametrize( + ("capability", "expected_recipe", "expected_tma_aligned"), + [ + ((9, 0), (1, 128, 128), False), + ((10, 0), (1, 1, 128), True), + ((12, 0), (1, 128, 128), False), + ((12, 1), (1, 128, 128), False), + ], +) +def test_deepseek_v4_o_proj_recipe_is_arch_specific( + monkeypatch: pytest.MonkeyPatch, + capability: tuple[int, int], + expected_recipe: tuple[int, int, int], + expected_tma_aligned: bool, +): + monkeypatch.setattr( + current_platform, + "get_device_capability", + lambda device_id=0: DeviceCapability(*capability), + ) + + assert compute_fp8_einsum_recipe() == (expected_recipe, expected_tma_aligned) diff --git a/vllm/models/deepseek_v4/nvidia/ops/o_proj.py b/vllm/models/deepseek_v4/nvidia/ops/o_proj.py index 54c255fc3ada..1b6f3db7de5c 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/o_proj.py +++ b/vllm/models/deepseek_v4/nvidia/ops/o_proj.py @@ -6,23 +6,26 @@ from vllm.models.deepseek_v4.common.ops.fused_inv_rope_fp8_quant import ( fused_inv_rope_fp8_quant, ) +from vllm.models.deepseek_v4.nvidia.ops.fp8_einsum import ( + deepseek_v4_fp8_einsum, + deepseek_v4_fp8_einsum_config, +) from vllm.platforms import current_platform -from vllm.utils.deep_gemm import fp8_einsum def compute_fp8_einsum_recipe() -> tuple[tuple[int, int, int], bool]: """fp8_einsum recipe + scale layout for the current GPU arch. SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128. - SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1. + SM100/SM110: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1. + SM12x: RTX PRO / GB10 does not expose the same TMA/TCGEN05 path, so keep + the legacy FP32 block-scale layout expected by DeepGEMM. Returns ``(einsum_recipe, tma_aligned_scales)`` for ``deep_gemm_fp8_o_proj``. """ cap = current_platform.get_device_capability() assert cap is not None, "DeepseekV4 attention requires a CUDA device" - einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) - tma_aligned_scales = cap.major >= 10 - return einsum_recipe, tma_aligned_scales + return deepseek_v4_fp8_einsum_config(cap.major) def deep_gemm_fp8_o_proj( @@ -65,11 +68,13 @@ def deep_gemm_fp8_o_proj( wo_a_scale = getattr(wo_a, "weight_scale_inv", None) if wo_a_scale is None: wo_a_scale = wo_a.weight_scale - fp8_einsum( - "bhr,hdr->bhd", - (o_fp8, o_scale), - (wo_a.weight, wo_a_scale), + deepseek_v4_fp8_einsum( + o_fp8, + o_scale, + wo_a.weight, + wo_a_scale, z, - recipe=einsum_recipe, + "bhr,hdr->bhd", + list(einsum_recipe), ) return wo_b(z.flatten(1)) From 25824b8371dc7e34e4c1f923b2a69b586a939c81 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 5 Jun 2026 15:50:28 +0800 Subject: [PATCH 081/131] sm12x: restore Triton sparse MLA decode dispatch Signed-off-by: jasl --- ...st_deepseek_v4_flashmla_decode_dispatch.py | 146 ++++++++++++++++++ vllm/models/deepseek_v4/nvidia/flashmla.py | 29 ++++ 2 files changed, 175 insertions(+) create mode 100644 tests/model_executor/test_deepseek_v4_flashmla_decode_dispatch.py diff --git a/tests/model_executor/test_deepseek_v4_flashmla_decode_dispatch.py b/tests/model_executor/test_deepseek_v4_flashmla_decode_dispatch.py new file mode 100644 index 000000000000..06e034f6b01c --- /dev/null +++ b/tests/model_executor/test_deepseek_v4_flashmla_decode_dispatch.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +import torch + +from vllm.models.deepseek_v4.nvidia import flashmla as flashmla_mod +from vllm.models.deepseek_v4.nvidia.flashmla import DeepseekV4FlashMLAAttention + + +def _make_layer(compress_ratio: int) -> DeepseekV4FlashMLAAttention: + layer = object.__new__(DeepseekV4FlashMLAAttention) + layer.compress_ratio = compress_ratio + layer.swa_cache_layer = SimpleNamespace(kv_cache=torch.empty(1, 1, 512)) + layer.n_local_heads = 2 + layer.scale = 1.0 + layer.attn_sink = torch.zeros(2) + return layer + + +def _make_swa_metadata() -> SimpleNamespace: + return SimpleNamespace( + num_decodes=1, + num_decode_tokens=1, + decode_swa_indices=torch.zeros((1, 4), dtype=torch.int32), + decode_swa_lens=torch.ones(1, dtype=torch.int32), + is_valid_token=torch.ones(1, dtype=torch.bool), + token_to_req_indices=torch.zeros(1, dtype=torch.int32), + block_table=torch.zeros((1, 1), dtype=torch.int32), + block_size=64, + seq_lens=torch.full((1,), 4, dtype=torch.int32), + tile_sched_swaonly=None, + tile_sched_c4a=None, + tile_sched_c128a=None, + ) + + +def test_swa_decode_uses_triton_path_without_flashmla_tile_sched(monkeypatch): + layer = _make_layer(compress_ratio=1) + metadata = _make_swa_metadata() + calls = [] + + def fake_decode( + cls, + layer, + q, + swa_k_cache, + swa_metadata, + output, + ): + calls.append( + (layer, q.shape, swa_k_cache.shape, swa_metadata, output.shape) + ) + + monkeypatch.setattr(flashmla_mod, "is_triton_sparse_mla_enabled", lambda _: True) + monkeypatch.setattr( + DeepseekV4FlashMLAAttention, + "_forward_sparse_mla_swa_decode_triton", + classmethod(fake_decode), + ) + + q = torch.empty(1, 2, 512) + output = torch.empty_like(q) + + layer._forward_decode( + q=q, + kv_cache=None, + swa_metadata=metadata, + attn_metadata=None, + swa_only=True, + output=output, + ) + + assert calls == [(layer, (1, 1, 2, 512), (1, 1, 512), metadata, (1, 2, 512))] + + +def test_compressed_decode_uses_triton_path_without_flashmla_tile_sched(monkeypatch): + layer = _make_layer(compress_ratio=128) + metadata = _make_swa_metadata() + attn_metadata = SimpleNamespace( + block_size=256, + c128a_global_decode_topk_indices=torch.zeros((1, 1, 2), dtype=torch.int32), + c128a_decode_topk_lens=torch.ones(1, dtype=torch.int32), + ) + kv_cache = torch.empty(1, 1, 512) + calls = [] + + def fake_decode( + cls, + layer, + q, + compressed_k_cache, + swa_k_cache, + topk_indices, + topk_lens, + swa_metadata, + attn_metadata, + output, + ): + calls.append( + ( + layer, + q.shape, + compressed_k_cache.shape, + swa_k_cache.shape, + topk_indices.shape, + topk_lens.shape, + swa_metadata, + attn_metadata, + output.shape, + ) + ) + + monkeypatch.setattr(flashmla_mod, "is_triton_sparse_mla_enabled", lambda _: True) + monkeypatch.setattr( + DeepseekV4FlashMLAAttention, + "_forward_sparse_mla_compressed_decode_triton", + classmethod(fake_decode), + ) + + q = torch.empty(1, 2, 512) + output = torch.empty_like(q) + + layer._forward_decode( + q=q, + kv_cache=kv_cache, + swa_metadata=metadata, + attn_metadata=attn_metadata, + swa_only=False, + output=output, + ) + + assert calls == [ + ( + layer, + (1, 1, 2, 512), + (1, 1, 512), + (1, 1, 512), + (1, 1, 2), + (1,), + metadata, + attn_metadata, + (1, 2, 512), + ) + ] diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index 29b1bcbc5c11..ec61458f6900 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -690,9 +690,38 @@ def _forward_decode( # Use unsqueeze to preserve strides (handles padded blocks correctly) swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2) # Reshape KV cache to (num_blocks, block_size, 1, head_bytes) + compressed_k_cache = kv_cache if kv_cache is not None: kv_cache = kv_cache.unsqueeze(-2) + if is_triton_sparse_mla_enabled(q.device): + if swa_only: + self._forward_sparse_mla_swa_decode_triton( + layer=self, + q=q, + swa_k_cache=self.swa_cache_layer.kv_cache, + swa_metadata=swa_metadata, + output=output, + ) + return + if self.compress_ratio in (4, 128): + assert compressed_k_cache is not None + assert attn_metadata is not None + assert topk_indices is not None + assert topk_lens is not None + self._forward_sparse_mla_compressed_decode_triton( + layer=self, + q=q, + compressed_k_cache=compressed_k_cache, + swa_k_cache=self.swa_cache_layer.kv_cache, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_metadata=swa_metadata, + attn_metadata=attn_metadata, + output=output, + ) + return + # One FlashMLASchedMeta per layer type, shared across all same-type # layers within this decode step. The first forward call per type # triggers the in-kernel planner (allocating tile_scheduler_metadata From aac3bdfa6ec9b57d9bc6916d7f93167d7195e3c3 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 5 Jun 2026 17:03:21 +0800 Subject: [PATCH 082/131] config: skip breakable cudagraph auto-enable on SM121 Signed-off-by: jasl --- .../test_deepseek_v4_cudagraph_config.py | 50 +++++++++++++++++++ vllm/config/vllm.py | 49 ++++++++++++++---- 2 files changed, 88 insertions(+), 11 deletions(-) create mode 100644 tests/config/test_deepseek_v4_cudagraph_config.py diff --git a/tests/config/test_deepseek_v4_cudagraph_config.py b/tests/config/test_deepseek_v4_cudagraph_config.py new file mode 100644 index 000000000000..ddbd6c3f010f --- /dev/null +++ b/tests/config/test_deepseek_v4_cudagraph_config.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from types import SimpleNamespace + +from vllm.config.vllm import _should_auto_enable_deepseek_v4_breakable_cudagraph +from vllm.platforms import current_platform + + +def _model_config(*architectures: str): + return SimpleNamespace(architectures=list(architectures)) + + +def test_deepseek_v4_auto_enables_breakable_cudagraph_off_sm121(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_device_capability", + lambda capability, device_id=0: False, + ) + + assert _should_auto_enable_deepseek_v4_breakable_cudagraph( + _model_config("DeepseekV4ForCausalLM") + ) + assert _should_auto_enable_deepseek_v4_breakable_cudagraph( + _model_config("DeepSeekV4MTPModel") + ) + + +def test_deepseek_v4_skips_breakable_cudagraph_on_sm121(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_device_capability", + lambda capability, device_id=0: capability == 121, + ) + + assert not _should_auto_enable_deepseek_v4_breakable_cudagraph( + _model_config("DeepseekV4ForCausalLM") + ) + + +def test_non_deepseek_v4_does_not_auto_enable_breakable_cudagraph(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_device_capability", + lambda capability, device_id=0: False, + ) + + assert not _should_auto_enable_deepseek_v4_breakable_cudagraph( + _model_config("Qwen3ForCausalLM") + ) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ba7d26c93b23..634f0e9281ec 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -76,6 +76,13 @@ } ) +DEEPSEEK_V4_CUDAGRAPH_ARCHITECTURES = frozenset( + { + "DeepseekV4ForCausalLM", + "DeepSeekV4MTPModel", + } +) + class OptimizationLevel(IntEnum): """Optimization level enum.""" @@ -104,6 +111,20 @@ class OptimizationLevel(IntEnum): # See https://github.com/vllm-project/vllm/issues/25689. +def _should_auto_enable_deepseek_v4_breakable_cudagraph( + model_config: ModelConfig, +) -> bool: + if not any( + arch in DEEPSEEK_V4_CUDAGRAPH_ARCHITECTURES + for arch in model_config.architectures + ): + return False + + from vllm.platforms import current_platform + + return not current_platform.is_device_capability(121) + + def enable_norm_fusion(cfg: "VllmConfig") -> bool: """Enable if either RMS norm or quant FP8 custom op is active; otherwise Inductor handles fusion.""" @@ -1073,21 +1094,27 @@ def __post_init__(self): ) self.compilation_config.mode = CompilationMode.NONE - # For model classes don't carry @support_torch_compile — - # the breakable cudagraph is the supported PIECEWISE path. Auto-enable - # it unless the user has explicitly opted out via the env var. + # DeepSeek V4's model classes don't carry @support_torch_compile. + # On SM120 the breakable cudagraph is the supported PIECEWISE path; + # on tested SM121/GB10 Ray configs the compiled PIECEWISE path is + # required for correctness (breakable can corrupt graph replay). + # Auto-enable only for the known-good DeepSeek V4 device path; MiniMax + # M3 retains its upstream unconditional auto-enable. if ( self.model_config is not None and "VLLM_USE_BREAKABLE_CUDAGRAPH" not in os.environ - and any( - a - in ( - "DeepseekV4ForCausalLM", - "DeepSeekV4MTPModel", - "MiniMaxM3SparseForCausalLM", - "MiniMaxM3SparseForConditionalGeneration", + and ( + _should_auto_enable_deepseek_v4_breakable_cudagraph( + self.model_config + ) + or any( + a + in ( + "MiniMaxM3SparseForCausalLM", + "MiniMaxM3SparseForConditionalGeneration", + ) + for a in self.model_config.architectures ) - for a in self.model_config.architectures ) ): os.environ["VLLM_USE_BREAKABLE_CUDAGRAPH"] = "1" From e1d6dc859b2cd6381b4353cd3b59e6383be288e5 Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 6 Jun 2026 04:32:52 +0800 Subject: [PATCH 083/131] deepseek-v4: preserve ubatch prefill metadata Signed-off-by: jasl --- .../test_indexer_deepseek_v4_slot_mapping.py | 60 +++++++++++++++++++ tests/v1/worker/test_ubatch_utils.py | 33 ++++++++++ vllm/v1/attention/backends/mla/indexer.py | 5 ++ vllm/v1/worker/ubatch_utils.py | 12 ++++ 4 files changed, 110 insertions(+) create mode 100644 tests/v1/worker/test_ubatch_utils.py diff --git a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py index 159bb8af3fb9..b0d4b542ec61 100644 --- a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py +++ b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py @@ -4,12 +4,72 @@ import pytest import torch +import vllm.v1.attention.backends.mla.indexer as indexer_module from tests.v1.attention.utils import create_vllm_config from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadataBuilder from vllm.v1.kv_cache_interface import MLAAttentionSpec +@pytest.mark.parametrize( + ("is_prefilling", "expected_treat_short_extends_as_decodes"), + [ + (torch.tensor([False, False]), True), + (torch.tensor([False, True]), False), + ], +) +def test_indexer_builder_keeps_short_prefill_continuations_as_prefills( + monkeypatch, + is_prefilling, + expected_treat_short_extends_as_decodes, +): + builder = object.__new__(DeepseekV32IndexerMetadataBuilder) + builder.reorder_batch_threshold = 1 + builder.use_flattening = False + + captured = {} + + def fake_split_decodes_and_prefills( + common_attn_metadata, + *, + decode_threshold=1, + require_uniform=False, + treat_short_extends_as_decodes=True, + ): + captured["treat_short_extends_as_decodes"] = ( + treat_short_extends_as_decodes + ) + raise RuntimeError("stop after split_decodes_and_prefills") + + monkeypatch.setattr( + indexer_module, + "split_decodes_and_prefills", + fake_split_decodes_and_prefills, + ) + query_start_loc = torch.tensor([0, 1, 2], dtype=torch.int32) + metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc.clone(), + seq_lens=torch.tensor([128, 129], dtype=torch.int32), + num_reqs=2, + num_actual_tokens=2, + max_query_len=1, + max_seq_len=129, + block_table_tensor=torch.zeros((2, 1), dtype=torch.int32), + slot_mapping=torch.arange(2, dtype=torch.int64), + is_prefilling=is_prefilling, + seq_lens_cpu_upper_bound=torch.tensor([128, 129], dtype=torch.int32), + ) + + with pytest.raises(RuntimeError, match="stop after"): + builder.build(common_prefix_len=0, common_attn_metadata=metadata) + + assert ( + captured["treat_short_extends_as_decodes"] + is expected_treat_short_extends_as_decodes + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") def test_indexer_builder_deepseek_v4_compressed_slot_mapping_uses_storage_block_size(): """Regression test: DeepseekV4 compression path must compute slot_mapping from diff --git a/tests/v1/worker/test_ubatch_utils.py b/tests/v1/worker/test_ubatch_utils.py new file mode 100644 index 000000000000..f0bc987a7c73 --- /dev/null +++ b/tests/v1/worker/test_ubatch_utils.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.v1.attention.backend import CommonAttentionMetadata +from vllm.v1.worker.ubatch_utils import UBatchSlice, split_attn_metadata + + +def test_split_attn_metadata_preserves_position_and_prefill_flags(): + query_start_loc = torch.tensor([0, 3, 7, 9], dtype=torch.int32) + metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc.clone(), + seq_lens=torch.tensor([16, 24, 32], dtype=torch.int32), + num_reqs=3, + num_actual_tokens=9, + max_query_len=4, + max_seq_len=32, + block_table_tensor=torch.arange(12, dtype=torch.int32).reshape(3, 4), + slot_mapping=torch.arange(9, dtype=torch.int64), + positions=torch.arange(100, 109, dtype=torch.int64), + is_prefilling=torch.tensor([False, True, True]), + seq_lens_cpu_upper_bound=torch.tensor([16, 24, 32], dtype=torch.int32), + ) + + split = split_attn_metadata( + [UBatchSlice(request_slice=slice(1, 3), token_slice=slice(3, 9))], + metadata, + )[0] + + torch.testing.assert_close(split.positions, metadata.positions[3:9]) + torch.testing.assert_close(split.is_prefilling, metadata.is_prefilling[1:3]) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index b2719cbbe7dc..150ea1d10555 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -502,12 +502,17 @@ def build( seq_lens = common_attn_metadata.seq_lens slot_mapping = common_attn_metadata.slot_mapping block_table = common_attn_metadata.block_table_tensor + has_prefilling_rows = ( + common_attn_metadata.is_prefilling is not None + and torch.any(common_attn_metadata.is_prefilling).item() + ) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=self.reorder_batch_threshold, require_uniform=not self.use_flattening, + treat_short_extends_as_decodes=not has_prefilling_rows, ) ) diff --git a/vllm/v1/worker/ubatch_utils.py b/vllm/v1/worker/ubatch_utils.py index f4a76529023c..0140a7afe550 100644 --- a/vllm/v1/worker/ubatch_utils.py +++ b/vllm/v1/worker/ubatch_utils.py @@ -231,6 +231,16 @@ def _make_metadata_with_slice( block_table_tensor = attn_metadata.block_table_tensor[request_slice] slot_mapping = attn_metadata.slot_mapping[token_slice] + positions = ( + attn_metadata.positions[token_slice] + if attn_metadata.positions is not None + else None + ) + is_prefilling = ( + attn_metadata.is_prefilling[request_slice] + if attn_metadata.is_prefilling is not None + else None + ) return CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -242,6 +252,8 @@ def _make_metadata_with_slice( max_seq_len=max_seq_len, block_table_tensor=block_table_tensor, slot_mapping=slot_mapping, + positions=positions, + is_prefilling=is_prefilling, seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound, _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=num_computed_tokens_cpu, From 08b329c5346fe857345da650f6a5a18a755497d5 Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 6 Jun 2026 04:54:36 +0800 Subject: [PATCH 084/131] deepseek-v4: defunctionalize fused MLA insert op Signed-off-by: jasl --- .../compile/passes/test_functionalization.py | 60 +++++++++++++++++++ .../passes/utility/fix_functionalization.py | 16 +++++ 2 files changed, 76 insertions(+) diff --git a/tests/compile/passes/test_functionalization.py b/tests/compile/passes/test_functionalization.py index 31bf225d4135..6369cb8c980f 100644 --- a/tests/compile/passes/test_functionalization.py +++ b/tests/compile/passes/test_functionalization.py @@ -251,12 +251,72 @@ def ops_not_in_model(self): return [] +class TestFusedDeepseekV4QnormRopeKvInsert(torch.nn.Module): + OP_REGISTERED = False + + def __init__(self): + super().__init__() + self.register_test_custom_op() + + @classmethod + def register_test_custom_op(cls): + if not cls.OP_REGISTERED: + + def fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_impl( + q: torch.Tensor, + kv: torch.Tensor, + k_cache: torch.Tensor, + ) -> None: + q.add_(kv) + k_cache.add_(kv) + + def fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_fake( + q: torch.Tensor, + kv: torch.Tensor, + k_cache: torch.Tensor, + ) -> None: + return None + + direct_register_custom_op( + op_name="fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert", + op_func=fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_impl, + mutates_args=["q", "k_cache"], + fake_impl=fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_fake, + ) + + cls.OP_REGISTERED = True + + def forward( + self, q: torch.Tensor, kv: torch.Tensor, k_cache: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( + q, kv, k_cache + ) + return q, k_cache + + def example_inputs(self, num_tokens=32, hidden_size=128): + return ( + torch.rand(num_tokens, hidden_size), + torch.rand(num_tokens, hidden_size), + torch.rand(num_tokens, hidden_size), + ) + + def ops_in_model(self, do_fusion): + return [ + torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert.default + ] + + def ops_not_in_model(self): + return [] + + MODELS_AND_DO_FUSION = { TestSiluMul: [True, False], TestFusedAddRMSNorm: [True, False], TestRotaryEmbedding: [False], TestRotaryEmbeddingSliceScatter: [False], TestFunctionWithMutatedArgsAndReturn: [False], + TestFusedDeepseekV4QnormRopeKvInsert: [False], } diff --git a/vllm/compilation/passes/utility/fix_functionalization.py b/vllm/compilation/passes/utility/fix_functionalization.py index c0643a916b38..544a99afd2a7 100644 --- a/vllm/compilation/passes/utility/fix_functionalization.py +++ b/vllm/compilation/passes/utility/fix_functionalization.py @@ -39,11 +39,24 @@ def __call__(self, graph: torch.fx.Graph) -> None: count = 0 rope_targets = [torch.ops._C.rotary_embedding.default] + fused_deepseek_v4_mla_targets = [] if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"): rope_targets.append( torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default ) + if hasattr( + torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert" + ): + fused_deepseek_v4_mla_targets.append( + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert.default + ) + if hasattr( + torch.ops.vllm, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert" + ): + fused_deepseek_v4_mla_targets.append( + torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert.default + ) for node in graph.nodes: if not is_func(node, auto_functionalized): @@ -181,6 +194,9 @@ def __call__(self, graph: torch.fx.Graph) -> None: 2: "key", } self.defunctionalize(graph, node, mutated_args=mutated_args) + elif at_target in fused_deepseek_v4_mla_targets: + mutated_args = {1: "q", 2: "k_cache"} + self.defunctionalize(graph, node, mutated_args) elif ( hasattr(torch.ops.vllm, "fused_rope_unified_mla_kv_cache_update") and at_target From 2aa21fbb5c10a36546d4803e65c28e3447ec2427 Mon Sep 17 00:00:00 2001 From: jasl Date: Sun, 7 Jun 2026 07:30:13 +0800 Subject: [PATCH 085/131] deepseek-v4: enable chunked D512 sparse MLA prefill Signed-off-by: jasl --- .../attention/test_sparse_mla_indexed_d512.py | 109 ++++++++ vllm/envs.py | 4 + vllm/models/deepseek_v4/nvidia/flashmla.py | 126 +++++++++- .../backends/mla/sparse_mla_kernels.py | 236 ++++++++++++++++++ 4 files changed, 473 insertions(+), 2 deletions(-) diff --git a/tests/v1/attention/test_sparse_mla_indexed_d512.py b/tests/v1/attention/test_sparse_mla_indexed_d512.py index 531c4146f3ee..8486de516281 100644 --- a/tests/v1/attention/test_sparse_mla_indexed_d512.py +++ b/tests/v1/attention/test_sparse_mla_indexed_d512.py @@ -5,6 +5,7 @@ import torch from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_indexed_d512_chunked_sparse_mla_attention, accumulate_indexed_d512_split_sparse_mla_attention, accumulate_indexed_sparse_mla_attention_chunk, ) @@ -180,3 +181,111 @@ def test_indexed_d512_split_sparse_mla_matches_c128_combined_width(): torch.testing.assert_close(split_max, current_max, atol=2e-5, rtol=2e-5) torch.testing.assert_close(split_denom, current_denom, atol=2e-3, rtol=2e-3) torch.testing.assert_close(split, current, atol=2e-3, rtol=2e-3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_indexed_d512_chunked_sparse_mla_matches_wide_indexed_accumulate(): + torch.cuda.set_device(0) + device = torch.device("cuda:0") + torch.manual_seed(29) + num_tokens = 64 + num_heads = 8 + head_dim = 512 + num_candidates = 2048 + chunk_candidates = 1152 + kv_tokens = 8192 + scale = head_dim**-0.5 + + q = torch.randn( + num_tokens, + num_heads, + head_dim, + device=device, + dtype=torch.bfloat16, + ) + kv = torch.randn(kv_tokens, head_dim, device=device, dtype=torch.bfloat16) + indices = torch.randint( + 0, + kv_tokens, + (num_tokens, num_candidates), + device=device, + dtype=torch.int32, + ) + lens = torch.randint( + num_candidates // 2, + num_candidates + 1, + (num_tokens,), + device=device, + dtype=torch.int32, + ) + lens[0] = 0 + + current_max = torch.full( + (num_tokens, num_heads), + -float("inf"), + device=device, + dtype=torch.float32, + ) + current_denom = torch.zeros_like(current_max) + current_acc = torch.zeros( + num_tokens, num_heads, head_dim, device=device, dtype=torch.float32 + ) + chunked_max = torch.empty_like(current_max) + chunked_denom = torch.empty_like(current_denom) + chunked_acc = torch.empty_like(current_acc) + chunk_max = torch.empty_like(current_max) + chunk_denom = torch.empty_like(current_denom) + chunk_acc = torch.empty_like(current_acc) + chunk_scores = torch.empty( + num_tokens, + num_heads, + chunk_candidates, + device=device, + dtype=torch.float32, + ) + + accumulate_indexed_sparse_mla_attention_chunk( + q=q, + kv_flat=kv, + indices=indices, + lens=lens, + scale=scale, + max_score=current_max, + denom=current_denom, + acc=current_acc, + ) + accumulate_indexed_d512_chunked_sparse_mla_attention( + q=q, + kv_flat=kv, + indices=indices, + lens=lens, + scale=scale, + max_score=chunked_max, + denom=chunked_denom, + acc=chunked_acc, + scores=chunk_scores, + chunk_max_score=chunk_max, + chunk_denom=chunk_denom, + chunk_acc=chunk_acc, + ) + torch.cuda.synchronize() + + valid_rows = lens > 0 + current = current_acc[valid_rows] / current_denom[valid_rows, :, None] + chunked = chunked_acc[valid_rows] / chunked_denom[valid_rows, :, None] + torch.testing.assert_close( + chunked_max[valid_rows], + current_max[valid_rows], + atol=2e-5, + rtol=2e-5, + ) + torch.testing.assert_close( + chunked_denom[valid_rows], + current_denom[valid_rows], + atol=2e-3, + rtol=2e-3, + ) + torch.testing.assert_close(chunked, current, atol=2e-3, rtol=2e-3) + assert torch.isneginf(chunked_max[~valid_rows]).all() + assert torch.count_nonzero(chunked_denom[~valid_rows]) == 0 + assert torch.count_nonzero(chunked_acc[~valid_rows]) == 0 diff --git a/vllm/envs.py b/vllm/envs.py index 390f6e877574..2757a7e26ac6 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -181,6 +181,7 @@ VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: bool = True VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL: bool = True + VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL: bool = True VLLM_TRITON_MLA_SPARSE: bool | None = None VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE: int = 512 VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE: int = 256 @@ -1457,6 +1458,9 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL": lambda: bool( int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL", "1")) ), + "VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL": lambda: bool( + int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL", "1")) + ), # Experimental sparse MLA fallback controls. # ``VLLM_TRITON_MLA_SPARSE`` unset means auto-select where FlashMLA sparse # is unavailable; set 0/1 to force-disable/force-enable the fallback. diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index ec61458f6900..44a5ccbd4eb3 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -34,6 +34,7 @@ ) from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead, + accumulate_indexed_d512_chunked_sparse_mla_attention, accumulate_indexed_d512_split_sparse_mla_attention, accumulate_indexed_sparse_mla_attention_chunk, build_combined_sparse_mla_decode_valid_mask, @@ -78,6 +79,27 @@ def _use_indexed_d512_split_prefill( ) +def _use_indexed_d512_chunked_prefill( + *, + compress_ratio: int, + head_dim: int, + num_prefills: int, + combined_topk: int, + max_prefill_seq_len: int, + swa_only: bool, +) -> bool: + return ( + envs.VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL + and envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL + and not swa_only + and compress_ratio in (4, 128) + and head_dim == 512 + and num_prefills == 1 + and combined_topk > _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK + and max_prefill_seq_len >= _INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS + ) + + def _sparse_mla_prefill_gather_len_upper_bound( *, max_model_len: int, @@ -203,6 +225,29 @@ def _prefill_workspace_reservation_specs( specs.append( ((query_chunk_size, num_heads, combined_topk), torch.float32) ) + elif _use_indexed_d512_chunked_prefill( + compress_ratio=compress_ratio, + head_dim=head_dim, + num_prefills=1, + combined_topk=combined_topk, + max_prefill_seq_len=max_model_len, + swa_only=False, + ): + chunked_score_width = min( + combined_topk, + _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK, + ) + specs.extend( + ( + ( + (query_chunk_size, num_heads, chunked_score_width), + torch.float32, + ), + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads, head_dim), torch.float32), + ) + ) return tuple(specs) @classmethod @@ -573,6 +618,7 @@ def _forward_sparse_mla_prefill_triton( else: max_score_buffer, denom_buffer, output_buffer = state_buffers[:3] indexed_d512_scores = None + indexed_d512_chunked_buffers = None if ( state_buffers is not None and envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL @@ -585,6 +631,17 @@ def _forward_sparse_mla_prefill_triton( and len(state_buffers) == 4 ): indexed_d512_scores = state_buffers[3] + elif ( + state_buffers is not None + and envs.VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL + and envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL + and layer.compress_ratio in (4, 128) + and q.shape[-1] == 512 + and kv.shape[0] == 1 + and combined_indices.shape[-1] > _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK + and len(state_buffers) == 7 + ): + indexed_d512_chunked_buffers = state_buffers[3:7] for token_start in range(0, q.shape[0], query_chunk_size): token_end = min(token_start + query_chunk_size, q.shape[0]) @@ -595,7 +652,17 @@ def _forward_sparse_mla_prefill_triton( max_score = max_score_buffer[:num_tokens] denom = denom_buffer[:num_tokens] subset_acc = output_buffer[:num_tokens] - if indexed_d512_scores is not None: + can_use_indexed_d512_scores = ( + indexed_d512_scores is not None + and indexed_d512_scores.shape[0] >= num_tokens + and indexed_d512_scores.shape[2] >= combined_indices.shape[-1] + ) + can_use_indexed_d512_chunked = ( + indexed_d512_chunked_buffers is not None + and indexed_d512_chunked_buffers[0].shape[0] >= num_tokens + ) + if can_use_indexed_d512_scores: + assert indexed_d512_scores is not None accumulate_indexed_d512_split_sparse_mla_attention( q=q_chunk, kv_flat=kv_flat, @@ -609,6 +676,28 @@ def _forward_sparse_mla_prefill_triton( :num_tokens, :, : combined_indices.shape[-1] ], ) + elif can_use_indexed_d512_chunked: + assert indexed_d512_chunked_buffers is not None + ( + indexed_d512_scores, + chunk_max_score, + chunk_denom, + chunk_acc, + ) = indexed_d512_chunked_buffers + accumulate_indexed_d512_chunked_sparse_mla_attention( + q=q_chunk, + kv_flat=kv_flat, + indices=indices_chunk_full, + lens=lens_chunk, + scale=layer.scale, + max_score=max_score, + denom=denom, + acc=subset_acc, + scores=indexed_d512_scores[:num_tokens], + chunk_max_score=chunk_max_score[:num_tokens], + chunk_denom=chunk_denom[:num_tokens], + chunk_acc=chunk_acc[:num_tokens], + ) else: max_score.fill_(float("-inf")) denom.zero_() @@ -839,8 +928,13 @@ def _forward_prefill( workspace_manager = current_workspace_manager() triton_sparse_mla_enabled = is_triton_sparse_mla_enabled(q.device) + indexed_d512_split_prefill = False + indexed_d512_chunked_prefill = False if triton_sparse_mla_enabled: - query_chunk_size = min(q.shape[0], triton_sparse_mla_query_chunk_size()) + query_chunk_size = min( + max_query_chunk_tokens, + triton_sparse_mla_query_chunk_size(), + ) indexed_d512_split_prefill = _use_indexed_d512_split_prefill( compress_ratio=int(self.compress_ratio), head_dim=int(self.head_dim), @@ -849,6 +943,15 @@ def _forward_prefill( max_prefill_seq_len=int(seq_lens_cpu.max().item()), swa_only=swa_only, ) + if not indexed_d512_split_prefill: + indexed_d512_chunked_prefill = _use_indexed_d512_chunked_prefill( + compress_ratio=int(self.compress_ratio), + head_dim=int(self.head_dim), + num_prefills=int(num_prefills), + combined_topk=int(combined_topk), + max_prefill_seq_len=int(seq_lens_cpu.max().item()), + swa_only=swa_only, + ) extra_specs: list[tuple[tuple[int, ...], torch.dtype]] = [] if indexed_d512_split_prefill: extra_specs.append( @@ -857,6 +960,25 @@ def _forward_prefill( torch.float32, ) ) + elif indexed_d512_chunked_prefill: + chunked_score_width = min( + combined_topk, + _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK, + ) + extra_specs.extend( + ( + ( + (query_chunk_size, self.n_local_heads, chunked_score_width), + torch.float32, + ), + ((query_chunk_size, self.n_local_heads), torch.float32), + ((query_chunk_size, self.n_local_heads), torch.float32), + ( + (query_chunk_size, self.n_local_heads, q.shape[-1]), + torch.float32, + ), + ) + ) ( kv, combined_indices_buffer, diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py index 01f01170218c..f0964e46d3a9 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -2006,6 +2006,242 @@ def accumulate_indexed_d512_split_sparse_mla_attention( ) +@triton.jit +def _indexed_d512_chunked_merge_acc_kernel( + max_score_ptr, + acc_ptr, + chunk_max_score_ptr, + chunk_acc_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + dim_block = tl.program_id(2) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = dim_block * BLOCK_D + tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + + running_max = tl.load( + max_score_ptr + token_idx * stride_state_t + head_offsets * stride_state_h, + mask=head_mask, + other=-float("inf"), + ).to(tl.float32) + chunk_max = tl.load( + chunk_max_score_ptr + + token_idx * stride_state_t + + head_offsets * stride_state_h, + mask=head_mask, + other=-float("inf"), + ).to(tl.float32) + next_max = tl.maximum(running_max, chunk_max) + running_valid = running_max != -float("inf") + chunk_valid = chunk_max != -float("inf") + running_scale = tl.where(running_valid, tl.exp(running_max - next_max), 0.0) + chunk_scale = tl.where(chunk_valid, tl.exp(chunk_max - next_max), 0.0) + + running_acc = tl.load( + acc_ptr + + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d, + mask=head_mask[:, None] & dim_mask[None, :], + other=0.0, + ).to(tl.float32) + chunk_acc = tl.load( + chunk_acc_ptr + + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d, + mask=head_mask[:, None] & dim_mask[None, :], + other=0.0, + ).to(tl.float32) + merged_acc = ( + running_acc * running_scale[:, None] + chunk_acc * chunk_scale[:, None] + ) + tl.store( + acc_ptr + + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d, + merged_acc, + mask=head_mask[:, None] & dim_mask[None, :], + ) + + +@triton.jit +def _indexed_d512_chunked_merge_state_kernel( + max_score_ptr, + denom_ptr, + chunk_max_score_ptr, + chunk_denom_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + num_heads: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + head_mask = head_idx < num_heads + + running_max = tl.load( + max_score_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + mask=head_mask, + other=-float("inf"), + ).to(tl.float32) + running_denom = tl.load( + denom_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + mask=head_mask, + other=0.0, + ).to(tl.float32) + chunk_max = tl.load( + chunk_max_score_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + mask=head_mask, + other=-float("inf"), + ).to(tl.float32) + chunk_denom = tl.load( + chunk_denom_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + mask=head_mask, + other=0.0, + ).to(tl.float32) + next_max = tl.maximum(running_max, chunk_max) + running_valid = running_max != -float("inf") + chunk_valid = chunk_max != -float("inf") + running_scale = tl.where(running_valid, tl.exp(running_max - next_max), 0.0) + chunk_scale = tl.where(chunk_valid, tl.exp(chunk_max - next_max), 0.0) + next_denom = running_denom * running_scale + chunk_denom * chunk_scale + + tl.store( + max_score_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + next_max, + mask=head_mask, + ) + tl.store( + denom_ptr + token_idx * stride_state_t + head_idx * stride_state_h, + next_denom, + mask=head_mask, + ) + + +def accumulate_indexed_d512_chunked_sparse_mla_attention( + q: torch.Tensor, + kv_flat: torch.Tensor, + indices: torch.Tensor, + lens: torch.Tensor, + scale: float, + scores: torch.Tensor, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + chunk_max_score: torch.Tensor, + chunk_denom: torch.Tensor, + chunk_acc: torch.Tensor, + head_block_size: int = 32, + candidate_block_size: int = 64, + value_block_size: int = 128, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv_flat.dim() == 2 + assert indices.dim() == 2 + assert indices.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert kv_flat.shape[-1] == q.shape[-1] + assert q.shape[-1] == 512 + assert scores.shape[0] == q.shape[0] + assert scores.shape[1] == max_score.shape[1] + assert 0 < scores.shape[2] <= 1152 + assert max_score.shape == denom.shape == chunk_max_score.shape == chunk_denom.shape + assert acc.shape == chunk_acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert chunk_max_score.dtype == torch.float32 + assert chunk_denom.dtype == torch.float32 + assert chunk_acc.dtype == torch.float32 + assert scores.dtype == torch.float32 + assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda + assert scores.is_cuda and max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert chunk_max_score.is_cuda and chunk_denom.is_cuda and chunk_acc.is_cuda + assert head_block_size in (8, 16, 32) + assert candidate_block_size in (32, 64, 128) + assert value_block_size in (32, 64, 128) + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + chunk_size = scores.shape[2] + max_score.fill_(float("-inf")) + denom.zero_() + acc.zero_() + + merge_acc_grid = ( + num_tokens, + triton.cdiv(num_heads, head_block_size), + triton.cdiv(head_dim, value_block_size), + ) + merge_state_grid = (num_tokens, num_heads) + for candidate_start in range(0, indices.shape[1], chunk_size): + candidate_end = min(candidate_start + chunk_size, indices.shape[1]) + chunk_candidates = candidate_end - candidate_start + chunk_lens = torch.clamp( + lens - candidate_start, + min=0, + max=chunk_candidates, + ) + accumulate_indexed_d512_split_sparse_mla_attention( + q=q, + kv_flat=kv_flat, + indices=indices[:, candidate_start:candidate_end], + lens=chunk_lens, + scale=scale, + scores=scores[:, :, :chunk_candidates], + max_score=chunk_max_score, + denom=chunk_denom, + acc=chunk_acc, + head_block_size=head_block_size, + candidate_block_size=candidate_block_size, + value_block_size=value_block_size, + ) + _indexed_d512_chunked_merge_acc_kernel[merge_acc_grid]( + max_score, + acc, + chunk_max_score, + chunk_acc, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + HEAD_BLOCK=head_block_size, + BLOCK_D=value_block_size, + num_warps=4, + num_stages=3, + ) + _indexed_d512_chunked_merge_state_kernel[merge_state_grid]( + max_score, + denom, + chunk_max_score, + chunk_denom, + max_score.stride(0), + max_score.stride(1), + num_heads, + num_warps=4, + num_stages=3, + ) + + @triton.autotune( configs=[ triton.Config({}, num_warps=4, num_stages=2), From ed520eaf3f34c8c9a481c221a4a30673db15d12f Mon Sep 17 00:00:00 2001 From: jasl Date: Sun, 7 Jun 2026 16:35:33 +0800 Subject: [PATCH 086/131] deepseek-v4: align sparse MLA metadata after upstream split Signed-off-by: jasl --- .../test_deepseek_v4_sparse_mla_metadata.py | 45 +++++ vllm/models/deepseek_v4/nvidia/flashmla.py | 2 +- vllm/models/deepseek_v4/sparse_mla.py | 65 +++++-- .../attention/backends/mla/flashmla_sparse.py | 169 ------------------ 4 files changed, 100 insertions(+), 181 deletions(-) create mode 100644 tests/model_executor/test_deepseek_v4_sparse_mla_metadata.py diff --git a/tests/model_executor/test_deepseek_v4_sparse_mla_metadata.py b/tests/model_executor/test_deepseek_v4_sparse_mla_metadata.py new file mode 100644 index 000000000000..0cb1431af1c7 --- /dev/null +++ b/tests/model_executor/test_deepseek_v4_sparse_mla_metadata.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.models.deepseek_v4 import sparse_mla + + +def test_c128a_effective_topk_width_uses_current_positions() -> None: + assert ( + sparse_mla._c128a_effective_topk_width( + positions=torch.tensor([0, 126], dtype=torch.int64), + compress_ratio=128, + max_compressed_tokens=4096, + alignment=128, + ) + == 128 + ) + assert ( + sparse_mla._c128a_effective_topk_width( + positions=torch.tensor([127, 1023], dtype=torch.int64), + compress_ratio=128, + max_compressed_tokens=4096, + alignment=128, + ) + == 128 + ) + assert ( + sparse_mla._c128a_effective_topk_width( + positions=torch.tensor([524287], dtype=torch.int64), + compress_ratio=128, + max_compressed_tokens=8192, + alignment=128, + ) + == 4096 + ) + assert ( + sparse_mla._c128a_effective_topk_width( + positions=torch.tensor([1048575], dtype=torch.int64), + compress_ratio=128, + max_compressed_tokens=8192, + alignment=128, + ) + == 8192 + ) diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index 44a5ccbd4eb3..5607112acea3 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -410,7 +410,7 @@ def _forward_sparse_mla_compressed_decode_triton( topk_indices: torch.Tensor, topk_lens: torch.Tensor, swa_metadata: "DeepseekSparseSWAMetadata", - attn_metadata: FlashMLASparseMetadata, + attn_metadata: DeepseekV4FlashMLAMetadata, output: torch.Tensor, ) -> None: if layer.compress_ratio not in (4, 128): diff --git a/vllm/models/deepseek_v4/sparse_mla.py b/vllm/models/deepseek_v4/sparse_mla.py index ca14fe20b138..066b6d1ce55e 100644 --- a/vllm/models/deepseek_v4/sparse_mla.py +++ b/vllm/models/deepseek_v4/sparse_mla.py @@ -319,11 +319,31 @@ def build_c128a_topk_metadata( num_tokens = positions.shape[0] num_prefill_tokens = num_tokens - num_decode_tokens - global_decode = global_decode_buffer[:num_decode_tokens] + if num_tokens == 0: + global_decode = global_decode_buffer[:num_decode_tokens, :0] + decode_lens = decode_lens_buffer[:num_decode_tokens] + prefill_local = prefill_buffer[:num_prefill_tokens, :0] + return global_decode, decode_lens, prefill_local + + KERNEL_BLOCK_SIZE = 1024 + effective_topk = _c128a_effective_topk_width( + positions=positions, + compress_ratio=compress_ratio, + max_compressed_tokens=max_compressed_tokens, + alignment=_C128A_TOPK_ALIGNMENT, + ) + + global_decode = global_decode_buffer[:num_decode_tokens, :effective_topk] decode_lens = decode_lens_buffer[:num_decode_tokens] - prefill_local = prefill_buffer[:num_prefill_tokens] + prefill_local = prefill_buffer[:num_prefill_tokens, :effective_topk] - if num_tokens == 0: + if num_decode_tokens > 0: + global_decode.fill_(-1) + decode_lens.zero_() + if num_prefill_tokens > 0: + prefill_local.fill_(-1) + + if effective_topk == 0: return global_decode, decode_lens, prefill_local _build_c128a_topk_metadata_kernel[(num_tokens,)]( @@ -334,18 +354,41 @@ def build_c128a_topk_metadata( prefill_buffer.stride(0), positions, compress_ratio, - max_compressed_tokens, + effective_topk, num_decode_tokens, token_to_req_indices, block_table, block_table.stride(0), block_size, slot_mapping, - BLOCK_SIZE=1024, + BLOCK_SIZE=KERNEL_BLOCK_SIZE, ) return global_decode, decode_lens, prefill_local +def _c128a_effective_topk_width( + *, + positions: torch.Tensor, + compress_ratio: int, + max_compressed_tokens: int, + alignment: int, +) -> int: + """Return the aligned C128A top-k width needed by current tokens.""" + if positions.numel() == 0: + return 0 + max_pos = int(positions.max().item()) + max_num_compressed = min( + max((max_pos + 1) // int(compress_ratio), 0), + int(max_compressed_tokens), + ) + if max_num_compressed == 0: + return min(int(max_compressed_tokens), int(alignment)) + return min( + int(max_compressed_tokens), + cdiv(max_num_compressed, int(alignment)) * int(alignment), + ) + + @triton.jit def _build_c128a_topk_metadata_kernel( # Decode outputs @@ -358,7 +401,7 @@ def _build_c128a_topk_metadata_kernel( # Inputs positions_ptr, compress_ratio, - max_compressed_tokens, + effective_topk, num_decode_tokens, token_to_req_indices_ptr, block_table_ptr, @@ -370,7 +413,7 @@ def _build_c128a_topk_metadata_kernel( token_idx = tl.program_id(0) position = tl.load(positions_ptr + token_idx) num_compressed = (position + 1) // compress_ratio - num_compressed = tl.minimum(num_compressed, max_compressed_tokens) + num_compressed = tl.minimum(num_compressed, effective_topk) is_decode = token_idx < num_decode_tokens if is_decode: @@ -378,9 +421,9 @@ def _build_c128a_topk_metadata_kernel( is_valid_token = tl.load(slot_mapping_ptr + token_idx) >= 0 req_idx = tl.load(token_to_req_indices_ptr + token_idx) count = tl.zeros((), dtype=tl.int32) - for i in range(0, max_compressed_tokens, BLOCK_SIZE): + for i in range(0, effective_topk, BLOCK_SIZE): offset = i + tl.arange(0, BLOCK_SIZE) - mask = offset < max_compressed_tokens + mask = offset < effective_topk is_valid = offset < num_compressed block_indices = offset // block_size @@ -405,9 +448,9 @@ def _build_c128a_topk_metadata_kernel( else: # --- Prefill: write local indices --- pfx_idx = token_idx - num_decode_tokens - for i in range(0, max_compressed_tokens, BLOCK_SIZE): + for i in range(0, effective_topk, BLOCK_SIZE): offset = i + tl.arange(0, BLOCK_SIZE) - mask = offset < max_compressed_tokens + mask = offset < effective_topk tl.store( prefill_local_ptr + pfx_idx * prefill_local_stride + offset, tl.where(offset < num_compressed, offset, -1), diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 6835fc22a4e1..6d8dfe13128d 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -15,8 +15,6 @@ ) from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability -from vllm.triton_utils import tl, triton -from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import num_compute_units from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.v1.attention.backend import ( @@ -29,7 +27,6 @@ MultipleOf, SparseMLAAttentionImpl, ) -from vllm.v1.attention.backends.mla.compressor_utils import get_compressed_slot_mapping from vllm.v1.attention.backends.mla.sparse_utils import ( triton_convert_req_index_to_global_index, ) @@ -895,169 +892,3 @@ def forward_mqa( ) return attn_out, None - - -def build_c128a_topk_metadata( - positions: torch.Tensor, - compress_ratio: int, - num_decode_tokens: int, - token_to_req_indices: torch.Tensor, - block_table: torch.Tensor, - block_size: int, - slot_mapping: torch.Tensor, - global_decode_buffer: torch.Tensor, - decode_lens_buffer: torch.Tensor, - prefill_buffer: torch.Tensor, - max_compressed_tokens: int = 8192, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Single kernel for all C128A tokens (decode + prefill). - - Decode tokens: position → block_table lookup → global slot ids + topk_lens. - Prefill tokens: position → local indices [0, ..., n-1, -1, ...]. - - Writes into pre-allocated buffers for CUDA graph address stability. - Returns slices of the buffers. - """ - num_tokens = positions.shape[0] - num_prefill_tokens = num_tokens - num_decode_tokens - - global_decode = global_decode_buffer[:num_decode_tokens] - decode_lens = decode_lens_buffer[:num_decode_tokens] - prefill_local = prefill_buffer[:num_prefill_tokens] - - if num_tokens == 0: - return global_decode, decode_lens, prefill_local - - BLOCK_SIZE = 1024 - - # Compute the smallest BLOCK_SIZE-aligned width that covers every - # in-flight token's num_compressed. When ``max_model_len`` is much - # larger than the actual prompts (e.g. 1M cap with 2K inputs) the - # original ``range(0, max_compressed_tokens, BLOCK_SIZE)`` iterated - # over a tail that always wrote ``-1`` for shorter contexts. Capping - # the inner loop at ``effective_topk`` cuts those dead iterations. - # - # This builder runs OUTSIDE the CUDA-graph-captured forward pass, so - # ``effective_topk`` may vary freely per call. To keep downstream - # (which reads the full ``max_compressed_tokens`` buffer width inside - # the captured forward) correct, we pre-fill the active slice with - # ``-1`` so the kernel-skipped tail is treated as "invalid" by the - # sentinel checks in the sparse MLA accumulate kernels. - if num_decode_tokens > 0: - global_decode_buffer[:num_decode_tokens].fill_(-1) - decode_lens_buffer[:num_decode_tokens].zero_() - if num_prefill_tokens > 0: - prefill_buffer[:num_prefill_tokens].fill_(-1) - - # ``.item()`` is a host sync, but this builder runs in metadata - # build (outside the captured forward) so it is harmless w.r.t. - # cudagraph capture/replay (see the comment block above). - max_pos = int(positions.max().item()) - max_num_compressed = min( - max((max_pos + 1) // compress_ratio, 0), - max_compressed_tokens, - ) - effective_topk_arg = min( - max_compressed_tokens, - cdiv(max_num_compressed, BLOCK_SIZE) * BLOCK_SIZE, - ) - if effective_topk_arg == 0: - # Nothing to write; the fill_(-1) above already produced the - # correct "all-invalid" buffer state. - return global_decode, decode_lens, prefill_local - - _build_c128a_topk_metadata_kernel[(num_tokens,)]( - global_decode_buffer, - global_decode_buffer.stride(0), - decode_lens_buffer, - prefill_buffer, - prefill_buffer.stride(0), - positions, - compress_ratio, - effective_topk_arg, - num_decode_tokens, - token_to_req_indices, - block_table, - block_table.stride(0), - block_size, - slot_mapping, - BLOCK_SIZE=BLOCK_SIZE, - ) - return global_decode, decode_lens, prefill_local - - -@triton.jit -def _build_c128a_topk_metadata_kernel( - # Decode outputs - global_decode_ptr, - global_decode_stride, - decode_lens_ptr, - # Prefill output - prefill_local_ptr, - prefill_local_stride, - # Inputs - positions_ptr, - compress_ratio, - effective_topk, - num_decode_tokens, - token_to_req_indices_ptr, - block_table_ptr, - block_table_stride, - block_size, - slot_mapping_ptr, - BLOCK_SIZE: tl.constexpr, -): - # ``effective_topk`` is the BLOCK_SIZE-aligned cap that covers every - # in-flight token's ``num_compressed`` (computed by the Python - # caller in ``build_c128a_topk_metadata``). The buffer columns - # extend out to the Python-side ``max_compressed_tokens`` width; - # entries past ``effective_topk`` are left at ``-1`` by the caller's - # ``fill_(-1)`` pre-pass so the downstream sparse MLA accumulate - # kernels treat them as invalid via their ``kv_index >= 0`` / - # ``slot_id >= 0`` sentinel checks. - token_idx = tl.program_id(0) - position = tl.load(positions_ptr + token_idx) - num_compressed = (position + 1) // compress_ratio - num_compressed = tl.minimum(num_compressed, effective_topk) - is_decode = token_idx < num_decode_tokens - - if is_decode: - # --- Decode: block-table lookup → global slot ids + count --- - is_valid_token = tl.load(slot_mapping_ptr + token_idx) >= 0 - req_idx = tl.load(token_to_req_indices_ptr + token_idx) - count = tl.zeros((), dtype=tl.int32) - for i in range(0, effective_topk, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) - mask = offset < effective_topk - is_valid = offset < num_compressed - - block_indices = offset // block_size - block_numbers = tl.load( - block_table_ptr + req_idx * block_table_stride + block_indices, - mask=mask & is_valid, - ) - block_offsets = offset % block_size - slot_ids = block_numbers * block_size + block_offsets - slot_ids = tl.where(is_valid, slot_ids, -1) - tl.store( - global_decode_ptr + token_idx * global_decode_stride + offset, - slot_ids, - mask=mask, - ) - count += tl.sum(is_valid.to(tl.int32), axis=0) - - tl.store( - decode_lens_ptr + token_idx, - tl.where(is_valid_token, count, 0), - ) - else: - # --- Prefill: write local indices --- - pfx_idx = token_idx - num_decode_tokens - for i in range(0, effective_topk, BLOCK_SIZE): - offset = i + tl.arange(0, BLOCK_SIZE) - mask = offset < effective_topk - tl.store( - prefill_local_ptr + pfx_idx * prefill_local_stride + offset, - tl.where(offset < num_compressed, offset, -1), - mask=mask, - ) From 03913e2687102009cf21226ee474b91a0260f292 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 9 Jun 2026 03:52:51 +0800 Subject: [PATCH 087/131] fix: sync DeepSeek V4 MoE metadata after runner refactor Signed-off-by: jasl --- .../test_deepseek_v4_moe_metadata.py | 25 +++++++++++++ vllm/models/deepseek_v4/nvidia/model.py | 36 +++++++++++++++++-- 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/tests/model_executor/test_deepseek_v4_moe_metadata.py b/tests/model_executor/test_deepseek_v4_moe_metadata.py index fbfdba6e19db..3e680c532d68 100644 --- a/tests/model_executor/test_deepseek_v4_moe_metadata.py +++ b/tests/model_executor/test_deepseek_v4_moe_metadata.py @@ -29,6 +29,31 @@ def test_deepseek_v4_fused_moe_metadata_is_available_to_mixture(): assert mixture.num_redundant_experts == 0 +def test_deepseek_v4_fused_moe_metadata_handles_moe_runner_shape(): + moe = object.__new__(DeepseekV4MoE) + moe.n_routed_experts = 256 + moe.n_shared_experts = 1 + moe.experts = SimpleNamespace( + moe_config=SimpleNamespace( + num_logical_experts=256, + num_experts=256, + num_local_experts=128, + ), + routed_experts=SimpleNamespace( + global_num_experts=256, + local_num_experts=128, + ), + ) + + moe._sync_fused_moe_metadata() + + assert moe.n_logical_experts == 256 + assert moe.n_physical_experts == 256 + assert moe.n_local_physical_experts == 128 + assert moe.n_local_experts == 128 + assert moe.n_redundant_experts == 0 + + def test_deepseek_v4_fused_moe_init_exports_moe_metadata(monkeypatch): class FakeGate: def __init__(self, *args, **kwargs): diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index c5a1a3ae8816..54492781b6d4 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -657,9 +657,39 @@ def _init_fused_moe_experts( self._sync_fused_moe_metadata() def _sync_fused_moe_metadata(self) -> None: - self.n_logical_experts = self.experts.logical_num_experts - self.n_physical_experts = self.experts.global_num_experts - self.n_local_physical_experts = self.experts.local_num_experts + experts = self.experts + moe_config = getattr(experts, "moe_config", None) + routed_experts = getattr(experts, "routed_experts", experts) + + def get_optional_attr(obj, name: str): + return None if obj is None else getattr(obj, name, None) + + def first_defined(*values): + return next((value for value in values if value is not None), None) + + self.n_logical_experts = first_defined( + get_optional_attr(experts, "logical_num_experts"), + get_optional_attr(moe_config, "num_logical_experts"), + ) + self.n_physical_experts = first_defined( + get_optional_attr(routed_experts, "global_num_experts"), + get_optional_attr(experts, "global_num_experts"), + get_optional_attr(moe_config, "num_experts"), + ) + self.n_local_physical_experts = first_defined( + get_optional_attr(routed_experts, "local_num_experts"), + get_optional_attr(experts, "local_num_experts"), + get_optional_attr(moe_config, "num_local_experts"), + ) + if ( + self.n_logical_experts is None + or self.n_physical_experts is None + or self.n_local_physical_experts is None + ): + raise AttributeError( + "DeepseekV4MoE FusedMoE metadata is incomplete after " + "construction." + ) self.n_local_experts = self.n_local_physical_experts self.n_redundant_experts = self.n_physical_experts - self.n_logical_experts From 574905ac9933cf858e69584cf700dcfb9a3f679b Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 9 Jun 2026 04:02:14 +0800 Subject: [PATCH 088/131] sched: admit cached long-prompt tails behind long prefill Signed-off-by: jasl --- tests/v1/core/test_scheduler.py | 177 ++++++++++++++++++++++++++++++- vllm/v1/core/kv_cache_manager.py | 10 +- vllm/v1/core/sched/scheduler.py | 74 ++++++++++--- 3 files changed, 242 insertions(+), 19 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index be4423daaf5e..45cee40524e5 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1087,12 +1087,55 @@ def test_running_very_long_prefill_defers_waiting_very_long_prefill(): assert mixed_output.num_scheduled_tokens[short_req.request_id] == 20 -def test_running_very_long_prefill_ignores_deferred_long_waiting_pressure(): +def test_running_very_long_prefill_defers_waiting_uncached_long_prefills(): + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_model_len=2048, + max_num_seqs=4, + enable_chunked_prefill=True, + ) + first_long_req = create_requests( + num_requests=1, + num_tokens=600, + req_ids=["first_long"], + )[0] + waiting_long_reqs = create_requests( + num_requests=3, + num_tokens=600, + req_ids=["second_long", "third_long", "fourth_long"], + ) + + scheduler.add_request(first_long_req) + first_chunk = scheduler.schedule() + assert first_chunk.num_scheduled_tokens[first_long_req.request_id] == 100 + scheduler.update_from_output( + first_chunk, + ModelRunnerOutput( + req_ids=[first_long_req.request_id], + req_id_to_index={first_long_req.request_id: 0}, + sampled_token_ids=[[]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + for request in waiting_long_reqs: + scheduler.add_request(request) + mixed_output = scheduler.schedule() + + assert mixed_output.num_scheduled_tokens[first_long_req.request_id] == 100 + for request in waiting_long_reqs: + assert request.request_id not in mixed_output.num_scheduled_tokens + + +def test_running_very_long_prefill_defers_uncached_long_with_spare_budget(): scheduler = create_scheduler( max_num_batched_tokens=100, max_model_len=2048, max_num_seqs=2, enable_chunked_prefill=True, + long_prefill_token_threshold=50, ) first_long_req = create_requests( num_requests=1, @@ -1107,7 +1150,7 @@ def test_running_very_long_prefill_ignores_deferred_long_waiting_pressure(): scheduler.add_request(first_long_req) first_chunk = scheduler.schedule() - assert first_chunk.num_scheduled_tokens[first_long_req.request_id] == 100 + assert first_chunk.num_scheduled_tokens[first_long_req.request_id] == 50 scheduler.update_from_output( first_chunk, ModelRunnerOutput( @@ -1123,10 +1166,138 @@ def test_running_very_long_prefill_ignores_deferred_long_waiting_pressure(): scheduler.add_request(second_long_req) mixed_output = scheduler.schedule() - assert mixed_output.num_scheduled_tokens[first_long_req.request_id] == 100 + assert mixed_output.num_scheduled_tokens[first_long_req.request_id] == 50 assert second_long_req.request_id not in mixed_output.num_scheduled_tokens +def _run_request_to_completion(scheduler: Scheduler, request: Request) -> None: + while request.request_id in scheduler.requests: + scheduler_output = scheduler.schedule() + assert scheduler_output.num_scheduled_tokens + req_ids = list(scheduler_output.num_scheduled_tokens) + sampled_token_ids = [] + for req_id in req_ids: + if ( + req_id == request.request_id + and request.num_computed_tokens >= request.num_prompt_tokens + ): + sampled_token_ids.append([0]) + else: + sampled_token_ids.append([]) + scheduler.update_from_output( + scheduler_output, + ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)}, + sampled_token_ids=sampled_token_ids, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + +def test_running_very_long_prefill_admits_cached_tail_request(): + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_model_len=2048, + max_num_seqs=3, + enable_chunked_prefill=True, + enable_prefix_caching=True, + ) + warm_req = create_requests( + num_requests=1, + num_tokens=600, + max_tokens=1, + same_prompt=True, + req_ids=["warm_prefix"], + )[0] + + scheduler.add_request(warm_req) + _run_request_to_completion(scheduler, warm_req) + + first_long_req = create_requests( + num_requests=1, + num_tokens=1000, + req_ids=["first_long"], + )[0] + cached_tail_req = create_requests( + num_requests=1, + num_tokens=620, + same_prompt=True, + req_ids=["cached_tail"], + )[0] + + scheduler.add_request(first_long_req) + first_chunk = scheduler.schedule() + assert first_chunk.num_scheduled_tokens[first_long_req.request_id] == 100 + scheduler.update_from_output( + first_chunk, + ModelRunnerOutput( + req_ids=[first_long_req.request_id], + req_id_to_index={first_long_req.request_id: 0}, + sampled_token_ids=[[]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ), + ) + + scheduler.add_request(cached_tail_req) + mixed_output = scheduler.schedule() + + assert mixed_output.num_scheduled_tokens[first_long_req.request_id] < 100 + assert cached_tail_req.request_id in mixed_output.num_scheduled_tokens + assert mixed_output.num_scheduled_tokens[cached_tail_req.request_id] <= 100 + + +def test_prefix_cache_peek_does_not_record_stats(): + scheduler = create_scheduler( + max_num_batched_tokens=100, + max_model_len=2048, + max_num_seqs=2, + enable_chunked_prefill=True, + enable_prefix_caching=True, + ) + warm_req = create_requests( + num_requests=1, + num_tokens=600, + max_tokens=1, + same_prompt=True, + req_ids=["warm_prefix"], + )[0] + replay_req = create_requests( + num_requests=1, + num_tokens=620, + same_prompt=True, + req_ids=["replay_prefix"], + )[0] + + scheduler.add_request(warm_req) + _run_request_to_completion(scheduler, warm_req) + stats = scheduler.kv_cache_manager.make_prefix_cache_stats() + assert stats is not None + + _, peeked_tokens = scheduler.kv_cache_manager.get_computed_blocks( + replay_req, + record_stats=False, + ) + assert peeked_tokens > 0 + stats = scheduler.kv_cache_manager.make_prefix_cache_stats() + assert stats is not None + assert stats.requests == 0 + assert stats.queries == 0 + assert stats.hits == 0 + + _, recorded_tokens = scheduler.kv_cache_manager.get_computed_blocks(replay_req) + assert recorded_tokens == peeked_tokens + stats = scheduler.kv_cache_manager.make_prefix_cache_stats() + assert stats is not None + assert stats.requests == 1 + assert stats.queries == replay_req.num_tokens + assert stats.hits == peeked_tokens + + def test_running_very_long_prefill_defers_to_later_running_decode(): scheduler = create_scheduler( max_num_batched_tokens=100, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index c829b845803e..230134f7f410 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -199,12 +199,18 @@ def make_prefix_cache_stats(self) -> PrefixCacheStats | None: self.prefix_cache_stats = PrefixCacheStats() return stats - def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: + def get_computed_blocks( + self, + request: Request, + *, + record_stats: bool = True, + ) -> tuple[KVCacheBlocks, int]: """Get the computed (cached) blocks for the request. Note that the computed blocks must be full. Args: request: The request to get the computed blocks. + record_stats: Whether to record prefix-cache stats for this lookup. Returns: A tuple containing: @@ -231,7 +237,7 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: ) ) - if self.log_stats: + if self.log_stats and record_stats: assert self.prefix_cache_stats is not None self.prefix_cache_stats.record( num_tokens=request.num_tokens, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5cb58cff9bff..d2e417d505a6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -403,26 +403,41 @@ def _is_very_long_prefill( return False if num_computed_tokens is None: num_computed_tokens = request.num_computed_tokens - return ( - request.num_prompt_tokens > self._very_long_prefill_threshold() - and num_computed_tokens < request.num_prompt_tokens - ) + remaining_prefill = max(0, request.num_prompt_tokens - num_computed_tokens) + return remaining_prefill > self._very_long_prefill_threshold() + + def _waiting_request_remaining_prefill(self, request: Request) -> int: + if request.num_computed_tokens > 0: + num_computed_tokens = request.num_computed_tokens + else: + _, num_computed_tokens = self.kv_cache_manager.get_computed_blocks( + request, + record_stats=False, + ) + return max(0, request.num_prompt_tokens - num_computed_tokens) def _has_active_very_long_prefill(self) -> bool: return any(self._is_very_long_prefill(request) for request in self.running) + def _is_waiting_request_very_long_prefill(self, request: Request) -> bool: + return ( + self._waiting_request_remaining_prefill(request) + > self._very_long_prefill_threshold() + ) + + def _waiting_prefill_competitor_count(self) -> int: + non_long_prefill_count = 0 + for waiting_request in itertools.chain(self.waiting, self.skipped_waiting): + if not self._is_waiting_request_very_long_prefill(waiting_request): + non_long_prefill_count += 1 + return non_long_prefill_count + def _has_waiting_requests_for_running_prefill(self, request: Request) -> bool: if not (self.waiting or self.skipped_waiting): return False if not self._is_very_long_prefill(request): return True - return any( - not self._is_very_long_prefill(waiting_request) - for waiting_request in self.waiting - ) or any( - not self._is_very_long_prefill(waiting_request) - for waiting_request in self.skipped_waiting - ) + return self._waiting_prefill_competitor_count() > 0 def _limit_mixed_decode_prefill_chunk( self, @@ -431,6 +446,7 @@ def _limit_mixed_decode_prefill_chunk( scheduled_running_reqs: list[Request], has_waiting_requests: bool = False, has_pending_decode: bool = False, + prefill_budget_partitions: int = 1, ) -> int: if ( not self.scheduler_config.enable_chunked_prefill @@ -441,7 +457,10 @@ def _limit_mixed_decode_prefill_chunk( has_decode_pressure = ( self._has_scheduled_decode(scheduled_running_reqs) or has_pending_decode ) - if not has_decode_pressure and not has_waiting_requests: + has_prefill_competition = ( + has_waiting_requests or prefill_budget_partitions > 1 + ) + if not has_decode_pressure and not has_prefill_competition: return num_new_tokens remaining_prefill = request.num_prompt_tokens - request.num_computed_tokens @@ -458,7 +477,11 @@ def _limit_mixed_decode_prefill_chunk( else: mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 4) elif remaining_prefill > very_long_prefill_threshold: - mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 2) + mixed_prefill_budget = max( + 1, + self.max_num_scheduled_tokens + // max(2, prefill_budget_partitions), + ) else: mixed_prefill_budget = max(1, (self.max_num_scheduled_tokens * 3) // 4) return min(num_new_tokens, mixed_prefill_budget) @@ -553,6 +576,11 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: later_request.num_computed_tokens >= later_request.num_prompt_tokens for later_request in self.running[req_index + 1 :] ) + prefill_budget_partitions = ( + 1 + + int(has_unscheduled_running_prefill) + + self._waiting_prefill_competitor_count() + ) num_new_tokens = self._limit_mixed_decode_prefill_chunk( request, num_new_tokens, @@ -560,6 +588,7 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: self._has_waiting_requests_for_running_prefill(request) or has_unscheduled_running_prefill, has_pending_running_decode, + prefill_budget_partitions, ) # Make sure the input position does not exceed the max model len. @@ -910,8 +939,25 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: break num_new_tokens = min(num_new_tokens, token_budget) + scheduled_prefill_count = sum( + scheduled_request.num_computed_tokens + < scheduled_request.num_prompt_tokens + for scheduled_request in itertools.chain( + scheduled_running_reqs, + scheduled_new_reqs, + scheduled_resumed_reqs, + ) + ) + prefill_budget_partitions = max( + 1, + scheduled_prefill_count + + self._waiting_prefill_competitor_count(), + ) num_new_tokens = self._limit_mixed_decode_prefill_chunk( - request, num_new_tokens, scheduled_running_reqs + request, + num_new_tokens, + scheduled_running_reqs, + prefill_budget_partitions=prefill_budget_partitions, ) if num_new_tokens == 0: break From c601168d0f2c0f54fdd5bef745da1e028ffaf943 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 9 Jun 2026 05:02:27 +0800 Subject: [PATCH 089/131] fix: clear stale prefix-cache block hashes on reuse Blocks taken from the free pool are about to receive new content, so their old prefix-cache hash must be cleared even if the cache map no longer contains that exact block entry. This avoids stale block hashes accumulating under sustained prefix-cache reuse. Based on vllm-project/vllm#44237. Co-authored-by: Oxygen <1391083091@qq.com> Signed-off-by: jasl --- tests/v1/core/test_prefix_caching.py | 28 ++++++++++++++++++++++++++++ vllm/v1/core/block_pool.py | 22 +++++++++++++++------- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 4a17a4f247e9..aed11d8c826e 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -2037,6 +2037,34 @@ def test_maybe_evict_cached_block(): assert pool.cached_block_hash_to_block._cache == {} +@pytest.mark.parametrize("cache_state", ["missing", "different_block"]) +def test_maybe_evict_cached_block_resets_stale_hash_on_miss(cache_state: str): + pool = BlockPool( + num_gpu_blocks=3, + enable_caching=True, + hash_block_size=16, + enable_kv_cache_events=True, + ) + block = pool.blocks[1] + other_block = pool.blocks[2] + block_hash = make_block_hash_with_group_id(BlockHash(b"stale"), 0) + + block.block_hash = block_hash + if cache_state == "different_block": + other_block.block_hash = block_hash + pool.cached_block_hash_to_block.insert(block_hash, other_block) + + assert pool._maybe_evict_cached_block(block) is False + assert block.block_hash is None + assert pool.kv_event_queue == [] + + if cache_state == "different_block": + assert pool.cached_block_hash_to_block._cache == {block_hash: other_block} + assert other_block.block_hash == block_hash + else: + assert pool.cached_block_hash_to_block._cache == {} + + @pytest.mark.parametrize("blocks_to_cache", [2, 3, 10]) def test_kv_cache_events(blocks_to_cache: int): block_size = 16 diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index e6bbba146693..0a69ac2e1cfc 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -92,7 +92,11 @@ def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None: def pop(self, key: BlockHashWithGroupId, block_id: int) -> KVCacheBlock | None: """ - Checks if block_hash exists and pop block_id from the cache + Checks if block_hash exists and pop block_id from the cache. + + In the single-block case, the caller is expected to pass the exact + block stored at ``key``. In the collision case, the block is looked up + by ``block_id`` while sibling blocks remain cached. """ blocks = self._cache.pop(key, None) if blocks is None: @@ -283,6 +287,7 @@ def cache_full_blocks( new_hashes.append(maybe_convert_block_hash(block_hash)) if self.enable_kv_cache_events: + assert new_hashes is not None if num_cached_blocks == 0: parent_block_hash: ExternalBlockHash | None = None else: @@ -367,6 +372,9 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: If a block is cached in `cached_block_hash_to_block`, we reset its hash metadata and evict it from the cache. + When the block has a hash but is not found in the cache map, the hash + is still reset because the block is about to be reused for new content. + Args: block: The block to evict. @@ -382,14 +390,14 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: # The block doesn't have hash, eviction is not needed return False - if self.cached_block_hash_to_block.pop(block_hash, block.block_id) is None: - # block not found in cached_block_hash_to_block, - # eviction is not needed - return False + evicted = ( + self.cached_block_hash_to_block.pop(block_hash, block.block_id) + is not None + ) block.reset_hash() - if self.enable_kv_cache_events: + if evicted and self.enable_kv_cache_events: self.kv_event_queue.append( BlockRemoved( block_hashes=[maybe_convert_block_hash(get_block_hash(block_hash))], @@ -397,7 +405,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: group_idx=get_group_id(block_hash), ) ) - return True + return evicted def touch(self, blocks: Sequence[KVCacheBlock]) -> None: """Touch a block increases its reference count by 1, and may remove From 79232ca15cabdb9a37b11cf3c0700f19dde7125e Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 9 Jun 2026 08:40:58 +0800 Subject: [PATCH 090/131] test: cover DeepSeek V4 RoutedExperts MXFP4 quant dispatch Upstream #44914 carries the runtime fix. Keep a local regression test for the DeepSeek V4 MoE runner refactor path so RoutedExperts continues to use MXFP4 expert quantization. Signed-off-by: jasl --- .../test_deepseek_v4_moe_metadata.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/model_executor/test_deepseek_v4_moe_metadata.py b/tests/model_executor/test_deepseek_v4_moe_metadata.py index 3e680c532d68..e87cdd8c3ff7 100644 --- a/tests/model_executor/test_deepseek_v4_moe_metadata.py +++ b/tests/model_executor/test_deepseek_v4_moe_metadata.py @@ -1,5 +1,7 @@ from types import SimpleNamespace +from vllm.model_executor.layers.fused_moe import RoutedExperts +from vllm.models.deepseek_v4 import quant_config as deepseek_v4_quant_config from vllm.models.deepseek_v4.nvidia import model as deepseek_v4_model from vllm.models.deepseek_v4.nvidia.model import ( DeepseekV4MixtureOfExperts, @@ -54,6 +56,44 @@ def test_deepseek_v4_fused_moe_metadata_handles_moe_runner_shape(): assert moe.n_redundant_experts == 0 +def test_deepseek_v4_fp4_quant_config_handles_routed_experts_after_moe_refactor( + monkeypatch, +): + class FakeMxfp4MoEMethod: + def __init__(self, moe_config): + self.moe_config = moe_config + + quant_config = deepseek_v4_quant_config.DeepseekV4FP8Config( + is_checkpoint_fp8_serialized=True, + weight_block_size=[128, 128], + ) + layer = object.__new__(RoutedExperts) + layer.moe_config = object() + + monkeypatch.setattr( + deepseek_v4_quant_config, + "Mxfp4MoEMethod", + FakeMxfp4MoEMethod, + ) + monkeypatch.setattr( + deepseek_v4_quant_config, + "get_current_vllm_config", + lambda: SimpleNamespace( + model_config=SimpleNamespace( + hf_config=SimpleNamespace( + expert_dtype="fp4", + quantization_config={}, + ) + ) + ), + ) + + method = quant_config.get_quant_method(layer, "model.layers.3.mlp.experts") + + assert isinstance(method, FakeMxfp4MoEMethod) + assert method.moe_config is layer.moe_config + + def test_deepseek_v4_fused_moe_init_exports_moe_metadata(monkeypatch): class FakeGate: def __init__(self, *args, **kwargs): From c828ef49c4d5b6e9d3ded816b857caa82b7abe90 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 9 Jun 2026 17:54:44 +0800 Subject: [PATCH 091/131] deepseek-v4: keep small C128A prefills on D512 path Signed-off-by: jasl --- .../test_deepseek_v4_sparse_mla_metadata.py | 9 +++++++++ vllm/models/deepseek_v4/nvidia/flashmla.py | 15 +++++++++++---- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/model_executor/test_deepseek_v4_sparse_mla_metadata.py b/tests/model_executor/test_deepseek_v4_sparse_mla_metadata.py index 0cb1431af1c7..8d26aa7fea24 100644 --- a/tests/model_executor/test_deepseek_v4_sparse_mla_metadata.py +++ b/tests/model_executor/test_deepseek_v4_sparse_mla_metadata.py @@ -4,6 +4,7 @@ import torch from vllm.models.deepseek_v4 import sparse_mla +from vllm.models.deepseek_v4.nvidia import flashmla def test_c128a_effective_topk_width_uses_current_positions() -> None: @@ -43,3 +44,11 @@ def test_c128a_effective_topk_width_uses_current_positions() -> None: ) == 8192 ) + + +def test_indexed_d512_split_topk_keeps_small_c128a_prefills() -> None: + assert not flashmla._is_indexed_d512_split_topk(128) + assert flashmla._is_indexed_d512_split_topk(256) + assert flashmla._is_indexed_d512_split_topk(512) + assert flashmla._is_indexed_d512_split_topk(1152) + assert not flashmla._is_indexed_d512_split_topk(1280) diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index 5607112acea3..dba35cba992b 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -56,6 +56,7 @@ _INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS = 8192 +_INDEXED_D512_SPLIT_PREFILL_MIN_TOPK = 256 _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK = 1152 @@ -74,11 +75,19 @@ def _use_indexed_d512_split_prefill( and compress_ratio in (4, 128) and head_dim == 512 and num_prefills == 1 - and 512 < combined_topk <= _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK + and _is_indexed_d512_split_topk(combined_topk) and max_prefill_seq_len >= _INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS ) +def _is_indexed_d512_split_topk(combined_topk: int) -> bool: + return ( + _INDEXED_D512_SPLIT_PREFILL_MIN_TOPK + <= combined_topk + <= _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK + ) + + def _use_indexed_d512_chunked_prefill( *, compress_ratio: int, @@ -625,9 +634,7 @@ def _forward_sparse_mla_prefill_triton( and layer.compress_ratio in (4, 128) and q.shape[-1] == 512 and kv.shape[0] == 1 - and 512 - < combined_indices.shape[-1] - <= _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK + and _is_indexed_d512_split_topk(combined_indices.shape[-1]) and len(state_buffers) == 4 ): indexed_d512_scores = state_buffers[3] From fde655cc06243196324c0ffc2bfb95039955b41f Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 9 Jun 2026 23:26:05 +0800 Subject: [PATCH 092/131] fix: scale DeepSeek MLA prefix retention by sequence limit Signed-off-by: jasl (cherry picked from commit 90bf5572d9ceb1c68633e28c64f6e09e72eeced5) --- tests/v1/core/test_prefix_caching.py | 53 ++++++++++++++++++++ vllm/v1/core/kv_cache_coordinator.py | 12 +++++ vllm/v1/core/kv_cache_manager.py | 2 + vllm/v1/core/sched/scheduler.py | 1 + vllm/v1/core/single_type_kv_cache_manager.py | 23 +++++++++ 5 files changed, 91 insertions(+) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index aed11d8c826e..7677d759859a 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -127,6 +127,27 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: ) +def make_mla_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: + return KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=512, + dtype=torch.float8_e4m3fn, + cache_dtype_str="fp8_ds_mla", + compress_ratio=2, + model_version="deepseek_v4", + ), + ) + ], + ) + + def make_kv_cache_config_hybrid_model( block_size: int, num_blocks: int, @@ -1656,6 +1677,38 @@ def test_cache_blocks_multi_group(): ) +def test_deepseek_v4_mla_prompt_protection_scales_with_max_num_seqs(): + block_size = 4 + manager = make_kv_cache_manager( + make_mla_kv_cache_config(block_size=block_size, num_blocks=64), + max_model_len=16, + max_num_batched_tokens=16, + max_num_seqs=4, + enable_caching=True, + hash_block_size=block_size, + ) + + mla_manager = manager.coordinator.single_type_managers[0] + + assert mla_manager._max_protected_prompt_blocks() == 16 + + +def test_deepseek_v4_mla_prompt_protection_leaves_allocation_headroom(): + block_size = 4 + manager = make_kv_cache_manager( + make_mla_kv_cache_config(block_size=block_size, num_blocks=15), + max_model_len=16, + max_num_batched_tokens=16, + max_num_seqs=4, + enable_caching=True, + hash_block_size=block_size, + ) + + mla_manager = manager.coordinator.single_type_managers[0] + + assert mla_manager._max_protected_prompt_blocks() == 10 + + def test_mm_prefix_caching(): """ This tests that the multi-modal prefix caching is correct. diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index b5b109d776c4..d741d7e67f5a 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -75,6 +75,7 @@ def __init__( pcp_world_size: int, scheduler_block_size: int, hash_block_size: int, + max_num_seqs: int | None = None, metrics_collector: KVCacheMetricsCollector | None = None, ): self.kv_cache_config = kv_cache_config @@ -115,6 +116,7 @@ def __init__( dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, scheduler_block_size=self.scheduler_block_size, + max_num_seqs=max_num_seqs, ) for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups) ) @@ -400,6 +402,7 @@ def __init__( pcp_world_size: int, scheduler_block_size: int, hash_block_size: int, + max_num_seqs: int | None = None, metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( @@ -413,6 +416,7 @@ def __init__( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=metrics_collector, ) self.num_single_type_manager = len(self.single_type_managers) @@ -450,6 +454,7 @@ def __init__( pcp_world_size: int, scheduler_block_size: int, hash_block_size: int, + max_num_seqs: int | None = None, metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( @@ -463,6 +468,7 @@ def __init__( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=metrics_collector, ) self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec @@ -536,6 +542,7 @@ def __init__( pcp_world_size: int, scheduler_block_size: int, hash_block_size: int, + max_num_seqs: int | None = None, metrics_collector: KVCacheMetricsCollector | None = None, ): super().__init__( @@ -549,6 +556,7 @@ def __init__( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=metrics_collector, ) # hash_block_size: the block size used to compute block hashes. @@ -800,6 +808,7 @@ def get_kv_cache_coordinator( pcp_world_size: int, scheduler_block_size: int, hash_block_size: int, + max_num_seqs: int | None = None, metrics_collector: KVCacheMetricsCollector | None = None, ) -> KVCacheCoordinator: if not enable_caching: @@ -813,6 +822,7 @@ def get_kv_cache_coordinator( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=metrics_collector, ) if len(kv_cache_config.kv_cache_groups) == 1: @@ -827,6 +837,7 @@ def get_kv_cache_coordinator( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=metrics_collector, ) return HybridKVCacheCoordinator( @@ -840,5 +851,6 @@ def get_kv_cache_coordinator( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=metrics_collector, ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 230134f7f410..ac782954ce56 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -121,6 +121,7 @@ def __init__( enable_kv_cache_events: bool = False, dcp_world_size: int = 1, pcp_world_size: int = 1, + max_num_seqs: int | None = None, metrics_collector: KVCacheMetricsCollector | None = None, watermark: float = 0.0, ) -> None: @@ -151,6 +152,7 @@ def __init__( pcp_world_size=pcp_world_size, scheduler_block_size=scheduler_block_size, hash_block_size=hash_block_size, + max_num_seqs=max_num_seqs, metrics_collector=self.metrics_collector, ) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index d2e417d505a6..3728b03eed53 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -263,6 +263,7 @@ def __init__( pcp_world_size=self.pcp_world_size, scheduler_block_size=self.block_size, hash_block_size=hash_block_size, + max_num_seqs=self.scheduler_config.max_num_seqs, metrics_collector=self.kv_metrics_collector, watermark=self.scheduler_config.watermark, ) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 87e73c5e4c4f..9ca7155a65fe 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -46,6 +46,7 @@ def __init__( pcp_world_size: int = 1, max_admission_blocks_per_request: int | None = None, max_model_len: int | None = None, + max_num_seqs: int | None = None, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -74,6 +75,7 @@ def __init__( self.enable_caching = enable_caching self._max_admission_blocks_per_request = max_admission_blocks_per_request self.max_model_len = max_model_len + self.max_num_seqs = max_num_seqs self.cache_alignment_tokens = self.block_size self.new_block_ids: list[int] = [] @@ -704,6 +706,23 @@ def _should_protect_prompt_blocks(self) -> bool: or self.kv_cache_spec.compress_ratio > 1 ) + def _max_protected_prompt_blocks(self) -> int | None: + if self.max_num_seqs is None: + return super()._max_protected_prompt_blocks() + if self.max_model_len is None: + return None + + prompt_blocks = cdiv(max(1, self.max_model_len), self.block_size) + target_reqs = max(2, self.max_num_seqs) + target_blocks = target_reqs * prompt_blocks + + # Keep one max-length request worth of blocks available for new work + # before the generic allocation path has to release protected prompts. + pool_blocks = max(0, self.block_pool.num_gpu_blocks - 1) + if pool_blocks <= prompt_blocks: + return pool_blocks + return min(target_blocks, pool_blocks - prompt_blocks) + def cache_blocks( self, request: Request, @@ -1512,6 +1531,7 @@ def __init__( dcp_world_size: int = 1, pcp_world_size: int = 1, max_model_len: int | None = None, + max_num_seqs: int | None = None, ): super().__init__( kv_cache_spec, @@ -1522,6 +1542,7 @@ def __init__( dcp_world_size, pcp_world_size, max_model_len=max_model_len, + max_num_seqs=max_num_seqs, ) sink_len = kv_cache_spec.sink_len assert sink_len is not None and sink_len > 0 and sink_len % self.block_size == 0 @@ -1533,6 +1554,7 @@ def get_manager_for_kv_cache_spec( kv_cache_spec: KVCacheSpec, max_num_batched_tokens: int, max_model_len: int, + max_num_seqs: int | None = None, **kwargs, ) -> SingleTypeKVCacheManager: """ @@ -1554,6 +1576,7 @@ def get_manager_for_kv_cache_spec( f"No manager registered for KVCacheSpec {type(kv_cache_spec)}" ) kwargs["max_model_len"] = max_model_len + kwargs["max_num_seqs"] = max_num_seqs # SlidingWindow / ChunkedLocalAttention managers recycle blocks across # chunks; the runtime admission cap must match the recycling-aware bound # the startup pool sizer uses (single source of truth: the spec method). From 58d8270b142580dde077d05595d4da3cf2d9d563 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 10 Jun 2026 00:05:38 +0800 Subject: [PATCH 093/131] fix: buffer split DeepSeek V4 DSML tool markers Signed-off-by: jasl --- .../test_deepseekv4_reasoning_parser.py | 93 +++++++++++++++++++ .../reasoning/deepseek_v4_reasoning_parser.py | 57 +++++++++++- 2 files changed, 149 insertions(+), 1 deletion(-) diff --git a/tests/reasoning/test_deepseekv4_reasoning_parser.py b/tests/reasoning/test_deepseekv4_reasoning_parser.py index f72fe27bbfa6..2dcaf45202db 100644 --- a/tests/reasoning/test_deepseekv4_reasoning_parser.py +++ b/tests/reasoning/test_deepseekv4_reasoning_parser.py @@ -210,6 +210,99 @@ def test_implicit_end_marker_within_delta_split(parser): assert parser._implicit_end_seen is True +def test_partial_implicit_end_marker_prefix_is_buffered(parser): + """A DSML marker split across text deltas must not leak into reasoning.""" + delta = parser.extract_reasoning_streaming( + previous_text="reasoning ", + current_text="reasoning <|DSML|tool", + delta_text="<|DSML|tool", + previous_token_ids=[450], + current_token_ids=[450, 451], + delta_token_ids=[451], + ) + assert delta is None + assert parser._implicit_end_seen is False + + delta = parser.extract_reasoning_streaming( + previous_text="reasoning <|DSML|tool", + current_text=f"reasoning {DSML_MARKER}\n<|DSML|invoke", + delta_text="_calls>\n<|DSML|invoke", + previous_token_ids=[450, 451], + current_token_ids=[450, 451, 452, 453], + delta_token_ids=[452, 453], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == f"{DSML_MARKER}\n<|DSML|invoke" + assert parser._implicit_end_seen is True + + +def test_partial_implicit_end_marker_with_reasoning_prefix(parser): + """Reasoning before a split DSML marker is emitted, but the marker waits.""" + delta = parser.extract_reasoning_streaming( + previous_text="head ", + current_text="head tail reasoning<|DSML|tool", + delta_text="tail reasoning<|DSML|tool", + previous_token_ids=[460], + current_token_ids=[460, 461, 462], + delta_token_ids=[461, 462], + ) + assert delta is not None + assert delta.reasoning == "tail reasoning" + assert delta.content is None + assert parser._implicit_end_seen is False + + delta = parser.extract_reasoning_streaming( + previous_text="head tail reasoning<|DSML|tool", + current_text=f"head tail reasoning{DSML_MARKER}", + delta_text="_calls>", + previous_token_ids=[460, 461, 462], + current_token_ids=[460, 461, 462, 463], + delta_token_ids=[463], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == DSML_MARKER + assert parser._implicit_end_seen is True + + +def test_partial_implicit_end_marker_can_span_multiple_deltas(parser): + """The buffered marker prefix can grow across more than two chunks.""" + first = parser.extract_reasoning_streaming( + previous_text="r", + current_text="r<|DSML|", + delta_text="<|DSML|", + previous_token_ids=[470], + current_token_ids=[470, 471], + delta_token_ids=[471], + ) + assert first is None + + second = parser.extract_reasoning_streaming( + previous_text="r<|DSML|", + current_text="r<|DSML|tool", + delta_text="tool", + previous_token_ids=[470, 471], + current_token_ids=[470, 471, 472], + delta_token_ids=[472], + ) + assert second is None + assert parser._implicit_end_seen is False + + third = parser.extract_reasoning_streaming( + previous_text="r<|DSML|tool", + current_text=f"r{DSML_MARKER}", + delta_text="_calls>", + previous_token_ids=[470, 471, 472], + current_token_ids=[470, 471, 472, 473], + delta_token_ids=[473], + ) + assert third is not None + assert third.reasoning is None + assert third.content == DSML_MARKER + assert parser._implicit_end_seen is True + + def test_subsequent_delta_after_implicit_end_is_content(parser): """Once the implicit end fires, every later delta is content.""" # Seed the parser by flipping the sticky flag via a marker delta. diff --git a/vllm/reasoning/deepseek_v4_reasoning_parser.py b/vllm/reasoning/deepseek_v4_reasoning_parser.py index 185af64a37e2..8b19d901da7a 100644 --- a/vllm/reasoning/deepseek_v4_reasoning_parser.py +++ b/vllm/reasoning/deepseek_v4_reasoning_parser.py @@ -58,6 +58,10 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): # the rest of the stream is content and the orchestrator's # is_reasoning_end check must return True for every subsequent delta. self._implicit_end_seen: bool = False + # Text deltas may split the DSML marker itself, for example + # "<|DSML|tool" then "_calls>". Hold such suffixes back so they do + # not leak into reasoning before the next delta disambiguates them. + self._pending_implicit_marker_prefix: str = "" def _find_implicit_end_marker(self, text: str) -> tuple[str, int] | None: """Return ``(marker, index)`` of the earliest implicit end marker in @@ -72,6 +76,46 @@ def _find_implicit_end_marker(self, text: str) -> tuple[str, int] | None: earliest = (marker, idx) return earliest + def _partial_implicit_end_marker_suffix(self, text: str) -> str: + """Return the longest suffix of ``text`` that prefixes a DSML marker.""" + best = "" + for marker in self.implicit_end_markers: + max_len = min(len(marker) - 1, len(text)) + for length in range(max_len, 0, -1): + suffix = text[-length:] + if marker.startswith(suffix): + if length > len(best): + best = suffix + break + return best + + def _extract_pending_implicit_marker_delta( + self, delta_text: str + ) -> DeltaMessage | None: + pending = self._pending_implicit_marker_prefix + combined = pending + delta_text + + marker = self._find_implicit_end_marker(combined) + if marker is not None and marker[1] == 0: + self._pending_implicit_marker_prefix = "" + self._implicit_end_seen = True + return DeltaMessage(content=combined) + + if any(marker.startswith(combined) for marker in self.implicit_end_markers): + self._pending_implicit_marker_prefix = combined + return None + + # The withheld suffix was a false alarm. Emit it as reasoning, while + # still guarding against a new marker prefix at the end of this delta. + partial = self._partial_implicit_end_marker_suffix(combined) + if partial: + self._pending_implicit_marker_prefix = partial + reasoning = combined[: -len(partial)] + else: + self._pending_implicit_marker_prefix = "" + reasoning = combined + return DeltaMessage(reasoning=reasoning) if reasoning else None + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: # Honor the explicit contract first. if super().is_reasoning_end(input_ids): @@ -134,6 +178,9 @@ def extract_reasoning_streaming( ): return ret + if self._pending_implicit_marker_prefix: + return self._extract_pending_implicit_marker_delta(delta_text) + # Sticky: marker observed in an earlier delta — everything here is # content for the tool parser. if self._implicit_end_seen: @@ -141,11 +188,19 @@ def extract_reasoning_streaming( marker_in_current = self._find_implicit_end_marker(current_text) if marker_in_current is None: - # No marker anywhere; parent's classification stands. + partial = self._partial_implicit_end_marker_suffix(current_text) + if partial: + partial_start = len(current_text) - len(partial) + if partial_start >= len(previous_text): + self._pending_implicit_marker_prefix = partial + reasoning = delta_text[: -len(partial)] or None + return DeltaMessage(reasoning=reasoning) if reasoning else None + # No marker or holdable marker prefix; parent's classification stands. return ret # First sighting of the implicit end marker. self._implicit_end_seen = True + self._pending_implicit_marker_prefix = "" _marker_str, marker_idx_current = marker_in_current # Position within delta_text where the marker begins. marker_idx_delta = marker_idx_current - len(previous_text) From 4b102a02e8d9c09a5e57f93c310dcaec525beb6e Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 10 Jun 2026 07:54:03 +0800 Subject: [PATCH 094/131] test: adapt DeepSeek V4 MoE metadata fixture to EPLB Signed-off-by: jasl --- tests/model_executor/test_deepseek_v4_moe_metadata.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/model_executor/test_deepseek_v4_moe_metadata.py b/tests/model_executor/test_deepseek_v4_moe_metadata.py index e87cdd8c3ff7..25a177a9792f 100644 --- a/tests/model_executor/test_deepseek_v4_moe_metadata.py +++ b/tests/model_executor/test_deepseek_v4_moe_metadata.py @@ -142,7 +142,11 @@ def __init__(self, *args, **kwargs): model_config=SimpleNamespace(hf_config=config), quant_config=None, kernel_config=SimpleNamespace(moe_backend="auto"), - parallel_config=SimpleNamespace(enable_expert_parallel=True), + parallel_config=SimpleNamespace( + enable_expert_parallel=True, + enable_eplb=False, + eplb_config=SimpleNamespace(num_redundant_experts=0), + ), ) moe = DeepseekV4MoE(vllm_config, prefix="model.layers.3.mlp") From 6ad66c450194c73b32a988f59fd6972ef68fb51a Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 11 Jun 2026 21:47:13 +0800 Subject: [PATCH 095/131] sched: keep trace import out of stable path Signed-off-by: jasl --- vllm/v1/core/sched/scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 3728b03eed53..57ea575b4517 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -7,7 +7,6 @@ from dataclasses import replace from typing import Any -from vllm import envs from vllm.compilation.cuda_graph import CUDAGraphStat from vllm.config import VllmConfig from vllm.distributed.ec_transfer.ec_connector.base import ( From 237e7e7667200481b949d217463993e95ec8a38d Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 12 Jun 2026 06:19:48 +0800 Subject: [PATCH 096/131] fix: restore DeepSeek V4 sparse MLA stats env Signed-off-by: jasl --- tests/test_envs.py | 19 +++++++++++++++++++ vllm/envs.py | 4 ++++ 2 files changed, 23 insertions(+) diff --git a/tests/test_envs.py b/tests/test_envs.py index d4d120ecee51..5c70d56e692c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -36,6 +36,25 @@ def test_nixl_side_channel_host_is_not_compile_factor( assert "VLLM_NIXL_SIDE_CHANNEL_HOST" not in envs.compile_factors() +def test_deepseek_v4_sparse_mla_stats_path_env( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.delenv("VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH", raising=False) + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + assert envs.VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH is None + + monkeypatch.setenv( + "VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH", + "/tmp/sparse_mla_stats", + ) + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + assert envs.VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH == "/tmp/sparse_mla_stats" + + def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_HOST_IP", "1.1.1.1") monkeypatch.setenv("VLLM_PORT", "1234") diff --git a/vllm/envs.py b/vllm/envs.py index 2757a7e26ac6..2ec89025b6cd 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -180,6 +180,7 @@ VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: bool = True + VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH: str | None = None VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL: bool = True VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL: bool = True VLLM_TRITON_MLA_SPARSE: bool | None = None @@ -1455,6 +1456,9 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP": lambda: bool( int(os.getenv("VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP", "1")) ), + "VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH": lambda: os.getenv( + "VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH" + ), "VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL": lambda: bool( int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL", "1")) ), From 2349559a48eba2e6e11f0eac59082286756d24de Mon Sep 17 00:00:00 2001 From: jasl Date: Sun, 14 Jun 2026 02:50:00 +0800 Subject: [PATCH 097/131] fix: allow DeepSeek V4 chat stray think end Signed-off-by: jasl --- .../test_structural_tag_registry.py | 34 +++++++++++++++++++ vllm/tool_parsers/structural_tag_registry.py | 16 ++++++++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/tests/tool_parsers/test_structural_tag_registry.py b/tests/tool_parsers/test_structural_tag_registry.py index bd84b2cbbfac..0a554bdc5cd0 100644 --- a/tests/tool_parsers/test_structural_tag_registry.py +++ b/tests/tool_parsers/test_structural_tag_registry.py @@ -91,6 +91,40 @@ def test_get_model_structural_tag_supports_all_xgrammar_builtins( assert isinstance(tag, StructuralTag) +def test_deepseek_v4_chat_structural_tag_allows_stray_think_end( + sample_tools: list[ChatCompletionToolsParam], +): + tag = get_model_structural_tag( + model="deepseek_v4", + tools=sample_tools, + tool_choice="auto", + reasoning=False, + ) + + dumped = tag.model_dump() + excludes = dumped["format"]["excludes"] + assert "" in excludes + assert "" not in excludes + + +def test_deepseek_v4_reasoning_structural_tag_is_unchanged( + sample_tools: list[ChatCompletionToolsParam], +): + tag = get_model_structural_tag( + model="deepseek_v4", + tools=sample_tools, + tool_choice="auto", + reasoning=True, + ) + + dumped = tag.model_dump() + assert dumped["format"]["type"] == "sequence" + assert dumped["format"]["elements"][1]["excludes"] == [ + "", + "", + ] + + def test_get_model_structural_tag_supports_vllm_hermes( sample_tools: list[ChatCompletionToolsParam], ): diff --git a/vllm/tool_parsers/structural_tag_registry.py b/vllm/tool_parsers/structural_tag_registry.py index 99c92f8f0a2e..0dfc879518b0 100644 --- a/vllm/tool_parsers/structural_tag_registry.py +++ b/vllm/tool_parsers/structural_tag_registry.py @@ -128,12 +128,26 @@ def get_model_structural_tag( supported = sorted(SUPPORTED_STRUCTURAL_TAG_MODELS) raise ValueError(f"Unknown format type: {model}, supported types: {supported}") - return get_xgrammar_model_structural_tag( + structural_tag = get_xgrammar_model_structural_tag( model=model, tools=dumped_tools, tool_choice=dumped_tool_choice, reasoning=reasoning, ) + if model == "deepseek_v4" and not reasoning: + _allow_deepseek_v4_chat_think_end(structural_tag) + return structural_tag + + +def _allow_deepseek_v4_chat_think_end(structural_tag: StructuralTag) -> None: + """Allow DSv4 chat-mode outputs to emit an extra ````.""" + fmt = structural_tag.format + if getattr(fmt, "type", None) != "triggered_tags": + return + excludes = getattr(fmt, "excludes", None) + if not isinstance(excludes, list): + return + fmt.excludes = [exclude for exclude in excludes if exclude != ""] def _dump_tool_for_xgrammar( From c13cae00bcecca71b6a13d9ab924eff9b285aaf5 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 3 Jun 2026 03:58:39 +0800 Subject: [PATCH 098/131] sm12x: support FlashInfer CUTLASS MXFP4 opt-in Signed-off-by: jasl (cherry picked from commit 134f3a6365b6a2e328481325a91a2885a739e1d1) --- .../test_flashinfer_cutlass_mxfp4_config.py | 173 ++++++++++++++++++ .../experts/flashinfer_cutlass_moe.py | 27 +-- .../layers/fused_moe/oracle/mxfp4.py | 43 ++++- 3 files changed, 229 insertions(+), 14 deletions(-) create mode 100644 tests/kernels/moe/test_flashinfer_cutlass_mxfp4_config.py diff --git a/tests/kernels/moe/test_flashinfer_cutlass_mxfp4_config.py b/tests/kernels/moe/test_flashinfer_cutlass_mxfp4_config.py new file mode 100644 index 000000000000..6bb12fde710b --- /dev/null +++ b/tests/kernels/moe/test_flashinfer_cutlass_mxfp4_config.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import sys +import types + +import torch + +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + RoutingMethodType, + mxfp4_mxfp8_moe_quant_config, +) +from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutlass_moe import ( + FlashInferExperts, +) +from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( + Mxfp4MoeBackend, + convert_weight_to_mxfp4_moe_kernel_format, +) + + +def _make_moe_config() -> FusedMoEConfig: + return FusedMoEConfig( + num_experts=2, + experts_per_token=1, + hidden_dim=16, + intermediate_size_per_partition=16, + num_local_experts=2, + num_logical_experts=2, + activation=MoEActivation.SILU, + device="cpu", + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + in_dtype=torch.bfloat16, + routing_method=RoutingMethodType.TopK, + max_num_tokens=16, + ) + + +def _make_experts( + *, + gemm1_alpha: float | None = None, + gemm1_beta: float | None = None, + gemm1_clamp_limit: float | None = None, +) -> FlashInferExperts: + quant_config = mxfp4_mxfp8_moe_quant_config( + w1_scale=torch.ones((2, 32, 1), dtype=torch.float8_e4m3fn), + w2_scale=torch.ones((2, 16, 1), dtype=torch.float8_e4m3fn), + gemm1_alpha=gemm1_alpha, + gemm1_beta=gemm1_beta, + gemm1_clamp_limit=gemm1_clamp_limit, + ) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return FlashInferExperts( + moe_config=_make_moe_config(), + quant_config=quant_config, + ) + + +def test_mxfp4_swiglu_parameters_stay_unset_without_quant_config() -> None: + experts = _make_experts() + + assert experts.gemm1_alpha is None + assert experts.gemm1_beta is None + assert experts.gemm1_clamp_limit is None + + +def test_mxfp4_swiglu_parameters_follow_quant_config() -> None: + experts = _make_experts( + gemm1_alpha=1.25, + gemm1_beta=0.75, + gemm1_clamp_limit=5.5, + ) + + torch.testing.assert_close(experts.gemm1_alpha, torch.tensor([1.25, 1.25])) + torch.testing.assert_close(experts.gemm1_beta, torch.tensor([0.75, 0.75])) + torch.testing.assert_close( + experts.gemm1_clamp_limit, + torch.tensor([5.5, 5.5]), + ) + + +def test_cutlass_mxfp8_kernel_format_converts_gate_up_layout(monkeypatch) -> None: + monkeypatch.setitem( + sys.modules, + "flashinfer", + types.SimpleNamespace(block_scale_interleave=lambda x: x.contiguous()), + ) + + num_experts = 1 + intermediate_size = 64 + hidden_size = 64 + packed_hidden_size = hidden_size // 2 + sf_block_size = 32 + + w13_weight = torch.arange( + num_experts * 2 * intermediate_size * packed_hidden_size, + dtype=torch.uint8, + ).reshape(num_experts, 2 * intermediate_size, packed_hidden_size) + w2_weight = torch.arange( + num_experts * hidden_size * (intermediate_size // 2), + dtype=torch.uint8, + ).reshape(num_experts, hidden_size, intermediate_size // 2) + w13_scale_u8 = torch.arange( + num_experts * 2 * intermediate_size * (hidden_size // sf_block_size), + dtype=torch.uint8, + ).reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size) + w2_scale_u8 = torch.arange( + num_experts * hidden_size * (intermediate_size // sf_block_size), + dtype=torch.uint8, + ).reshape(num_experts, hidden_size, intermediate_size // sf_block_size) + w13_bias = torch.arange( + num_experts * 2 * intermediate_size, + dtype=torch.bfloat16, + ).reshape(num_experts, 2 * intermediate_size) + w2_bias = torch.arange( + num_experts * hidden_size, + dtype=torch.bfloat16, + ).reshape(num_experts, hidden_size) + + ( + out_w13, + out_w2, + out_w13_scale, + out_w2_scale, + out_w13_bias, + out_w2_bias, + ) = convert_weight_to_mxfp4_moe_kernel_format( + mxfp4_backend=Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8, + layer=torch.nn.Module(), + w13_weight=w13_weight, + w2_weight=w2_weight, + w13_weight_scale=w13_scale_u8.view(torch.float8_e4m3fn), + w2_weight_scale=w2_scale_u8.view(torch.float8_e4m3fn), + w13_bias=w13_bias, + w2_bias=w2_bias, + ) + + expected_w13 = torch.cat( + [ + w13_weight[:, intermediate_size:, :], + w13_weight[:, :intermediate_size, :], + ], + dim=1, + ) + expected_w13_scale = torch.cat( + [ + w13_scale_u8[:, intermediate_size:, :], + w13_scale_u8[:, :intermediate_size, :], + ], + dim=1, + ) + expected_w13_bias = torch.cat( + [ + w13_bias[:, intermediate_size:], + w13_bias[:, :intermediate_size], + ], + dim=1, + ) + + assert out_w13.is_contiguous() + assert out_w2.is_contiguous() + torch.testing.assert_close(out_w13, expected_w13) + torch.testing.assert_close(out_w2, w2_weight) + torch.testing.assert_close(out_w13_scale, expected_w13_scale) + torch.testing.assert_close(out_w2_scale, w2_scale_u8) + torch.testing.assert_close(out_w13_bias, expected_w13_bias) + torch.testing.assert_close(out_w2_bias, w2_bias) diff --git a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutlass_moe.py index 76cd15ff5a0c..cd0b7ac4a54f 100644 --- a/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/flashinfer_cutlass_moe.py @@ -103,20 +103,24 @@ def __init__( ) if quant_config.weight_quant_dtype == "mxfp4": - # This value is used specifically for gpt-oss, - # Need to revisit this for other models - self.gemm1_alpha = torch.tensor( - [1.702] * self.num_experts, dtype=torch.float32, device=self.device - ) - self.gemm1_beta = torch.tensor( - [1.0] * self.num_experts, dtype=torch.float32, device=self.device + self.gemm1_alpha = ( + torch.tensor( + [quant_config.gemm1_alpha] * self.num_experts, + dtype=torch.float32, + device=self.device, + ) + if quant_config.gemm1_alpha is not None + else None ) - if self.gemm1_clamp_limit is None: - self.gemm1_clamp_limit = torch.tensor( - [7.0] * self.num_experts, + self.gemm1_beta = ( + torch.tensor( + [quant_config.gemm1_beta] * self.num_experts, dtype=torch.float32, device=self.device, ) + if quant_config.gemm1_beta is not None + else None + ) if quant_config.quant_dtype == "mxfp8": self.fake_input_scale = torch.ones( self.num_experts, @@ -325,9 +329,6 @@ def apply( elif self.weight_quant_dtype == "mxfp4": assert self.w1_scale is not None and self.w2_scale is not None assert w1.is_contiguous() and w2.is_contiguous() - assert self.gemm1_alpha is not None - assert self.gemm1_beta is not None - assert self.gemm1_clamp_limit is not None assert topk_ids.is_contiguous() fc1_expert_biases = self.w1_bias diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index ab76cea1327b..e5deade5e244 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -1191,7 +1191,7 @@ def convert_weight_to_mxfp4_moe_kernel_format( ]: """Convert loaded weights into backend-specific kernel format. - Supports DeepGEMM, TRTLLM MXFP8, Triton and Marlin backends. + Supports DeepGEMM, TRTLLM MXFP8, CUTLASS MXFP8, Triton and Marlin backends. """ if mxfp4_backend == Mxfp4MoeBackend.DEEPGEMM_MXFP4: @@ -1246,6 +1246,47 @@ def convert_weight_to_mxfp4_moe_kernel_format( sf_block_size = 32 # mxfp4 block size + if mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8: + from flashinfer import block_scale_interleave + + w13_weight = w13_weight.data + w2_weight = w2_weight.data + w13_weight_scale = w13_weight_scale.data + w2_weight_scale = w2_weight_scale.data + + # FlashInfer CUTLASS expects [up, gate], while vLLM stores [gate, up]. + w1_weight = w13_weight[:, :intermediate_size, :] + w3_weight = w13_weight[:, intermediate_size:, :] + w13_weight = torch.cat([w3_weight, w1_weight], dim=1).contiguous() + + w1_scale = w13_weight_scale[:, :intermediate_size, :] + w3_scale = w13_weight_scale[:, intermediate_size:, :] + w13_weight_scale = torch.cat([w3_scale, w1_scale], dim=1).contiguous() + + if w13_bias is not None: + b1 = w13_bias[:, :intermediate_size] + b3 = w13_bias[:, intermediate_size:] + w13_bias = torch.cat([b3, b1], dim=1).contiguous() + + w13_scale_shape = w13_weight_scale.shape + w13_weight_scale = block_scale_interleave( + w13_weight_scale.view(torch.uint8) + ).reshape(w13_scale_shape) + + w2_scale_shape = w2_weight_scale.shape + w2_weight_scale = block_scale_interleave( + w2_weight_scale.view(torch.uint8) + ).reshape(w2_scale_shape) + + return ( + w13_weight, + w2_weight.contiguous(), + w13_weight_scale, + w2_weight_scale, + w13_bias, + w2_bias, + ) + if mxfp4_backend in TRTLLM_BACKENDS: assert _cache_permute_indices is not None from flashinfer.fp4_quantization import nvfp4_block_scale_interleave From 99a9f10e7add9c6c8564acf138558def04877636 Mon Sep 17 00:00:00 2001 From: carlosmolina0615 Date: Sun, 7 Jun 2026 12:57:44 +0200 Subject: [PATCH 099/131] deepseek_v4: route NVFP4-modelopt experts to ModelOptNvFp4FusedMoE DeepSeek-V4-Flash-NVFP4 (NVIDIA modelopt) sets expert_dtype=fp4 but its MoE experts are NVFP4 (weight_scale_2 + input_scale), not MXFP4. DeepseekV4FP8Config previously always used Mxfp4MoEMethod for fp4 experts, causing KeyError: experts.w13_input_scale on GB10 (sm_121). Detect moe_quant_algo==NVFP4 and use the existing ModelOptNvFp4FusedMoE for experts while keeping FP8 block for linear/attn. Adjust weights mapper to not apply the MXFP4 .scale rename to NVFP4 per-expert keys. (cherry picked from commit bf6a74e7eb06e199fc610c625211ac9d744ef529) (cherry picked from commit a34fff5da17603438d443dd0a2c35430694bff37) Signed-off-by: jasl --- vllm/envs.py | 26 +++++++++++++++++ .../model_executor/kernels/linear/__init__.py | 29 +++++++++++++++++-- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 2ec89025b6cd..962a747c9399 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -233,6 +233,7 @@ VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_HAS_FLASHINFER_CUBIN: bool = False + VLLM_NVFP4_GEMM_BACKEND: str | None = None VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True VLLM_ALLREDUCE_USE_FLASHINFER: bool = False @@ -1673,6 +1674,31 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_HAS_FLASHINFER_CUBIN": lambda: bool( int(os.getenv("VLLM_HAS_FLASHINFER_CUBIN", "0")) ), + # Selects the GEMM backend used for NVFP4-quantized linear/MoE layers. + # Supported options: + # - "flashinfer-b12x": use flashinfer b12x (Blackwell 12.x) NVFP4 kernels + # - "flashinfer-cudnn": use flashinfer cudnn GEMM backend + # - "flashinfer-trtllm": use flashinfer trtllm GEMM backend + # - "flashinfer-cutlass": use flashinfer cutlass GEMM backend + # - "marlin": use marlin GEMM backend (for GPUs without native FP4 support) + # - "emulation": + # use BF16/FP16 GEMM, dequantizing weights and running QDQ on activations. + # This is only meant for research purposes to run on devices where NVFP4 + # GEMM kernels are not available. + # - : automatically pick an available backend + "VLLM_NVFP4_GEMM_BACKEND": env_with_choices( + "VLLM_NVFP4_GEMM_BACKEND", + None, + [ + "flashinfer-b12x", + "flashinfer-cudnn", + "flashinfer-trtllm", + "flashinfer-cutlass", + "cutlass", + "marlin", + "emulation", + ], + ), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. # If set to 1, allows GC to run during capture. diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index 919d71fb8e8b..214110473651 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -400,8 +400,9 @@ def _filter_kernels_by_backend( _POSSIBLE_NVFP4_KERNELS: dict[PlatformEnum, list[type[NvFp4LinearKernel]]] = { PlatformEnum.CUDA: [ # FlashInferB12xNvFp4LinearKernel excluded from auto-selection until - # upstream CUTLASS SM121 MMA op guard is resolved; use - # --linear-backend flashinfer_b12x to opt in explicitly. + # upstream CUTLASS SM121 MMA op guard is resolved; opt in explicitly + # via --linear-backend flashinfer_b12x or + # VLLM_NVFP4_GEMM_BACKEND=flashinfer-b12x. FlashInferCutlassNvFp4LinearKernel, CutlassNvFp4LinearKernel, MarlinNvFp4LinearKernel, @@ -831,6 +832,20 @@ def init_wfp8_a16_linear_kernel( ) +# Maps VLLM_NVFP4_GEMM_BACKEND env var values to kernel classes. This is an +# env-driven alternative to --linear-backend for forcing a specific NVFP4 +# linear kernel (e.g. VLLM_NVFP4_GEMM_BACKEND=flashinfer-b12x). +_NVFP4_BACKEND_TO_KERNEL: dict[str, type[NvFp4LinearKernel]] = { + "flashinfer-b12x": FlashInferB12xNvFp4LinearKernel, + "flashinfer-cutlass": FlashInferCutlassNvFp4LinearKernel, + "cutlass": CutlassNvFp4LinearKernel, + "marlin": MarlinNvFp4LinearKernel, + "flashinfer-trtllm": FlashInferTrtllmNvFp4LinearKernel, + "flashinfer-cudnn": FlashInferCudnnNvFp4LinearKernel, + "emulation": EmulationNvFp4LinearKernel, +} + + def init_nvfp4_linear_kernel(use_a16: bool = False) -> NvFp4LinearKernel: """Select and instantiate the best NVFP4 linear kernel for the current platform.""" @@ -870,6 +885,16 @@ def init_nvfp4_linear_kernel(use_a16: bool = False) -> NvFp4LinearKernel: reason, ) force_kernel = EmulationNvFp4LinearKernel + elif envs.VLLM_NVFP4_GEMM_BACKEND is not None: + # Env-driven override (alternative to --linear-backend). Maps a + # VLLM_NVFP4_GEMM_BACKEND value to a concrete kernel class. + backend_name = envs.VLLM_NVFP4_GEMM_BACKEND + force_kernel = _NVFP4_BACKEND_TO_KERNEL.get(backend_name) + if force_kernel is None: + raise ValueError( + f"Unknown VLLM_NVFP4_GEMM_BACKEND={backend_name!r}. " + f"Valid choices: {list(_NVFP4_BACKEND_TO_KERNEL.keys())}" + ) elif linear_backend == "auto" and use_a16: # Force a16 (Marlin) when running weight-only quantization. force_kernel = MarlinNvFp4LinearKernel From b4c1cab5fa7fe41cced808030a07d9f364f25a04 Mon Sep 17 00:00:00 2001 From: jasl Date: Mon, 15 Jun 2026 19:48:16 +0800 Subject: [PATCH 100/131] deepseek_v4: allow flashinfer_cutlass NVFP4 MoE for SwiGLU-clamp models The flashinfer_cutlass NVFP4 MoE already supports SM12x (family 100|120) and applies the SwiGLU clamp (populates gemm1_clamp_limit, passes swiglu_limit to flashinfer_cutlass_fused_moe), but NVFP4_BACKENDS_WITH_CLAMP listed only FLASHINFER_TRTLLM (SM100-only). DeepSeek-V4-Flash-NVFP4 (swiglu_limit=10.0) was therefore unservable on SM12x. Add FLASHINFER_CUTLASS to the clamp set so --moe-backend flashinfer_cutlass serves NVFP4 on RTX SM120 / GB10 SM121 with no FlashInfer upgrade. Verified: loads (77.95 GiB), serves, FLASHINFER_CUTLASS selected. Signed-off-by: jasl (cherry picked from commit 401d3174c3e8dc8eddf9f54ab0ea3cbb693bd5b9) --- vllm/model_executor/layers/fused_moe/oracle/nvfp4.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index 93bc81c22be4..2f98e9285977 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -174,6 +174,7 @@ def select_nvfp4_moe_backend( NVFP4_BACKENDS_WITH_CLAMP = { NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_CUTLASS, } if config.swiglu_limit is not None: From 2c547ec52894fd8881db54d27374bd6ab2b8ea55 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 16 Jun 2026 04:58:45 +0800 Subject: [PATCH 101/131] deepseek-v4: integrate upstream adaptive prefill chunk planning (#45061) Replace the fixed PREFILL_CHUNK_SIZE chunking + batch-wide workspace bound in the SM12x sparse-MLA prefill path with upstream's adaptive get_prefill_chunk_plan (vllm-project/vllm#45061): pack as many requests as fit the workspace-area bound per chunk and allocate the kv workspace per-chunk at this chunk's compressed+gather width (chunk_M) instead of the batch-wide worst case. Preserves the SM12x indexed-D512 split/chunked prefill paths and the Triton sparse-MLA dispatch. Signed-off-by: jasl --- vllm/models/deepseek_v4/nvidia/flashmla.py | 111 ++++++++++--------- vllm/v1/attention/backends/mla/sparse_swa.py | 97 +++++++++++++++- 2 files changed, 150 insertions(+), 58 deletions(-) diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index dba35cba992b..cb0dbcc083b4 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -904,24 +904,24 @@ def _forward_prefill( assert attn_metadata is not None topk_indices = attn_metadata.c128a_prefill_topk_indices top_k = topk_indices.shape[-1] - # Compressed region must fit the full compressed pool (seq_len // - # compress_ratio), not just top_k. top_k bounds how many indices - # the indexer selects, not the pool size it indexes into. - N = int((seq_lens_cpu // self.compress_ratio).max().item()) else: # NOTE(woosuk): topk_indices will not be used for SWA-only layers. assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[num_decode_tokens:] top_k = 0 - N = 0 - M = N + int(gather_lens_cpu.max().item()) - chunk_size_const = self.PREFILL_CHUNK_SIZE - num_chunks = (num_prefills + chunk_size_const - 1) // chunk_size_const + # Adaptive prefill chunk plan (#45061): pack as many requests as fit the + # workspace-area bound into each chunk, with per-chunk compressed (chunk_N) + # and total (chunk_M) widths. Replaces the fixed PREFILL_CHUNK_SIZE + # chunking with batch-wide M/N. + chunk_plan = swa_metadata.get_prefill_chunk_plan( + compress_ratio=int(self.compress_ratio), + prefill_chunk_size=self.PREFILL_CHUNK_SIZE, + ) + assert chunk_plan, "prefill chunk plan must be non-empty when num_prefills > 0" + max_query_chunk_tokens = 0 - for chunk_idx in range(num_chunks): - chunk_start = chunk_idx * chunk_size_const - chunk_end = min(chunk_start + chunk_size_const, num_prefills) + for chunk_start, chunk_end, _chunk_n, _chunk_m in chunk_plan: query_start = ( query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base ) @@ -937,6 +937,7 @@ def _forward_prefill( triton_sparse_mla_enabled = is_triton_sparse_mla_enabled(q.device) indexed_d512_split_prefill = False indexed_d512_chunked_prefill = False + extra_specs: list[tuple[tuple[int, ...], torch.dtype]] = [] if triton_sparse_mla_enabled: query_chunk_size = min( max_query_chunk_tokens, @@ -959,7 +960,6 @@ def _forward_prefill( max_prefill_seq_len=int(seq_lens_cpu.max().item()), swa_only=swa_only, ) - extra_specs: list[tuple[tuple[int, ...], torch.dtype]] = [] if indexed_d512_split_prefill: extra_specs.append( ( @@ -986,44 +986,51 @@ def _forward_prefill( ), ) ) - ( - kv, - combined_indices_buffer, - combined_lens_buffer, - max_score_buffer, - denom_buffer, - output_buffer, - *extra_state_buffers, - ) = workspace_manager.get_simultaneous( - ((chunk_size_const, M, q.shape[-1]), torch.bfloat16), - ((max_query_chunk_tokens, combined_topk), torch.int32), - ((max_query_chunk_tokens,), torch.int32), - ((query_chunk_size, self.n_local_heads), torch.float32), - ((query_chunk_size, self.n_local_heads), torch.float32), - ((query_chunk_size, self.n_local_heads, q.shape[-1]), torch.float32), - *extra_specs, - ) - prefill_state_buffers = ( - max_score_buffer, - denom_buffer, - output_buffer, - *extra_state_buffers, - ) - else: - ( - kv, - combined_indices_buffer, - combined_lens_buffer, - ) = workspace_manager.get_simultaneous( - ((chunk_size_const, M, q.shape[-1]), torch.bfloat16), - ((max_query_chunk_tokens, combined_topk), torch.int32), - ((max_query_chunk_tokens,), torch.int32), - ) - prefill_state_buffers = None - for chunk_idx in range(num_chunks): - chunk_start = chunk_idx * chunk_size_const - chunk_end = min(chunk_start + chunk_size_const, num_prefills) + + # Per-chunk workspace allocation (#45061): the kv buffer width is this + # chunk's compressed+gather width (chunk_m), keeping the area bounded by + # the planner's max_workspace_area instead of the batch-wide worst case. + for chunk_start, chunk_end, chunk_n, chunk_m in chunk_plan: chunk_size = chunk_end - chunk_start + if triton_sparse_mla_enabled: + ( + kv, + combined_indices_buffer, + combined_lens_buffer, + max_score_buffer, + denom_buffer, + output_buffer, + *extra_state_buffers, + ) = workspace_manager.get_simultaneous( + ((chunk_size, chunk_m, q.shape[-1]), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ((query_chunk_size, self.n_local_heads), torch.float32), + ((query_chunk_size, self.n_local_heads), torch.float32), + ( + (query_chunk_size, self.n_local_heads, q.shape[-1]), + torch.float32, + ), + *extra_specs, + ) + prefill_state_buffers = ( + max_score_buffer, + denom_buffer, + output_buffer, + *extra_state_buffers, + ) + else: + ( + kv, + combined_indices_buffer, + combined_lens_buffer, + ) = workspace_manager.get_simultaneous( + ((chunk_size, chunk_m, q.shape[-1]), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ) + prefill_state_buffers = None + if not swa_only: # Gather compressed KV assert attn_metadata is not None @@ -1047,7 +1054,7 @@ def _forward_prefill( gather_lens=gather_lens[chunk_start:chunk_end], block_table=swa_block_table[chunk_start:chunk_end], block_size=swa_metadata.block_size, - offset=N, + offset=chunk_n, ) # Combine the topk indices and SWA indices for gathered KV cache @@ -1068,8 +1075,8 @@ def _forward_prefill( self.window_size, self.compress_ratio, top_k, - M, - N, + chunk_m, + chunk_n, combined_indices=combined_indices_buffer, combined_lens=combined_lens_buffer, ) diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index 4455cba25e90..a4511096116f 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -9,6 +9,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.platforms import current_platform from vllm.triton_utils import tl, triton +from vllm.utils.math_utils import cdiv from vllm.v1.attention.backend import ( AttentionBackend, AttentionCGSupport, @@ -178,6 +179,11 @@ class DeepseekSparseSWAMetadata: prefill_gather_lens: torch.Tensor | None = None prefill_seq_lens_cpu: torch.Tensor | None = None prefill_gather_lens_cpu: torch.Tensor | None = None + # Inputs to the adaptive prefill chunk planner (#45061). + prefill_query_lens_cpu: torch.Tensor | None = None + prefill_window_size: int = 0 + prefill_max_model_len: int = 0 + prefill_max_num_batched_tokens: int = 0 # Per-layer-type FlashMLA tile-scheduler metadata. One FlashMLASchedMeta # per present DeepseekV4 layer type, shared across all ~60 layers of that type @@ -196,6 +202,79 @@ class DeepseekSparseSWAMetadata: default_factory=dict ) + def get_prefill_chunk_plan( + self, compress_ratio: int, prefill_chunk_size: int + ) -> list[tuple[int, int, int, int]]: + if self.num_prefills == 0: + return [] + + assert self.prefill_seq_lens_cpu is not None + assert self.prefill_query_lens_cpu is not None + + # query_len <= max_num_batched_tokens and + # gather_len = query_len + min(prefix_len, window_size - 1), so the + # worst-case gathered width is bounded by + # max_num_batched_tokens + window_size - 1. The compressed prefix pool + # is bounded by ceil(max_model_len / compress_ratio). + max_workspace_area = prefill_chunk_size * ( + ( + 0 + if compress_ratio <= 1 + else cdiv(self.prefill_max_model_len, compress_ratio) + ) + + self.prefill_window_size + + self.prefill_max_num_batched_tokens + ) + prefix_lens_cpu = self.prefill_seq_lens_cpu - self.prefill_query_lens_cpu + gather_lens_cpu = self.prefill_query_lens_cpu + torch.clamp( + prefix_lens_cpu, min=0, max=self.prefill_window_size - 1 + ) + compressed_lens_cpu = ( + torch.zeros_like(self.prefill_seq_lens_cpu) + if compress_ratio <= 1 + else torch.div( + self.prefill_seq_lens_cpu, + compress_ratio, + rounding_mode="floor", + ) + ) + + chunk_plan: list[tuple[int, int, int, int]] = [] + chunk_start = 0 + while chunk_start < self.num_prefills: + chunk_max_compressed = int(compressed_lens_cpu[chunk_start].item()) + chunk_max_gather = int(gather_lens_cpu[chunk_start].item()) + chunk_end = chunk_start + 1 + + while chunk_end < self.num_prefills: + candidate_max_compressed = max( + chunk_max_compressed, + int(compressed_lens_cpu[chunk_end].item()), + ) + candidate_max_gather = max( + chunk_max_gather, + int(gather_lens_cpu[chunk_end].item()), + ) + candidate_width = candidate_max_compressed + candidate_max_gather + candidate_area = (chunk_end - chunk_start + 1) * candidate_width + if candidate_area > max_workspace_area: + break + chunk_max_compressed = candidate_max_compressed + chunk_max_gather = candidate_max_gather + chunk_end += 1 + + chunk_plan.append( + ( + chunk_start, + chunk_end, + chunk_max_compressed, + chunk_max_compressed + chunk_max_gather, + ) + ) + chunk_start = chunk_end + + return chunk_plan + class DeepseekSparseSWAMetadataBuilder(AttentionMetadataBuilder): """Builds metadata for DeepseekV4 SWA cache. @@ -231,6 +310,10 @@ def __init__(self, *args, **kwargs): hf_config = self.vllm_config.model_config.hf_config assert hasattr(hf_config, "sliding_window") self.window_size = hf_config.sliding_window + self.max_model_len = self.vllm_config.model_config.max_model_len + self.max_num_batched_tokens = ( + self.vllm_config.scheduler_config.max_num_batched_tokens + ) # Detect which DeepseekV4 layer types this model uses so we only build a # FlashMLA tile-scheduler plan for types that will actually be called. @@ -353,7 +436,7 @@ def build( tile_sched_swaonly=tile_sched[_LAYER_TYPE_SWAONLY], tile_sched_c4a=tile_sched[_LAYER_TYPE_C4A], tile_sched_c128a=tile_sched[_LAYER_TYPE_C128A], - **deepseek_v4_fields, + **deepseek_v4_fields, # type: ignore[arg-type] ) def build_tile_scheduler( @@ -399,7 +482,7 @@ def _build_deepseek_v4_metadata( query_start_loc: torch.Tensor, query_start_loc_cpu: torch.Tensor, seq_lens_cpu_upper_bound: torch.Tensor | None, - ) -> dict[str, torch.Tensor | None]: + ) -> dict[str, torch.Tensor | int | None]: """Pre-compute DeepseekV4 prefill metadata during the metadata build phase. Returns a dict of keyword arguments to pass to the @@ -408,7 +491,7 @@ def _build_deepseek_v4_metadata( Note: C128A topk indices are computed by the FlashMLASparse builder (which owns the C128A block_table), not here. """ - result: dict[str, torch.Tensor | None] = {} + result: dict[str, torch.Tensor | int | None] = {} # --- Prefill query metadata (single Triton kernel + CPU slicing) --- if num_prefills > 0: @@ -431,9 +514,7 @@ def _build_deepseek_v4_metadata( num_decodes : num_decodes + num_prefills ] query_lens_cpu = ( - query_start_loc_cpu[ - num_decodes + 1 : num_decodes + num_prefills + 1 - ] + query_start_loc_cpu[num_decodes + 1 : num_decodes + num_prefills + 1] - query_start_loc_cpu[num_decodes : num_decodes + num_prefills] ) prefix_lens_cpu = prefill_seq_lens_cpu - query_lens_cpu @@ -446,6 +527,10 @@ def _build_deepseek_v4_metadata( result["prefill_gather_lens"] = pfx_gather_lens result["prefill_seq_lens_cpu"] = prefill_seq_lens_cpu result["prefill_gather_lens_cpu"] = prefill_gather_lens_cpu + result["prefill_query_lens_cpu"] = query_lens_cpu.to(dtype=torch.int32) + result["prefill_window_size"] = self.window_size + result["prefill_max_model_len"] = self.max_model_len + result["prefill_max_num_batched_tokens"] = self.max_num_batched_tokens return result From a1135c0cbc73de21ec866b8b7d7388a3943bd7d6 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 16 Jun 2026 05:08:38 +0800 Subject: [PATCH 102/131] deepseek-v4: guard sparse MLA decode metadata against None (mypy) The rebase onto upstream/main surfaced 9 mypy errors in the sparse-MLA decode kernels where decode_swa_lens / decode_swa_indices / seq_lens (typed torch.Tensor | None) are indexed without a None-guard. They were uncaught because git rebase skips pre-commit hooks. The fields are unconditionally populated when num_decode_tokens > 0 (the only path that reaches the decode kernels), so assert-guard them. Signed-off-by: jasl --- vllm/models/deepseek_v4/nvidia/flashmla.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index cb0dbcc083b4..5fb3d9137c61 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -351,6 +351,11 @@ def _forward_sparse_mla_swa_decode_triton( num_decode_tokens = swa_metadata.num_decode_tokens mtp_decode = num_decode_tokens != num_decodes + # Decode metadata is unconditionally populated when num_decode_tokens > 0, + # which is the only path that reaches the decode kernels. + assert swa_metadata.decode_swa_lens is not None + assert swa_metadata.decode_swa_indices is not None + assert swa_metadata.seq_lens is not None swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens] swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens] max_swa_len = swa_metadata.decode_swa_indices.shape[-1] @@ -432,6 +437,11 @@ def _forward_sparse_mla_compressed_decode_triton( num_decode_tokens = swa_metadata.num_decode_tokens mtp_decode = num_decode_tokens != num_decodes + # Decode metadata is unconditionally populated when num_decode_tokens > 0, + # which is the only path that reaches the decode kernels. + assert swa_metadata.decode_swa_lens is not None + assert swa_metadata.decode_swa_indices is not None + assert swa_metadata.seq_lens is not None max_swa_len = swa_metadata.decode_swa_indices.shape[-1] compressed_block_size = attn_metadata.block_size // layer.compress_ratio compressed_topk = topk_indices.shape[-1] From 59c7918b16a0f76b7afef024599ced639c6c6ce1 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 16 Jun 2026 05:23:06 +0800 Subject: [PATCH 103/131] chore: restore mypy + ruff cleanliness after upstream rebase git rebase replays commits without running pre-commit hooks, so the rebased branch carried pre-existing type-safety gaps and ruff-format drift in our changed files (surfaced under the newer upstream config). Fixes: - kernel_warmup: drop the phantom _disable_sparse_mla_prefill_stats reference (the symbol was never defined anywhere; the stats-disable wrapper collapses to a direct warmup call) and type: ignore the intentional SimpleNamespace warmup batch for apply_grammar_bitmask - single_type_kv_cache_manager: access MLA-spec subtype attributes via getattr (mypy-safe, identical runtime) - llm_base_proposer: assert the draft temperature tensor is non-None (it is used unconditionally below) - fused_moe: annotate config_file_paths: list[str] - test_deepseek_v4_mega_moe: annotate the mixed-type calls list - ruff-format drift across the touched files No runtime behavior change except the warmup phantom, which previously raised ImportError when VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH was set. Signed-off-by: jasl --- .../compile/passes/test_functionalization.py | 4 +--- ...st_deepseek_v4_flashmla_decode_dispatch.py | 4 +--- tests/models/test_deepseek_v4_mega_moe.py | 2 +- .../test_deepseekv4_reasoning_parser.py | 12 ++++------ .../test_indexer_deepseek_v4_slot_mapping.py | 4 +--- .../test_sm120_deepgemm_fallbacks.py | 8 ++----- tests/v1/core/test_scheduler.py | 12 ++++++---- tests/v1/spec_decode/test_mtp.py | 15 ++++-------- .../passes/utility/fix_functionalization.py | 8 ++----- vllm/config/vllm.py | 4 +--- vllm/envs.py | 7 +++--- .../layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/warmup/kernel_warmup.py | 24 +++++++------------ vllm/models/deepseek_v4/nvidia/model.py | 3 +-- .../nvidia/ops/sm12x_deep_gemm_fallbacks.py | 4 +--- .../deepseek_v4/nvidia/ops/sm12x_mqa.py | 4 +--- .../reasoning/deepseek_v4_reasoning_parser.py | 8 ++----- .../backends/mla/sparse_mla_kernels.py | 11 +++------ vllm/v1/core/block_pool.py | 3 +-- vllm/v1/core/kv_cache_coordinator.py | 4 +--- vllm/v1/core/sched/scheduler.py | 7 ++---- vllm/v1/core/single_type_kv_cache_manager.py | 6 ++--- vllm/v1/spec_decode/llm_base_proposer.py | 1 + 23 files changed, 54 insertions(+), 103 deletions(-) diff --git a/tests/compile/passes/test_functionalization.py b/tests/compile/passes/test_functionalization.py index 6369cb8c980f..3feb90fe0c56 100644 --- a/tests/compile/passes/test_functionalization.py +++ b/tests/compile/passes/test_functionalization.py @@ -289,9 +289,7 @@ def fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_fake( def forward( self, q: torch.Tensor, kv: torch.Tensor, k_cache: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( - q, kv, k_cache - ) + torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(q, kv, k_cache) return q, k_cache def example_inputs(self, num_tokens=32, hidden_size=128): diff --git a/tests/model_executor/test_deepseek_v4_flashmla_decode_dispatch.py b/tests/model_executor/test_deepseek_v4_flashmla_decode_dispatch.py index 06e034f6b01c..0955306a0f4b 100644 --- a/tests/model_executor/test_deepseek_v4_flashmla_decode_dispatch.py +++ b/tests/model_executor/test_deepseek_v4_flashmla_decode_dispatch.py @@ -49,9 +49,7 @@ def fake_decode( swa_metadata, output, ): - calls.append( - (layer, q.shape, swa_k_cache.shape, swa_metadata, output.shape) - ) + calls.append((layer, q.shape, swa_k_cache.shape, swa_metadata, output.shape)) monkeypatch.setattr(flashmla_mod, "is_triton_sparse_mla_enabled", lambda _: True) monkeypatch.setattr( diff --git a/tests/models/test_deepseek_v4_mega_moe.py b/tests/models/test_deepseek_v4_mega_moe.py index ca0ec781e269..93a68377e310 100644 --- a/tests/models/test_deepseek_v4_mega_moe.py +++ b/tests/models/test_deepseek_v4_mega_moe.py @@ -125,7 +125,7 @@ def test_deepseek_v4_mega_moe_finalize_uses_deep_gemm_wrapper(monkeypatch): ) transformed = (object(), object()) - calls = [] + calls: list[object] = [] def fake_runtime_check(self): calls.append("runtime_check") diff --git a/tests/reasoning/test_deepseekv4_reasoning_parser.py b/tests/reasoning/test_deepseekv4_reasoning_parser.py index 2dcaf45202db..fd3dff449f80 100644 --- a/tests/reasoning/test_deepseekv4_reasoning_parser.py +++ b/tests/reasoning/test_deepseekv4_reasoning_parser.py @@ -83,12 +83,8 @@ def test_registration_resolves_to_v4_class(): ({}, IdentityReasoningParser), ], ) -def test_dispatch_based_on_thinking_kwarg( - tokenizer, thinking_kwargs, expected_inner -): - parser = DeepSeekV4ReasoningParser( - tokenizer, chat_template_kwargs=thinking_kwargs - ) +def test_dispatch_based_on_thinking_kwarg(tokenizer, thinking_kwargs, expected_inner): + parser = DeepSeekV4ReasoningParser(tokenizer, chat_template_kwargs=thinking_kwargs) assert isinstance(parser._parser, expected_inner) @@ -195,7 +191,7 @@ def test_implicit_end_marker_in_isolated_delta(parser): def test_implicit_end_marker_within_delta_split(parser): """Marker appears partway through a delta — split it at the boundary.""" - delta_text = f"tail of reasoning{DSML_MARKER}\n<|DSML|invoke name=\"w\"" + delta_text = f'tail of reasoning{DSML_MARKER}\n<|DSML|invoke name="w"' delta = parser.extract_reasoning_streaming( previous_text="head ", current_text=f"head {delta_text}", @@ -206,7 +202,7 @@ def test_implicit_end_marker_within_delta_split(parser): ) assert delta is not None assert delta.reasoning == "tail of reasoning" - assert delta.content == f"{DSML_MARKER}\n<|DSML|invoke name=\"w\"" + assert delta.content == f'{DSML_MARKER}\n<|DSML|invoke name="w"' assert parser._implicit_end_seen is True diff --git a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py index b0d4b542ec61..eb535598fda6 100644 --- a/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py +++ b/tests/v1/attention/test_indexer_deepseek_v4_slot_mapping.py @@ -36,9 +36,7 @@ def fake_split_decodes_and_prefills( require_uniform=False, treat_short_extends_as_decodes=True, ): - captured["treat_short_extends_as_decodes"] = ( - treat_short_extends_as_decodes - ) + captured["treat_short_extends_as_decodes"] = treat_short_extends_as_decodes raise RuntimeError("stop after split_decodes_and_prefills") monkeypatch.setattr( diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py index 642bc7ef7fd5..178983449b2c 100644 --- a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -194,9 +194,7 @@ def test_sm120_mqa_direct_topk_uses_triton_logits_when_logits_fit( kv_fp8 = (kv * kv_scale.reciprocal()[:, None]).to(torch.float8_e4m3fn) weights = torch.randn(num_q, num_heads, device="cuda", dtype=torch.float32) cu_seqlen_ks = torch.arange(num_q, device="cuda", dtype=torch.int32) % 3 - cu_seqlen_ke = torch.full( - (num_q,), seq_len_kv, device="cuda", dtype=torch.int32 - ) + cu_seqlen_ke = torch.full((num_q,), seq_len_kv, device="cuda", dtype=torch.int32) out = torch.empty(num_q, topk_tokens, device="cuda", dtype=torch.int32) original_triton = sm12x_mqa.fp8_mqa_logits_triton @@ -274,9 +272,7 @@ def test_sm120_mqa_direct_topk_uses_triton_chunks_when_logits_do_not_fit( kv_fp8 = (kv * kv_scale.reciprocal()[:, None]).to(torch.float8_e4m3fn) weights = torch.randn(num_q, num_heads, device="cuda", dtype=torch.float32) cu_seqlen_ks = torch.arange(num_q, device="cuda", dtype=torch.int32) % 4 - cu_seqlen_ke = torch.full( - (num_q,), seq_len_kv, device="cuda", dtype=torch.int32 - ) + cu_seqlen_ke = torch.full((num_q,), seq_len_kv, device="cuda", dtype=torch.int32) out = torch.empty(num_q, topk_tokens, device="cuda", dtype=torch.int32) original_triton = sm12x_mqa.fp8_mqa_logits_triton diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 45cee40524e5..4067c1a9f027 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1339,10 +1339,14 @@ def test_running_very_long_prefill_defers_to_later_running_decode(): req_id_to_index = {} for index, req_id in enumerate(mixed_output.num_scheduled_tokens): req_id_to_index[req_id] = index - if req_id == short_req.request_id and ( - short_req.num_computed_tokens - + mixed_output.num_scheduled_tokens[req_id] - ) >= short_req.num_prompt_tokens: + if ( + req_id == short_req.request_id + and ( + short_req.num_computed_tokens + + mixed_output.num_scheduled_tokens[req_id] + ) + >= short_req.num_prompt_tokens + ): sampled_token_ids.append([0]) else: sampled_token_ids.append([]) diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index b4056fb67567..fa5416ae3edb 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -318,9 +318,7 @@ def test_mtp_propose_random_sampling_records_draft_probs(): assert torch.allclose(proposer._last_draft_probs, expected_probs) # ``take_last_draft_probs`` is the upstream-side accessor that # ``GPUModelRunner`` uses to plumb probs into the rejection sampler. - assert torch.equal( - proposer.take_last_draft_probs(), proposer._last_draft_probs - ) + assert torch.equal(proposer.take_last_draft_probs(), proposer._last_draft_probs) def test_mtp_sequential_drafting_passes_spec_step_indices(): @@ -395,8 +393,7 @@ def logits_for_token(token_id: int): for call in model_mock.compute_logits.call_args_list ] == [0, 1] assert [ - call.kwargs.get("spec_step_idx", 0) - for call in model_mock.call_args_list + call.kwargs.get("spec_step_idx", 0) for call in model_mock.call_args_list ] == [0, 1] @@ -492,13 +489,9 @@ def test_mtp_parallel_drafting_random_sampling_records_draft_probs(): assert result.shape == (batch_size, num_spec_tokens) assert proposer._last_draft_probs is not None - assert proposer._last_draft_probs.shape == ( - batch_size, num_spec_tokens, vocab_size - ) + assert proposer._last_draft_probs.shape == (batch_size, num_spec_tokens, vocab_size) assert torch.allclose( proposer._last_draft_probs, torch.softmax(logits, dim=-1).view(batch_size, num_spec_tokens, vocab_size), ) - assert torch.equal( - proposer.take_last_draft_probs(), proposer._last_draft_probs - ) + assert torch.equal(proposer.take_last_draft_probs(), proposer._last_draft_probs) diff --git a/vllm/compilation/passes/utility/fix_functionalization.py b/vllm/compilation/passes/utility/fix_functionalization.py index 544a99afd2a7..75994ba58039 100644 --- a/vllm/compilation/passes/utility/fix_functionalization.py +++ b/vllm/compilation/passes/utility/fix_functionalization.py @@ -45,15 +45,11 @@ def __call__(self, graph: torch.fx.Graph) -> None: rope_targets.append( torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default ) - if hasattr( - torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert" - ): + if hasattr(torch.ops._C, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert"): fused_deepseek_v4_mla_targets.append( torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert.default ) - if hasattr( - torch.ops.vllm, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert" - ): + if hasattr(torch.ops.vllm, "fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert"): fused_deepseek_v4_mla_targets.append( torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert.default ) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 634f0e9281ec..4d17ed6e91c8 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1104,9 +1104,7 @@ def __post_init__(self): self.model_config is not None and "VLLM_USE_BREAKABLE_CUDAGRAPH" not in os.environ and ( - _should_auto_enable_deepseek_v4_breakable_cudagraph( - self.model_config - ) + _should_auto_enable_deepseek_v4_breakable_cudagraph(self.model_config) or any( a in ( diff --git a/vllm/envs.py b/vllm/envs.py index 962a747c9399..11c06e5ccfa6 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1623,10 +1623,9 @@ def _resolve_rust_frontend_path() -> str | None: os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1") ), # Enforce function parameter schemas in structural-tag based tool calling. - "VLLM_ENFORCE_STRICT_TOOL_CALLING": lambda: os.getenv( - "VLLM_ENFORCE_STRICT_TOOL_CALLING", "True" - ).lower() - in ("true", "1"), + "VLLM_ENFORCE_STRICT_TOOL_CALLING": lambda: ( + os.getenv("VLLM_ENFORCE_STRICT_TOOL_CALLING", "True").lower() in ("true", "1") + ), # Control the max chunk bytes (in MB) for the rpc message queue. # Object larger than this threshold will be broadcast to worker # processes via zmq. diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 5b11243edeed..3636c7a6dcd1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1074,7 +1074,7 @@ def get_moe_configs( block_shape = [block_n, block_k] if block_n and block_k else None json_file_names = _get_config_file_names(E, N, dtype, block_shape) - config_file_paths = [] + config_file_paths: list[str] = [] # note that we prioritize user defined config user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index b23e13668048..3a5663f47f9d 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -111,8 +111,7 @@ def _deepseek_v4_mtp_uniform_decode_warmup_requests( _DEEPSEEK_V4_MTP_UNIFORM_DECODE_MAX_WARMUP_REQUESTS, ) candidates = sorted( - set(_DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS) - | {max_warmup_reqs} + set(_DEEPSEEK_V4_MTP_UNIFORM_DECODE_WARMUP_REQUESTS) | {max_warmup_reqs} ) return tuple(reqs for reqs in candidates if reqs <= max_warmup_reqs) @@ -149,9 +148,9 @@ def _deepseek_v4_slot_mapping_warmup(runner: "GPUModelRunner") -> None: ) if hasattr(runner, "positions"): - saved_positions: torch.Tensor | None = ( - runner.positions[:num_tokens].clone() - ) + saved_positions: torch.Tensor | None = runner.positions[ + :num_tokens + ].clone() runner.positions[:num_tokens].copy_(positions_source) positions = runner.positions[:num_tokens] else: @@ -197,7 +196,10 @@ def _deepseek_v4_structured_output_bitmask_warmup( ) input_batch = SimpleNamespace(req_ids=req_ids) apply_grammar_bitmask( - SchedulerOutput.make_empty(), grammar_output, input_batch, logits + SchedulerOutput.make_empty(), + grammar_output, + input_batch, # type: ignore[arg-type] + logits, ) @@ -489,15 +491,7 @@ def kernel_warmup(worker: "Worker"): minimax_m3_msa_warmup(worker) - if envs.VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH: - from vllm.models.deepseek_v4.nvidia.flashmla import ( - _disable_sparse_mla_prefill_stats, - ) - - with _disable_sparse_mla_prefill_stats(): - _deepseek_v4_sparse_mla_attention_warmup(worker) - else: - _deepseek_v4_sparse_mla_attention_warmup(worker) + _deepseek_v4_sparse_mla_attention_warmup(worker) _deepseek_v4_request_prep_warmup(worker) enable_flashinfer_autotune = ( diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 54492781b6d4..90be150fb516 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -687,8 +687,7 @@ def first_defined(*values): or self.n_local_physical_experts is None ): raise AttributeError( - "DeepseekV4MoE FusedMoE metadata is incomplete after " - "construction." + "DeepseekV4MoE FusedMoE metadata is incomplete after construction." ) self.n_local_experts = self.n_local_physical_experts self.n_redundant_experts = self.n_physical_experts - self.n_logical_experts diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py index 64351957917d..d2f157be884f 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py @@ -262,9 +262,7 @@ def _fp8_mqa_logits_topk_triton( select_k, ) selected.add_(cu_seqlen_ks[:, None]) - valid = (selected >= cu_seqlen_ks[:, None]) & ( - selected < cu_seqlen_ke[:, None] - ) + valid = (selected >= cu_seqlen_ks[:, None]) & (selected < cu_seqlen_ke[:, None]) selected.masked_fill_(~valid, -1) else: values, indices = torch.topk(logits, select_k, dim=1) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py index a8d0c8ce58ce..3116a8d90c98 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py @@ -20,9 +20,7 @@ def _view_packed_fp8_paged_mqa_kv_cache( elif kv_cache.dim() == 4: num_blocks, block_size, num_kv_heads, head_dim_with_scale = kv_cache.shape else: - raise ValueError( - f"Expected 3D or 4D kv_cache, got {kv_cache.dim()} dimensions" - ) + raise ValueError(f"Expected 3D or 4D kv_cache, got {kv_cache.dim()} dimensions") if num_kv_heads != 1: raise ValueError(f"Expected one KV head, got {num_kv_heads}") diff --git a/vllm/reasoning/deepseek_v4_reasoning_parser.py b/vllm/reasoning/deepseek_v4_reasoning_parser.py index 8b19d901da7a..fe38ca84c03b 100644 --- a/vllm/reasoning/deepseek_v4_reasoning_parser.py +++ b/vllm/reasoning/deepseek_v4_reasoning_parser.py @@ -21,9 +21,7 @@ # invoke a tool. The reasoning parser treats these as an implicit # end-of-reasoning when the explicit token is missing, which the # model occasionally fails to emit at long context. -_DSV4_TOOL_CALL_IMPLICIT_END_MARKERS: tuple[str, ...] = ( - "<|DSML|tool_calls>", -) +_DSV4_TOOL_CALL_IMPLICIT_END_MARKERS: tuple[str, ...] = ("<|DSML|tool_calls>",) class DeepSeekV4ThinkingReasoningParser(DeepSeekR1ReasoningParser): @@ -259,9 +257,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs): self._parser: ReasoningParser if thinking: - self._parser = DeepSeekV4ThinkingReasoningParser( - tokenizer, *args, **kwargs - ) + self._parser = DeepSeekV4ThinkingReasoningParser(tokenizer, *args, **kwargs) else: self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py index f0964e46d3a9..5dfa75da3210 100644 --- a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -661,6 +661,7 @@ def finish_materialized_sparse_mla_scores_with_sink( num_tokens, _, num_candidates = scores.shape head_dim = kv.shape[2] + # Clamp BLOCK_D to the smallest power-of-2 >= head_dim (among the allowed # {64, 128, 256, 512}). Without this, a caller-supplied value_block_size # larger than head_dim wastes work on masked-off positions — e.g. DSv4 @@ -1256,7 +1257,6 @@ def accumulate_indexed_sparse_mla_attention_chunk( ) - @triton.autotune( configs=[ triton.Config({}, num_warps=4, num_stages=2), @@ -1677,7 +1677,6 @@ def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( ) - @triton.jit def _indexed_d512_split_score_kernel( q_ptr, @@ -1827,9 +1826,7 @@ def _indexed_d512_split_value_kernel( dim_mask = dim_offsets < head_dim valid_len = tl.load(lens_ptr + token_idx) max_score = tl.load( - max_score_ptr - + token_idx * stride_state_t - + head_offsets * stride_state_h, + max_score_ptr + token_idx * stride_state_t + head_offsets * stride_state_h, mask=head_mask, other=0.0, ).to(tl.float32) @@ -2064,9 +2061,7 @@ def _indexed_d512_chunked_merge_acc_kernel( mask=head_mask[:, None] & dim_mask[None, :], other=0.0, ).to(tl.float32) - merged_acc = ( - running_acc * running_scale[:, None] + chunk_acc * chunk_scale[:, None] - ) + merged_acc = running_acc * running_scale[:, None] + chunk_acc * chunk_scale[:, None] tl.store( acc_ptr + token_idx * stride_acc_t diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 0a69ac2e1cfc..dac8914b4a22 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -391,8 +391,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: return False evicted = ( - self.cached_block_hash_to_block.pop(block_hash, block.block_id) - is not None + self.cached_block_hash_to_block.pop(block_hash, block.block_id) is not None ) block.reset_hash() diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index d741d7e67f5a..96d74d0bb950 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -627,9 +627,7 @@ def verify_and_split_kv_cache_groups(self) -> None: # Attention-group indices (into ``self.attention_groups``) that # contain at least one EAGLE/MTP KV cache group. self.eagle_attn_group_indices: set[int] = { - i - for i, group in enumerate(self.attention_groups) - if group.use_eagle + i for i, group in enumerate(self.attention_groups) if group.use_eagle } def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 57ea575b4517..5f0b926ba739 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -457,9 +457,7 @@ def _limit_mixed_decode_prefill_chunk( has_decode_pressure = ( self._has_scheduled_decode(scheduled_running_reqs) or has_pending_decode ) - has_prefill_competition = ( - has_waiting_requests or prefill_budget_partitions > 1 - ) + has_prefill_competition = has_waiting_requests or prefill_budget_partitions > 1 if not has_decode_pressure and not has_prefill_competition: return num_new_tokens @@ -479,8 +477,7 @@ def _limit_mixed_decode_prefill_chunk( elif remaining_prefill > very_long_prefill_threshold: mixed_prefill_budget = max( 1, - self.max_num_scheduled_tokens - // max(2, prefill_budget_partitions), + self.max_num_scheduled_tokens // max(2, prefill_budget_partitions), ) else: mixed_prefill_budget = max(1, (self.max_num_scheduled_tokens * 3) // 4) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 9ca7155a65fe..0105ad0f482f 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -701,9 +701,9 @@ def _should_protect_prompt_blocks(self) -> bool: # 3. ``compress_ratio > 1``: any compressed MLA cache (today # only DSv4 sets ``compress_ratio > 1``; V3.2 keeps it at 1). return ( - self.kv_cache_spec.model_version == "deepseek_v4" - or self.kv_cache_spec.cache_dtype_str == "fp8_ds_mla" - or self.kv_cache_spec.compress_ratio > 1 + getattr(self.kv_cache_spec, "model_version", None) == "deepseek_v4" + or getattr(self.kv_cache_spec, "cache_dtype_str", None) == "fp8_ds_mla" + or getattr(self.kv_cache_spec, "compress_ratio", 1) > 1 ) def _max_protected_prompt_blocks(self) -> int | None: diff --git a/vllm/v1/spec_decode/llm_base_proposer.py b/vllm/v1/spec_decode/llm_base_proposer.py index b76c1f98160b..c512909415e6 100644 --- a/vllm/v1/spec_decode/llm_base_proposer.py +++ b/vllm/v1/spec_decode/llm_base_proposer.py @@ -1798,6 +1798,7 @@ def compute_probs_and_sample_next_token( sampling_metadata.temperature, num_tokens, ) + assert temperature is not None # Avoid division by zero if there are greedy requests. if not sampling_metadata.all_random: is_greedy = temperature < _SAMPLING_EPS From f91caa0cb4859933e16be9b508d3ca1990c301c0 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 16 Jun 2026 05:55:17 +0800 Subject: [PATCH 104/131] deepseek-v4: drop dead VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH env A dev-branch sparse-MLA stats diagnostic leaked into the PR: the only reader was the phantom _disable_sparse_mla_prefill_stats warmup wrapper (removed in the prior cleanup, as it referenced a never-defined symbol). With that gone the env has no production reader, so remove its envs.py declaration + lookup and the dedicated test that exercised it. Signed-off-by: jasl --- tests/test_envs.py | 19 ------------------- vllm/envs.py | 4 ---- 2 files changed, 23 deletions(-) diff --git a/tests/test_envs.py b/tests/test_envs.py index 5c70d56e692c..d4d120ecee51 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -36,25 +36,6 @@ def test_nixl_side_channel_host_is_not_compile_factor( assert "VLLM_NIXL_SIDE_CHANNEL_HOST" not in envs.compile_factors() -def test_deepseek_v4_sparse_mla_stats_path_env( - monkeypatch: pytest.MonkeyPatch, -): - monkeypatch.delenv("VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH", raising=False) - if hasattr(envs.__getattr__, "cache_clear"): - envs.__getattr__.cache_clear() - - assert envs.VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH is None - - monkeypatch.setenv( - "VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH", - "/tmp/sparse_mla_stats", - ) - if hasattr(envs.__getattr__, "cache_clear"): - envs.__getattr__.cache_clear() - - assert envs.VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH == "/tmp/sparse_mla_stats" - - def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("VLLM_HOST_IP", "1.1.1.1") monkeypatch.setenv("VLLM_PORT", "1234") diff --git a/vllm/envs.py b/vllm/envs.py index 11c06e5ccfa6..364b84f037a4 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -180,7 +180,6 @@ VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: bool = True - VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH: str | None = None VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL: bool = True VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL: bool = True VLLM_TRITON_MLA_SPARSE: bool | None = None @@ -1457,9 +1456,6 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP": lambda: bool( int(os.getenv("VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP", "1")) ), - "VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH": lambda: os.getenv( - "VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH" - ), "VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL": lambda: bool( int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL", "1")) ), From 627adeee64e9a8c74f6136e5e5b02fc6d335ff29 Mon Sep 17 00:00:00 2001 From: jasl Date: Sun, 14 Jun 2026 20:44:17 +0800 Subject: [PATCH 105/131] feat: gated FlashInfer SM120 packed sparse-MLA decode for DeepSeek V4 Wire FlashInfer PR3395's packed SM120 sparse-MLA decode kernel into the DeepSeek V4 FlashMLA attention as an env-gated decode override (default off). The kernel ships in official flashinfer >= 0.6.13; we drive it through its low-level _SparseMLAPagedAttentionRunner rather than the public trtllm_batch_decode_sparse_mla_dsv4 wrapper. Root cause of the C8-C64 ctx0 decode gap versus the FlashMLA decode path is this decode kernel (the prior PR3395 reintegration ported only the packed prefill). Holding everything else fixed (MARLIN MoE, packed fp8_ds_mla cache, source tree, MTP2) and swapping only the decode kernel lifts ctx0 decode throughput on dual RTX PRO 6000 / SM120, in128/out512: C default(Triton) this delta 8 542 582 +7% 16 771 833 +8% 32 790 981 +24% 64 1345 1683 +25% GSM8K 5-shot limit-300 is correctness-neutral (flexible 0.953 / strict 0.927, matching the MXFP4 baseline). Why the low-level runner and not the public wrapper: the wrapper's _sparse_mla_decode_workspace returns no scratch when num_tokens > 64, so it allocates mid_out/mid_lse (hundreds of MB) fresh on every decode step. The MTP multi-query decode shape routinely exceeds 64 tokens (C32/C64), making the wrapper a regression (-17 to -20% vs the FlashMLA path). The runner instead takes graph-stable mid_out/mid_lse reserved once from the vLLM workspace manager and reused every step; that cached scratch is the entire win (a decode-shaped autotune pass over the kernel's chunks_per_block tactic added 0% on top, so it is not included). - New DeepseekV4FlashInferSM120Attention(DeepseekV4FlashMLAAttention) overrides only _forward_decode; reuses the packed cache, sparse-index metadata, and packed prefill. The compressed decode index is forced contiguous (the kernel asserts it). - Gated by VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE, SM12x, and has_flashinfer_trtllm_sparse_mla_dsv4(); default off keeps the FlashMLA decode path byte-for-byte. Signed-off-by: jasl --- vllm/envs.py | 4 + .../nvidia/flashinfer_sm120_decode.py | 284 ++++++++++++++++++ vllm/models/deepseek_v4/nvidia/model.py | 25 ++ vllm/utils/flashinfer.py | 16 + 4 files changed, 329 insertions(+) create mode 100644 vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py diff --git a/vllm/envs.py b/vllm/envs.py index 364b84f037a4..7d369e7cab98 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -182,6 +182,7 @@ VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: bool = True VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL: bool = True VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL: bool = True + VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE: bool = False VLLM_TRITON_MLA_SPARSE: bool | None = None VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE: int = 512 VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE: int = 256 @@ -1462,6 +1463,9 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL": lambda: bool( int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL", "1")) ), + "VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE": lambda: bool( + int(os.getenv("VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE", "0")) + ), # Experimental sparse MLA fallback controls. # ``VLLM_TRITON_MLA_SPARSE`` unset means auto-select where FlashMLA sparse # is unavailable; set 0/1 to force-disable/force-enable the fallback. diff --git a/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py b/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py new file mode 100644 index 000000000000..21717e573add --- /dev/null +++ b/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""DeepSeek V4 FlashInfer packed sparse-MLA SM120 decode (gated). + +Subclasses the FlashMLA V4 attention to reuse its packed ``fp8_ds_mla`` KV cache, +sparse-index metadata, and packed prefill, and overrides only the decode path to +use FlashInfer's official SM120 packed sparse-MLA decode kernel (FlashInfer +PR3395, merged in flashinfer >= 0.6.13). That kernel scales better at high +concurrency in the MTP speculative-verify (multi-query) decode shape than the +FlashMLA decode kernel, which is the root cause of the C8-C64 ctx0 decode gap. + +Implementation note: flashinfer main exposes this kernel through the +``trtllm_batch_decode_sparse_mla_dsv4`` wrapper, but that wrapper re-validates +inputs and -- critically -- carves the split-K ``mid_out``/``mid_lse`` scratch +from a fixed workspace only for ``num_tokens <= 64``, falling back to a fresh +``torch.empty`` of hundreds of MB on every decode step above that. The MTP +multi-query decode shape routinely exceeds 64 tokens (C32/C64), so that per-step +allocation dominates and makes the wrapper materially slower than the FlashMLA +path. We instead drive the same kernel through its low-level +``_SparseMLAPagedAttentionRunner``, constructed once and fed graph-stable scratch +from vLLM's workspace manager -- so the scratch is reserved during warmup and +reused, never reallocated per step. + +Gated behind ``VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE``; selected only on SM12x +when the official packed kernel is importable (see ``_select_dsv4_attn_cls``). +Default off; gate-off behavior is identical to the FlashMLA decode path. +""" + +from typing import TYPE_CHECKING + +import torch + +from vllm.config import get_current_vllm_config +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.models.deepseek_v4.common.ops import compute_global_topk_indices_and_lens +from vllm.models.deepseek_v4.nvidia.flashmla import DeepseekV4FlashMLAAttention +from vllm.v1.worker.workspace import current_workspace_manager + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.models.deepseek_v4.sparse_mla import DeepseekV4FlashMLAMetadata + from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata + +# Split-K decode scratch sizing, mirrored from the FlashInfer sparse-sm120 +# kernel (``_BI`` = 64 candidates per partition tile): one tile per SWA top-k +# plus one per compressed (extra) top-k. Cap the warmup reservation at the +# largest single-graph decode batch. +_DECODE_MAX_TOKENS = 64 +_DECODE_SPLIT_TILE = 64 +_C128A_TOPK_ALIGNMENT = 128 + + +def _cdiv(x: int, y: int) -> int: + return (int(x) + int(y) - 1) // int(y) + + +def _max_decode_workspace_tokens(max_num_batched_tokens: int) -> int: + return min(int(max_num_batched_tokens), _DECODE_MAX_TOKENS) + + +def _decode_num_splits(topk: int, extra_topk: int = 0) -> int: + return _cdiv(topk, _DECODE_SPLIT_TILE) + _cdiv(extra_topk, _DECODE_SPLIT_TILE) + + +def _c128a_max_compressed(max_model_len: int, compress_ratio: int) -> int: + return ( + _cdiv(_cdiv(max_model_len, compress_ratio), _C128A_TOPK_ALIGNMENT) + * _C128A_TOPK_ALIGNMENT + ) + + +def _get_decode_scratch( + num_tokens: int, + num_heads: int, + head_dim: int, + topk: int, + extra_topk: int = 0, +) -> tuple[torch.Tensor, torch.Tensor]: + num_splits = _decode_num_splits(topk, extra_topk) + mid_out, mid_lse = current_workspace_manager().get_simultaneous( + ((num_tokens, num_heads, num_splits, head_dim), torch.bfloat16), + ((num_tokens, num_heads, num_splits), torch.float32), + ) + return mid_out, mid_lse + + +def _as_sparse_sm120_cache(kv_cache: torch.Tensor) -> torch.Tensor: + if kv_cache.dtype == torch.float8_e4m3fn: + kv_cache = kv_cache.view(torch.uint8) + if kv_cache.dim() == 4: + return kv_cache + return kv_cache.unsqueeze(-2) + + +class DeepseekV4FlashInferSM120Attention(DeepseekV4FlashMLAAttention): + """FlashMLA V4 attention with the official FlashInfer SM120 packed decode. + + Reuses every FlashMLA V4 behavior (packed ``fp8_ds_mla`` cache, metadata + pipeline, packed prefill); only :meth:`_forward_decode` differs. + """ + + @classmethod + def get_padded_num_q_heads(cls, num_heads: int) -> int: + if num_heads <= 16: + return 16 + if num_heads <= 32: + return 32 + if num_heads <= 64: + return 64 + if num_heads <= 128: + return 128 + raise ValueError( + f"DeepseekV4 FlashInfer sparse-sm120 decode does not support " + f"{num_heads} heads (SM120 kernel requires h_q in {{16, 32, 64, 128}})." + ) + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + from vllm.utils.flashinfer import has_flashinfer_trtllm_sparse_mla_dsv4 + + if not has_flashinfer_trtllm_sparse_mla_dsv4(): + raise RuntimeError( + "VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE requires FlashInfer's " + "SM120 packed sparse-MLA decode kernel " + "(trtllm_batch_decode_sparse_mla_dsv4, PR3395, " + "flashinfer >= 0.6.13)." + ) + + from flashinfer.mla._sparse_mla_sm120 import _SparseMLAPagedAttentionRunner + + max_tokens = get_current_vllm_config().scheduler_config.max_num_batched_tokens + runner_device = torch.device("cuda", torch.accelerator.current_device_index()) + # Construct the low-level runner once: its only per-instance state is a + # pre-sized LSE buffer. We feed it graph-stable mid_out/mid_lse scratch + # explicitly on every call, so it never allocates per step. + self._sm120_runner = _SparseMLAPagedAttentionRunner( + max_num_tokens=max_tokens, + max_num_heads=self.padded_heads, + d_v=self.head_dim, + kv_scale_format="auto", + device=runner_device, + ) + logger.info_once( + "DeepSeek V4: using official FlashInfer SM120 packed sparse-MLA decode " + "via the low-level runner (VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE=1)." + ) + + def _reserve_sm120_decode_workspace(self) -> None: + if self.compress_ratio <= 1: + extra_topk = 0 + elif self.compress_ratio == 4: + assert self.topk_indices_buffer is not None + extra_topk = self.topk_indices_buffer.shape[-1] + elif self.compress_ratio == 128: + extra_topk = _c128a_max_compressed(self.max_model_len, self.compress_ratio) + else: + raise ValueError( + f"Unsupported compress_ratio={self.compress_ratio}; " + "expected 1, 4, or 128." + ) + _get_decode_scratch( + _max_decode_workspace_tokens(self.max_num_batched_tokens), + self.padded_heads, + self.head_dim, + self.window_size, + extra_topk, + ) + + def _prepare_sm120_query( + self, q: torch.Tensor, output: torch.Tensor + ) -> torch.Tensor: + # The SM120 packed kernel consumes a bf16 query; the FlashMLA fp8 path + # keeps q in fp8, so convert here. q already arrives padded to + # ``padded_heads`` by the outer attention wrapper. + if self.kv_cache_torch_dtype == torch.float8_e4m3fn: + assert q.dtype == torch.float8_e4m3fn + q = q.to(torch.bfloat16) + else: + assert q.dtype == torch.bfloat16 + padded_heads = output.shape[1] + if q.shape[1] < padded_heads: + padded_query = q.new_zeros((q.shape[0], padded_heads, q.shape[2])) + padded_query[:, : q.shape[1], :] = q + q = padded_query + return q.contiguous() + + def forward_mqa( + self, + q: torch.Tensor, + kv: torch.Tensor, + positions: torch.Tensor, + output: torch.Tensor, + ) -> None: + # Mirror the FlashMLA warmup branch but also reserve the graph-stable + # split-K decode scratch the sparse-sm120 kernel needs, then defer to the + # parent for the real prefill/decode split (which calls our overridden + # _forward_decode). + forward_context = get_forward_context() + if forward_context.attn_metadata is None: + self._reserve_prefill_workspace(self) + self._reserve_sm120_decode_workspace() + output.zero_() + return + super().forward_mqa(q, kv, positions, output) + + def _forward_decode( + self, + q: torch.Tensor, + kv_cache: torch.Tensor | None, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: "DeepseekV4FlashMLAMetadata | None", + swa_only: bool, + output: torch.Tensor, + ) -> None: + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + + # Identical decode-side index/length construction to the FlashMLA decode + # path; only the kernel launch below differs. + topk_indices = None + topk_lens = None + if not swa_only: + assert attn_metadata is not None + assert swa_metadata.is_valid_token is not None + block_size = attn_metadata.block_size // self.compress_ratio + is_valid = swa_metadata.is_valid_token[:num_decode_tokens] + if self.compress_ratio == 4: + assert self.topk_indices_buffer is not None + global_indices, topk_lens = compute_global_topk_indices_and_lens( + self.topk_indices_buffer[:num_decode_tokens], + swa_metadata.token_to_req_indices, + attn_metadata.block_table[:num_decodes], + block_size, + is_valid, + ) + topk_indices = global_indices.view(num_decode_tokens, 1, -1) + else: + topk_indices = attn_metadata.c128a_global_decode_topk_indices + topk_lens = attn_metadata.c128a_decode_topk_lens + # The sparse-sm120 kernel asserts the extra (compressed) index + # tensor is int32 and contiguous; current's metadata builder can hand + # back a non-contiguous view, so normalize before the launch. + if topk_indices is not None: + topk_indices = topk_indices.contiguous() + + swa_indices = swa_metadata.decode_swa_indices + swa_lens = swa_metadata.decode_swa_lens + assert swa_indices is not None + assert swa_lens is not None + + extra_topk = topk_indices.shape[-1] if topk_indices is not None else 0 + mid_out, mid_lse = _get_decode_scratch( + num_decode_tokens, + output.shape[1], + output.shape[-1], + swa_indices.shape[-1], + extra_topk, + ) + + # Each decode token is a one-token query: [num_decode_tokens, 1, h, d]; + # the runner squeezes the singleton s_q dim internally. + q = self._prepare_sm120_query(q, output).unsqueeze(1) + swa_cache = _as_sparse_sm120_cache(self.swa_cache_layer.kv_cache) + extra_cache = ( + _as_sparse_sm120_cache(kv_cache) + if (kv_cache is not None and not swa_only) + else None + ) + self._sm120_runner.run( + q, + swa_cache, + swa_indices, + output, + self.scale, + topk_length=swa_lens, + attn_sink=self.attn_sink, + extra_kv_cache=extra_cache, + extra_indices=topk_indices, + extra_topk_length=topk_lens, + mid_out=mid_out, + mid_lse=mid_lse, + ) diff --git a/vllm/models/deepseek_v4/nvidia/model.py b/vllm/models/deepseek_v4/nvidia/model.py index 90be150fb516..d25786ef8404 100644 --- a/vllm/models/deepseek_v4/nvidia/model.py +++ b/vllm/models/deepseek_v4/nvidia/model.py @@ -764,12 +764,37 @@ def _select_dsv4_attn_cls(vllm_config: VllmConfig) -> type[DeepseekV4Attention]: An explicit ``--attention-backend FLASHINFER_MLA_SPARSE_DSV4`` selects the FlashInfer TRTLLM-gen path; otherwise the FlashMLA path is used. + + When ``VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE`` is set and the runtime is + SM12x with FlashInfer's packed sparse-MLA decode kernel available, decode is + routed through the official ``trtllm_batch_decode_sparse_mla_dsv4`` SM120 + kernel (FlashInfer PR3395) instead of the FlashMLA decode kernel; everything + else (packed ``fp8_ds_mla`` cache, metadata, prefill) is unchanged. + Default off. """ if ( vllm_config.attention_config.backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE_DSV4 ): return DeepseekV4FlashInferMLAAttention + + import vllm.envs as envs + + if envs.VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE: + from vllm.platforms import current_platform + from vllm.utils.flashinfer import has_flashinfer_trtllm_sparse_mla_dsv4 + + capability = current_platform.get_device_capability() + if ( + capability is not None + and capability.major == 12 + and has_flashinfer_trtllm_sparse_mla_dsv4() + ): + from vllm.models.deepseek_v4.nvidia.flashinfer_sm120_decode import ( + DeepseekV4FlashInferSM120Attention, + ) + + return DeepseekV4FlashInferSM120Attention return DeepseekV4FlashMLAAttention diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index e05182778658..af0c99bd69db 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -209,6 +209,22 @@ def has_flashinfer_moe() -> bool: ) +@functools.cache +def has_flashinfer_trtllm_sparse_mla_dsv4() -> bool: + """Return ``True`` if FlashInfer's official SM120 packed sparse-MLA decode + kernel (``trtllm_batch_decode_sparse_mla_dsv4``, PR3395, merged in + flashinfer >= 0.6.13) is available.""" + if not has_flashinfer(): + return False + try: + from flashinfer.mla import ( # noqa: F401 + trtllm_batch_decode_sparse_mla_dsv4, + ) + except ImportError: + return False + return True + + @functools.cache def has_flashinfer_cutedsl() -> bool: """Return ``True`` if FlashInfer cutedsl module is available.""" From 9ab29807f5395c0039bafc3957d1b9aa80ea3914 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 17 Jun 2026 00:31:40 +0800 Subject: [PATCH 106/131] fix: pre-compile DeepSeek-V4 D512-split sparse-MLA prefill kernels at startup The first long prefill JIT-compiled the D512-split sparse-MLA prefill Triton kernels mid-engine-step (~20s), parking EngineCore in shm_broadcast and surfacing as a "sample_tokens RPC timed out" wedge under concurrency. Pre-compile them during the DeepSeek-V4 sparse-MLA warmup over the complete 128-aligned combined_topk specialization set [256..1152], gated by a new env VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_WARMUP (default on). Synthetic throwaway tensors only (no workspace manager use, no state leak); cleanly no-ops when the split path is unreachable. No inference-path behavior change. Signed-off-by: jasl --- vllm/envs.py | 7 + vllm/model_executor/warmup/kernel_warmup.py | 193 ++++++++++++++++++++ 2 files changed, 200 insertions(+) diff --git a/vllm/envs.py b/vllm/envs.py index 7d369e7cab98..56ae2eaff300 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -181,6 +181,7 @@ VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: bool = True VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL: bool = True + VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_WARMUP: bool = True VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL: bool = True VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE: bool = False VLLM_TRITON_MLA_SPARSE: bool | None = None @@ -1460,6 +1461,12 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL": lambda: bool( int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL", "1")) ), + # Pre-compile the D512-split sparse-MLA prefill Triton kernels at startup + # (one per 128-aligned combined_topk in [256, 1152]) so the first long + # prefill does not pay a first-use JIT compile that blocks the engine step. + "VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_WARMUP": lambda: bool( + int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_WARMUP", "1")) + ), "VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL": lambda: bool( int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL", "1")) ), diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 3a5663f47f9d..9ce3a7054f3e 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -354,6 +354,193 @@ def _run_deepseek_v4_mtp_spec_decode_warmup_kernels( ) +def _deepseek_v4_indexed_d512_split_prefill_warmup(runner: "GPUModelRunner") -> None: + """Force-compile the DeepSeek-V4 D512-split sparse-MLA prefill kernels. + + The split path (``_use_indexed_d512_split_prefill`` -> + ``accumulate_indexed_d512_split_sparse_mla_attention``) bottoms out in three + plain ``@triton.jit`` kernels whose compile key is the constexpr set -- + chiefly ``num_candidates`` (= the per-chunk ``combined_topk``) plus the + workspace buffer strides. ``combined_topk`` is 128-aligned + (``_SPARSE_PREFILL_TOPK_ALIGNMENT``) and the split path is gated to + ``[256, 1152]`` (``_is_indexed_d512_split_topk``), so the complete + specialization set is the eight widths {256, 384, ..., 1152}. The kernels + never see ``compress_ratio``, so one warm per width covers cr=4 and cr=128. + + Without this, the first long-prefill request JIT-compiles these kernels + inside the engine step (~20s), parking EngineCore in shm_broadcast and + surfacing as a "sample_tokens RPC timed out" wedge (PR #41834). + + Triton compilation is data-independent, so synthetic zero tensors compile + the same cubin a real request uses -- provided every constexpr matches. Two + non-obvious constexprs (verified against the live jit_monitor compile key): + the per-chunk ``scores``/``indices`` workspaces are sized to that chunk's own + ``combined_topk`` (contiguous at width C, so ``stride_scores_h == C`` and + ``stride_indices_t == C`` -- NOT a slice of a wider buffer), and the prefill + ``q`` buffer is padded to the FP8-decode head count (``padded_heads``), so + ``stride_q_t == padded_heads * head_dim`` even though the kernel reads only + ``n_local_heads``. The synthetic tensors mirror both. + + Scope: only the split path (``combined_topk <= 1152``) is warmed. DeepSeek-V4 + -Flash caps ``combined_topk`` at ``sparse_prefill_combined_topk_size( + index_topk=512, 128) = 640`` for every context length, so that is complete + coverage. A variant whose ``combined_topk`` can exceed 1152 routes onto the + chunked path (extra split-stride and merge kernels) which is not pre-warmed + here; that case is warned at startup rather than left as a silent gap. + """ + if not ( + envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_WARMUP + and envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL + ): + return + + try: + from vllm.models.deepseek_v4.common.ops.cache_utils import ( + sparse_prefill_combined_topk_size, + ) + from vllm.models.deepseek_v4.nvidia.flashmla import ( + _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK, + _INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS, + _INDEXED_D512_SPLIT_PREFILL_MIN_TOPK, + DeepseekV4FlashMLAAttention, + ) + from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled_for_platform, + triton_sparse_mla_query_chunk_size, + ) + from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_indexed_d512_split_sparse_mla_attention, + ) + except ImportError: + logger.debug( + "Skipping DeepSeek V4 D512-split prefill warmup: split kernels or " + "helpers are unavailable on this build." + ) + return + + try: + if not is_triton_sparse_mla_enabled_for_platform(): + return + if getattr(runner, "max_model_len", 0) < _INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS: + return + + # The split kernel never sees compress_ratio, so any cr in (4, 128) + # layer yields identical strides; the first one is representative. + layer = None + for module in runner.get_model().modules(): + if ( + isinstance(module, DeepseekV4FlashMLAAttention) + and module.compress_ratio in (4, 128) + ): + layer = module + break + if layer is None: + return + + head_dim = int(layer.head_dim) + if head_dim != 512: + return + num_heads = int(layer.n_local_heads) + window_size = max(1, int(layer.window_size)) + device = layer.attn_sink.device + + # Upper bound on the per-chunk combined_topk a real request can reach: + # the runtime sizes its prefill workspace from the same expression, so a + # request cannot exceed it. The split path is gated to <= 1152. + topk_bound = DeepseekV4FlashMLAAttention._prefill_workspace_topk_bound(layer) + max_reachable_topk = sparse_prefill_combined_topk_size(topk_bound, window_size) + # Non-silent gap: if combined_topk can exceed the split ceiling, the + # request routes onto the chunked path whose kernels we do not pre-warm. + # DSv4-Flash caps at 640 so this never fires for it; warn for variants. + if max_reachable_topk > _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK: + logger.warning( + "DeepSeek V4 D512 prefill: combined_topk can reach %d (> %d); the " + "chunked-prefill kernels are NOT pre-warmed and may JIT on the " + "first very-long prefill. Only the split path is warmed.", + max_reachable_topk, + _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK, + ) + max_topk = min(_INDEXED_D512_SPLIT_PREFILL_MAX_TOPK, max_reachable_topk) + topk_widths = list( + range(_INDEXED_D512_SPLIT_PREFILL_MIN_TOPK, max_topk + 1, 128) + ) + if not topk_widths: + return + + # The real prefill q buffer is padded to the FP8-decode head count; the + # split kernel reads only n_local_heads, but stride_q_t (a constexpr in + # the compile key) reflects the padded width, so match it. + padded_heads = int( + getattr(layer, "padded_heads", 0) + or DeepseekV4FlashMLAAttention.get_padded_num_q_heads(num_heads) + ) + # T sizes only the launch grid -- the cubin is T-independent -- so keep + # it small to bound the transient footprint. + num_tokens = max(1, min(triton_sparse_mla_query_chunk_size(), 32)) + + logger.info( + "Warming up DeepSeek V4 D512-split sparse-MLA prefill kernels for " + "combined_topk widths=%s (heads=%d, padded_q_heads=%d).", + topk_widths, + num_heads, + padded_heads, + ) + + # Throwaway tensors -- never the shared workspace, so warmup can't grow + # or leak steady-state memory. q/kv/state are width-independent; scores + # and indices are contiguous at each per-chunk width so their constexpr + # strides (stride_scores_h == width, stride_indices_t == width) match the + # runtime per-chunk workspace exactly. + q = torch.zeros( + (num_tokens, padded_heads, head_dim), dtype=torch.bfloat16, device=device + ) + kv_flat = torch.zeros( + (max_topk, head_dim), dtype=torch.bfloat16, device=device + ) + max_score = torch.zeros( + (num_tokens, num_heads), dtype=torch.float32, device=device + ) + denom = torch.zeros( + (num_tokens, num_heads), dtype=torch.float32, device=device + ) + acc = torch.zeros( + (num_tokens, num_heads, head_dim), dtype=torch.float32, device=device + ) + lens = torch.zeros((num_tokens,), dtype=torch.int32, device=device) + + for width in topk_widths: + # indices=0 (valid row) + lens=width keep every candidate active so + # the full kernel body, including the tl.dot MMA, compiles rather + # than an early-return stub. + indices = torch.zeros( + (num_tokens, width), dtype=torch.int32, device=device + ) + scores = torch.zeros( + (num_tokens, num_heads, width), dtype=torch.float32, device=device + ) + lens.fill_(width) + accumulate_indexed_d512_split_sparse_mla_attention( + q=q, + kv_flat=kv_flat, + indices=indices, + lens=lens, + scale=layer.scale, + scores=scores, + max_score=max_score, + denom=denom, + acc=acc, + ) + torch.accelerator.synchronize() + except Exception as exc: # noqa: BLE001 - warmup must never break startup + # Warn (not debug): a swallowed failure here silently leaves the split + # kernels uncompiled, so the first long prefill pays the JIT stall again. + logger.warning( + "DeepSeek V4 D512-split prefill warmup skipped after error " + "(first long prefill may JIT in-inference): %s", + exc, + ) + + def _deepseek_v4_sparse_mla_attention_warmup(worker: "Worker") -> None: if not envs.VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: return @@ -417,6 +604,12 @@ def _deepseek_v4_sparse_mla_attention_warmup(worker: "Worker") -> None: # Do not synthesize multi-request prefill here: that dummy shape # overflows the CUTeDSL KV-gather workspace on SM12x. Revisit only # with a real buffer-sizing fix for that warmup path. + + # The prefill dummies above never drive the C128A indexer, so the + # D512-split prefill kernels stay uncompiled until the first long request + # (PR #41834 wedge). Compile them directly with synthetic inputs. + _deepseek_v4_indexed_d512_split_prefill_warmup(runner) + query_len = getattr(runner, "uniform_decode_query_len", 0) for num_reqs in uniform_decode_reqs: runner._dummy_run( From dce4b3b10f80a6729a85dab8dc76cc1e05c16b91 Mon Sep 17 00:00:00 2001 From: jasl Date: Wed, 17 Jun 2026 18:35:29 +0800 Subject: [PATCH 107/131] feat: restore DeepSeek-V4 API semantics on the SM120 min-enable base Port the dropped "Align DeepSeek V4 API semantics" layer (a8746555ef, on ds4-sm120-full but never on the PR line) onto the PR head 73e99c165: - top-level `thinking` request field (DeepSeek OpenAI-compat) -> chat-template enable_thinking, via apply_chat_template_kwargs at the protocol boundary; - bare DeepSeek-V4 requests (no thinking key) now default thinking ON, matching ds4-sm120-full (the PR line currently defaults them OFF); - deepseek_v4_sampling_override (apply DeepSeek's official sampling defaults when thinking is enabled; per-request opt-out); - reasoning_content alias on ChatMessage/DeltaMessage; prefix/wo_eos message fields; tool-call empty-arguments robustness in the DSv4 tokenizer. Addresses the common bare-request instruction-following regression in jasl/vllm#19: preview-dev defaulted bare requests to thinking OFF, so the model answered directly and prepended explanatory prose despite "output ONLY a JSON array"; full defaulted bare -> thinking ON, reasoned, then complied. (The reporter's EXPLICIT enable_thinking=false case is a separate residual softness, not fixed by this change.) Cherry-picked from a8746555ef; conflicts resolved by keeping the June PR's evolved code (build_chat_params reasoning_effort handling; the dedicated DeepSeekV4ReasoningParser) and layering the DSv4 semantics on top (serving._effective_chat_template_kwargs applies apply_chat_template_kwargs after build_chat_params). Signed-off-by: jasl --- .../test_deepseekv3_reasoning_parser.py | 21 +- tests/tokenizers_/test_deepseek_v4.py | 323 ++++++++++++++++++ vllm/entrypoints/chat_utils.py | 11 + .../openai/chat_completion/batch_serving.py | 6 +- .../openai/chat_completion/protocol.py | 101 +++++- .../openai/chat_completion/serving.py | 8 +- vllm/entrypoints/openai/engine/protocol.py | 9 + vllm/entrypoints/serve/render/serving.py | 28 +- vllm/tokenizers/deepseek_v4_encoding.py | 11 +- 9 files changed, 509 insertions(+), 9 deletions(-) diff --git a/tests/reasoning/test_deepseekv3_reasoning_parser.py b/tests/reasoning/test_deepseekv3_reasoning_parser.py index 49b373dbe332..4e19ca3811a8 100644 --- a/tests/reasoning/test_deepseekv3_reasoning_parser.py +++ b/tests/reasoning/test_deepseekv3_reasoning_parser.py @@ -4,16 +4,35 @@ import pytest from transformers import AutoTokenizer +from vllm.config.reasoning import ReasoningConfig from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.reasoning import ReasoningParserManager from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser -from vllm.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser +from vllm.reasoning.deepseek_v3_reasoning_parser import ( + DeepSeekV3ReasoningParser, + DeepSeekV3ReasoningWithThinkingParser, +) from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-V3.1" +class FakeReasoningTokenizer: + + def get_vocab(self) -> dict[str, int]: + return {"": 100, "": 101} + + def encode( + self, + text: str, + add_special_tokens: bool = False, + **kwargs, + ) -> list[int]: + assert add_special_tokens is False + return [self.get_vocab()[text]] + + @pytest.fixture(scope="module") def tokenizer(): return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) diff --git a/tests/tokenizers_/test_deepseek_v4.py b/tests/tokenizers_/test_deepseek_v4.py index 358732eabf40..0a3253957add 100644 --- a/tests/tokenizers_/test_deepseek_v4.py +++ b/tests/tokenizers_/test_deepseek_v4.py @@ -8,8 +8,14 @@ import pytest from vllm.entrypoints.chat_utils import parse_chat_messages +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ChatMessage, +) +from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.renderers.registry import RENDERER_REGISTRY from vllm.tokenizers.deepseek_v4 import get_deepseek_v4_tokenizer +from vllm.tokenizers.deepseek_v4_encoding import encode_arguments_to_dsml from vllm.tokenizers.registry import TokenizerRegistry FIXTURES_DIR = Path(__file__).parent / "fixtures" / "deepseek_v4" @@ -96,6 +102,130 @@ def test_deepseek_v4_enables_thinking_with_compatible_kwargs(kwargs): assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|>") +def test_deepseek_v4_honors_official_thinking_request_field(): + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + prompt = _tokenizer().apply_chat_template( + request.messages, + tokenize=False, + **chat_kwargs, + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|>") + + +def test_deepseek_v4_defaults_to_official_thinking_for_openai_request(): + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + + +def test_deepseek_v4_preserves_official_reasoning_content_alias(): + messages = [ + {"role": "user", "content": "Q1"}, + {"role": "assistant", "reasoning_content": "because", "content": "A1"}, + {"role": "user", "content": "Q2"}, + ] + + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + + assert conversation[1]["reasoning"] == "because" + assert conversation[1]["reasoning_content"] == "because" + + +def test_deepseek_v4_response_messages_expose_reasoning_content_alias(): + message = ChatMessage(role="assistant", reasoning="because", content="answer") + delta = DeltaMessage(reasoning="because") + + assert message.reasoning_content == "because" + assert delta.reasoning_content == "because" + assert ( + ChatMessage( + role="assistant", + reasoning_content="because", + content="answer", + ).reasoning + == "because" + ) + + +def test_deepseek_v4_preserves_official_prefix_assistant_message(): + messages = [ + {"role": "user", "content": "Please write quick sort code"}, + {"role": "assistant", "content": "```python\n", "prefix": True}, + ] + + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + prompt = _tokenizer().apply_chat_template( + conversation=conversation, + messages=messages, + tokenize=False, + ) + + assert conversation[1]["prefix"] is True + assert conversation[1]["wo_eos"] is True + assert prompt.endswith("<|Assistant|>```python\n") + assert not prompt.endswith("<|end▁of▁sentence|>") + + +def test_deepseek_v4_thinking_ignores_sampling_controls(): + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + "temperature": 0.2, + "top_p": 0.3, + "top_k": 4, + "presence_penalty": 1.5, + "frequency_penalty": 1.25, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + ) + + assert sampling_params.temperature == 1.0 + assert sampling_params.top_p == 1.0 + assert sampling_params.top_k == 0 + assert sampling_params.presence_penalty == 0.0 + assert sampling_params.frequency_penalty == 0.0 + + def test_deepseek_v4_uses_v4_tool_prompt_from_request_tools(): tools = [ { @@ -183,6 +313,66 @@ def test_deepseek_v4_renders_parsed_history_tool_arguments(): assert 'parameter name="arguments"' not in prompt +@pytest.mark.parametrize( + ("tool_call", "expected_parameter"), + [ + ({"name": "refresh", "arguments": None}, None), + ({"name": "refresh"}, None), + ({"name": "refresh", "arguments": ""}, None), + ( + {"name": "refresh", "arguments": '{"target": "cache"}'}, + '<|DSML|parameter name="target" string="true">cache', + ), + ( + {"name": "refresh", "arguments": {"target": "cache"}}, + '<|DSML|parameter name="target" string="true">cache', + ), + ], +) +def test_deepseek_v4_encodes_empty_history_tool_arguments( + tool_call, expected_parameter +): + prompt = encode_arguments_to_dsml(tool_call) + + if expected_parameter is None: + assert prompt == "" + else: + assert expected_parameter in prompt + + +def test_deepseek_v4_renders_openai_history_tool_call_with_null_arguments(): + messages = [ + {"role": "user", "content": "Refresh state"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "refresh", + "arguments": None, + }, + } + ], + }, + ] + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + + prompt = _tokenizer().apply_chat_template( + conversation=conversation, + messages=messages, + tokenize=False, + ) + + assert '<|DSML|invoke name="refresh">' in prompt + assert "<|DSML|parameter" not in prompt + + @pytest.mark.parametrize("reasoning_effort", ["minimal", "low", "medium", "high"]) def test_deepseek_v4_accepts_openai_reasoning_effort_values(reasoning_effort): prompt = _tokenizer().apply_chat_template( @@ -288,3 +478,136 @@ def test_deepseek_v4_matches_reference_golden_fixtures(case_id, kwargs): expected = (FIXTURES_DIR / f"test_output_{case_id}.txt").read_text() assert prompt == expected + + +@pytest.mark.parametrize( + "model", + [ + "deepseek-ai/DeepSeek-V4-Flash", + "deepseek-ai/DeepSeek-V4-Pro", + ], +) +def test_deepseek_v4_official_api_defaults_to_thinking_for_v4_family(model): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": model, + "messages": [{"role": "user", "content": "Hello"}], + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + + +def test_deepseek_v4_official_api_uses_model_config_for_family_detection(): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": "local-ds4-alias", + "messages": [{"role": "user", "content": "Hello"}], + "temperature": 0.2, + } + ) + model_config = SimpleNamespace( + hf_config=SimpleNamespace(model_type="deepseek_v4", architectures=[]), + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs, + model_config=model_config, + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + model_config=model_config, + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + assert sampling_params.temperature == 1.0 + + +def test_deepseek_v4_official_api_sampling_override_can_be_disabled(): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + "deepseek_v4_sampling_override": False, + "temperature": 0.2, + "top_p": 0.3, + "top_k": 4, + "min_p": 0.05, + "presence_penalty": 1.5, + "frequency_penalty": 1.25, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + ) + + assert sampling_params.temperature == 0.2 + assert sampling_params.top_p == 0.3 + assert sampling_params.top_k == 4 + assert sampling_params.min_p == 0.05 + assert sampling_params.presence_penalty == 1.5 + assert sampling_params.frequency_penalty == 1.25 + + +def test_deepseek_v4_official_api_sampling_override_is_v4_only(): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-R1", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + "temperature": 0.2, + "top_p": 0.3, + "top_k": 4, + "min_p": 0.05, + "presence_penalty": 1.5, + "frequency_penalty": 1.25, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + ) + + assert "thinking" not in chat_kwargs + assert "enable_thinking" not in chat_kwargs + assert sampling_params.temperature == 0.2 + assert sampling_params.top_p == 0.3 + assert sampling_params.top_k == 4 + assert sampling_params.min_p == 0.05 + assert sampling_params.presence_penalty == 1.5 + assert sampling_params.frequency_penalty == 1.25 diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index f0b56da3432d..625a65d733a4 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -382,6 +382,12 @@ class ConversationMessage(TypedDict, total=False): reasoning_content: str | None """Deprecated: The reasoning content for interleaved thinking.""" + prefix: bool + """Whether this assistant message is a prefix for continuation.""" + + wo_eos: bool + """Whether this message should be rendered without an EOS marker.""" + tools: list[ChatCompletionFunctionToolParam] | None """The tools for developer role.""" @@ -1750,6 +1756,8 @@ def _parse_chat_message_content( role = message["role"] content = message.get("content") reasoning = message.get("reasoning") + if reasoning is None: + reasoning = message.get("reasoning_content") if content is None: content = [] @@ -1779,6 +1787,9 @@ def _parse_chat_message_content( result_msg["reasoning_content"] = cast( str, reasoning ) # keep compatibility + if parsed_msg.get("prefix"): + result_msg["prefix"] = True + result_msg["wo_eos"] = True elif role == "tool": parsed_msg = _ToolParser(message) if "tool_call_id" in parsed_msg: diff --git a/vllm/entrypoints/openai/chat_completion/batch_serving.py b/vllm/entrypoints/openai/chat_completion/batch_serving.py index 96ed7dcb777b..5ee929976a19 100644 --- a/vllm/entrypoints/openai/chat_completion/batch_serving.py +++ b/vllm/entrypoints/openai/chat_completion/batch_serving.py @@ -160,8 +160,12 @@ async def create_batch_chat_completion( self.override_max_tokens, ) single_request = single_requests[i] + chat_template_kwargs = self._effective_chat_template_kwargs(single_request) sampling_params = single_request.to_sampling_params( - max_tokens, self.default_sampling_params + max_tokens, + self.default_sampling_params, + chat_template_kwargs=chat_template_kwargs, + model_config=self.model_config, ) self._log_inputs( sub_request_id, diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 3457aa12f4ac..a4f0b54b74c5 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -65,6 +65,15 @@ class ChatMessage(OpenAIBaseModel): # vLLM-specific fields that are not in OpenAI spec reasoning: str | None = None + reasoning_content: str | None = None + + @model_validator(mode="after") + def _populate_reasoning_content_alias(self) -> "ChatMessage": + if self.reasoning_content is None and self.reasoning is not None: + self.reasoning_content = self.reasoning + elif self.reasoning is None and self.reasoning_content is not None: + self.reasoning = self.reasoning_content + return self class ChatCompletionLogProb(OpenAIBaseModel): @@ -183,6 +192,10 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): type: Literal["function"] = "function" +class DeepSeekThinkingParam(OpenAIBaseModel): + type: Literal["enabled", "disabled"] = "enabled" + + class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create @@ -228,6 +241,15 @@ class ChatCompletionRequest(OpenAIBaseModel): "part of the standard OpenAI API specification." ), ) + thinking: DeepSeekThinkingParam | None = None + deepseek_v4_sampling_override: bool = Field( + default=True, + description=( + "Apply DeepSeek V4 official sampling defaults when thinking is " + "enabled. This only affects the DeepSeek V4 family and can be " + "disabled per request." + ), + ) thinking_token_budget: ThinkingTokenBudget = None include_reasoning: bool = True parallel_tool_calls: bool | None = True @@ -500,6 +522,63 @@ def build_chat_params( media_io_kwargs=self.media_io_kwargs, ) + def _is_deepseek_v4_model(self, model_config: ModelConfig | None = None) -> bool: + hf_config = getattr(model_config, "hf_config", None) + if getattr(hf_config, "model_type", None) == "deepseek_v4": + return True + + architectures = getattr(hf_config, "architectures", None) or () + if any("deepseekv4" in str(arch).replace("_", "").lower() + for arch in architectures): + return True + + model = (self.model or "").lower().replace("_", "-") + return "deepseek-v4" in model + + def apply_chat_template_kwargs( + self, + chat_template_kwargs: dict[str, Any], + *, + model_config: ModelConfig | None = None, + ) -> dict[str, Any]: + """Apply request-level DeepSeek API compatibility knobs. + + DeepSeek's OpenAI-compatible API exposes ``thinking`` as a top-level + request field, while vLLM's DeepSeek tokenizer consumes it as a chat + template kwarg. Keep the translation at the protocol boundary so the + tokenizer and reasoning parser see the same effective state. + """ + chat_template_kwargs = dict(chat_template_kwargs) + if not self._is_deepseek_v4_model(model_config): + return chat_template_kwargs + + if self.thinking is not None: + enabled = self.thinking.type == "enabled" + chat_template_kwargs["thinking"] = enabled + chat_template_kwargs["enable_thinking"] = enabled + elif ( + "thinking" not in chat_template_kwargs + and "enable_thinking" not in chat_template_kwargs + ): + chat_template_kwargs["thinking"] = True + chat_template_kwargs["enable_thinking"] = True + + return chat_template_kwargs + + def _use_deepseek_v4_sampling_override(self) -> bool: + return self.deepseek_v4_sampling_override + + @staticmethod + def _is_thinking_enabled( + chat_template_kwargs: dict[str, Any] | None, + ) -> bool: + if chat_template_kwargs is None: + return False + return bool( + chat_template_kwargs.get("thinking") + or chat_template_kwargs.get("enable_thinking") + ) + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: if self.max_completion_tokens is not None: max_output_tokens: int | None = self.max_completion_tokens @@ -550,6 +629,9 @@ def to_sampling_params( self, max_tokens: int, default_sampling_params: dict, + *, + chat_template_kwargs: dict[str, Any] | None = None, + model_config: ModelConfig | None = None, ) -> SamplingParams: # Default parameters if (repetition_penalty := self.repetition_penalty) is None: @@ -574,6 +656,21 @@ def to_sampling_params( "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] ) + if ( + self._is_deepseek_v4_model(model_config) + and self._use_deepseek_v4_sampling_override() + and self._is_thinking_enabled(chat_template_kwargs) + ): + temperature = self._DEFAULT_SAMPLING_PARAMS["temperature"] + top_p = self._DEFAULT_SAMPLING_PARAMS["top_p"] + top_k = self._DEFAULT_SAMPLING_PARAMS["top_k"] + min_p = self._DEFAULT_SAMPLING_PARAMS["min_p"] + presence_penalty = 0.0 + frequency_penalty = 0.0 + else: + presence_penalty = self.presence_penalty or 0.0 + frequency_penalty = self.frequency_penalty or 0.0 + prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: prompt_logprobs = self.top_logprobs @@ -616,8 +713,8 @@ def to_sampling_params( extra_args["kv_transfer_params"] = self.kv_transfer_params return SamplingParams.from_optional( n=self.n, - presence_penalty=self.presence_penalty, - frequency_penalty=self.frequency_penalty, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, repetition_penalty=repetition_penalty, temperature=temperature, top_p=top_p, diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 911421029c3b..ecb798f3e5c1 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -196,7 +196,7 @@ def warmup(self) -> None: def _effective_chat_template_kwargs( self, request: ChatCompletionRequest ) -> dict[str, Any]: - return ( + chat_template_kwargs = ( request.build_chat_params( self.chat_template, self.chat_template_content_format, @@ -204,6 +204,10 @@ def _effective_chat_template_kwargs( .with_defaults(self.default_chat_template_kwargs) .chat_template_kwargs ) + return request.apply_chat_template_kwargs( + chat_template_kwargs, + model_config=self.model_config, + ) async def render_chat_request( self, @@ -319,6 +323,8 @@ async def _create_chat_completion( sampling_params = request.to_sampling_params( max_tokens, self.default_sampling_params, + chat_template_kwargs=chat_template_kwargs, + model_config=self.model_config, ) self._log_inputs( diff --git a/vllm/entrypoints/openai/engine/protocol.py b/vllm/entrypoints/openai/engine/protocol.py index d86c77561dbd..6817f01b6b80 100644 --- a/vllm/entrypoints/openai/engine/protocol.py +++ b/vllm/entrypoints/openai/engine/protocol.py @@ -351,8 +351,17 @@ class DeltaMessage(OpenAIBaseModel): role: str | None = None content: str | None = None reasoning: str | None = None + reasoning_content: str | None = None tool_calls: list[DeltaToolCall] = Field(default_factory=list) + @model_validator(mode="after") + def _populate_reasoning_content_alias(self) -> "DeltaMessage": + if self.reasoning_content is None and self.reasoning is not None: + self.reasoning_content = self.reasoning + elif self.reasoning is None and self.reasoning_content is not None: + self.reasoning = self.reasoning_content + return self + class GenerationError(Exception): """raised when finish_reason indicates internal server error (500)""" diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 1f7296cdaa7c..adcefdbb74e8 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time from collections.abc import Sequence +from dataclasses import replace from http import HTTPStatus from typing import Any, cast @@ -261,7 +262,21 @@ async def render_chat_request( self.override_max_tokens, truncate_prompt_tokens=request.truncate_prompt_tokens, ) - params = request.to_sampling_params(max_tokens, self.default_sampling_params) + chat_template_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params( + self.chat_template, + self.chat_template_content_format, + ) + .with_defaults(self.default_chat_template_kwargs) + .chat_template_kwargs, + model_config=self.model_config, + ) + params = request.to_sampling_params( + max_tokens, + self.default_sampling_params, + chat_template_kwargs=chat_template_kwargs, + model_config=self.model_config, + ) request_id = f"chatcmpl-{random_uuid()}" @@ -789,6 +804,17 @@ async def preprocess_chat( default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None), default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None), ) + apply_chat_template_kwargs = getattr( + request, "apply_chat_template_kwargs", None + ) + if apply_chat_template_kwargs is not None: + chat_params = replace( + chat_params, + chat_template_kwargs=apply_chat_template_kwargs( + chat_params.chat_template_kwargs, + model_config=self.model_config, + ), + ) (conversation,), (engine_input,) = await renderer.render_chat_async( [messages], diff --git a/vllm/tokenizers/deepseek_v4_encoding.py b/vllm/tokenizers/deepseek_v4_encoding.py index 6895771e2f59..74f01ce017e4 100644 --- a/vllm/tokenizers/deepseek_v4_encoding.py +++ b/vllm/tokenizers/deepseek_v4_encoding.py @@ -155,10 +155,15 @@ def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str: p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}' P_dsml_strs = [] - if isinstance(tool_call["arguments"], str): - arguments = json.loads(tool_call["arguments"]) + raw_arguments = tool_call.get("arguments") + if raw_arguments is None or raw_arguments == "": + arguments = {} + elif isinstance(raw_arguments, str): + arguments = json.loads(raw_arguments) + if arguments is None: + arguments = {} else: - arguments = tool_call["arguments"] + arguments = raw_arguments for k, v in arguments.items(): p_dsml_str = p_dsml_template.format( From a8280db3013bf0f5d4eb5a11374d77da2077170a Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 18 Jun 2026 04:25:20 +0800 Subject: [PATCH 108/131] fix: correct SM12x indexer prefill top-k non-contiguous output corruption top_k_per_row_prefill writes its output as a contiguous [M, select_k] buffer (it receives the logits strides, not the output's). The indexer passes out[:, :select_k], which is non-contiguous whenever the compressed-KV count is below the top-k width -- i.e. for short prompts and the early queries of long prompts. Writing it as contiguous silently corrupts the later rows' top-k (all -1), so the C4A sparse-MLA prefill drops the distant/downsampled context and attends only the recent sliding window. This degrades instruction following (returns prose instead of the requested JSON) and garbles long-context generation under concurrent traffic. Hand the op a contiguous work buffer and copy the result back; this is a no-op when the slice is already contiguous (select_k == top-k width), so behavior is unchanged outside the corrupted case. The chunked long-context path was already safe (stride-aware torch.topk / torch.gather). Signed-off-by: jasl --- .../nvidia/ops/sm12x_deep_gemm_fallbacks.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py index d2f157be884f..da5f9922a912 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py @@ -251,19 +251,30 @@ def _fp8_mqa_logits_topk_triton( selected = out[:, :select_k] topk_op = _top_k_per_row_prefill_op() if topk_op is not None: + # top_k_per_row_prefill writes its output as a contiguous [M, select_k] + # buffer (it is given the logits strides, not the output strides). When + # select_k < out.shape[1] -- i.e. the compressed-KV count is below the + # topk width, which happens for short prompts and the early queries of + # long prompts -- out[:, :select_k] is non-contiguous (row stride = + # out.shape[1]), so writing it as contiguous silently corrupts later + # rows and drops their top-k (all -1). Hand the op a contiguous buffer + # and copy back. + work = selected if selected.is_contiguous() else selected.contiguous() topk_op( logits, cu_seqlen_ks, cu_seqlen_ke, - selected, + work, logits.shape[0], logits.stride(0), logits.stride(1), select_k, ) - selected.add_(cu_seqlen_ks[:, None]) - valid = (selected >= cu_seqlen_ks[:, None]) & (selected < cu_seqlen_ke[:, None]) - selected.masked_fill_(~valid, -1) + work.add_(cu_seqlen_ks[:, None]) + valid = (work >= cu_seqlen_ks[:, None]) & (work < cu_seqlen_ke[:, None]) + work.masked_fill_(~valid, -1) + if work is not selected: + selected.copy_(work) else: values, indices = torch.topk(logits, select_k, dim=1) selected.copy_(indices.to(torch.int32)) From 667205ef57bfd482ef94e7b6df849357e56cd601 Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 18 Jun 2026 05:33:23 +0800 Subject: [PATCH 109/131] perf: skip throwaway copy in SM12x indexer top-k fallback The non-contiguous-output fix used selected.contiguous(), which copies the slice's current contents (the -1 placeholders just written by out.fill_(-1)) into the work buffer. top_k_per_row_prefill then overwrites every element, so that copy is wasted. Allocate an uninitialized contiguous buffer via selected.new_empty(selected.shape) instead; behavior is unchanged (copy-back still lands the result in the strided slice), one elementwise pass saved on the short-prompt / early-query path. Signed-off-by: jasl --- .../deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py index da5f9922a912..543f4035f9d1 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py @@ -257,9 +257,11 @@ def _fp8_mqa_logits_topk_triton( # topk width, which happens for short prompts and the early queries of # long prompts -- out[:, :select_k] is non-contiguous (row stride = # out.shape[1]), so writing it as contiguous silently corrupts later - # rows and drops their top-k (all -1). Hand the op a contiguous buffer - # and copy back. - work = selected if selected.is_contiguous() else selected.contiguous() + # rows and drops their top-k (all -1). Hand the op a fresh contiguous + # buffer and copy back. The buffer is left uninitialized (new_empty) + # rather than copied (.contiguous()) because the op overwrites every + # element, so copying the placeholder -1s in would be wasted work. + work = selected if selected.is_contiguous() else selected.new_empty(selected.shape) topk_op( logits, cu_seqlen_ks, From edf0d37427db481cef675e06920b28f4a1149ffa Mon Sep 17 00:00:00 2001 From: jasl Date: Thu, 18 Jun 2026 20:13:15 +0800 Subject: [PATCH 110/131] feat: re-enable breakable cudagraph auto-enable on SM121 for DeepSeek-V4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The SM121 carve-out in _should_auto_enable_deepseek_v4_breakable_cudagraph was added because breakable cudagraph produced garbage on trivial prompts on SM121. That was upstream #45309 (reduced eager_break_during_capture), reverted upstream in #45972 and now in our base. With the full @eager_break_during_capture split restored, breakable cudagraph generates correctly on SM121 again -- verified on 2x GB10 (EP off): "2+2等于几" and arithmetic clean, throughput on-par-or-slightly- better than FULL_AND_PIECEWISE (38.5 vs 36.7 out_tok/s). Drop the carve-out so breakable auto-enables for DeepSeek-V4 on all SM12x platforms. Signed-off-by: jasl --- .../config/test_deepseek_v4_cudagraph_config.py | 7 +++++-- vllm/config/vllm.py | 16 +++++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/config/test_deepseek_v4_cudagraph_config.py b/tests/config/test_deepseek_v4_cudagraph_config.py index ddbd6c3f010f..0bf933643631 100644 --- a/tests/config/test_deepseek_v4_cudagraph_config.py +++ b/tests/config/test_deepseek_v4_cudagraph_config.py @@ -26,14 +26,17 @@ def test_deepseek_v4_auto_enables_breakable_cudagraph_off_sm121(monkeypatch): ) -def test_deepseek_v4_skips_breakable_cudagraph_on_sm121(monkeypatch): +def test_deepseek_v4_auto_enables_breakable_cudagraph_on_sm121(monkeypatch): + # Re-enabled on SM121 after upstream reverted #45309 (#45972): with the full + # @eager_break_during_capture split restored, breakable cudagraph generates + # correctly on SM121 again (verified 2x GB10, EP off, "2+2等于几" clean). monkeypatch.setattr( current_platform, "is_device_capability", lambda capability, device_id=0: capability == 121, ) - assert not _should_auto_enable_deepseek_v4_breakable_cudagraph( + assert _should_auto_enable_deepseek_v4_breakable_cudagraph( _model_config("DeepseekV4ForCausalLM") ) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 4d17ed6e91c8..df22a932d6d4 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -114,15 +114,17 @@ class OptimizationLevel(IntEnum): def _should_auto_enable_deepseek_v4_breakable_cudagraph( model_config: ModelConfig, ) -> bool: - if not any( + # Auto-enable breakable cudagraph for DeepSeek-V4 on all SM12x platforms. + # The earlier SM121 carve-out (breakable cudagraph produced garbage on + # trivial prompts there) was removed after upstream reverted #45309 in + # #45972: with the full @eager_break_during_capture split restored, + # breakable cudagraph generates correctly on SM121 again (verified on 2x + # GB10, EP off: "2+2等于几" and arithmetic clean) and is on-par-or-faster + # than FULL_AND_PIECEWISE. + return any( arch in DEEPSEEK_V4_CUDAGRAPH_ARCHITECTURES for arch in model_config.architectures - ): - return False - - from vllm.platforms import current_platform - - return not current_platform.is_device_capability(121) + ) def enable_norm_fusion(cfg: "VllmConfig") -> bool: From 3476a25a5d6a7fa54fb5c726b1bcd631b4fa5890 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 19 Jun 2026 00:02:43 +0800 Subject: [PATCH 111/131] chore: wrap SM12x indexer fallback line under ruff line-length Restore ruff cleanliness after the upstream rebase: the indexer top-k fallback contiguity guard exceeded the 88-char line limit (E501). Signed-off-by: jasl --- .../deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py index 543f4035f9d1..7ed1ecc4c05c 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_deep_gemm_fallbacks.py @@ -261,7 +261,9 @@ def _fp8_mqa_logits_topk_triton( # buffer and copy back. The buffer is left uninitialized (new_empty) # rather than copied (.contiguous()) because the op overwrites every # element, so copying the placeholder -1s in would be wasted work. - work = selected if selected.is_contiguous() else selected.new_empty(selected.shape) + work = ( + selected if selected.is_contiguous() else selected.new_empty(selected.shape) + ) topk_op( logits, cu_seqlen_ks, From 285b542b83f3573713d8dee88c218f5b2f580ffe Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 19 Jun 2026 09:24:04 +0800 Subject: [PATCH 112/131] fix(sm12x): eager-break DeepSeek-V4 attention under FULL cudagraph for spec-decode The DSv4 sparse-MLA decode is per-token (token_to_req_indices -> per-request block_table/topk gather). Captured into a FULL monolithic cudagraph under speculative decoding (MTP) it cross-contaminates concurrent requests (long-context high-concurrency gibberish) and, for the q=1 draft forward, collapses deep-context recall. Keep cudagraph_mode=FULL_AND_PIECEWISE but eager-break the DSv4 attention out of the FULL graph whenever spec-decode is active (draft + verify); the nested indexer then runs eagerly too. Non-spec single-token decode and all non-DSv4 ops keep FULL capture (zero default-path regression). GPU-validated on 2xRTX SM120: forcing the DSv4 attention to eager-break under FULL recovered no-MTP parity vs gibberish when fully captured. Signed-off-by: jasl --- vllm/compilation/breakable_cudagraph.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/breakable_cudagraph.py b/vllm/compilation/breakable_cudagraph.py index 6da3ec717861..787cc16b4d23 100644 --- a/vllm/compilation/breakable_cudagraph.py +++ b/vllm/compilation/breakable_cudagraph.py @@ -53,6 +53,13 @@ def is_breakable_cudagraph_enabled() -> bool: return bool(envs.VLLM_USE_BREAKABLE_CUDAGRAPH) +# Set True (once, at wrapper init) when the engine runs speculative decoding. +# The DeepSeek-V4 sparse-MLA decode is per-token and not FULL-cudagraph-safe for +# spec-decode batches; when this is set we eager-break the DSv4 attention out of +# the FULL graph (see eager_break_during_capture). +_BREAK_DSV4_ATTN_UNDER_FULL_FOR_SPEC = False + + F = TypeVar("F", bound=Callable[..., Any]) @@ -99,7 +106,16 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return fn(*args, **kwargs) if is_forward_context_available(): mode = get_forward_context().cudagraph_runtime_mode - if mode == CUDAGraphMode.FULL: + # Under spec-decode, the per-token DeepSeek-V4 sparse-MLA decode + # cross-contaminates requests when captured into a FULL monolithic + # cudagraph (token_to_req_indices -> per-request block_table/topk + # gather). Eager-break the DSv4 attention (the nested indexer then + # runs eagerly too) for the q=1 draft and the q>1 verify forwards; + # keep FULL capture for non-spec decode and every non-DSv4 op. + if mode == CUDAGraphMode.FULL and not ( + _BREAK_DSV4_ATTN_UNDER_FULL_FOR_SPEC + and "deepseek_v4.attention" in (getattr(fn, "__module__", "") or "") + ): return fn(*args, **kwargs) # Weak-ref args: strong refs in the replay lambda pin cudagraph-pool @@ -278,6 +294,10 @@ def __init__( # BatchDescriptor which already encodes batch shape / uniformity. self.runnable = runnable self.vllm_config = vllm_config + _sc = getattr(vllm_config, "speculative_config", None) + if _sc is not None and (getattr(_sc, "num_speculative_tokens", 0) or 0) > 0: + global _BREAK_DSV4_ATTN_UNDER_FULL_FOR_SPEC + _BREAK_DSV4_ATTN_UNDER_FULL_FOR_SPEC = True self.compilation_config = vllm_config.compilation_config self.graph_pool = current_platform.get_global_graph_pool() self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG" From 20e147276b509d7fa1a1aecb904bf837413074a5 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 19 Jun 2026 09:26:05 +0800 Subject: [PATCH 113/131] fix(v1): write-completion fence for prefix-cache block sharing A cached block was registered cache-shareable at schedule time (for total_computed_tokens + num_new_tokens, i.e. including this step's not-yet-forwarded tokens), with no barrier proving the forward that writes its KV has retired. Under >=3 concurrent identical-prefix requests a sibling could bind a recent-region block whose write was still in flight; for DeepSeek-V4, whose SWA + C4/C128 + compressor-state groups are byte-packed into one physical page, this persistently committed a corrupted block to the shared prefix cache, dropping the most-recent long-context needle (long-context high-concurrency recall failure). Add a write-completion fence: tag each block with the schedule pass at which it was committed, and only hand a cached block to OTHER requests once a forward has retired past that pass (two decoupled counters: schedule_pass advances at schedule() start, retired_forward at update_from_output). Safe cross-step prefix hits are preserved; only the unsafe in-flight intra-pass hand-off is withheld. Gated by VLLM_PREFIX_CACHE_WRITE_FENCE (default on). GPU-validated 2xRTX SM120: arthur 280-line conc=8 recall ~20% -> ~91% (MTP2) / ~78% -> ~98% (no-MTP), prefix-cache hit rate preserved (~87% vs ~90%); GSM8K-200 0.97 + issue19 JSON-only PASS (no correctness regression); GB10 SM121 2-node non-regressive + serves cleanly. Signed-off-by: jasl --- vllm/envs.py | 10 +++++++ vllm/v1/core/block_pool.py | 48 ++++++++++++++++++++++++++++++--- vllm/v1/core/kv_cache_utils.py | 5 ++++ vllm/v1/core/sched/scheduler.py | 4 +++ 4 files changed, 64 insertions(+), 3 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 56ae2eaff300..0cb64c7c734e 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -157,6 +157,7 @@ VLLM_USE_STANDALONE_COMPILE: bool = True VLLM_ENABLE_PREGRAD_PASSES: bool = True VLLM_USE_BREAKABLE_CUDAGRAPH: bool = False + VLLM_PREFIX_CACHE_WRITE_FENCE: bool = True VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False @@ -708,6 +709,15 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_USE_BREAKABLE_CUDAGRAPH": lambda: ( os.environ.get("VLLM_USE_BREAKABLE_CUDAGRAPH", "0") == "1" ), + # Prefix-cache write-completion fence: only expose a cached block to OTHER + # requests after the forward that wrote its tokens has retired. Prevents a + # concurrent same-prefix request from binding a recent-region block whose + # KV/compressed write is still in flight (DeepSeek-V4 packed multi-group + # pages make this corrupt the recent context under conc>=3). Default on; + # set to 0 to restore the legacy expose-at-schedule-time behavior. + "VLLM_PREFIX_CACHE_WRITE_FENCE": lambda: ( + os.environ.get("VLLM_PREFIX_CACHE_WRITE_FENCE", "1") == "1" + ), # Debug pattern matching inside custom passes. # Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3'). "VLLM_PATTERN_MATCH_DEBUG": lambda: os.environ.get( diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index dac8914b4a22..3b949513f054 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -3,6 +3,7 @@ from collections.abc import Iterable, Sequence from typing import Any +import vllm.envs as envs from vllm.distributed.kv_events import ( MEDIUM_GPU, AllBlocksCleared, @@ -72,6 +73,25 @@ def get_one_block(self, key: BlockHashWithGroupId) -> KVCacheBlock | None: self._unexpected_blocks_type(blocks) return None + def get_one_block_retired( + self, key: BlockHashWithGroupId, retired_forward: int + ) -> "KVCacheBlock | None": + """Like get_one_block, but only returns a block whose committed_step + has retired (committed_step <= retired_forward), so a concurrent request + never binds a block whose writing forward is still in flight.""" + blocks = self._cache.get(key) + if blocks is None: + return None + if isinstance(blocks, KVCacheBlock): + return blocks if blocks.committed_step <= retired_forward else None + if isinstance(blocks, dict): + for blk in blocks.values(): + if blk.committed_step <= retired_forward: + return blk + return None + self._unexpected_blocks_type(blocks) + return None + def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None: """ Inserts the KVCacheBlock to the cache @@ -185,6 +205,22 @@ def __init__( self.metrics_collector = metrics_collector + # Prefix-cache write-completion fence (VLLM_PREFIX_CACHE_WRITE_FENCE). + # schedule_pass advances at each schedule() start; retired_forward + # advances when a forward completes (update_from_output). A block + # committed at schedule_pass S is exposed to other requests only once + # retired_forward >= S (its writing forward has retired). Async-safe: + # the two clocks decouple commit-time from write-completion-time. + self.write_fence = envs.VLLM_PREFIX_CACHE_WRITE_FENCE + self.schedule_pass = 0 + self.retired_forward = 0 + + def advance_schedule_pass(self) -> None: + self.schedule_pass += 1 + + def advance_retired_forward(self) -> None: + self.retired_forward += 1 + def get_cached_block( self, block_hash: BlockHash, kv_cache_group_ids: list[int] ) -> list[KVCacheBlock] | None: @@ -204,9 +240,14 @@ def get_cached_block( block_hash_with_group_id = make_block_hash_with_group_id( block_hash, group_id ) - block = self.cached_block_hash_to_block.get_one_block( - block_hash_with_group_id - ) + if self.write_fence: + block = self.cached_block_hash_to_block.get_one_block_retired( + block_hash_with_group_id, self.retired_forward + ) + else: + block = self.cached_block_hash_to_block.get_one_block( + block_hash_with_group_id + ) if not block: return None cached_blocks.append(block) @@ -282,6 +323,7 @@ def cache_full_blocks( block_hash, kv_cache_group_id ) blk.block_hash = block_hash_with_group_id + blk.committed_step = self.schedule_pass self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk) if new_hashes is not None: new_hashes.append(maybe_convert_block_hash(block_hash)) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index a1ebe08c0789..2120e5382b60 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -133,6 +133,11 @@ class KVCacheBlock: # Whether the block is a null block that should never be cached. is_null: bool = False + # Schedule-pass index at which this block was committed to the prefix cache. + # Used by the write-completion fence to withhold the block from other + # requests until the forward that wrote it has retired (-1 = uncommitted). + committed_step: int = -1 + @property def block_hash(self) -> BlockHashWithGroupId | None: return self._block_hash diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 5f0b926ba739..20aa1fa48ccd 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -517,6 +517,8 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: scheduled_timestamp = time.monotonic() self.kv_cache_manager.new_step_starts() + if self.kv_cache_manager.block_pool.write_fence: + self.kv_cache_manager.block_pool.advance_schedule_pass() # DP prefill balancing: on a throttled (non-cadence-aligned) step, defer # all prefill compute unless saturated. @@ -1614,6 +1616,8 @@ def update_from_output( scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: + if self.kv_cache_manager.block_pool.write_fence: + self.kv_cache_manager.block_pool.advance_retired_forward() sampled_token_ids = model_runner_output.sampled_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict From 6c92b0972813a97f39594a64d90a1a5a4e536b5f Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 19 Jun 2026 17:33:28 +0800 Subject: [PATCH 114/131] perf(sm12x): default DeepSeek-V4 to FULL_AND_PIECEWISE (drop breakable-cudagraph auto-enable) Breakable cudagraph was auto-enabled for DeepSeek-V4 on the assumption it was "on-par-or-faster than FULL_AND_PIECEWISE". That is wrong: breakable mode disables the torch.compile pipeline (equivalent to -O.mode=none) and runs attention eagerly every decode step, so single-stream MTP decode is 1.5-3.8x SLOWER and degrades with output length. Measured on both arches: RTX PRO 6000 (SM120), single-stream MTP2 decode tok/s, breakable on vs off: max_tokens 400: ~103 vs ~160 600: ~55 vs ~175 800: ~45 vs ~172 2x GB10 (SM121) shows the same on/off split (community report on #41834). FULL_AND_PIECEWISE + torch.compile is correct (GSM8K-200 0.96, bare prompts clean: 2+2->4, capital->Paris) and faster, so make it the default for both SM120 and SM121. Breakable stays available via VLLM_USE_BREAKABLE_CUDAGRAPH=1 for the MTP + long-context + high-concurrency garbled-output workaround (which then also engages the spec-decode attention eager-break). The prefix-cache write-fence and the spec-break fix are unaffected; the latter is now opt-in with breakable. Signed-off-by: jasl --- .../test_deepseek_v4_cudagraph_config.py | 15 ++++++------- vllm/config/vllm.py | 22 +++++++++---------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/tests/config/test_deepseek_v4_cudagraph_config.py b/tests/config/test_deepseek_v4_cudagraph_config.py index 0bf933643631..5a938b0ab9d9 100644 --- a/tests/config/test_deepseek_v4_cudagraph_config.py +++ b/tests/config/test_deepseek_v4_cudagraph_config.py @@ -11,32 +11,31 @@ def _model_config(*architectures: str): return SimpleNamespace(architectures=list(architectures)) -def test_deepseek_v4_auto_enables_breakable_cudagraph_off_sm121(monkeypatch): +def test_deepseek_v4_does_not_auto_enable_breakable_cudagraph(monkeypatch): + # Breakable cudagraph disables torch.compile and is 1.5-3.8x slower for MTP + # decode on SM12x (measured); DeepSeek-V4 defaults to FULL_AND_PIECEWISE. monkeypatch.setattr( current_platform, "is_device_capability", lambda capability, device_id=0: False, ) - assert _should_auto_enable_deepseek_v4_breakable_cudagraph( + assert not _should_auto_enable_deepseek_v4_breakable_cudagraph( _model_config("DeepseekV4ForCausalLM") ) - assert _should_auto_enable_deepseek_v4_breakable_cudagraph( + assert not _should_auto_enable_deepseek_v4_breakable_cudagraph( _model_config("DeepSeekV4MTPModel") ) -def test_deepseek_v4_auto_enables_breakable_cudagraph_on_sm121(monkeypatch): - # Re-enabled on SM121 after upstream reverted #45309 (#45972): with the full - # @eager_break_during_capture split restored, breakable cudagraph generates - # correctly on SM121 again (verified 2x GB10, EP off, "2+2等于几" clean). +def test_deepseek_v4_does_not_auto_enable_breakable_cudagraph_on_sm121(monkeypatch): monkeypatch.setattr( current_platform, "is_device_capability", lambda capability, device_id=0: capability == 121, ) - assert _should_auto_enable_deepseek_v4_breakable_cudagraph( + assert not _should_auto_enable_deepseek_v4_breakable_cudagraph( _model_config("DeepseekV4ForCausalLM") ) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index df22a932d6d4..4c2a60e4f3a3 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -114,17 +114,17 @@ class OptimizationLevel(IntEnum): def _should_auto_enable_deepseek_v4_breakable_cudagraph( model_config: ModelConfig, ) -> bool: - # Auto-enable breakable cudagraph for DeepSeek-V4 on all SM12x platforms. - # The earlier SM121 carve-out (breakable cudagraph produced garbage on - # trivial prompts there) was removed after upstream reverted #45309 in - # #45972: with the full @eager_break_during_capture split restored, - # breakable cudagraph generates correctly on SM121 again (verified on 2x - # GB10, EP off: "2+2等于几" and arithmetic clean) and is on-par-or-faster - # than FULL_AND_PIECEWISE. - return any( - arch in DEEPSEEK_V4_CUDAGRAPH_ARCHITECTURES - for arch in model_config.architectures - ) + # DeepSeek-V4 does NOT auto-enable breakable cudagraph. Breakable mode + # disables the torch.compile pipeline (equivalent to -O.mode=none) and runs + # attention eagerly every decode step; on SM12x that is 1.5-3.8x SLOWER for + # MTP decode and degrades with output length, measured on both RTX PRO 6000 + # (SM120) and 2x GB10 (SM121). FULL_AND_PIECEWISE + torch.compile is correct + # (GSM8K parity, bare-prompt clean) and faster, so it is the default. + # Opt in with VLLM_USE_BREAKABLE_CUDAGRAPH=1 for the MTP + long-context + + # high-concurrency garbled-output workaround (which also engages the + # spec-decode attention eager-break). + del model_config # architecture-independent: never auto-enable + return False def enable_norm_fusion(cfg: "VllmConfig") -> bool: From f108fa9d1c4e52702976644f00d1503c511a8e89 Mon Sep 17 00:00:00 2001 From: jasl Date: Fri, 19 Jun 2026 23:08:29 +0800 Subject: [PATCH 115/131] fix(sm12x): COW shared final block for DeepSeek-V4 writable cache groups under spec decode Under MTP/spec decode, _annotate_eagle_groups_deepseek_v4 flags only the MTP draft layer's group as an eagle group, so only it receives drop_eagle_block (the COW that gives spec decode a private final block). DeepSeek-V4's other writable cache groups -- notably the sliding-window caches (SWA + compressor state, SlidingWindowMLASpec), which are rewritten every decode step -- kept a final block that is physically SHARED across concurrent identical prefix-hit requests. Under MTP that shared block is written while sibling requests read it, corrupting the KV the sparse-MLA decode consumes and producing ~9% needle-drop / mixed-script gibberish under long-context concurrent traffic (reporter: arthur, PR #41834). Fix: in HybridKVCacheCoordinator.find_longest_cache_hit and find_longest_cache_hit_per_group, extend drop_eagle_block to fire for writable compressed (compress_ratio>1) and sliding-window groups whenever spec decode is active (eagle_group_ids non-empty), not only the MTP group. The sliding-window drop is the load-bearing part; compressed (C4/C128) is included for completeness (all DeepSeek-V4 writable groups get a private final block under spec). Perf-neutral: one extra private final block per writable group under spec. Validated: RTX SM120 TP=2 warm-full-hit conc=8 MTP2 ~9%->0% recall-miss (0.4% conc=16); GB10 SM121 2-node 64/64; GSM8K-200 0.96-0.975; decode 170-180 tok/s (no regression); bare prompts clean. Signed-off-by: jasl --- vllm/v1/core/kv_cache_coordinator.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 96d74d0bb950..4fa9e8121961 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -705,7 +705,16 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: ) continue - drop_eagle_block = use_eagle and idx not in eagle_verified + drop_eagle_block = ( + use_eagle + or ( + bool(self.eagle_group_ids) + and ( + getattr(spec, "compress_ratio", 1) > 1 + or getattr(spec, "sliding_window", None) is not None + ) + ) + ) and idx not in eagle_verified _max_length = curr_hit_length if drop_eagle_block: @@ -784,7 +793,16 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: kv_cache_group_ids=group_ids, block_pool=self.block_pool, kv_cache_spec=spec, - drop_eagle_block=use_eagle, + drop_eagle_block=( + use_eagle + or ( + bool(self.eagle_group_ids) + and ( + getattr(spec, "compress_ratio", 1) > 1 + or getattr(spec, "sliding_window", None) is not None + ) + ) + ), alignment_tokens=self.scheduler_block_size, ) group_hit = len(blocks[0]) * spec.block_size From 67e060d8c272240ec52642f7d50640d74e81d590 Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 20 Jun 2026 02:47:15 +0800 Subject: [PATCH 116/131] perf(sm12x): default indexed-D512 prefill min-token gate to 4096 (env-tunable) Re-add VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS (dropped in the upstream rebase, then left hard-coded at 8192) and default it to 4096. With max-num-batched-tokens=4096 this admits the first ~8192-token region (the early chunks) of every long prefill to the fast indexed-D512 path instead of the slow fallback. Measured on 2x RTX PRO 6000 (SM120): +57%@4096, +28%@8192, decaying to +9%@24k (gain proportional to 1/num_chunks); GSM8K-200 strict 0.965 (clean) and KV-cache-neutral. Set =8192 to restore the prior threshold. Signed-off-by: jasl --- vllm/envs.py | 9 +++++++++ vllm/models/deepseek_v4/nvidia/flashmla.py | 7 ++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 0cb64c7c734e..bb9f3c77fcee 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -182,6 +182,7 @@ VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True VLLM_ENABLE_DEEPSEEK_V4_SPARSE_MLA_WARMUP: bool = True VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL: bool = True + VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS: int = 4096 VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_WARMUP: bool = True VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL: bool = True VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE: bool = False @@ -1471,6 +1472,14 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL": lambda: bool( int(os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL", "1")) ), + # Minimum prefill sequence length to admit a row to the indexed-D512 fast + # prefill path. Lower = more (shorter / early-chunk) prefills use the fast + # kernel. 4096 measured +9-59% short/medium-prefill tok/s (biggest at the + # 4-8k band), GSM8K-clean and KV-cache-neutral on SM12x; set 8192 to revert + # to the prior conservative threshold. + "VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS": lambda: int( + os.getenv("VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS", "4096") + ), # Pre-compile the D512-split sparse-MLA prefill Triton kernels at startup # (one per 128-aligned combined_topk in [256, 1152]) so the first long # prefill does not pay a first-use JIT compile that blocks the engine step. diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index 5fb3d9137c61..37cac1ae0019 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -55,7 +55,6 @@ from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadata -_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS = 8192 _INDEXED_D512_SPLIT_PREFILL_MIN_TOPK = 256 _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK = 1152 @@ -76,7 +75,8 @@ def _use_indexed_d512_split_prefill( and head_dim == 512 and num_prefills == 1 and _is_indexed_d512_split_topk(combined_topk) - and max_prefill_seq_len >= _INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS + and max_prefill_seq_len + >= envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS ) @@ -105,7 +105,8 @@ def _use_indexed_d512_chunked_prefill( and head_dim == 512 and num_prefills == 1 and combined_topk > _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK - and max_prefill_seq_len >= _INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS + and max_prefill_seq_len + >= envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS ) From e57276b5540ebddcd066e636a31ebcac6bbe58bd Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 20 Jun 2026 04:44:23 +0800 Subject: [PATCH 117/131] =?UTF-8?q?revert:=20drop=5Feagle=5Fblock=20broade?= =?UTF-8?q?ning=20=E2=80=94=20breaks=20prefix=20caching=20under=20MTP?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reverts a743ef5dfbd. The broadened drop_eagle_block (copy-on-write the shared final block for all DSv4 compressed/sliding-window cache groups under spec decode) fixed the long-context + high-concurrency recall/garble bug, but it drops nearly all prefix-cache reuse under MTP, not just the contended block. On 2-node GB10 the prefix-cache hit rate was 0% with it (warm request = full re-prefill) vs 33-71% without; a same-build recall-fix ON/OFF x MTP ON/OFF matrix plus a revert test isolated it definitively. That is a broad perf regression for every multi-turn / agentic MTP request, traded for a narrow recall corner case. Reverting to restore caching; reopening the recall bug pending a surgical fix (drop only the genuinely spec-written block). Signed-off-by: jasl --- vllm/v1/core/kv_cache_coordinator.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 4fa9e8121961..96d74d0bb950 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -705,16 +705,7 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: ) continue - drop_eagle_block = ( - use_eagle - or ( - bool(self.eagle_group_ids) - and ( - getattr(spec, "compress_ratio", 1) > 1 - or getattr(spec, "sliding_window", None) is not None - ) - ) - ) and idx not in eagle_verified + drop_eagle_block = use_eagle and idx not in eagle_verified _max_length = curr_hit_length if drop_eagle_block: @@ -793,16 +784,7 @@ def _get_block_hashes(kv_cache_spec: KVCacheSpec) -> BlockHashList: kv_cache_group_ids=group_ids, block_pool=self.block_pool, kv_cache_spec=spec, - drop_eagle_block=( - use_eagle - or ( - bool(self.eagle_group_ids) - and ( - getattr(spec, "compress_ratio", 1) > 1 - or getattr(spec, "sliding_window", None) is not None - ) - ) - ), + drop_eagle_block=use_eagle, alignment_tokens=self.scheduler_block_size, ) group_hit = len(blocks[0]) * spec.block_size From 197d21ee790ce0cb366dae40e5d4f359f7ed34d9 Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 20 Jun 2026 11:02:21 +0800 Subject: [PATCH 118/131] fix(sm12x): int64 block offsets in indexer paged-MQA-logits kernel The post-rebase indexer kv_cache block stride grew (~1039680, a strided slice of the fused KV block); block_idx * stride overflowed int32 in the SM120 paged MQA-logits Triton kernel for block_idx beyond ~2065, giving an illegal memory access on SM120 and silent garbage on SM121 under long-context / multi-request indexer calls. Cast block_idx to int64 at the K and scale gather sites in _fp8_paged_mqa_logits_rowwise_kernel. Signed-off-by: jasl --- vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py index 3116a8d90c98..1b5bffdf6b4f 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py @@ -373,7 +373,7 @@ def _fp8_paged_mqa_logits_rowwise_kernel( ) scale = tl.load( - scale_ptr + block_idx * stride_sb + block_offset * stride_ss, + scale_ptr + block_idx.to(tl.int64) * stride_sb + block_offset * stride_ss, mask=context_mask, other=0.0, ) @@ -396,7 +396,7 @@ def _fp8_paged_mqa_logits_rowwise_kernel( ).to(tl.float32) k = tl.load( kv_ptr - + block_idx[None, :] * stride_kvb + + block_idx[None, :].to(tl.int64) * stride_kvb + block_offset[None, :] * stride_kvs + d[:, None] * stride_kvd, mask=context_mask[None, :] & (d[:, None] < head_dim), From d782c62b865de011e836b2316d87ed866858e203 Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 20 Jun 2026 11:02:28 +0800 Subject: [PATCH 119/131] feat(sm12x): env-gated packed FlashInfer sparse-MLA prefill for DeepSeek-V4 Add a _forward_prefill override driving the SM120 packed sparse-MLA runner, gated by VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL (default off; defers to the FlashMLA indexed-D512 prefill when off). Rebase the indexer's batch-global compressed top-k to per-request-local before the per-request block-table map, and slice the query to the real prefill-token count to stay consistent under padded / MTP-draft batches. ~+5-6% single-stream prefill, flat with concurrency vs the FlashMLA prefill path. Signed-off-by: jasl --- vllm/envs.py | 4 + .../nvidia/flashinfer_sm120_decode.py | 173 ++++++++++++++++++ vllm/v1/attention/backends/mla/sparse_swa.py | 55 +++++- 3 files changed, 229 insertions(+), 3 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index bb9f3c77fcee..0a31094d5981 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -186,6 +186,7 @@ VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_WARMUP: bool = True VLLM_DEEPSEEK_V4_INDEXED_D512_CHUNKED_PREFILL: bool = True VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE: bool = False + VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL: bool = False VLLM_TRITON_MLA_SPARSE: bool | None = None VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE: int = 512 VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE: int = 256 @@ -1492,6 +1493,9 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE": lambda: bool( int(os.getenv("VLLM_DEEPSEEK_V4_FLASHINFER_SM120_DECODE", "0")) ), + "VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL": lambda: bool( + int(os.getenv("VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL", "0")) + ), # Experimental sparse MLA fallback controls. # ``VLLM_TRITON_MLA_SPARSE`` unset means auto-select where FlashMLA sparse # is unavailable; set 0/1 to force-disable/force-enable the fallback. diff --git a/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py b/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py index 21717e573add..2ad3c8c8af38 100644 --- a/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py +++ b/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py @@ -94,6 +94,17 @@ def _as_sparse_sm120_cache(kv_cache: torch.Tensor) -> torch.Tensor: return kv_cache.unsqueeze(-2) +def _get_prefill_swa_scratch( + num_tokens: int, window_size: int +) -> tuple[torch.Tensor, torch.Tensor]: + # Graph-stable per-token SWA window indices + lengths for the prefill tokens. + swa_indices, swa_lens = current_workspace_manager().get_simultaneous( + ((num_tokens, 1, window_size), torch.int32), + ((num_tokens,), torch.int32), + ) + return swa_indices, swa_lens + + class DeepseekV4FlashInferSM120Attention(DeepseekV4FlashMLAAttention): """FlashMLA V4 attention with the official FlashInfer SM120 packed decode. @@ -282,3 +293,165 @@ def _forward_decode( mid_out=mid_out, mid_lse=mid_lse, ) + + def _forward_prefill( + self, + q: torch.Tensor, + positions: torch.Tensor, + compressed_k_cache: torch.Tensor | None, + swa_k_cache: torch.Tensor, + output: torch.Tensor, + attn_metadata: "DeepseekV4FlashMLAMetadata | None", + swa_metadata: "DeepseekSparseSWAMetadata", + ) -> None: + import vllm.envs as envs + + # Packed prefill is an independent opt-in on top of the decode port; when + # off, defer to the FlashMLA indexed-D512 prefill path byte-for-byte. + if not envs.VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL: + super()._forward_prefill( + q, + positions, + compressed_k_cache, + swa_k_cache, + output, + attn_metadata, + swa_metadata, + ) + return + + swa_only = attn_metadata is None + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + num_prefills = swa_metadata.num_prefills + num_prefill_tokens = swa_metadata.num_prefill_tokens + num_reqs = num_decodes + num_prefills + num_tokens = num_decode_tokens + num_prefill_tokens + if num_prefill_tokens == 0: + return + + assert swa_metadata.is_valid_token is not None + assert swa_metadata.query_start_loc is not None + assert swa_metadata.seq_lens is not None + assert swa_metadata.token_to_req_indices is not None + assert swa_metadata.block_table is not None + + # --- Prefill SWA window indices. The metadata builder hoists this once per + # step (DeepseekSparseSWAMetadataBuilder.build widens its decode-SWA launch + # over the prefill tail), so steady-state just reads the precomputed views + # and skips ~60 redundant per-layer kernel launches. The builder deliberately + # leaves them None on the warmup/profile dummy (its all-(-1) slot_mapping + # makes is_valid_token all-False over the prefill tail) and during CUDA-graph + # capture; we then self-compute exactly as before, keeping those paths + # byte-identical to the validated v1 (the prefill packed kernel is a no-op + # over the all-invalid dummy: every token gets swa_len=0). + swa_indices = swa_metadata.prefill_swa_indices + swa_lens = swa_metadata.prefill_swa_lens + if swa_indices is None or swa_lens is None: + from vllm.v1.attention.backends.mla.sparse_swa import ( + _compute_swa_indices_and_lens_kernel, + ) + + swa_idx_full, swa_len_full = _get_prefill_swa_scratch( + num_tokens, self.window_size + ) + _compute_swa_indices_and_lens_kernel[(num_tokens,)]( + swa_idx_full, + swa_idx_full.stride(0), + swa_len_full, + self.window_size, + swa_metadata.query_start_loc, + swa_metadata.seq_lens, + swa_metadata.token_to_req_indices, + swa_metadata.is_valid_token, + swa_metadata.block_table, + swa_metadata.block_table.stride(0), + swa_metadata.block_size, + TRITON_BLOCK_SIZE=1024, + ) + swa_indices = swa_idx_full[num_decode_tokens:num_tokens] + swa_lens = swa_len_full[num_decode_tokens:num_tokens] + + # --- Compressed (extra) prefill indices, mirroring the FlashMLA prefill + # construction but converted to global slots for the packed kernel. + topk_indices = None + topk_lens = None + if not swa_only: + assert attn_metadata is not None + block_size = attn_metadata.block_size // self.compress_ratio + if self.compress_ratio == 4: + assert self.topk_indices_buffer is not None + prefill_local = self.topk_indices_buffer[num_decode_tokens:num_tokens] + # Rebase the indexer's BATCH-GLOBAL compressed top-k positions + # (cu_seqlen_ks = exclusive cumsum of seq_len // compress_ratio; see + # indexer.py) to per-request-local so block_table[req] maps them + # in-range. Without this, req>0 positions overflow into the wrong + # request's physical blocks. No-op at num_prefills==1 (cu_base[0]==0). + comp_lens = ( + swa_metadata.seq_lens[num_decodes:num_reqs] // self.compress_ratio + ) + cu_base = (torch.cumsum(comp_lens, dim=0) - comp_lens).to(torch.int32) + req_local = ( + swa_metadata.token_to_req_indices[num_decode_tokens:num_tokens] + - num_decodes + ).long() + base_per_token = cu_base[req_local].unsqueeze(1) + prefill_local = torch.where( + prefill_local >= 0, prefill_local - base_per_token, prefill_local + ) + global_indices, topk_lens = compute_global_topk_indices_and_lens( + prefill_local, + swa_metadata.token_to_req_indices[num_decode_tokens:num_tokens], + attn_metadata.block_table[:num_reqs], + block_size, + swa_metadata.is_valid_token[num_decode_tokens:num_tokens], + ) + topk_indices = global_indices.view(num_prefill_tokens, 1, -1) + else: + assert attn_metadata.c128a_prefill_topk_indices is not None + topk_indices = attn_metadata.c128a_prefill_topk_indices.view( + num_prefill_tokens, 1, -1 + ) + topk_indices = topk_indices.contiguous() + + # --- Launch the packed prefill kernel via the runner. num_tokens > 64 + # auto-dispatches the prefill kernel; mid_out/mid_lse are decode-only and + # only needed for the (rare) <=64-token prefill chunk. + query = self._prepare_sm120_query(q, output) + # Bug-C guard: under CUDA-graph padding or MTP-draft, q can carry more rows + # than the real prefill-token count; the runner sizes its writes by the + # query row count, so slice to num_prefill_tokens to match output/indices/ + # scratch (no-op in the common, unpadded case). + if query.shape[0] > num_prefill_tokens: + query = query[:num_prefill_tokens] + swa_cache = _as_sparse_sm120_cache(swa_k_cache) + extra_cache = ( + _as_sparse_sm120_cache(compressed_k_cache) + if (compressed_k_cache is not None and not swa_only) + else None + ) + mid_out = None + mid_lse = None + if num_prefill_tokens <= _DECODE_MAX_TOKENS: + extra_topk = topk_indices.shape[-1] if topk_indices is not None else 0 + mid_out, mid_lse = _get_decode_scratch( + num_prefill_tokens, + output.shape[1], + output.shape[-1], + swa_indices.shape[-1], + extra_topk, + ) + self._sm120_runner.run( + query, + swa_cache, + swa_indices, + output, + self.scale, + topk_length=swa_lens, + attn_sink=self.attn_sink, + extra_kv_cache=extra_cache, + extra_indices=topk_indices, + extra_topk_length=topk_lens, + mid_out=mid_out, + mid_lse=mid_lse, + ) diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index a4511096116f..4e14eed3901f 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -5,7 +5,9 @@ import torch +import vllm.envs as envs from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -28,6 +30,8 @@ SlidingWindowMLASpec, ) +logger = init_logger(__name__) + # DeepseekV4 decode layer types, keyed by compress_ratio. Each type has a distinct # (topk, extra_topk, extra_page_block_size) config, so they cannot share a # FlashMLA tile-scheduler plan. Within a type, all ~60 DeepseekV4 layers share one @@ -167,6 +171,15 @@ class DeepseekSparseSWAMetadata: token_to_req_indices: torch.Tensor | None = None # [num_tokens] decode_swa_indices: torch.Tensor | None = None # [num_decode_tokens, window_size] decode_swa_lens: torch.Tensor | None = None # [num_decode_tokens] + # Prefill SWA window indices, hoisted once-per-step by the builder as views of + # the decode_swa_* buffers over the prefill token range [num_decode_tokens: + # num_tokens]. Populated only when the FlashInfer SM120 packed-prefill feature + # is active AND the step has real prefill tokens; None otherwise (gate-off / + # warmup / CUDA-graph capture / decode-only) -> the layer self-computes. + prefill_swa_indices: torch.Tensor | None = ( + None # [num_prefill_tokens, 1, window_size] + ) + prefill_swa_lens: torch.Tensor | None = None # [num_prefill_tokens] # Number of decode/prefill requests/tokens (batch is reordered: decodes first) num_decodes: int = 0 @@ -385,9 +398,35 @@ def build( is_valid_token = self.is_valid_token[: slot_mapping.shape[0]] is_valid_token.copy_(slot_mapping >= 0) - if num_decode_tokens > 0: - self.decode_swa_lens[num_decode_tokens:] = 0 - _compute_swa_indices_and_lens_kernel[(num_decode_tokens,)]( + # SWA window indices are keyed by the GLOBAL token index, so a single launch + # over [0, swa_total_tokens) fills the decode rows and -- when the FlashInfer + # SM120 packed-prefill feature is active -- the prefill rows too, hoisting the + # per-token SWA compute out of the per-layer _forward_prefill (~60x/step) into + # this once-per-step build(). decode_swa_* are sized max_num_batched_tokens, + # so num_tokens always fits. + num_tokens = num_decode_tokens + num_prefill_tokens + # Gate the prefill widening conservatively; the predicate short-circuits + # left-to-right so the is_valid_token.any() device sync runs ONLY when the + # feature is on, there are prefill tokens, and we are not in CUDA-graph + # capture (the sync is illegal during capture). The warmup/profile prefill + # dummy fills slot_mapping with -1, so is_valid_token over the prefill tail is + # all-False -> .any() is False -> the widened launch is skipped (this is the + # exact OOB that hung the earlier metadata attempt). + want_prefill_swa = ( + envs.VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL + and num_prefill_tokens > 0 + and not torch.cuda.is_current_stream_capturing() + and bool(is_valid_token[num_decode_tokens:num_tokens].any()) + ) + swa_total_tokens = num_tokens if want_prefill_swa else num_decode_tokens + if want_prefill_swa: + logger.info_once( + "DeepSeek V4 SM120: prefill SWA window indices hoisted into the " + "metadata builder (once per step, replacing per-layer recompute)." + ) + if swa_total_tokens > 0: + self.decode_swa_lens[swa_total_tokens:] = 0 + _compute_swa_indices_and_lens_kernel[(swa_total_tokens,)]( self.decode_swa_indices, self.decode_swa_indices.stride(0), self.decode_swa_lens, @@ -428,6 +467,16 @@ def build( token_to_req_indices=token_to_req_indices, decode_swa_indices=self.decode_swa_indices[:num_decode_tokens], decode_swa_lens=self.decode_swa_lens[:num_decode_tokens], + prefill_swa_indices=( + self.decode_swa_indices[num_decode_tokens:num_tokens] + if want_prefill_swa + else None + ), + prefill_swa_lens=( + self.decode_swa_lens[num_decode_tokens:num_tokens] + if want_prefill_swa + else None + ), block_size=self.block_size, num_decodes=num_decodes, num_prefills=num_prefills, From c736caf00441740ba1955b34f3f1847c745ce1d4 Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 20 Jun 2026 11:02:30 +0800 Subject: [PATCH 120/131] fix(sm12x): rebase compressed top-k per-request in FlashMLA prefill combine combine_topk_swa_indices maps a compressed position p of the k-th in-chunk request to gathered slot p + M*k, which is only correct for request-local p; the indexer writes batch-global (cu_seqlen_ks) positions, so non-first prefill requests indexed past their gathered slot and read stale workspace (latent C4A multi-request prefill correctness bug). Rebase to per-request-local before combine. No-op at num_prefills == 1. Signed-off-by: jasl --- vllm/models/deepseek_v4/nvidia/flashmla.py | 23 ++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/vllm/models/deepseek_v4/nvidia/flashmla.py b/vllm/models/deepseek_v4/nvidia/flashmla.py index 37cac1ae0019..a59452712a97 100644 --- a/vllm/models/deepseek_v4/nvidia/flashmla.py +++ b/vllm/models/deepseek_v4/nvidia/flashmla.py @@ -910,6 +910,29 @@ def _forward_prefill( assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[num_decode_tokens:] topk_indices = topk_indices[:num_prefill_tokens] + # Rebase the indexer's BATCH-GLOBAL compressed top-k positions to + # per-request-local. combine_topk_swa_indices maps a position p of + # the k-th in-chunk request to gathered slot p + M*k, which is only + # correct when p is request-local; the indexer writes cu_seqlen_ks- + # cumulative (batch-global) positions, so without this rebase non- + # first prefill requests index past their gathered slot and read + # stale workspace (latent C4A multi-request prefill bug). torch.where + # preserves -1 sentinels; no-op at num_prefills==1 (cu_base[0]==0). + assert swa_metadata.token_to_req_indices is not None + _comp_lens = seq_lens // self.compress_ratio + _cu_base = (torch.cumsum(_comp_lens, dim=0) - _comp_lens).to( + torch.int32 + ) + _req_local = ( + swa_metadata.token_to_req_indices[ + num_decode_tokens : num_decode_tokens + num_prefill_tokens + ] + - num_decodes + ).long() + _base = _cu_base[_req_local].unsqueeze(1) + topk_indices = torch.where( + topk_indices >= 0, topk_indices - _base, topk_indices + ) else: # C128A: pre-computed during metadata build. assert attn_metadata is not None From 88ec87e1e074249f5b9194d66c1f6ce1b4fe8408 Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 20 Jun 2026 13:18:20 +0800 Subject: [PATCH 121/131] fix(sm12x): int64 block offsets in the non-rowwise paged-MQA-logits kernel too 197d21ee79 cast block_idx to int64 in _fp8_paged_mqa_logits_rowwise_kernel but left _fp8_paged_mqa_logits_kernel (the non-rowwise variant) unfixed: its `block_idx * stride_sb` (scale) and `block_idx[:, :, None] * stride_kvb` (KV) still overflow int32 for the post-rebase packed-KV block stride (~1039680) at higher block ids. conc=1 exercises the rowwise kernel (already clean) but conc=8 hits this one and hard-crashes (Xid 31 MMU fault, SM121 / IMA on SM120). Cast both block-offset multiplies to int64 to match the rowwise kernel. Validated on GB10 2-node (SM121, MTP2): conc=8 long-context coherence 68/68 (0 gibberish, 0 recall-miss; was a hard CUDA crash), GSM8K-200 0.965. Signed-off-by: jasl --- vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py index 1b5bffdf6b4f..ac15de394957 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sm12x_mqa.py @@ -249,7 +249,7 @@ def _fp8_paged_mqa_logits_kernel( logits = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) scale = tl.load( - scale_ptr + block_idx * stride_sb + block_offset[None, :] * stride_ss, + scale_ptr + block_idx.to(tl.int64) * stride_sb + block_offset[None, :] * stride_ss, mask=context_mask, other=0.0, ) @@ -268,7 +268,7 @@ def _fp8_paged_mqa_logits_kernel( ).to(tl.float32) k = tl.load( kv_ptr - + block_idx[:, :, None] * stride_kvb + + block_idx[:, :, None].to(tl.int64) * stride_kvb + block_offset[None, :, None] * stride_kvs + d[None, None, :] * stride_kvd, mask=context_mask[:, :, None] & (d[None, None, :] < head_dim), From 3f42d92be83b867c2f79ced18e4924f3061d4016 Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 20 Jun 2026 22:16:02 +0800 Subject: [PATCH 122/131] chore(sm12x-audit): restore upstream multi-line form for VLLM_ENFORCE_STRICT_TOOL_CALLING Cosmetic reflow churn from 59c7918b16 on an upstream-owned env; restore byte-identical to base 0fbf42af84. No behavior change. --- vllm/envs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 0a31094d5981..c6a24245897f 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1653,9 +1653,10 @@ def _resolve_rust_frontend_path() -> str | None: os.getenv("VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS", "1") ), # Enforce function parameter schemas in structural-tag based tool calling. - "VLLM_ENFORCE_STRICT_TOOL_CALLING": lambda: ( - os.getenv("VLLM_ENFORCE_STRICT_TOOL_CALLING", "True").lower() in ("true", "1") - ), + "VLLM_ENFORCE_STRICT_TOOL_CALLING": lambda: os.getenv( + "VLLM_ENFORCE_STRICT_TOOL_CALLING", "True" + ).lower() + in ("true", "1"), # Control the max chunk bytes (in MB) for the rpc message queue. # Object larger than this threshold will be broadcast to worker # processes via zmq. From ec4ca5739f106294e8195f185e3d65fc1b7bbb9e Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 20 Jun 2026 22:19:33 +0800 Subject: [PATCH 123/131] refactor(sm12x-audit): express breakable-cudagraph auto-enable as a real MiniMax-only gate The DSv4-out behavior (default FULL_AND_PIECEWISE, 1.5-3.8x faster MTP decode, measured RTX/SM120 + GB10/SM121) is unchanged. Replaces the dead always-False _should_auto_enable_deepseek_v4_breakable_cudagraph stub + unused DEEPSEEK_V4_CUDAGRAPH_ARCHITECTURES frozenset + misleading SM120/SM121 comment with a single meaningful _should_auto_enable_breakable_cudagraph(model_config) that returns True only for the MiniMax M3 architectures (upstream's auto-enable set minus DSv4). Test upgraded from tautological always-False asserts to observable behavior: DSv4 off, MiniMax on, others off. --- .../test_deepseek_v4_cudagraph_config.py | 39 ++++--------- vllm/config/vllm.py | 55 ++++++++----------- 2 files changed, 35 insertions(+), 59 deletions(-) diff --git a/tests/config/test_deepseek_v4_cudagraph_config.py b/tests/config/test_deepseek_v4_cudagraph_config.py index 5a938b0ab9d9..2989a3f17d3d 100644 --- a/tests/config/test_deepseek_v4_cudagraph_config.py +++ b/tests/config/test_deepseek_v4_cudagraph_config.py @@ -3,50 +3,35 @@ from types import SimpleNamespace -from vllm.config.vllm import _should_auto_enable_deepseek_v4_breakable_cudagraph -from vllm.platforms import current_platform +from vllm.config.vllm import _should_auto_enable_breakable_cudagraph def _model_config(*architectures: str): return SimpleNamespace(architectures=list(architectures)) -def test_deepseek_v4_does_not_auto_enable_breakable_cudagraph(monkeypatch): +def test_deepseek_v4_does_not_auto_enable_breakable_cudagraph(): # Breakable cudagraph disables torch.compile and is 1.5-3.8x slower for MTP # decode on SM12x (measured); DeepSeek-V4 defaults to FULL_AND_PIECEWISE. - monkeypatch.setattr( - current_platform, - "is_device_capability", - lambda capability, device_id=0: False, - ) - - assert not _should_auto_enable_deepseek_v4_breakable_cudagraph( + assert not _should_auto_enable_breakable_cudagraph( _model_config("DeepseekV4ForCausalLM") ) - assert not _should_auto_enable_deepseek_v4_breakable_cudagraph( + assert not _should_auto_enable_breakable_cudagraph( _model_config("DeepSeekV4MTPModel") ) -def test_deepseek_v4_does_not_auto_enable_breakable_cudagraph_on_sm121(monkeypatch): - monkeypatch.setattr( - current_platform, - "is_device_capability", - lambda capability, device_id=0: capability == 121, +def test_minimax_m3_auto_enables_breakable_cudagraph(): + # MiniMax M3 retains upstream's unconditional auto-enable. + assert _should_auto_enable_breakable_cudagraph( + _model_config("MiniMaxM3SparseForCausalLM") ) - - assert not _should_auto_enable_deepseek_v4_breakable_cudagraph( - _model_config("DeepseekV4ForCausalLM") + assert _should_auto_enable_breakable_cudagraph( + _model_config("MiniMaxM3SparseForConditionalGeneration") ) -def test_non_deepseek_v4_does_not_auto_enable_breakable_cudagraph(monkeypatch): - monkeypatch.setattr( - current_platform, - "is_device_capability", - lambda capability, device_id=0: False, - ) - - assert not _should_auto_enable_deepseek_v4_breakable_cudagraph( +def test_other_models_do_not_auto_enable_breakable_cudagraph(): + assert not _should_auto_enable_breakable_cudagraph( _model_config("Qwen3ForCausalLM") ) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 4c2a60e4f3a3..c7ddeb5f787f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -76,10 +76,10 @@ } ) -DEEPSEEK_V4_CUDAGRAPH_ARCHITECTURES = frozenset( +_BREAKABLE_CUDAGRAPH_AUTO_ENABLE_ARCHITECTURES = frozenset( { - "DeepseekV4ForCausalLM", - "DeepSeekV4MTPModel", + "MiniMaxM3SparseForCausalLM", + "MiniMaxM3SparseForConditionalGeneration", } ) @@ -111,20 +111,22 @@ class OptimizationLevel(IntEnum): # See https://github.com/vllm-project/vllm/issues/25689. -def _should_auto_enable_deepseek_v4_breakable_cudagraph( +def _should_auto_enable_breakable_cudagraph( model_config: ModelConfig, ) -> bool: - # DeepSeek-V4 does NOT auto-enable breakable cudagraph. Breakable mode - # disables the torch.compile pipeline (equivalent to -O.mode=none) and runs - # attention eagerly every decode step; on SM12x that is 1.5-3.8x SLOWER for - # MTP decode and degrades with output length, measured on both RTX PRO 6000 - # (SM120) and 2x GB10 (SM121). FULL_AND_PIECEWISE + torch.compile is correct - # (GSM8K parity, bare-prompt clean) and faster, so it is the default. - # Opt in with VLLM_USE_BREAKABLE_CUDAGRAPH=1 for the MTP + long-context + - # high-concurrency garbled-output workaround (which also engages the - # spec-decode attention eager-break). - del model_config # architecture-independent: never auto-enable - return False + # Auto-enable breakable cudagraph only for architectures that lack + # @support_torch_compile and are known-good under it (MiniMax M3 retains its + # upstream unconditional auto-enable). DeepSeek-V4 is deliberately excluded: + # breakable mode disables the torch.compile pipeline (equivalent to + # -cc.mode=none) and runs attention eagerly every decode step, which on SM12x + # is 1.5-3.8x SLOWER for MTP decode and degrades with output length (measured + # on RTX PRO 6000 / SM120 and 2x GB10 / SM121). FULL_AND_PIECEWISE + + # torch.compile is correct (GSM8K parity, bare-prompt clean) and faster, so + # it is the DSv4 default. Opt in with VLLM_USE_BREAKABLE_CUDAGRAPH=1. + return any( + arch in _BREAKABLE_CUDAGRAPH_AUTO_ENABLE_ARCHITECTURES + for arch in model_config.architectures + ) def enable_norm_fusion(cfg: "VllmConfig") -> bool: @@ -1096,26 +1098,15 @@ def __post_init__(self): ) self.compilation_config.mode = CompilationMode.NONE - # DeepSeek V4's model classes don't carry @support_torch_compile. - # On SM120 the breakable cudagraph is the supported PIECEWISE path; - # on tested SM121/GB10 Ray configs the compiled PIECEWISE path is - # required for correctness (breakable can corrupt graph replay). - # Auto-enable only for the known-good DeepSeek V4 device path; MiniMax - # M3 retains its upstream unconditional auto-enable. + # Some model classes don't carry @support_torch_compile and rely on the + # breakable cudagraph PIECEWISE path; auto-enable it for those unless the + # user explicitly opted out. DeepSeek-V4 is deliberately excluded (see + # _should_auto_enable_breakable_cudagraph) — it is faster on + # FULL_AND_PIECEWISE + torch.compile. if ( self.model_config is not None and "VLLM_USE_BREAKABLE_CUDAGRAPH" not in os.environ - and ( - _should_auto_enable_deepseek_v4_breakable_cudagraph(self.model_config) - or any( - a - in ( - "MiniMaxM3SparseForCausalLM", - "MiniMaxM3SparseForConditionalGeneration", - ) - for a in self.model_config.architectures - ) - ) + and _should_auto_enable_breakable_cudagraph(self.model_config) ): os.environ["VLLM_USE_BREAKABLE_CUDAGRAPH"] = "1" logger.info_once( From 13166bb1bcc6eba7516655109dba28ec5aa394c4 Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 20 Jun 2026 22:21:12 +0800 Subject: [PATCH 124/131] chore(sm12x-audit): move out VLLM_NVFP4_GEMM_BACKEND b12x research lever Reverts 99a9f10e7a (whose actual content is solely the gemm-backend env + _NVFP4_BACKEND_TO_KERNEL force-map, not the modelopt-routing its title names). The DSv4-Flash shipped path does not use the flashinfer-b12x NVFP4 route; this env was only a research lever (and the sole working way to reach b12x, since FlashInferB12xNvFp4LinearKernel is excluded from auto-selection and --linear-backend flashinfer_b12x is filtered out). Preserved verbatim on backup/min-enable-88ec-pre-audit-20260620 for future NVFP4-backend experiments; restore the env+map from that commit to re-enable b12x A/B. --- vllm/envs.py | 26 ----------------- .../model_executor/kernels/linear/__init__.py | 29 ++----------------- 2 files changed, 2 insertions(+), 53 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index c6a24245897f..f5aa11f4434c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -237,7 +237,6 @@ VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = True VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_HAS_FLASHINFER_CUBIN: bool = False - VLLM_NVFP4_GEMM_BACKEND: str | None = None VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True VLLM_ALLREDUCE_USE_FLASHINFER: bool = False @@ -1704,31 +1703,6 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_HAS_FLASHINFER_CUBIN": lambda: bool( int(os.getenv("VLLM_HAS_FLASHINFER_CUBIN", "0")) ), - # Selects the GEMM backend used for NVFP4-quantized linear/MoE layers. - # Supported options: - # - "flashinfer-b12x": use flashinfer b12x (Blackwell 12.x) NVFP4 kernels - # - "flashinfer-cudnn": use flashinfer cudnn GEMM backend - # - "flashinfer-trtllm": use flashinfer trtllm GEMM backend - # - "flashinfer-cutlass": use flashinfer cutlass GEMM backend - # - "marlin": use marlin GEMM backend (for GPUs without native FP4 support) - # - "emulation": - # use BF16/FP16 GEMM, dequantizing weights and running QDQ on activations. - # This is only meant for research purposes to run on devices where NVFP4 - # GEMM kernels are not available. - # - : automatically pick an available backend - "VLLM_NVFP4_GEMM_BACKEND": env_with_choices( - "VLLM_NVFP4_GEMM_BACKEND", - None, - [ - "flashinfer-b12x", - "flashinfer-cudnn", - "flashinfer-trtllm", - "flashinfer-cutlass", - "cutlass", - "marlin", - "emulation", - ], - ), # Controls garbage collection during CUDA graph capture. # If set to 0 (default), enables GC freezing to speed up capture time. # If set to 1, allows GC to run during capture. diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index 214110473651..919d71fb8e8b 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -400,9 +400,8 @@ def _filter_kernels_by_backend( _POSSIBLE_NVFP4_KERNELS: dict[PlatformEnum, list[type[NvFp4LinearKernel]]] = { PlatformEnum.CUDA: [ # FlashInferB12xNvFp4LinearKernel excluded from auto-selection until - # upstream CUTLASS SM121 MMA op guard is resolved; opt in explicitly - # via --linear-backend flashinfer_b12x or - # VLLM_NVFP4_GEMM_BACKEND=flashinfer-b12x. + # upstream CUTLASS SM121 MMA op guard is resolved; use + # --linear-backend flashinfer_b12x to opt in explicitly. FlashInferCutlassNvFp4LinearKernel, CutlassNvFp4LinearKernel, MarlinNvFp4LinearKernel, @@ -832,20 +831,6 @@ def init_wfp8_a16_linear_kernel( ) -# Maps VLLM_NVFP4_GEMM_BACKEND env var values to kernel classes. This is an -# env-driven alternative to --linear-backend for forcing a specific NVFP4 -# linear kernel (e.g. VLLM_NVFP4_GEMM_BACKEND=flashinfer-b12x). -_NVFP4_BACKEND_TO_KERNEL: dict[str, type[NvFp4LinearKernel]] = { - "flashinfer-b12x": FlashInferB12xNvFp4LinearKernel, - "flashinfer-cutlass": FlashInferCutlassNvFp4LinearKernel, - "cutlass": CutlassNvFp4LinearKernel, - "marlin": MarlinNvFp4LinearKernel, - "flashinfer-trtllm": FlashInferTrtllmNvFp4LinearKernel, - "flashinfer-cudnn": FlashInferCudnnNvFp4LinearKernel, - "emulation": EmulationNvFp4LinearKernel, -} - - def init_nvfp4_linear_kernel(use_a16: bool = False) -> NvFp4LinearKernel: """Select and instantiate the best NVFP4 linear kernel for the current platform.""" @@ -885,16 +870,6 @@ def init_nvfp4_linear_kernel(use_a16: bool = False) -> NvFp4LinearKernel: reason, ) force_kernel = EmulationNvFp4LinearKernel - elif envs.VLLM_NVFP4_GEMM_BACKEND is not None: - # Env-driven override (alternative to --linear-backend). Maps a - # VLLM_NVFP4_GEMM_BACKEND value to a concrete kernel class. - backend_name = envs.VLLM_NVFP4_GEMM_BACKEND - force_kernel = _NVFP4_BACKEND_TO_KERNEL.get(backend_name) - if force_kernel is None: - raise ValueError( - f"Unknown VLLM_NVFP4_GEMM_BACKEND={backend_name!r}. " - f"Valid choices: {list(_NVFP4_BACKEND_TO_KERNEL.keys())}" - ) elif linear_backend == "auto" and use_a16: # Force a16 (Marlin) when running weight-only quantization. force_kernel = MarlinNvFp4LinearKernel From 0d66c897ad8776e292d1e55117c5e4b83c5277d0 Mon Sep 17 00:00:00 2001 From: jasl Date: Sat, 20 Jun 2026 23:13:20 +0800 Subject: [PATCH 125/131] refactor(sm12x-audit): move out scheduler prefill-fairness heuristics + tests Removes the 9-commit very-long-prefill starvation / mixed-decode-prefill chunk limiting family (a8bdc004cc, 129e129215, a962bf1f2b[sched part], 1059c81172, ad26f8fc4c, db3a71f53c, 6dac492601, 52c549e573, 574905ac99[sched part]) from scheduler.py: 9 helper methods + 3 call sites; restores the deleted blank line and the original 'assert num_new_tokens > 0'. Drops the 11 tautological fairness tests. Ungated generic-vLLM tuning aimed at a cliff re-diagnosed as MoE-GEMM + NCCL-all-reduce bound (config knobs proven dead on GB10) plus a phantom wedge; never required for correctness. Preserved (verified standalone, not fairness-coupled): the write-fence hooks (20e147276b, kept pending the fence-OFF recall gate), max_num_seqs + DSv4 MLA prefix-retention (fde655cc06), the a962bf1f2b adaptive BLOCK_M kernel tuning in sm12x_mqa.py, and the 574905ac99 record_stats param + its test_prefix_cache_peek_does_not_record_stats (tests kept kv_cache_manager stats-suppression behavior). Net: scheduler.py == base except the 3 KEEP hunks; test_scheduler.py == base except the peek test. Needs RTX long-ctx-concurrency + GSM8K + toolcall-15 no-regression revalidation. --- tests/v1/core/test_scheduler.py | 504 -------------------------------- vllm/v1/core/sched/scheduler.py | 152 +--------- 2 files changed, 2 insertions(+), 654 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 4067c1a9f027..08731742a779 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -866,310 +866,6 @@ def test_schedule_order(enable_chunked_prefill: bool): assert len(scheduler_output1.scheduled_new_reqs) == 1 -def test_mixed_decode_prefill_does_not_cap_short_prefill(): - scheduler = create_scheduler( - max_num_batched_tokens=100, - max_model_len=512, - max_num_seqs=2, - enable_chunked_prefill=True, - ) - decode_req = create_requests(num_requests=1, num_tokens=100, req_ids=["decode"])[0] - short_prefill_req = create_requests( - num_requests=1, - num_tokens=40, - req_ids=["short_prefill"], - )[0] - - scheduler.add_request(decode_req) - prefill_output = scheduler.schedule() - assert prefill_output.num_scheduled_tokens[decode_req.request_id] == 100 - - scheduler.update_from_output( - prefill_output, - ModelRunnerOutput( - req_ids=[decode_req.request_id], - req_id_to_index={decode_req.request_id: 0}, - sampled_token_ids=[[0]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - ), - ) - - scheduler.add_request(short_prefill_req) - mixed_output = scheduler.schedule() - - assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 - assert mixed_output.num_scheduled_tokens[short_prefill_req.request_id] == 40 - - -def test_mixed_decode_prefill_caps_long_prefill_chunk(): - scheduler = create_scheduler( - max_num_batched_tokens=100, - max_model_len=512, - max_num_seqs=2, - enable_chunked_prefill=True, - ) - decode_req = create_requests(num_requests=1, num_tokens=100, req_ids=["decode"])[0] - long_prefill_req = create_requests( - num_requests=1, - num_tokens=300, - req_ids=["long_prefill"], - )[0] - - scheduler.add_request(decode_req) - prefill_output = scheduler.schedule() - assert prefill_output.num_scheduled_tokens[decode_req.request_id] == 100 - - scheduler.update_from_output( - prefill_output, - ModelRunnerOutput( - req_ids=[decode_req.request_id], - req_id_to_index={decode_req.request_id: 0}, - sampled_token_ids=[[0]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - ), - ) - - scheduler.add_request(long_prefill_req) - mixed_output = scheduler.schedule() - - assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 - assert mixed_output.num_scheduled_tokens[long_prefill_req.request_id] == 25 - - -def test_running_long_prefill_leaves_budget_for_waiting_short_prefill(): - scheduler = create_scheduler( - max_num_batched_tokens=100, - max_model_len=512, - max_num_seqs=2, - enable_chunked_prefill=True, - ) - long_prefill_req = create_requests( - num_requests=1, - num_tokens=300, - req_ids=["long_prefill"], - )[0] - short_prefill_req = create_requests( - num_requests=1, - num_tokens=20, - req_ids=["short_prefill"], - )[0] - - scheduler.add_request(long_prefill_req) - first_chunk = scheduler.schedule() - assert first_chunk.num_scheduled_tokens[long_prefill_req.request_id] == 100 - - scheduler.update_from_output( - first_chunk, - ModelRunnerOutput( - req_ids=[long_prefill_req.request_id], - req_id_to_index={long_prefill_req.request_id: 0}, - sampled_token_ids=[[]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - ), - ) - - scheduler.add_request(short_prefill_req) - mixed_output = scheduler.schedule() - - assert mixed_output.num_scheduled_tokens[long_prefill_req.request_id] == 75 - assert mixed_output.num_scheduled_tokens[short_prefill_req.request_id] == 20 - - -def test_running_long_prefill_leaves_budget_for_running_short_prefill(): - scheduler = create_scheduler( - max_num_batched_tokens=100, - max_model_len=512, - max_num_seqs=2, - enable_chunked_prefill=True, - ) - long_prefill_req = create_requests( - num_requests=1, - num_tokens=300, - req_ids=["long_prefill"], - )[0] - short_prefill_req = create_requests( - num_requests=1, - num_tokens=80, - req_ids=["short_prefill"], - )[0] - - scheduler.add_request(long_prefill_req) - first_chunk = scheduler.schedule() - assert first_chunk.num_scheduled_tokens[long_prefill_req.request_id] == 100 - - scheduler.update_from_output( - first_chunk, - ModelRunnerOutput( - req_ids=[long_prefill_req.request_id], - req_id_to_index={long_prefill_req.request_id: 0}, - sampled_token_ids=[[]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - ), - ) - - scheduler.add_request(short_prefill_req) - first_mixed = scheduler.schedule() - assert first_mixed.num_scheduled_tokens[long_prefill_req.request_id] == 75 - assert first_mixed.num_scheduled_tokens[short_prefill_req.request_id] == 25 - - scheduler.update_from_output( - first_mixed, - ModelRunnerOutput( - req_ids=[long_prefill_req.request_id, short_prefill_req.request_id], - req_id_to_index={ - long_prefill_req.request_id: 0, - short_prefill_req.request_id: 1, - }, - sampled_token_ids=[[], []], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - ), - ) - - second_mixed = scheduler.schedule() - assert second_mixed.num_scheduled_tokens[long_prefill_req.request_id] == 75 - assert second_mixed.num_scheduled_tokens[short_prefill_req.request_id] == 25 - - -def test_running_very_long_prefill_defers_waiting_very_long_prefill(): - scheduler = create_scheduler( - max_num_batched_tokens=100, - max_model_len=2048, - max_num_seqs=3, - enable_chunked_prefill=True, - ) - first_long_req = create_requests( - num_requests=1, - num_tokens=600, - req_ids=["first_long"], - )[0] - second_long_req = create_requests( - num_requests=1, - num_tokens=600, - req_ids=["second_long"], - )[0] - short_req = create_requests( - num_requests=1, - num_tokens=20, - req_ids=["short"], - )[0] - - scheduler.add_request(first_long_req) - first_chunk = scheduler.schedule() - assert first_chunk.num_scheduled_tokens[first_long_req.request_id] == 100 - scheduler.update_from_output( - first_chunk, - ModelRunnerOutput( - req_ids=[first_long_req.request_id], - req_id_to_index={first_long_req.request_id: 0}, - sampled_token_ids=[[]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - ), - ) - - scheduler.add_request(second_long_req) - scheduler.add_request(short_req) - mixed_output = scheduler.schedule() - - assert mixed_output.num_scheduled_tokens[first_long_req.request_id] == 50 - assert second_long_req.request_id not in mixed_output.num_scheduled_tokens - assert mixed_output.num_scheduled_tokens[short_req.request_id] == 20 - - -def test_running_very_long_prefill_defers_waiting_uncached_long_prefills(): - scheduler = create_scheduler( - max_num_batched_tokens=100, - max_model_len=2048, - max_num_seqs=4, - enable_chunked_prefill=True, - ) - first_long_req = create_requests( - num_requests=1, - num_tokens=600, - req_ids=["first_long"], - )[0] - waiting_long_reqs = create_requests( - num_requests=3, - num_tokens=600, - req_ids=["second_long", "third_long", "fourth_long"], - ) - - scheduler.add_request(first_long_req) - first_chunk = scheduler.schedule() - assert first_chunk.num_scheduled_tokens[first_long_req.request_id] == 100 - scheduler.update_from_output( - first_chunk, - ModelRunnerOutput( - req_ids=[first_long_req.request_id], - req_id_to_index={first_long_req.request_id: 0}, - sampled_token_ids=[[]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - ), - ) - - for request in waiting_long_reqs: - scheduler.add_request(request) - mixed_output = scheduler.schedule() - - assert mixed_output.num_scheduled_tokens[first_long_req.request_id] == 100 - for request in waiting_long_reqs: - assert request.request_id not in mixed_output.num_scheduled_tokens - - -def test_running_very_long_prefill_defers_uncached_long_with_spare_budget(): - scheduler = create_scheduler( - max_num_batched_tokens=100, - max_model_len=2048, - max_num_seqs=2, - enable_chunked_prefill=True, - long_prefill_token_threshold=50, - ) - first_long_req = create_requests( - num_requests=1, - num_tokens=600, - req_ids=["first_long"], - )[0] - second_long_req = create_requests( - num_requests=1, - num_tokens=600, - req_ids=["second_long"], - )[0] - - scheduler.add_request(first_long_req) - first_chunk = scheduler.schedule() - assert first_chunk.num_scheduled_tokens[first_long_req.request_id] == 50 - scheduler.update_from_output( - first_chunk, - ModelRunnerOutput( - req_ids=[first_long_req.request_id], - req_id_to_index={first_long_req.request_id: 0}, - sampled_token_ids=[[]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - ), - ) - - scheduler.add_request(second_long_req) - mixed_output = scheduler.schedule() - - assert mixed_output.num_scheduled_tokens[first_long_req.request_id] == 50 - assert second_long_req.request_id not in mixed_output.num_scheduled_tokens - - def _run_request_to_completion(scheduler: Scheduler, request: Request) -> None: while request.request_id in scheduler.requests: scheduler_output = scheduler.schedule() @@ -1197,60 +893,6 @@ def _run_request_to_completion(scheduler: Scheduler, request: Request) -> None: ) -def test_running_very_long_prefill_admits_cached_tail_request(): - scheduler = create_scheduler( - max_num_batched_tokens=100, - max_model_len=2048, - max_num_seqs=3, - enable_chunked_prefill=True, - enable_prefix_caching=True, - ) - warm_req = create_requests( - num_requests=1, - num_tokens=600, - max_tokens=1, - same_prompt=True, - req_ids=["warm_prefix"], - )[0] - - scheduler.add_request(warm_req) - _run_request_to_completion(scheduler, warm_req) - - first_long_req = create_requests( - num_requests=1, - num_tokens=1000, - req_ids=["first_long"], - )[0] - cached_tail_req = create_requests( - num_requests=1, - num_tokens=620, - same_prompt=True, - req_ids=["cached_tail"], - )[0] - - scheduler.add_request(first_long_req) - first_chunk = scheduler.schedule() - assert first_chunk.num_scheduled_tokens[first_long_req.request_id] == 100 - scheduler.update_from_output( - first_chunk, - ModelRunnerOutput( - req_ids=[first_long_req.request_id], - req_id_to_index={first_long_req.request_id: 0}, - sampled_token_ids=[[]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - ), - ) - - scheduler.add_request(cached_tail_req) - mixed_output = scheduler.schedule() - - assert mixed_output.num_scheduled_tokens[first_long_req.request_id] < 100 - assert cached_tail_req.request_id in mixed_output.num_scheduled_tokens - assert mixed_output.num_scheduled_tokens[cached_tail_req.request_id] <= 100 - - def test_prefix_cache_peek_does_not_record_stats(): scheduler = create_scheduler( max_num_batched_tokens=100, @@ -1298,152 +940,6 @@ def test_prefix_cache_peek_does_not_record_stats(): assert stats.hits == peeked_tokens -def test_running_very_long_prefill_defers_to_later_running_decode(): - scheduler = create_scheduler( - max_num_batched_tokens=100, - max_model_len=2048, - max_num_seqs=2, - enable_chunked_prefill=True, - ) - long_prefill_req = create_requests( - num_requests=1, - num_tokens=1000, - req_ids=["long_prefill"], - )[0] - short_req = create_requests( - num_requests=1, - num_tokens=80, - req_ids=["short_then_decode"], - )[0] - - scheduler.add_request(long_prefill_req) - first_chunk = scheduler.schedule() - assert first_chunk.num_scheduled_tokens[long_prefill_req.request_id] == 100 - scheduler.update_from_output( - first_chunk, - ModelRunnerOutput( - req_ids=[long_prefill_req.request_id], - req_id_to_index={long_prefill_req.request_id: 0}, - sampled_token_ids=[[]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - ), - ) - - scheduler.add_request(short_req) - while short_req.num_computed_tokens < short_req.num_prompt_tokens: - mixed_output = scheduler.schedule() - assert short_req.request_id in mixed_output.num_scheduled_tokens - sampled_token_ids = [] - req_id_to_index = {} - for index, req_id in enumerate(mixed_output.num_scheduled_tokens): - req_id_to_index[req_id] = index - if ( - req_id == short_req.request_id - and ( - short_req.num_computed_tokens - + mixed_output.num_scheduled_tokens[req_id] - ) - >= short_req.num_prompt_tokens - ): - sampled_token_ids.append([0]) - else: - sampled_token_ids.append([]) - scheduler.update_from_output( - mixed_output, - ModelRunnerOutput( - req_ids=list(mixed_output.num_scheduled_tokens), - req_id_to_index=req_id_to_index, - sampled_token_ids=sampled_token_ids, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - ), - ) - - assert long_prefill_req.num_computed_tokens < long_prefill_req.num_prompt_tokens - assert short_req.num_computed_tokens >= short_req.num_prompt_tokens - - decode_mixed = scheduler.schedule() - assert decode_mixed.num_scheduled_tokens[short_req.request_id] == 1 - assert long_prefill_req.request_id not in decode_mixed.num_scheduled_tokens - - -def test_mixed_decode_prefill_caps_mid_long_prefill_more_tightly(): - scheduler = create_scheduler( - max_num_batched_tokens=100, - max_model_len=1024, - max_num_seqs=2, - enable_chunked_prefill=True, - ) - decode_req = create_requests(num_requests=1, num_tokens=100, req_ids=["decode"])[0] - mid_long_prefill_req = create_requests( - num_requests=1, - num_tokens=300, - req_ids=["mid_long_prefill"], - )[0] - - scheduler.add_request(decode_req) - prefill_output = scheduler.schedule() - assert prefill_output.num_scheduled_tokens[decode_req.request_id] == 100 - - scheduler.update_from_output( - prefill_output, - ModelRunnerOutput( - req_ids=[decode_req.request_id], - req_id_to_index={decode_req.request_id: 0}, - sampled_token_ids=[[0]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - ), - ) - - scheduler.add_request(mid_long_prefill_req) - mixed_output = scheduler.schedule() - - assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 - assert mixed_output.num_scheduled_tokens[mid_long_prefill_req.request_id] == 25 - - -def test_mixed_decode_prefill_defers_very_long_prefill(): - scheduler = create_scheduler( - max_num_batched_tokens=100, - max_model_len=4096, - max_num_seqs=2, - enable_chunked_prefill=True, - ) - decode_req = create_requests(num_requests=1, num_tokens=100, req_ids=["decode"])[0] - very_long_prefill_req = create_requests( - num_requests=1, - num_tokens=2000, - req_ids=["very_long_prefill"], - )[0] - - scheduler.add_request(decode_req) - prefill_output = scheduler.schedule() - assert prefill_output.num_scheduled_tokens[decode_req.request_id] == 100 - - scheduler.update_from_output( - prefill_output, - ModelRunnerOutput( - req_ids=[decode_req.request_id], - req_id_to_index={decode_req.request_id: 0}, - sampled_token_ids=[[0]], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - ), - ) - - scheduler.add_request(very_long_prefill_req) - mixed_output = scheduler.schedule() - - assert mixed_output.num_scheduled_tokens[decode_req.request_id] == 1 - assert very_long_prefill_req.request_id not in mixed_output.num_scheduled_tokens - - def test_preempt_during_execution(): # NOTE(woosuk): The actual number of available blocks is 10 instead of 11 # because block 0 is reserved as the null block. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 20aa1fa48ccd..7de5aa427b44 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -385,104 +385,6 @@ def _mamba_block_aligned_split( num_new_tokens = num_new_tokens // block_size * block_size return num_new_tokens - def _has_scheduled_decode(self, requests: list[Request]) -> bool: - return any( - request.num_computed_tokens >= request.num_prompt_tokens - for request in requests - ) - - def _very_long_prefill_threshold(self) -> int: - return self.max_num_scheduled_tokens * 4 - - def _is_very_long_prefill( - self, - request: Request, - num_computed_tokens: int | None = None, - ) -> bool: - if not self.scheduler_config.enable_chunked_prefill: - return False - if num_computed_tokens is None: - num_computed_tokens = request.num_computed_tokens - remaining_prefill = max(0, request.num_prompt_tokens - num_computed_tokens) - return remaining_prefill > self._very_long_prefill_threshold() - - def _waiting_request_remaining_prefill(self, request: Request) -> int: - if request.num_computed_tokens > 0: - num_computed_tokens = request.num_computed_tokens - else: - _, num_computed_tokens = self.kv_cache_manager.get_computed_blocks( - request, - record_stats=False, - ) - return max(0, request.num_prompt_tokens - num_computed_tokens) - - def _has_active_very_long_prefill(self) -> bool: - return any(self._is_very_long_prefill(request) for request in self.running) - - def _is_waiting_request_very_long_prefill(self, request: Request) -> bool: - return ( - self._waiting_request_remaining_prefill(request) - > self._very_long_prefill_threshold() - ) - - def _waiting_prefill_competitor_count(self) -> int: - non_long_prefill_count = 0 - for waiting_request in itertools.chain(self.waiting, self.skipped_waiting): - if not self._is_waiting_request_very_long_prefill(waiting_request): - non_long_prefill_count += 1 - return non_long_prefill_count - - def _has_waiting_requests_for_running_prefill(self, request: Request) -> bool: - if not (self.waiting or self.skipped_waiting): - return False - if not self._is_very_long_prefill(request): - return True - return self._waiting_prefill_competitor_count() > 0 - - def _limit_mixed_decode_prefill_chunk( - self, - request: Request, - num_new_tokens: int, - scheduled_running_reqs: list[Request], - has_waiting_requests: bool = False, - has_pending_decode: bool = False, - prefill_budget_partitions: int = 1, - ) -> int: - if ( - not self.scheduler_config.enable_chunked_prefill - or request.num_computed_tokens >= request.num_prompt_tokens - ): - return num_new_tokens - - has_decode_pressure = ( - self._has_scheduled_decode(scheduled_running_reqs) or has_pending_decode - ) - has_prefill_competition = has_waiting_requests or prefill_budget_partitions > 1 - if not has_decode_pressure and not has_prefill_competition: - return num_new_tokens - - remaining_prefill = request.num_prompt_tokens - request.num_computed_tokens - if remaining_prefill <= self.max_num_scheduled_tokens: - return num_new_tokens - - # Very long prefills span many scheduling steps; a smaller chunk keeps - # already-active decoders from seeing long inter-token gaps and leaves - # room for short requests that arrive behind an active long prefill. - very_long_prefill_threshold = self._very_long_prefill_threshold() - if has_decode_pressure: - if remaining_prefill > very_long_prefill_threshold: - return 0 - else: - mixed_prefill_budget = max(1, self.max_num_scheduled_tokens // 4) - elif remaining_prefill > very_long_prefill_threshold: - mixed_prefill_budget = max( - 1, - self.max_num_scheduled_tokens // max(2, prefill_budget_partitions), - ) - else: - mixed_prefill_budget = max(1, (self.max_num_scheduled_tokens * 3) // 4) - return min(num_new_tokens, mixed_prefill_budget) - def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: self.current_step += 1 # NOTE(woosuk) on the scheduling algorithm: @@ -507,6 +409,7 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: if self._pause_state == PauseState.PAUSED_ALL: # Do not schedule any requests when paused. token_budget = 0 + # Encoder-related. scheduled_encoder_inputs: dict[str, list[int]] = {} encoder_compute_budget = self.max_num_encoder_input_tokens @@ -567,28 +470,6 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) - has_unscheduled_running_prefill = any( - later_request.num_computed_tokens < later_request.num_prompt_tokens - for later_request in self.running[req_index + 1 :] - ) - has_pending_running_decode = any( - later_request.num_computed_tokens >= later_request.num_prompt_tokens - for later_request in self.running[req_index + 1 :] - ) - prefill_budget_partitions = ( - 1 - + int(has_unscheduled_running_prefill) - + self._waiting_prefill_competitor_count() - ) - num_new_tokens = self._limit_mixed_decode_prefill_chunk( - request, - num_new_tokens, - scheduled_running_reqs, - self._has_waiting_requests_for_running_prefill(request) - or has_unscheduled_running_prefill, - has_pending_running_decode, - prefill_budget_partitions, - ) # Make sure the input position does not exceed the max model len. # This is necessary when using spec decoding. @@ -896,14 +777,6 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens - if ( - self._is_very_long_prefill(request, num_computed_tokens) - and self._has_active_very_long_prefill() - ): - request_queue.pop_request() - step_skipped_waiting.prepend_request(request) - continue - encoder_inputs_to_schedule = None external_load_encoder_input = [] new_encoder_compute_budget = encoder_compute_budget @@ -938,28 +811,7 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: break num_new_tokens = min(num_new_tokens, token_budget) - scheduled_prefill_count = sum( - scheduled_request.num_computed_tokens - < scheduled_request.num_prompt_tokens - for scheduled_request in itertools.chain( - scheduled_running_reqs, - scheduled_new_reqs, - scheduled_resumed_reqs, - ) - ) - prefill_budget_partitions = max( - 1, - scheduled_prefill_count - + self._waiting_prefill_competitor_count(), - ) - num_new_tokens = self._limit_mixed_decode_prefill_chunk( - request, - num_new_tokens, - scheduled_running_reqs, - prefill_budget_partitions=prefill_budget_partitions, - ) - if num_new_tokens == 0: - break + assert num_new_tokens > 0 # Schedule encoder inputs. if request.has_encoder_inputs: From 72261a7af149fa5d3fe2ed2b9956e92590731012 Mon Sep 17 00:00:00 2001 From: jasl Date: Sun, 21 Jun 2026 00:28:11 +0800 Subject: [PATCH 126/131] refactor(sm12x-audit): remove prefix-cache write fence (int64 fix holds recall) GPU-validated 2026-06-21 on the int64-fixed 88ec build: the fence-OFF recall gate (VLLM_PREFIX_CACHE_WRITE_FENCE=0) holds arthur long-context coherence 8/8 at conc=8 and 16/16 at conc=16 (MTP2, 0 miss) -- exercising exactly the >=3 concurrent-identical-prefix in-flight hand-off window the fence guarded. So the write fence is redundant: the int64 block-offset overflow fix (197d21ee79 + 88ec87e1e0) is the real long-context recall fix; the fence was built on the disproven shared-write/COW theory and its 06-19 commit-message recall claim (~20%->91%) was masking the then-unfixed int64 bug. Reverts 20e147276b (committed_step/schedule_pass/retired_forward clocks, get_one_block_retired, the default-on env, and the two scheduler hooks). scheduler.py is now identical to base except the fde655cc06 max_num_seqs prefix-retention arg. --- vllm/envs.py | 10 ------- vllm/v1/core/block_pool.py | 48 +++------------------------------ vllm/v1/core/kv_cache_utils.py | 5 ---- vllm/v1/core/sched/scheduler.py | 4 --- 4 files changed, 3 insertions(+), 64 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index f5aa11f4434c..4d85103ac159 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -157,7 +157,6 @@ VLLM_USE_STANDALONE_COMPILE: bool = True VLLM_ENABLE_PREGRAD_PASSES: bool = True VLLM_USE_BREAKABLE_CUDAGRAPH: bool = False - VLLM_PREFIX_CACHE_WRITE_FENCE: bool = True VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False @@ -710,15 +709,6 @@ def _resolve_rust_frontend_path() -> str | None: "VLLM_USE_BREAKABLE_CUDAGRAPH": lambda: ( os.environ.get("VLLM_USE_BREAKABLE_CUDAGRAPH", "0") == "1" ), - # Prefix-cache write-completion fence: only expose a cached block to OTHER - # requests after the forward that wrote its tokens has retired. Prevents a - # concurrent same-prefix request from binding a recent-region block whose - # KV/compressed write is still in flight (DeepSeek-V4 packed multi-group - # pages make this corrupt the recent context under conc>=3). Default on; - # set to 0 to restore the legacy expose-at-schedule-time behavior. - "VLLM_PREFIX_CACHE_WRITE_FENCE": lambda: ( - os.environ.get("VLLM_PREFIX_CACHE_WRITE_FENCE", "1") == "1" - ), # Debug pattern matching inside custom passes. # Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3'). "VLLM_PATTERN_MATCH_DEBUG": lambda: os.environ.get( diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 3b949513f054..dac8914b4a22 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -3,7 +3,6 @@ from collections.abc import Iterable, Sequence from typing import Any -import vllm.envs as envs from vllm.distributed.kv_events import ( MEDIUM_GPU, AllBlocksCleared, @@ -73,25 +72,6 @@ def get_one_block(self, key: BlockHashWithGroupId) -> KVCacheBlock | None: self._unexpected_blocks_type(blocks) return None - def get_one_block_retired( - self, key: BlockHashWithGroupId, retired_forward: int - ) -> "KVCacheBlock | None": - """Like get_one_block, but only returns a block whose committed_step - has retired (committed_step <= retired_forward), so a concurrent request - never binds a block whose writing forward is still in flight.""" - blocks = self._cache.get(key) - if blocks is None: - return None - if isinstance(blocks, KVCacheBlock): - return blocks if blocks.committed_step <= retired_forward else None - if isinstance(blocks, dict): - for blk in blocks.values(): - if blk.committed_step <= retired_forward: - return blk - return None - self._unexpected_blocks_type(blocks) - return None - def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None: """ Inserts the KVCacheBlock to the cache @@ -205,22 +185,6 @@ def __init__( self.metrics_collector = metrics_collector - # Prefix-cache write-completion fence (VLLM_PREFIX_CACHE_WRITE_FENCE). - # schedule_pass advances at each schedule() start; retired_forward - # advances when a forward completes (update_from_output). A block - # committed at schedule_pass S is exposed to other requests only once - # retired_forward >= S (its writing forward has retired). Async-safe: - # the two clocks decouple commit-time from write-completion-time. - self.write_fence = envs.VLLM_PREFIX_CACHE_WRITE_FENCE - self.schedule_pass = 0 - self.retired_forward = 0 - - def advance_schedule_pass(self) -> None: - self.schedule_pass += 1 - - def advance_retired_forward(self) -> None: - self.retired_forward += 1 - def get_cached_block( self, block_hash: BlockHash, kv_cache_group_ids: list[int] ) -> list[KVCacheBlock] | None: @@ -240,14 +204,9 @@ def get_cached_block( block_hash_with_group_id = make_block_hash_with_group_id( block_hash, group_id ) - if self.write_fence: - block = self.cached_block_hash_to_block.get_one_block_retired( - block_hash_with_group_id, self.retired_forward - ) - else: - block = self.cached_block_hash_to_block.get_one_block( - block_hash_with_group_id - ) + block = self.cached_block_hash_to_block.get_one_block( + block_hash_with_group_id + ) if not block: return None cached_blocks.append(block) @@ -323,7 +282,6 @@ def cache_full_blocks( block_hash, kv_cache_group_id ) blk.block_hash = block_hash_with_group_id - blk.committed_step = self.schedule_pass self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk) if new_hashes is not None: new_hashes.append(maybe_convert_block_hash(block_hash)) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 2120e5382b60..a1ebe08c0789 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -133,11 +133,6 @@ class KVCacheBlock: # Whether the block is a null block that should never be cached. is_null: bool = False - # Schedule-pass index at which this block was committed to the prefix cache. - # Used by the write-completion fence to withhold the block from other - # requests until the forward that wrote it has retired (-1 = uncommitted). - committed_step: int = -1 - @property def block_hash(self) -> BlockHashWithGroupId | None: return self._block_hash diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7de5aa427b44..e9372d27ee87 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -420,8 +420,6 @@ def schedule(self, throttle_prefills: bool = False) -> SchedulerOutput: scheduled_timestamp = time.monotonic() self.kv_cache_manager.new_step_starts() - if self.kv_cache_manager.block_pool.write_fence: - self.kv_cache_manager.block_pool.advance_schedule_pass() # DP prefill balancing: on a throttled (non-cadence-aligned) step, defer # all prefill compute unless saturated. @@ -1468,8 +1466,6 @@ def update_from_output( scheduler_output: SchedulerOutput, model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: - if self.kv_cache_manager.block_pool.write_fence: - self.kv_cache_manager.block_pool.advance_retired_forward() sampled_token_ids = model_runner_output.sampled_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict From 052bf0b3a31a5401fcc3b92f1abb4eafe8583454 Mon Sep 17 00:00:00 2001 From: jasl Date: Mon, 22 Jun 2026 18:48:28 +0800 Subject: [PATCH 127/131] fix(sm12x): slice packed-prefill output to num_prefill_tokens (84-vs-83 crash) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR#41834 user 1zilc crash with VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL=1: tvm.error.InternalError: Check failed: output.size(0) == num_tokens (84 vs 83) at flashinfer sparse_mla_sm120.cu:183, cascading to a CUDA illegal memory access that kills EngineCore. Root cause: the packed _forward_prefill Bug-C guard sliced the QUERY to num_prefill_tokens under CUDA-graph / MTP-draft padding but passed the UNSLICED padded OUTPUT to the runner. The flashinfer kernel derives num_tokens from the query rows and hard-asserts output.size(0) == num_tokens, so a padded output (84) vs sliced query (83) aborts. The guard comment already said to slice 'output/ indices/scratch' — only output was missed. Fix: slice output the same way (a view into the same storage; padded tail rows are never read downstream). No-op in the unpadded case; PREFILL-gate-only (the default FlashMLA prefill loops over q.shape[0] and has no such assert). 256k is irrelevant — 32k + small max-num-seqs + MTP reproduces it. --- .../deepseek_v4/nvidia/flashinfer_sm120_decode.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py b/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py index 2ad3c8c8af38..927e925d8b01 100644 --- a/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py +++ b/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py @@ -424,6 +424,18 @@ def _forward_prefill( # scratch (no-op in the common, unpadded case). if query.shape[0] > num_prefill_tokens: query = query[:num_prefill_tokens] + # The packed kernel hard-asserts output.size(0) == num_tokens (derived from + # the sliced query row count). The output buffer can still carry padded + # prefill rows (output[num_decode_tokens:] of a CUDA-graph / MTP-draft padded + # batch), so slice it the same way as the query/indices/scratch above. It is + # a view into the same storage and the padded tail rows are never read or + # written downstream, so this only narrows what the kernel writes; no-op in + # the common unpadded case. Without it the kernel aborts (84 vs 83). + out = ( + output[:num_prefill_tokens] + if output.shape[0] > num_prefill_tokens + else output + ) swa_cache = _as_sparse_sm120_cache(swa_k_cache) extra_cache = ( _as_sparse_sm120_cache(compressed_k_cache) @@ -445,7 +457,7 @@ def _forward_prefill( query, swa_cache, swa_indices, - output, + out, self.scale, topk_length=swa_lens, attn_sink=self.attn_sink, From 5ba0f19f02941c4542c1961e9895a470657e7d6f Mon Sep 17 00:00:00 2001 From: jasl Date: Mon, 22 Jun 2026 19:32:20 +0800 Subject: [PATCH 128/131] fix(sm12x): cast MTP draft logits to float32 before top-k/top-p sampling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Default-path crash under MTP + ANY top-k/top-p sampling (i.e. ordinary non-greedy chat traffic): compute_probs_and_sample_next_token (the MTP draft sampler, a noted duplicate of the main sampler) called apply_top_k_top_p on the bf16 draft-head logits, but the triton top-k/top-p kernel asserts logits.dtype == torch.float32 (topk_topp_triton.py:881) -> AssertionError -> worker dies -> EngineDeadError. Greedy requests return early (all_greedy branch) and never reach the sampler, so greedy-only GSM8K validation never exercised this — a 256k sampled-traffic soak (temperature 0.7, top_p 0.9) crashes in ~25s. Fix mirrors the main sampler: cast logits to float32 before div_/apply_top_k_top_p. Likely the (or a) root cause of the PR#41834 default-path crash reports under real chat traffic. --- vllm/v1/spec_decode/llm_base_proposer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/v1/spec_decode/llm_base_proposer.py b/vllm/v1/spec_decode/llm_base_proposer.py index c512909415e6..9c5eb2098611 100644 --- a/vllm/v1/spec_decode/llm_base_proposer.py +++ b/vllm/v1/spec_decode/llm_base_proposer.py @@ -1794,6 +1794,12 @@ def compute_probs_and_sample_next_token( # Use epsilon comparison to detect greedy sampling (temperature ~ 0.0) # consistent with sampler.py's _SAMPLING_EPS threshold num_tokens = logits.shape[0] + # The triton top-k/top-p sampler (apply_top_k_top_p) asserts float32 logits, + # matching the main sampler (sampler.py). The MTP draft head emits bf16 logits, + # so without this cast MTP draft sampling with any top-k/top-p (i.e. ordinary + # non-greedy chat traffic) trips that assertion and kills the engine. Greedy + # requests return above and never reach here. + logits = logits.to(torch.float32) temperature = _expand_draft_sampling_tensor( sampling_metadata.temperature, num_tokens, From a94657e60173c2d2075493f4a8468252b2d529e0 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 23 Jun 2026 04:11:46 +0800 Subject: [PATCH 129/131] fix(sm120): don't auto-enable DeepGEMM on SM120 (pinned ref asserts) #43477 added family-120 to CudaPlatform.support_deep_gemm, which selects the DeepGEMM SM120 MXFP4 kernels. Those need the still-unmerged DeepGEMM PR #324; the released/pinned DeepGEMM ref aborts at engine init on SM120 with a scale-factor layout assertion (sf.size(-2) == ceil_div(mn, gran_mn)), so DSv4 fails to serve on stock deps. Drop family-120 here so SM120 uses the Marlin/cutlass + sm12x DeepGEMM-fallback path (matches pre-#43477 behavior). Re-enable when #324 lands. Co-Authored-By: Claude Opus 4.8 --- vllm/platforms/cuda.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index dabe6058e42f..abd86e518ded 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -661,12 +661,17 @@ def support_static_graph_mode(cls) -> bool: @classmethod def support_deep_gemm(cls) -> bool: - """Currently, only Hopper and Blackwell GPUs are supported.""" - return ( - cls.is_device_capability(90) - or cls.is_device_capability_family(100) - or cls.is_device_capability_family(120) - ) + """Currently, only Hopper and Blackwell SM90/SM100 GPUs are supported. + + SM120 (family-120, consumer Blackwell) is intentionally excluded: the + DeepGEMM SM120 MXFP4 kernels require the still-unmerged DeepGEMM PR #324, + and the released/pinned DeepGEMM ref aborts on SM120 with a scale-factor + layout assertion (`sf.size(-2) == ceil_div(mn, gran_mn)`). DSv4 on SM120 + runs the Marlin / cutlass + sm12x DeepGEMM-fallback path instead, which + serves on stock deps. (#43477 enabled family-120 here; re-enable once + DeepGEMM #324 lands.) + """ + return cls.is_device_capability(90) or cls.is_device_capability_family(100) @classmethod def is_integrated_gpu(cls, device_id: int = 0) -> bool: From f7b4b425b094bcf52279095c683cbab8f78a9542 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 23 Jun 2026 16:44:13 +0800 Subject: [PATCH 130/131] fix(sm120): gate #43477 prefill-SWA launch + clamp kernel OOB (inference crash) Reconciling #43477's sparse_swa I dropped the prefill-SWA gate, so _compute_swa_indices_and_lens_kernel launched unconditionally over all prefill tokens. Its block_table load computes the address for every lane (only the load is masked); the masked-off tail lanes of a deep (32k) prefill row index past the request's block_table row -> cudaErrorLaunchFailure under concurrent load (the 'unspecified launch failure' that wedged SM120). The sibling kernel in this file already clamps this exact SM12x+Triton-3.6 masked-lane IMA via safe_offset. The prefill-SWA indices are consumed only by the FlashInfer-SM120 fork attention path; the stock FlashMLA/Triton prefill self-computes and discards them. Fix: (1) re-gate the launch + metadata behind VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL (default off) -> stock path back to pre-reconcile decode-only behavior; (2) clamp the masked-off lanes in _compute_swa_indices_and_lens_kernel (defense-in-depth); (3) pass the now-mandatory token_offset=0 in the flashinfer_sm120_decode self-compute fallback (latent TypeError). Co-Authored-By: Claude Opus 4.8 --- .../nvidia/flashinfer_sm120_decode.py | 1 + vllm/v1/attention/backends/mla/sparse_swa.py | 29 ++++++++++++++----- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py b/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py index 272922f9bd5d..547c2380f6d8 100644 --- a/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py +++ b/vllm/models/deepseek_v4/nvidia/flashinfer_sm120_decode.py @@ -367,6 +367,7 @@ def _forward_prefill( swa_metadata.block_table, swa_metadata.block_table.stride(0), swa_metadata.block_size, + token_offset=0, TRITON_BLOCK_SIZE=1024, ) swa_indices = swa_idx_full[num_decode_tokens:num_tokens] diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index 783b844ce3b5..0e54a27ba30a 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -5,6 +5,7 @@ import torch +import vllm.envs as envs from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.platforms import current_platform @@ -423,10 +424,18 @@ def build( TRITON_BLOCK_SIZE=1024, ) - # Prefill SWA indices live in paged coordinates. `token_offset` lets - # the kernel read is_valid_token / token_to_req_indices at absolute - # prefill positions while writing output starting at index 0. - if num_prefill_tokens > 0: + # Prefill SWA indices (paged coords; `token_offset` lets the kernel read at + # absolute prefill positions while writing from index 0) are consumed ONLY by + # the FlashInfer SM120 sparse-MLA fork path. The stock FlashMLA/Triton prefill + # self-computes and never reads them, so gate the launch behind + # VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL (default off). Running it + # unconditionally faulted `_compute_swa_indices_and_lens_kernel` over 32k + # prefill rows (unclamped block_table address arithmetic on masked-off lanes + # -> cudaErrorLaunchFailure under concurrent load). + want_prefill_swa = ( + num_prefill_tokens > 0 and envs.VLLM_DEEPSEEK_V4_FLASHINFER_SM120_PREFILL + ) + if want_prefill_swa: prefill_swa_indices = self.prefill_swa_indices[:num_prefill_tokens] prefill_swa_lens = self.prefill_swa_lens[:num_prefill_tokens] _compute_swa_indices_and_lens_kernel[(num_prefill_tokens,)]( @@ -473,12 +482,12 @@ def build( decode_swa_lens=self.decode_swa_lens[:num_decode_tokens], prefill_swa_indices=( self.prefill_swa_indices[:num_prefill_tokens] - if num_prefill_tokens > 0 + if want_prefill_swa else None ), prefill_swa_lens=( self.prefill_swa_lens[:num_prefill_tokens] - if num_prefill_tokens > 0 + if want_prefill_swa else None ), block_size=self.block_size, @@ -664,8 +673,14 @@ def _compute_swa_indices_and_lens_kernel( pos_offset = start_pos + offset block_indices = pos_offset // block_size + # Clamp masked-off lanes before the address add: SM12x + Triton 3.6 raises + # IMA on out-of-bounds address arithmetic even when the load mask gates the + # read (same hazard the sibling _compute_prefill_metadata_kernel clamps via + # safe_offset). Over a deep prefill row the tail lanes index past the + # request's block_table row -> cudaErrorLaunchFailure without this. + safe_block_indices = tl.where(pos_offset < end_pos, block_indices, 0) block_numbers = tl.load( - block_table_ptr + req_idx * block_table_stride + block_indices, + block_table_ptr + req_idx * block_table_stride + safe_block_indices, mask=pos_offset < end_pos, ) block_offsets = pos_offset % block_size From 28fef2c703a350e1f6df7db266689251067bd968 Mon Sep 17 00:00:00 2001 From: jasl Date: Tue, 23 Jun 2026 23:50:38 +0800 Subject: [PATCH 131/131] fix(sm12x): restore DeepSeek-V4 D512-split prefill warmup (stale import) kernel_warmup.py imported the removed module constant _INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS from flashmla.py, which a prior refactor replaced with envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS. The ImportError was swallowed by the surrounding `except ImportError`, so the D512-split prefill warmup (default-on for DSv4 SM12x) silently self-skipped: the first long prefill (>4096 tokens) JIT-compiled the split kernels mid-inference (latency spike, and a hang/crash on FULL-capture builds at long context), negating the warmup PR #41834 added. Drop the dead import item and point the max_model_len guard at the env var (envs is already imported). Also raise the swallowed-import log from DEBUG to WARNING: the early gate already confirms the warmup was requested, so a failed import here is a real problem (a renamed symbol) that should surface instead of no-op'ing for weeks. Reported, diagnosed, and patched by @wingcomm (PR #41834). --- vllm/model_executor/warmup/kernel_warmup.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index d2b3190e045c..b97368269348 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -408,7 +408,6 @@ def _deepseek_v4_indexed_d512_split_prefill_warmup(runner: "GPUModelRunner") -> ) from vllm.models.deepseek_v4.nvidia.flashmla import ( _INDEXED_D512_SPLIT_PREFILL_MAX_TOPK, - _INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS, _INDEXED_D512_SPLIT_PREFILL_MIN_TOPK, DeepseekV4FlashMLAAttention, ) @@ -419,17 +418,26 @@ def _deepseek_v4_indexed_d512_split_prefill_warmup(runner: "GPUModelRunner") -> from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( accumulate_indexed_d512_split_sparse_mla_attention, ) - except ImportError: - logger.debug( - "Skipping DeepSeek V4 D512-split prefill warmup: split kernels or " - "helpers are unavailable on this build." + except ImportError as exc: + # The early gate above already confirmed the warmup is requested, so a + # failed import here is not a benign "kernels unavailable" case — it is + # usually a renamed symbol (it silently disabled this warmup for weeks). + # Surface it at WARNING so a future rename does not no-op the warmup. + logger.warning( + "Skipping DeepSeek V4 D512-split prefill warmup: a required symbol " + "failed to import (%s). The split kernels are likely present but a " + "helper was renamed; the first long prefill will JIT them mid-inference.", + exc, ) return try: if not is_triton_sparse_mla_enabled_for_platform(): return - if getattr(runner, "max_model_len", 0) < _INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS: + if ( + getattr(runner, "max_model_len", 0) + < envs.VLLM_DEEPSEEK_V4_INDEXED_D512_SPLIT_PREFILL_MIN_TOKENS + ): return # The split kernel never sees compress_ratio, so any cr in (4, 128)