diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 1ced32e1a5..9541e94fd0 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -61,6 +61,9 @@ # Check for supported quantization schemes fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = te.is_mxfp8_grouped_gemm_available( + return_reason=True +) nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) # Supported data types @@ -89,6 +92,17 @@ def _reset_rng_states_per_test(): yield +def maybe_skip_quantization_for_grouped_gemm(quantization: Optional[str]) -> None: + """Skip MXFP8 grouped-GEMM cases on devices where they're not yet supported. + + cuBLASLt 13.6.0.2 supports single-GEMM MXFP8 on sm_120 but not grouped + MXFP8 GEMM; the grouped-MXFP8 dispatch in ``general_grouped_gemm`` / + ``general_grouped_gemm_for_grouped_tensor`` will refuse those inputs. + """ + if quantization == "mxfp8" and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) + + def maybe_skip_quantization( quantization: Optional[str], *, @@ -2111,6 +2125,7 @@ def test_grouped_linear( # Skip invalid configurations maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) + maybe_skip_quantization_for_grouped_gemm(quantization) if quantization is None and (quantized_compute or quantized_weight): pytest.skip("Quantization scheme is not specified") if quantization is not None and not (quantized_compute or quantized_weight): @@ -2316,6 +2331,7 @@ def test_grouped_linear_cuda_graph_safe( pytest.skip("quantized_weight requires a quantization recipe") if single_grouped_bias and not bias: pytest.skip("single_grouped_bias requires bias=True") + maybe_skip_quantization_for_grouped_gemm(quantization) # Split sizes (statically pinned for graph capture) split_sizes = [split_alignment * (i + 1) for i in range(group_size)] @@ -3741,6 +3757,7 @@ def test_grouped_mlp( raise ValueError(f"Unexpected grouped MLP activation ({activation})") activation_is_glu = is_glu_activation(scaled_act) maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) + maybe_skip_quantization_for_grouped_gemm(quantization) if single_grouped_weight and quantization != "mxfp8": pytest.skip("single_grouped_weight is only supported for MXFP8 quantization") if single_grouped_bias and not bias: diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 5f82bfcba2..ab2cbaf87d 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -41,6 +41,7 @@ get_device_compute_capability, is_fp8_available, is_mxfp8_available, + is_mxfp8_grouped_gemm_available, is_fp8_block_scaling_available, is_bf16_available, is_nvfp4_available, @@ -60,6 +61,9 @@ # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) +mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = is_mxfp8_grouped_gemm_available( + return_reason=True +) fp8_block_scaling_available = is_fp8_block_scaling_available() nvfp4_available = is_nvfp4_available() @@ -1954,6 +1958,8 @@ def test_grouped_linear_accuracy( pytest.skip("FP8 parameters are not supported in debug mode.") if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: pytest.skip("Delayed wgrad compute is not supported in debug mode.") + if fp8 and recipe.mxfp8() and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) skip_unsupported_backward_override( "grouped_linear", recipe, getattr(recipe, "backward_override", None) ) @@ -2101,6 +2107,8 @@ def test_grouped_linear_accuracy_save_original_input( pytest.skip("DelayedScaling recipe is not supported with save_original_input") if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: pytest.skip("Delayed wgrad compute is not supported in debug mode.") + if fp8 and recipe.mxfp8() and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) skip_unsupported_backward_override( "grouped_linear", recipe, getattr(recipe, "backward_override", None) ) @@ -2310,6 +2318,8 @@ def test_padding_grouped_linear_accuracy( ): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 and recipe.mxfp8() and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) skip_unsupported_backward_override( "grouped_linear", recipe, getattr(recipe, "backward_override", None) ) @@ -2390,6 +2400,8 @@ def test_padding_grouped_linear_accuracy_save_original_input( pytest.skip("FP8 parameters are not supported in debug mode.") if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") + if fp8 and recipe.mxfp8() and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) skip_unsupported_backward_override( "grouped_linear", recipe, getattr(recipe, "backward_override", None) ) @@ -3095,6 +3107,8 @@ def test_grouped_gemm_grouped_tensor_zero_work(layout, accumulate, quant_type) - pytest.skip("bfloat16 is required for grouped GEMM test.") if quant_type == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if quant_type == "mxfp8" and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) z = 4 k, n = 256, 256 @@ -3263,6 +3277,8 @@ def test_grouped_gemm_grouped_tensor_mxfp8( pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") if dtype == torch.bfloat16 and not is_bf16_available(): pytest.skip("bfloat16 is required for grouped GEMM test.") + if not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) torch.manual_seed(0) z, m, k, n = shape diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 27eafbecdc..2a4c672ce8 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -44,6 +44,9 @@ fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +mxfp8_grouped_gemm_available, reason_for_no_mxfp8_grouped_gemm = te.is_mxfp8_grouped_gemm_available( + return_reason=True +) nvfp4_available, _ = te.is_nvfp4_available(return_reason=True) # Record initial RNG state from script run. @@ -607,6 +610,8 @@ def test_sanity_grouped_linear( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.mxfp8() and not mxfp8_grouped_gemm_available: + pytest.skip(reason_for_no_mxfp8_grouped_gemm) if fp8_recipe.nvfp4(): if not getattr(fp8_recipe, "row_scaled_activation", False): pytest.skip("NVFP4 not supported for grouped linear") diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 7653d5992e..533d64fd4d 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -45,6 +45,7 @@ from transformer_engine.pytorch.quantization import quantized_model_init from transformer_engine.pytorch.quantization import is_fp8_available from transformer_engine.pytorch.quantization import is_mxfp8_available +from transformer_engine.pytorch.quantization import is_mxfp8_grouped_gemm_available from transformer_engine.pytorch.quantization import is_fp8_block_scaling_available from transformer_engine.pytorch.quantization import is_nvfp4_available from transformer_engine.pytorch.quantization import get_default_recipe diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index edf2c1e1c2..13c25bb986 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -13,9 +13,11 @@ from ..constants import TE_DType from ..utils import get_sm_count, _empty_tensor +from ..quantization import check_mxfp8_grouped_gemm_support from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..tensor.storage.grouped_tensor_storage import GroupedTensorStorage +from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from ..tensor.utils import is_custom from ..custom_recipes.gemm import custom_gemm @@ -76,6 +78,44 @@ def _is_nvfp4_row_scaled_tensor(tensor: torch.Tensor) -> bool: return isinstance(tensor, NVFP4TensorStorage) and tensor._row_scaled_nvfp4 +def _is_mxfp8_storage(tensor) -> bool: + """Whether ``tensor`` is MXFP8-quantized storage (per-tensor or grouped).""" + if tensor is None: + return False + if isinstance(tensor, MXFP8TensorStorage): + return True + if isinstance(tensor, GroupedTensorStorage): + quantizer = getattr(tensor, "quantizer", None) + if quantizer is not None: + try: + recipe = quantizer._get_compatible_recipe() + except (AttributeError, NotImplementedError): + return False + return bool(recipe.mxfp8()) + return False + + +def _check_mxfp8_grouped_gemm_inputs(*tensor_iterables) -> None: + """Raise ``NotImplementedError`` if MXFP8 grouped GEMM is requested but + not supported on the current device / cuBLASLt combination. + + Accepts one or more iterables of tensors (e.g. ``A``, ``B`` from + :func:`general_grouped_gemm`) so callers don't have to flatten lists. + """ + for tensors in tensor_iterables: + if tensors is None: + continue + if isinstance(tensors, (list, tuple)): + has_mxfp8 = any(_is_mxfp8_storage(t) for t in tensors) + else: + has_mxfp8 = _is_mxfp8_storage(tensors) + if has_mxfp8: + supported, reason = check_mxfp8_grouped_gemm_support() + if not supported: + raise NotImplementedError(reason) + return + + def _nvfp4_row_scaled_gemm_inputs( A: NVFP4TensorStorage, B: NVFP4TensorStorage, @@ -294,6 +334,8 @@ def general_grouped_gemm( """ TN layout Grouped GEMM with fp8 inputs. """ + _check_mxfp8_grouped_gemm_inputs(A, B) + num_gemms = len(A) transa = layout[0] == "T" @@ -470,6 +512,8 @@ def general_grouped_gemm_for_grouped_tensor( The caller must ensure that GroupedTensor metadata is already compatible with the underlying GEMM implementation (e.g., aligned offsets and output metadata layout). """ + _check_mxfp8_grouped_gemm_inputs(A, B) + assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." if grad: raise NotImplementedError("grad is not supported for grouped_tensor GEMM yet.") diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index e503b4b560..6afba9b0f3 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -37,6 +37,7 @@ "quantized_model_init", "is_fp8_available", "is_mxfp8_available", + "is_mxfp8_grouped_gemm_available", "is_fp8_block_scaling_available", "is_nvfp4_available", "get_default_recipe", @@ -49,6 +50,7 @@ _FP8_SUPPORT: Optional[Tuple[bool, str]] = None _MXFP8_SUPPORT: Optional[Tuple[bool, str]] = None +_MXFP8_GROUPED_GEMM_SUPPORT: Optional[Tuple[bool, str]] = None _NVFP4_SUPPORT: Optional[Tuple[bool, str]] = None _FP8_BLOCK_SCALING_SUPPORT: Optional[Tuple[bool, str]] = None @@ -160,14 +162,54 @@ def _compute_fp8_support() -> Tuple[bool, str]: def _compute_mxfp8_support() -> Tuple[bool, str]: - """Return if fp8 support is available""" - if get_device_compute_capability() >= (12, 0): - return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." + """Return if MXFP8 single-GEMM support is available. + + On sm_120 / sm_121 this covers the single-GEMM TN/NN/NT paths via cuBLASLt >= 13.6.0.2; + grouped MXFP8 GEMM is gated separately by :func:`_compute_mxfp8_grouped_gemm_support`. + """ + if get_device_compute_capability() in ((12, 0), (12, 1)): + cublaslt_version = tex.get_cublasLt_version() + if cublaslt_version >= 130600: + return True, "" + return ( + False, + ( + "MXFP8 on sm_120 / sm_121 requires cuBLASLt >= 13.6.0.2 for NN/NT GEMM " + f"support (loaded cuBLASLt={cublaslt_version})." + ), + ) if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" return False, "Device compute capability 10.0 or higher required for MXFP8 execution." +def _compute_mxfp8_grouped_gemm_support() -> Tuple[bool, str]: + """Return if MXFP8 *grouped*-GEMM support is available. + + This is strictly a subset of single-GEMM MXFP8 support: it inherits the + requirements of :func:`_compute_mxfp8_support` and then additionally + requires that the loaded cuBLASLt implements grouped MXFP8 GEMM on the + current device. On sm_120 / sm_121 the cuBLASLt 13.6.0.2 release supports + single MXFP8 GEMM (TN/NN/NT) but does NOT yet implement grouped MXFP8 GEMM; + callers should treat that case as unsupported until a future cuBLAS adds it. + """ + base_ok, base_reason = _compute_mxfp8_support() + if not base_ok: + return False, base_reason + if get_device_compute_capability() in ((12, 0), (12, 1)): + cublaslt_version = tex.get_cublasLt_version() + return ( + False, + ( + "MXFP8 grouped GEMM is not yet supported on sm_120 / sm_121 by the loaded " + f"cuBLASLt={cublaslt_version} (single-GEMM MXFP8 is supported with " + "cuBLASLt >= 13.6.0.2). Use a non-grouped module or switch to a " + "non-MXFP8 recipe (e.g. Float8CurrentScaling) for grouped-GEMM workloads." + ), + ) + return True, "" + + def _compute_nvfp4_support() -> Tuple[bool, str]: """Return if nvfp4 support is available""" if get_device_compute_capability() >= (10, 0): # blackwell and above @@ -196,13 +238,22 @@ def check_fp8_support() -> Tuple[bool, str]: @torch.compiler.assume_constant_result def check_mxfp8_support() -> Tuple[bool, str]: - """Return if MXFP8 support is available.""" + """Return if MXFP8 single-GEMM support is available.""" global _MXFP8_SUPPORT if _MXFP8_SUPPORT is None: _MXFP8_SUPPORT = _compute_mxfp8_support() return _MXFP8_SUPPORT +@torch.compiler.assume_constant_result +def check_mxfp8_grouped_gemm_support() -> Tuple[bool, str]: + """Return if MXFP8 grouped-GEMM support is available.""" + global _MXFP8_GROUPED_GEMM_SUPPORT + if _MXFP8_GROUPED_GEMM_SUPPORT is None: + _MXFP8_GROUPED_GEMM_SUPPORT = _compute_mxfp8_grouped_gemm_support() + return _MXFP8_GROUPED_GEMM_SUPPORT + + @torch.compiler.assume_constant_result def check_nvfp4_support() -> Tuple[bool, str]: """Return if NVFP4 support is available.""" @@ -323,7 +374,15 @@ def is_fp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str def is_mxfp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: """ - Determine if support is available for the MXFP8 recipe. + Determine if support is available for the MXFP8 recipe (single GEMM). + + This reports support for the single-GEMM MXFP8 dispatch (the common TN/NN/NT + fwd/dgrad/wgrad path used by ``te.Linear`` / ``te.LayerNormLinear`` / + ``te.LayerNormMLP`` / ``te.TransformerLayer``). Grouped MXFP8 GEMM + (e.g. ``te.GroupedLinear``, ``general_grouped_gemm``, + ``general_grouped_gemm_for_grouped_tensor``) is gated separately by + :func:`is_mxfp8_grouped_gemm_available` because it may be unsupported on + some device + cuBLASLt combinations even when single-GEMM MXFP8 is supported. Parameters ---------- @@ -339,6 +398,34 @@ def is_mxfp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, s return check_mxfp8_support()[0] +def is_mxfp8_grouped_gemm_available( + return_reason: bool = False, +) -> Union[bool, Tuple[bool, str]]: + """ + Determine if support is available for MXFP8 grouped GEMM. + + MXFP8 grouped GEMM is a strict superset of single-GEMM MXFP8 in terms of + requirements: the underlying cuBLASLt must implement the grouped MXFP8 + GEMM heuristic for the current device. Use this check to gate + ``te.GroupedLinear`` / ``general_grouped_gemm`` / + ``general_grouped_gemm_for_grouped_tensor`` dispatch and to skip + MXFP8 grouped-GEMM tests on devices where the underlying cuBLASLt + does not (yet) implement that path. + + Parameters + ---------- + return_reason : bool, optional + If ``False`` (default), return only a boolean indicating availability. + If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides + a human-readable explanation when required support is not available. The reason + will be an empty string if support is available. + + """ + if return_reason: + return check_mxfp8_grouped_gemm_support() + return check_mxfp8_grouped_gemm_support()[0] + + def is_fp8_block_scaling_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: """ Determine if support is available for the FP8 block scaling recipe. @@ -424,6 +511,11 @@ def is_mxfp8_available(cls) -> Tuple[bool, str]: """Return if MXFP8/current scaling support is available.""" return check_mxfp8_support() + @classmethod + def is_mxfp8_grouped_gemm_available(cls) -> Tuple[bool, str]: + """Return if MXFP8 grouped-GEMM support is available.""" + return check_mxfp8_grouped_gemm_support() + @classmethod def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: """Return if Float8 block scaling support is available."""