diff --git a/src/liger_kernel/ops/__init__.py b/src/liger_kernel/ops/__init__.py index 6a34b18b4..58b987962 100644 --- a/src/liger_kernel/ops/__init__.py +++ b/src/liger_kernel/ops/__init__.py @@ -77,6 +77,7 @@ from liger_kernel.ops.swiglu import swiglu_backward # noqa: F401 from liger_kernel.ops.swiglu import swiglu_forward # noqa: F401 from liger_kernel.ops.tiled_mlp import LigerTiledMLPFunction # noqa: F401 +from liger_kernel.ops.tiled_mlp import LigerTiledMLPFunctionDDP # noqa: F401 from liger_kernel.ops.tiled_mlp import apply_tiled_mlp # noqa: F401 from liger_kernel.ops.tvd import LigerTVDLossFunction # noqa: F401 diff --git a/src/liger_kernel/ops/tiled_mlp.py b/src/liger_kernel/ops/tiled_mlp.py index 2c1943c3a..40dca7434 100644 --- a/src/liger_kernel/ops/tiled_mlp.py +++ b/src/liger_kernel/ops/tiled_mlp.py @@ -98,12 +98,103 @@ def backward(ctx, *grads) -> tuple: return (None, None, x_grad, None, None) +class LigerTiledMLPFunctionDDP(torch.autograd.Function): + """ + DDP-compatible variant of LigerTiledMLPFunction. + + Accumulates parameter gradients across shards and only assigns .grad after + the last shard, so DDP's gradient reduction runs once per backward. + Use via apply_tiled_mlp(..., ddp_safe=True). + + See: https://github.com/linkedin/Liger-Kernel/issues/893 + """ + + @staticmethod + @ensure_contiguous + def forward( + ctx, + fn: Callable, + mlp_module: torch.nn.Module, + x: torch.Tensor, + shards: int, + compute_params: Optional[List[torch.nn.Parameter]] = None, + ) -> torch.Tensor: + ctx.fn = fn + ctx.mlp_module = mlp_module + ctx.shards = shards + ctx.compute_params = [p for p in (compute_params or []) if p.requires_grad] + ctx.save_for_backward(x) + + x_shards = list(torch.chunk(x, chunks=shards, dim=-2)) + with torch.no_grad(): + output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards] + output_unsharded = torch.cat(output_shards, dim=-2) + + return output_unsharded + + @staticmethod + @ensure_contiguous + def backward(ctx, *grads) -> tuple: + fn = ctx.fn + (x,) = ctx.saved_tensors + mlp_module = ctx.mlp_module + shards = ctx.shards + compute_params = ctx.compute_params + + x_requires_grad = x.requires_grad + x = x.detach() + x.requires_grad_(x_requires_grad) + + hidden_size = x.shape[-1] + x_shape_orig = x.shape + x = x.view(-1, hidden_size) + incoming_grad = grads[0].view(-1, hidden_size) + x_grad = torch.zeros_like(x) + x_shards = list(torch.chunk(x, chunks=shards, dim=0)) + + # Accumulate param grads across shards; assign only after last shard (DDP-safe) + accumulated = { + p: torch.zeros_like(p, dtype=p.dtype, device=p.device) + for p in compute_params + } + + for i, x_shard in enumerate(x_shards): + x_shard.requires_grad_(x_requires_grad) + shard_step = x_shards[i].shape[0] + shard_offset = i * x_shards[0].shape[0] + x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard) + incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as( + x_shard + ) + + # Clear param.grad so this shard's backward fills it; we'll accumulate below + for p in compute_params: + if p.grad is not None: + p.grad.zero_() + + with torch.enable_grad(): + output = fn(mlp_module, x_shard) + torch.autograd.backward(output, incoming_grad_shard) + + for p in compute_params: + if p.grad is not None: + accumulated[p].add_(p.grad) + + # Assign accumulated gradients only once (after last shard) + for p in compute_params: + p.grad = accumulated[p] + + x_grad = x_grad.view(x_shape_orig) + return (None, None, x_grad, None, None) + + def apply_tiled_mlp( fn: Callable, mlp_module: torch.nn.Module, x: torch.Tensor, num_shards: Optional[int] = None, compute_params: Optional[List[torch.nn.Parameter]] = None, + ddp_safe: bool = False, ) -> torch.Tensor: """ Apply tiled MLP computation for memory efficiency. @@ -114,6 +205,8 @@ def apply_tiled_mlp( x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size] num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size) compute_params: list of parameters for DeepSpeed ZeRO optimization + ddp_safe: if True, accumulate parameter gradients across shards and assign only after + the last shard, making the backward pass compatible with PyTorch DDP. Returns: output tensor with the same shape as input @@ -127,6 +220,14 @@ def apply_tiled_mlp( # Ensure num_shards is at least 1 num_shards = max(1, num_shards) + if ddp_safe: + return LigerTiledMLPFunctionDDP.apply( + fn, + mlp_module, + x, + num_shards, + compute_params, + ) return LigerTiledMLPFunction.apply( fn, mlp_module, diff --git a/src/liger_kernel/transformers/tiled_mlp.py b/src/liger_kernel/transformers/tiled_mlp.py index b72507b2e..77830b824 100644 --- a/src/liger_kernel/transformers/tiled_mlp.py +++ b/src/liger_kernel/transformers/tiled_mlp.py @@ -19,14 +19,17 @@ class LigerTiledGEGLUMLP(nn.Module): config: Model configuration with hidden_size and intermediate_size attributes num_shards: Number of shards to split the sequence. If None, automatically calculated as ceil(seqlen / hidden_size) + ddp_safe: If True, use DDP-compatible backward (gradients assigned only after + last shard). Set True when using with torch.nn.parallel.DistributedDataParallel. """ - def __init__(self, config, num_shards: Optional[int] = None): + def __init__(self, config, num_shards: Optional[int] = None, ddp_safe: bool = False): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.num_shards = num_shards + self.ddp_safe = ddp_safe self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -65,6 +68,7 @@ def forward(self, x): x=x, num_shards=self.num_shards, compute_params=compute_params, + ddp_safe=self.ddp_safe, ) @@ -80,14 +84,17 @@ class LigerTiledSwiGLUMLP(nn.Module): config: Model configuration with hidden_size and intermediate_size attributes num_shards: Number of shards to split the sequence. If None, automatically calculated as ceil(seqlen / hidden_size) + ddp_safe: If True, use DDP-compatible backward (gradients assigned only after + last shard). Set True when using with torch.nn.parallel.DistributedDataParallel. """ - def __init__(self, config, num_shards: Optional[int] = None): + def __init__(self, config, num_shards: Optional[int] = None, ddp_safe: bool = False): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.num_shards = num_shards + self.ddp_safe = ddp_safe self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) @@ -122,4 +129,5 @@ def forward(self, x): x=x, num_shards=self.num_shards, compute_params=compute_params, + ddp_safe=self.ddp_safe, ) diff --git a/test/transformers/test_tiled_mlp.py b/test/transformers/test_tiled_mlp.py index bb9ecda09..9a309fb39 100644 --- a/test/transformers/test_tiled_mlp.py +++ b/test/transformers/test_tiled_mlp.py @@ -195,3 +195,53 @@ def test_tiled_swiglu_correctness( ) torch.testing.assert_close(x1.grad, x2.grad, atol=atol, rtol=rtol, msg="Input gradients don't match") + + +@pytest.mark.parametrize("num_shards", [2, 4]) +def test_tiled_swiglu_ddp_safe_gradient_parity(num_shards): + """Test that ddp_safe=True produces the same gradients as ddp_safe=False.""" + bsz, seq_len, hidden_size, intermediate_size = 2, 256, 128, 256 + dtype = torch.float32 + atol, rtol = 1e-5, 1e-4 + + config = LlamaConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + hidden_act="silu", + ) + + _input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype) * 0.1 + x1 = _input.detach().clone().requires_grad_(True) + x2 = _input.detach().clone().requires_grad_(True) + + G = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) + U = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) + D = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) + + tiled_default = LigerTiledSwiGLUMLP(config=config, num_shards=num_shards, ddp_safe=False).to(device).to(dtype) + tiled_default.gate_proj.weight.data = G.clone() + tiled_default.up_proj.weight.data = U.clone() + tiled_default.down_proj.weight.data = D.clone() + + tiled_ddp = LigerTiledSwiGLUMLP(config=config, num_shards=num_shards, ddp_safe=True).to(device).to(dtype) + tiled_ddp.gate_proj.weight.data = G.clone() + tiled_ddp.up_proj.weight.data = U.clone() + tiled_ddp.down_proj.weight.data = D.clone() + + y1 = tiled_default(x1) + y2 = tiled_ddp(x2) + torch.testing.assert_close(y1, y2, atol=atol, rtol=rtol, msg="Forward outputs don't match") + + dy = torch.randn_like(y1) + y1.backward(dy.clone()) + y2.backward(dy.clone()) + + for p1, p2 in zip(tiled_default.parameters(), tiled_ddp.parameters()): + torch.testing.assert_close( + p1.grad, + p2.grad, + atol=atol, + rtol=rtol, + msg="DDP-safe and default gradients do not match", + ) + torch.testing.assert_close(x1.grad, x2.grad, atol=atol, rtol=rtol, msg="Input gradients don't match")