Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/liger_kernel/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from liger_kernel.ops.swiglu import swiglu_backward # noqa: F401
from liger_kernel.ops.swiglu import swiglu_forward # noqa: F401
from liger_kernel.ops.tiled_mlp import LigerTiledMLPFunction # noqa: F401
from liger_kernel.ops.tiled_mlp import LigerTiledMLPFunctionDDP # noqa: F401
from liger_kernel.ops.tiled_mlp import apply_tiled_mlp # noqa: F401
from liger_kernel.ops.tvd import LigerTVDLossFunction # noqa: F401

Expand Down
101 changes: 101 additions & 0 deletions src/liger_kernel/ops/tiled_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,103 @@ def backward(ctx, *grads) -> tuple:
return (None, None, x_grad, None, None)


class LigerTiledMLPFunctionDDP(torch.autograd.Function):
"""
DDP-compatible variant of LigerTiledMLPFunction.

Accumulates parameter gradients across shards and only assigns .grad after
the last shard, so DDP's gradient reduction runs once per backward.
Use via apply_tiled_mlp(..., ddp_safe=True).

See: https://github.com/linkedin/Liger-Kernel/issues/893
"""

@staticmethod
@ensure_contiguous
def forward(
ctx,
fn: Callable,
mlp_module: torch.nn.Module,
x: torch.Tensor,
shards: int,
compute_params: Optional[List[torch.nn.Parameter]] = None,
) -> torch.Tensor:
ctx.fn = fn
ctx.mlp_module = mlp_module
ctx.shards = shards
ctx.compute_params = [p for p in (compute_params or []) if p.requires_grad]
ctx.save_for_backward(x)

x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
with torch.no_grad():
output_shards = [fn(mlp_module, x_shard) for x_shard in x_shards]
output_unsharded = torch.cat(output_shards, dim=-2)

return output_unsharded

@staticmethod
@ensure_contiguous
def backward(ctx, *grads) -> tuple:
fn = ctx.fn
(x,) = ctx.saved_tensors
mlp_module = ctx.mlp_module
shards = ctx.shards
compute_params = ctx.compute_params

x_requires_grad = x.requires_grad
x = x.detach()
x.requires_grad_(x_requires_grad)

hidden_size = x.shape[-1]
x_shape_orig = x.shape
x = x.view(-1, hidden_size)
incoming_grad = grads[0].view(-1, hidden_size)
x_grad = torch.zeros_like(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=0))

# Accumulate param grads across shards; assign only after last shard (DDP-safe)
accumulated = {
p: torch.zeros_like(p, dtype=p.dtype, device=p.device)
for p in compute_params
}

for i, x_shard in enumerate(x_shards):
x_shard.requires_grad_(x_requires_grad)
shard_step = x_shards[i].shape[0]
shard_offset = i * x_shards[0].shape[0]
x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(
x_shard
)

# Clear param.grad so this shard's backward fills it; we'll accumulate below
for p in compute_params:
if p.grad is not None:
p.grad.zero_()

with torch.enable_grad():
output = fn(mlp_module, x_shard)
torch.autograd.backward(output, incoming_grad_shard)

for p in compute_params:
if p.grad is not None:
accumulated[p].add_(p.grad)

# Assign accumulated gradients only once (after last shard)
for p in compute_params:
p.grad = accumulated[p]

x_grad = x_grad.view(x_shape_orig)
return (None, None, x_grad, None, None)


def apply_tiled_mlp(
fn: Callable,
mlp_module: torch.nn.Module,
x: torch.Tensor,
num_shards: Optional[int] = None,
compute_params: Optional[List[torch.nn.Parameter]] = None,
ddp_safe: bool = False,
) -> torch.Tensor:
"""
Apply tiled MLP computation for memory efficiency.
Expand All @@ -114,6 +205,8 @@ def apply_tiled_mlp(
x: the input tensor with shape [bs, seqlen, hidden_size] or [seqlen, hidden_size]
num_shards: number of shards to use. If None, automatically calculated as ceil(seqlen / hidden_size)
compute_params: list of parameters for DeepSpeed ZeRO optimization
ddp_safe: if True, accumulate parameter gradients across shards and assign only after
the last shard, making the backward pass compatible with PyTorch DDP.

Returns:
output tensor with the same shape as input
Expand All @@ -127,6 +220,14 @@ def apply_tiled_mlp(
# Ensure num_shards is at least 1
num_shards = max(1, num_shards)

if ddp_safe:
return LigerTiledMLPFunctionDDP.apply(
fn,
mlp_module,
x,
num_shards,
compute_params,
)
return LigerTiledMLPFunction.apply(
fn,
mlp_module,
Expand Down
12 changes: 10 additions & 2 deletions src/liger_kernel/transformers/tiled_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@ class LigerTiledGEGLUMLP(nn.Module):
config: Model configuration with hidden_size and intermediate_size attributes
num_shards: Number of shards to split the sequence. If None, automatically
calculated as ceil(seqlen / hidden_size)
ddp_safe: If True, use DDP-compatible backward (gradients assigned only after
last shard). Set True when using with torch.nn.parallel.DistributedDataParallel.
"""

def __init__(self, config, num_shards: Optional[int] = None):
def __init__(self, config, num_shards: Optional[int] = None, ddp_safe: bool = False):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.num_shards = num_shards
self.ddp_safe = ddp_safe

self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
Expand Down Expand Up @@ -65,6 +68,7 @@ def forward(self, x):
x=x,
num_shards=self.num_shards,
compute_params=compute_params,
ddp_safe=self.ddp_safe,
)


Expand All @@ -80,14 +84,17 @@ class LigerTiledSwiGLUMLP(nn.Module):
config: Model configuration with hidden_size and intermediate_size attributes
num_shards: Number of shards to split the sequence. If None, automatically
calculated as ceil(seqlen / hidden_size)
ddp_safe: If True, use DDP-compatible backward (gradients assigned only after
last shard). Set True when using with torch.nn.parallel.DistributedDataParallel.
"""

def __init__(self, config, num_shards: Optional[int] = None):
def __init__(self, config, num_shards: Optional[int] = None, ddp_safe: bool = False):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.num_shards = num_shards
self.ddp_safe = ddp_safe

self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
Expand Down Expand Up @@ -122,4 +129,5 @@ def forward(self, x):
x=x,
num_shards=self.num_shards,
compute_params=compute_params,
ddp_safe=self.ddp_safe,
)
50 changes: 50 additions & 0 deletions test/transformers/test_tiled_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,53 @@ def test_tiled_swiglu_correctness(
)

torch.testing.assert_close(x1.grad, x2.grad, atol=atol, rtol=rtol, msg="Input gradients don't match")


@pytest.mark.parametrize("num_shards", [2, 4])
def test_tiled_swiglu_ddp_safe_gradient_parity(num_shards):
"""Test that ddp_safe=True produces the same gradients as ddp_safe=False."""
bsz, seq_len, hidden_size, intermediate_size = 2, 256, 128, 256
dtype = torch.float32
atol, rtol = 1e-5, 1e-4

config = LlamaConfig(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act="silu",
)

_input = torch.randn(bsz, seq_len, hidden_size, device=device, dtype=dtype) * 0.1
x1 = _input.detach().clone().requires_grad_(True)
x2 = _input.detach().clone().requires_grad_(True)

G = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype)
U = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype)
D = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype)

tiled_default = LigerTiledSwiGLUMLP(config=config, num_shards=num_shards, ddp_safe=False).to(device).to(dtype)
tiled_default.gate_proj.weight.data = G.clone()
tiled_default.up_proj.weight.data = U.clone()
tiled_default.down_proj.weight.data = D.clone()

tiled_ddp = LigerTiledSwiGLUMLP(config=config, num_shards=num_shards, ddp_safe=True).to(device).to(dtype)
tiled_ddp.gate_proj.weight.data = G.clone()
tiled_ddp.up_proj.weight.data = U.clone()
tiled_ddp.down_proj.weight.data = D.clone()

y1 = tiled_default(x1)
y2 = tiled_ddp(x2)
torch.testing.assert_close(y1, y2, atol=atol, rtol=rtol, msg="Forward outputs don't match")

dy = torch.randn_like(y1)
y1.backward(dy.clone())
y2.backward(dy.clone())

for p1, p2 in zip(tiled_default.parameters(), tiled_ddp.parameters()):
torch.testing.assert_close(
p1.grad,
p2.grad,
atol=atol,
rtol=rtol,
msg="DDP-safe and default gradients do not match",
)
torch.testing.assert_close(x1.grad, x2.grad, atol=atol, rtol=rtol, msg="Input gradients don't match")