Add weight-only quantization MoE example#4507
Conversation
Adds examples/quantize_moe.py: a self-contained, self-verifying script that applies int8 (CPU) or int4 (CUDA + mslk) weight-only quantization to the experts of a small token-choice top-2 MoE block via quantize_(), keeping the router in high precision. Prints weight types, serialized size reduction (~3.9x for int8), and SQNR vs the float32 baseline, and points users with fused-3D-expert checkpoints at the FqnToConfig + PerRow(1) pattern from quantize_llama_4.py. Addresses pytorch#729
Adds ToyMoEModel and a CPU test that quantizes only the expert linears via quantize_(filter_fn=...), asserting expert weights become Int8Tensor, the router weight stays unquantized, and SQNR vs float32 stays above 25 dB. Addresses pytorch#729
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4507
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @jcaip @msaroufim — this is my first contribution to torchao, addressing #729. It adds a runnable weight-only quantization example for MoE models ( I left a few open questions in the description that I'd value your steer on — especially the preferred example location ( |
Why this PR
Issue #729 notes that
quantize_()should already support Mixture-of-Experts (MoE) models, but there is no example or tutorial demonstrating weight-only quantization on an MoE — users have no reference for the workflow even though the API reportedly covers it. The existingexamples/quantize_llama_4.pyquantizes routed experts to float8 w8a8 dynamic (activation + weight), not weight-only, andtorchao/prototype/moe_training/is MoE training, a different feature. So the weight-only-MoE showcase gap is real.I verified the maintainer's claim before writing anything: applying
Int8WeightOnlyConfigto an MoE block's expert linears works unmodified — expert weights becomeInt8Tensor, the model shrinks ~3.9x, and outputs stay within ~45 dB SQNR of fp32. The contribution is therefore purely additive: a runnable example plus a unit test, with no core changes.What this PR does
Adds
examples/quantize_moe.py: it builds a small token-choice top-2 MoE block (a softmax router plusnn.Linearexperts) and quantizes only the expert weights viaquantize_(model, Int8WeightOnlyConfig(), filter_fn=is_expert_linear). The router is deliberately left in high precision — quantizing it would change token-to-expert routing decisions, not just numerics.The script is self-verifying: it prints before/after weight types and serialized sizes, runs a forward pass, reports SQNR vs the fp32 baseline, and asserts (experts quantized, router not, ≥1.5x smaller, SQNR > 25 dB), exiting non-zero on any failure.
--dtype int8|int4and--deviceflags; int4 is gated behind a hardware/dependency warning.meta-llama/Llama-4-Scout-17B-16E-Instruct) at theFqnToConfig+PerRow(1)pattern, mirroringexamples/quantize_llama_4.py.It also adds a unit test —
TestQuantFlow.test_int8_weight_only_moe_experts_onlyintest/quantization/test_quant_api.py(with a module-levelToyMoEModelmodeled on the file'sToyLinearModel) — since CI lints examples but does not execute them, so without a test MoE support stays demonstrated-but-unguarded.Relevant issues
Closes #729
Test plan / evidence
CI lints
examples/but does not execute them (onlytutorials/run viarun_tutorials.yml), so the example is self-asserting and the unit test provides the regression guard.CPU (macOS, torch 2.12): example runs twice identically — experts →
Int8Tensor, routerfloat32, serialized 8.40 → 2.15 MB (3.90x), SQNR 45.1 dB, all asserts pass.pytest test/quantization/test_quant_api.py -k moe→ 1 passed.CUDA (RTX 3090, SM 8.6, torch 2.12.1+cu130):
The 28 skips are all hardware-gated (
Need SM 8.9+/Checkpoints are produced in SM90+— float8 paths on this SM 8.6 card) or pre-existing unconditional skips; none are introduced by this change.Acceptance criteria
test_int8_weight_only_moe_experts_only+ToyMoEModel)git stash)ruff check+ruff formatclean, ruff 0.11.6; BSD-3 header viascripts/check_copyright_header.py)Open questions for maintainers
examples/quantize_moe.py(matchingexamples/quantize_llama_4.py, added in add an example for quantizing LLaMa 4 Scout #3408) vsexamples/inference/(hinted by theliteralincludeinquant_api.py)?examples/README.mdentry — should I add one? add an example for quantizing LLaMa 4 Scout #3408 didn't.mslk >= 1.0.0, but themslkpackage on public PyPI is a0.0.0placeholder andtorchao/utils.py:1226gates real availability onis_fbcode(). On a fresh RTX 3090 with the public wheel,--dtype int4raisesImportError: Requires mslk >= 1.0.0(same on CPU and CUDA). What's the supported public way to exercise int4 weight-only? (The example gates int4 behind a flag + warning, so the int8 showcase is unaffected.)Note
Minor and separate from this PR:
torchao/quantization/quant_api.py:1499has an invalid escape'\.'in a docstring that triggersSyntaxWarning: invalid escape sequenceon import under Python ≥ 3.12. Happy to send a one-liner"""micro-fix as its own PR.