diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index bb4a4e3857..8e9bb2bc23 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -12,6 +12,7 @@ LayerNormLinear, LayerNormMLP, Linear, + GroupedLinear, MultiheadAttention, TransformerLayer, autocast, @@ -23,6 +24,7 @@ is_nvfp4_available, is_bf16_available, ) +from transformer_engine.pytorch.module.grouped_linear import _GroupedLinear from transformer_engine.pytorch.quantization import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe @@ -157,6 +159,52 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) torch.testing.assert_close(t1, t2, rtol=0, atol=0) +def test_grouped_linear_forwards_fp8_graph_skip_tensor(monkeypatch) -> None: + """GroupedLinear should propagate the dynamic FP8 weight-update skip flag.""" + mod = GroupedLinear(1, 4, 4, bias=False, device="cpu") + skip_fp8_weight_update = torch.tensor([1.0]) + captured_non_tensor_args = {} + + monkeypatch.setattr(FP8GlobalStateManager, "fp8_graph_capturing", lambda: True) + monkeypatch.setattr( + FP8GlobalStateManager.quantization_state, + "skip_fp8_weight_update_tensor", + skip_fp8_weight_update, + ) + monkeypatch.setattr(mod, "is_debug_iter", lambda: False) + monkeypatch.setattr(mod, "prepare_forward", lambda inp, num_gemms=1: inp) + monkeypatch.setattr(mod, "end_forward", lambda: None) + monkeypatch.setattr(mod, "_get_weight_tensors", lambda: [torch.empty(4, 4)]) + monkeypatch.setattr(mod, "_get_bias_tensors", lambda: []) + monkeypatch.setattr( + mod, + "_get_quantizers", + lambda: ([None], [None], [None], [None], [None], [None]), + ) + + def _capture_forward(ctx, inp, m_splits, non_tensor_args, *weights_and_biases): + del ctx, m_splits, weights_and_biases + captured_non_tensor_args["value"] = non_tensor_args + return inp, [None] + + monkeypatch.setattr(_GroupedLinear, "forward", staticmethod(_capture_forward)) + + with torch.no_grad(): + mod(torch.empty(2, 4), [2], is_first_microbatch=True) + + non_tensor_args = captured_non_tensor_args["value"] + # non_tensor_args layout (see GroupedLinear.forward): + # 0: apply_bias, 1: is_first_microbatch, 2: fp8, 3: fp8_calibration, + # 4: wgrad_store, 5-10: quantizers (x6), 11: fuse_wgrad_accumulation, + # 12: is_cpu_offload, 13: sequence_parallel, 14: activation_dtype, + # 15: is_grad_enabled, 16: weight_workspaces, 17: cache_weight, + # 18: skip_fp8_weight_update, 19: save_original_input, 20: debug + _IS_FIRST_MICROBATCH_IDX = 1 + _SKIP_FP8_WEIGHT_UPDATE_IDX = 18 + assert non_tensor_args[_IS_FIRST_MICROBATCH_IDX] is False + assert non_tensor_args[_SKIP_FP8_WEIGHT_UPDATE_IDX] is skip_fp8_weight_update + + def generate_data( model_config: ModelConfig, dtype: torch.dtype, diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 15ec3fe322..c1d45511df 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -1696,6 +1696,15 @@ def forward( f"does not match number of GEMMs ({num_gemms})." ) + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = ( + FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor + ) + else: + skip_fp8_weight_update = None + if skip_fp8_weight_update is not None: + is_first_microbatch = False + # Preprocess input tensor if isinstance(inp, QuantizedTensorStorage): raise TypeError("GroupedLinear doesn't support input tensor in FP8.") @@ -1754,7 +1763,7 @@ def forward( is_grad_enabled, weight_workspaces, cache_weight, - None, # skip_fp8_weight_update + skip_fp8_weight_update, self.save_original_input, debug, )