Skip to content

add FlashInfer MoE autotune and PyTorch grouped-MM MoE backend#1153

Merged
helloyongyang merged 3 commits into
mainfrom
yr/moe_opt
Jun 16, 2026
Merged

add FlashInfer MoE autotune and PyTorch grouped-MM MoE backend#1153
helloyongyang merged 3 commits into
mainfrom
yr/moe_opt

Conversation

@STwangyingrui

Copy link
Copy Markdown
Contributor

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FlashInfer autotuning and adds a fallback PyTorch-based grouped matrix multiplication (_grouped_mm) implementation for Mixture of Experts (MoE) routing. It adds a new autotuning session manager, splits the decoder layer into separate attention and FFN components to optimize Magi Compile compatibility, and updates weight loading to support both backends. The review feedback highlights a critical bug where expert weights must be transposed before stacking to prevent shape mismatches in torch._grouped_mm. Additionally, it recommends performance optimizations to eliminate redundant indexing computations in the padding helpers and to replace memory-allocating .repeat() calls with .expand() during scatter-reduction.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +161 to +169
def _build_pytorch_grouped_mm_weights(self):
gate_list, up_list, down_list = [], [], []
for expert_w in self.experts:
gate_list.append(expert_w.gate_proj._get_actual_weight())
up_list.append(expert_w.up_proj._get_actual_weight())
down_list.append(expert_w.down_proj._get_actual_weight())
self._pt_gate_weight = torch.stack(gate_list, dim=0).contiguous()
self._pt_up_weight = torch.stack(up_list, dim=0).contiguous()
self._pt_down_weight = torch.stack(down_list, dim=0).contiguous()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

In _build_pytorch_grouped_mm_weights, the expert weights are stacked directly without transposing. However, torch._grouped_mm expects the weight matrices to be of shape [in_features, out_features] (i.e., [K, N]), whereas standard PyTorch linear layer weights are stored as [out_features, in_features] (i.e., [N, K]). This will cause a runtime shape mismatch error or mathematically incorrect results during matrix multiplication. They must be transposed using .t().contiguous() before stacking, similar to how it is done in _build_flashinfer_weights.

Suggested change
def _build_pytorch_grouped_mm_weights(self):
gate_list, up_list, down_list = [], [], []
for expert_w in self.experts:
gate_list.append(expert_w.gate_proj._get_actual_weight())
up_list.append(expert_w.up_proj._get_actual_weight())
down_list.append(expert_w.down_proj._get_actual_weight())
self._pt_gate_weight = torch.stack(gate_list, dim=0).contiguous()
self._pt_up_weight = torch.stack(up_list, dim=0).contiguous()
self._pt_down_weight = torch.stack(down_list, dim=0).contiguous()
def _build_pytorch_grouped_mm_weights(self):
gate_list, up_list, down_list = [], [], []
for expert_w in self.experts:
gate_list.append(expert_w.gate_proj._get_actual_weight().t().contiguous())
up_list.append(expert_w.up_proj._get_actual_weight().t().contiguous())
down_list.append(expert_w.down_proj._get_actual_weight().t().contiguous())
self._pt_gate_weight = torch.stack(gate_list, dim=0)
self._pt_up_weight = torch.stack(up_list, dim=0)
self._pt_down_weight = torch.stack(down_list, dim=0)

Comment on lines +45 to +70
def _pad_tokens_for_grouped_mm(x_perm, counts):
padded_counts = _expert_padded_counts(counts)
offsets = padded_counts.cumsum(0).to(torch.int32)

total = counts.sum()
expert_for_row = _sorted_expert_row_map(counts)
perm_starts = counts.cumsum(0) - counts
padded_starts = padded_counts.cumsum(0) - padded_counts
row_idx = torch.arange(total, device=counts.device, dtype=torch.long)
within = row_idx - perm_starts[expert_for_row]
dst_idx = padded_starts[expert_for_row] + within

x_padded = x_perm.new_zeros(padded_counts.sum(), x_perm.shape[-1])
x_padded[dst_idx] = x_perm
return x_padded, offsets, padded_counts


def _strip_padding_from_grouped_mm_output(out_padded, counts, padded_counts):
total = counts.sum()
expert_for_row = _sorted_expert_row_map(counts)
perm_starts = counts.cumsum(0) - counts
padded_starts = padded_counts.cumsum(0) - padded_counts
row_idx = torch.arange(total, device=counts.device, dtype=torch.long)
within = row_idx - perm_starts[expert_for_row]
src_idx = padded_starts[expert_for_row] + within
return out_padded[src_idx]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The functions _pad_tokens_for_grouped_mm and _strip_padding_from_grouped_mm_output both independently compute the exact same indexing and routing variables (total, expert_for_row, perm_starts, padded_starts, row_idx, within, and dst_idx/src_idx). Since these functions are called sequentially in every MoE layer at every inference step, recomputing these tensors is highly redundant and introduces significant CPU-GPU launch overhead. We can return dst_idx from _pad_tokens_for_grouped_mm and pass it directly to _strip_padding_from_grouped_mm_output, completely eliminating the redundant computations.

Suggested change
def _pad_tokens_for_grouped_mm(x_perm, counts):
padded_counts = _expert_padded_counts(counts)
offsets = padded_counts.cumsum(0).to(torch.int32)
total = counts.sum()
expert_for_row = _sorted_expert_row_map(counts)
perm_starts = counts.cumsum(0) - counts
padded_starts = padded_counts.cumsum(0) - padded_counts
row_idx = torch.arange(total, device=counts.device, dtype=torch.long)
within = row_idx - perm_starts[expert_for_row]
dst_idx = padded_starts[expert_for_row] + within
x_padded = x_perm.new_zeros(padded_counts.sum(), x_perm.shape[-1])
x_padded[dst_idx] = x_perm
return x_padded, offsets, padded_counts
def _strip_padding_from_grouped_mm_output(out_padded, counts, padded_counts):
total = counts.sum()
expert_for_row = _sorted_expert_row_map(counts)
perm_starts = counts.cumsum(0) - counts
padded_starts = padded_counts.cumsum(0) - padded_counts
row_idx = torch.arange(total, device=counts.device, dtype=torch.long)
within = row_idx - perm_starts[expert_for_row]
src_idx = padded_starts[expert_for_row] + within
return out_padded[src_idx]
def _pad_tokens_for_grouped_mm(x_perm, counts):
padded_counts = _expert_padded_counts(counts)
offsets = padded_counts.cumsum(0).to(torch.int32)
total = counts.sum()
expert_for_row = _sorted_expert_row_map(counts)
perm_starts = counts.cumsum(0) - counts
padded_starts = padded_counts.cumsum(0) - padded_counts
row_idx = torch.arange(total, device=counts.device, dtype=torch.long)
within = row_idx - perm_starts[expert_for_row]
dst_idx = padded_starts[expert_for_row] + within
x_padded = x_perm.new_zeros(padded_counts.sum(), x_perm.shape[-1])
x_padded[dst_idx] = x_perm
return x_padded, offsets, padded_counts, dst_idx
def _strip_padding_from_grouped_mm_output(out_padded, dst_idx):
return out_padded[dst_idx]

Comment on lines +396 to +422
def _sparse_moe_pytorch(self, moe_w, hidden_states, selected_experts, routing_weights):
hidden_dim = hidden_states.shape[-1]
flat_topk_idx = selected_experts.reshape(-1)
flat_topk_weight = routing_weights.reshape(-1, 1)

idxs = flat_topk_idx.argsort()
token_idxs = idxs // self.num_experts_per_tok
counts = flat_topk_idx.bincount(minlength=moe_w.num_experts)

x_perm = hidden_states[token_idxs]
x_padded, offsets, padded_counts = _pad_tokens_for_grouped_mm(x_perm, counts)
gate_out = torch._grouped_mm(x_padded, moe_w._pt_gate_weight, offs=offsets)
up_out = torch._grouped_mm(x_padded, moe_w._pt_up_weight, offs=offsets)
hidden = F.silu(gate_out) * up_out
out_padded = torch._grouped_mm(hidden, moe_w._pt_down_weight, offs=offsets)
expert_out = _strip_padding_from_grouped_mm_output(out_padded, counts, padded_counts)
expert_out.mul_(flat_topk_weight[idxs])

expert_cache = torch.zeros_like(hidden_states)
expert_cache = expert_cache.to(expert_out.dtype)
expert_cache.scatter_reduce_(
0,
token_idxs.view(-1, 1).repeat(1, hidden_dim),
expert_out,
reduce="sum",
)
return expert_cache

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _sparse_moe_pytorch, token_idxs.view(-1, 1).repeat(1, hidden_dim) is used to expand the index tensor for scatter_reduce_. Using .repeat() allocates a new tensor and copies the data, which can be very expensive and cause memory fragmentation during step-by-step inference. Using .expand(-1, hidden_dim) achieves the same result without any memory allocation or copy overhead. Additionally, we should update the call sites to pass dst_idx directly as per the optimization in _pad_tokens_for_grouped_mm.

    def _sparse_moe_pytorch(self, moe_w, hidden_states, selected_experts, routing_weights):
        hidden_dim = hidden_states.shape[-1]
        flat_topk_idx = selected_experts.reshape(-1)
        flat_topk_weight = routing_weights.reshape(-1, 1)

        idxs = flat_topk_idx.argsort()
        token_idxs = idxs // self.num_experts_per_tok
        counts = flat_topk_idx.bincount(minlength=moe_w.num_experts)

        x_perm = hidden_states[token_idxs]
        x_padded, offsets, padded_counts, dst_idx = _pad_tokens_for_grouped_mm(x_perm, counts)
        gate_out = torch._grouped_mm(x_padded, moe_w._pt_gate_weight, offs=offsets)
        up_out = torch._grouped_mm(x_padded, moe_w._pt_up_weight, offs=offsets)
        hidden = F.silu(gate_out) * up_out
        out_padded = torch._grouped_mm(hidden, moe_w._pt_down_weight, offs=offsets)
        expert_out = _strip_padding_from_grouped_mm_output(out_padded, dst_idx)
        expert_out.mul_(flat_topk_weight[idxs])

        expert_cache = torch.zeros_like(hidden_states)
        expert_cache = expert_cache.to(expert_out.dtype)
        expert_cache.scatter_reduce_(
            0,
            token_idxs.view(-1, 1).expand(-1, hidden_dim),
            expert_out,
            reduce="sum",
        )
        return expert_cache

@STwangyingrui

Copy link
Copy Markdown
Contributor Author

Thanks @gemini-code-assist for the review.

Applied: (1) reuse dst_idx between pad/strip helpers; (2) repeat → expand for scatter indices.

Not applied: transpose in _build_pytorch_grouped_mm_weights. MMWeight already stores weights as [in, out] for torch.mm / torch._grouped_mm. FlashInfer’s extra .t() is for Cutlass layout only — the two backends differ by design.

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Thanks for the update and for clarifying the reasoning behind the weight layout design. Your explanation regarding the MMWeight storage format ([in, out]) and the specific requirement for FlashInfer's .t() being backend-specific to Cutlass makes perfect sense. It's good to see the dst_idx reuse and the expand optimization applied. The implementation looks solid.

@helloyongyang helloyongyang merged commit 97b2320 into main Jun 16, 2026
2 checks passed
@helloyongyang helloyongyang deleted the yr/moe_opt branch June 16, 2026 09:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants