feat: support mxfp4 × mxfp4 gemm on sm100#348
Open
z52527 wants to merge 12 commits into
Open
Conversation
Migrate fea-fp4's standalone MXF4 (SM100_MMA_MXF4_SS) GEMM on top of main's
post-PR-304 refactor. Keeps the FP4-specialized hardware path distinct from
main's MXF8F6F4 unified kernel.
- csrc/jit_kernels/impls/sm100_fp4_gemm_1d1d.hpp: rebuilt on GemmDesc + nested
GemmConfig {layout, storage_config, pipeline_config, launch_config}. Forces
a_dtype=b_dtype=kPackedFP4. Three entry points: dense, m-grouped contiguous,
m-grouped masked. BLOCK_K and SHAPE_K convert main's byte/FP4-count to my
kernel's int32-count units.
- deep_gemm/include/deep_gemm/impls/sm100_fp4_gemm_1d1d.cuh: scheduler ->
sched::Scheduler with shape_k ctor arg; KGroupedIndexType -> sched::IndexType;
helpers via math:: / ptx:: / mma::sm100:: namespaces; tma_copy signature
(multicast moved to runtime); make_runtime_instr_desc_with_sf_id takes
sfb_id. swap_ab epilogue / MXF4 MMA path preserved.
- csrc/apis/gemm.hpp: dispatch FP4xFP4 -> sm100_fp4_gemm_1d1d (MXF4 path);
FP8xFP4 and other mixed cases unchanged -> main's sm100_fp8_fp4_gemm_1d1d.
Status: host build clean, JIT NVCC clean (was 19 errors before). Runtime
IMA on first launch -- next: localize via compute-sanitizer (TMA descriptor
K-unit or SF tensor layout suspected). tests/test_fp4.py copied as-is from
fea-fp4; needs recipe=(1,1,32) update for FP4 SF granularity.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
…lamp Stack of incremental fixes on top of previous port commit (7d27b8a). Build clean, JIT clean, kernel launches without IMA. NaN output remains (deeper SF/data-layout debug needed). Wrapper changes (sm100_fp4_gemm_1d1d.hpp): - view int8 (kPackedFP4) tensors as int32 before TMA descriptor creation, so TMA uses CU_TENSOR_MAP_DATA_TYPE_INT32 (matches my kernel's int32-packed smem expectation, sidesteps main's 16U4_ALIGN16B unpacked-smem path). - Convert k and block_k to int32 units (k/8, block_k/4) for TMA + kernel template instantiation. - num_stages clamp: my kernel allocates 2x SF smem/stage vs main's heuristic estimate; cap stages so total smem fits 232448-byte capacity and num_stages <= num_k_blocks (fea-fp4 invariant). - Single-CTA force (v0): override cluster_m=cluster_n=1; recompute storage_config / pipeline_config so they stay self-consistent with the new layout. main's heuristic picks cluster=2 too aggressively for shapes fea-fp4 stayed single-CTA on. Status: kernel runs, produces wrong output (suspected SF byte ordering or TMEM column placement mismatch between main's UE8M0 packed transform and my kernel's read expectation). fea-fp4 tests historically used sf=1.0 which masks SF reading bugs (1.0 multiplier regardless of byte order); my smoke test with main's generators uses varying SF, exposing the issue. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Continuing port towards numerical correctness. Current state: kernel builds clean, JIT clean, runs without IMA, but writes zeros / wrong values. Wrapper changes (sm100_fp4_gemm_1d1d.hpp): - Override main's layout choice to fea-fp4's tested config: block_m=128, block_n=112 (kSwizzleCDMode=32, matches fea-fp4 test coverage), single-CTA cluster. main's heuristic picks block_n=16 / cluster=2 which fea-fp4 kernel was never validated against. - Fix sf_packed_k_per_stage formula in recompute_stages_for_fp4 (was off by 4x with block_k_int32/4; correct: block_k_bytes/64). Test adaptation (tests/test_fp4.py): - pack_fp4_random / pack_fp4_constant now produce int8 (kPackedFP4) instead of int32, matching main's API. Same byte layout when viewed as int32. - fp4_reference handles int8 packing (2 FP4 per byte: low nibble + high nibble), works for both int8 and int32 packed inputs. - run_kernel uses recipe_a=(1,32), recipe_b=(1,32) for FP4 SF granularity (was (1,1,128) which fea-fp4 commented out shape check to allow). Open issue: kernel writes 0 / wrong values for both test_constant (sf=1.0) and test_random. Needs device-side printf to localize: A/B smem content vs expected layout, MMA output before vs after epilogue, SF byte ordering in TMEM after UTCCP. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
…gue path) Root cause of port "writes zeros" bug found: the kernel's epilogue store loop has only an `if constexpr (cd_dtype_t == float)` branch — no `else` for bf16 — so when d.dtype is bf16, the store body is a no-op and d stays at initial value (zero). This matches the existing memory `fp4_output_dtype.md` note that bf16 epilogue branch is a missing TODO. fea-fp4's run_kernel uses `torch.float32` for d; my earlier port test was generating with `torch.bfloat16` (main's typical FP8 output), which silently hit the no-op branch. Switching back to fp32 makes everything work. Verified PASS on fea-fp4-synced after this fix: test_constant (9/9 PASS, exact integer match for sf=1.0 case) test_random (9/9 PASS, max_diff=0.0000 vs CPU reference) test_random_sf (8/8 PASS, max_diff=0.0000 with varying SF) This closes the dense-FP4 port. m-grouped (contiguous + masked) paths still need a similar walkthrough and test_m_grouped_* validation. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Complete the test adaptation to main's API surface: - Switch m_grouped helpers to int8 (kPackedFP4) packing for both 2D and 3D pack_fp4_random variants. - Rename remaining deep_gemm.m_grouped_fp8_gemm_* -> m_grouped_fp8_fp4_gemm_*. - Use recipe_a=(1,32) / recipe_b=(1,32) for FP4 SF granularity throughout. csrc/apis/gemm.hpp: relax `d.scalar_type() == kBFloat16` assertion in the two m-grouped dispatch paths to also allow `kFloat` when both A/B are kPackedFP4 (my MXF4 kernel hardcodes fp32 output; bf16 epilogue branch remains TODO). Final result: `python tests/test_fp4.py` -> ALL FP4 TESTS PASSED constant, random, sweep, asymmetric, uniform_sf, random_sf, multicast, m_grouped (contiguous), m_grouped_masked, trtllm_cmp. This completes the FP4 main-sync port from fea-fp4 -> fea-fp4-synced. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Replace the v0 hardcoded layout (block_n=112, cluster=1, swap_ab=false) with a port of fea-fp4's wave-aware FP4 heuristic. Now returns main's nested Layout struct instead of fea-fp4's flat GemmConfig. pick_fp4_layout(gemm_type, m, n, k, num_groups, num_sms, expected_m_per_group): - block_m = 128 (fixed, MXF4 UMMA_M) - block_k = 128 bytes (= 32 int32 = 256 FP4 per K block) - block_n picked by wave count + composite score = est_stages^2 * bn; 2-epi-stage tiebreak when waves >= 2 (TMEM double-buffer benefit needs >= 2 tiles per SM). - cluster_m=2, cluster_n=1 (B-multicast) when m >= 512, divisible, and gemm type is Normal / KGroupedContiguous. m-grouped stays cluster=1. - swap_ab=true for MGroupedContiguous when useful_m_per_group < BLOCK_M (sparse MoE benefit; forces block_n=128 + cluster=1 per kernel asserts). Wrapper changes: - 3 entry points now call pick_fp4_layout instead of get_best_config + override. - m-grouped contiguous derives useful_per_group from (m_indices >= 0).sum() / G to drive swap_ab gating accurately (matches fea-fp4 wrapper behavior). - m-grouped masked passes expected_m as per-group hint (swap_ab itself disabled for masked per fea-fp4 v0). - recompute_stages_for_fp4: fix CD smem formula to account for swap_ab path (STORE_BLOCK_M=16 * STORE_BLOCK_N=block_n * sizeof(cd_dtype=fp32)) — was using non-swap formula and underestimating CD smem, causing IMA on the swap_ab path when num_stages was over-budgeted. Perf vs fea-fp4 (m-grouped contiguous prod shapes): G=4 m=8192 N=4096 K=7168: 4290 -> 4255 TFLOPS (-0.8%) G=4 m=8192 N=7168 K=2048: 2879 -> 2874 TFLOPS (-0.2%) G=8 m=4096 N=4096 K=7168: 4270 -> 4249 TFLOPS (-0.5%) G=8 m=4096 N=7168 K=2048: 2877 -> 2873 TFLOPS (-0.1%) Parity within noise. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
The previous test_fp4.py (752 lines) used custom pack_fp4_random + CPU LUT
reference + bespoke PASS/FAIL printing, predating the QuantConfig / generators
infrastructure. This rewrites it to match the official test_fp8_fp4.py:
- Reuse generate_normal / generate_m_grouped_contiguous / generate_m_grouped_masked
from tests/generators.py instead of pack_fp4_random + per-test fp4_reference.
- QuantConfig((32, 32, True, True)) drives both quantization and reference.
- assert + max_diff() threshold instead of custom PASS/FAIL prints.
- '> Perf (m=..., n=..., k=..., 1D1D, layout=NT, FP32): X us | Y TFLOPS | Z GB/s'
output line matches main's official format exactly.
- Three test functions named to align with test_fp8_fp4.py:
test_gemm, test_m_grouped_gemm_contiguous, test_m_grouped_gemm_masked.
- Entry block: torch.manual_seed(0); random.seed(0); print('Library path:'); test_*().
Shape coverage:
Dense: m in [128, 4096] x 7 (n, k) production combos from main's nk_list.
Contig: (num_groups, m_per_group) in [(4, 8192), (8, 4096)] x 4 (n, k).
Masked: (num_groups, expected_m) in [(1, 1024), (2, 512), (4, 256)] x 2 (n, k).
All shapes PASS with diff < QuantConfig.max_diff() = 0.02.
Perf parity check (same kernel, same heuristic, only test wrapper changed):
4, 8192, 4096, 7168: 4255 -> 4306 TFLOPS (+1.2%, noise)
4, 8192, 7168, 2048: 2874 -> 2891 TFLOPS (+0.6%, noise)
8, 4096, 4096, 7168: 4249 -> 4305 TFLOPS (+1.3%, noise)
8, 4096, 7168, 2048: 2873 -> 2849 TFLOPS (-0.8%, noise)
Note: kernel hardcodes fp32 output (bf16 epilogue is a remaining TODO); the
test sets out_dtype=torch.float and casts ref_d from generators' bf16 default
for the diff check.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Comment hygiene pass against the official FP8 kernel's style:
- Removed 4 unused helper functions (fp4_e2m1_to_float, pack_ue8m0_*,
swizzled_smem_k_major_idx) — leftover debug code, never called.
- Replaced 18 '// ========== Chinese label ==========' section banners with
short English labels matching the FP8 kernel ('// MMA configs',
'// SF configs', '// Shared memory sizes', '// Block scheduler', etc.).
- Translated the M-wave / warp-row band stride explanation: kTmemDpStride
/ kTmemMWaveStride / kTmemWarpRowBandStride were only referenced from
commented-out alternative tmem_addr formulas, so all three constants
and the associated dead-code blocks are removed.
- Removed verbose SF-packing math docstrings (3 lines) and inline
type-cast notes — the code is self-evident.
Net: -73 lines (-100 deletions / +27 insertions). 0 Chinese characters left
in any branch file. test_fp4.py still PASS (exit 0).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Apply the P0/P1/P2 items from FP4_PR_REVIEW_NOTES.md that don't change the
swap_ab gating contract:
P0-1. Add cudaGridDependencySynchronize() before block scheduling in
sm100_fp4_gemm_1d1d.cuh. Mirrors sm100_fp8_fp4_gemm_1d1d.cuh — required
because the API transforms SFA/SFB in a prior kernel; without the wait
the GEMM can race the producer.
P0-2. Tighten the D dtype assertions in csrc/apis/gemm.hpp at all three
FP4xFP4 dispatch sites (fp8_fp4_gemm_nt, m_grouped_*_contiguous,
m_grouped_*_masked): when both a and b are kPackedFP4, require
d.scalar_type() == kFloat. The kernel has no bf16 epilogue store path;
previously the assertion still accepted bf16 and silently produced
garbage output.
P1-2. Drop trailing whitespace at the `} else if` join in the kernel
(`git diff --check origin/main..HEAD` is now clean).
P2. Remove unused template/runtime plumbing:
- kNumLastStages template parameter (kernel body recomputes the value
from runtime shape_k; the compile-time arg was never referenced).
- tensor_map_c kernel arg, Args field, wrapper-side make_tma_cd_desc
call, and the local `cd` alias. Accumulation is folded into D
pre-launch by `d.copy_(c.value())`, so the kernel never loads C.
- compute_num_last_stages_fp4() helper (callers gone).
P2 (style). Rewrite remaining comments to drop personal/migration wording
("my kernel", "fea-fp4", "Phase 2", "NaN issues unresolved",
"wait simpler", exploratory math notes). Keep only invariants needed
to read the code.
Swap_ab gating in pick_fp4_layout still uses
`(grouped_layout >= 0).sum().item()` for MGroupedContiguous (P0-3 deferred).
That host sync only fires when swap_ab might apply; documented as a known
limitation for CUDA-graph callers.
Build clean. test_fp4.py PASS (exit 0). Manual bf16 input check now raises
the strict-dtype assertion instead of returning zeroed D.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
Replaces the per-launch device→host sync
`(grouped_layout >= 0).sum().item<int64_t>() / num_groups` with a host-side
hint plumbed through the dispatch:
- sm100_m_grouped_fp4_gemm_contiguous_1d1d gains an `expected_m_per_group`
int parameter, consumed directly by pick_fp4_layout for swap_ab gating.
- csrc/apis/gemm.hpp dispatch passes
`expected_m_for_psum_layout.value_or(m / num_groups)`:
- Caller-known sparse MoE workloads (DeepSeek-style routed expert
shapes) pass the actual useful per-group row count via the existing
`expected_m_for_psum_layout` keyword and keep the swap_ab speedup.
- Dense callers get `m / num_groups`, which matches the actual
per-group count when m_indices has no padding.
No Python API surface change; the optional `expected_m_for_psum_layout`
argument was already available. Removes the only host sync inside FP4
grouped dispatch — kernel is now safe to capture in a CUDA graph.
Build clean. test_fp4.py PASS (exit 0). Production-shape perf unchanged
(swap_ab didn't fire for these dense shapes either way).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
… style)
Re-audit follow-ups from FP4_PR_REVIEW_NOTES.md:
1. Guard the FP4xFP4 dispatch against MN-major / transposed inputs.
`a.view(torch::kInt)` requires `stride(-1) == 1`, which the API's
nn/tn/tt aliases break. Add explicit major checks at all three
FP4xFP4 dispatch sites in csrc/apis/gemm.hpp:
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == ...);
2. Assert `k % 8 == 0` at the same dispatch sites. The wrapper divides
logical FP4 K by 8 to get packed int32 K; non-8-multiple K would be
silently truncated.
3. Drop the redundant `c -> d` copy in sm100_fp4_gemm_1d1d. early_return()
in csrc/apis/gemm.hpp already performs the merge before dispatch.
4. Stop overloading expected_m_for_psum_layout as a sparse-MoE hint —
the shared layout-check code asserts it is unset when use_psum_layout
is false, so the .value_or() never fired. The FP4xFP4 contiguous path
now uses m / num_groups unconditionally and the comment reflects that.
5. Comment cleanup: drop "==========" banners and "Existing non-swap
epilogue (unchanged)" wording in the .cuh, drop the "hardcodes" /
"TODO" / "proper fp32 reference" notes in tests/test_fp4.py and the
"matches main's block_k convention" line in the .hpp.
Build clean, full test_fp4.py PASS (exit 0, 0 FAIL).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
The previous comment claimed sparse MoE callers pass an actual per-group row count, but no such path exists — the API passes m / num_groups unconditionally. Rewrite to describe the real behavior: dense layouts are accurate, sparse layouts with -1 padding over-estimate and conservatively disable swap_ab, which is the accepted trade-off vs a device sync to inspect m_indices. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Runchu Zhao <zhaorunchu@gmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds an FP4×FP4 (MXF4) GEMM path on SM100/B200 using
SM100_MMA_MXF4_SSdirectly, separate from main's unified MXF8F6F4 kernel. Dispatched from the existing fp8_fp4 APIs (dense / m-grouped contiguous / m-grouped masked) when both operands are FP4. Tested withtests/test_fp4.pyon B200 (28/28 pass). Bench: median ~1.07× vscublasTest 13.3direct, 10/13 shapes ≥ cuBLAS.Dense GEMM perf (B200, CUDA 12.9, vs cublasTest 13.3 direct binary)
TODO
DG_HOST_ASSERTs for aliasing / K-major /k % 8but no test coverage (repo-wide gap, not FP4-only)m / num_groupsas a host estimate accurate for dense, but misses swap_ab on-1-padded sparse layouts (padding inflates the estimate aboveblock_m). Avoids a device sync onm_indices; revisit if sparse MoE perf is a bottleneck.