|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Tests for _is_sparse_moe_block and _QuantSparseMoe.""" |
| 17 | + |
| 18 | +import pytest |
| 19 | +import torch |
| 20 | +import torch.nn as nn |
| 21 | + |
| 22 | +pytest.importorskip("transformers") |
| 23 | + |
| 24 | +from _test_utils.torch.transformers_models import get_tiny_qwen3_moe |
| 25 | + |
| 26 | +import modelopt.torch.quantization as mtq |
| 27 | +from modelopt.torch.quantization.nn import QuantModuleRegistry |
| 28 | +from modelopt.torch.quantization.plugins.huggingface import ( |
| 29 | + TRANSFORMERS_VERSION_GE_5_0, |
| 30 | + _is_sparse_moe_block, |
| 31 | + register_sparse_moe_on_the_fly, |
| 32 | +) |
| 33 | + |
| 34 | + |
| 35 | +# --------------------------------------------------------------------------- |
| 36 | +# Helpers: lightweight mock modules for _is_sparse_moe_block |
| 37 | +# --------------------------------------------------------------------------- |
| 38 | +class _FakeGateWithRouter(nn.Module): |
| 39 | + """Mimics a v5.x TopKRouter gate with top_k and num_experts.""" |
| 40 | + |
| 41 | + def __init__(self, top_k=2, num_experts=4): |
| 42 | + super().__init__() |
| 43 | + self.top_k = top_k |
| 44 | + self.num_experts = num_experts |
| 45 | + self.linear = nn.Linear(8, num_experts) |
| 46 | + |
| 47 | + def forward(self, x): |
| 48 | + return self.linear(x) |
| 49 | + |
| 50 | + |
| 51 | +class _FakeExperts(nn.ModuleList): |
| 52 | + def __init__(self, n=4): |
| 53 | + super().__init__([nn.Linear(8, 8) for _ in range(n)]) |
| 54 | + self.num_experts = n |
| 55 | + |
| 56 | + |
| 57 | +class _MoEBlockWithGateRouter(nn.Module): |
| 58 | + """Matches the primary detection path: gate.top_k + gate.num_experts.""" |
| 59 | + |
| 60 | + def __init__(self, num_experts=4, top_k=2): |
| 61 | + super().__init__() |
| 62 | + self.gate = _FakeGateWithRouter(top_k=top_k, num_experts=num_experts) |
| 63 | + self.experts = _FakeExperts(num_experts) |
| 64 | + |
| 65 | + def forward(self, hidden_states): |
| 66 | + logits = self.gate(hidden_states) |
| 67 | + routing_weights, selected = torch.topk(logits, self.gate.top_k, dim=-1) |
| 68 | + out = torch.zeros_like(hidden_states) |
| 69 | + for i in range(self.gate.num_experts): |
| 70 | + mask = (selected == i).any(dim=-1) |
| 71 | + if mask.any(): |
| 72 | + out[mask] += self.experts[i](hidden_states[mask]) |
| 73 | + return out |
| 74 | + |
| 75 | + |
| 76 | +class _MoEBlockFallback(nn.Module): |
| 77 | + """Matches the fallback path: top_k + num_experts on the block itself.""" |
| 78 | + |
| 79 | + def __init__(self, num_experts=4, top_k=2): |
| 80 | + super().__init__() |
| 81 | + self.num_experts = num_experts |
| 82 | + self.top_k = top_k |
| 83 | + self.gate = nn.Linear(8, num_experts) |
| 84 | + self.experts = _FakeExperts(num_experts) |
| 85 | + |
| 86 | + def forward(self, hidden_states): |
| 87 | + logits = self.gate(hidden_states) |
| 88 | + routing_weights, selected = torch.topk(logits, self.top_k, dim=-1) |
| 89 | + out = torch.zeros_like(hidden_states) |
| 90 | + for i in range(self.num_experts): |
| 91 | + mask = (selected == i).any(dim=-1) |
| 92 | + if mask.any(): |
| 93 | + out[mask] += self.experts[i](hidden_states[mask]) |
| 94 | + return out |
| 95 | + |
| 96 | + |
| 97 | +# --------------------------------------------------------------------------- |
| 98 | +# Tests for _is_sparse_moe_block |
| 99 | +# --------------------------------------------------------------------------- |
| 100 | +class TestIsSparseBlock: |
| 101 | + def test_no_experts_returns_false(self): |
| 102 | + module = nn.Linear(8, 8) |
| 103 | + assert _is_sparse_moe_block(module) is False |
| 104 | + |
| 105 | + def test_experts_but_no_gate_or_topk_returns_false(self): |
| 106 | + module = nn.Module() |
| 107 | + module.experts = nn.ModuleList([nn.Linear(8, 8)]) |
| 108 | + assert _is_sparse_moe_block(module) is False |
| 109 | + |
| 110 | + def test_gate_with_router_attrs_returns_true(self): |
| 111 | + block = _MoEBlockWithGateRouter(num_experts=4, top_k=2) |
| 112 | + assert _is_sparse_moe_block(block) is True |
| 113 | + |
| 114 | + def test_fallback_block_level_attrs_returns_true(self): |
| 115 | + block = _MoEBlockFallback(num_experts=4, top_k=2) |
| 116 | + assert _is_sparse_moe_block(block) is True |
| 117 | + |
| 118 | + def test_gate_missing_num_experts_returns_false(self): |
| 119 | + """gate.top_k present but gate.num_experts absent -> primary path fails.""" |
| 120 | + module = nn.Module() |
| 121 | + module.experts = nn.ModuleList([nn.Linear(8, 8)]) |
| 122 | + gate = nn.Module() |
| 123 | + gate.top_k = 2 |
| 124 | + module.gate = gate |
| 125 | + assert _is_sparse_moe_block(module) is False |
| 126 | + |
| 127 | + def test_gate_missing_top_k_returns_false(self): |
| 128 | + """gate.num_experts present but gate.top_k absent -> primary path fails.""" |
| 129 | + module = nn.Module() |
| 130 | + module.experts = nn.ModuleList([nn.Linear(8, 8)]) |
| 131 | + gate = nn.Module() |
| 132 | + gate.num_experts = 4 |
| 133 | + module.gate = gate |
| 134 | + assert _is_sparse_moe_block(module) is False |
| 135 | + |
| 136 | + def test_block_level_only_top_k_returns_false(self): |
| 137 | + """Only top_k on block (no num_experts) -> fallback fails.""" |
| 138 | + module = nn.Module() |
| 139 | + module.experts = nn.ModuleList([nn.Linear(8, 8)]) |
| 140 | + module.top_k = 2 |
| 141 | + assert _is_sparse_moe_block(module) is False |
| 142 | + |
| 143 | + def test_block_level_only_num_experts_returns_false(self): |
| 144 | + """Only num_experts on block (no top_k) -> fallback fails.""" |
| 145 | + module = nn.Module() |
| 146 | + module.experts = nn.ModuleList([nn.Linear(8, 8)]) |
| 147 | + module.num_experts = 4 |
| 148 | + assert _is_sparse_moe_block(module) is False |
| 149 | + |
| 150 | + def test_glm4_like_block_rejected(self): |
| 151 | + """A module with n_routed_experts instead of num_experts should be rejected.""" |
| 152 | + module = nn.Module() |
| 153 | + module.experts = nn.ModuleList([nn.Linear(8, 8)]) |
| 154 | + gate = nn.Module() |
| 155 | + gate.top_k = 2 |
| 156 | + gate.n_routed_experts = 4 # different attr name |
| 157 | + module.gate = gate |
| 158 | + assert _is_sparse_moe_block(module) is False |
| 159 | + |
| 160 | + |
| 161 | +# --------------------------------------------------------------------------- |
| 162 | +# Tests for _QuantSparseMoe |
| 163 | +# --------------------------------------------------------------------------- |
| 164 | +class TestQuantSparseMoe: |
| 165 | + """Tests for _QuantSparseMoe using a real tiny Qwen3Moe model.""" |
| 166 | + |
| 167 | + @staticmethod |
| 168 | + def _get_moe_block(model): |
| 169 | + """Return the first MoE block from the model.""" |
| 170 | + for module in model.modules(): |
| 171 | + if _is_sparse_moe_block(module): |
| 172 | + return module |
| 173 | + raise RuntimeError("No MoE block found in model") |
| 174 | + |
| 175 | + def test_register_sparse_moe_on_the_fly(self): |
| 176 | + model = get_tiny_qwen3_moe() |
| 177 | + moe_block = self._get_moe_block(model) |
| 178 | + moe_type = type(moe_block) |
| 179 | + |
| 180 | + if QuantModuleRegistry.get(moe_type) is not None: |
| 181 | + pytest.skip("MoE type already registered (upstream change)") |
| 182 | + |
| 183 | + register_sparse_moe_on_the_fly(model) |
| 184 | + assert QuantModuleRegistry.get(moe_type) is not None |
| 185 | + |
| 186 | + def test_setup_creates_expert_token_count(self): |
| 187 | + model = get_tiny_qwen3_moe() |
| 188 | + moe_block = self._get_moe_block(model) |
| 189 | + moe_type = type(moe_block) |
| 190 | + |
| 191 | + if QuantModuleRegistry.get(moe_type) is None: |
| 192 | + register_sparse_moe_on_the_fly(model) |
| 193 | + |
| 194 | + converted = QuantModuleRegistry.convert(moe_block) |
| 195 | + assert hasattr(converted, "expert_token_count") |
| 196 | + expected_num_experts = moe_block.num_experts if hasattr(moe_block, "num_experts") else 0 |
| 197 | + assert converted.expert_token_count.shape == (expected_num_experts,) |
| 198 | + assert converted.expert_token_count.dtype == torch.long |
| 199 | + assert (converted.expert_token_count == 0).all() |
| 200 | + |
| 201 | + def test_setup_count_expert_tokens_default_false(self): |
| 202 | + model = get_tiny_qwen3_moe() |
| 203 | + moe_block = self._get_moe_block(model) |
| 204 | + moe_type = type(moe_block) |
| 205 | + |
| 206 | + if QuantModuleRegistry.get(moe_type) is None: |
| 207 | + register_sparse_moe_on_the_fly(model) |
| 208 | + |
| 209 | + converted = QuantModuleRegistry.convert(moe_block) |
| 210 | + assert converted._count_expert_tokens is False |
| 211 | + |
| 212 | + def test_forward_no_calib_matches_original(self): |
| 213 | + """When calibration is off, _QuantSparseMoe should produce the same output as the original.""" |
| 214 | + model = get_tiny_qwen3_moe() |
| 215 | + moe_block = self._get_moe_block(model) |
| 216 | + moe_type = type(moe_block) |
| 217 | + |
| 218 | + if QuantModuleRegistry.get(moe_type) is None: |
| 219 | + register_sparse_moe_on_the_fly(model) |
| 220 | + |
| 221 | + ref_block = self._get_moe_block(get_tiny_qwen3_moe()) |
| 222 | + ref_block.load_state_dict(moe_block.state_dict()) |
| 223 | + |
| 224 | + converted = QuantModuleRegistry.convert(moe_block) |
| 225 | + |
| 226 | + torch.manual_seed(42) |
| 227 | + x = torch.randn(1, 4, 32) |
| 228 | + with torch.no_grad(): |
| 229 | + out_ref = ref_block(x) |
| 230 | + out_test = converted(x) |
| 231 | + |
| 232 | + if isinstance(out_ref, tuple): |
| 233 | + out_ref = out_ref[0] |
| 234 | + if isinstance(out_test, tuple): |
| 235 | + out_test = out_test[0] |
| 236 | + assert torch.allclose(out_ref, out_test, atol=1e-5) |
| 237 | + |
| 238 | + def test_forward_calib_sends_all_tokens_to_all_experts(self): |
| 239 | + """During calibration, all experts should see tokens (expert_token_count all > 0).""" |
| 240 | + model = get_tiny_qwen3_moe() |
| 241 | + register_sparse_moe_on_the_fly(model) |
| 242 | + |
| 243 | + def calib_fn(model): |
| 244 | + x = model.dummy_inputs["input_ids"] |
| 245 | + model(x) |
| 246 | + |
| 247 | + mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calib_fn) |
| 248 | + |
| 249 | + for name, module in model.named_modules(): |
| 250 | + if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0: |
| 251 | + assert (module.expert_token_count > 0).all(), ( |
| 252 | + f"Not all experts received tokens in {name}: {module.expert_token_count}" |
| 253 | + ) |
| 254 | + |
| 255 | + def test_forward_calib_restores_top_k(self): |
| 256 | + """After calibration forward, top_k should be restored to its original value.""" |
| 257 | + model = get_tiny_qwen3_moe() |
| 258 | + moe_block = self._get_moe_block(model) |
| 259 | + moe_type = type(moe_block) |
| 260 | + |
| 261 | + if QuantModuleRegistry.get(moe_type) is None: |
| 262 | + register_sparse_moe_on_the_fly(model) |
| 263 | + |
| 264 | + if TRANSFORMERS_VERSION_GE_5_0: |
| 265 | + original_top_k = moe_block.gate.top_k |
| 266 | + else: |
| 267 | + original_top_k = moe_block.top_k |
| 268 | + |
| 269 | + converted = QuantModuleRegistry.convert(moe_block) |
| 270 | + |
| 271 | + # Simulate calibration mode: set _if_calib on a child TensorQuantizer |
| 272 | + for m in converted.experts.modules(): |
| 273 | + if hasattr(m, "_if_calib"): |
| 274 | + m._if_calib = True |
| 275 | + break |
| 276 | + |
| 277 | + x = torch.randn(1, 4, 32) |
| 278 | + with torch.no_grad(): |
| 279 | + converted(x) |
| 280 | + |
| 281 | + if TRANSFORMERS_VERSION_GE_5_0: |
| 282 | + assert converted.gate.top_k == original_top_k |
| 283 | + else: |
| 284 | + assert converted.top_k == original_top_k |
| 285 | + |
| 286 | + def test_gate_forward_hook_counts_tokens(self): |
| 287 | + """Verify the gate forward hook correctly counts expert token assignments.""" |
| 288 | + model = get_tiny_qwen3_moe() |
| 289 | + moe_block = self._get_moe_block(model) |
| 290 | + moe_type = type(moe_block) |
| 291 | + |
| 292 | + if QuantModuleRegistry.get(moe_type) is None: |
| 293 | + register_sparse_moe_on_the_fly(model) |
| 294 | + |
| 295 | + converted = QuantModuleRegistry.convert(moe_block) |
| 296 | + |
| 297 | + # Reset counts and enable counting |
| 298 | + converted.expert_token_count.zero_() |
| 299 | + converted._count_expert_tokens = True |
| 300 | + |
| 301 | + hidden_size = converted.gate.in_features |
| 302 | + x = torch.randn(8, hidden_size) |
| 303 | + with torch.no_grad(): |
| 304 | + converted.gate(x) |
| 305 | + |
| 306 | + # After one gate call with counting enabled, total assigned tokens should equal |
| 307 | + # num_tokens * top_k |
| 308 | + top_k = converted.top_k if hasattr(converted, "top_k") else converted.gate.top_k |
| 309 | + total_assigned = converted.expert_token_count.sum().item() |
| 310 | + assert total_assigned == 8 * top_k |
| 311 | + |
| 312 | + # Disable counting and verify counts don't change |
| 313 | + converted._count_expert_tokens = False |
| 314 | + prev_counts = converted.expert_token_count.clone() |
| 315 | + with torch.no_grad(): |
| 316 | + converted.gate(x) |
| 317 | + assert torch.equal(converted.expert_token_count, prev_counts) |
0 commit comments