Skip to content

[mx_formats/cutedsl] Unified NVFP4 + MXFP4 (+/- RHT) quantize kernel#4517

Open
santoshmo wants to merge 1 commit into
pytorch:mainfrom
santoshmo:fp4-unified-cutedsl
Open

[mx_formats/cutedsl] Unified NVFP4 + MXFP4 (+/- RHT) quantize kernel#4517
santoshmo wants to merge 1 commit into
pytorch:mainfrom
santoshmo:fp4-unified-cutedsl

Conversation

@santoshmo

Copy link
Copy Markdown

Summary

Adds a self-contained CuTeDSL FP4 quantize subpackage under
torchao/prototype/mx_formats/cutedsl/. A single no-smem streaming kernel
handles both FP4 formats, all three scale layouts the SM100 blockscaled
GEMMs consume, and an optional fused Random Hadamard Transform — replacing
the need for separate per-format casts.

  • NVFP4 — E2M1 block-16, two-level E4M3 block scale + per-tensor global
    scale (float8_e4m3fn scales). Supports arbitrary K % 16 via a masked
    remainder.
  • MXFP4 — E2M1 block-32, single-level E8M0 block scale (float8_e8m0fnu),
    floor or rceil. Requires K % 32.
  • scale_layout{linear, cublas_blocked, mma_tiled}, selected at compile
    time. cublas_blocked is the to_blocked padded swizzle (f4f4bf16 GEMM);
    mma_tiled is the SM100 blockscaled-GEMM atom layout (no separate
    scale-conversion pass); linear is the plain (M, K//blk).
  • optional fused RHT — register FWHT16/32 + sign, applied per block before
    amax/scale/pack; compiled out via a constexpr on the plain path.

What's added

  • cutedsl/fp4_unified_quantize.py — the kernel, the
    torchao::fp4_quantize_unified custom op, and the gated
    fp4_quantize_unified_2d wrapper.
  • cutedsl/cute_utils.py — E2M1 packing + E4M3/E8M0 block-scale recipes + amax
    (all bit-exact vs eager torchao).
  • cutedsl/fwht.py — register-resident FWHT16/32 + sign helpers.
  • cutedsl/__init__.py — availability gating + lazy (CPU-safe) re-exports.
  • test/prototype/mx_formats/test_fp4_unified_cutedsl.py.

Design

A "group" is 32 input elements = one 128-bit store = two NVFP4 blocks or one
MXFP4 block. The per-format scale recipe, FWHT size, and MMA row-atom are
compile-time FORMAT-selected, so a single kernel body covers both formats.

Two byte-identical thread mappings are exposed via mapping=:

  • "striped" — threads stripe a row's groups; grid-strided rows. Best at very
    large N.
  • "wpr" — warp-per-row: warp w owns a contiguous row, with the 32 lanes + a
    grid.x column split + ILP covering the columns (replicates the dense-GEMM
    grid). Best at small / mid N; requires K % 32.

The kernel is no-smem (register streaming) with forced 128-bit loads and wide
128-bit stores; RHT uses a register f32 FWHT to stay bit-exact.

Correctness

Gated behind SM 10.x (Blackwell), CUDA ≥ 12.8, and the CuTeDSL runtime. The
B200-gated test suite (35 cases) checks:

  • plain-cast per-block scales byte-exact vs the cute_utils host scale
    references (which are themselves bit-exact vs eager NVFP4/MXFP4), for both
    formats incl. MXFP4 floor and rceil;
  • wpr output is byte-identical to striped across all three scale layouts
    and both formats (± RHT);
  • qdata is invariant across the three scale layouts;
  • a plain-cast dequant round-trip reconstructs the input to within FP4 error.

Performance (B200, kernel-only device time, mma_tiled)

Throughput in GB/s over total bytes moved (bf16 read + packed-FP4 write + scale

NVFP4 (block-16, two-level E4M3)

M × N plain GB/s % peak + RHT GB/s % peak RHT / plain
2304 × 4096 4405 55% 2944 37% 0.72×
16384 × 4096 5581 70% 3627 45% 0.70×
56064 × 4096 6040 76% 4010 50% 0.70×
2304 × 65536 6125 77% 4183 52% 0.68×
11776× 65536 6363 80% 4446 56% 0.70×

MXFP4 (block-32, E8M0, floor)

M × N plain GB/s % peak + RHT GB/s % peak RHT / plain
2304 × 4096 3523 44% 2613 33% 0.74×
16384 × 4096 4064 51% 3058 38% 0.75×
56064 × 4096 4434 55% 3329 42% 0.75×
2304 × 65536 4814 60% 3240 41% 0.67×
11776× 65536 5238 65% 3489 44% 0.67×

Notes

  • % peak is vs ~8 TB/s B200 HBM. Plain NVFP4 reaches ~80% at large N
    (roofline-class — a plain coalesced copy itself only hits ~79%).
  • Best of an autotune sweep over threads / ilp / rows_per_cta; NVFP4
    plain also picks the better mapping (striped at very large N,
    warp-per-row at small/mid N).
  • MXFP4 plain runs ~13% below NVFP4 plain — the E8M0 scale recipe + amax over a
    32-element block are heavier than the E4M3 block-16 path.
  • Fused RHT costs a consistent ~30% (the register f32 FWHT butterfly); the
    RHT / plain ratio is computed against the same mapping.

Notes

  • Supersedes the separate nvfp4_rht / mxfp4_rht CuTeDSL casts (same kernel,
    same byte-exact output, plus MXFP4 + the mma_tiled layout + warp-per-row).
  • The RHT path is deliberately the f32 register FWHT: MMA-izing the per-block
    16/32 Hadamard would require a fragment-layout shuffle (no net win in a
    no-smem kernel) and bf16 tensor-core math (breaks bit-exactness), so the
    ~30% RHT cost is the floor for a byte-exact fused RHT.

Adds a self-contained CuTeDSL FP4 quantize subpackage: one no-smem streaming
kernel that serves both FP4 formats and all three scale layouts the GEMM
consumers use, with optional fused RHT. Supersedes the separate per-format
nvfp4_rht / mxfp4_rht casts.

* fmt="nvfp4": block 16, two-level E4M3 block scale + per-tensor global scale
  (float8_e4m3fn scales); supports arbitrary K % 16 via a masked remainder.
* fmt="mxfp4": block 32, single-level E8M0 block scale (float8_e8m0fnu), floor
  or rceil; requires K % 32.
* scale_layout in {linear, cublas_blocked, mma_tiled} selected at compile time
  (cublas_blocked feeds f4f4bf16; mma_tiled feeds the SM100 blockscaled GEMM
  with no separate scale-conversion pass).
* optional fused RHT (register FWHT16/32 + sign), skipped via a constexpr on
  the plain path.

A "group" is 32 input elements = one 128-bit store = two NVFP4 blocks or one
MXFP4 block; the per-format scale recipe, FWHT size, and MMA row-atom are
compile-time FORMAT-selected so a single kernel body covers both formats. Two
byte-identical thread mappings are exposed via mapping=: "striped" (best at very
large N) and "wpr" (warp-per-row + grid.x column split; best at small/mid N).

Files: cute_utils.py (E2M1 pack + E4M3/E8M0 scale recipes + amax, bit-exact vs
eager), fwht.py (register FWHT16/32 + sign), fp4_unified_quantize.py (the
kernel + torchao::fp4_quantize_unified op + gated fp4_quantize_unified_2d).

Test: test/prototype/mx_formats/test_fp4_unified_cutedsl.py (B200-gated) checks
the plain-cast per-block scales byte-exact vs the cute_utils host references
(themselves bit-exact vs eager), wpr == striped across all layouts, qdata
invariance across the three scale layouts, and a plain-cast dequant round-trip.
@pytorch-bot

pytorch-bot Bot commented Jun 19, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4517

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 3 New Failures

As of commit 0a22213 with merge base 213d25c (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant