-
Notifications
You must be signed in to change notification settings - Fork 521
Expand file tree
/
Copy pathtiled_mlp.py
More file actions
133 lines (107 loc) · 4.94 KB
/
tiled_mlp.py
File metadata and controls
133 lines (107 loc) · 4.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from typing import Optional
import torch.nn as nn
from liger_kernel.ops import LigerGELUMulFunction
from liger_kernel.ops import LigerSiLUMulFunction
from liger_kernel.ops import apply_tiled_mlp
class LigerTiledGEGLUMLP(nn.Module):
"""
Memory-efficient GEGLU MLP using tiled computation.
This module combines GEGLU activation with tiled processing to handle
very long sequences efficiently. The forward pass is recomputed during
backward to save memory.
Args:
config: Model configuration with hidden_size and intermediate_size attributes
num_shards: Number of shards to split the sequence. If None, automatically
calculated as ceil(seqlen / hidden_size)
ddp_safe: If True, use DDP-compatible backward (gradients assigned only after
last shard). Set True when using with torch.nn.parallel.DistributedDataParallel.
"""
def __init__(self, config, num_shards: Optional[int] = None, ddp_safe: bool = False):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.num_shards = num_shards
self.ddp_safe = ddp_safe
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
# Validate activation function
if hasattr(config, "hidden_act") and config.hidden_act not in [
"gelu",
"gelu_new",
"gelu_pytorch_tanh",
]:
raise ValueError(f"LigerTiledGEGLUMLP requires GELU activation, got {config.hidden_act}")
def _mlp_forward(self, module, x):
"""Internal MLP forward function for tiled computation."""
gate = module.gate_proj(x)
up = module.up_proj(x)
return module.down_proj(LigerGELUMulFunction.apply(gate, up))
def forward(self, x):
"""
Forward pass with tiled computation.
Args:
x: Input tensor of shape [batch_size, seq_len, hidden_size]
or [seq_len, hidden_size]
Returns:
Output tensor of the same shape as input
"""
compute_params = [p for p in self.parameters() if p.requires_grad]
return apply_tiled_mlp(
fn=self._mlp_forward,
mlp_module=self,
x=x,
num_shards=self.num_shards,
compute_params=compute_params,
ddp_safe=self.ddp_safe,
)
class LigerTiledSwiGLUMLP(nn.Module):
"""
Memory-efficient SwiGLU MLP using tiled computation.
This module combines SwiGLU activation with tiled processing to handle
very long sequences efficiently. The forward pass is recomputed during
backward to save memory.
Args:
config: Model configuration with hidden_size and intermediate_size attributes
num_shards: Number of shards to split the sequence. If None, automatically
calculated as ceil(seqlen / hidden_size)
ddp_safe: If True, use DDP-compatible backward (gradients assigned only after
last shard). Set True when using with torch.nn.parallel.DistributedDataParallel.
"""
def __init__(self, config, num_shards: Optional[int] = None, ddp_safe: bool = False):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.num_shards = num_shards
self.ddp_safe = ddp_safe
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
# Validate activation function
if hasattr(config, "hidden_act") and config.hidden_act not in ["silu", "swish"]:
raise ValueError(f"LigerTiledSwiGLUMLP requires SiLU/Swish activation, got {config.hidden_act}")
def _mlp_forward(self, module, x):
"""Internal MLP forward function for tiled computation."""
gate = module.gate_proj(x)
up = module.up_proj(x)
return module.down_proj(LigerSiLUMulFunction.apply(gate, up))
def forward(self, x):
"""
Forward pass with tiled computation.
Args:
x: Input tensor of shape [batch_size, seq_len, hidden_size]
or [seq_len, hidden_size]
Returns:
Output tensor of the same shape as input
"""
compute_params = [p for p in self.parameters() if p.requires_grad]
return apply_tiled_mlp(
fn=self._mlp_forward,
mlp_module=self,
x=x,
num_shards=self.num_shards,
compute_params=compute_params,
ddp_safe=self.ddp_safe,
)