Commit 50706d1
Add closed-form MXFP4 -> NVFP4 weight cast (--cast_mxfp4_to_nvfp4) (#1372)
## Summary
- New `--cast_mxfp4_to_nvfp4` flag in `hf_ptq.py` (and
`huggingface_example.sh`) that converts an MXFP4 source checkpoint (e.g.
`openai/gpt-oss-20b`) into an NVFP4 export with **bit-exact** weight
reconstruction for the in-range blocks.
- The cast pins NVFP4's `scale_2 = 2^m` (where `m = k_max − 8`) and
`_amax = 6·2^k_j` per NVFP4 block, both read from the source `*_scales`.
The resulting per-block scale `2^(k_j − m)` is exactly representable in
E4M3, so `round_to_E2M1(value / 2^k_j)` yields the original MXFP4 nibble
verbatim. For out-of-range blocks (`k_max − k_j > 17`) the per-block
amax falls back to data-derived `max(|w_block|)`, which keeps the
post-E4M3-clamp scale close to the block's actual magnitude.
## Verification
End-to-end on `openai/gpt-oss-20b` with `--qformat=nvfp4_mlp_only
--cast_mxfp4_to_nvfp4`:
```
[cast_mxfp4_to_nvfp4] overrode 48/48 weight quantizers
[cast_mxfp4_to_nvfp4] lossless layers: 48/48 (100.00%)
[cast_mxfp4_to_nvfp4] lossless blocks: 597196800/597196800 (100.0000%)
```
End-to-end on `openai/gpt-oss-120b` with the same flags (4×B200,
`--use_seq_device_map --gpu_max_mem_percentage 0.5 --calib_batch_size
4`):
```
[cast_mxfp4_to_nvfp4] overrode 72/72 weight quantizers
[cast_mxfp4_to_nvfp4] lossless layers: 67/72 (93.06%)
[cast_mxfp4_to_nvfp4] lossless blocks: 3583179586/3583180800 (100.0000%)
```
Five layers fall into the OOR regime (block-spread > 17); the remaining
1,214 OOR blocks use the data-derived per-block amax fallback.
Block-level losslessness is **99.99996%** end-to-end.
Per-tensor MSE between MXFP4 source dequant and NVFP4 export dequant
(~19B elements):
| Metric | Without cast | With cast |
|---|---|---|
| Per-tensor SNR | ~26.4 dB (FP4 noise floor) | **∞ (every tensor)** |
| Total RMSE | 8.67e−02 | **0** |
| max\|err\| | up to 8.0e+1 | **0** |
## Modelopt-side enablers
- `max_calibrate` auto-promotes static-block NVFP4 weight quantizers to
`NVFP4StaticQuantizer` at the end of calibration.
- `static_blockwise_fp4_fake_quant` kernel accepts N-D inputs (was
2D-only), unblocking MoE expert weights of shape `(E, F, K)`.
- BMM-experts NVFP4 export routes through
`get_weights_scaling_factor_from_quantizer` for static-mode quantizers,
so the pinned `_amax` is actually consumed.
- `set_expert_quantizer_amax` scalar-reduces per-quantizer amax before
stacking, supporting per-block (vs scalar) static-mode amax.
## Test plan
- [x] Unit tests at `tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py`
(15 tests, all passing) cover: scalar/global-amax math, per-block hybrid
(in-range closed-form vs OOR data-derived), shape preservation, key
collection, and end-to-end `build_amax_map` against a synthetic
safetensors checkpoint.
- [x] End-to-end PTQ → export on `openai/gpt-oss-20b` (`nvfp4_mlp_only`
qformat) with `--cast_mxfp4_to_nvfp4` succeeds; export takes ~21 s. 100%
lossless cast (48/48 layers, 597,196,800 / 597,196,800 blocks).
- [x] End-to-end PTQ → export on `openai/gpt-oss-120b` (4×B200,
`nvfp4_mlp_only`, `--use_seq_device_map --gpu_max_mem_percentage 0.5
--calib_batch_size 4`). 67/72 layers fully lossless; 99.99996%
block-level losslessness (3,583,179,586 / 3,583,180,800).
- [x] TRT-LLM serving validation (TRT-LLM 1.3.0rc11, B200) on both
exported NVFP4 checkpoints via `examples/llm_ptq/run_tensorrt_llm.py`:
- **20b** (TP=1): 18.3 GB GPU memory; coherent generation. Sample:
*"Quantum computing is poised to revolutionize data analysis. However,
its potential is currently limited by quantum hardware constraints,
including error rates, qubit lifetimes, and lack of fault tolerance…"*
- **120b** (TP=4): 36.4 GB / GPU; coherent generation. Sample: *"Quantum
computing is poised to revolutionize data storage and processing. These
rare earth-based systems could serve as robust qubits; resistant to
environmental decoherence…"*
- [x] MSE comparison script (run separately during development) confirms
per-tensor SNR=∞ across all 48 MoE expert tensors.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **New Features**
* Added a MXFP4→NVFP4 weight-format cast utility and a CLI flag to
enable it; helper scripts updated to expose the option.
* **Bug Fixes**
* Fixed static NVFP4 export for expert weights.
* Improved collection/handling of quantizer amax values to avoid shape
issues.
* Generalized FP4 kernel to accept flexible tensor dimensionality.
* Ensured static-block NVFP4 promotion during calibration.
* **Tests**
* Added comprehensive tests for the conversion workflow, helpers, and
end-to-end application.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>1 parent 168cd82 commit 50706d1
11 files changed
Lines changed: 974 additions & 25 deletions
File tree
- examples/llm_ptq
- scripts
- modelopt/torch
- export
- kernels/quantization/gemm
- quantization
- tests/examples/llm_ptq
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| 21 | + | |
21 | 22 | | |
22 | 23 | | |
23 | 24 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
114 | 114 | | |
115 | 115 | | |
116 | 116 | | |
| 117 | + | |
117 | 118 | | |
118 | 119 | | |
119 | 120 | | |
| |||
128 | 129 | | |
129 | 130 | | |
130 | 131 | | |
131 | | - | |
| 132 | + | |
| 133 | + | |
132 | 134 | | |
133 | 135 | | |
134 | 136 | | |
| |||
221 | 223 | | |
222 | 224 | | |
223 | 225 | | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
224 | 242 | | |
225 | 243 | | |
226 | 244 | | |
| |||
0 commit comments