Skip to content

Commit 50706d1

Browse files
cjluo-nvclaude
andauthored
Add closed-form MXFP4 -> NVFP4 weight cast (--cast_mxfp4_to_nvfp4) (#1372)
## Summary - New `--cast_mxfp4_to_nvfp4` flag in `hf_ptq.py` (and `huggingface_example.sh`) that converts an MXFP4 source checkpoint (e.g. `openai/gpt-oss-20b`) into an NVFP4 export with **bit-exact** weight reconstruction for the in-range blocks. - The cast pins NVFP4's `scale_2 = 2^m` (where `m = k_max − 8`) and `_amax = 6·2^k_j` per NVFP4 block, both read from the source `*_scales`. The resulting per-block scale `2^(k_j − m)` is exactly representable in E4M3, so `round_to_E2M1(value / 2^k_j)` yields the original MXFP4 nibble verbatim. For out-of-range blocks (`k_max − k_j > 17`) the per-block amax falls back to data-derived `max(|w_block|)`, which keeps the post-E4M3-clamp scale close to the block's actual magnitude. ## Verification End-to-end on `openai/gpt-oss-20b` with `--qformat=nvfp4_mlp_only --cast_mxfp4_to_nvfp4`: ``` [cast_mxfp4_to_nvfp4] overrode 48/48 weight quantizers [cast_mxfp4_to_nvfp4] lossless layers: 48/48 (100.00%) [cast_mxfp4_to_nvfp4] lossless blocks: 597196800/597196800 (100.0000%) ``` End-to-end on `openai/gpt-oss-120b` with the same flags (4×B200, `--use_seq_device_map --gpu_max_mem_percentage 0.5 --calib_batch_size 4`): ``` [cast_mxfp4_to_nvfp4] overrode 72/72 weight quantizers [cast_mxfp4_to_nvfp4] lossless layers: 67/72 (93.06%) [cast_mxfp4_to_nvfp4] lossless blocks: 3583179586/3583180800 (100.0000%) ``` Five layers fall into the OOR regime (block-spread > 17); the remaining 1,214 OOR blocks use the data-derived per-block amax fallback. Block-level losslessness is **99.99996%** end-to-end. Per-tensor MSE between MXFP4 source dequant and NVFP4 export dequant (~19B elements): | Metric | Without cast | With cast | |---|---|---| | Per-tensor SNR | ~26.4 dB (FP4 noise floor) | **∞ (every tensor)** | | Total RMSE | 8.67e−02 | **0** | | max\|err\| | up to 8.0e+1 | **0** | ## Modelopt-side enablers - `max_calibrate` auto-promotes static-block NVFP4 weight quantizers to `NVFP4StaticQuantizer` at the end of calibration. - `static_blockwise_fp4_fake_quant` kernel accepts N-D inputs (was 2D-only), unblocking MoE expert weights of shape `(E, F, K)`. - BMM-experts NVFP4 export routes through `get_weights_scaling_factor_from_quantizer` for static-mode quantizers, so the pinned `_amax` is actually consumed. - `set_expert_quantizer_amax` scalar-reduces per-quantizer amax before stacking, supporting per-block (vs scalar) static-mode amax. ## Test plan - [x] Unit tests at `tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py` (15 tests, all passing) cover: scalar/global-amax math, per-block hybrid (in-range closed-form vs OOR data-derived), shape preservation, key collection, and end-to-end `build_amax_map` against a synthetic safetensors checkpoint. - [x] End-to-end PTQ → export on `openai/gpt-oss-20b` (`nvfp4_mlp_only` qformat) with `--cast_mxfp4_to_nvfp4` succeeds; export takes ~21 s. 100% lossless cast (48/48 layers, 597,196,800 / 597,196,800 blocks). - [x] End-to-end PTQ → export on `openai/gpt-oss-120b` (4×B200, `nvfp4_mlp_only`, `--use_seq_device_map --gpu_max_mem_percentage 0.5 --calib_batch_size 4`). 67/72 layers fully lossless; 99.99996% block-level losslessness (3,583,179,586 / 3,583,180,800). - [x] TRT-LLM serving validation (TRT-LLM 1.3.0rc11, B200) on both exported NVFP4 checkpoints via `examples/llm_ptq/run_tensorrt_llm.py`: - **20b** (TP=1): 18.3 GB GPU memory; coherent generation. Sample: *"Quantum computing is poised to revolutionize data analysis. However, its potential is currently limited by quantum hardware constraints, including error rates, qubit lifetimes, and lack of fault tolerance…"* - **120b** (TP=4): 36.4 GB / GPU; coherent generation. Sample: *"Quantum computing is poised to revolutionize data storage and processing. These rare earth-based systems could serve as robust qubits; resistant to environmental decoherence…"* - [x] MSE comparison script (run separately during development) confirms per-tensor SNR=∞ across all 48 MoE expert tensors. 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a MXFP4→NVFP4 weight-format cast utility and a CLI flag to enable it; helper scripts updated to expose the option. * **Bug Fixes** * Fixed static NVFP4 export for expert weights. * Improved collection/handling of quantizer amax values to avoid shape issues. * Generalized FP4 kernel to accept flexible tensor dimensionality. * Ensured static-block NVFP4 promotion during calibration. * **Tests** * Added comprehensive tests for the conversion workflow, helpers, and end-to-end application. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chenjie Luo <chenjiel@nvidia.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 168cd82 commit 50706d1

11 files changed

Lines changed: 974 additions & 25 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Changelog
1818
**New Features**
1919

2020
- Add offline DFlash speculative decoding training. Train the draft module from pre-computed base-model hidden states dumped by ``examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py``; base-model transformer layers are deleted after conversion to save memory. Controlled by the auto-derived ``dflash_offline`` flag on ``DFlashConfig`` (derived from ``data_args.offline_data_path``). The dump scripts now share ``collect_hidden_states/common.py`` for aux-layer selection (``--aux-layers eagle|dflash|<list>``) and optional assistant-token ``loss_mask`` for answer-only-loss training.
21+
- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq#mxfp4--nvfp4-cast-for-gpt-oss>`__ for usage.
2122

2223
0.44 (2026-05-xx)
2324
^^^^^^^^^^^^^^^^^

examples/llm_ptq/README.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
114114
| GLM-4.7<sup>8</sup> || - | - | - ||
115115
| Kimi K2 | - | - | - | - ||
116116
| MiniMax M2.1 | - | - | - | - ||
117+
| GPT-OSS<sup>10</sup> | - | - | - | - ||
117118
| T5 ||||| - |
118119
| Whisper<sup>9</sup> ||||| - |
119120
| Nemotron-3 ||||||
@@ -128,7 +129,8 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
128129
> *<sup>6.</sup>Some models currently support export to HF format only.* \
129130
> *<sup>7.</sup>[PTQ for DeepSeek](../deepseek/README.md)* \
130131
> *<sup>8.</sup>GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.* \
131-
> *<sup>9.</sup>Running Whisper model with transformers>=5.0 requires [torchcodec](https://github.com/meta-pytorch/torchcodec?tab=readme-ov-file#installing-cuda-enabled-torchcodec) and other system packages (e.g. ffmpeg).*
132+
> *<sup>9.</sup>Running Whisper model with transformers>=5.0 requires [torchcodec](https://github.com/meta-pytorch/torchcodec?tab=readme-ov-file#installing-cuda-enabled-torchcodec) and other system packages (e.g. ffmpeg).* \
133+
> *<sup>10.</sup>GPT-OSS ships with native MXFP4 weights; NVFP4 export is produced via the closed-form `--cast_mxfp4_to_nvfp4` cast (see [MXFP4 → NVFP4 cast](#mxfp4--nvfp4-cast-for-gpt-oss)).*
132134
133135
> *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead. For NVFP4 quantization specifically, we recommend `nvfp4_mlp_only`, `nvfp4_experts_only`, or `nvfp4_omlp_only` to achieve higher accuracy by restricting quantization to the MLP/expert layers (and optionally the `o_proj` layer) while keeping the attention QKV projections unquantized.*
134136
@@ -221,6 +223,22 @@ Available KV cache formats:
221223

222224
> *Formats ending in `_cast` (fp8_cast, nvfp4_cast) are fast — they set the amax to the format's full range without data-driven calibration. Other formats use data-driven calibration for potentially better accuracy.*
223225
226+
#### MXFP4 → NVFP4 cast (for GPT-OSS)
227+
228+
GPT-OSS checkpoints (`openai/gpt-oss-20b`, `openai/gpt-oss-120b`) ship with native MXFP4 weights (`*_blocks` + `*_scales` in the checkpoint, `quantization_config.quant_method == "mxfp4"`). Passing `--cast_mxfp4_to_nvfp4` tells `hf_ptq.py` to read the source MXFP4 scales and produce a closed-form, bit-exact NVFP4 weight export — no GEMM-level recalibration of the weights needed.
229+
230+
```bash
231+
python hf_ptq.py \
232+
--pyt_ckpt_path openai/gpt-oss-20b \
233+
--qformat nvfp4_mlp_only \
234+
--cast_mxfp4_to_nvfp4 \
235+
--export_path <quantized_ckpt_path>
236+
```
237+
238+
The cast pins each NVFP4 block's `scale_2 = 2^(k_max - 8)` and `_amax = 6 * 2^k_j`, both derived from the source MXFP4 E8M0 scales. For blocks whose `k_j` lands in E4M3's representable window (`k_max - k_j ≤ 17`), NVFP4 dequant matches MXFP4 dequant bit-for-bit; out-of-range blocks fall back to a data-derived per-block amax.
239+
240+
> *`--cast_mxfp4_to_nvfp4` requires an NVFP4-family `--qformat` (e.g. `nvfp4_mlp_only`, `nvfp4_experts_only`, `nvfp4`) and is incompatible with `--auto_quantize_bits`.*
241+
224242
#### Deepseek R1
225243

226244
[PTQ for DeepSeek](../deepseek/README.md) shows how to quantize the DeepSeek model with FP4 and export to TensorRT-LLM.

0 commit comments

Comments
 (0)