Skip to content
17 changes: 17 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@
# Check for supported quantization schemes
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = te.is_mxfp8_grouped_gemm_available(
return_reason=True
)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)

# Supported data types
Expand Down Expand Up @@ -89,6 +92,17 @@ def _reset_rng_states_per_test():
yield


def maybe_skip_quantization_for_grouped_gemm(quantization: Optional[str]) -> None:
"""Skip MXFP8 grouped-GEMM cases on devices where they're not yet supported.

cuBLASLt 13.6.0.2 supports single-GEMM MXFP8 on sm_120 but not grouped
MXFP8 GEMM; the grouped-MXFP8 dispatch in ``general_grouped_gemm`` /
``general_grouped_gemm_for_grouped_tensor`` will refuse those inputs.
"""
if quantization == "mxfp8" and not mxfp8_grouped_gemm_available:
pytest.skip(reason_for_no_mxfp8_grouped_gemm)


def maybe_skip_quantization(
quantization: Optional[str],
*,
Expand Down Expand Up @@ -2111,6 +2125,7 @@ def test_grouped_linear(
# Skip invalid configurations
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization(quantization, dims=out_shape)
maybe_skip_quantization_for_grouped_gemm(quantization)
if quantization is None and (quantized_compute or quantized_weight):
pytest.skip("Quantization scheme is not specified")
if quantization is not None and not (quantized_compute or quantized_weight):
Expand Down Expand Up @@ -2316,6 +2331,7 @@ def test_grouped_linear_cuda_graph_safe(
pytest.skip("quantized_weight requires a quantization recipe")
if single_grouped_bias and not bias:
pytest.skip("single_grouped_bias requires bias=True")
maybe_skip_quantization_for_grouped_gemm(quantization)

# Split sizes (statically pinned for graph capture)
split_sizes = [split_alignment * (i + 1) for i in range(group_size)]
Expand Down Expand Up @@ -3741,6 +3757,7 @@ def test_grouped_mlp(
raise ValueError(f"Unexpected grouped MLP activation ({activation})")
activation_is_glu = is_glu_activation(scaled_act)
maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype)
maybe_skip_quantization_for_grouped_gemm(quantization)
if single_grouped_weight and quantization != "mxfp8":
pytest.skip("single_grouped_weight is only supported for MXFP8 quantization")
if single_grouped_bias and not bias:
Expand Down
16 changes: 16 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
get_device_compute_capability,
is_fp8_available,
is_mxfp8_available,
is_mxfp8_grouped_gemm_available,
is_fp8_block_scaling_available,
is_bf16_available,
is_nvfp4_available,
Expand All @@ -60,6 +61,9 @@
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = is_mxfp8_grouped_gemm_available(
return_reason=True
)
fp8_block_scaling_available = is_fp8_block_scaling_available()
nvfp4_available = is_nvfp4_available()

Expand Down Expand Up @@ -1954,6 +1958,8 @@ def test_grouped_linear_accuracy(
pytest.skip("FP8 parameters are not supported in debug mode.")
if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
if fp8 and recipe.mxfp8() and not mxfp8_grouped_gemm_available:
pytest.skip(reason_for_no_mxfp8_grouped_gemm)
skip_unsupported_backward_override(
"grouped_linear", recipe, getattr(recipe, "backward_override", None)
)
Expand Down Expand Up @@ -2101,6 +2107,8 @@ def test_grouped_linear_accuracy_save_original_input(
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute:
pytest.skip("Delayed wgrad compute is not supported in debug mode.")
if fp8 and recipe.mxfp8() and not mxfp8_grouped_gemm_available:
pytest.skip(reason_for_no_mxfp8_grouped_gemm)
skip_unsupported_backward_override(
"grouped_linear", recipe, getattr(recipe, "backward_override", None)
)
Expand Down Expand Up @@ -2310,6 +2318,8 @@ def test_padding_grouped_linear_accuracy(
):
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.mxfp8() and not mxfp8_grouped_gemm_available:
pytest.skip(reason_for_no_mxfp8_grouped_gemm)
skip_unsupported_backward_override(
"grouped_linear", recipe, getattr(recipe, "backward_override", None)
)
Expand Down Expand Up @@ -2390,6 +2400,8 @@ def test_padding_grouped_linear_accuracy_save_original_input(
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.delayed():
pytest.skip("DelayedScaling recipe is not supported with save_original_input")
if fp8 and recipe.mxfp8() and not mxfp8_grouped_gemm_available:
pytest.skip(reason_for_no_mxfp8_grouped_gemm)
skip_unsupported_backward_override(
"grouped_linear", recipe, getattr(recipe, "backward_override", None)
)
Expand Down Expand Up @@ -3095,6 +3107,8 @@ def test_grouped_gemm_grouped_tensor_zero_work(layout, accumulate, quant_type) -
pytest.skip("bfloat16 is required for grouped GEMM test.")
if quant_type == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quant_type == "mxfp8" and not mxfp8_grouped_gemm_available:
pytest.skip(reason_for_no_mxfp8_grouped_gemm)

z = 4
k, n = 256, 256
Expand Down Expand Up @@ -3263,6 +3277,8 @@ def test_grouped_gemm_grouped_tensor_mxfp8(
pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.")
if dtype == torch.bfloat16 and not is_bf16_available():
pytest.skip("bfloat16 is required for grouped GEMM test.")
if not mxfp8_grouped_gemm_available:
pytest.skip(reason_for_no_mxfp8_grouped_gemm)

torch.manual_seed(0)
z, m, k, n = shape
Expand Down
5 changes: 5 additions & 0 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = te.is_mxfp8_grouped_gemm_available(
return_reason=True
)
nvfp4_available, _ = te.is_nvfp4_available(return_reason=True)

# Record initial RNG state from script run.
Expand Down Expand Up @@ -607,6 +610,8 @@ def test_sanity_grouped_linear(
if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
if fp8_recipe.mxfp8() and not mxfp8_grouped_gemm_available:
pytest.skip(reason_for_no_mxfp8_grouped_gemm)
if fp8_recipe.nvfp4():
if not getattr(fp8_recipe, "row_scaled_activation", False):
pytest.skip("NVFP4 not supported for grouped linear")
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from transformer_engine.pytorch.quantization import quantized_model_init
from transformer_engine.pytorch.quantization import is_fp8_available
from transformer_engine.pytorch.quantization import is_mxfp8_available
from transformer_engine.pytorch.quantization import is_mxfp8_grouped_gemm_available
from transformer_engine.pytorch.quantization import is_fp8_block_scaling_available
from transformer_engine.pytorch.quantization import is_nvfp4_available
from transformer_engine.pytorch.quantization import get_default_recipe
Expand Down
44 changes: 44 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from ..constants import TE_DType
from ..utils import get_sm_count, _empty_tensor

from ..quantization import check_mxfp8_grouped_gemm_support
from ..quantized_tensor import Quantizer
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.storage.grouped_tensor_storage import GroupedTensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from ..tensor.utils import is_custom
from ..custom_recipes.gemm import custom_gemm
Expand Down Expand Up @@ -76,6 +78,44 @@ def _is_nvfp4_row_scaled_tensor(tensor: torch.Tensor) -> bool:
return isinstance(tensor, NVFP4TensorStorage) and tensor._row_scaled_nvfp4


def _is_mxfp8_storage(tensor) -> bool:
"""Whether ``tensor`` is MXFP8-quantized storage (per-tensor or grouped)."""
if tensor is None:
return False
if isinstance(tensor, MXFP8TensorStorage):
return True
if isinstance(tensor, GroupedTensorStorage):
quantizer = getattr(tensor, "quantizer", None)
if quantizer is not None:
try:
recipe = quantizer._get_compatible_recipe()
except (AttributeError, NotImplementedError):
return False
return bool(recipe.mxfp8())
return False


def _check_mxfp8_grouped_gemm_inputs(*tensor_iterables) -> None:
"""Raise ``NotImplementedError`` if MXFP8 grouped GEMM is requested but
not supported on the current device / cuBLASLt combination.

Accepts one or more iterables of tensors (e.g. ``A``, ``B`` from
:func:`general_grouped_gemm`) so callers don't have to flatten lists.
"""
for tensors in tensor_iterables:
if tensors is None:
continue
if isinstance(tensors, (list, tuple)):
has_mxfp8 = any(_is_mxfp8_storage(t) for t in tensors)
else:
has_mxfp8 = _is_mxfp8_storage(tensors)
if has_mxfp8:
supported, reason = check_mxfp8_grouped_gemm_support()
if not supported:
raise NotImplementedError(reason)
return


def _nvfp4_row_scaled_gemm_inputs(
A: NVFP4TensorStorage,
B: NVFP4TensorStorage,
Expand Down Expand Up @@ -294,6 +334,8 @@ def general_grouped_gemm(
"""
TN layout Grouped GEMM with fp8 inputs.
"""
_check_mxfp8_grouped_gemm_inputs(A, B)

num_gemms = len(A)

transa = layout[0] == "T"
Expand Down Expand Up @@ -470,6 +512,8 @@ def general_grouped_gemm_for_grouped_tensor(
The caller must ensure that GroupedTensor metadata is already compatible with the
underlying GEMM implementation (e.g., aligned offsets and output metadata layout).
"""
_check_mxfp8_grouped_gemm_inputs(A, B)

assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported."
if grad:
raise NotImplementedError("grad is not supported for grouped_tensor GEMM yet.")
Expand Down
102 changes: 97 additions & 5 deletions transformer_engine/pytorch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"quantized_model_init",
"is_fp8_available",
"is_mxfp8_available",
"is_mxfp8_grouped_gemm_available",
"is_fp8_block_scaling_available",
"is_nvfp4_available",
"get_default_recipe",
Expand All @@ -49,6 +50,7 @@

_FP8_SUPPORT: Optional[Tuple[bool, str]] = None
_MXFP8_SUPPORT: Optional[Tuple[bool, str]] = None
_MXFP8_GROUPED_GEMM_SUPPORT: Optional[Tuple[bool, str]] = None
_NVFP4_SUPPORT: Optional[Tuple[bool, str]] = None
_FP8_BLOCK_SCALING_SUPPORT: Optional[Tuple[bool, str]] = None

Expand Down Expand Up @@ -160,14 +162,54 @@ def _compute_fp8_support() -> Tuple[bool, str]:


def _compute_mxfp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if get_device_compute_capability() >= (12, 0):
return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet."
"""Return if MXFP8 single-GEMM support is available.

On sm_120 / sm_121 this covers the single-GEMM TN/NN/NT paths via cuBLASLt >= 13.6.0.2;
grouped MXFP8 GEMM is gated separately by :func:`_compute_mxfp8_grouped_gemm_support`.
"""
if get_device_compute_capability() in ((12, 0), (12, 1)):
cublaslt_version = tex.get_cublasLt_version()
if cublaslt_version >= 130600:
return True, ""
return (
False,
(
"MXFP8 on sm_120 / sm_121 requires cuBLASLt >= 13.6.0.2 for NN/NT GEMM "
f"support (loaded cuBLASLt={cublaslt_version})."
),
)
if get_device_compute_capability() >= (10, 0): # blackwell and above
return True, ""
return False, "Device compute capability 10.0 or higher required for MXFP8 execution."


def _compute_mxfp8_grouped_gemm_support() -> Tuple[bool, str]:
"""Return if MXFP8 *grouped*-GEMM support is available.

This is strictly a subset of single-GEMM MXFP8 support: it inherits the
requirements of :func:`_compute_mxfp8_support` and then additionally
requires that the loaded cuBLASLt implements grouped MXFP8 GEMM on the
current device. On sm_120 / sm_121 the cuBLASLt 13.6.0.2 release supports
single MXFP8 GEMM (TN/NN/NT) but does NOT yet implement grouped MXFP8 GEMM;
callers should treat that case as unsupported until a future cuBLAS adds it.
"""
base_ok, base_reason = _compute_mxfp8_support()
if not base_ok:
return False, base_reason
if get_device_compute_capability() in ((12, 0), (12, 1)):
cublaslt_version = tex.get_cublasLt_version()
return (
False,
(
"MXFP8 grouped GEMM is not yet supported on sm_120 / sm_121 by the loaded "
f"cuBLASLt={cublaslt_version} (single-GEMM MXFP8 is supported with "
"cuBLASLt >= 13.6.0.2). Use a non-grouped module or switch to a "
"non-MXFP8 recipe (e.g. Float8CurrentScaling) for grouped-GEMM workloads."
),
)
return True, ""


def _compute_nvfp4_support() -> Tuple[bool, str]:
"""Return if nvfp4 support is available"""
if get_device_compute_capability() >= (10, 0): # blackwell and above
Expand Down Expand Up @@ -196,13 +238,22 @@ def check_fp8_support() -> Tuple[bool, str]:

@torch.compiler.assume_constant_result
def check_mxfp8_support() -> Tuple[bool, str]:
"""Return if MXFP8 support is available."""
"""Return if MXFP8 single-GEMM support is available."""
global _MXFP8_SUPPORT
if _MXFP8_SUPPORT is None:
_MXFP8_SUPPORT = _compute_mxfp8_support()
return _MXFP8_SUPPORT


@torch.compiler.assume_constant_result
def check_mxfp8_grouped_gemm_support() -> Tuple[bool, str]:
"""Return if MXFP8 grouped-GEMM support is available."""
global _MXFP8_GROUPED_GEMM_SUPPORT
if _MXFP8_GROUPED_GEMM_SUPPORT is None:
_MXFP8_GROUPED_GEMM_SUPPORT = _compute_mxfp8_grouped_gemm_support()
return _MXFP8_GROUPED_GEMM_SUPPORT


@torch.compiler.assume_constant_result
def check_nvfp4_support() -> Tuple[bool, str]:
"""Return if NVFP4 support is available."""
Expand Down Expand Up @@ -323,7 +374,15 @@ def is_fp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str

def is_mxfp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]:
"""
Determine if support is available for the MXFP8 recipe.
Determine if support is available for the MXFP8 recipe (single GEMM).

This reports support for the single-GEMM MXFP8 dispatch (the common TN/NN/NT
fwd/dgrad/wgrad path used by ``te.Linear`` / ``te.LayerNormLinear`` /
``te.LayerNormMLP`` / ``te.TransformerLayer``). Grouped MXFP8 GEMM
(e.g. ``te.GroupedLinear``, ``general_grouped_gemm``,
``general_grouped_gemm_for_grouped_tensor``) is gated separately by
:func:`is_mxfp8_grouped_gemm_available` because it may be unsupported on
some device + cuBLASLt combinations even when single-GEMM MXFP8 is supported.

Parameters
----------
Expand All @@ -339,6 +398,34 @@ def is_mxfp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, s
return check_mxfp8_support()[0]


def is_mxfp8_grouped_gemm_available(
return_reason: bool = False,
) -> Union[bool, Tuple[bool, str]]:
"""
Determine if support is available for MXFP8 grouped GEMM.

MXFP8 grouped GEMM is a strict superset of single-GEMM MXFP8 in terms of
requirements: the underlying cuBLASLt must implement the grouped MXFP8
GEMM heuristic for the current device. Use this check to gate
``te.GroupedLinear`` / ``general_grouped_gemm`` /
``general_grouped_gemm_for_grouped_tensor`` dispatch and to skip
MXFP8 grouped-GEMM tests on devices where the underlying cuBLASLt
does not (yet) implement that path.

Parameters
----------
return_reason : bool, optional
If ``False`` (default), return only a boolean indicating availability.
If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides
a human-readable explanation when required support is not available. The reason
will be an empty string if support is available.

"""
if return_reason:
return check_mxfp8_grouped_gemm_support()
return check_mxfp8_grouped_gemm_support()[0]


def is_fp8_block_scaling_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]:
"""
Determine if support is available for the FP8 block scaling recipe.
Expand Down Expand Up @@ -424,6 +511,11 @@ def is_mxfp8_available(cls) -> Tuple[bool, str]:
"""Return if MXFP8/current scaling support is available."""
return check_mxfp8_support()

@classmethod
def is_mxfp8_grouped_gemm_available(cls) -> Tuple[bool, str]:
"""Return if MXFP8 grouped-GEMM support is available."""
return check_mxfp8_grouped_gemm_support()

@classmethod
def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]:
"""Return if Float8 block scaling support is available."""
Expand Down
Loading