diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 12b4703469..6ed66b4102 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -252,6 +252,9 @@ inline std::string grouped_gemm_skip_reason(const TestParams& params) { return "Grouped GEMM on Hopper (SM90) requires cuBLAS 13.4+, but run-time cuBLAS " "version is " + std::to_string(cublas_ver) + "."; } + if (cc == 120 || cc == 121) { + return "Grouped GEMM is currently unsupported on SM12x architectures."; + } if (params.recipe != InputRecipe::kBF16) { const bool is_blackwell_plus = cc >= blackwellComputeCapability; const bool fp8_block = is_fp8_block_recipe(params.recipe); diff --git a/tests/pytorch/debug/run_distributed.py b/tests/pytorch/debug/run_distributed.py index 285ec7ba0c..78f45286ea 100644 --- a/tests/pytorch/debug/run_distributed.py +++ b/tests/pytorch/debug/run_distributed.py @@ -48,6 +48,19 @@ fp8_available = is_fp8_available() +def _cmp_dist(ground_truth, output, parallel_mode): + if parallel_mode == "column" and torch.cuda.get_device_capability() == (12, 0): + # SM120: distributed column-parallel path may show a single-element + # activation outlier slightly above default fp32 atol, while grads match. + torch.testing.assert_close( + ground_truth["activation"], output["activation"], atol=1.2e-5, rtol=1.3e-6 + ) + torch.testing.assert_close(ground_truth["wgrad"], output["wgrad"]) + torch.testing.assert_close(ground_truth["dgrad"], output["dgrad"]) + else: + _cmp(ground_truth, output) + + def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None): if tp_size is None: tp_size = WORLD_SIZE @@ -445,7 +458,7 @@ def test_disable_fp8_gemms(fprop_fp8, dgrad_fp8, wgrad_fp8, parallel_mode, **kwa x.grad.zero_() ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs) - _cmp(ground_truth, output) + _cmp_dist(ground_truth, output, parallel_mode) @run_debug_test @@ -466,7 +479,7 @@ def test_disable_fp8_layer(parallel_mode, **kwargs): y = _run_forward_backward(x, model, parallel_mode) output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()} - _cmp(ground_truth, output) + _cmp_dist(ground_truth, output, parallel_mode) @run_debug_test @@ -554,7 +567,7 @@ def test_per_tensor_scaling( x, weight, parallel_mode=parallel_mode, loss_multiplier=LOSS_MULTIPLIER, **fp8_kwargs ) - _cmp(ground_truth, output) + _cmp_dist(ground_truth, output, parallel_mode) @run_debug_test @@ -617,7 +630,7 @@ def test_fake_quant_fp8( _get_current_scale(x, wgrad_input) if not fp8_kwargs["wgrad_fp8"] else None ) ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs) - _cmp(ground_truth, output) + _cmp_dist(ground_truth, output, parallel_mode) def _init_distributed(): diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 53c7a5e7cc..b1883a3bc9 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -30,6 +30,11 @@ warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) +FP8_DEFAULT_RTOL_ATOL = (0.125, 0.0625) +FP8_CS_SM120_DETERMINISTIC_RTOL_ATOL = (0.4, 0.25) +BF16_DEFAULT_RTOL_ATOL = (0.025, 0.00125) +BF16_SM120_DETERMINISTIC_OVERLAP_RTOL_ATOL = (0.05, 0.01) + class multi_module_model(torch.nn.Module): def __init__(self, module, num_layers, *args, **kwargs): @@ -551,9 +556,33 @@ def run_fwd_bwd(model, x): # Now validate accuracy if not bool(numerics_failed.item()): + is_sm120 = torch.cuda.get_device_capability() == (12, 0) + is_deterministic_mode = os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1") == "0" for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): - rtol = 0.125 if opts.fp8 else 0.025 - atol = 0.0625 if opts.fp8 else 0.00125 + if opts.fp8: + if ( + opts.quantization == "fp8_current_scaling" + and is_sm120 + and is_deterministic_mode + ): + # SM120 deterministic mode disables fused attn, so rt uses alternate attn backends. + # Combined with FP8 CS, this path needs the looser distributed fp8_cs tolerance policy. + rtol, atol = FP8_CS_SM120_DETERMINISTIC_RTOL_ATOL + else: + rtol, atol = FP8_DEFAULT_RTOL_ATOL + else: + rtol, atol = BF16_DEFAULT_RTOL_ATOL + if ( + is_sm120 + and is_deterministic_mode + and opts.layer_type == te.TransformerLayer + and opts.num_layers > 1 + and opts.overlap_rs_dgrad + ): + # SM120 + deterministic training disables fused attn . + # Rt then selects an alternate attn backend, and + # the overlap path can show tiny BF16 accumulation-order drift vs reference. + rtol, atol = BF16_SM120_DETERMINISTIC_OVERLAP_RTOL_ATOL grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) numerics_failed[0] = int(grad_failed) diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py index d46a874695..def86a2f77 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py @@ -28,6 +28,19 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) +def _reference_scale_for_layout( + ref_unswizzled: torch.Tensor, + split_m: int, + n: int, + columnwise: bool, + with_gemm_swizzled_scales: bool, +) -> torch.Tensor: + """Return reference scale in expected backend-reported layout.""" + if with_gemm_swizzled_scales: + return swizzle_nvfp4_scale(split_m, n, ref_unswizzled.clone(), columnwise=columnwise) + return ref_unswizzled + + def fused_grouped_quantize( x: torch.Tensor, split_section_tensor: torch.Tensor, quantizer: NVFP4Quantizer ): @@ -56,7 +69,6 @@ def check_grouped_tensor_nvfp4_versus_reference( ) -> None: te_dtype = tex.DType.kFloat4E2M1 - split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") # Setup device and random seed @@ -98,6 +110,14 @@ def check_grouped_tensor_nvfp4_versus_reference( group_quantized_output = fused_grouped_quantize(x, split_section_tensor, grouped_quantizer) # get a list of nvfp4 quantized tensors for testing split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() + expected_swizzled_layout = bool(group_quantized_output._with_gemm_swizzled_scales) + for i, output in enumerate(split_quantize_outputs): + split_flag = bool(output._with_gemm_swizzled_scales) + assert split_flag == expected_swizzled_layout, ( + "Grouped output and split output disagree on swizzled-scale metadata " + f"(split {i}: grouped={expected_swizzled_layout}, split={split_flag})" + ) + scale_atol, scale_rtol = 0.0, 0.0 if return_rowwise: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] @@ -121,11 +141,15 @@ def check_grouped_tensor_nvfp4_versus_reference( ), "The scale shape is not correctly aligned" x_sx_i = x_sx[i].clone() x_sx_ref_i = x_sx_ref[i].clone() - if optimize_for_gemm: - x_sx_ref_i = swizzle_nvfp4_scale( - split_sections[i], N, x_sx_ref_i, columnwise=False - ) - torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=0.0, rtol=0.0) + # Swizzle the reference scale based on expected_swizzled_layout + x_sx_ref_i = _reference_scale_for_layout( + ref_unswizzled=x_sx_ref_i, + split_m=split_sections[i], + n=N, + columnwise=False, + with_gemm_swizzled_scales=expected_swizzled_layout, + ) + torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=scale_atol, rtol=scale_rtol) if return_transpose: x_qx_t = [ @@ -151,11 +175,14 @@ def check_grouped_tensor_nvfp4_versus_reference( ), "The scale shape is not correctly aligned" x_sx_t_i = x_sx_t[i].clone() x_sx_t_ref_i = x_sx_t_ref[i].clone() - if optimize_for_gemm: - x_sx_t_ref_i = swizzle_nvfp4_scale( - split_sections[i], N, x_sx_t_ref_i, columnwise=True - ) - torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=0.0, rtol=0.0) + x_sx_t_ref_i = _reference_scale_for_layout( + ref_unswizzled=x_sx_t_ref_i, + split_m=split_sections[i], + n=N, + columnwise=True, + with_gemm_swizzled_scales=expected_swizzled_layout, + ) + torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=scale_atol, rtol=scale_rtol) def check_grouped_tensor_nvfp4_with_paged_stashing( @@ -173,7 +200,6 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( ) -> None: te_dtype = tex.DType.kFloat4E2M1 - assert valid_M is not None, "valid_M must be provided when with_paged_stashing is True" assert valid_M < M, "valid_M must be less than M when with_paged_stashing is True" @@ -225,6 +251,14 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( # get a list of nvfp4 quantized tensors for testing split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() + expected_swizzled_layout = bool(group_quantized_output._with_gemm_swizzled_scales) + for i, output in enumerate(split_quantize_outputs): + split_flag = bool(output._with_gemm_swizzled_scales) + assert split_flag == expected_swizzled_layout, ( + "Grouped output and split output disagree on swizzled-scale metadata " + f"(split {i}: grouped={expected_swizzled_layout}, split={split_flag})" + ) + scale_atol, scale_rtol = 0.0, 0.0 if return_rowwise: x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] @@ -248,11 +282,15 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( ), "The scale shape is not correctly aligned" x_sx_i = x_sx[i].clone() x_sx_ref_i = x_sx_ref[i].clone() - if optimize_for_gemm: - x_sx_ref_i = swizzle_nvfp4_scale( - split_sections[i], N, x_sx_ref_i, columnwise=False - ) - torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=0.0, rtol=0.0) + # Swizzle the reference scale based on expected swizzled layout + x_sx_ref_i = _reference_scale_for_layout( + ref_unswizzled=x_sx_ref_i, + split_m=split_sections[i], + n=N, + columnwise=False, + with_gemm_swizzled_scales=expected_swizzled_layout, + ) + torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=scale_atol, rtol=scale_rtol) if return_transpose: x_qx_t = [ @@ -275,11 +313,14 @@ def check_grouped_tensor_nvfp4_with_paged_stashing( valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, True) x_sx_t_i = x_sx_t[i].clone() x_sx_t_ref_i = x_sx_t_ref[i].clone() - if optimize_for_gemm: - x_sx_t_ref_i = swizzle_nvfp4_scale( - split_sections[i], N, x_sx_t_ref_i, columnwise=True - ) - torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=0.0, rtol=0.0) + x_sx_t_ref_i = _reference_scale_for_layout( + ref_unswizzled=x_sx_t_ref_i, + split_m=split_sections[i], + n=N, + columnwise=True, + with_gemm_swizzled_scales=expected_swizzled_layout, + ) + torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=scale_atol, rtol=scale_rtol) @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @@ -402,6 +443,11 @@ def test_grouped_tensor_nvfp4_with_paged_stashing( with_rht: bool, optimize_for_gemm: bool, ) -> None: + if torch.cuda.get_device_capability() == (12, 0): + pytest.skip( + "SM120: paged-stashing grouped NVFP4 path is currently unsupported. " + "group_hadamard_transform_amax assumes sum(split_sections) == input rows)." + ) # paged stashing means that the sum of total tokens is less than # or equal to the buffer size, you can have buffer [2048, 1024] diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py index b14eeb815b..7a6fd5b43a 100755 --- a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -14,11 +14,31 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) +SM120_SR_EQUIVALENCE_ATOL = 2e-7 + seed = 12345 torch.manual_seed(seed) torch.cuda.manual_seed(seed) +def _assert_sr_vs_rn_behavior( + me_sr: torch.Tensor, + me_rn: torch.Tensor, + me_t_sr: torch.Tensor, + me_t_rn: torch.Tensor, +) -> None: + if torch.cuda.get_device_capability() == (12, 0): + # SM120 currently disables NVFP4 stochastic rounding in backend paths, + # so SR and RN should be numerically equivalent. + torch.testing.assert_close(me_sr, me_rn, atol=SM120_SR_EQUIVALENCE_ATOL, rtol=0.0) + torch.testing.assert_close(me_t_sr, me_t_rn, atol=SM120_SR_EQUIVALENCE_ATOL, rtol=0.0) + else: + assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." + assert ( + me_t_sr < me_t_rn + ), "Stochastic rounding failed - error larger than the round to nearest." + + def unpack_fp4(x: torch.Tensor) -> torch.Tensor: repeated = x.repeat_interleave(2, dim=1) repeated[:, 0::2] &= 0x0F @@ -247,7 +267,7 @@ def check_quantization_nvfp4_versus_reference( me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean()) sr_result = torch.zeros_like(x).float() sr_t_result = torch.zeros_like(x).float().t().contiguous() - for i in range(n_iters): + for _ in range(n_iters): q_sr, s_sr, q_t_sr, s_t_sr = quantize_fp4( x, use_stochastic_rounding=True, use_2D=use_2D, use_RHT=use_RHT ) @@ -278,8 +298,7 @@ def check_quantization_nvfp4_versus_reference( print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") - assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." - assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest." + _assert_sr_vs_rn_behavior(me_sr, me_rn, me_t_sr, me_t_rn) def check_group_quantization_nvfp4_versus_reference( @@ -362,10 +381,7 @@ def check_group_quantization_nvfp4_versus_reference( print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") - assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." - assert ( - me_t_sr < me_t_rn - ), "Stochastic rounding failed - error larger than the round to nearest." + _assert_sr_vs_rn_behavior(me_sr, me_rn, me_t_sr, me_t_rn) @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 62a6291797..4cf5c6ec1b 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -128,7 +128,7 @@ def test_custom_recipe_grouped_linear_sanity(): in_features = 64 out_features = 64 # Each per-GEMM M dim must be a multiple of 16 to satisfy cuBLAS FP8 GEMM's - # leading-dimension alignment requirement on Hopper (sm_90). + # leading-dimension alignment requirement on Hopper and SM120 paths. m_splits = [16] * num_gemms batch = sum(m_splits) @@ -281,7 +281,7 @@ def test_custom_recipe_factory_invocation_counts_and_cycling(): in_features = 64 out_features = 64 # batch must be a multiple of 16 to satisfy cuBLAS FP8 GEMM's leading-dim - # alignment requirement on Hopper (sm_90). + # alignment requirement on Hopper and SM120 paths. batch = 16 op = Linear(in_features, out_features, params_dtype=torch.bfloat16) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 1ced32e1a5..74ff3cb248 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2132,6 +2132,16 @@ def test_grouped_linear( "with quantized_model_init" ) + # SM120: cuBLASLt grouped GEMM (graph-safe path) is unsupported. + # BF16/FP16 unquantized compute and MXFP8 quantized compute select that + # path on SM100+; FP32 and non-MXFP8 quantized compute use legacy GEMM. + if torch.cuda.get_device_capability() == (12, 0) and _uses_cublaslt_grouped_gemm_path( + dtype=dtype, + quantization=quantization, + quantized_compute=quantized_compute, + ): + pytest.skip("Grouped cuBLASLt GEMM is currently unsupported on SM120.") + # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, @@ -2317,6 +2327,15 @@ def test_grouped_linear_cuda_graph_safe( if single_grouped_bias and not bias: pytest.skip("single_grouped_bias requires bias=True") + # SM120: this test only exercises the cuBLASLt grouped GEMM path (BF16/FP16 + # or MXFP8 autocast), which is unsupported on SM120. + if torch.cuda.get_device_capability() == (12, 0) and _uses_cublaslt_grouped_gemm_path( + dtype=dtype, + quantization=quantization, + quantized_compute=quantization is not None, + ): + pytest.skip("Grouped cuBLASLt GEMM is currently unsupported on SM120.") + # Split sizes (statically pinned for graph capture) split_sizes = [split_alignment * (i + 1) for i in range(group_size)] random.shuffle(split_sizes) @@ -3673,6 +3692,16 @@ def test_layernorm_mlp( # Check values tols = {"rtol": 0.25, "atol": 0.5} # Loose tols for sanity checking + # SM120 + NVFP4 quantized compute has wider per-element numerical spread + # in the bwd path (notably ffn1.bias.grad / ffn2.weight.grad with + # bias=True) as it falls back to RN instead of SR, uses unfused path + # instead of fused RHT grouped kernel, uses non-TMA gated-act kernels. + if ( + quantization == "nvfp4" + and quantized_compute + and torch.cuda.get_device_capability() == (12, 0) + ): + tols["atol"] = max(tols["atol"], 0.75) assert_close(y_test, y_ref, **tols) assert_close(x_test.grad, x_ref.grad, **tols) assert_close_grads(norm.weight, norm_w_ref, **tols) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index e087f1e1cd..d39777bf66 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -53,6 +53,7 @@ nvfp4_available = is_nvfp4_available() sm_80plus = get_device_compute_capability() >= (8, 0) +sm_120 = get_device_compute_capability() == (12, 0) seed = 1234 # Reset RNG states. @@ -2115,9 +2116,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): max_seqlen_kv=config.max_seqlen_kv, ) + tols = dtype_tols(dtype) + if sm_120: + # sm120 FusedAttention does not support T3HD/TH3D layouts, so for T3HD/TH3D, the test falls back to using Flash Attn backend + # whereas for BSHD/SBHD, the test uses FusedAttention backend by default. Hence, relaxing the atol tolerance for T3HD/TH3D. + tols["atol"] = max(tols["atol"], 4e-3) torch.testing.assert_close( y_bshd, y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(), + **tols, ) diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index 63c1b046ff..0568fc521b 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -39,7 +39,7 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t break; } case NVTE_MXFP8_1D_SCALING: { - if (is_supported_by_CC_100()) { + if (is_supported_by_CC_100_or_newer()) { mxfp8::dequantize(input, output, stream); } else { NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); @@ -62,7 +62,7 @@ inline void group_dequantize_helper(const GroupedTensor &input, GroupedTensor *o switch (input.scaling_mode) { case NVTE_MXFP8_1D_SCALING: { - if (is_supported_by_CC_100()) { + if (is_supported_by_CC_100_or_newer()) { mxfp8::group_dequantize(&input, output, stream); } else { NVTE_ERROR("MXFP8 Grouped Dequantization is NOT supported by architectures < 10.0"); diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index 06e8f0e306..11b28c2483 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -46,7 +46,10 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // SM120 has lower shared-memory headroom than SM100 for this kernel family. + // Keep TMA kernels disabled on SM120 and use the non-TMA fallback path. + const bool use_tma_kernels = + (cols % 32 == 0) && is_supported_by_CC_100_or_newer() && !is_supported_by_CC_120(); if (use_tma_kernels) { Tensor dummy_grad_tensor; fp8::cast_gated_tma(input, dummy_grad_tensor, @@ -83,7 +86,7 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), "The type of the columnwise output tensor should be FP8."); } - NVTE_CHECK(is_supported_by_CC_100(), + NVTE_CHECK(is_supported_by_CC_100_or_newer(), "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); Tensor dummy_grad_tensor; mxfp8::quantize_gated(input, dummy_grad_tensor, @@ -137,7 +140,10 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { - const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + // SM120 has lower shared-memory headroom than SM100 for this kernel family. + // Keep TMA kernels disabled on SM120 and use the non-TMA fallback path. + const bool use_tma_kernels = + (cols % 32 == 0) && is_supported_by_CC_100_or_newer() && !is_supported_by_CC_120(); if (use_tma_kernels) { fp8::cast_gated_tma(gated_input, grad, output, p, stream); @@ -173,7 +179,7 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), "The type of the columnwise output tensor should be FP8."); } - NVTE_CHECK(is_supported_by_CC_100(), + NVTE_CHECK(is_supported_by_CC_100_or_newer(), "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); mxfp8::quantize_gated(gated_input, grad, output, p, diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh index bad10c954e..1cddf889ce 100644 --- a/transformer_engine/common/cast/fp8/quantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -531,7 +531,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); // Supported by the Arch >= 10.0 - if (is_supported_by_CC_100()) { + if (is_supported_by_CC_100_or_newer()) { if (!IS_DBIAS && !IS_DACT) { if (common::full_tile_1D_tensor(output, ELEMS_PER_BLOCK) && is_fp8_dtype(output->dtype()) && is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 5e71a30e83..f599c0a245 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -491,6 +491,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } + // Ensure async shared->global copy is done reading shared source before reuse. + ptx::cp_async_bulk_wait_group_read<0>(); + // Ensure all warps reach the reuse boundary before DBIAS scratch writes. + __syncthreads(); + parity ^= 1; if constexpr (IS_DBIAS) { diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 1bdd80a369..52c1f2bb3b 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -281,12 +281,18 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); } -bool is_supported_by_CC_100() { +bool is_supported_by_CC_100_or_newer() { int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); return deviceComputeCapability >= 100; } +bool is_supported_by_CC_120() { + int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); + + return deviceComputeCapability == 120; +} + std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size) { std::vector> ret; diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index eb4dcc055c..4eb1858d05 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1185,7 +1185,9 @@ void create_2D_tensor_map( const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits, const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); -bool is_supported_by_CC_100(); +bool is_supported_by_CC_100_or_newer(); + +bool is_supported_by_CC_120(); std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 6df7ad35c8..e8f113bff6 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -385,7 +385,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); - if (is_ragged_q && cudnn_runtime_version >= 90600) { + if (use_ragged_stats) { Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); } else { Stats->set_stride({h * s_q, s_q, 1, 1}); @@ -1142,7 +1142,8 @@ void fused_attn_arbitrary_seqlen_fwd( Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && + (sm_arch_ != 120)) { output_S->data.shape = {num_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index f064af2478..e270af2087 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -345,6 +345,7 @@ inline void check_grouped_gemm_requirements(const char *api_name) { } #else NVTE_CHECK(sm >= 100, api_name, " requires Blackwell (SM100) or newer architecture."); + NVTE_CHECK(sm != 120 && sm != 121, api_name, " is currently unsupported on SM12x architectures."); NVTE_CHECK(cublas_ver >= CUBLAS_GROUPED_GEMM_VERSION, api_name, " requires cuBLAS 13.3+, but run-time cuBLAS version is ", cublas_ver); #endif diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 989b65f190..036b80c135 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -468,11 +468,20 @@ def get_attention_backend( if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: logger.debug("Disabling FlashAttention 3 for compute capability != sm90") use_flash_attention_3 = False - # FA4 supports SM80, SM90, SM100, SM120 + # FA4 supports SM80, SM90, SM100 if device_compute_capability < (8, 0): if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed: logger.debug("Disabling FlashAttention 4 for compute capability < sm80") use_flash_attention_4 = False + # FA4 is temporarily disabled on SM120 due to failures observed with + # SplitKV, Block sparsity / paged KV and likely FAv4/DSL integration issues. + if device_compute_capability == (12, 0): + if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed: + logger.warning( + "Disabling FlashAttention 4 on sm120 due to missings bits of support for SplitKV," + " Block sparsity / paged KV and likely FAv4/DSL integration issues." + ) + use_flash_attention_4 = False # On SM90, prefer FA3 over FA4 when FA3 is available. # FA3 is more mature on Hopper; FA4's SM90 backward has limitations # (MLA, non-standard head dims, SplitKV). diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index d1a9cd8587..e12d2dba6d 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -16,9 +16,11 @@ #include #include "../extensions.h" +#include "../util.h" #include "common.h" #include "common/util/system.h" #include "pybind.h" +#include "transformer_engine/swizzle.h" #include "transformer_engine/transformer_engine.h" namespace transformer_engine { @@ -75,6 +77,52 @@ py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector namespace { +void split_quantize_nvfp4_impl(const TensorWrapper &input, + const std::vector &input_list, + std::vector &output_list, + const std::vector &split_sections, + const std::vector &quantizers); + +// Converts the per-group GPU row counts (first_dims, int64 CUDA tensor) +// into a host vector of per-group row counts and returns it. +// The returned vector is used by NVFP4 grouped-quantize to split the input +// tensor into per-group sub-tensors. +// Currently, only used for SM120 NVFP4 grouped-quantize fallback. +std::vector get_split_sections(std::optional first_dims, size_t num_tensors) { + auto first_dims_tensor = first_dims.value(); + NVTE_CHECK(first_dims_tensor.scalar_type() == at::kLong, + "Expected first_dims dtype=int64, got scalar_type enum=", + static_cast(first_dims_tensor.scalar_type())); + // D2H copy to CPU + auto first_dims_cpu = first_dims_tensor.contiguous().to(at::kCPU); + NVTE_CHECK(static_cast(first_dims_cpu.numel()) == num_tensors, "Expected ", num_tensors, + " first_dims entries, but got ", first_dims_cpu.numel(), "."); + std::vector split_sections(num_tensors, 0); + const int64_t *first_dims_ptr = first_dims_cpu.data_ptr(); + for (size_t i = 0; i < num_tensors; ++i) { + NVTE_CHECK(first_dims_ptr[i] >= 0, "first_dims must be non-negative, got ", first_dims_ptr[i], + " at index ", i, "."); + split_sections[i] = static_cast(first_dims_ptr[i]); + } + return split_sections; +} + +// Converts the Python GroupedTensor into a C++ vector of TensorWrappers, +// which are used by NVFP4 grouped-quantize to store the quantized output tensors. +// Currently, only used for SM120 NVFP4 grouped-quantize fallback. +std::vector get_grouped_outputs(const py::object &grouped_output_py, + size_t num_tensors) { + py::list split_outputs = grouped_output_py.attr("split_into_quantized_tensors")(); + NVTE_CHECK(static_cast(py::len(split_outputs)) == num_tensors, "Expected ", num_tensors, + " output tensors, but got ", py::len(split_outputs), "."); + std::vector output_list; + output_list.reserve(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + output_list.emplace_back(makeTransformerEngineTensor(split_outputs[i], py::none())); + } + return output_list; +} + // helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy) void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor, GroupedTensorWrapper &grouped_output_tensor, @@ -157,10 +205,11 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const using namespace transformer_engine::pytorch::detail; init_extension(); - NVTE_CHECK(tensor.dim() == 2, "Tensor must be 2D"); + auto input_contiguous = tensor.contiguous(); + NVTE_CHECK(input_contiguous.dim() == 2, "Tensor must be 2D"); std::vector logical_shape; - for (const auto &d : tensor.sizes()) { + for (const auto &d : input_contiguous.sizes()) { logical_shape.push_back(d); } const auto logical_first_dim = logical_shape[0]; @@ -172,8 +221,9 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const // Create input GroupedTensor. auto grouped_input_tensor = GroupedTensorWrapper(num_tensors, logical_shape); - grouped_input_tensor.set_rowwise_data( - tensor.data_ptr(), GetTransformerEngineDType(tensor.scalar_type()), getTensorShape(tensor)); + grouped_input_tensor.set_rowwise_data(input_contiguous.data_ptr(), + GetTransformerEngineDType(input_contiguous.scalar_type()), + getTensorShape(input_contiguous)); // Create output GroupedTensor. auto [grouped_output_tensor_cpp, grouped_output_py] = quantizer_cpp->create_grouped_tensor( @@ -206,8 +256,41 @@ py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const case GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE: { // NVFP4 grouped quantization NVFP4Quantizer *nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); - group_quantize_nvfp4_impl(grouped_input_tensor, grouped_output_tensor_cpp, - nvfp4_quantizer_cpp, at::cuda::getCurrentCUDAStream()); + const bool enable_sm120_grouped_nvfp4_fallback = is_sm120_device() && first_dims.has_value(); + if (enable_sm120_grouped_nvfp4_fallback) { + // SM120 fallback: the fused grouped NVFP4 kernel does not run here, so we cast + // per split with the unfused split-quantize path. split_quantize_nvfp4_impl owns + // the SM120-specific bookkeeping for optimize_for_gemm (compact-only cast + + // post-cast in-place swizzle), so the public contract holds at every layer. + auto split_sections = get_split_sections(first_dims, num_tensors); + std::vector input_list; + input_list.reserve(num_tensors); + auto *input_dptr = reinterpret_cast(input_contiguous.data_ptr()); + const auto input_dtype = GetTransformerEngineDType(input_contiguous.scalar_type()); + const size_t dim0_stride = logical_first_dim == 0 + ? 0 + : static_cast(input_contiguous.element_size()) * + static_cast(input_contiguous.numel()) / + logical_first_dim; + size_t dim0_offset = 0; + for (size_t i = 0; i < num_tensors; ++i) { + NVTE_CHECK(dim0_offset + split_sections[i] <= logical_first_dim, + "Split sections exceed input tensor first dimension."); + std::vector split_shape = {split_sections[i], logical_last_dim}; + void *split_dptr = static_cast(input_dptr + dim0_offset * dim0_stride); + input_list.emplace_back( + makeTransformerEngineTensor(split_dptr, split_shape, input_dtype)); + dim0_offset += split_sections[i]; + } + auto output_list = get_grouped_outputs(grouped_output_py, num_tensors); + std::vector quantizers(num_tensors, nvfp4_quantizer_cpp); + auto input_tensor_cpp = makeTransformerEngineTensor(input_contiguous); + split_quantize_nvfp4_impl(input_tensor_cpp, input_list, output_list, split_sections, + quantizers); + } else { + group_quantize_nvfp4_impl(grouped_input_tensor, grouped_output_tensor_cpp, + nvfp4_quantizer_cpp, at::cuda::getCurrentCUDAStream()); + } break; } case GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE: { @@ -1008,6 +1091,7 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, const bool nvfp4_use_4over6 = quantizer.nvfp4_4over6_mode != kNVTENVFP44Over6Disabled; NVTE_CHECK(!nvfp4_use_4over6, "NVFP4 4over6 quantization is not supported with RHT split quantization."); + const bool sm120_device = is_sm120_device(); std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; @@ -1020,6 +1104,8 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, bool all_aligned_token_dim = std::all_of(split_sections.begin(), split_sections.end(), [](size_t split_section) { return split_section % 128 == 0; }); + // SM120 fallback: avoid the fully fused grouped row+col RHT kernel path. + all_aligned_token_dim = all_aligned_token_dim && !sm120_device; // in the case when rowwise and colwise cannot be fused, we have to generate the RNG states twice // so that rowwise and colwise will have different random numbers @@ -1038,7 +1124,7 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, bool with_bulk_generate_rng_states = true; // Stochastic rounding - bool need_stochastic_rounding = quantizer.stochastic_rounding; + bool need_stochastic_rounding = quantizer.stochastic_rounding && !sm120_device; auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper( num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, need_separate_rng_states, quant_config_list, quant_config_list_colwise); @@ -1147,6 +1233,8 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, if (quantizer.columnwise_usage) { std::vector out_transpose_list; std::vector nvte_tensor_out_transpose_list; + std::vector rht_output_t_tensors; + rht_output_t_tensors.reserve(num_tensors); for (size_t i = 0; i < num_tensors; i++) { bool is_empty_split = input_list[i].numel() == 0; auto out_columnwise_data = output_list[i].get_columnwise_data(); @@ -1172,10 +1260,35 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, out_transpose_list.emplace_back(std::move(out_transpose)); nvte_tensor_out_transpose_list.push_back(out_transpose_list.back().data()); } - nvte_group_hadamard_transform_cast_fusion_columnwise( - input.data(), reinterpret_cast(nvte_tensor_out_transpose_list.data()), - rht_matrix_nvte.data(), split_sections.data(), num_tensors, - quant_config_list_colwise_to_use[0], stream); + if (sm120_device) { + // SM120 fallback: avoid grouped columnwise RHT fusion path and run unfused per split. + for (size_t i = 0; i < num_tensors; i++) { + if (input_list[i].numel() == 0) { + continue; + } + const int rows = static_cast(split_sections[i]); + const int cols = static_cast(input_list[i].size(input_list[i].ndim() - 1)); + auto rht_output_t = allocateTorchTensor(cols, rows, input_list[i].dtype()); + rht_output_t_tensors.push_back(rht_output_t); + TensorWrapper rht_output_t_cpp; + rht_output_t_cpp.set_rowwise_data( + rht_output_t.data_ptr(), input_list[i].dtype(), + std::vector{static_cast(cols), static_cast(rows)}); + // SM120 unfused columnwise path (per split): + // 1) Apply RHT on the input and write the result in transposed layout (shape [cols, rows]) into rht_output_t_cpp. + // Columnwise NVFP4 scales are obtained by running rowwise NVFP4 on x_t, so we need the transposed layout here. + nvte_hadamard_transform(input_list[i].data(), rht_output_t_cpp.data(), 0, + quantizer.rht_matrix_random_sign_mask_t, stream); + // 2) NVFP4-quantize the RHT(x_t) output into the columnwise (out_transpose) slot. + nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose_list[i].data(), + quant_config_list_colwise_to_use[i], stream); + } + } else { + nvte_group_hadamard_transform_cast_fusion_columnwise( + input.data(), reinterpret_cast(nvte_tensor_out_transpose_list.data()), + rht_matrix_nvte.data(), split_sections.data(), num_tensors, + quant_config_list_colwise_to_use[0], stream); + } } } } @@ -1191,6 +1304,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, const bool nvfp4_use_4over6 = quantizer.nvfp4_4over6_mode != kNVTENVFP44Over6Disabled; NVTE_CHECK(!nvfp4_use_4over6 || !quantizer.stochastic_rounding, "NVFP4 4over6 quantization does not support stochastic rounding."); + const bool sm120_device = is_sm120_device(); std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; @@ -1213,7 +1327,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, // so that we can generate all rng states at once bool with_bulk_generate_rng_states = false; - bool need_stochastic_rounding = quantizer.stochastic_rounding; + bool need_stochastic_rounding = quantizer.stochastic_rounding && !sm120_device; // place holder for colwise rng states, which are not needed in this case std::vector dummy_quant_config_list_colwise; @@ -1275,6 +1389,51 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, } } +// Swizzle the rowwise or columnwise scale slot of a single per-split NVFP4 output in +// place. nvte_swizzle_scaling_factors is not in-place safe, so we allocate a scratch +// buffer the same shape/dtype as the slot, swizzle out-of-place into the scratch, then +// memcpy the swizzled bytes back into the original slot. Used to honor optimize_for_gemm +// on SM120 where the NVFP4 split-quantize cast kernels only emit compact-layout scales. +void swizzle_per_split_scale_slot_in_place(TensorWrapper &out, bool rowwise, cudaStream_t stream) { + const auto scales = rowwise ? out.get_rowwise_scale_inv() : out.get_columnwise_scale_inv(); + if (scales.data_ptr == nullptr) { + return; + } + const auto data = rowwise ? out.get_rowwise_data() : out.get_columnwise_data(); + if (data.data_ptr == nullptr) { + return; + } + const auto scaling_mode = out.scaling_mode(); + const auto scales_dtype = static_cast(scales.dtype); + const auto data_dtype = static_cast(data.dtype); + + auto scratch = allocateSpace(scales.shape, scales_dtype, /*init_to_zeros=*/false); + void *scratch_ptr = getDataPtr(scratch); + + TensorWrapper input_nvte(scaling_mode); + TensorWrapper output_nvte(scaling_mode); + if (rowwise) { + input_nvte.set_rowwise_data(nullptr, data_dtype, data.shape); + input_nvte.set_rowwise_scale_inv(scales.data_ptr, scales_dtype, scales.shape); + output_nvte.set_rowwise_data(nullptr, data_dtype, data.shape); + output_nvte.set_rowwise_scale_inv(scratch_ptr, scales_dtype, scales.shape); + } else { + input_nvte.set_columnwise_data(nullptr, data_dtype, data.shape); + input_nvte.set_columnwise_scale_inv(scales.data_ptr, scales_dtype, scales.shape); + output_nvte.set_columnwise_data(nullptr, data_dtype, data.shape); + output_nvte.set_columnwise_scale_inv(scratch_ptr, scales_dtype, scales.shape); + } + output_nvte.set_with_gemm_swizzled_scales(true); + + NVTE_SCOPED_GIL_RELEASE( + { nvte_swizzle_scaling_factors(input_nvte.data(), output_nvte.data(), stream); }); + + const size_t nbytes = + static_cast(scratch.numel()) * static_cast(scratch.element_size()); + NVTE_CHECK_CUDA( + cudaMemcpyAsync(scales.data_ptr, scratch_ptr, nbytes, cudaMemcpyDeviceToDevice, stream)); +} + void split_quantize_nvfp4_impl(const TensorWrapper &input, const std::vector &input_list, std::vector &output_list, @@ -1328,6 +1487,30 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, // CUDA stream auto stream = at::cuda::getCurrentCUDAStream(); + // SM120 only emits scales in the compact (unswizzled) layout from the NVFP4 split- + // quantize cast kernels. To honor optimize_for_gemm at this boundary, we: + // 1. Note which per-split outputs were requested with swizzled scales + // 2. Override their C++ with_gemm_swizzled_scales flag to false so the cast + // kernels' compact-only assertion passes + // 3. Run the cast + // 4. Post-cast, swizzle each requested per-split scale slot in place so the + // data matches the (already-true) Python _with_gemm_swizzled_scales metadata. + // This makes both call sites of split_quantize_nvfp4_impl (group_quantize SM120 + // fallback and standalone split_quantize) honor the public optimize_for_gemm + // contract uniformly. + const bool sm120 = is_sm120_device(); + std::vector wanted_swizzled; + if (sm120) { + wanted_swizzled.reserve(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + const bool want = quantizers[i]->optimize_for_gemm; + wanted_swizzled.push_back(want); + if (want) { + output_list[i].set_with_gemm_swizzled_scales(false); + } + } + } + // Perform multi-tensor quantization NVTE_SCOPED_GIL_RELEASE({ if (quantizer.with_rht) { // Quantize row-wise data, RHT+quantize column-wise data @@ -1342,6 +1525,19 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, stream); } }); + + // Post-cast swizzle on SM120 for outputs that requested optimize_for_gemm=true. + // Each call swizzles one scale direction (rowwise or columnwise) in place; the + // helper itself handles GIL release for the underlying CUDA work. + if (sm120) { + for (size_t i = 0; i < num_tensors; ++i) { + if (!wanted_swizzled[i] || input_list[i].numel() == 0) { + continue; + } + swizzle_per_split_scale_slot_in_place(output_list[i], /*rowwise=*/true, stream); + swizzle_per_split_scale_slot_in_place(output_list[i], /*rowwise=*/false, stream); + } + } } } // namespace diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index bc87b54ba8..f4bf5fb7ad 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -11,6 +11,7 @@ #include "common/util/system.h" #include "pybind.h" #include "torch/torch.h" +#include "util.h" namespace transformer_engine::pytorch { @@ -2008,7 +2009,12 @@ std::pair NVFP4Quantizer::create_grouped_tenso getTensorShape(*tensor_offsets)); } - out_cpp.set_with_gemm_swizzled_scales(this->optimize_for_gemm); + // The grouped tensor metadata always follows the quantizer's optimize_for_gemm. + // Architecture-specific cast paths (e.g. the SM120 grouped NVFP4 fallback in + // pytorch/csrc/extensions/cast.cpp) are responsible for producing scales that + // match this metadata at the API boundary. + const bool with_gemm_swizzled_scales = this->optimize_for_gemm; + out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); py::handle GroupedTensorClass = grouped_tensor_python_class(this->internal); py::dict kwargs; @@ -2031,7 +2037,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso kwargs["first_dims"] = first_dims.has_value() ? py::cast(*first_dims) : py::none(); kwargs["last_dims"] = py::none(); kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); - kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; + kwargs["with_gemm_swizzled_scales"] = with_gemm_swizzled_scales; kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); kwargs["nvfp4_use_4over6"] = py::cast(nvfp4_use_4over6); kwargs["nvfp4_e4m3_max"] = py::cast(nvfp4_e4m3_max); @@ -2314,7 +2320,11 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou quant_config_columnwise.set_noop_tensor(noop_flag->data()); } quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); - quant_config.set_stochastic_rounding(this->stochastic_rounding); + // Disable stochastic-rounding FP4 cast path for SM120, which relies on arch-specific PTX + // instructions. + const bool sm120_device = is_sm120_device(); + const bool use_stochastic_rounding = this->stochastic_rounding && !sm120_device; + quant_config.set_stochastic_rounding(use_stochastic_rounding); quant_config.set_nvfp4_4over6_mode(this->nvfp4_4over6_mode); quant_config_columnwise.set_nvfp4_4over6_mode(this->nvfp4_4over6_mode); @@ -2361,11 +2371,11 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // 3. Columnwise usage is enabled // 4. Rowwise and columnwise quantization are not fused, // because within a single kernel we can generate two different random numbers for rowwise and columnwise - const bool need_separate_columnwise_rng = this->stochastic_rounding && this->with_rht && + const bool need_separate_columnwise_rng = use_stochastic_rounding && this->with_rht && this->columnwise_usage && (!eligible_for_rht_cast_fusion); - if (this->stochastic_rounding) { + if (use_stochastic_rounding) { const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened auto gen = at::get_generator_or_default( std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 132db4075f..5eb51721df 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -13,6 +13,7 @@ #include #include +#include "common/util/cuda_runtime.h" #include "transformer_engine/transformer_engine.h" namespace transformer_engine { @@ -65,6 +66,11 @@ std::optional maybe_swizzle_grouped_tensor(GroupedTensorW */ at::Tensor convert_block_scaling_to_mxfp8_tensor(TensorWrapper& input, bool rowwise); +/*! \brief Check whether the current CUDA device is SM120. */ +inline bool is_sm120_device() { + return transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) == 120; +} + } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 438e124021..de2fc89e49 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -1094,7 +1094,9 @@ def split_into_quantized_tensors( columnwise_scale_inv=columnwise_scale_inv, fp8_dtype=quantizer.dtype, quantizer=quantizer, - with_gemm_swizzled_scales=quantizer.optimize_for_gemm, + # Use the actual grouped-output layout. This can differ from the requested + # quantizer flag if the backend produces a different layout (e.g. sm120) + with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, ) result.append(tensor) @@ -1229,7 +1231,9 @@ def split_into_quantized_tensors( amax_columnwise=amax_columnwise, fp4_dtype=quantizer.dtype, quantizer=quantizer, - with_gemm_swizzled_scales=quantizer.optimize_for_gemm, + # Use the actual grouped-output layout. This can differ from the requested + # quantizer flag if the backend produces a different layout (e.g. sm120) + with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, nvfp4_use_4over6=nvfp4_use_4over6, nvfp4_e4m3_max=nvfp4_e4m3_max,