Skip to content

Commit f5d57cb

Browse files
committed
[DeepSeek] Default to top-k calibration with peer-max input amax sync
Previously DeepSeek PTQ forced every token through every MoE expert during calibration via CalibMoe. This doubled the calibration forward pass and exposed cold-routing experts to outliers they would never see at inference, inflating their input_quantizer.amax. Default now uses native top-k routing and runs fixup_moe_expert_amax after mtq.quantize. For each MoE layer x linear (w1/w2/w3), every expert's input_quantizer.amax is synced to the per-layer global peer max (dist.all_reduce(MAX) across EP ranks). weight_quantizer.amax stays per-expert; uncalibrated experts fall back to a compute path over the dequantized FP8 weight. The previous behavior is preserved behind --calib_all_experts. Also write mtq.print_quant_summary output to <output_path>/.quant_summary.txt to mirror llm_ptq/hf_ptq.py. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent 3ad4f4f commit f5d57cb

3 files changed

Lines changed: 141 additions & 10 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Changelog
1818
**New Features**
1919

2020
- Add offline DFlash speculative decoding training. Train the draft module from pre-computed base-model hidden states dumped by ``examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py``; base-model transformer layers are deleted after conversion to save memory. Controlled by the auto-derived ``dflash_offline`` flag on ``DFlashConfig`` (derived from ``data_args.offline_data_path``). The dump scripts now share ``collect_hidden_states/common.py`` for aux-layer selection (``--aux-layers eagle|dflash|<list>``) and optional assistant-token ``loss_mask`` for answer-only-loss training.
21+
- DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``.
2122

2223
0.44 (2026-05-xx)
2324
^^^^^^^^^^^^^^^^^

examples/deepseek/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,26 @@ DeepSeek V3.2
6363
torchrun --nproc-per-node 8 --master_port=12346 ptq.py --model_path $DS_CKPT --config DeepSeek-V3.2-Exp/inference/config_671B_v3.2.json --quant_cfg NVFP4_DEFAULT_CFG --output_path $FP4_QUANT_PATH
6464
```
6565

66+
#### MoE expert calibration
67+
68+
By default, calibration uses the model's native top-k routing and then runs a
69+
post-calibration sync that sets every expert's `input_quantizer.amax` (w1/w2/w3)
70+
to the per-layer global peer max (all-reduced across EP ranks).
71+
`weight_quantizer.amax` stays per-expert; any uncalibrated expert falls back to
72+
a compute path over the dequantized FP8 weight. This mirrors the
73+
`layer_sync_moe_local_experts_amax` flow that mtq runs automatically for
74+
QuantSequentialMLP-derived MoEs.
75+
76+
To restore the original behavior — force every token through every expert
77+
during calibration (slower, ~2x forwards, no post-calibration sync) — pass
78+
`--calib_all_experts`:
79+
80+
```bash
81+
torchrun --nproc-per-node 8 --master_port=12346 ptq.py --model_path $DS_CKPT --config DeepSeek-V3.2-Exp/inference/config_671B_v3.2.json --quant_cfg NVFP4_DEFAULT_CFG --output_path $FP4_QUANT_PATH --calib_all_experts
82+
```
83+
84+
A summary of every TensorQuantizer is written to `$FP4_QUANT_PATH/.quant_summary.txt`.
85+
6686
### Quantize the FP8 hf checkpoint to FP4
6787

6888
We provide a one-step-script which will:

examples/deepseek/ptq.py

Lines changed: 120 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
is_quantized_column_parallel_linear,
6262
is_quantized_parallel_linear,
6363
is_quantized_row_parallel_linear,
64+
reduce_amax,
6465
)
6566
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
6667
from modelopt.torch.utils.distributed import ParallelState
@@ -81,7 +82,7 @@
8182
from kernel import act_quant, fp8_gemm # noqa: E402
8283

8384

84-
def monkey_patch_deepseek_model():
85+
def monkey_patch_deepseek_model(calib_all_experts: bool = False):
8586
gemm_impl: Literal["bf16", "fp8"] = "bf16"
8687
block_size = 128
8788

@@ -199,6 +200,10 @@ def _setup(self):
199200
self.pe_bmm_quantizer = TensorQuantizer()
200201

201202
class CalibMoe(deekseep_model.MoE):
203+
"""MoE override that forces every token through every expert during
204+
calibration. Slower (~2x forwards) but every expert sees the full token
205+
distribution, so per-expert input amaxes are calibrated directly."""
206+
202207
def __init__(self, *args, **kwargs):
203208
super().__init__(*args, **kwargs)
204209
self._setup()
@@ -208,14 +213,11 @@ def _setup(self):
208213
self._original_topk_groups = self.gate.topk_groups
209214

210215
def forward(self, x: torch.Tensor) -> torch.Tensor:
211-
# Forward all tokens to all experts for calibration
212216
self.gate.topk = self.n_routed_experts
213217
self.gate.topk_groups = self.gate.n_groups
214218
super().forward(x)
215-
# Restore the original topk and topk_groups
216219
self.gate.topk = self._original_topk
217220
self.gate.topk_groups = self._original_topk_groups
218-
219221
return super().forward(x)
220222

221223
mtq.register(
@@ -228,10 +230,86 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
228230
)
229231
mtq.register(original_cls=deekseep_model.Linear, quantized_cls=QuantLinear)
230232
mtq.register(original_cls=deekseep_model.MLA, quantized_cls=QuantMLA)
231-
mtq.register(original_cls=deekseep_model.MoE, quantized_cls=CalibMoe)
233+
if calib_all_experts:
234+
mtq.register(original_cls=deekseep_model.MoE, quantized_cls=CalibMoe)
235+
236+
237+
def _expert_linear_names() -> list[str]:
238+
return ["w1", "w2", "w3"]
239+
240+
241+
def fixup_moe_expert_amax(transformer):
242+
"""Post-calibration amax sweep for MoE experts.
243+
244+
* ``input_quantizer`` (w1/w2/w3): every expert in a layer takes the per-layer
245+
peer max, all-reduced across ranks so the max is global across all experts.
246+
* ``weight_quantizer``: kept per-expert. Calibrated values are preserved;
247+
missing values are filled by computing amax over the dequantized FP8 weight.
248+
"""
249+
world_size = dist.get_world_size() if dist.is_initialized() else 1
250+
synced_input = fixed_weight = 0
251+
252+
def _missing(amax):
253+
return amax is None or torch.all(amax == 0)
254+
255+
for module in transformer.modules():
256+
if not isinstance(module, deekseep_model.MoE):
257+
continue
258+
local_experts = [
259+
module.experts[i]
260+
for i in range(module.experts_start_idx, module.experts_end_idx)
261+
if module.experts[i] is not None
262+
]
232263

264+
for linear_name in _expert_linear_names():
265+
linears = [getattr(e, linear_name) for e in local_experts]
233266

234-
def load_deepseek_model(model_config: str, model_path: str, batch_size: int):
267+
qs = [
268+
lin.input_quantizer
269+
for lin in linears
270+
if lin.input_quantizer.is_enabled
271+
and not getattr(lin.input_quantizer, "_dynamic", False)
272+
]
273+
valid = [q.amax.float() for q in qs if not _missing(q.amax)]
274+
if valid:
275+
m = torch.stack(valid).amax(dim=0)
276+
if world_size > 1:
277+
dist.all_reduce(m, op=dist.ReduceOp.MAX)
278+
for q in qs:
279+
q.amax = m.clone()
280+
synced_input += 1
281+
282+
for lin in linears:
283+
wq = lin.weight_quantizer
284+
if not wq.is_enabled or getattr(wq, "_dynamic", False):
285+
continue
286+
if not _missing(wq.amax):
287+
continue
288+
w = lin.weight
289+
if w.is_meta:
290+
continue
291+
# DeepSeek stores experts as FP8 with a per-block .scale; dequantize
292+
# to bf16 first so we measure the real weight distribution, not bytes.
293+
deq = weight_dequant(w, w.scale, torch.bfloat16) if w.element_size() == 1 else w
294+
axis = getattr(wq, "_axis", None)
295+
if axis is None:
296+
reduce_axis = None
297+
else:
298+
axis = (axis,) if isinstance(axis, int) else axis
299+
keep = {a % deq.dim() for a in axis}
300+
reduce_axis = tuple(d for d in range(deq.dim()) if d not in keep)
301+
wq.amax = reduce_amax(deq.detach(), axis=reduce_axis).to(torch.float32)
302+
fixed_weight += 1
303+
304+
return synced_input, fixed_weight
305+
306+
307+
def load_deepseek_model(
308+
model_config: str,
309+
model_path: str,
310+
batch_size: int,
311+
calib_all_experts: bool = False,
312+
):
235313
"""Loads the deepseek model to memory."""
236314
# get distributed info
237315
world_size = int(os.getenv("WORLD_SIZE", "1"))
@@ -252,7 +330,7 @@ def load_deepseek_model(model_config: str, model_path: str, batch_size: int):
252330
model = deekseep_model.Transformer(model_args)
253331

254332
# monkey path the model definition for quantization
255-
monkey_patch_deepseek_model()
333+
monkey_patch_deepseek_model(calib_all_experts=calib_all_experts)
256334

257335
# load model
258336
checkpoint_path = os.path.join(model_path, f"model{rank}-mp{world_size}.safetensors")
@@ -280,6 +358,8 @@ def ptq(
280358
batch_size: int,
281359
calib_size: int,
282360
mla_quant: str | None = None,
361+
calib_all_experts: bool = False,
362+
output_path: str | None = None,
283363
):
284364
"""Runs Deepseek model PTQ and returns the quantized model."""
285365

@@ -384,8 +464,18 @@ def calibrate_loop(model):
384464
## ptq
385465
transformer = mtq.quantize(transformer, mtq_cfg, calibrate_loop)
386466

467+
if not calib_all_experts:
468+
synced_input, fixed_weight = fixup_moe_expert_amax(transformer)
469+
if int(os.environ["LOCAL_RANK"]) == 0:
470+
print(
471+
f"Synced peer-max for {synced_input} expert input_quantizer(s) "
472+
f"and computed {fixed_weight} weight_quantizer amax(es) on rank 0."
473+
)
474+
387475
if int(os.environ["LOCAL_RANK"]) == 0:
388-
mtq.print_quant_summary(transformer)
476+
if output_path:
477+
os.makedirs(output_path, exist_ok=True)
478+
mtq.print_quant_summary(transformer, output_path)
389479

390480
return model
391481

@@ -472,11 +562,31 @@ def state_dict_filter(state_dict):
472562
default=None,
473563
help="MLA quantization type: None (disable), per_tensor_fp8, nvfp4",
474564
)
565+
parser.add_argument(
566+
"--calib_all_experts",
567+
action="store_true",
568+
help=(
569+
"Force every token through every MoE expert during calibration "
570+
"(slower, ~2x forwards). Default: use native top-k routing and "
571+
"post-calibration peer-max sync of expert input amaxes."
572+
),
573+
)
475574

476575
args = parser.parse_args()
477-
model = load_deepseek_model(args.config, args.model_path, args.batch_size)
576+
model = load_deepseek_model(
577+
args.config, args.model_path, args.batch_size, calib_all_experts=args.calib_all_experts
578+
)
478579
tokenizer = AutoTokenizer.from_pretrained(
479580
args.model_path, trust_remote_code=args.trust_remote_code
480581
)
481-
model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size, args.mla_quant)
582+
model = ptq(
583+
model,
584+
tokenizer,
585+
args.quant_cfg,
586+
args.batch_size,
587+
args.calib_size,
588+
args.mla_quant,
589+
calib_all_experts=args.calib_all_experts,
590+
output_path=args.output_path,
591+
)
482592
save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache)

0 commit comments

Comments
 (0)