@@ -59,9 +59,26 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
5959 # 2-3. Split + export each per-expert projection.
6060 fused_dim0 = gate_up .shape [1 ] # 2 * expert_dim
6161
62+ def _safe_cpu_amax (quantizer_src : nn .Module ) -> torch .Tensor | None :
63+ """Extract _amax to CPU float32, surfacing and clearing any pending CUDA error first."""
64+ amax = getattr (quantizer_src , "_amax" , None )
65+ if amax is None or not isinstance (amax , torch .Tensor ):
66+ return None
67+ try :
68+ if amax .is_cuda :
69+ torch .cuda .synchronize (amax .device )
70+ return amax .detach ().cpu ().float ()
71+ except Exception :
72+ return None
73+
6274 for idx in range (n ):
6375 expert = nn .Module ()
6476
77+ # Extract amaxes to CPU before deepcopy: cloning a corrupt CUDA _amax tensor
78+ # (e.g. from an under-calibrated expert) triggers an async CUDA error.
79+ gu_amax_cpu = _safe_cpu_amax (module .gate_up_proj_weight_quantizers [idx ])
80+ down_amax_cpu = _safe_cpu_amax (module .down_proj_weight_quantizers [idx ])
81+
6582 projections = [
6683 ("gate_proj" , gate_up [idx , :expert_dim , :], 0 , fused_dim0 , True ),
6784 ("up_proj" , gate_up [idx , expert_dim :, :], expert_dim , fused_dim0 , True ),
@@ -76,8 +93,17 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
7693 )
7794 i_quantizer = gate_up_input_q if is_gate_up else down_input_q
7895
79- # gate/up share a weight quantizer — clone so each gets independent amax.
80- w_quantizer = copy .deepcopy (w_quantizer_src ) if is_gate_up else w_quantizer_src
96+ # gate/up share a quantizer — deepcopy with _amax nulled to avoid cloning
97+ # the corrupt CUDA tensor, then inject the pre-extracted CPU amax.
98+ if is_gate_up :
99+ _saved_amax = getattr (w_quantizer_src , "_amax" , None )
100+ w_quantizer_src ._amax = None
101+ w_quantizer = copy .deepcopy (w_quantizer_src )
102+ w_quantizer_src ._amax = _saved_amax
103+ w_quantizer ._amax = gu_amax_cpu
104+ else :
105+ w_quantizer = w_quantizer_src
106+ w_quantizer ._amax = down_amax_cpu
81107
82108 # For per-channel amax (dim >= 1), proportionally slice dim-0
83109 # to match the split weight.
@@ -86,12 +112,14 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
86112 and w_quantizer ._amax is not None
87113 and w_quantizer ._amax .dim () >= 1
88114 ):
89- amax = w_quantizer ._amax
115+ amax = w_quantizer ._amax # CPU float32
90116 amax_dim0 = amax .shape [0 ]
91- if fused_total % amax_dim0 == 0 :
117+ if amax_dim0 % fused_total == 0 :
92118 slice_start = fused_start * amax_dim0 // fused_total
93119 slice_end = (fused_start + weight_slice .shape [0 ]) * amax_dim0 // fused_total
94- w_quantizer .amax = amax [slice_start :slice_end ].contiguous ()
120+ # Bypass amax.setter (which forbids shape changes); w_quantizer is a
121+ # deepcopy for gate/up so mutating it is safe.
122+ w_quantizer ._amax = amax [slice_start :slice_end ].contiguous ()
95123 else :
96124 warnings .warn (
97125 f"Expert { idx } { proj_name } : fused amax dim0 ({ amax_dim0 } ) does not "
@@ -100,20 +128,68 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
100128 stacklevel = 2 ,
101129 )
102130
103- # If the weight quantizer was never calibrated, compute amax from weights.
131+ # Patch invalid per-block amax entries (NaN/inf/negative/zero/too-small/too-large)
132+ # with weight-derived fallback values.
133+ _MIN_VALID_AMAX = 1e-4
134+ _MAX_VALID_AMAX = 1e6
135+ if (
136+ hasattr (w_quantizer , "_amax" )
137+ and w_quantizer ._amax is not None
138+ and w_quantizer ._amax .numel () > 1
139+ ):
140+ amax_cpu = w_quantizer ._amax
141+ invalid_mask = ~ (
142+ torch .isfinite (amax_cpu )
143+ & (amax_cpu >= _MIN_VALID_AMAX )
144+ & (amax_cpu <= _MAX_VALID_AMAX )
145+ )
146+ if invalid_mask .any ():
147+ per_block_fallback = (
148+ weight_slice .detach ()
149+ .reshape (- 1 , 16 )
150+ .abs ()
151+ .amax (dim = 1 , keepdim = True )
152+ .cpu ()
153+ .float ()
154+ .clamp (min = 2e-3 )
155+ .reshape (amax_cpu .shape )
156+ )
157+ amax_cpu [invalid_mask ] = per_block_fallback [invalid_mask ]
158+ w_quantizer ._amax = amax_cpu
159+
160+ # For uncalibrated experts (amax missing or invalid scalar), fall back to
161+ # per-block amax from weights so the static export path can reshape it correctly.
104162 if (
105163 hasattr (w_quantizer , "is_enabled" )
106164 and w_quantizer .is_enabled
107165 and (
108166 not hasattr (w_quantizer , "_amax" )
109167 or w_quantizer ._amax is None
110- or torch .all (w_quantizer ._amax == 0 )
168+ or (
169+ w_quantizer ._amax .numel () == 1
170+ and not (
171+ torch .isfinite (w_quantizer ._amax )
172+ and w_quantizer ._amax >= _MIN_VALID_AMAX
173+ and w_quantizer ._amax <= _MAX_VALID_AMAX
174+ )
175+ )
111176 )
112177 ):
113- w_quantizer .amax = weight_slice .abs ().amax ().to (torch .float32 )
178+ _block_size = 16
179+ fallback_per_block = (
180+ weight_slice .detach ()
181+ .reshape (- 1 , _block_size )
182+ .abs ()
183+ .amax (dim = 1 , keepdim = True )
184+ .cpu ()
185+ .float ()
186+ .clamp (min = 2e-3 )
187+ .reshape (* weight_slice .shape [:- 1 ], weight_slice .shape [- 1 ] // _block_size )
188+ )
189+ w_quantizer ._amax = fallback_per_block
114190 warnings .warn (
115191 f"Expert { idx } { proj_name } weight quantizer was not calibrated "
116- f"(amax missing or zero). Using weight-derived amax as fallback. "
192+ f"(amax missing or zero). Using weight-derived per-block amax as fallback. "
117193 f"Consider using more calibration data to activate all experts." ,
118194 stacklevel = 2 ,
119195 )
@@ -123,6 +199,18 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
123199 wrapper .weight_quantizer = w_quantizer
124200 wrapper .input_quantizer = i_quantizer
125201
202+ # Set global_amax to route to the static NVFP4 export path (reads per-block _amax).
203+ # Always recompute from the current (possibly patched) _amax — a stale zero
204+ # global_amax causes division-by-zero in the per-block scale formula.
205+ wq = wrapper .weight_quantizer
206+ if (
207+ hasattr (wq , "_amax" )
208+ and wq ._amax is not None
209+ and wq ._amax .numel () > 1
210+ ):
211+ wq ._amax = wq ._amax .to (weight_slice .device )
212+ wq .global_amax = wq ._amax .float ().amax ().clamp (min = 2e-3 )
213+
126214 _export_quantized_weight (wrapper , dtype )
127215
128216 proj = nn .Module ()
0 commit comments