Skip to content

[PyTorch] Integrate the cuBLAS single GEMM MXFP8 NN, NT support for sm120#3050

Draft
KshitijLakhani wants to merge 8 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/test/mxfp8-cublas-gemm-sm120
Draft

[PyTorch] Integrate the cuBLAS single GEMM MXFP8 NN, NT support for sm120#3050
KshitijLakhani wants to merge 8 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/test/mxfp8-cublas-gemm-sm120

Conversation

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani commented May 28, 2026

Description

Enable MXFP8 support for NT, NN single GEMMs via cuBLAS for sm120.
fwd=TN, dgrad=NN, wgrad=NT.

Fixes #2668

This PR is complimentary to: #2833
It would be best to merge PR 3050 along with (ideally, right after) PR 2833.
The reason is that PR 2833 goes hand in hand with a CI PR that enables sm120 in the CI as well. Hence, it would be best to merge 3050 when sm120 in CI is available and not prior.

TODO: Code and git history clean up

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Modify _compute_mxfp8_support() to enable usage of cuBLAS MXFP8 NN and NT single GEMMs in TE
  • Add a helper _compute_mxfp8_grouped_gemm_support() to differentiate support between single GEMMs (_compute_mxfp8_support()) and grouped GEMM (_compute_mxfp8_grouped_gemm_support()) - the latter is not supported in cuBLAS yet

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Adds NVTE_ENABLE_MXFP8_SM120 environment variable to unblock MXFP8
testing on sm120 (compute capability 12.0) devices. Default behavior
unchanged; MXFP8 remains gated off on sm120 without explicit opt-in
since not all GEMM layouts are currently supported.

Also adds tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py: a focused
layout x shape x dtype matrix exercising MXFP8 single GEMM via the
underlying general_gemm call directly. The TN layout is exercised
across small/medium/transformer-sized shapes and BF16/FP32 outputs.
NN and NT layouts on sm120 are marked strict-xfail; the suite will
fail-on-XPASS once full layout support is added so the markers can
be removed.
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…13.6+

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Remove tests/pytorch/mxfp8/test_mxfp8_gemm_exact.py. The TN/NN/NT MXFP8
GEMM code paths it was added to localize are already exercised end-to-end
on sm_120 (with cuBLASLt >= 13.6.0.2) by the existing te.Linear /
te.LayerNormLinear / te.GroupedLinear / te.TransformerLayer numerics
tests in tests/pytorch/test_numerics.py via the MXFP8BlockScaling entry
in fp8_recipes (each Linear forward + backward dispatches the three
cuBLAS GEMMs as fwd=TN, dgrad=NN, wgrad=NT).

The runtime _compute_mxfp8_support gate added in the earlier commits
on this branch already module-skips MXFP8 below cuBLASLt 13.6.0.2 on
sm_120, so the per-layout strict-xfail layer in this file is redundant.
Out-of-tree triage material (Testing/repro_mxfp8_layouts.cu and the
Testing/repro_mxfp8_layouts.py driver) remains available if a future
cuBLAS regression needs layout-localized signal again.
@KshitijLakhani KshitijLakhani changed the title [PyTorch] Integrate the cuBLAS MXFP8 NN, NT support for sm120 [PyTorch] Integrate the cuBLAS single GEMM MXFP8 NN, NT support for sm120 May 28, 2026
cuBLASLt 13.6.0.2 supports single-GEMM MXFP8 on sm_120 / sm_121 but not
the grouped variant. Route general_grouped_gemm and
general_grouped_gemm_for_grouped_tensor through
check_mxfp8_grouped_gemm_support() and raise NotImplementedError when
unsupported, instead of failing opaquely inside cuBLAS.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Introduce _compute_mxfp8_grouped_gemm_support / check_mxfp8_grouped_gemm_support
and a public is_mxfp8_grouped_gemm_available helper so callers (te.GroupedLinear,
general_grouped_gemm[_for_grouped_tensor], and grouped-GEMM tests) can gate on
grouped MXFP8 separately from single-GEMM MXFP8. On sm_120 / sm_121, cuBLASLt
13.6.0.2 enables single MXFP8 GEMM (TN/NN/NT) but not the grouped variant; the
new gate returns False there with a descriptive reason. Also widen the
single-GEMM gate to sm_121 alongside sm_120.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
- Probe is_mxfp8_grouped_gemm_available in test_fusible_ops, test_numerics,
and test_sanity, and pytest.skip MXFP8 grouped_linear / padding_grouped_linear /
grouped_gemm cases (plus a maybe_skip_quantization_for_grouped_gemm helper in
test_fusible_ops) with the gate's reason.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/test/mxfp8-cublas-gemm-sm120 branch from ffa0eab to 0370659 Compare May 29, 2026 00:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Inquiry] Support status and roadmap for MXFP8 on SM120 (Blackwell)

1 participant