Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
128 commits
Select commit Hold shift + click to select a range
c7b19da
Fix DeepSeek V4 MLA prefix cache reuse
jasl May 6, 2026
0cc91bf
Add Blackwell tuning config aliases
jasl May 5, 2026
b7c8123
Add portable sparse MLA Triton kernels
jasl May 6, 2026
d3c3739
Add DeepSeek V4 SM12x fallback ops
jasl May 6, 2026
f8f18ce
Route SM12x DeepGEMM fallbacks
jasl May 6, 2026
cbb6995
Wire SM12x sparse MLA into DeepSeek V4
jasl May 6, 2026
b32ed20
Reduce DeepSeek V4 load overhead on GB10
jasl May 6, 2026
e46bd0d
Apply weight filter to fast safetensors loading
jasl May 6, 2026
f809e14
Warm DeepSeek V4 startup kernels
jasl May 5, 2026
e47bdd1
Add SM12x sparse MLA direct decode kernels
jasl May 6, 2026
14be6ce
Stabilize DeepSeek V4 MTP scheduling
jasl May 5, 2026
28714ae
Warm DeepSeek V4 MTP spec-decode kernels
jasl May 8, 2026
8102228
Tune dense FP8 block-scaled GEMM configs for SM12x DSv4
jasl May 11, 2026
919365f
T1-D: adaptive BLOCK_M for _fp8_paged_mqa_logits_kernel (SM12x)
jasl May 11, 2026
326b480
T2-A: clamp BLOCK_D in sparse MLA finish kernel to head_dim
jasl May 11, 2026
5530c37
Extend DeepSeek V4 prefill warmup to max single-chunk size
jasl May 12, 2026
c7653d8
Extend DeepSeek V4 warmup coverage to multi-request shapes
jasl May 12, 2026
a6c292a
Restore rowwise paged-MQA logits kernel for SM12x long context
jasl May 13, 2026
a042a41
reasoning: defensive implicit </think> for DeepSeek V4 tool-call stre…
jasl May 13, 2026
411a2de
sm12x: keep @torch.compile on HC head reduction via free-function wra…
jasl May 14, 2026
21d46ac
sm12x: drop multi-request prefill warmup that crashes CUTeDSL kv-gather
jasl May 14, 2026
455c282
sm12x: drop vestigial cudagraph kill-switch on Triton sparse MLA
jasl May 14, 2026
6e4a037
sm12x: harden sparse_attn_indexer seq_lens slice with .contiguous()
jasl May 14, 2026
00ea0da
sm12x: autotune num_warps on fp8_einsum + fused_indexer_q kernels
jasl May 14, 2026
4ed8452
sm12x: autotune num_warps/num_stages on 3 sparse MLA accumulate kernels
jasl May 14, 2026
e8b14e2
sm12x: add 3 dense FP8 W8A8 Block configs for RTX PRO 6000 WS Edition
jasl May 14, 2026
1b69e3d
sm12x: cap C128A metadata kernel loop at effective_topk (no shape cha…
jasl May 14, 2026
5632e82
sm12x: per-token early-loop-exit on sparse MLA accumulate inner candi…
jasl May 15, 2026
47cc4c3
sm12x: docs cleanup pass 1 — clarify metadata + MLA manager docstrings
jasl May 15, 2026
dbed306
sm12x: docs cleanup pass 2 — dedupe _upcast_e8m0_to_fp32 + simplify s…
jasl May 15, 2026
1cf15d2
sm12x: docs cleanup pass 3 — drop tautological is_valid in 7 accumula…
jasl May 15, 2026
922e64b
sm12x: multi-head prefill accumulate kernel + drop fp8 einsum autotune
alexbi29 May 16, 2026
9c37c22
sm12x: add fused-MoE FP8 W8A8 Block configs for RTX PRO 6000 (4 shape…
jasl May 17, 2026
ff98c4b
sm120: use Triton MQA logits for direct topk fallback
jasl May 18, 2026
ef37fc5
sm120: use custom row topk for MQA fallback indices
jasl May 18, 2026
2ebb26f
sm120: widen FP8 MQA logits tile
jasl May 18, 2026
e072e20
sm120: increase FP8 MQA logits row tile
jasl May 18, 2026
1028229
Fix DeepSeek V4 MTP sparse SWA reordering
jasl May 19, 2026
c429c77
sm12x: update DeepSeek V4 fallback imports
jasl May 19, 2026
2fb99cc
tests: update DeepSeek V4 MegaMoE refactor assumptions
jasl May 19, 2026
f30a7af
Fix DeepSeek V4 MLA prompt cache protection
jasl May 19, 2026
0101ca2
Clean up DeepSeek V4 upstream rebase leftovers
jasl May 19, 2026
ca2dd98
Fix CUTeDSL availability probe
jasl May 19, 2026
95fc907
Fix DeepSeek V4 MTP small-batch graph hangs
jasl May 19, 2026
16c1667
Remove ineffective DeepSeek V4 mHC warmup
jasl May 19, 2026
97b7143
Tune SM120 FP8 MQA logits row tile
jasl May 19, 2026
a81282b
Clean up SM120 rebase leftovers
jasl May 20, 2026
52a0c19
Remove unused SM120 splitKV decode experiment
jasl May 20, 2026
a8bdc00
Limit long prefill chunks behind active decode
jasl May 21, 2026
129e129
Tighten mixed prefill cap for very long prompts
jasl May 21, 2026
a962bf1
Improve SM120 mixed prefill scheduling
jasl May 21, 2026
b02f683
Clean up DeepSeek V4 reasoning parser lint
jasl May 22, 2026
54cdf34
Add DeepSeek V4 prefix cache pressure regression
jasl May 23, 2026
75ba377
Keep hybrid prefix cache tail blocks
jasl May 23, 2026
ea52914
Stabilize SM12x sparse MLA long prefill
jasl May 24, 2026
44c90d2
Tune SM12x sparse MLA single prefill topk
jasl May 24, 2026
1059c81
Protect active decode from very long prefill
jasl May 25, 2026
650de06
Clean sparse SWA imports after rebase
jasl May 27, 2026
f169084
Guard SM120 FP4 sparse indexer dependency
jasl May 27, 2026
d727d81
Absorb SM120 external Marlin fixes
jasl May 27, 2026
93400dd
sm120: keep optimized MHC prenorm path without DeepGEMM
jasl May 28, 2026
fb9c4d3
sm12x: prune fallback tests and tuned config duplicates
jasl May 28, 2026
d6e4dcd
sm12x: clear MXFP4 loading cache after setup
jasl May 29, 2026
aaef91a
sm12x: drop obsolete MHC CustomOp wrapper
jasl May 29, 2026
ad26f8f
Protect running prefills from long prefill starvation
jasl May 31, 2026
6bce5c9
Add chunked SM120 direct MQA top-k fallback
jasl May 31, 2026
db3a71f
Protect later running decodes from long prefill starvation
jasl Jun 1, 2026
6dac492
Protect very-long prefill fairness
jasl Jun 1, 2026
1381809
sm12x: avoid MHC prenorm GEMM JIT per token count
jasl Jun 1, 2026
56897a2
test: adapt DS4 prefix cache tests to scheduler block size
jasl Jun 2, 2026
6a576fa
fix: export DeepSeek V4 FusedMoE metadata
jasl Jun 3, 2026
52c549e
sched: defer very long prefill under decode pressure
jasl Jun 3, 2026
6e44ffa
sm12x: add sparse MLA prefill D512 split prototype
jasl Jun 2, 2026
42aade9
sm12x: warm high-concurrency MTP decode workspace
jasl Jun 3, 2026
1a3e89c
sm12x: enable indexed D512 sparse MLA prefill by default
jasl Jun 4, 2026
e4b8347
sm12x: retune D512 sparse MLA split tiles
jasl Jun 4, 2026
087c961
fix: align prefix cache manager signatures after rebase
jasl Jun 4, 2026
d3373c5
sm12x: skip empty D512 sparse MLA tail blocks
jasl Jun 4, 2026
dad43ed
sm12x: clean sparse MLA rebase leftovers
jasl Jun 5, 2026
05e4a1d
sm12x: restore DeepSeek V4 O-proj FP8 einsum layout
jasl Jun 5, 2026
25824b8
sm12x: restore Triton sparse MLA decode dispatch
jasl Jun 5, 2026
aac3bdf
config: skip breakable cudagraph auto-enable on SM121
jasl Jun 5, 2026
e1d6dc8
deepseek-v4: preserve ubatch prefill metadata
jasl Jun 5, 2026
08b329c
deepseek-v4: defunctionalize fused MLA insert op
jasl Jun 5, 2026
2aa21fb
deepseek-v4: enable chunked D512 sparse MLA prefill
jasl Jun 6, 2026
ed520ea
deepseek-v4: align sparse MLA metadata after upstream split
jasl Jun 7, 2026
03913e2
fix: sync DeepSeek V4 MoE metadata after runner refactor
jasl Jun 8, 2026
574905a
sched: admit cached long-prompt tails behind long prefill
jasl Jun 8, 2026
c601168
fix: clear stale prefix-cache block hashes on reuse
jasl Jun 8, 2026
79232ca
test: cover DeepSeek V4 RoutedExperts MXFP4 quant dispatch
jasl Jun 9, 2026
c828ef4
deepseek-v4: keep small C128A prefills on D512 path
jasl Jun 9, 2026
fde655c
fix: scale DeepSeek MLA prefix retention by sequence limit
jasl Jun 9, 2026
58d8270
fix: buffer split DeepSeek V4 DSML tool markers
jasl Jun 9, 2026
4b102a0
test: adapt DeepSeek V4 MoE metadata fixture to EPLB
jasl Jun 9, 2026
6ad66c4
sched: keep trace import out of stable path
jasl Jun 11, 2026
237e7e7
fix: restore DeepSeek V4 sparse MLA stats env
jasl Jun 11, 2026
2349559
fix: allow DeepSeek V4 chat stray think end
jasl Jun 13, 2026
c13cae0
sm12x: support FlashInfer CUTLASS MXFP4 opt-in
jasl Jun 2, 2026
99a9f10
deepseek_v4: route NVFP4-modelopt experts to ModelOptNvFp4FusedMoE
vedcsolution Jun 7, 2026
b4c1cab
deepseek_v4: allow flashinfer_cutlass NVFP4 MoE for SwiGLU-clamp models
jasl Jun 15, 2026
2c547ec
deepseek-v4: integrate upstream adaptive prefill chunk planning (#45061)
jasl Jun 15, 2026
a1135c0
deepseek-v4: guard sparse MLA decode metadata against None (mypy)
jasl Jun 15, 2026
59c7918
chore: restore mypy + ruff cleanliness after upstream rebase
jasl Jun 15, 2026
f91caa0
deepseek-v4: drop dead VLLM_DEEPSEEK_V4_SPARSE_MLA_STATS_PATH env
jasl Jun 15, 2026
627adee
feat: gated FlashInfer SM120 packed sparse-MLA decode for DeepSeek V4
jasl Jun 14, 2026
9ab2980
fix: pre-compile DeepSeek-V4 D512-split sparse-MLA prefill kernels at…
jasl Jun 16, 2026
dce4b3b
feat: restore DeepSeek-V4 API semantics on the SM120 min-enable base
jasl Jun 17, 2026
a8280db
fix: correct SM12x indexer prefill top-k non-contiguous output corrup…
jasl Jun 17, 2026
667205e
perf: skip throwaway copy in SM12x indexer top-k fallback
jasl Jun 17, 2026
edf0d37
feat: re-enable breakable cudagraph auto-enable on SM121 for DeepSeek-V4
jasl Jun 18, 2026
3476a25
chore: wrap SM12x indexer fallback line under ruff line-length
jasl Jun 18, 2026
285b542
fix(sm12x): eager-break DeepSeek-V4 attention under FULL cudagraph fo…
jasl Jun 19, 2026
20e1472
fix(v1): write-completion fence for prefix-cache block sharing
jasl Jun 19, 2026
6c92b09
perf(sm12x): default DeepSeek-V4 to FULL_AND_PIECEWISE (drop breakabl…
jasl Jun 19, 2026
f108fa9
fix(sm12x): COW shared final block for DeepSeek-V4 writable cache gro…
jasl Jun 19, 2026
67e060d
perf(sm12x): default indexed-D512 prefill min-token gate to 4096 (env…
jasl Jun 19, 2026
e57276b
revert: drop_eagle_block broadening — breaks prefix caching under MTP
jasl Jun 19, 2026
197d21e
fix(sm12x): int64 block offsets in indexer paged-MQA-logits kernel
jasl Jun 20, 2026
d782c62
feat(sm12x): env-gated packed FlashInfer sparse-MLA prefill for DeepS…
jasl Jun 20, 2026
c736caf
fix(sm12x): rebase compressed top-k per-request in FlashMLA prefill c…
jasl Jun 20, 2026
88ec87e
fix(sm12x): int64 block offsets in the non-rowwise paged-MQA-logits k…
jasl Jun 20, 2026
3f42d92
chore(sm12x-audit): restore upstream multi-line form for VLLM_ENFORCE…
jasl Jun 20, 2026
ec4ca57
refactor(sm12x-audit): express breakable-cudagraph auto-enable as a r…
jasl Jun 20, 2026
13166bb
chore(sm12x-audit): move out VLLM_NVFP4_GEMM_BACKEND b12x research lever
jasl Jun 20, 2026
0d66c89
refactor(sm12x-audit): move out scheduler prefill-fairness heuristics…
jasl Jun 20, 2026
72261a7
refactor(sm12x-audit): remove prefix-cache write fence (int64 fix hol…
jasl Jun 20, 2026
052bf0b
fix(sm12x): slice packed-prefill output to num_prefill_tokens (84-vs-…
jasl Jun 22, 2026
5ba0f19
fix(sm12x): cast MTP draft logits to float32 before top-k/top-p sampling
jasl Jun 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions csrc/libtorch_stable/moe/marlin_moe_wna16/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
STD_TORCH_CHECK(max_shared_mem > 0);
int device_max_shared_mem = max_shared_mem;

int major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
Expand Down Expand Up @@ -527,10 +528,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
}

cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_mem);
device_max_shared_mem);
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
kernel<<<blocks, num_threads, sh_cache_size, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, num_groups, prob_m,
Expand Down Expand Up @@ -708,9 +709,8 @@ torch::stable::Tensor moe_wna16_marlin_gemm(
torch::stable::Tensor c_tmp;
if (use_fp32_reduce && !use_atomic_add) {
// max num of threadblocks is sms * 4
long max_c_tmp_size = min(
(long)size_n * sorted_token_ids.size(0),
(long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n);
long max_c_tmp_size =
(long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n;
if (moe_block_size == 8) max_c_tmp_size *= 2;
c_tmp = torch::stable::new_empty(a, {max_c_tmp_size}, kFloat);
} else {
Expand Down
58 changes: 58 additions & 0 deletions tests/compile/passes/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,70 @@ def ops_not_in_model(self):
return []


class TestFusedDeepseekV4QnormRopeKvInsert(torch.nn.Module):
OP_REGISTERED = False

def __init__(self):
super().__init__()
self.register_test_custom_op()

@classmethod
def register_test_custom_op(cls):
if not cls.OP_REGISTERED:

def fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_impl(
q: torch.Tensor,
kv: torch.Tensor,
k_cache: torch.Tensor,
) -> None:
q.add_(kv)
k_cache.add_(kv)

def fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_fake(
q: torch.Tensor,
kv: torch.Tensor,
k_cache: torch.Tensor,
) -> None:
return None

direct_register_custom_op(
op_name="fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert",
op_func=fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_impl,
mutates_args=["q", "k_cache"],
fake_impl=fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert_fake,
)

cls.OP_REGISTERED = True

def forward(
self, q: torch.Tensor, kv: torch.Tensor, k_cache: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(q, kv, k_cache)
return q, k_cache

def example_inputs(self, num_tokens=32, hidden_size=128):
return (
torch.rand(num_tokens, hidden_size),
torch.rand(num_tokens, hidden_size),
torch.rand(num_tokens, hidden_size),
)

def ops_in_model(self, do_fusion):
return [
torch.ops.vllm.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert.default
]

def ops_not_in_model(self):
return []


MODELS_AND_DO_FUSION = {
TestSiluMul: [True, False],
TestFusedAddRMSNorm: [True, False],
TestRotaryEmbedding: [False],
TestRotaryEmbeddingSliceScatter: [False],
TestFunctionWithMutatedArgsAndReturn: [False],
TestFusedDeepseekV4QnormRopeKvInsert: [False],
}


Expand Down
32 changes: 32 additions & 0 deletions tests/compile/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,38 @@ def test_cudagraph_sizes_post_init(
)


def test_spec_decode_cudagraph_sizes_keep_small_full_decode_batches_exact():
config = CompilationConfig(
cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE,
cudagraph_capture_sizes=[
1,
2,
4,
8,
16,
24,
32,
40,
48,
56,
64,
72,
80,
88,
96,
],
max_cudagraph_capture_size=96,
)

config.adjust_cudagraph_sizes_for_spec_decode(
uniform_decode_query_len=3,
tensor_parallel_size=1,
)

for num_reqs in range(1, 33):
assert 3 * num_reqs in config.cudagraph_capture_sizes


@pytest.mark.skipif(
not current_platform.support_static_graph_mode(),
reason="Skip if not cudagraph mode supported",
Expand Down
37 changes: 37 additions & 0 deletions tests/config/test_deepseek_v4_cudagraph_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from types import SimpleNamespace

from vllm.config.vllm import _should_auto_enable_breakable_cudagraph


def _model_config(*architectures: str):
return SimpleNamespace(architectures=list(architectures))


def test_deepseek_v4_does_not_auto_enable_breakable_cudagraph():
# Breakable cudagraph disables torch.compile and is 1.5-3.8x slower for MTP
# decode on SM12x (measured); DeepSeek-V4 defaults to FULL_AND_PIECEWISE.
assert not _should_auto_enable_breakable_cudagraph(
_model_config("DeepseekV4ForCausalLM")
)
assert not _should_auto_enable_breakable_cudagraph(
_model_config("DeepSeekV4MTPModel")
)


def test_minimax_m3_auto_enables_breakable_cudagraph():
# MiniMax M3 retains upstream's unconditional auto-enable.
assert _should_auto_enable_breakable_cudagraph(
_model_config("MiniMaxM3SparseForCausalLM")
)
assert _should_auto_enable_breakable_cudagraph(
_model_config("MiniMaxM3SparseForConditionalGeneration")
)


def test_other_models_do_not_auto_enable_breakable_cudagraph():
assert not _should_auto_enable_breakable_cudagraph(
_model_config("Qwen3ForCausalLM")
)
173 changes: 173 additions & 0 deletions tests/kernels/moe/test_flashinfer_cutlass_mxfp4_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import sys
import types

import torch

from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
RoutingMethodType,
mxfp4_mxfp8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutlass_moe import (
FlashInferExperts,
)
from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import (
Mxfp4MoeBackend,
convert_weight_to_mxfp4_moe_kernel_format,
)


def _make_moe_config() -> FusedMoEConfig:
return FusedMoEConfig(
num_experts=2,
experts_per_token=1,
hidden_dim=16,
intermediate_size_per_partition=16,
num_local_experts=2,
num_logical_experts=2,
activation=MoEActivation.SILU,
device="cpu",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=torch.bfloat16,
routing_method=RoutingMethodType.TopK,
max_num_tokens=16,
)


def _make_experts(
*,
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
) -> FlashInferExperts:
quant_config = mxfp4_mxfp8_moe_quant_config(
w1_scale=torch.ones((2, 32, 1), dtype=torch.float8_e4m3fn),
w2_scale=torch.ones((2, 16, 1), dtype=torch.float8_e4m3fn),
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return FlashInferExperts(
moe_config=_make_moe_config(),
quant_config=quant_config,
)


def test_mxfp4_swiglu_parameters_stay_unset_without_quant_config() -> None:
experts = _make_experts()

assert experts.gemm1_alpha is None
assert experts.gemm1_beta is None
assert experts.gemm1_clamp_limit is None


def test_mxfp4_swiglu_parameters_follow_quant_config() -> None:
experts = _make_experts(
gemm1_alpha=1.25,
gemm1_beta=0.75,
gemm1_clamp_limit=5.5,
)

torch.testing.assert_close(experts.gemm1_alpha, torch.tensor([1.25, 1.25]))
torch.testing.assert_close(experts.gemm1_beta, torch.tensor([0.75, 0.75]))
torch.testing.assert_close(
experts.gemm1_clamp_limit,
torch.tensor([5.5, 5.5]),
)


def test_cutlass_mxfp8_kernel_format_converts_gate_up_layout(monkeypatch) -> None:
monkeypatch.setitem(
sys.modules,
"flashinfer",
types.SimpleNamespace(block_scale_interleave=lambda x: x.contiguous()),
)

num_experts = 1
intermediate_size = 64
hidden_size = 64
packed_hidden_size = hidden_size // 2
sf_block_size = 32

w13_weight = torch.arange(
num_experts * 2 * intermediate_size * packed_hidden_size,
dtype=torch.uint8,
).reshape(num_experts, 2 * intermediate_size, packed_hidden_size)
w2_weight = torch.arange(
num_experts * hidden_size * (intermediate_size // 2),
dtype=torch.uint8,
).reshape(num_experts, hidden_size, intermediate_size // 2)
w13_scale_u8 = torch.arange(
num_experts * 2 * intermediate_size * (hidden_size // sf_block_size),
dtype=torch.uint8,
).reshape(num_experts, 2 * intermediate_size, hidden_size // sf_block_size)
w2_scale_u8 = torch.arange(
num_experts * hidden_size * (intermediate_size // sf_block_size),
dtype=torch.uint8,
).reshape(num_experts, hidden_size, intermediate_size // sf_block_size)
w13_bias = torch.arange(
num_experts * 2 * intermediate_size,
dtype=torch.bfloat16,
).reshape(num_experts, 2 * intermediate_size)
w2_bias = torch.arange(
num_experts * hidden_size,
dtype=torch.bfloat16,
).reshape(num_experts, hidden_size)

(
out_w13,
out_w2,
out_w13_scale,
out_w2_scale,
out_w13_bias,
out_w2_bias,
) = convert_weight_to_mxfp4_moe_kernel_format(
mxfp4_backend=Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
layer=torch.nn.Module(),
w13_weight=w13_weight,
w2_weight=w2_weight,
w13_weight_scale=w13_scale_u8.view(torch.float8_e4m3fn),
w2_weight_scale=w2_scale_u8.view(torch.float8_e4m3fn),
w13_bias=w13_bias,
w2_bias=w2_bias,
)

expected_w13 = torch.cat(
[
w13_weight[:, intermediate_size:, :],
w13_weight[:, :intermediate_size, :],
],
dim=1,
)
expected_w13_scale = torch.cat(
[
w13_scale_u8[:, intermediate_size:, :],
w13_scale_u8[:, :intermediate_size, :],
],
dim=1,
)
expected_w13_bias = torch.cat(
[
w13_bias[:, intermediate_size:],
w13_bias[:, :intermediate_size],
],
dim=1,
)

assert out_w13.is_contiguous()
assert out_w2.is_contiguous()
torch.testing.assert_close(out_w13, expected_w13)
torch.testing.assert_close(out_w2, w2_weight)
torch.testing.assert_close(out_w13_scale, expected_w13_scale)
torch.testing.assert_close(out_w2_scale, w2_scale_u8)
torch.testing.assert_close(out_w13_bias, expected_w13_bias)
torch.testing.assert_close(out_w2_bias, w2_bias)
Loading