add FlashInfer MoE autotune and PyTorch grouped-MM MoE backend#1153
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for FlashInfer autotuning and adds a fallback PyTorch-based grouped matrix multiplication (_grouped_mm) implementation for Mixture of Experts (MoE) routing. It adds a new autotuning session manager, splits the decoder layer into separate attention and FFN components to optimize Magi Compile compatibility, and updates weight loading to support both backends. The review feedback highlights a critical bug where expert weights must be transposed before stacking to prevent shape mismatches in torch._grouped_mm. Additionally, it recommends performance optimizations to eliminate redundant indexing computations in the padding helpers and to replace memory-allocating .repeat() calls with .expand() during scatter-reduction.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _build_pytorch_grouped_mm_weights(self): | ||
| gate_list, up_list, down_list = [], [], [] | ||
| for expert_w in self.experts: | ||
| gate_list.append(expert_w.gate_proj._get_actual_weight()) | ||
| up_list.append(expert_w.up_proj._get_actual_weight()) | ||
| down_list.append(expert_w.down_proj._get_actual_weight()) | ||
| self._pt_gate_weight = torch.stack(gate_list, dim=0).contiguous() | ||
| self._pt_up_weight = torch.stack(up_list, dim=0).contiguous() | ||
| self._pt_down_weight = torch.stack(down_list, dim=0).contiguous() |
There was a problem hiding this comment.
In _build_pytorch_grouped_mm_weights, the expert weights are stacked directly without transposing. However, torch._grouped_mm expects the weight matrices to be of shape [in_features, out_features] (i.e., [K, N]), whereas standard PyTorch linear layer weights are stored as [out_features, in_features] (i.e., [N, K]). This will cause a runtime shape mismatch error or mathematically incorrect results during matrix multiplication. They must be transposed using .t().contiguous() before stacking, similar to how it is done in _build_flashinfer_weights.
| def _build_pytorch_grouped_mm_weights(self): | |
| gate_list, up_list, down_list = [], [], [] | |
| for expert_w in self.experts: | |
| gate_list.append(expert_w.gate_proj._get_actual_weight()) | |
| up_list.append(expert_w.up_proj._get_actual_weight()) | |
| down_list.append(expert_w.down_proj._get_actual_weight()) | |
| self._pt_gate_weight = torch.stack(gate_list, dim=0).contiguous() | |
| self._pt_up_weight = torch.stack(up_list, dim=0).contiguous() | |
| self._pt_down_weight = torch.stack(down_list, dim=0).contiguous() | |
| def _build_pytorch_grouped_mm_weights(self): | |
| gate_list, up_list, down_list = [], [], [] | |
| for expert_w in self.experts: | |
| gate_list.append(expert_w.gate_proj._get_actual_weight().t().contiguous()) | |
| up_list.append(expert_w.up_proj._get_actual_weight().t().contiguous()) | |
| down_list.append(expert_w.down_proj._get_actual_weight().t().contiguous()) | |
| self._pt_gate_weight = torch.stack(gate_list, dim=0) | |
| self._pt_up_weight = torch.stack(up_list, dim=0) | |
| self._pt_down_weight = torch.stack(down_list, dim=0) |
| def _pad_tokens_for_grouped_mm(x_perm, counts): | ||
| padded_counts = _expert_padded_counts(counts) | ||
| offsets = padded_counts.cumsum(0).to(torch.int32) | ||
|
|
||
| total = counts.sum() | ||
| expert_for_row = _sorted_expert_row_map(counts) | ||
| perm_starts = counts.cumsum(0) - counts | ||
| padded_starts = padded_counts.cumsum(0) - padded_counts | ||
| row_idx = torch.arange(total, device=counts.device, dtype=torch.long) | ||
| within = row_idx - perm_starts[expert_for_row] | ||
| dst_idx = padded_starts[expert_for_row] + within | ||
|
|
||
| x_padded = x_perm.new_zeros(padded_counts.sum(), x_perm.shape[-1]) | ||
| x_padded[dst_idx] = x_perm | ||
| return x_padded, offsets, padded_counts | ||
|
|
||
|
|
||
| def _strip_padding_from_grouped_mm_output(out_padded, counts, padded_counts): | ||
| total = counts.sum() | ||
| expert_for_row = _sorted_expert_row_map(counts) | ||
| perm_starts = counts.cumsum(0) - counts | ||
| padded_starts = padded_counts.cumsum(0) - padded_counts | ||
| row_idx = torch.arange(total, device=counts.device, dtype=torch.long) | ||
| within = row_idx - perm_starts[expert_for_row] | ||
| src_idx = padded_starts[expert_for_row] + within | ||
| return out_padded[src_idx] |
There was a problem hiding this comment.
The functions _pad_tokens_for_grouped_mm and _strip_padding_from_grouped_mm_output both independently compute the exact same indexing and routing variables (total, expert_for_row, perm_starts, padded_starts, row_idx, within, and dst_idx/src_idx). Since these functions are called sequentially in every MoE layer at every inference step, recomputing these tensors is highly redundant and introduces significant CPU-GPU launch overhead. We can return dst_idx from _pad_tokens_for_grouped_mm and pass it directly to _strip_padding_from_grouped_mm_output, completely eliminating the redundant computations.
| def _pad_tokens_for_grouped_mm(x_perm, counts): | |
| padded_counts = _expert_padded_counts(counts) | |
| offsets = padded_counts.cumsum(0).to(torch.int32) | |
| total = counts.sum() | |
| expert_for_row = _sorted_expert_row_map(counts) | |
| perm_starts = counts.cumsum(0) - counts | |
| padded_starts = padded_counts.cumsum(0) - padded_counts | |
| row_idx = torch.arange(total, device=counts.device, dtype=torch.long) | |
| within = row_idx - perm_starts[expert_for_row] | |
| dst_idx = padded_starts[expert_for_row] + within | |
| x_padded = x_perm.new_zeros(padded_counts.sum(), x_perm.shape[-1]) | |
| x_padded[dst_idx] = x_perm | |
| return x_padded, offsets, padded_counts | |
| def _strip_padding_from_grouped_mm_output(out_padded, counts, padded_counts): | |
| total = counts.sum() | |
| expert_for_row = _sorted_expert_row_map(counts) | |
| perm_starts = counts.cumsum(0) - counts | |
| padded_starts = padded_counts.cumsum(0) - padded_counts | |
| row_idx = torch.arange(total, device=counts.device, dtype=torch.long) | |
| within = row_idx - perm_starts[expert_for_row] | |
| src_idx = padded_starts[expert_for_row] + within | |
| return out_padded[src_idx] | |
| def _pad_tokens_for_grouped_mm(x_perm, counts): | |
| padded_counts = _expert_padded_counts(counts) | |
| offsets = padded_counts.cumsum(0).to(torch.int32) | |
| total = counts.sum() | |
| expert_for_row = _sorted_expert_row_map(counts) | |
| perm_starts = counts.cumsum(0) - counts | |
| padded_starts = padded_counts.cumsum(0) - padded_counts | |
| row_idx = torch.arange(total, device=counts.device, dtype=torch.long) | |
| within = row_idx - perm_starts[expert_for_row] | |
| dst_idx = padded_starts[expert_for_row] + within | |
| x_padded = x_perm.new_zeros(padded_counts.sum(), x_perm.shape[-1]) | |
| x_padded[dst_idx] = x_perm | |
| return x_padded, offsets, padded_counts, dst_idx | |
| def _strip_padding_from_grouped_mm_output(out_padded, dst_idx): | |
| return out_padded[dst_idx] |
| def _sparse_moe_pytorch(self, moe_w, hidden_states, selected_experts, routing_weights): | ||
| hidden_dim = hidden_states.shape[-1] | ||
| flat_topk_idx = selected_experts.reshape(-1) | ||
| flat_topk_weight = routing_weights.reshape(-1, 1) | ||
|
|
||
| idxs = flat_topk_idx.argsort() | ||
| token_idxs = idxs // self.num_experts_per_tok | ||
| counts = flat_topk_idx.bincount(minlength=moe_w.num_experts) | ||
|
|
||
| x_perm = hidden_states[token_idxs] | ||
| x_padded, offsets, padded_counts = _pad_tokens_for_grouped_mm(x_perm, counts) | ||
| gate_out = torch._grouped_mm(x_padded, moe_w._pt_gate_weight, offs=offsets) | ||
| up_out = torch._grouped_mm(x_padded, moe_w._pt_up_weight, offs=offsets) | ||
| hidden = F.silu(gate_out) * up_out | ||
| out_padded = torch._grouped_mm(hidden, moe_w._pt_down_weight, offs=offsets) | ||
| expert_out = _strip_padding_from_grouped_mm_output(out_padded, counts, padded_counts) | ||
| expert_out.mul_(flat_topk_weight[idxs]) | ||
|
|
||
| expert_cache = torch.zeros_like(hidden_states) | ||
| expert_cache = expert_cache.to(expert_out.dtype) | ||
| expert_cache.scatter_reduce_( | ||
| 0, | ||
| token_idxs.view(-1, 1).repeat(1, hidden_dim), | ||
| expert_out, | ||
| reduce="sum", | ||
| ) | ||
| return expert_cache |
There was a problem hiding this comment.
In _sparse_moe_pytorch, token_idxs.view(-1, 1).repeat(1, hidden_dim) is used to expand the index tensor for scatter_reduce_. Using .repeat() allocates a new tensor and copies the data, which can be very expensive and cause memory fragmentation during step-by-step inference. Using .expand(-1, hidden_dim) achieves the same result without any memory allocation or copy overhead. Additionally, we should update the call sites to pass dst_idx directly as per the optimization in _pad_tokens_for_grouped_mm.
def _sparse_moe_pytorch(self, moe_w, hidden_states, selected_experts, routing_weights):
hidden_dim = hidden_states.shape[-1]
flat_topk_idx = selected_experts.reshape(-1)
flat_topk_weight = routing_weights.reshape(-1, 1)
idxs = flat_topk_idx.argsort()
token_idxs = idxs // self.num_experts_per_tok
counts = flat_topk_idx.bincount(minlength=moe_w.num_experts)
x_perm = hidden_states[token_idxs]
x_padded, offsets, padded_counts, dst_idx = _pad_tokens_for_grouped_mm(x_perm, counts)
gate_out = torch._grouped_mm(x_padded, moe_w._pt_gate_weight, offs=offsets)
up_out = torch._grouped_mm(x_padded, moe_w._pt_up_weight, offs=offsets)
hidden = F.silu(gate_out) * up_out
out_padded = torch._grouped_mm(hidden, moe_w._pt_down_weight, offs=offsets)
expert_out = _strip_padding_from_grouped_mm_output(out_padded, dst_idx)
expert_out.mul_(flat_topk_weight[idxs])
expert_cache = torch.zeros_like(hidden_states)
expert_cache = expert_cache.to(expert_out.dtype)
expert_cache.scatter_reduce_(
0,
token_idxs.view(-1, 1).expand(-1, hidden_dim),
expert_out,
reduce="sum",
)
return expert_cache|
Thanks @gemini-code-assist for the review. Applied: (1) reuse dst_idx between pad/strip helpers; (2) repeat → expand for scatter indices. Not applied: transpose in _build_pytorch_grouped_mm_weights. MMWeight already stores weights as [in, out] for torch.mm / torch._grouped_mm. FlashInfer’s extra .t() is for Cutlass layout only — the two backends differ by design. |
|
Thanks for the update and for clarifying the reasoning behind the weight layout design. Your explanation regarding the |
No description provided.