Skip to content

Commit 35dad9a

Browse files
committed
fixes for fused moe (qwen3.6, GLM5.1 + MSE calibration
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent 50706d1 commit 35dad9a

7 files changed

Lines changed: 263 additions & 16 deletions

File tree

examples/llm_ptq/hf_ptq.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,11 @@ def pre_quantize(
797797
preview_input_ids = next(iter(calib_dataloader))[
798798
"input_features" if model_type == "whisper" else "input_ids"
799799
][0:1]
800+
# Strip leading padding tokens so the preview input shows real content
801+
if model_type not in ("whisper",) and tokenizer is not None and tokenizer.pad_token_id is not None:
802+
first_non_pad = (preview_input_ids[0] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
803+
if first_non_pad.numel() > 0:
804+
preview_input_ids = preview_input_ids[:, first_non_pad[0]:]
800805

801806
# Generate preview before quantization
802807
if args.skip_generate:
@@ -897,7 +902,7 @@ def input_decode(input_ids):
897902
if processor is not None and isinstance(processor, WhisperProcessor):
898903
return first_text_speech_dataset
899904
elif tokenizer is not None:
900-
return tokenizer.batch_decode(input_ids)
905+
return tokenizer.batch_decode(input_ids, skip_special_tokens=True)
901906
else:
902907
raise ValueError("The processor or tokenizer must be set")
903908

modelopt/torch/export/moe_utils.py

Lines changed: 97 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,26 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
5959
# 2-3. Split + export each per-expert projection.
6060
fused_dim0 = gate_up.shape[1] # 2 * expert_dim
6161

62+
def _safe_cpu_amax(quantizer_src: nn.Module) -> torch.Tensor | None:
63+
"""Extract _amax to CPU float32, surfacing and clearing any pending CUDA error first."""
64+
amax = getattr(quantizer_src, "_amax", None)
65+
if amax is None or not isinstance(amax, torch.Tensor):
66+
return None
67+
try:
68+
if amax.is_cuda:
69+
torch.cuda.synchronize(amax.device)
70+
return amax.detach().cpu().float()
71+
except Exception:
72+
return None
73+
6274
for idx in range(n):
6375
expert = nn.Module()
6476

77+
# Extract amaxes to CPU before deepcopy: cloning a corrupt CUDA _amax tensor
78+
# (e.g. from an under-calibrated expert) triggers an async CUDA error.
79+
gu_amax_cpu = _safe_cpu_amax(module.gate_up_proj_weight_quantizers[idx])
80+
down_amax_cpu = _safe_cpu_amax(module.down_proj_weight_quantizers[idx])
81+
6582
projections = [
6683
("gate_proj", gate_up[idx, :expert_dim, :], 0, fused_dim0, True),
6784
("up_proj", gate_up[idx, expert_dim:, :], expert_dim, fused_dim0, True),
@@ -76,8 +93,17 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
7693
)
7794
i_quantizer = gate_up_input_q if is_gate_up else down_input_q
7895

79-
# gate/up share a weight quantizer — clone so each gets independent amax.
80-
w_quantizer = copy.deepcopy(w_quantizer_src) if is_gate_up else w_quantizer_src
96+
# gate/up share a quantizer — deepcopy with _amax nulled to avoid cloning
97+
# the corrupt CUDA tensor, then inject the pre-extracted CPU amax.
98+
if is_gate_up:
99+
_saved_amax = getattr(w_quantizer_src, "_amax", None)
100+
w_quantizer_src._amax = None
101+
w_quantizer = copy.deepcopy(w_quantizer_src)
102+
w_quantizer_src._amax = _saved_amax
103+
w_quantizer._amax = gu_amax_cpu
104+
else:
105+
w_quantizer = w_quantizer_src
106+
w_quantizer._amax = down_amax_cpu
81107

82108
# For per-channel amax (dim >= 1), proportionally slice dim-0
83109
# to match the split weight.
@@ -86,12 +112,14 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
86112
and w_quantizer._amax is not None
87113
and w_quantizer._amax.dim() >= 1
88114
):
89-
amax = w_quantizer._amax
115+
amax = w_quantizer._amax # CPU float32
90116
amax_dim0 = amax.shape[0]
91-
if fused_total % amax_dim0 == 0:
117+
if amax_dim0 % fused_total == 0:
92118
slice_start = fused_start * amax_dim0 // fused_total
93119
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
94-
w_quantizer.amax = amax[slice_start:slice_end].contiguous()
120+
# Bypass amax.setter (which forbids shape changes); w_quantizer is a
121+
# deepcopy for gate/up so mutating it is safe.
122+
w_quantizer._amax = amax[slice_start:slice_end].contiguous()
95123
else:
96124
warnings.warn(
97125
f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not "
@@ -100,20 +128,68 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
100128
stacklevel=2,
101129
)
102130

103-
# If the weight quantizer was never calibrated, compute amax from weights.
131+
# Patch invalid per-block amax entries (NaN/inf/negative/zero/too-small/too-large)
132+
# with weight-derived fallback values.
133+
_MIN_VALID_AMAX = 1e-4
134+
_MAX_VALID_AMAX = 1e6
135+
if (
136+
hasattr(w_quantizer, "_amax")
137+
and w_quantizer._amax is not None
138+
and w_quantizer._amax.numel() > 1
139+
):
140+
amax_cpu = w_quantizer._amax
141+
invalid_mask = ~(
142+
torch.isfinite(amax_cpu)
143+
& (amax_cpu >= _MIN_VALID_AMAX)
144+
& (amax_cpu <= _MAX_VALID_AMAX)
145+
)
146+
if invalid_mask.any():
147+
per_block_fallback = (
148+
weight_slice.detach()
149+
.reshape(-1, 16)
150+
.abs()
151+
.amax(dim=1, keepdim=True)
152+
.cpu()
153+
.float()
154+
.clamp(min=2e-3)
155+
.reshape(amax_cpu.shape)
156+
)
157+
amax_cpu[invalid_mask] = per_block_fallback[invalid_mask]
158+
w_quantizer._amax = amax_cpu
159+
160+
# For uncalibrated experts (amax missing or invalid scalar), fall back to
161+
# per-block amax from weights so the static export path can reshape it correctly.
104162
if (
105163
hasattr(w_quantizer, "is_enabled")
106164
and w_quantizer.is_enabled
107165
and (
108166
not hasattr(w_quantizer, "_amax")
109167
or w_quantizer._amax is None
110-
or torch.all(w_quantizer._amax == 0)
168+
or (
169+
w_quantizer._amax.numel() == 1
170+
and not (
171+
torch.isfinite(w_quantizer._amax)
172+
and w_quantizer._amax >= _MIN_VALID_AMAX
173+
and w_quantizer._amax <= _MAX_VALID_AMAX
174+
)
175+
)
111176
)
112177
):
113-
w_quantizer.amax = weight_slice.abs().amax().to(torch.float32)
178+
_block_size = 16
179+
fallback_per_block = (
180+
weight_slice.detach()
181+
.reshape(-1, _block_size)
182+
.abs()
183+
.amax(dim=1, keepdim=True)
184+
.cpu()
185+
.float()
186+
.clamp(min=2e-3)
187+
.reshape(*weight_slice.shape[:-1], weight_slice.shape[-1] // _block_size)
188+
)
189+
w_quantizer._amax = fallback_per_block
114190
warnings.warn(
115191
f"Expert {idx} {proj_name} weight quantizer was not calibrated "
116-
f"(amax missing or zero). Using weight-derived amax as fallback. "
192+
f"(amax missing or zero). Using weight-derived per-block amax as fallback. "
117193
f"Consider using more calibration data to activate all experts.",
118194
stacklevel=2,
119195
)
@@ -123,6 +199,18 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
123199
wrapper.weight_quantizer = w_quantizer
124200
wrapper.input_quantizer = i_quantizer
125201

202+
# Set global_amax to route to the static NVFP4 export path (reads per-block _amax).
203+
# Always recompute from the current (possibly patched) _amax — a stale zero
204+
# global_amax causes division-by-zero in the per-block scale formula.
205+
wq = wrapper.weight_quantizer
206+
if (
207+
hasattr(wq, "_amax")
208+
and wq._amax is not None
209+
and wq._amax.numel() > 1
210+
):
211+
wq._amax = wq._amax.to(weight_slice.device)
212+
wq.global_amax = wq._amax.float().amax().clamp(min=2e-3)
213+
126214
_export_quantized_weight(wrapper, dtype)
127215

128216
proj = nn.Module()

modelopt/torch/quantization/model_calib.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,18 @@ def mse_calibrate(
421421
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
422422
if getattr(weight_quantizer, "_calibrator", None) is not None:
423423
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
424+
# _QuantFusedExperts stores per-expert weight quantizers as nn.ModuleList named
425+
# {param_name}_weight_quantizers (plural). Detect this pattern and enqueue each
426+
# per-expert quantizer individually.
427+
for param_name, _ in parent_module.named_parameters(recurse=False):
428+
qlist = getattr(parent_module, f"{param_name}_weight_quantizers", None)
429+
if not isinstance(qlist, nn.ModuleList):
430+
continue
431+
for expert_idx, wq in enumerate(qlist):
432+
if isinstance(wq, TensorQuantizer) and wq.is_enabled:
433+
if getattr(wq, "_calibrator", None) is not None:
434+
weight_quantizers.append((parent_module, (param_name, expert_idx), wq))
435+
424436
seen_modules.add(parent_module)
425437

426438
# Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
@@ -432,7 +444,11 @@ def mse_calibrate(
432444
weight_quantizer.disable_quant()
433445
weight_quantizer.enable_calib()
434446
with enable_weight_access_and_writeback(parent_module, model, name_to_module):
435-
weight = getattr(parent_module, weight_name)
447+
if isinstance(weight_name, tuple):
448+
param_name, expert_idx = weight_name
449+
weight = getattr(parent_module, param_name)[expert_idx]
450+
else:
451+
weight = getattr(parent_module, weight_name)
436452
weight_quantizer(weight)
437453

438454
# IMMEDIATELY compute amax and reset calibrator to free memory

modelopt/torch/quantization/model_quant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,7 @@ def print_quant_summary(model: nn.Module, output_dir: str | None = None):
595595
lines.append(f"{len(lines)} TensorQuantizers found in model")
596596

597597
if output_dir:
598+
os.makedirs(output_dir, exist_ok=True)
598599
path = os.path.join(output_dir, ".quant_summary.txt")
599600
with open(path, "w", encoding="utf-8") as f:
600601
f.write("\n".join(lines) + "\n")

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,7 @@ def forward(self, inputs):
11121112

11131113
return outputs
11141114

1115-
def _short_amax(self, fmt=".4f"):
1115+
def _short_amax(self, fmt=".2e"):
11161116
"""Short description of amax.
11171117
11181118
Returns:
@@ -1130,7 +1130,7 @@ def _short_amax(self, fmt=".4f"):
11301130
return "meta"
11311131
return self._short_tensor(self._amax, fmt)
11321132

1133-
def _short_tensor(self, tensor: torch.Tensor, fmt=".4f"):
1133+
def _short_tensor(self, tensor: torch.Tensor, fmt=".2e"):
11341134
"""Short description of tensor."""
11351135
if tensor.numel() == 1:
11361136
return f"{tensor.item():{fmt}}"

modelopt/torch/quantization/qtensor/nvfp4_tensor.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,10 @@ def get_weights_scaling_factor_from_quantizer(
124124

125125
# Quantize scales to FP8
126126
if not keep_high_precision:
127-
per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to(
128-
torch.float8_e4m3fn
129-
)
127+
_FP8_E4M3FN_MIN = 2**-9 # 0.001953125 — smallest positive subnormal
128+
per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).clamp(
129+
min=_FP8_E4M3FN_MIN
130+
).to(torch.float8_e4m3fn)
130131
return per_block_scale, weights_scaling_factor_2
131132
else:
132133
# Dynamic path: compute from weight tensor
@@ -167,6 +168,12 @@ def get_weights_scaling_factor(
167168
per_block_scale[per_block_scale == 0] = 1.0
168169
# Convert to torch.float8_e4m3fn
169170
if not keep_high_precision:
171+
# Clamp to the minimum positive FP8 E4M3FN subnormal (~0.00195 = 2^-9) before
172+
# casting. Without this, blocks whose scale falls below the FP8 representable
173+
# range silently underflow to 0, causing those blocks to produce zero output at
174+
# inference even when the weights are non-trivial.
175+
_FP8_E4M3FN_MIN = 2**-9 # 0.001953125 — smallest positive subnormal
176+
per_block_scale = per_block_scale.clamp(min=_FP8_E4M3FN_MIN)
170177
per_block_scale = per_block_scale.to(torch.float8_e4m3fn)
171178
return per_block_scale, weights_scaling_factor_2
172179

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
metadata:
17+
recipe_type: ptq
18+
description: >
19+
NVFP4 W4A4 for MoE routed experts only. Static weight scales via MSE + FP8 scale sweep;
20+
dynamic activation scales. Supports sequential experts (nn.Linear-based) and fused experts
21+
(_QuantFusedExperts, HF transformers 5.0+ 3D nn.Parameter style).
22+
quantize:
23+
algorithm:
24+
method: mse
25+
fp8_scale_sweep: true
26+
layerwise: false
27+
quant_cfg:
28+
# ── Disable everything first ─────────────────────────────────────────────
29+
- quantizer_name: '*'
30+
enable: false
31+
32+
# ── Sequential experts (nn.Linear per expert) ────────────────────────────
33+
- quantizer_name: '*mlp.experts*weight_quantizer'
34+
enable: true
35+
cfg:
36+
block_sizes:
37+
-1: 16
38+
type: static
39+
scale_bits: e4m3
40+
num_bits: e2m1
41+
- quantizer_name: '*mlp.experts*input_quantizer'
42+
enable: true
43+
cfg:
44+
block_sizes:
45+
-1: 16
46+
type: dynamic
47+
scale_bits: e4m3
48+
num_bits: e2m1
49+
50+
# ── Sequential experts: Mixtral / block_sparse_moe style ────────────────
51+
- quantizer_name: '*block_sparse_moe*weight_quantizer'
52+
enable: true
53+
cfg:
54+
block_sizes:
55+
-1: 16
56+
type: static
57+
scale_bits: e4m3
58+
num_bits: e2m1
59+
- quantizer_name: '*block_sparse_moe*input_quantizer'
60+
enable: true
61+
cfg:
62+
block_sizes:
63+
-1: 16
64+
type: dynamic
65+
scale_bits: e4m3
66+
num_bits: e2m1
67+
68+
# ── Fused experts (_QuantFusedExperts, HF transformers 5.0+ 3D nn.Parameter style) ──
69+
- quantizer_name: '*gate_up_proj_weight_quantizers*'
70+
enable: true
71+
cfg:
72+
block_sizes:
73+
-1: 16
74+
type: static
75+
scale_bits: e4m3
76+
num_bits: e2m1
77+
- quantizer_name: '*gate_up_proj_input_quantizer*'
78+
enable: true
79+
cfg:
80+
block_sizes:
81+
-1: 16
82+
type: dynamic
83+
scale_bits: e4m3
84+
num_bits: e2m1
85+
- quantizer_name: '*down_proj_weight_quantizers*'
86+
enable: true
87+
cfg:
88+
block_sizes:
89+
-1: 16
90+
type: static
91+
scale_bits: e4m3
92+
num_bits: e2m1
93+
- quantizer_name: '*down_proj_input_quantizer*'
94+
enable: true
95+
cfg:
96+
block_sizes:
97+
-1: 16
98+
type: dynamic
99+
scale_bits: e4m3
100+
num_bits: e2m1
101+
102+
# ── Exclusions: shared experts, attention, routers, lm_head ─────────────
103+
- quantizer_name: '*block_sparse_moe.gate*'
104+
enable: false
105+
- quantizer_name: '*linear_attn.conv1d*'
106+
enable: false
107+
- quantizer_name: '*lm_head*'
108+
enable: false
109+
- quantizer_name: '*mlp.gate.*'
110+
enable: false
111+
- quantizer_name: '*mlp.shared_expert*'
112+
enable: false
113+
- quantizer_name: '*mlp.shared_expert_gate.*'
114+
enable: false
115+
- quantizer_name: '*router*'
116+
enable: false
117+
- quantizer_name: 'output.*'
118+
enable: false
119+
- parent_class: 'nn.BatchNorm1d'
120+
quantizer_name: '*'
121+
enable: false
122+
- parent_class: 'nn.BatchNorm2d'
123+
quantizer_name: '*'
124+
enable: false
125+
- parent_class: 'nn.BatchNorm3d'
126+
quantizer_name: '*'
127+
enable: false
128+
- parent_class: 'nn.LeakyReLU'
129+
quantizer_name: '*'
130+
enable: false

0 commit comments

Comments
 (0)