[mx_formats/cutedsl] Unified NVFP4 + MXFP4 (+/- RHT) quantize kernel#4517
Open
santoshmo wants to merge 1 commit into
Open
[mx_formats/cutedsl] Unified NVFP4 + MXFP4 (+/- RHT) quantize kernel#4517santoshmo wants to merge 1 commit into
santoshmo wants to merge 1 commit into
Conversation
Adds a self-contained CuTeDSL FP4 quantize subpackage: one no-smem streaming
kernel that serves both FP4 formats and all three scale layouts the GEMM
consumers use, with optional fused RHT. Supersedes the separate per-format
nvfp4_rht / mxfp4_rht casts.
* fmt="nvfp4": block 16, two-level E4M3 block scale + per-tensor global scale
(float8_e4m3fn scales); supports arbitrary K % 16 via a masked remainder.
* fmt="mxfp4": block 32, single-level E8M0 block scale (float8_e8m0fnu), floor
or rceil; requires K % 32.
* scale_layout in {linear, cublas_blocked, mma_tiled} selected at compile time
(cublas_blocked feeds f4f4bf16; mma_tiled feeds the SM100 blockscaled GEMM
with no separate scale-conversion pass).
* optional fused RHT (register FWHT16/32 + sign), skipped via a constexpr on
the plain path.
A "group" is 32 input elements = one 128-bit store = two NVFP4 blocks or one
MXFP4 block; the per-format scale recipe, FWHT size, and MMA row-atom are
compile-time FORMAT-selected so a single kernel body covers both formats. Two
byte-identical thread mappings are exposed via mapping=: "striped" (best at very
large N) and "wpr" (warp-per-row + grid.x column split; best at small/mid N).
Files: cute_utils.py (E2M1 pack + E4M3/E8M0 scale recipes + amax, bit-exact vs
eager), fwht.py (register FWHT16/32 + sign), fp4_unified_quantize.py (the
kernel + torchao::fp4_quantize_unified op + gated fp4_quantize_unified_2d).
Test: test/prototype/mx_formats/test_fp4_unified_cutedsl.py (B200-gated) checks
the plain-cast per-block scales byte-exact vs the cute_utils host references
(themselves bit-exact vs eager), wpr == striped across all layouts, qdata
invariance across the three scale layouts, and a plain-cast dequant round-trip.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4517
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 3 New FailuresAs of commit 0a22213 with merge base 213d25c ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This was referenced Jun 19, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a self-contained CuTeDSL FP4 quantize subpackage under
torchao/prototype/mx_formats/cutedsl/. A single no-smem streaming kernelhandles both FP4 formats, all three scale layouts the SM100 blockscaled
GEMMs consume, and an optional fused Random Hadamard Transform — replacing
the need for separate per-format casts.
scale (
float8_e4m3fnscales). Supports arbitraryK % 16via a maskedremainder.
float8_e8m0fnu),floororrceil. RequiresK % 32.{linear, cublas_blocked, mma_tiled}, selected at compiletime.
cublas_blockedis theto_blockedpadded swizzle (f4f4bf16 GEMM);mma_tiledis the SM100 blockscaled-GEMM atom layout (no separatescale-conversion pass);
linearis the plain(M, K//blk).amax/scale/pack; compiled out via a constexpr on the plain path.
What's added
cutedsl/fp4_unified_quantize.py— the kernel, thetorchao::fp4_quantize_unifiedcustom op, and the gatedfp4_quantize_unified_2dwrapper.cutedsl/cute_utils.py— E2M1 packing + E4M3/E8M0 block-scale recipes + amax(all bit-exact vs eager torchao).
cutedsl/fwht.py— register-resident FWHT16/32 + sign helpers.cutedsl/__init__.py— availability gating + lazy (CPU-safe) re-exports.test/prototype/mx_formats/test_fp4_unified_cutedsl.py.Design
A "group" is 32 input elements = one 128-bit store = two NVFP4 blocks or one
MXFP4 block. The per-format scale recipe, FWHT size, and MMA row-atom are
compile-time
FORMAT-selected, so a single kernel body covers both formats.Two byte-identical thread mappings are exposed via
mapping=:"striped"— threads stripe a row's groups; grid-strided rows. Best at verylarge N.
"wpr"— warp-per-row: warpwowns a contiguous row, with the 32 lanes + agrid.xcolumn split + ILP covering the columns (replicates the dense-GEMMgrid). Best at small / mid N; requires
K % 32.The kernel is no-smem (register streaming) with forced 128-bit loads and wide
128-bit stores; RHT uses a register f32 FWHT to stay bit-exact.
Correctness
Gated behind SM 10.x (Blackwell), CUDA ≥ 12.8, and the CuTeDSL runtime. The
B200-gated test suite (35 cases) checks:
cute_utilshost scalereferences (which are themselves bit-exact vs eager NVFP4/MXFP4), for both
formats incl. MXFP4
floorandrceil;wproutput is byte-identical tostripedacross all three scale layoutsand both formats (± RHT);
Performance (B200, kernel-only device time, mma_tiled)
Throughput in GB/s over total bytes moved (bf16 read + packed-FP4 write + scale
NVFP4 (block-16, two-level E4M3)
MXFP4 (block-32, E8M0, floor)
Notes
(roofline-class — a plain coalesced copy itself only hits ~79%).
threads/ilp/rows_per_cta; NVFP4plain also picks the better mapping (
stripedat very large N,warp-per-rowat small/mid N).32-element block are heavier than the E4M3 block-16 path.
RHT / plainratio is computed against the same mapping.Notes
nvfp4_rht/mxfp4_rhtCuTeDSL casts (same kernel,same byte-exact output, plus MXFP4 + the
mma_tiledlayout + warp-per-row).16/32 Hadamard would require a fragment-layout shuffle (no net win in a
no-smem kernel) and bf16 tensor-core math (breaks bit-exactness), so the
~30% RHT cost is the floor for a byte-exact fused RHT.