You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add nvfp4_omlp_only config and simplify the config.py (#973)
### What does this PR do?
Type of change: ? new feature
1) Add vfp4_omlp_only config == nvfp4_mlp_only + o_proj quant
2) Add block sparse MOE to mlp only config
3) Simplfiy config.py
4) Update readme in llm_ptq mention these two configs for better
accuracy.
### Usage
huggingface_script.sh ... --quant nvfp4_omlp_only
### Before your PR is "*Ready for review*"
Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)
and your commits are signed (`git commit -s -S`).
Make sure you read and follow the [Security Best
Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors)
(e.g. avoiding hardcoded `trust_remote_code=True`, using
`torch.load(..., weights_only=True)`, avoiding `pickle`, etc.).
- Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain
why. -->
- If you copied code from any other source, did you follow IP policy in
[CONTRIBUTING.md](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md#-copying-code-from-other-sources)?:
✅ / ❌ / N/A <!--- Mandatory -->
- Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory
for new features or examples. -->
- Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?:
✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes
or backward incompatible changes. -->
### Additional Information
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Added nvfp4_omlp_only quantization format for NVFP4, enabling
selective quantization of MLP and output projection layers while
preserving attention QKV projection accuracy.
* **Changed**
* pass_through_bwd now defaults to True; set to False if using STE with
zeroed outlier gradients for better QAT accuracy.
* **Documentation**
* Updated post-training quantization guidance with NVFP4-specific
configuration recommendations and usage examples.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Copy file name to clipboardExpand all lines: CHANGELOG.rst
+2Lines changed: 2 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -17,6 +17,8 @@ NVIDIA Model Optimizer Changelog
17
17
- Add support for rotating the input before quantization for RHT.
18
18
- Add support for advanced weight scale search for NVFP4 quantization and its export path.
19
19
- Enable PTQ workflow for Qwen3.5 MoE models.
20
+
- Add ``nvfp4_omlp_only`` quantization format for NVFP4 quantization. This is similar to ``nvfp4_mlp_only`` but also quantizes the output projection layer in attention.
21
+
- ``pass_through_bwd`` in the quantization config is now default to True. Please set it to False if you want to use STE with zeroed outlier gradients for potentially better QAT accuracy.
Copy file name to clipboardExpand all lines: examples/llm_ptq/README.md
+6-4Lines changed: 6 additions & 4 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -69,6 +69,8 @@ def forward_loop(model):
69
69
model = mtq.quantize(model, mtq.NVFP4_DEFAULT_CFG, forward_loop)
70
70
```
71
71
72
+
> *For higher NVFP4 PTQ accuracy, we recommend using `mtq.NVFP4_MLP_ONLY_CFG` or `mtq.NVFP4_OMLP_ONLY_CFG` instead of `mtq.NVFP4_DEFAULT_CFG`. `NVFP4_MLP_ONLY_CFG` applies NVFP4 quantization to MLP (and MoE) layers, leaving attention layers unquantized. `NVFP4_OMLP_ONLY_CFG` additionally quantizes the `o_proj` layer. Both preserve accuracy in the sensitive attention QKV projections while still providing significant compression.*
73
+
72
74
### 2. Export Quantized Model
73
75
74
76
Once your model is quantized, you can now export that model to a checkpoint for easy deployment. \
> *<sup>7.</sup>[PTQ for DeepSeek](../deepseek/README.md)*\
127
129
> *<sup>8.</sup>GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.*
128
130
129
-
> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead.*
131
+
> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead. For NVFP4 quantization specifically, we recommend `nvfp4_mlp_only` or `nvfp4_omlp_only` to achieve higher accuracy by restricting quantization to the MLP layers (and optionally the `o_proj` layer) while keeping the attention QKV projections unquantized.*
130
132
131
133
> You can also create your own custom config using [this](https://nvidia.github.io/Model-Optimizer/guides/_pytorch_quantization.html#custom-calibration-algorithm) guide.
132
134
@@ -144,7 +146,7 @@ For LLM models like [Llama-3](https://huggingface.co/meta-llama):
144
146
# Install model specific pip dependencies if needed
145
147
146
148
export HF_PATH=<the downloaded LLaMA checkpoint from the Hugging Face hub, or simply the model card>
@@ -460,4 +462,4 @@ There are many quantization schemes supported in the example scripts:
460
462
461
463
1. The W4A8 AWQ is an extension of the INT4 AWQ quantization that it also uses FP8 for activation for more speed up and acceleration.
462
464
463
-
1. The [NVFP4](https://blogs.nvidia.com/blog/generative-ai-studio-ces-geforce-rtx-50-series/) is one of the new FP4 formats supported by NVIDIA Blackwell GPU and demonstrates good accuracy compared with other 4-bit alternatives. NVFP4 can be applied to both model weights as well as activations, providing the potential for both a significant increase in math throughput and reductions in memory footprint and memory bandwidth usage compared to the FP8 data format on Blackwell.
465
+
1. The [NVFP4](https://blogs.nvidia.com/blog/generative-ai-studio-ces-geforce-rtx-50-series/) is one of the new FP4 formats supported by NVIDIA Blackwell GPU and demonstrates good accuracy compared with other 4-bit alternatives. NVFP4 can be applied to both model weights as well as activations, providing the potential for both a significant increase in math throughput and reductions in memory footprint and memory bandwidth usage compared to the FP8 data format on Blackwell. For higher accuracy with NVFP4 PTQ, we recommend `nvfp4_mlp_only` or `nvfp4_omlp_only`. `nvfp4_mlp_only` restricts NVFP4 quantization to MLP (and MoE) layers only, leaving attention layers in higher precision. `nvfp4_omlp_only` extends this by also quantizing the `o_proj` layer, providing a middle ground between full NVFP4 and MLP-only quantization.
0 commit comments