Skip to content

Commit 8fa0d4c

Browse files
committed
feat(ptq): replace Nemotron-H ad-hoc lm_head/embedding helper with YAML recipe
Move the Nemotron-H-specific quantization extensions out of `hf_ptq.py` and into a declarative recipe at `modelopt_recipes/models/Nemotron-H/nvfp4_w4a16.yaml`, addressing PR #1327 review feedback. The recipe captures exactly what the removed `_enable_lm_head_and_embedding_quantization` helper did: * All Linear weight quantizers ON (NVFP4 W4A16, group_size 16, scale_bits e4m3). * Standard `_default_disabled_quantizer_cfg` exclusions (BatchNorm, conv1d, etc.). * `*lm_head*weight_quantizer`, `*embeddings*weight_quantizer`, and `*embed_tokens*weight_quantizer` re-enabled AFTER the default disables so they take precedence (last matching entry wins). Drop the helpers (`_enable_lm_head_and_embedding_quantization`, `_extract_wildcard_quantizer_cfg`) and the `if model_type == "nemotron_h":` block in `mono_quantize`. Users now opt in explicitly via `--recipe models/Nemotron-H/nvfp4_w4a16` instead of relying on auto-detection. Verified end-to-end on `nvidia/NVIDIA-Nemotron-3-Nano-4B-BF16` (RTX 6000 Ada, calib_size=16, calib_seq=256): 94 weight quantizers enabled and 21 disabled (the Mamba `*mixer.conv1d*` layers), `lm_head.weight_quantizer` and `model.embeddings.weight_quantizer` carry NVFP4 cfg, exported safetensors is 2.13 GiB (matches prior PR-validation export size), and `hf_quant_config.json` reports `quant_algo=NVFP4_W4A16`, `group_size=16`, `exclude_modules=[21 conv1d layers]`. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent a115c88 commit 8fa0d4c

3 files changed

Lines changed: 132 additions & 120 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Changelog
77
**New Features**
88

99
- Add NVFP4 W4A16 weight-only quantization (``nvfp4_w4a16``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.NVFP4_W4A16_CFG`` or ``--qformat nvfp4_w4a16`` in ``hf_ptq.py``. Exported checkpoints can be served on vLLM after conversion to compressed-tensors format.
10-
- Register ``nn.Embedding`` with ``QuantModuleRegistry`` (weight-only wrapper) and extend the unified HF exporter to pack quantized embedding weights. Enables NVFP4 quantization of ``lm_head`` and the input token embedding on hybrid SSM+Attention models such as Nemotron-H, where those two tables are a sizeable fraction of parameters and leaving them in bf16 wastes most of the compression. Nemotron-H-specific enablement + ``--exclude_modules`` CLI flag wired up in ``examples/llm_ptq/hf_ptq.py``.
10+
- Register ``nn.Embedding`` with ``QuantModuleRegistry`` (weight-only wrapper) and extend the unified HF exporter to pack quantized embedding weights. Enables NVFP4 quantization of ``lm_head`` and the input token embedding on hybrid SSM+Attention models such as Nemotron-H, where those two tables are a sizeable fraction of parameters and leaving them in bf16 wastes most of the compression. Use ``--recipe models/Nemotron-H/nvfp4_w4a16`` (see `modelopt_recipes/models/Nemotron-H/nvfp4_w4a16.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/models/Nemotron-H/nvfp4_w4a16.yaml>`_) to opt in. The ``--exclude_modules`` CLI flag in ``examples/llm_ptq/hf_ptq.py`` lets users selectively exclude individual modules from the recipe's coverage.
1111
- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
1212
- Add Puzzletron - a new algorithm for heterogeneous pruning of LLM and VLM models. See `examples/puzzletron/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/puzzletron>`_ for more details.
1313
- Added iterator interface using CalibrationDataReader in ONNX quantization workflow.

examples/llm_ptq/hf_ptq.py

Lines changed: 5 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -596,99 +596,6 @@ def sparsity_main(
596596
mts.export(full_model)
597597

598598

599-
def _enable_lm_head_and_embedding_quantization(
600-
quant_cfg: dict[str, Any],
601-
weight_quantizer_cfg: dict[str, Any],
602-
input_quantizer_cfg: dict[str, Any] | None = None,
603-
user_excluded_modules: list[str] | None = None,
604-
) -> None:
605-
"""Re-enable quantization of ``lm_head`` and the input embedding table.
606-
607-
ModelOpt's default PTQ recipes exclude ``*lm_head*`` and never touch ``nn.Embedding``
608-
because most LLM deployment runtimes keep those layers at full precision. For Nemotron-H
609-
(and similar SSM+Attention hybrids) the embedding and lm_head are a large fraction of the
610-
total parameters — quantizing them recovers most of the promised memory savings. This
611-
helper appends entries to the cfg list that override earlier ``*lm_head*`` disables
612-
and explicitly target the embedding weight quantizer.
613-
614-
For activation-aware recipes (``fp8``, ``nvfp4``, ...) ``input_quantizer_cfg`` is mirrored
615-
onto ``*lm_head*input_quantizer`` so ``lm_head`` keeps the same activation format as the
616-
rest of the model. Embedding input quantizers are left alone since
617-
``QuantEmbedding._setup`` disables them by default (embedding inputs are integer indices).
618-
619-
If ``user_excluded_modules`` is provided, entries matching any user exclusion pattern
620-
are skipped so ``--exclude_modules lm_head`` / ``--exclude_modules embeddings`` is not
621-
silently overridden.
622-
623-
Args:
624-
quant_cfg: the primary quant_cfg dict (``{"quant_cfg": [...], "algorithm": ...}``).
625-
weight_quantizer_cfg: the weight-quantizer attribute dict to apply (e.g. ``_nvfp4_cfg``).
626-
input_quantizer_cfg: the activation-quantizer attribute dict to mirror on ``lm_head``.
627-
``None`` for weight-only recipes, in which case no input-quantizer entry is added.
628-
user_excluded_modules: raw ``--exclude_modules`` patterns from the CLI; targets
629-
matching any of them (bidirectional substring match) are skipped.
630-
"""
631-
excluded = user_excluded_modules or []
632-
633-
def _user_excluded(target_hint: str) -> bool:
634-
# Bidirectional substring: "lm_head" user pattern excludes target "lm_head"; a more
635-
# specific user pattern (e.g. "backbone.embeddings") also excludes "embeddings".
636-
return any(p in target_hint or target_hint in p for p in excluded)
637-
638-
# Ordering matters: these entries must come AFTER the _default_disabled_quantizer_cfg
639-
# entries (which set *lm_head* → disabled) so they take effect.
640-
if not _user_excluded("lm_head"):
641-
quant_cfg["quant_cfg"].append(
642-
{
643-
"quantizer_name": "*lm_head*weight_quantizer",
644-
"cfg": copy.deepcopy(weight_quantizer_cfg),
645-
}
646-
)
647-
# For activation-aware recipes, keep lm_head's input format aligned with the rest of
648-
# the model — otherwise lm_head silently downgrades to weight-only and gets
649-
# reclassified as e.g. NVFP4_W4A16 on export while the rest of the model is NVFP4.
650-
if input_quantizer_cfg is not None:
651-
quant_cfg["quant_cfg"].append(
652-
{
653-
"quantizer_name": "*lm_head*input_quantizer",
654-
"cfg": copy.deepcopy(input_quantizer_cfg),
655-
}
656-
)
657-
658-
# nn.Embedding quantizers only exist once `quant_embedding.py` registers the class.
659-
# Nemotron-H's backbone attribute name differs between the remote-code ("backbone.embeddings")
660-
# and transformers built-in ("model.embeddings") paths; both are weight-only vocab
661-
# embeddings here. The broad "*embeddings*" wildcard covers both and does not match
662-
# any other layer in a Nemotron-H model (no positional/rotary embeddings exist).
663-
if not _user_excluded("embeddings"):
664-
quant_cfg["quant_cfg"].append(
665-
{
666-
"quantizer_name": "*embeddings*weight_quantizer",
667-
"cfg": copy.deepcopy(weight_quantizer_cfg),
668-
}
669-
)
670-
# Also keep the standard HF "embed_tokens" naming in case future Nemotron-H variants
671-
# rename the attribute.
672-
if not _user_excluded("embed_tokens"):
673-
quant_cfg["quant_cfg"].append(
674-
{
675-
"quantizer_name": "*embed_tokens*weight_quantizer",
676-
"cfg": copy.deepcopy(weight_quantizer_cfg),
677-
}
678-
)
679-
680-
681-
def _extract_wildcard_quantizer_cfg(
682-
quant_cfg: dict[str, Any], quantizer_attr: str
683-
) -> dict[str, Any] | None:
684-
"""Return the first ``*<quantizer_attr>`` cfg dict from an ordered quant_cfg list."""
685-
target = f"*{quantizer_attr}"
686-
for entry in quant_cfg.get("quant_cfg", []):
687-
if entry.get("quantizer_name") == target and isinstance(entry.get("cfg"), dict):
688-
return entry["cfg"]
689-
return None
690-
691-
692599
def mono_quantize(
693600
args: argparse.Namespace,
694601
quant_cfg: dict[str, Any],
@@ -725,32 +632,11 @@ def mono_quantize(
725632
) # Nemotron-Parse specific
726633
print("Quantization will only be applied to the decoder (text generation) component")
727634

728-
# For Nemotron-H (Mamba-2 + MLP + Attention hybrid, e.g. NVIDIA-Nemotron-3-Nano-4B),
729-
# extend quantization coverage to the lm_head and the input token embedding. On this
730-
# architecture those two 131072x3136 tables account for ~21% of parameters, so leaving
731-
# them at bf16 wastes most of the NVFP4 memory benefit.
732-
if model_type == "nemotron_h":
733-
weight_quantizer_cfg = _extract_wildcard_quantizer_cfg(quant_cfg, "weight_quantizer")
734-
if weight_quantizer_cfg is not None:
735-
# ``input_quantizer_cfg`` is present only for activation-aware recipes (fp8, nvfp4,
736-
# ...). For weight-only recipes (nvfp4_w4a16, fp8_pb_wo, ...) this returns None and
737-
# ``lm_head`` stays weight-only along with the embedding.
738-
input_quantizer_cfg = _extract_wildcard_quantizer_cfg(quant_cfg, "input_quantizer")
739-
print(
740-
"Nemotron-H detected: extending quantization to lm_head and input embedding "
741-
"(backbone.embeddings)."
742-
)
743-
_enable_lm_head_and_embedding_quantization(
744-
quant_cfg,
745-
weight_quantizer_cfg,
746-
input_quantizer_cfg=input_quantizer_cfg,
747-
user_excluded_modules=args.exclude_modules or None,
748-
)
749-
else:
750-
warnings.warn(
751-
"Nemotron-H detected but quant_cfg has no wildcard '*weight_quantizer' entry; "
752-
"skipping lm_head/embedding extension (model-specific or non-standard recipe)."
753-
)
635+
# Model-specific quantization extensions (e.g. quantizing lm_head + input embedding for
636+
# Nemotron-H, where those tables are a large fraction of parameters and leaving them at
637+
# bf16 wastes most of the memory savings) are now expressed as recipes under
638+
# ``modelopt_recipes/models/<ModelName>/``. Pass ``--recipe models/<ModelName>/<flavor>``
639+
# (e.g. ``--recipe models/Nemotron-H/nvfp4_w4a16``) to opt in.
754640

755641
if not model_is_already_quantized or calibration_only:
756642
# quantize the model
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# NVFP4 W4A16 (weight-only) recipe for Nemotron-H hybrid models.
17+
#
18+
# Mirrors the general ``nvfp4_w4a16`` qformat (NVFP4_W4A16_CFG) but additionally
19+
# re-enables quantization of ``lm_head`` and the input token embedding. On
20+
# Nemotron-3-Nano-4B those two 131072x3136 tables account for ~21% of model
21+
# parameters, so leaving them at bf16 wastes most of the NVFP4 memory benefit.
22+
#
23+
# Coverage:
24+
# * Linear layers in attention + MLP: NVFP4 W4A16 weight-only.
25+
# * lm_head: NVFP4 W4A16 weight-only (re-enabled here; default disables it).
26+
# * Input embedding (``backbone.embeddings`` / ``model.embed_tokens``):
27+
# NVFP4 W4A16 weight-only via ``QuantEmbedding``. Embedding inputs are
28+
# integer indices, so the input quantizer is intentionally not enabled.
29+
# * Mamba ``*mixer.conv1d*``: kept at bf16 (default exclusion).
30+
#
31+
# Notes for vLLM consumption:
32+
# * ``vllm.compressed-tensors`` consumes packed NVFP4 weights for ``Linear``
33+
# and ``Embedding`` layers when the corresponding kernels are present. As
34+
# of vLLM 0.19, ``ParallelLMHead`` and ``VocabParallelEmbedding`` need an
35+
# additional patch to dispatch ``CompressedTensorsLinearMethod``; see the
36+
# PR notes for details. If the target deployment is stock vLLM and you
37+
# can't apply that patch, use the general ``nvfp4_w4a16`` qformat
38+
# instead, which leaves ``lm_head`` and embeddings at bf16.
39+
40+
metadata:
41+
recipe_type: ptq
42+
description: NVFP4 W4A16 weight-only for Nemotron-H, including lm_head and input embedding.
43+
quantize:
44+
algorithm: max
45+
quant_cfg:
46+
# Start with everything disabled, then enable layers explicitly.
47+
- quantizer_name: '*'
48+
enable: false
49+
50+
# Quantize all Linear weight quantizers (attention q/k/v/o + MLP up/down).
51+
- quantizer_name: '*weight_quantizer'
52+
enable: true
53+
cfg:
54+
block_sizes:
55+
-1: 16
56+
type: dynamic
57+
scale_bits: e4m3
58+
num_bits: e2m1
59+
60+
# Standard exclusions copied from ``_default_disabled_quantizer_cfg``.
61+
# Order matters: later entries override earlier ones in
62+
# ``modelopt.torch.quantization.set_quantizer_by_cfg``.
63+
- quantizer_name: '*lm_head*'
64+
enable: false
65+
- quantizer_name: '*proj_out.*'
66+
enable: false
67+
- quantizer_name: '*block_sparse_moe.gate*'
68+
enable: false
69+
- quantizer_name: '*router*'
70+
enable: false
71+
- quantizer_name: '*mlp.gate.*'
72+
enable: false
73+
- quantizer_name: '*mlp.shared_expert_gate.*'
74+
enable: false
75+
- quantizer_name: '*linear_attn.conv1d*'
76+
enable: false
77+
- quantizer_name: '*mixer.conv1d*'
78+
enable: false
79+
- quantizer_name: '*output_layer*'
80+
enable: false
81+
- quantizer_name: 'output.*'
82+
enable: false
83+
- parent_class: 'nn.BatchNorm1d'
84+
quantizer_name: '*'
85+
enable: false
86+
- parent_class: 'nn.BatchNorm2d'
87+
quantizer_name: '*'
88+
enable: false
89+
- parent_class: 'nn.BatchNorm3d'
90+
quantizer_name: '*'
91+
enable: false
92+
- parent_class: 'nn.LeakyReLU'
93+
quantizer_name: '*'
94+
enable: false
95+
96+
# Nemotron-H specific overrides: re-enable the weight quantizer for
97+
# ``lm_head`` and the input embedding. These come AFTER the default
98+
# disables above so they take precedence (last matching entry wins).
99+
- quantizer_name: '*lm_head*weight_quantizer'
100+
enable: true
101+
cfg:
102+
block_sizes:
103+
-1: 16
104+
type: dynamic
105+
scale_bits: e4m3
106+
num_bits: e2m1
107+
108+
# Two embedding patterns cover both the Nemotron-H remote-code path
109+
# (``backbone.embeddings``) and the standard transformers naming
110+
# (``model.embed_tokens``).
111+
- quantizer_name: '*embeddings*weight_quantizer'
112+
enable: true
113+
cfg:
114+
block_sizes:
115+
-1: 16
116+
type: dynamic
117+
scale_bits: e4m3
118+
num_bits: e2m1
119+
- quantizer_name: '*embed_tokens*weight_quantizer'
120+
enable: true
121+
cfg:
122+
block_sizes:
123+
-1: 16
124+
type: dynamic
125+
scale_bits: e4m3
126+
num_bits: e2m1

0 commit comments

Comments
 (0)