Skip to content

Commit 0ef1b98

Browse files
committed
Extract cast_mxfp4_to_nvfp4 quant_cfg mutation into helper
Move the inline weight-quantizer block_sizes='static' rewrite out of quantize_main() into a public force_weight_quantizers_static() helper in cast_mxfp4_to_nvfp4.py, keeping the cast-specific config logic colocated with the rest of the cast flow. Addresses review feedback on PR #1372. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent 024c428 commit 0ef1b98

2 files changed

Lines changed: 19 additions & 12 deletions

File tree

examples/llm_ptq/cast_mxfp4_to_nvfp4.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,23 @@ def build_amax_map(checkpoint_dir: str | Path) -> dict[str, dict]:
291291
return amax_map
292292

293293

294+
def force_weight_quantizers_static(quant_cfg: list) -> None:
295+
"""Force every weight-quantizer entry's ``block_sizes`` to ``type='static'``.
296+
297+
The MXFP4 -> NVFP4 cast needs the per-block weight ``_amax`` to be recorded
298+
by max-cal (so it can be paired with the pinned global_amax later). Setting
299+
``block_sizes['type'] = 'static'`` makes ``is_static_block_quant`` True so
300+
``promote_nvfp4_static_quantizers`` picks the entry up automatically at the
301+
end of max_calibrate.
302+
"""
303+
for i, entry in enumerate(quant_cfg):
304+
qname = entry.get("quantizer_name", "")
305+
cfg = entry.get("cfg") or {}
306+
bs = cfg.get("block_sizes")
307+
if "weight_quantizer" in qname and isinstance(bs, dict):
308+
quant_cfg[i] = {**entry, "cfg": {**cfg, "block_sizes": {**bs, "type": "static"}}}
309+
310+
294311
def apply_to_model(
295312
model: "torch.nn.Module",
296313
source_checkpoint_path: str | Path,

examples/llm_ptq/hf_ptq.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torch
2626
from accelerate.hooks import remove_hook_from_module
2727
from cast_mxfp4_to_nvfp4 import apply_to_model as apply_cast_mxfp4_to_nvfp4
28+
from cast_mxfp4_to_nvfp4 import force_weight_quantizers_static
2829
from example_utils import (
2930
build_quant_cfg,
3031
copy_custom_model_files,
@@ -1088,20 +1089,9 @@ def quantize_main(
10881089
f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}"
10891090
)
10901091

1091-
# MXFP4 -> NVFP4 cast needs the per-block weight ``_amax`` to be recorded
1092-
# by max-cal (so it can be paired with the pinned global_amax later).
1093-
# Force every weight-quantizer entry to ``block_sizes['type'] = 'static'``
1094-
# so ``is_static_block_quant`` is True and ``promote_nvfp4_static_quantizers``
1095-
# picks them up automatically at the end of max_calibrate.
10961092
if args.cast_mxfp4_to_nvfp4:
10971093
quant_cfg = copy.deepcopy(quant_cfg)
1098-
for entry in quant_cfg.get("quant_cfg", []):
1099-
qname = entry.get("quantizer_name", "")
1100-
cfg = entry.get("cfg") or {}
1101-
bs = cfg.get("block_sizes")
1102-
if "weight_quantizer" in qname and isinstance(bs, dict):
1103-
bs = {**bs, "type": "static"}
1104-
entry["cfg"] = {**cfg, "block_sizes": bs}
1094+
force_weight_quantizers_static(quant_cfg["quant_cfg"])
11051095

11061096
if args.qformat in QUANT_CFG_CHOICES:
11071097
mono_quantize(

0 commit comments

Comments
 (0)