Skip to content

Commit 3e6b05e

Browse files
committed
[DeepSeek] Fix weight_dequant kwargs in fixup_moe_expert_amax
weight_dequant(x, s, block_size=128, dtype=...) — the third positional arg is block_size, not dtype. Passing torch.bfloat16 there sets block_size to the dtype object, which would either fail inside the triton kernel or compute amax over corrupt blocks for any uncalibrated expert. The bug never fired in our validation run because every expert was activated during calibration (top-k over 256 experts × 1024 samples), so the _missing(wq.amax) branch was dead. Spotted by bot review on PR #1380. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
1 parent f5d57cb commit 3e6b05e

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

examples/deepseek/ptq.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,9 @@ def _missing(amax):
290290
continue
291291
# DeepSeek stores experts as FP8 with a per-block .scale; dequantize
292292
# to bf16 first so we measure the real weight distribution, not bytes.
293-
deq = weight_dequant(w, w.scale, torch.bfloat16) if w.element_size() == 1 else w
293+
deq = (
294+
weight_dequant(w, w.scale, dtype=torch.bfloat16) if w.element_size() == 1 else w
295+
)
294296
axis = getattr(wq, "_axis", None)
295297
if axis is None:
296298
reduce_axis = None

0 commit comments

Comments
 (0)