diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index a0d4ac3530..acfb6ddff8 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -6,7 +6,7 @@ from __future__ import annotations from collections.abc import Iterable -from contextlib import contextmanager, AbstractContextManager, ContextDecorator +from contextlib import contextmanager, AbstractContextManager, ContextDecorator, nullcontext from functools import lru_cache from dataclasses import dataclass import math @@ -918,7 +918,10 @@ def fork(self, name: str = "model-parallel-rng"): def reduce_scatter_along_first_dim( - inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False + inp: torch.Tensor, + tp_group: dist_group_type, + async_op: bool = False, + output: torch.Tensor = None, ) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: """Reduce-scatter the input tensor across model parallel group.""" world_size = get_distributed_world_size(tp_group) @@ -936,7 +939,8 @@ def reduce_scatter_along_first_dim( dim_size[0] = dim_size[0] // world_size - output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device()) + if output is None: + output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device()) handle = torch.distributed.reduce_scatter_tensor( output, inp.contiguous(), group=tp_group, async_op=async_op ) @@ -1281,11 +1285,13 @@ def _post_process_nvfp4_gather( handle = None # Fix the interleaved transposed data from gathering along first dim. - out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) - out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) + # In-place .copy_() (not `=` rebind) to keep the storage address stable + # for CUDA graph capture — replays see the same pointer they captured. + out._columnwise_scale_inv.copy_(_swap_first_dims(columnwise_scale_inv_interleaved, world_size)) + out._columnwise_data.copy_(_swap_first_dims(columnwise_data_interleaved, world_size)) - # Optionally pad the scaling inverse if needed. - out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) + # Optionally pad the scaling inverse if needed (same in-place pattern). + out._columnwise_scale_inv.copy_(pad_columnwise_scale_inv(out._columnwise_scale_inv)) @dataclass @@ -1299,17 +1305,25 @@ class _NVFP4AllGatherAsyncHandle: async_handle: torch.distributed.Work _synchronized: bool = False - def wait(self) -> None: - """Wait for the async operation to complete and post-process the tensor.""" - if self._synchronized: - return - self.async_handle.wait() + def post_process_nvfp4_gather(self) -> None: + """Fix interleaved transposed data + pad scale_inv after the async AG completes. + + Idempotent: gated by ``_synchronized`` in :meth:`wait`. + """ _post_process_nvfp4_gather( self.output, self.columnwise_data_interleaved, self.columnwise_scale_inv_interleaved, self.world_size, ) + + def wait(self) -> None: + """Wait for the async operation to complete and post-process the tensor.""" + if self._synchronized: + return + if self.async_handle is not None: + self.async_handle.wait() + self.post_process_nvfp4_gather() self._synchronized = True @@ -1320,6 +1334,8 @@ def _all_gather_nvfp4( async_op: bool = False, quantizer: NVFP4Quantizer, out_shape: Optional[list[int]] = None, + output_tensor=None, + grouped=False, ) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]: """All-gather NVFP4 tensor along first dimension.""" @@ -1383,6 +1399,12 @@ def _all_gather_nvfp4( out = quantizer(out) return out, None + # Construct NVFP4 output tensor + if output_tensor is not None: + out = output_tensor + else: + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + # Cast input tensor to NVFP4 with required data if not isinstance(inp, NVFP4TensorStorage): inp = quantizer(inp) @@ -1395,17 +1417,19 @@ def _all_gather_nvfp4( ) inp = quantizer(inp.dequantize(dtype=dtype)) - # Construct NVFP4 output tensor - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - - # Coalesce NCCL collectives for gathering data and scale inverses. - with torch.distributed._coalescing_manager( - group=process_group, - device=device, - async_ops=async_op, - ) as gather_coalescing_manager: + if not grouped: + # Coalesce NCCL collectives for gathering data and scale inverses. + gather_coalescing_manager = torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) + else: + gather_coalescing_manager = nullcontext() + with gather_coalescing_manager as coalesced_handle: # Gather NVFP4 data for row-wise usage + out_columnwise_data = None if quantizer.rowwise_usage: # Remove padding from NVFP4 scale-inverses @@ -1433,8 +1457,9 @@ def _all_gather_nvfp4( group=process_group, ) - # Transfer amax to output. - out._amax_rowwise = inp._amax_rowwise + # Transfer amax to output via in-place .copy_() so the storage + # address stays stable for CUDA graph capture. + out._amax_rowwise.copy_(inp._amax_rowwise) # Gather the transposed NVFP4 data along first dimension. Fix format later. if quantizer.columnwise_usage: @@ -1483,17 +1508,24 @@ def _all_gather_nvfp4( ) # Transfer amax to output. - out._amax_columnwise = inp._amax_columnwise + out._amax_columnwise.copy_(inp._amax_columnwise) - handle = gather_coalescing_manager if async_op else None + handle = coalesced_handle if async_op else None # Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed. - if async_op and quantizer.columnwise_usage: - handle = _NVFP4AllGatherAsyncHandle( - out, out_columnwise_data, out_scale_inv, world_size, handle - ) - elif quantizer.columnwise_usage: - _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) + if quantizer.columnwise_usage: + if async_op or grouped: + # Defer post-processing: either the async op hasn't completed yet, or an + # external coalescing manager owns the NCCL ops and hasn't flushed them. + inner_handle = handle if async_op else None + handle = _NVFP4AllGatherAsyncHandle( + out, out_columnwise_data, out_scale_inv, world_size, inner_handle + ) + else: + _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) + else: + if handle is not None: + handle.output = out return out, handle @@ -1505,6 +1537,8 @@ def _all_gather_mxfp8( async_op: bool = False, quantizer: MXFP8Quantizer, out_shape: Optional[list[int]] = None, + output_tensor: torch.Tensor = None, + grouped: bool = False, ) -> tuple[MXFP8TensorStorage, Optional[torch.distributed.Work]]: """All-gather MXFP8 tensor along first dimension.""" @@ -1570,15 +1604,22 @@ def _all_gather_mxfp8( inp = quantizer(inp.dequantize(dtype=dtype)) # Construct MXFP8 output tensor - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + if output_tensor is not None: + out = output_tensor + else: + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - # Coalesce NCCL collectives - with torch.distributed._coalescing_manager( - group=process_group, - device=device, - async_ops=async_op, - ) as coalescing_manager: + if not grouped: + # Coalesce NCCL collectives for gathering data and scale inverses. + gather_coalescing_manager = torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) + else: + gather_coalescing_manager = nullcontext() + with gather_coalescing_manager as coalesced_handle: # Gather MXFP8 data for row-wise usage if quantizer.rowwise_usage: @@ -1625,7 +1666,7 @@ def _all_gather_mxfp8( group=process_group, ) - handle = coalescing_manager if async_op else None + handle = coalesced_handle if async_op else None return out, handle @@ -1634,6 +1675,8 @@ def gather_along_first_dim( process_group: dist_group_type, async_op: bool = False, quantizer: Optional[Quantizer] = None, + output_tensor: torch.Tensor = None, + grouped: bool = False, ) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: """ All-gather tensors and concatenate along first dimension. @@ -1724,6 +1767,8 @@ def gather_along_first_dim( async_op=async_op, quantizer=quantizer, out_shape=out_shape, + output_tensor=output_tensor, + grouped=grouped, ) # NVFP4 case @@ -1738,6 +1783,8 @@ def gather_along_first_dim( async_op=async_op, quantizer=quantizer, out_shape=out_shape, + output_tensor=output_tensor, + grouped=grouped, ) # High-precision communication for quantized tensors @@ -1767,19 +1814,20 @@ def gather_along_first_dim( inp = inp.dequantize() # Communication for plain PyTorch tensors - out = torch.empty( - out_shape, - dtype=inp.dtype, - device=inp.device, - memory_format=torch.contiguous_format, - ) + if output_tensor is None: + output_tensor = torch.empty( + out_shape, + dtype=inp.dtype, + device=inp.device, + memory_format=torch.contiguous_format, + ) handle = torch.distributed.all_gather_into_tensor( - out, + output_tensor, inp.contiguous(), group=process_group, async_op=async_op, ) - return out, handle + return output_tensor, handle # Global cache to store symmetric memory tensors diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 746177ec78..64853c54c4 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -61,7 +61,13 @@ from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled -__all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"] +__all__ = [ + "initialize_ub", + "destroy_ub", + "UserBufferQuantizationMode", + "register_gtp_hooks", + "maybe_wrap_gtp", +] _2X_ACC_FPROP = False _2X_ACC_DGRAD = True @@ -72,6 +78,47 @@ layers_atomic_ring_exchange = [] +# GTP hook slots. An external integrator (currently ``megatron.experimental.gtp``) +# populates these via ``register_gtp_hooks`` at its own import time. When the +# slots stay ``None``, the ``gtp_group=`` codepath in TE modules is a no-op +# and TE has no ``from megatron...`` dependency. +_gtp_slice_fn = None +_gtp_finalize_fn = None +_gtp_wrap_fn = None + + +def register_gtp_hooks(*, slice_fn=None, finalize_fn=None, wrap_fn=None): + """Register GTP integration hooks. Hooks left as ``None`` are unchanged. + + slice_fn(module, name, param, *, expert_idx) -> GTPShardedParam | None + Fires per weight during ``reset_parameters``, before FP8 quantize. + finalize_fn(module, weight_names) -> None + Fires after the per-weight loop in ``reset_parameters``. + wrap_fn(module, weight_names, gtp_group, is_grouped=False) -> None + Fires at the end of a module's ``__init__`` to finalize GTP wiring. + """ + global _gtp_slice_fn, _gtp_finalize_fn, _gtp_wrap_fn + if slice_fn is not None: + _gtp_slice_fn = slice_fn + if finalize_fn is not None: + _gtp_finalize_fn = finalize_fn + if wrap_fn is not None: + _gtp_wrap_fn = wrap_fn + + +def maybe_wrap_gtp(module, weight_names, gtp_group, is_grouped=False): + """Finalize GTP wiring on a module if a wrap hook is registered. + + No-op when ``gtp_group`` is None or no GTP integrator has called + ``register_gtp_hooks``. Called from each TE module's ``__init__`` after + ``reset_parameters`` finishes; the per-weight slice already happened + inside ``reset_parameters`` via ``_gtp_slice_fn``. + """ + if gtp_group is None or _gtp_wrap_fn is None: + return + _gtp_wrap_fn(module, weight_names, gtp_group, is_grouped=is_grouped) + + class UserBufferQuantizationMode(Enum): """ UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer. @@ -1604,7 +1651,10 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: if defer_init: return - for name, param in self.named_parameters(recurse=False): + # Names of GTP-sharded weights, for GroupedLinear's post-loop finalize. + _gtp_sharded_weight_names = [] + + for idx, (name, param) in enumerate(self.named_parameters(recurse=False)): # Check if parameter is a DTensor (FSDP2) or regular tensor is_dtensor = isinstance(param, DTensor) dtensor_param = param if is_dtensor else None @@ -1626,10 +1676,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: with get_rng_state_tracker().fork(): init_fn(param) + # GTP slice: shard the freshly-init weight into a GTPShardedParam; + # the FP8 quantize block below is skipped for it. + gtp_sharded = None + if ( + not is_dtensor + and getattr(self, "_gtp_group", None) is not None + and _gtp_slice_fn is not None + ): + gtp_sharded = _gtp_slice_fn(self, name, param, expert_idx=idx) + if gtp_sharded is not None: + param = gtp_sharded + _gtp_sharded_weight_names.append(name) + # Wrap parameters in QuantizedTensor if needed fp8_meta_index = self.param_init_meta[name].fp8_meta_index high_precision_init_val = None - if self.primary_weights_in_fp8 and fp8_meta_index is not None: + if self.primary_weights_in_fp8 and fp8_meta_index is not None and gtp_sharded is None: # Keep high-precision values on CPU if needed if self.preserve_high_precision_init_val: @@ -1657,6 +1720,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. + # Skip the wrap for GTPShardedParam (Parameter.__new__ would drop attrs). if is_dtensor: # recreate the DTensor from the parameter. dtensor_param = DTensor.from_local( @@ -1667,7 +1731,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: stride=dtensor_param.stride(), ) dtensor_param = torch.nn.Parameter(dtensor_param) - else: + elif gtp_sharded is None: param = torch.nn.Parameter(param) # Keep high-precision values on CPU if needed @@ -1705,6 +1769,10 @@ def clear(self): else: self.module_setattr(name, dtensor_param) + # GroupedLinear post-loop finalize hook (no-op outside GroupedLinear). + if _gtp_sharded_weight_names and _gtp_finalize_fn is not None: + _gtp_finalize_fn(self, _gtp_sharded_weight_names) + @abstractmethod def forward(self): """Needs override.""" diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 627144345c..5f85f9c56e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -24,6 +24,7 @@ _2X_ACC_WGRAD, ) from ._common import WeightGradStore +from .base import maybe_wrap_gtp from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( divide, @@ -100,6 +101,7 @@ def forward( skip_fp8_weight_update, save_original_input, debug, + gtp_size, ) = non_tensor_args if fp8: backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override @@ -114,6 +116,14 @@ def forward( device = inp.device weight_requires_grad = weights[0].requires_grad + weights_gtp_sharded = weights + if gtp_size > 1: + weights = weights[0].batched_all_gather_and_prefetch( + fwd=True, + skip_weight_cast=is_first_microbatch is False, + cast_noop_flag=skip_fp8_weight_update, + ) + # Configure quantizers if save_original_input and isinstance(input_quantizers[0], Float8Quantizer): if FP8GlobalStateManager.get_fp8_recipe().custom(): @@ -276,12 +286,19 @@ def forward( else: inputmats = [None] * num_gemms - tensors_to_save, tensor_objects = prepare_for_saving( - *inputmats, - *weights_fp8, - *weights, - *biases, - ) + if gtp_size == 1: + tensors_to_save, tensor_objects = prepare_for_saving( + *inputmats, + *weights_fp8, + *weights, + *biases, + ) + else: + tensors_to_save, tensor_objects = prepare_for_saving( + *inputmats, + *weights_gtp_sharded, + *biases, + ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects @@ -303,6 +320,10 @@ def forward( if hasattr(weights[0], "__fsdp_param__"): # MCore FSDP creates main_grad lazily before backward ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_gemms)] + elif gtp_size > 1: + ctx.main_grad_funcs = [ + weights_gtp_sharded[i].get_wgrad_tensor for i in range(num_gemms) + ] else: ctx.main_grad_funcs = [ lambda j=i: weights[j].main_grad for i in range(num_gemms) @@ -332,6 +353,7 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + ctx.gtp_size = gtp_size # backward overrides if backward_override is not None: @@ -357,17 +379,30 @@ def backward( with get_nvtx_range_context("_GroupedLinear_backward"): saved_tensors = restore_from_func_ctx(ctx) N = ctx.num_gemms - inputmats = saved_tensors[:N] - weights = saved_tensors[N : 2 * N] - saved_weights = saved_tensors[2 * N : 3 * N] - biases = saved_tensors[3 * N : 4 * N] + if ctx.gtp_size == 1: + inputmats = saved_tensors[:N] + weights = saved_tensors[N : 2 * N] + saved_weights = saved_tensors[2 * N : 3 * N] + biases = saved_tensors[3 * N : 4 * N] + gtp_origin_weights = None + else: + inputmats = saved_tensors[:N] + gtp_origin_weights = saved_tensors[N : 2 * N] + biases = saved_tensors[2 * N : 3 * N] + weights = None # Restore from weakrefs to get original weight python objects # (preserves attributes like main_grad, grad_added_to_main_grad, etc.) # Only needed when fuse_wgrad_accumulation is enabled. origin_weights = [None] * N main_grads = [None] * N - if ctx.fuse_wgrad_accumulation and ctx.weights_requires_grad: + if ctx.gtp_size > 1: + # GTP: origin_weights come from saved tensors; main_grads are + # get_wgrad_tensor scratch (do not assign to param.main_grad). + origin_weights = gtp_origin_weights + if ctx.fuse_wgrad_accumulation and ctx.weights_requires_grad: + main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + elif ctx.fuse_wgrad_accumulation and ctx.weights_requires_grad: origin_weight_refs = ctx.origin_weight_refs ctx.origin_weight_refs = None origin_weights = [ref() if ref is not None else None for ref in origin_weight_refs] @@ -428,13 +463,18 @@ def backward( ctx.m_splits, ) - if ctx.is_first_microbatch is not None: + if ctx.gtp_size > 1: + accumulate_wgrad_into_param_main_grad = False + elif ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + if ctx.gtp_size > 1: + weights = origin_weights[0].batched_all_gather_and_prefetch_bwd() + if ctx.requires_dgrad: dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD if ctx.fp8 or ctx.debug: @@ -485,6 +525,14 @@ def backward( use_split_accumulator=dgrad_gemm_use_split_accumulator, ) + # Gathered weights are no longer needed after dgrad GEMM. + # For nvfp4, the NVFP4TensorStorage and its sub-tensors (scale_inv etc.) + # would otherwise survive until function return via this local ref. + w_shape = None + if ctx.gtp_size > 1: + w_shape = list(weights[0].size()) + del weights + if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD if ctx.fp8: @@ -496,7 +544,7 @@ def backward( if ctx.fuse_wgrad_accumulation: wgrad_list = main_grads else: - weight_shape = list(weights[0].size()) + weight_shape = w_shape if ctx.gtp_size > 1 else list(weights[0].size()) wgrad_list = tex.bulk_allocate( [weight_shape] * ctx.num_gemms, [ctx.activation_dtype] * ctx.num_gemms, @@ -553,7 +601,8 @@ def backward( use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=( accumulate_wgrad_into_param_main_grad - if not getattr(ctx, "origin_weights_overwrite_main_grad", False) + if ctx.gtp_size == 1 + and not getattr(ctx, "origin_weights_overwrite_main_grad", False) else False ), ) @@ -595,10 +644,19 @@ def handle_custom_ddp_from_mcore(weight, main_grad, wgrad): wgrad = None return wgrad - wgrad_list = [ - handle_custom_ddp_from_mcore(weight, main_grad, wgrad) - for weight, main_grad, wgrad in zip(origin_weights, main_grads, wgrad_list) - ] + if ctx.gtp_size > 1: + wgrad_list = origin_weights[0].batched_wgrad_reduce_scatter(wgrad_list) + # Drop Python refs to wgrad input buffers. The async RS on rs_stream + # still holds C++ refs (via NCCL Work); those are released when + # _wait_reduce_scatter calls handle.wait() + self.handle = None. + # Without this del, main_grads keeps the tensors alive until function + # return, wasting memory during graph capture warmup. + del main_grads + else: + wgrad_list = [ + handle_custom_ddp_from_mcore(weight, main_grad, wgrad) + for weight, main_grad, wgrad in zip(origin_weights, main_grads, wgrad_list) + ] else: wgrad_list = [None] * ctx.num_gemms @@ -716,6 +774,7 @@ def __init__( single_grouped_weight: bool = False, single_grouped_bias: bool = False, name: Optional[str] = None, + gtp_group: Optional[dist_group_type] = None, ) -> None: super().__init__(name) @@ -771,6 +830,11 @@ def __init__( "Because the TP communication is handled outside of this module." ) + if gtp_group is None: + self.gtp_size = 1 + else: + self.gtp_size = get_distributed_world_size(gtp_group) + self.parallel_mode = parallel_mode if self.parallel_mode not in GemmParallelModes: raise ValueError( @@ -822,9 +886,18 @@ def __init__( if self.primary_weights_in_fp8: self.init_fp8_metadata(num_gemms=self.num_gemms) + self.weight_names = [f"weight{idx}" for idx in range(self.num_gemms)] is_meta = torch.device(device).type == "meta" + if gtp_group is not None: + # Stashed before reset_parameters so the slice hook can see it; + # _gtp_is_grouped routes through the GroupedLinear finalize path. + self._gtp_group = gtp_group + self._gtp_is_grouped = True + self.reset_parameters(defer_init=is_meta) + maybe_wrap_gtp(self, self.weight_names, gtp_group, is_grouped=True) + if self.wgrad_store.delay_wgrad_compute(): for name, param in self.named_parameters(): if name in ("weight", "bias"): @@ -1148,6 +1221,11 @@ def forward( weight_tensors = self._get_weight_tensors() bias_tensors = self._get_bias_tensors() + if self.gtp_size > 1: + weight_tensors[0].setup( + weight_quantizer=self._get_weight_quantizers(), + ) + quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() if debug: @@ -1202,6 +1280,7 @@ def forward( None, # skip_fp8_weight_update self.save_original_input, debug, + self.gtp_size, ) out, new_workspaces = linear_fn( *autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 8c88f3ee82..8e000e3764 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -28,6 +28,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) +from .base import maybe_wrap_gtp from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( assert_dim_for_fp8_exec, @@ -143,6 +144,7 @@ def forward( symmetric_ar_type, debug, is_fsdp2, + gtp_size, ) = non_tensor_args if fp8: backward_override = FP8GlobalStateManager.get_fp8_recipe().backward_override @@ -297,6 +299,16 @@ def forward( # ------------------------------------------------------ # Prepare weight tensor # ------------------------------------------------------ + + weight_gtp_sharded = weight + if gtp_size > 1: + weight = weight.all_gather_and_prefetch( + fwd=True, + skip_weight_cast=is_first_microbatch is False, + cast_noop_flag=skip_fp8_weight_update, + ) + out_features = weight.shape[0] + new_weight_workspace = None weightmat = weight is_weight_param_quantized = False @@ -484,8 +496,9 @@ def forward( wt_save = None tensors_to_save, tensor_objects = prepare_for_saving( inputmat, - wt_save, - weight, + # GTP: save the sharded reference only; backward re-gathers it. + wt_save if gtp_size == 1 else None, + weight if gtp_size == 1 else weight_gtp_sharded, bias, ln_weight, ln_out_to_save, @@ -512,6 +525,8 @@ def forward( if hasattr(weight, "__fsdp_param__"): # MCore FSDP creates main_grad lazily before backward ctx.main_grad_func = weight.get_main_grad + elif gtp_size > 1: + ctx.main_grad_func = weight_gtp_sharded.get_wgrad_tensor else: ctx.main_grad_func = lambda: weight.main_grad ctx.grad_input_quantizer = grad_input_quantizer @@ -554,6 +569,7 @@ def forward( qstate.is_first_fp8_module = _first_fp8_module ctx.wgrad_store = wgrad_store ctx.debug = debug + ctx.gtp_size = gtp_size # backward overrides if backward_override is not None: @@ -605,6 +621,9 @@ def backward( rsigma, ) = restore_from_func_ctx(ctx) + if ctx.gtp_size > 1: + weight = saved_weight.all_gather_and_prefetch_bwd() + # Restore from weakref to get original weight python object # (preserves attributes like main_grad, grad_added_to_main_grad, etc.) # Only needed when fuse_wgrad_accumulation is enabled. @@ -622,7 +641,7 @@ def backward( ), "weight was removed while fuse_wgrad_accumulation=True" # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ctx.main_grad_func() if weight is not None else None - if main_grad is not None: + if main_grad is not None and ctx.gtp_size == 1: origin_weight.main_grad = main_grad # Gather intermediate/activation tensors if needed @@ -929,7 +948,10 @@ def backward( use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator # Figure out whether to output wgrad GEMM directly into main grad - if ctx.is_first_microbatch is not None: + if ctx.gtp_size > 1: + # GTP: accumulation happens downstream in wgrad_reduce_scatter. + accumulate_wgrad_into_param_main_grad = False + elif ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) @@ -1001,6 +1023,9 @@ def wgrad_gemm( # Call wgrad GEMM now wgrad, grad_bias_ = wgrad_gemm(ln_out_total, grad_output) + if ctx.gtp_size > 1: + wgrad = saved_weight.wgrad_reduce_scatter(wgrad) + # Update grad bias if needed if grad_bias is None: grad_bias = grad_bias_ @@ -1080,7 +1105,10 @@ def wgrad_gemm( if ctx.requires_wgrad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): + if ctx.gtp_size > 1: + # GTP: skip — wgrad RS already produced the correct shard. + pass + elif ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): origin_weight.grad_added_to_main_grad = True if getattr(origin_weight, "zero_out_wgrad", False): wgrad = get_dummy_wgrad( @@ -1247,6 +1275,7 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, name: Optional[str] = None, + gtp_group: Optional[dist_group_type] = None, ) -> None: super().__init__(name) @@ -1277,6 +1306,11 @@ def __init__( self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() + if gtp_group is None: + self.gtp_size = 1 + else: + self.gtp_size = get_distributed_world_size(gtp_group) + self.parallel_mode = parallel_mode assert ( self.parallel_mode in GemmParallelModes @@ -1471,8 +1505,18 @@ def __init__( if with_fp8_params: self.init_fp8_metadata() + if gtp_group is not None: + # Stashed before reset_parameters so the slice hook can see it. + self._gtp_group = gtp_group + self._gtp_is_grouped = False + self.reset_parameters(defer_init=device == "meta") + maybe_wrap_gtp(self, self.weight_names, gtp_group) + if gtp_group is not None: + # Free the full-size backing buffer; GTP replaced it with a sharded param. + del weight_tensor + # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM if self.parallel_mode == "row" and self.apply_bias: @@ -1635,6 +1679,11 @@ def forward( # Get concatenated weight and bias tensors weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + if self.gtp_size > 1: + weight_tensor.setup( + weight_quantizer=self._get_weight_quantizers(), + ) + quantizers = ( self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) if not debug @@ -1705,6 +1754,7 @@ def forward( self.symmetric_ar_type, debug, self.is_fsdp2, + self.gtp_size, ) out, ln_out, new_weight_workspace = fwd_fn( *autograd_ctx, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index dcbb9eaf93..7b4e18f066 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -28,6 +28,7 @@ _2X_ACC_WGRAD, ) from ._common import noop_cat, WeightGradStore +from .base import maybe_wrap_gtp from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( cast_if_needed, @@ -153,6 +154,9 @@ class LinearFwdArgs: cpu_offloading: bool is_grad_enabled: bool + # --- Generalized tensor parallelism --- + gtp_size: int = 1 + @dataclass(slots=True) class LinearBwdArgs: @@ -222,6 +226,9 @@ class LinearBwdArgs: cpu_offloading: bool = False owns_input: bool = False + # --- Generalized tensor parallelism --- + gtp_size: int = 1 + # --- Per-backward scratch state (populated inside _linear_backward) --- ub_obj_gradout: Optional[Any] = None @@ -399,6 +406,15 @@ def _linear_forward_impl( # ------------------------------------------------------ # Prepare weight tensor # ------------------------------------------------------ + # GTP: rebind `weight` to the all-gathered tensor; `args.weight` keeps + # the GTPShardedParam reference for backward re-gather / wgrad RS. + if args.gtp_size > 1: + weight = weight.all_gather_and_prefetch( + fwd=True, + skip_weight_cast=is_first_microbatch is False, + cast_noop_flag=args.skip_fp8_weight_update, + ) + new_weight_workspace = None weightmat = weight if fp8 or debug: @@ -576,6 +592,9 @@ def _linear_forward_impl( wt_save = weightmat if is_fsdp2 and weightmat is not weight: wt_save = None + # GTP: don't save the workspace; backward re-gathers it. + if args.gtp_size > 1: + wt_save = None # Dedup save slots that alias forward inputs; ``_linear_setup_ctx`` # rebuilds the refs from ``inp`` / ``weight`` / ``bias``. @@ -680,11 +699,14 @@ def _linear_setup_ctx( bwd_args.origin_weight_overwrites_main_grad = getattr(weight, "overwrite_main_grad", False) if hasattr(weight, "__fsdp_param__"): bwd_args.main_grad_func = weight.get_main_grad + elif fwd_args.gtp_size > 1: + bwd_args.main_grad_func = weight.get_wgrad_tensor else: bwd_args.main_grad_func = lambda: weight.main_grad # Misc bwd_args.cpu_offloading = fwd_args.cpu_offloading + bwd_args.gtp_size = fwd_args.gtp_size if backward_override is not None: bwd_args.fp8 = False @@ -751,7 +773,8 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. origin_weight_python_object is not None ), "weight was removed while fuse_wgrad_accumulation=True" main_grad = bwd_args.main_grad_func() - origin_weight_python_object.main_grad = main_grad + if bwd_args.gtp_size == 1: + origin_weight_python_object.main_grad = main_grad # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when bwd_args.fp8 == False and torch.disttributed.FSDP already @@ -921,6 +944,12 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. dgrad = None dgrad_work = None + + # GTP: re-gather the sharded weight; runs even when requires_dgrad=False + # so the prev_w prefetch is issued for the next layer's bwd. + if bwd_args.gtp_size > 1: + weight_fp8 = saved_weight.all_gather_and_prefetch_bwd() + if bwd_args.requires_dgrad: # FSDP2: Re-create workspace from all-gathered weight when @@ -1102,7 +1131,10 @@ def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], .. use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator # Figure out whether to output wgrad GEMM directly into main grad - if bwd_args.is_first_microbatch is not None: + if bwd_args.gtp_size > 1: + # GTP: accumulation happens downstream in wgrad_reduce_scatter. + accumulate_wgrad_into_param_main_grad = False + elif bwd_args.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( bwd_args.fuse_wgrad_accumulation and not bwd_args.is_first_microbatch ) @@ -1178,6 +1210,11 @@ def wgrad_gemm( # Call wgrad GEMM now wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output) + # GTP: reduce-scatter the freshly computed wgrad (async; overlap + # with the next layer's bwd via the cascade). + if bwd_args.gtp_size > 1: + wgrad = saved_weight.wgrad_reduce_scatter(wgrad) + # Update grad bias if needed if grad_bias is None: grad_bias = grad_bias_ @@ -1223,15 +1260,19 @@ def wgrad_gemm( origin_weight_python_object, "grad_added_to_main_grad" ): origin_weight_python_object.grad_added_to_main_grad = True + # Use the param's local shape (sharded under GTP) so the dummy wgrad + # matches the saved weight shape; main_grad_func() under GTP returns + # an unsharded scratch and would otherwise mismatch. + wgrad_shape = list(origin_weight_python_object.shape) if getattr(origin_weight_python_object, "zero_out_wgrad", False): wgrad = get_dummy_wgrad( - list(main_grad.shape), + wgrad_shape, origin_weight_python_object.dtype, zero=True, ) else: wgrad = get_dummy_wgrad( - list(main_grad.shape), + wgrad_shape, origin_weight_python_object.dtype, ) elif bwd_args.fuse_wgrad_accumulation: @@ -1447,6 +1488,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, name: Optional[str] = None, + gtp_group: Optional[dist_group_type] = None, ) -> None: super().__init__(name) @@ -1475,6 +1517,11 @@ def __init__( self.set_tensor_parallel_group(tp_group) self.set_nccl_overlap_warning_if_tp() + if gtp_group is None: + self.gtp_size = 1 + else: + self.gtp_size = get_distributed_world_size(gtp_group) + self.parallel_mode = parallel_mode assert ( self.parallel_mode in GemmParallelModes @@ -1644,8 +1691,18 @@ def __init__( if with_fp8_params: self.init_fp8_metadata() + if gtp_group is not None: + # Stashed before reset_parameters so the slice hook can see it. + self._gtp_group = gtp_group + self._gtp_is_grouped = False + self.reset_parameters(defer_init=device == "meta") + maybe_wrap_gtp(self, self.weight_names, gtp_group) + if gtp_group is not None: + # Free the full-size backing buffer; GTP replaced it with a sharded param. + del weight_tensor + # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM if self.parallel_mode == "row" and self.apply_bias: @@ -1776,6 +1833,11 @@ def forward( try: weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + if self.gtp_size > 1: + weight_tensor.setup( + weight_quantizer=self._get_weight_quantizers(), + ) + quantizers = ( self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) if not debug @@ -1894,6 +1956,8 @@ def forward( # misc cpu_offloading=is_cpu_offload_enabled(), is_grad_enabled=is_grad_enabled, + # generalized tensor parallelism + gtp_size=self.gtp_size, ) out, new_weight_workspace = linear_fn( *autograd_ctx,