Skip to content

Commit 2e29ee7

Browse files
committed
Add unittest
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent 7da77b9 commit 2e29ee7

1 file changed

Lines changed: 317 additions & 0 deletions

File tree

Lines changed: 317 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,317 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for _is_sparse_moe_block and _QuantSparseMoe."""
17+
18+
import pytest
19+
import torch
20+
import torch.nn as nn
21+
22+
pytest.importorskip("transformers")
23+
24+
from _test_utils.torch.transformers_models import get_tiny_qwen3_moe
25+
26+
import modelopt.torch.quantization as mtq
27+
from modelopt.torch.quantization.nn import QuantModuleRegistry
28+
from modelopt.torch.quantization.plugins.huggingface import (
29+
TRANSFORMERS_VERSION_GE_5_0,
30+
_is_sparse_moe_block,
31+
register_sparse_moe_on_the_fly,
32+
)
33+
34+
35+
# ---------------------------------------------------------------------------
36+
# Helpers: lightweight mock modules for _is_sparse_moe_block
37+
# ---------------------------------------------------------------------------
38+
class _FakeGateWithRouter(nn.Module):
39+
"""Mimics a v5.x TopKRouter gate with top_k and num_experts."""
40+
41+
def __init__(self, top_k=2, num_experts=4):
42+
super().__init__()
43+
self.top_k = top_k
44+
self.num_experts = num_experts
45+
self.linear = nn.Linear(8, num_experts)
46+
47+
def forward(self, x):
48+
return self.linear(x)
49+
50+
51+
class _FakeExperts(nn.ModuleList):
52+
def __init__(self, n=4):
53+
super().__init__([nn.Linear(8, 8) for _ in range(n)])
54+
self.num_experts = n
55+
56+
57+
class _MoEBlockWithGateRouter(nn.Module):
58+
"""Matches the primary detection path: gate.top_k + gate.num_experts."""
59+
60+
def __init__(self, num_experts=4, top_k=2):
61+
super().__init__()
62+
self.gate = _FakeGateWithRouter(top_k=top_k, num_experts=num_experts)
63+
self.experts = _FakeExperts(num_experts)
64+
65+
def forward(self, hidden_states):
66+
logits = self.gate(hidden_states)
67+
routing_weights, selected = torch.topk(logits, self.gate.top_k, dim=-1)
68+
out = torch.zeros_like(hidden_states)
69+
for i in range(self.gate.num_experts):
70+
mask = (selected == i).any(dim=-1)
71+
if mask.any():
72+
out[mask] += self.experts[i](hidden_states[mask])
73+
return out
74+
75+
76+
class _MoEBlockFallback(nn.Module):
77+
"""Matches the fallback path: top_k + num_experts on the block itself."""
78+
79+
def __init__(self, num_experts=4, top_k=2):
80+
super().__init__()
81+
self.num_experts = num_experts
82+
self.top_k = top_k
83+
self.gate = nn.Linear(8, num_experts)
84+
self.experts = _FakeExperts(num_experts)
85+
86+
def forward(self, hidden_states):
87+
logits = self.gate(hidden_states)
88+
routing_weights, selected = torch.topk(logits, self.top_k, dim=-1)
89+
out = torch.zeros_like(hidden_states)
90+
for i in range(self.num_experts):
91+
mask = (selected == i).any(dim=-1)
92+
if mask.any():
93+
out[mask] += self.experts[i](hidden_states[mask])
94+
return out
95+
96+
97+
# ---------------------------------------------------------------------------
98+
# Tests for _is_sparse_moe_block
99+
# ---------------------------------------------------------------------------
100+
class TestIsSparseBlock:
101+
def test_no_experts_returns_false(self):
102+
module = nn.Linear(8, 8)
103+
assert _is_sparse_moe_block(module) is False
104+
105+
def test_experts_but_no_gate_or_topk_returns_false(self):
106+
module = nn.Module()
107+
module.experts = nn.ModuleList([nn.Linear(8, 8)])
108+
assert _is_sparse_moe_block(module) is False
109+
110+
def test_gate_with_router_attrs_returns_true(self):
111+
block = _MoEBlockWithGateRouter(num_experts=4, top_k=2)
112+
assert _is_sparse_moe_block(block) is True
113+
114+
def test_fallback_block_level_attrs_returns_true(self):
115+
block = _MoEBlockFallback(num_experts=4, top_k=2)
116+
assert _is_sparse_moe_block(block) is True
117+
118+
def test_gate_missing_num_experts_returns_false(self):
119+
"""gate.top_k present but gate.num_experts absent -> primary path fails."""
120+
module = nn.Module()
121+
module.experts = nn.ModuleList([nn.Linear(8, 8)])
122+
gate = nn.Module()
123+
gate.top_k = 2
124+
module.gate = gate
125+
assert _is_sparse_moe_block(module) is False
126+
127+
def test_gate_missing_top_k_returns_false(self):
128+
"""gate.num_experts present but gate.top_k absent -> primary path fails."""
129+
module = nn.Module()
130+
module.experts = nn.ModuleList([nn.Linear(8, 8)])
131+
gate = nn.Module()
132+
gate.num_experts = 4
133+
module.gate = gate
134+
assert _is_sparse_moe_block(module) is False
135+
136+
def test_block_level_only_top_k_returns_false(self):
137+
"""Only top_k on block (no num_experts) -> fallback fails."""
138+
module = nn.Module()
139+
module.experts = nn.ModuleList([nn.Linear(8, 8)])
140+
module.top_k = 2
141+
assert _is_sparse_moe_block(module) is False
142+
143+
def test_block_level_only_num_experts_returns_false(self):
144+
"""Only num_experts on block (no top_k) -> fallback fails."""
145+
module = nn.Module()
146+
module.experts = nn.ModuleList([nn.Linear(8, 8)])
147+
module.num_experts = 4
148+
assert _is_sparse_moe_block(module) is False
149+
150+
def test_glm4_like_block_rejected(self):
151+
"""A module with n_routed_experts instead of num_experts should be rejected."""
152+
module = nn.Module()
153+
module.experts = nn.ModuleList([nn.Linear(8, 8)])
154+
gate = nn.Module()
155+
gate.top_k = 2
156+
gate.n_routed_experts = 4 # different attr name
157+
module.gate = gate
158+
assert _is_sparse_moe_block(module) is False
159+
160+
161+
# ---------------------------------------------------------------------------
162+
# Tests for _QuantSparseMoe
163+
# ---------------------------------------------------------------------------
164+
class TestQuantSparseMoe:
165+
"""Tests for _QuantSparseMoe using a real tiny Qwen3Moe model."""
166+
167+
@staticmethod
168+
def _get_moe_block(model):
169+
"""Return the first MoE block from the model."""
170+
for module in model.modules():
171+
if _is_sparse_moe_block(module):
172+
return module
173+
raise RuntimeError("No MoE block found in model")
174+
175+
def test_register_sparse_moe_on_the_fly(self):
176+
model = get_tiny_qwen3_moe()
177+
moe_block = self._get_moe_block(model)
178+
moe_type = type(moe_block)
179+
180+
if QuantModuleRegistry.get(moe_type) is not None:
181+
pytest.skip("MoE type already registered (upstream change)")
182+
183+
register_sparse_moe_on_the_fly(model)
184+
assert QuantModuleRegistry.get(moe_type) is not None
185+
186+
def test_setup_creates_expert_token_count(self):
187+
model = get_tiny_qwen3_moe()
188+
moe_block = self._get_moe_block(model)
189+
moe_type = type(moe_block)
190+
191+
if QuantModuleRegistry.get(moe_type) is None:
192+
register_sparse_moe_on_the_fly(model)
193+
194+
converted = QuantModuleRegistry.convert(moe_block)
195+
assert hasattr(converted, "expert_token_count")
196+
expected_num_experts = moe_block.num_experts if hasattr(moe_block, "num_experts") else 0
197+
assert converted.expert_token_count.shape == (expected_num_experts,)
198+
assert converted.expert_token_count.dtype == torch.long
199+
assert (converted.expert_token_count == 0).all()
200+
201+
def test_setup_count_expert_tokens_default_false(self):
202+
model = get_tiny_qwen3_moe()
203+
moe_block = self._get_moe_block(model)
204+
moe_type = type(moe_block)
205+
206+
if QuantModuleRegistry.get(moe_type) is None:
207+
register_sparse_moe_on_the_fly(model)
208+
209+
converted = QuantModuleRegistry.convert(moe_block)
210+
assert converted._count_expert_tokens is False
211+
212+
def test_forward_no_calib_matches_original(self):
213+
"""When calibration is off, _QuantSparseMoe should produce the same output as the original."""
214+
model = get_tiny_qwen3_moe()
215+
moe_block = self._get_moe_block(model)
216+
moe_type = type(moe_block)
217+
218+
if QuantModuleRegistry.get(moe_type) is None:
219+
register_sparse_moe_on_the_fly(model)
220+
221+
ref_block = self._get_moe_block(get_tiny_qwen3_moe())
222+
ref_block.load_state_dict(moe_block.state_dict())
223+
224+
converted = QuantModuleRegistry.convert(moe_block)
225+
226+
torch.manual_seed(42)
227+
x = torch.randn(1, 4, 32)
228+
with torch.no_grad():
229+
out_ref = ref_block(x)
230+
out_test = converted(x)
231+
232+
if isinstance(out_ref, tuple):
233+
out_ref = out_ref[0]
234+
if isinstance(out_test, tuple):
235+
out_test = out_test[0]
236+
assert torch.allclose(out_ref, out_test, atol=1e-5)
237+
238+
def test_forward_calib_sends_all_tokens_to_all_experts(self):
239+
"""During calibration, all experts should see tokens (expert_token_count all > 0)."""
240+
model = get_tiny_qwen3_moe()
241+
register_sparse_moe_on_the_fly(model)
242+
243+
def calib_fn(model):
244+
x = model.dummy_inputs["input_ids"]
245+
model(x)
246+
247+
mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calib_fn)
248+
249+
for name, module in model.named_modules():
250+
if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0:
251+
assert (module.expert_token_count > 0).all(), (
252+
f"Not all experts received tokens in {name}: {module.expert_token_count}"
253+
)
254+
255+
def test_forward_calib_restores_top_k(self):
256+
"""After calibration forward, top_k should be restored to its original value."""
257+
model = get_tiny_qwen3_moe()
258+
moe_block = self._get_moe_block(model)
259+
moe_type = type(moe_block)
260+
261+
if QuantModuleRegistry.get(moe_type) is None:
262+
register_sparse_moe_on_the_fly(model)
263+
264+
if TRANSFORMERS_VERSION_GE_5_0:
265+
original_top_k = moe_block.gate.top_k
266+
else:
267+
original_top_k = moe_block.top_k
268+
269+
converted = QuantModuleRegistry.convert(moe_block)
270+
271+
# Simulate calibration mode: set _if_calib on a child TensorQuantizer
272+
for m in converted.experts.modules():
273+
if hasattr(m, "_if_calib"):
274+
m._if_calib = True
275+
break
276+
277+
x = torch.randn(1, 4, 32)
278+
with torch.no_grad():
279+
converted(x)
280+
281+
if TRANSFORMERS_VERSION_GE_5_0:
282+
assert converted.gate.top_k == original_top_k
283+
else:
284+
assert converted.top_k == original_top_k
285+
286+
def test_gate_forward_hook_counts_tokens(self):
287+
"""Verify the gate forward hook correctly counts expert token assignments."""
288+
model = get_tiny_qwen3_moe()
289+
moe_block = self._get_moe_block(model)
290+
moe_type = type(moe_block)
291+
292+
if QuantModuleRegistry.get(moe_type) is None:
293+
register_sparse_moe_on_the_fly(model)
294+
295+
converted = QuantModuleRegistry.convert(moe_block)
296+
297+
# Reset counts and enable counting
298+
converted.expert_token_count.zero_()
299+
converted._count_expert_tokens = True
300+
301+
hidden_size = converted.gate.in_features
302+
x = torch.randn(8, hidden_size)
303+
with torch.no_grad():
304+
converted.gate(x)
305+
306+
# After one gate call with counting enabled, total assigned tokens should equal
307+
# num_tokens * top_k
308+
top_k = converted.top_k if hasattr(converted, "top_k") else converted.gate.top_k
309+
total_assigned = converted.expert_token_count.sum().item()
310+
assert total_assigned == 8 * top_k
311+
312+
# Disable counting and verify counts don't change
313+
converted._count_expert_tokens = False
314+
prev_counts = converted.expert_token_count.clone()
315+
with torch.no_grad():
316+
converted.gate(x)
317+
assert torch.equal(converted.expert_token_count, prev_counts)

0 commit comments

Comments
 (0)