From cd2d9ff7d16da81ff0c0ecc9d2182fa83eef6366 Mon Sep 17 00:00:00 2001 From: STwangyingrui Date: Mon, 15 Jun 2026 08:40:19 +0000 Subject: [PATCH 1/3] add pytorch backend of moe --- .../networks/neopp/infer/transformer_infer.py | 161 ++++++++++++++---- .../neopp/weights/transformer_weights.py | 26 ++- 2 files changed, 148 insertions(+), 39 deletions(-) diff --git a/lightx2v/models/networks/neopp/infer/transformer_infer.py b/lightx2v/models/networks/neopp/infer/transformer_infer.py index 88e5bf96c..47cd7ea1f 100755 --- a/lightx2v/models/networks/neopp/infer/transformer_infer.py +++ b/lightx2v/models/networks/neopp/infer/transformer_infer.py @@ -9,16 +9,59 @@ flashinfer_cutlass_fused_moe = None try: - from magi_compiler import magi_compile, magi_register_custom_op + from magi_compiler import magi_compile except ImportError: magi_compile = None - magi_register_custom_op = None from lightx2v.common.magi_custom_op_mode import configure_dynamo_for_magi_compile from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.models.networks.neopp.infer.kv_cache_manager import KVCacheManager from lightx2v.utils.profiler import * +_GROUPED_MM_ALIGN = 8 + + +def _expert_padded_counts(counts, align=_GROUPED_MM_ALIGN): + pad = (align - counts.remainder(align)) % align + return torch.where(counts > 0, counts + pad, counts) + + +def _sorted_expert_row_map(counts): + num_experts = counts.shape[0] + return torch.repeat_interleave( + torch.arange(num_experts, device=counts.device, dtype=torch.long), + counts.to(torch.long), + ) + + +def _pad_tokens_for_grouped_mm(x_perm, counts): + padded_counts = _expert_padded_counts(counts) + offsets = padded_counts.cumsum(0).to(torch.int32) + + total = counts.sum() + expert_for_row = _sorted_expert_row_map(counts) + perm_starts = counts.cumsum(0) - counts + padded_starts = padded_counts.cumsum(0) - padded_counts + row_idx = torch.arange(total, device=counts.device, dtype=torch.long) + within = row_idx - perm_starts[expert_for_row] + dst_idx = padded_starts[expert_for_row] + within + + x_padded = x_perm.new_zeros(padded_counts.sum(), x_perm.shape[-1]) + x_padded[dst_idx] = x_perm + return x_padded, offsets, padded_counts + + +def _strip_padding_from_grouped_mm_output(out_padded, counts, padded_counts): + total = counts.sum() + expert_for_row = _sorted_expert_row_map(counts) + perm_starts = counts.cumsum(0) - counts + padded_starts = padded_counts.cumsum(0) - padded_counts + row_idx = torch.arange(total, device=counts.device, dtype=torch.long) + within = row_idx - perm_starts[expert_for_row] + src_idx = padded_starts[expert_for_row] + within + return out_padded[src_idx] + + # Register neopp::kv_update as a PyTorch custom op via torch.library. # We use torch.library (define + impl) instead of magi_register_custom_op # because the latter internally calls torch.library.custom_op, which has @@ -76,6 +119,7 @@ def __init__(self, config): if self.version == "moe": self.num_experts_per_tok = llm_config["num_experts_per_tok"] self.norm_topk_prob = llm_config.get("norm_topk_prob", True) + logger.info(f"NeoPP MoE backend: {config.get('moe_backend', 'flashinfer')}") self._mlp_forward = self._sparse_moe else: self._mlp_forward = self._dense_mlp @@ -92,7 +136,10 @@ def __init__(self, config): self.use_magi_compile = False if self.use_magi_compile: configure_dynamo_for_magi_compile() - logger.info("Using Magi Compile (per-layer decoder, split at kv_update)") + if self.version == "moe": + logger.info("Using Magi Compile (per-layer attention, MoE FFN runs eager)") + else: + logger.info("Using Magi Compile (per-layer attn + FFN)") @torch.no_grad() def infer(self, weights, pre_infer_out, inputs): @@ -110,7 +157,6 @@ def infer(self, weights, pre_infer_out, inputs): def infer_without_offload(self, blocks, hidden_states, cos_sin, past_key_values): seq_len_q = hidden_states.shape[0] kvcache_len = past_key_values.shape[2] - seq_len_k = kvcache_len + seq_len_q # Allocate the KV buffer fresh each step so Dynamo sees it as a local # tensor inside the compiled region. @@ -118,23 +164,8 @@ def infer_without_offload(self, blocks, hidden_states, cos_sin, past_key_values) self.kv_cache.prepare(past_key_values, seq_len_q) kv_buf = self.kv_cache._kv_buf - cos_t, sin_t, cos_h, sin_h, cos_w, sin_w = cos_sin for layer_idx, block_weight in enumerate(blocks): - if self.use_magi_compile: - hidden_states = self._decoder_layer_magi( - block_weight, - layer_idx, - hidden_states, - cos_t, - sin_t, - cos_h, - sin_h, - cos_w, - sin_w, - kv_buf, - ) - else: - hidden_states = self._decoder_layer(block_weight, layer_idx, hidden_states, cos_sin, kv_buf) + hidden_states = self._decoder_layer(block_weight, layer_idx, hidden_states, cos_sin, kv_buf) return hidden_states if magi_compile is not None: @@ -159,7 +190,7 @@ def _magi_config_patch(c): }, config_patch=_magi_config_patch, ) - def _decoder_layer_magi( + def _decoder_layer_attn_magi( self, block_weight, layer_idx, @@ -173,22 +204,49 @@ def _decoder_layer_magi( kv_buf, ): cos_sin = (cos_t, sin_t, cos_h, sin_h, cos_w, sin_w) - return self._decoder_layer(block_weight, layer_idx, hidden_states, cos_sin, kv_buf) + return self._decoder_layer_attn(block_weight, layer_idx, hidden_states, cos_sin, kv_buf) - # @ProfilingContext4DebugL1("Decoder Layer") - def _decoder_layer(self, block_weight, layer_idx, hidden_states, cos_sin, kv_buf=None): + @magi_compile( + dynamic_arg_dims={"hidden_states": 0}, + config_patch=_magi_config_patch, + ) + def _decoder_layer_ffn_magi(self, block_weight, hidden_states): + return self._decoder_layer_ffn(block_weight, hidden_states) + + def _decoder_layer_attn(self, block_weight, layer_idx, hidden_states, cos_sin, kv_buf=None): residual = hidden_states hidden_states = block_weight.input_layernorm_mot_gen.apply(hidden_states) - hidden_states = self._self_attn(block_weight.self_attn, layer_idx, hidden_states, cos_sin, kv_buf) - hidden_states = residual + hidden_states + return residual + hidden_states + def _decoder_layer_ffn(self, block_weight, hidden_states): residual = hidden_states gen_hidden = block_weight.post_attention_layernorm_mot_gen.apply(hidden_states) gen_hidden = self._mlp_forward(block_weight.mlp_mot_gen, gen_hidden) - hidden_states = residual + gen_hidden + return residual + gen_hidden - return hidden_states + # @ProfilingContext4DebugL1("Decoder Layer") + def _decoder_layer(self, block_weight, layer_idx, hidden_states, cos_sin, kv_buf=None): + if self.use_magi_compile: + cos_t, sin_t, cos_h, sin_h, cos_w, sin_w = cos_sin + hidden_states = self._decoder_layer_attn_magi( + block_weight, + layer_idx, + hidden_states, + cos_t, + sin_t, + cos_h, + sin_h, + cos_w, + sin_w, + kv_buf, + ) + if self.version == "moe": + return self._decoder_layer_ffn(block_weight, hidden_states) + return self._decoder_layer_ffn_magi(block_weight, hidden_states) + + hidden_states = self._decoder_layer_attn(block_weight, layer_idx, hidden_states, cos_sin, kv_buf) + return self._decoder_layer_ffn(block_weight, hidden_states) # @ProfilingContext4DebugL1("Self Attn") def _self_attn(self, attn_w, layer_idx, hidden_states, cos_sin, kv_buf=None): @@ -305,8 +363,7 @@ def _compute_attn(self, attn_w, query_states, key_states, value_states): ) return attn_output - # @ProfilingContext4DebugL1("Sparse MoE") - def _sparse_moe(self, moe_w, hidden_states): + def _moe_route(self, moe_w, hidden_states): router_logits = moe_w.gate.apply(hidden_states) if self.norm_topk_prob: _, selected_experts = torch.topk(router_logits, self.num_experts_per_tok, dim=-1, sorted=False) @@ -314,6 +371,44 @@ def _sparse_moe(self, moe_w, hidden_states): else: routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.num_experts_per_tok, dim=-1) + return selected_experts, routing_weights + + def _sparse_moe_pytorch(self, moe_w, hidden_states, selected_experts, routing_weights): + hidden_dim = hidden_states.shape[-1] + flat_topk_idx = selected_experts.reshape(-1) + flat_topk_weight = routing_weights.reshape(-1, 1) + + idxs = flat_topk_idx.argsort() + token_idxs = idxs // self.num_experts_per_tok + counts = flat_topk_idx.bincount(minlength=moe_w.num_experts) + + x_perm = hidden_states[token_idxs] + x_padded, offsets, padded_counts = _pad_tokens_for_grouped_mm(x_perm, counts) + gate_out = torch._grouped_mm(x_padded, moe_w._pt_gate_weight, offs=offsets) + up_out = torch._grouped_mm(x_padded, moe_w._pt_up_weight, offs=offsets) + hidden = F.silu(gate_out) * up_out + out_padded = torch._grouped_mm(hidden, moe_w._pt_down_weight, offs=offsets) + expert_out = _strip_padding_from_grouped_mm_output(out_padded, counts, padded_counts) + expert_out.mul_(flat_topk_weight[idxs]) + + expert_cache = torch.zeros_like(hidden_states) + expert_cache = expert_cache.to(expert_out.dtype) + expert_cache.scatter_reduce_( + 0, + token_idxs.view(-1, 1).repeat(1, hidden_dim), + expert_out, + reduce="sum", + ) + return expert_cache + + # @ProfilingContext4DebugL1("Sparse MoE") + def _sparse_moe(self, moe_w, hidden_states): + selected_experts, routing_weights = self._moe_route(moe_w, hidden_states) + if moe_w.moe_backend == "pytorch": + return self._sparse_moe_pytorch(moe_w, hidden_states, selected_experts, routing_weights) + + if flashinfer_cutlass_fused_moe is None: + raise RuntimeError("moe_backend=flashinfer but flashinfer.fused_moe is not available") output = flashinfer_cutlass_fused_moe( hidden_states if hidden_states.is_contiguous() else hidden_states.contiguous(), @@ -324,7 +419,6 @@ def _sparse_moe(self, moe_w, hidden_states): hidden_states.dtype, quant_scales=None, )[0] - return output # @ProfilingContext4DebugL1("FM Head") @@ -339,8 +433,3 @@ def _dense_mlp(self, mlp_w, hidden_states): gate_states = mlp_w.gate_proj.apply(hidden_states) intermediate_states = F.silu(gate_states) * up_states return mlp_w.down_proj.apply(intermediate_states) - - # def _dense_mlp(self, mlp_w, hidden_states): - # gate_up_states = torch.mm(hidden_states, mlp_w._fi_gate_up_weight) - # intermediate_states = flashinfer_silu_and_mul(gate_up_states) - # return mlp_w.down_proj.apply(intermediate_states) diff --git a/lightx2v/models/networks/neopp/weights/transformer_weights.py b/lightx2v/models/networks/neopp/weights/transformer_weights.py index cda373aab..767bbd458 100755 --- a/lightx2v/models/networks/neopp/weights/transformer_weights.py +++ b/lightx2v/models/networks/neopp/weights/transformer_weights.py @@ -5,6 +5,7 @@ except ImportError: get_cutlass_fused_moe_module = None + from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList from lightx2v.common.ops.attn import FlashAttn2Weight, FlashAttn3Weight # noqa: F401 from lightx2v.common.ops.norm.rms_norm_weight import RMSWeightFusedQKNorm3DRope @@ -68,7 +69,10 @@ def __init__(self, block_index, config, mm_type, attn_type="flash_attn2", lora_p if config["version"] == "moe": gen_num_experts = int(config["llm_config"]["gen_num_experts"]) - mlp_mot_gen = NeoppSparseMoeWeights(block_index, mm_type, "mlp_mot_gen", gen_num_experts, lora_path=lora_path) + moe_backend = config.get("moe_backend", "flashinfer") + if moe_backend not in ("pytorch", "flashinfer"): + raise ValueError(f"Invalid moe_backend={moe_backend!r}, expected 'pytorch' or 'flashinfer'") + mlp_mot_gen = NeoppSparseMoeWeights(block_index, mm_type, "mlp_mot_gen", gen_num_experts, moe_backend=moe_backend, lora_path=lora_path) elif config["version"] == "dense": mlp_mot_gen = NeoppMlpWeights(block_index, mm_type, lora_path=lora_path) else: @@ -130,11 +134,12 @@ def __init__(self, config, block_index, mm_type, attn_type="flash_attn2", use_tr class NeoppSparseMoeWeights(WeightModule): - def __init__(self, block_index, mm_type, subname, num_experts, lora_path=None): + def __init__(self, block_index, mm_type, subname, num_experts, moe_backend, lora_path=None): super().__init__() prefix = f"language_model.model.layers.{block_index}.{subname}" lora_prefix = "language_model" + self.moe_backend = moe_backend self.add_module("gate", MM_WEIGHT_REGISTER[mm_type](f"{prefix}.gate.weight", None, lora_prefix=lora_prefix, lora_path=lora_path)) self.num_experts = num_experts @@ -143,7 +148,22 @@ def __init__(self, block_index, mm_type, subname, num_experts, lora_path=None): def load(self, weight_dict): super().load(weight_dict) - self._build_flashinfer_weights() + if self.moe_backend == "flashinfer": + self._build_flashinfer_weights() + elif self.moe_backend == "pytorch": + self._build_pytorch_grouped_mm_weights() + else: + raise ValueError(f"Invalid moe_backend={self.moe_backend!r}, expected 'pytorch' or 'flashinfer'") + + def _build_pytorch_grouped_mm_weights(self): + gate_list, up_list, down_list = [], [], [] + for expert_w in self.experts: + gate_list.append(expert_w.gate_proj._get_actual_weight()) + up_list.append(expert_w.up_proj._get_actual_weight()) + down_list.append(expert_w.down_proj._get_actual_weight()) + self._pt_gate_weight = torch.stack(gate_list, dim=0).contiguous() + self._pt_up_weight = torch.stack(up_list, dim=0).contiguous() + self._pt_down_weight = torch.stack(down_list, dim=0).contiguous() def _build_flashinfer_weights(self): if torch.cuda.is_available(): From 9922dcc340865127a5958f278c515bc0b41af538 Mon Sep 17 00:00:00 2001 From: STwangyingrui Date: Tue, 16 Jun 2026 07:45:26 +0000 Subject: [PATCH 2/3] add flashinfer moe autotune --- configs/neopp/neopp_moe.json | 5 + lightx2v/common/flashinfer_autotune.py | 97 +++++++++++++++++++ .../networks/neopp/infer/moe_fi_autotune.py | 43 ++++++++ .../networks/neopp/infer/transformer_infer.py | 23 ++++- .../neopp/weights/transformer_weights.py | 3 + lightx2v/models/runners/neopp/neopp_runner.py | 36 ++++--- 6 files changed, 195 insertions(+), 12 deletions(-) create mode 100644 lightx2v/common/flashinfer_autotune.py create mode 100644 lightx2v/models/networks/neopp/infer/moe_fi_autotune.py diff --git a/configs/neopp/neopp_moe.json b/configs/neopp/neopp_moe.json index 1b5379d41..c3a74bb01 100644 --- a/configs/neopp/neopp_moe.json +++ b/configs/neopp/neopp_moe.json @@ -3,6 +3,11 @@ "load_kv_cache_in_pipeline_for_debug": true, "infer_steps": 50, "attn_type": "flash_attn3", + "moe_backend": "flashinfer", + "moe_flashinfer_setting": { + "autotune": true, + "tune_max_num_tokens": 8192 + }, "cfg_scale": 4.0, "timestep_shift": 3.0, "cfg_interval": [-1, 2], diff --git a/lightx2v/common/flashinfer_autotune.py b/lightx2v/common/flashinfer_autotune.py new file mode 100644 index 000000000..bbb50ba87 --- /dev/null +++ b/lightx2v/common/flashinfer_autotune.py @@ -0,0 +1,97 @@ +import os +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import torch +from loguru import logger + +FI_FORCE_RETUNE_ENV = "LIGHTX2V_FI_FORCE_RETUNE" + +try: + from flashinfer.autotuner import autotune as flashinfer_autotune +except ImportError: + flashinfer_autotune = None + + +def fi_force_retune(env_name: str = FI_FORCE_RETUNE_ENV) -> bool: + return os.environ.get(env_name, "0").strip().lower() in ("1", "true", "yes", "on") + + +def fi_sm_arch() -> int: + if not torch.cuda.is_available(): + return 0 + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def fi_autotune_cache_path(cache_namespace: str, model_sig: str) -> str: + root = Path.home() / ".cache" / "lightx2v" / "autotune" / cache_namespace + return str(root / model_sig / f"sm{fi_sm_arch()}.json") + + +def _resolve_tune_mode(cache_path: str, *, tune_mode: bool | None) -> bool: + if tune_mode is True: + return True + if tune_mode is False: + return False + return not os.path.isfile(os.path.expanduser(cache_path)) + + +def _tune_mode_label(tune_mode: bool | None, effective: bool) -> str: + if tune_mode is None: + return f"auto->{effective}" + return str(effective) + + +@dataclass +class FlashInferAutotune: + """Generic FlashInfer autotune session (cache + tune_mode dispatch).""" + + enabled: bool = False + cache_path: Optional[str] = None + force_retune_env: str = FI_FORCE_RETUNE_ENV + log_prefix: str = "Flashinfer autotune" + + def cache_rebuild_needed(self) -> bool: + if not self.enabled or not self.cache_path: + return False + if fi_force_retune(self.force_retune_env): + return True + return not os.path.isfile(os.path.expanduser(self.cache_path)) + + @contextmanager + def session(self, *, tune_mode: bool | None = None): + """FlashInfer autotune session. + + ``tune_mode``: + None: cache hit → cache-only; cache miss → lazy online rebuild. + True: profile uncovered shapes (offline tune / one-shot rebuild step). + False: cache-only even when cache is missing (benchmark fallback path). + """ + if not self.enabled or not self.cache_path or flashinfer_autotune is None: + yield + return + + cache_path = os.path.expanduser(self.cache_path) + cache_dir = os.path.dirname(cache_path) + if cache_dir: + os.makedirs(cache_dir, exist_ok=True) + + effective_tune_mode = _resolve_tune_mode(cache_path, tune_mode=tune_mode) + mode_label = _tune_mode_label(tune_mode, effective_tune_mode) + + if fi_force_retune(self.force_retune_env) and effective_tune_mode and os.path.isfile(cache_path): + os.remove(cache_path) + logger.info(f"Removed {self.log_prefix} cache ({self.force_retune_env}=1): {cache_path}; will profile once in this session, then cache-only for later steps/runs") + + if os.path.isfile(cache_path): + logger.info(f"{self.log_prefix}: loading cache from {cache_path} (tune_mode={mode_label})") + elif effective_tune_mode: + logger.info(f"{self.log_prefix}: cache not found at {cache_path}, lazy-rebuilding online (tune_mode={mode_label}); first inference after cache loss will be slower") + else: + logger.warning(f"{self.log_prefix}: cache not found at {cache_path} and tune_mode=False; will use fallback tactics until cache is built.") + + with flashinfer_autotune(effective_tune_mode, cache=cache_path): + yield diff --git a/lightx2v/models/networks/neopp/infer/moe_fi_autotune.py b/lightx2v/models/networks/neopp/infer/moe_fi_autotune.py new file mode 100644 index 000000000..b04fe7ed6 --- /dev/null +++ b/lightx2v/models/networks/neopp/infer/moe_fi_autotune.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass + +from lightx2v.common.flashinfer_autotune import ( + FlashInferAutotune, + fi_autotune_cache_path, +) + +MOE_FI_CACHE_NAMESPACE = "neopp_moe" +MOE_FI_FORCE_RETUNE_ENV = "LIGHTX2V_MOE_FI_FORCE_RETUNE" + + +def build_moe_model_sig(config) -> str: + llm = config["llm_config"] + hidden = int(llm["hidden_size"]) + intermediate = int(llm.get("moe_intermediate_size", llm.get("gen_moe_intermediate_size", 0))) + num_experts = int(llm["gen_num_experts"]) + top_k = int(llm["num_experts_per_tok"]) + return f"neopp_moe_e{num_experts}_k{top_k}_h{hidden}_i{intermediate}_swiglu" + + +def moe_fi_autotune_cache(config) -> str: + return fi_autotune_cache_path(MOE_FI_CACHE_NAMESPACE, build_moe_model_sig(config)) + + +@dataclass +class MoeFiAutotune(FlashInferAutotune): + tune_max_num_tokens: int = 8192 + + @classmethod + def from_neopp_config(cls, config) -> "MoeFiAutotune": + fi_cfg = config.get("moe_flashinfer_setting") or {} + tune_max = int(fi_cfg.get("tune_max_num_tokens", 8192)) + if config.get("version", "moe") != "moe" or config.get("moe_backend") != "flashinfer": + return cls(tune_max_num_tokens=tune_max) + if not fi_cfg.get("autotune", False): + return cls(tune_max_num_tokens=tune_max) + return cls( + enabled=True, + cache_path=moe_fi_autotune_cache(config), + tune_max_num_tokens=tune_max, + force_retune_env=MOE_FI_FORCE_RETUNE_ENV, + log_prefix="Flashinfer MoE autotune", + ) diff --git a/lightx2v/models/networks/neopp/infer/transformer_infer.py b/lightx2v/models/networks/neopp/infer/transformer_infer.py index 47cd7ea1f..4602ebc53 100755 --- a/lightx2v/models/networks/neopp/infer/transformer_infer.py +++ b/lightx2v/models/networks/neopp/infer/transformer_infer.py @@ -1,3 +1,5 @@ +import os + import torch import torch.nn.functional as F from loguru import logger @@ -8,6 +10,12 @@ except ImportError: flashinfer_cutlass_fused_moe = None +from lightx2v.common.flashinfer_autotune import flashinfer_autotune +from lightx2v.models.networks.neopp.infer.moe_fi_autotune import ( + MOE_FI_FORCE_RETUNE_ENV, + MoeFiAutotune, +) + try: from magi_compiler import magi_compile except ImportError: @@ -116,11 +124,23 @@ def __init__(self, config): self.scaling = self.head_dim**-0.5 self.use_triton_qknorm_rope = config.get("use_triton_qknorm_rope", True) self.version = config.get("version", "moe") + self.fi_moe_autotune = MoeFiAutotune.from_neopp_config(config) if self.version == "moe": self.num_experts_per_tok = llm_config["num_experts_per_tok"] self.norm_topk_prob = llm_config.get("norm_topk_prob", True) - logger.info(f"NeoPP MoE backend: {config.get('moe_backend', 'flashinfer')}") + moe_backend = config.get("moe_backend", "flashinfer") + logger.info(f"NeoPP MoE backend: {moe_backend}") self._mlp_forward = self._sparse_moe + if moe_backend == "flashinfer" and self.fi_moe_autotune.enabled: + if flashinfer_autotune is None or flashinfer_cutlass_fused_moe is None: + raise RuntimeError("moe_flashinfer_setting.autotune=true but flashinfer MoE autotuner is not available") + logger.info( + f"NeoPP flashinfer MoE autotune enabled " + f"(cache={self.fi_moe_autotune.cache_path}, " + f"tune_mode=auto (cache-only if present, else lazy rebuild), " + f"tune_max_num_tokens={self.fi_moe_autotune.tune_max_num_tokens}, " + f"{MOE_FI_FORCE_RETUNE_ENV}={os.environ.get(MOE_FI_FORCE_RETUNE_ENV, '0')})" + ) else: self._mlp_forward = self._dense_mlp if self.config["seq_parallel"]: @@ -418,6 +438,7 @@ def _sparse_moe(self, moe_w, hidden_states): moe_w._fi_fc2_weight, hidden_states.dtype, quant_scales=None, + tune_max_num_tokens=self.fi_moe_autotune.tune_max_num_tokens, )[0] return output diff --git a/lightx2v/models/networks/neopp/weights/transformer_weights.py b/lightx2v/models/networks/neopp/weights/transformer_weights.py index 767bbd458..5e296b914 100755 --- a/lightx2v/models/networks/neopp/weights/transformer_weights.py +++ b/lightx2v/models/networks/neopp/weights/transformer_weights.py @@ -72,6 +72,9 @@ def __init__(self, block_index, config, mm_type, attn_type="flash_attn2", lora_p moe_backend = config.get("moe_backend", "flashinfer") if moe_backend not in ("pytorch", "flashinfer"): raise ValueError(f"Invalid moe_backend={moe_backend!r}, expected 'pytorch' or 'flashinfer'") + fi_cfg = config.get("moe_flashinfer_setting") or {} + if fi_cfg.get("autotune") and moe_backend != "flashinfer": + raise ValueError("moe_flashinfer_setting.autotune=true requires moe_backend='flashinfer'") mlp_mot_gen = NeoppSparseMoeWeights(block_index, mm_type, "mlp_mot_gen", gen_num_experts, moe_backend=moe_backend, lora_path=lora_path) elif config["version"] == "dense": mlp_mot_gen = NeoppMlpWeights(block_index, mm_type, lora_path=lora_path) diff --git a/lightx2v/models/runners/neopp/neopp_runner.py b/lightx2v/models/runners/neopp/neopp_runner.py index 7f48ea025..83931f11c 100755 --- a/lightx2v/models/runners/neopp/neopp_runner.py +++ b/lightx2v/models/runners/neopp/neopp_runner.py @@ -266,20 +266,34 @@ def clear_kvcache(self): def init_run(self): self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape) - def run_main(self): - self.init_run() - infer_steps = self.model.scheduler.infer_steps - for step_index in range(infer_steps): - logger.info(f"==> step_index: {step_index + 1} / {infer_steps}") + def _run_infer_step(self, step_index: int, infer_steps: int) -> None: + logger.info(f"==> step_index: {step_index + 1} / {infer_steps}") + + with ProfilingContext4DebugL1("step_pre"): + self.scheduler.step_pre(step_index) - with ProfilingContext4DebugL1("step_pre"): - self.scheduler.step_pre(step_index) + with ProfilingContext4DebugL1("🚀 infer_main"): + self.model.infer(self.inputs) - with ProfilingContext4DebugL1("🚀 infer_main"): - self.model.infer(self.inputs) + with ProfilingContext4DebugL1("step_post"): + self.scheduler.step_post() - with ProfilingContext4DebugL1("step_post"): - self.scheduler.step_post() + def run_main(self): + self.init_run() + infer_steps = self.model.scheduler.infer_steps + infer = self.model.transformer_infer + at = infer.fi_moe_autotune + start_step = 0 + + if at.cache_rebuild_needed(): + logger.info("Flashinfer MoE autotune: cache rebuild required; profiling on step 1 only, then cache-only for remaining steps") + with at.session(tune_mode=True): + self._run_infer_step(0, infer_steps) + start_step = 1 + + with at.session(tune_mode=False): + for step_index in range(start_step, infer_steps): + self._run_infer_step(step_index, infer_steps) if self.config.get("save_result_for_debug", True): gen_result = self.process_images_after_vae_decoder_for_debug() From cb334b427bbcca38a1e092109df10b57e7732668 Mon Sep 17 00:00:00 2001 From: STwangyingrui Date: Tue, 16 Jun 2026 08:12:57 +0000 Subject: [PATCH 3/3] fix review comments of gemini code assist --- .../networks/neopp/infer/transformer_infer.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/lightx2v/models/networks/neopp/infer/transformer_infer.py b/lightx2v/models/networks/neopp/infer/transformer_infer.py index 4602ebc53..e880c8747 100755 --- a/lightx2v/models/networks/neopp/infer/transformer_infer.py +++ b/lightx2v/models/networks/neopp/infer/transformer_infer.py @@ -56,18 +56,11 @@ def _pad_tokens_for_grouped_mm(x_perm, counts): x_padded = x_perm.new_zeros(padded_counts.sum(), x_perm.shape[-1]) x_padded[dst_idx] = x_perm - return x_padded, offsets, padded_counts + return x_padded, offsets, padded_counts, dst_idx -def _strip_padding_from_grouped_mm_output(out_padded, counts, padded_counts): - total = counts.sum() - expert_for_row = _sorted_expert_row_map(counts) - perm_starts = counts.cumsum(0) - counts - padded_starts = padded_counts.cumsum(0) - padded_counts - row_idx = torch.arange(total, device=counts.device, dtype=torch.long) - within = row_idx - perm_starts[expert_for_row] - src_idx = padded_starts[expert_for_row] + within - return out_padded[src_idx] +def _strip_padding_from_grouped_mm_output(out_padded, dst_idx): + return out_padded[dst_idx] # Register neopp::kv_update as a PyTorch custom op via torch.library. @@ -403,19 +396,19 @@ def _sparse_moe_pytorch(self, moe_w, hidden_states, selected_experts, routing_we counts = flat_topk_idx.bincount(minlength=moe_w.num_experts) x_perm = hidden_states[token_idxs] - x_padded, offsets, padded_counts = _pad_tokens_for_grouped_mm(x_perm, counts) + x_padded, offsets, _padded_counts, dst_idx = _pad_tokens_for_grouped_mm(x_perm, counts) gate_out = torch._grouped_mm(x_padded, moe_w._pt_gate_weight, offs=offsets) up_out = torch._grouped_mm(x_padded, moe_w._pt_up_weight, offs=offsets) hidden = F.silu(gate_out) * up_out out_padded = torch._grouped_mm(hidden, moe_w._pt_down_weight, offs=offsets) - expert_out = _strip_padding_from_grouped_mm_output(out_padded, counts, padded_counts) + expert_out = _strip_padding_from_grouped_mm_output(out_padded, dst_idx) expert_out.mul_(flat_topk_weight[idxs]) expert_cache = torch.zeros_like(hidden_states) expert_cache = expert_cache.to(expert_out.dtype) expert_cache.scatter_reduce_( 0, - token_idxs.view(-1, 1).repeat(1, hidden_dim), + token_idxs.view(-1, 1).expand(-1, hidden_dim), expert_out, reduce="sum", )