Skip to content

feat: support mxfp4 × mxfp4 gemm on sm100#348

Open
z52527 wants to merge 12 commits into
deepseek-ai:nv_devfrom
z52527:fea-fp4-synced
Open

feat: support mxfp4 × mxfp4 gemm on sm100#348
z52527 wants to merge 12 commits into
deepseek-ai:nv_devfrom
z52527:fea-fp4-synced

Conversation

@z52527

@z52527 z52527 commented Jun 1, 2026

Copy link
Copy Markdown

Adds an FP4×FP4 (MXF4) GEMM path on SM100/B200 using SM100_MMA_MXF4_SS directly, separate from main's unified MXF8F6F4 kernel. Dispatched from the existing fp8_fp4 APIs (dense / m-grouped contiguous / m-grouped masked) when both operands are FP4. Tested with tests/test_fp4.py on B200 (28/28 pass). Bench: median ~1.07× vs cublasTest 13.3 direct, 10/13 shapes ≥ cuBLAS.

Dense GEMM perf (B200, CUDA 12.9, vs cublasTest 13.3 direct binary)

M×N×K DG TFLOPS cBLT TFLOPS DG/cBLT
128×2048×7168 328 338 0.97×
128×24576×1536 980 745 1.32×
512×2048×7168 1178 1422 0.83×
512×7168×2048 1439 1190 1.21×
1024×7168×2048 2056 1399 1.47×
4096×4096×7168 5332 3939 1.35×
8192×8192×8192 5702 5239 1.09×

TODO

  • BF16 D epilogue — kernel currently hardcodes fp32 output
  • Contract-rejection tests — wrapper has DG_HOST_ASSERTs for aliasing / K-major / k % 8 but no test coverage (repo-wide gap, not FP4-only)
  • Sparse-MoE swap_ab gating uses m / num_groups as a host estimate accurate for dense, but misses swap_ab on -1-padded sparse layouts (padding inflates the estimate above block_m). Avoids a device sync on m_indices; revisit if sparse MoE perf is a bottleneck.
  • ~5–15% perf gap vs cuBLAS on mid-M shapes (M ∈ {256, 512}, N ∈ {2k, 4k}, K=7168)

z52527 and others added 12 commits May 29, 2026 02:46
Migrate fea-fp4's standalone MXF4 (SM100_MMA_MXF4_SS) GEMM on top of main's
post-PR-304 refactor. Keeps the FP4-specialized hardware path distinct from
main's MXF8F6F4 unified kernel.

- csrc/jit_kernels/impls/sm100_fp4_gemm_1d1d.hpp: rebuilt on GemmDesc + nested
  GemmConfig {layout, storage_config, pipeline_config, launch_config}. Forces
  a_dtype=b_dtype=kPackedFP4. Three entry points: dense, m-grouped contiguous,
  m-grouped masked. BLOCK_K and SHAPE_K convert main's byte/FP4-count to my
  kernel's int32-count units.
- deep_gemm/include/deep_gemm/impls/sm100_fp4_gemm_1d1d.cuh: scheduler ->
  sched::Scheduler with shape_k ctor arg; KGroupedIndexType -> sched::IndexType;
  helpers via math:: / ptx:: / mma::sm100:: namespaces; tma_copy signature
  (multicast moved to runtime); make_runtime_instr_desc_with_sf_id takes
  sfb_id. swap_ab epilogue / MXF4 MMA path preserved.
- csrc/apis/gemm.hpp: dispatch FP4xFP4 -> sm100_fp4_gemm_1d1d (MXF4 path);
  FP8xFP4 and other mixed cases unchanged -> main's sm100_fp8_fp4_gemm_1d1d.

Status: host build clean, JIT NVCC clean (was 19 errors before). Runtime
IMA on first launch -- next: localize via compute-sanitizer (TMA descriptor
K-unit or SF tensor layout suspected). tests/test_fp4.py copied as-is from
fea-fp4; needs recipe=(1,1,32) update for FP4 SF granularity.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
…lamp

Stack of incremental fixes on top of previous port commit (7d27b8a). Build
clean, JIT clean, kernel launches without IMA. NaN output remains (deeper
SF/data-layout debug needed).

Wrapper changes (sm100_fp4_gemm_1d1d.hpp):
- view int8 (kPackedFP4) tensors as int32 before TMA descriptor creation, so
  TMA uses CU_TENSOR_MAP_DATA_TYPE_INT32 (matches my kernel's int32-packed
  smem expectation, sidesteps main's 16U4_ALIGN16B unpacked-smem path).
- Convert k and block_k to int32 units (k/8, block_k/4) for TMA + kernel
  template instantiation.
- num_stages clamp: my kernel allocates 2x SF smem/stage vs main's heuristic
  estimate; cap stages so total smem fits 232448-byte capacity and
  num_stages <= num_k_blocks (fea-fp4 invariant).
- Single-CTA force (v0): override cluster_m=cluster_n=1; recompute
  storage_config / pipeline_config so they stay self-consistent with the
  new layout. main's heuristic picks cluster=2 too aggressively for shapes
  fea-fp4 stayed single-CTA on.

Status: kernel runs, produces wrong output (suspected SF byte ordering or
TMEM column placement mismatch between main's UE8M0 packed transform and
my kernel's read expectation). fea-fp4 tests historically used sf=1.0
which masks SF reading bugs (1.0 multiplier regardless of byte order);
my smoke test with main's generators uses varying SF, exposing the issue.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Continuing port towards numerical correctness. Current state: kernel builds
clean, JIT clean, runs without IMA, but writes zeros / wrong values.

Wrapper changes (sm100_fp4_gemm_1d1d.hpp):
- Override main's layout choice to fea-fp4's tested config: block_m=128,
  block_n=112 (kSwizzleCDMode=32, matches fea-fp4 test coverage), single-CTA
  cluster. main's heuristic picks block_n=16 / cluster=2 which fea-fp4 kernel
  was never validated against.
- Fix sf_packed_k_per_stage formula in recompute_stages_for_fp4 (was off by 4x
  with block_k_int32/4; correct: block_k_bytes/64).

Test adaptation (tests/test_fp4.py):
- pack_fp4_random / pack_fp4_constant now produce int8 (kPackedFP4) instead
  of int32, matching main's API. Same byte layout when viewed as int32.
- fp4_reference handles int8 packing (2 FP4 per byte: low nibble + high
  nibble), works for both int8 and int32 packed inputs.
- run_kernel uses recipe_a=(1,32), recipe_b=(1,32) for FP4 SF granularity
  (was (1,1,128) which fea-fp4 commented out shape check to allow).

Open issue: kernel writes 0 / wrong values for both test_constant (sf=1.0)
and test_random. Needs device-side printf to localize: A/B smem content vs
expected layout, MMA output before vs after epilogue, SF byte ordering in
TMEM after UTCCP.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
…gue path)

Root cause of port "writes zeros" bug found: the kernel's epilogue store
loop has only an `if constexpr (cd_dtype_t == float)` branch — no `else` for
bf16 — so when d.dtype is bf16, the store body is a no-op and d stays at
initial value (zero). This matches the existing memory `fp4_output_dtype.md`
note that bf16 epilogue branch is a missing TODO.

fea-fp4's run_kernel uses `torch.float32` for d; my earlier port test was
generating with `torch.bfloat16` (main's typical FP8 output), which silently
hit the no-op branch. Switching back to fp32 makes everything work.

Verified PASS on fea-fp4-synced after this fix:
  test_constant (9/9 PASS, exact integer match for sf=1.0 case)
  test_random   (9/9 PASS, max_diff=0.0000 vs CPU reference)
  test_random_sf (8/8 PASS, max_diff=0.0000 with varying SF)

This closes the dense-FP4 port. m-grouped (contiguous + masked) paths still
need a similar walkthrough and test_m_grouped_* validation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Complete the test adaptation to main's API surface:
- Switch m_grouped helpers to int8 (kPackedFP4) packing for both 2D and 3D
  pack_fp4_random variants.
- Rename remaining deep_gemm.m_grouped_fp8_gemm_* -> m_grouped_fp8_fp4_gemm_*.
- Use recipe_a=(1,32) / recipe_b=(1,32) for FP4 SF granularity throughout.

csrc/apis/gemm.hpp: relax `d.scalar_type() == kBFloat16` assertion in the two
m-grouped dispatch paths to also allow `kFloat` when both A/B are kPackedFP4
(my MXF4 kernel hardcodes fp32 output; bf16 epilogue branch remains TODO).

Final result: `python tests/test_fp4.py` -> ALL FP4 TESTS PASSED
  constant, random, sweep, asymmetric, uniform_sf, random_sf, multicast,
  m_grouped (contiguous), m_grouped_masked, trtllm_cmp.

This completes the FP4 main-sync port from fea-fp4 -> fea-fp4-synced.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Replace the v0 hardcoded layout (block_n=112, cluster=1, swap_ab=false) with
a port of fea-fp4's wave-aware FP4 heuristic. Now returns main's nested Layout
struct instead of fea-fp4's flat GemmConfig.

pick_fp4_layout(gemm_type, m, n, k, num_groups, num_sms, expected_m_per_group):
- block_m = 128 (fixed, MXF4 UMMA_M)
- block_k = 128 bytes (= 32 int32 = 256 FP4 per K block)
- block_n picked by wave count + composite score = est_stages^2 * bn;
  2-epi-stage tiebreak when waves >= 2 (TMEM double-buffer benefit needs
  >= 2 tiles per SM).
- cluster_m=2, cluster_n=1 (B-multicast) when m >= 512, divisible, and gemm
  type is Normal / KGroupedContiguous. m-grouped stays cluster=1.
- swap_ab=true for MGroupedContiguous when useful_m_per_group < BLOCK_M
  (sparse MoE benefit; forces block_n=128 + cluster=1 per kernel asserts).

Wrapper changes:
- 3 entry points now call pick_fp4_layout instead of get_best_config + override.
- m-grouped contiguous derives useful_per_group from (m_indices >= 0).sum() / G
  to drive swap_ab gating accurately (matches fea-fp4 wrapper behavior).
- m-grouped masked passes expected_m as per-group hint (swap_ab itself disabled
  for masked per fea-fp4 v0).
- recompute_stages_for_fp4: fix CD smem formula to account for swap_ab path
  (STORE_BLOCK_M=16 * STORE_BLOCK_N=block_n * sizeof(cd_dtype=fp32)) — was
  using non-swap formula and underestimating CD smem, causing IMA on the
  swap_ab path when num_stages was over-budgeted.

Perf vs fea-fp4 (m-grouped contiguous prod shapes):
  G=4 m=8192 N=4096 K=7168: 4290 -> 4255 TFLOPS (-0.8%)
  G=4 m=8192 N=7168 K=2048: 2879 -> 2874 TFLOPS (-0.2%)
  G=8 m=4096 N=4096 K=7168: 4270 -> 4249 TFLOPS (-0.5%)
  G=8 m=4096 N=7168 K=2048: 2877 -> 2873 TFLOPS (-0.1%)
Parity within noise.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
The previous test_fp4.py (752 lines) used custom pack_fp4_random + CPU LUT
reference + bespoke PASS/FAIL printing, predating the QuantConfig / generators
infrastructure. This rewrites it to match the official test_fp8_fp4.py:

- Reuse generate_normal / generate_m_grouped_contiguous / generate_m_grouped_masked
  from tests/generators.py instead of pack_fp4_random + per-test fp4_reference.
- QuantConfig((32, 32, True, True)) drives both quantization and reference.
- assert + max_diff() threshold instead of custom PASS/FAIL prints.
- '> Perf (m=..., n=..., k=..., 1D1D, layout=NT, FP32): X us | Y TFLOPS | Z GB/s'
  output line matches main's official format exactly.
- Three test functions named to align with test_fp8_fp4.py:
    test_gemm, test_m_grouped_gemm_contiguous, test_m_grouped_gemm_masked.
- Entry block: torch.manual_seed(0); random.seed(0); print('Library path:'); test_*().

Shape coverage:
  Dense:    m in [128, 4096] x 7 (n, k) production combos from main's nk_list.
  Contig:   (num_groups, m_per_group) in [(4, 8192), (8, 4096)] x 4 (n, k).
  Masked:   (num_groups, expected_m) in [(1, 1024), (2, 512), (4, 256)] x 2 (n, k).
All shapes PASS with diff < QuantConfig.max_diff() = 0.02.

Perf parity check (same kernel, same heuristic, only test wrapper changed):
  4, 8192, 4096, 7168: 4255 -> 4306 TFLOPS (+1.2%, noise)
  4, 8192, 7168, 2048: 2874 -> 2891 TFLOPS (+0.6%, noise)
  8, 4096, 4096, 7168: 4249 -> 4305 TFLOPS (+1.3%, noise)
  8, 4096, 7168, 2048: 2873 -> 2849 TFLOPS (-0.8%, noise)

Note: kernel hardcodes fp32 output (bf16 epilogue is a remaining TODO); the
test sets out_dtype=torch.float and casts ref_d from generators' bf16 default
for the diff check.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Comment hygiene pass against the official FP8 kernel's style:
- Removed 4 unused helper functions (fp4_e2m1_to_float, pack_ue8m0_*,
  swizzled_smem_k_major_idx) — leftover debug code, never called.
- Replaced 18 '// ========== Chinese label ==========' section banners with
  short English labels matching the FP8 kernel ('// MMA configs',
  '// SF configs', '// Shared memory sizes', '// Block scheduler', etc.).
- Translated the M-wave / warp-row band stride explanation: kTmemDpStride
  / kTmemMWaveStride / kTmemWarpRowBandStride were only referenced from
  commented-out alternative tmem_addr formulas, so all three constants
  and the associated dead-code blocks are removed.
- Removed verbose SF-packing math docstrings (3 lines) and inline
  type-cast notes — the code is self-evident.

Net: -73 lines (-100 deletions / +27 insertions). 0 Chinese characters left
in any branch file. test_fp4.py still PASS (exit 0).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Apply the P0/P1/P2 items from FP4_PR_REVIEW_NOTES.md that don't change the
swap_ab gating contract:

P0-1. Add cudaGridDependencySynchronize() before block scheduling in
      sm100_fp4_gemm_1d1d.cuh. Mirrors sm100_fp8_fp4_gemm_1d1d.cuh — required
      because the API transforms SFA/SFB in a prior kernel; without the wait
      the GEMM can race the producer.

P0-2. Tighten the D dtype assertions in csrc/apis/gemm.hpp at all three
      FP4xFP4 dispatch sites (fp8_fp4_gemm_nt, m_grouped_*_contiguous,
      m_grouped_*_masked): when both a and b are kPackedFP4, require
      d.scalar_type() == kFloat. The kernel has no bf16 epilogue store path;
      previously the assertion still accepted bf16 and silently produced
      garbage output.

P1-2. Drop trailing whitespace at the `} else if` join in the kernel
      (`git diff --check origin/main..HEAD` is now clean).

P2.   Remove unused template/runtime plumbing:
      - kNumLastStages template parameter (kernel body recomputes the value
        from runtime shape_k; the compile-time arg was never referenced).
      - tensor_map_c kernel arg, Args field, wrapper-side make_tma_cd_desc
        call, and the local `cd` alias. Accumulation is folded into D
        pre-launch by `d.copy_(c.value())`, so the kernel never loads C.
      - compute_num_last_stages_fp4() helper (callers gone).

P2 (style). Rewrite remaining comments to drop personal/migration wording
      ("my kernel", "fea-fp4", "Phase 2", "NaN issues unresolved",
      "wait simpler", exploratory math notes). Keep only invariants needed
      to read the code.

Swap_ab gating in pick_fp4_layout still uses
`(grouped_layout >= 0).sum().item()` for MGroupedContiguous (P0-3 deferred).
That host sync only fires when swap_ab might apply; documented as a known
limitation for CUDA-graph callers.

Build clean. test_fp4.py PASS (exit 0). Manual bf16 input check now raises
the strict-dtype assertion instead of returning zeroed D.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Replaces the per-launch device→host sync
`(grouped_layout >= 0).sum().item<int64_t>() / num_groups` with a host-side
hint plumbed through the dispatch:

- sm100_m_grouped_fp4_gemm_contiguous_1d1d gains an `expected_m_per_group`
  int parameter, consumed directly by pick_fp4_layout for swap_ab gating.
- csrc/apis/gemm.hpp dispatch passes
  `expected_m_for_psum_layout.value_or(m / num_groups)`:
    - Caller-known sparse MoE workloads (DeepSeek-style routed expert
      shapes) pass the actual useful per-group row count via the existing
      `expected_m_for_psum_layout` keyword and keep the swap_ab speedup.
    - Dense callers get `m / num_groups`, which matches the actual
      per-group count when m_indices has no padding.

No Python API surface change; the optional `expected_m_for_psum_layout`
argument was already available. Removes the only host sync inside FP4
grouped dispatch — kernel is now safe to capture in a CUDA graph.

Build clean. test_fp4.py PASS (exit 0). Production-shape perf unchanged
(swap_ab didn't fire for these dense shapes either way).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
… style)

Re-audit follow-ups from FP4_PR_REVIEW_NOTES.md:

1. Guard the FP4xFP4 dispatch against MN-major / transposed inputs.
   `a.view(torch::kInt)` requires `stride(-1) == 1`, which the API's
   nn/tn/tt aliases break. Add explicit major checks at all three
   FP4xFP4 dispatch sites in csrc/apis/gemm.hpp:
     DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == ...);

2. Assert `k % 8 == 0` at the same dispatch sites. The wrapper divides
   logical FP4 K by 8 to get packed int32 K; non-8-multiple K would be
   silently truncated.

3. Drop the redundant `c -> d` copy in sm100_fp4_gemm_1d1d. early_return()
   in csrc/apis/gemm.hpp already performs the merge before dispatch.

4. Stop overloading expected_m_for_psum_layout as a sparse-MoE hint —
   the shared layout-check code asserts it is unset when use_psum_layout
   is false, so the .value_or() never fired. The FP4xFP4 contiguous path
   now uses m / num_groups unconditionally and the comment reflects that.

5. Comment cleanup: drop "==========" banners and "Existing non-swap
   epilogue (unchanged)" wording in the .cuh, drop the "hardcodes" /
   "TODO" / "proper fp32 reference" notes in tests/test_fp4.py and the
   "matches main's block_k convention" line in the .hpp.

Build clean, full test_fp4.py PASS (exit 0, 0 FAIL).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
The previous comment claimed sparse MoE callers pass an actual
per-group row count, but no such path exists — the API passes
m / num_groups unconditionally. Rewrite to describe the real
behavior: dense layouts are accurate, sparse layouts with -1
padding over-estimate and conservatively disable swap_ab, which
is the accepted trade-off vs a device sync to inspect m_indices.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
@z52527 z52527 changed the base branch from main to nv_dev June 2, 2026 02:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant