1515
1616"""Support quantization for Transformer Engine layers."""
1717
18+ import warnings
19+
1820import torch
1921import transformer_engine as te
2022import transformer_engine .pytorch .module .grouped_linear as te_grouped_linear
23+ import transformer_engine .pytorch .module .layernorm_linear as te_layernorm_linear
2124import transformer_engine .pytorch .module .linear as te_linear
25+ from packaging .version import Version
26+
27+ from modelopt .torch .quantization .utils import replace_function
2228
2329from ..nn import QuantModuleRegistry
2430from .custom import _ParallelLinear
2531
32+ _TE_VERSION = Version (te .__version__ )
33+
34+
35+ def _assert_te_fp8_enabled ():
36+ """Check if Transformer Engine FP8 autocast is enabled and raise error if so."""
37+ try :
38+ from transformer_engine .pytorch .fp8 import FP8GlobalStateManager
39+
40+ if FP8GlobalStateManager .is_fp8_enabled ():
41+ raise RuntimeError (
42+ "Transformer Engine FP8 training (fp8_autocast) is enabled, which conflicts with "
43+ "ModelOpt quantization. Please disable TE FP8 autocast when using ModelOpt "
44+ "quantization, or use ModelOpt's FP8 quantization instead."
45+ )
46+ except ImportError :
47+ pass # Older TE versions may not have this API
48+
2649
2750@QuantModuleRegistry .register ({te .pytorch .Linear : "te_Linear" })
2851class _QuantTELinear (_ParallelLinear ):
29- _functionals_to_replace = [
30- (
31- te_linear ._Linear ,
32- "apply" if torch .is_grad_enabled () else "forward" ,
33- ),
34- ]
52+ @property
53+ def _functionals_to_replace (self ):
54+ return (
55+ [(te_linear ._Linear , "apply" )]
56+ if torch .is_grad_enabled ()
57+ else [(te_linear ._Linear , "forward" )]
58+ )
59+
60+ @_functionals_to_replace .setter
61+ def _functionals_to_replace (self , value ):
62+ self ._functionals_to_replace = value
63+
64+ def _setup (self ):
65+ super ()._setup ()
66+ if getattr (self , "fuse_wgrad_accumulation" , False ):
67+ warnings .warn (
68+ "fuse_wgrad_accumulation is not supported with ModelOpt quantization. "
69+ "Setting fuse_wgrad_accumulation to False."
70+ )
71+ self .fuse_wgrad_accumulation = False
3572
3673 @staticmethod
3774 def te_quantized_linear_fn (package , func_name , self , * args , ** kwargs ):
3875 """Quantized version specifically for TE with weight first, then input."""
39- if te .__version__ >= "2.0" :
40- weight , inputs = args [0 ], args [1 ]
41- remaining_args = args [2 :]
76+ _assert_te_fp8_enabled ()
77+ if Version ("2.0" ) <= _TE_VERSION :
78+ idx = 1 if func_name == "_forward" else 0
79+ weight , inputs = args [idx ], args [idx + 1 ]
80+ remaining_args = args [idx + 2 :]
81+ weight = self .weight_quantizer (weight )
82+ inputs = self .input_quantizer (inputs )
83+ new_args = (weight , inputs , * remaining_args )
84+ new_args = (args [0 ], * new_args ) if func_name == "_forward" else new_args
4285 output = getattr (package , func_name )(
43- self .weight_quantizer (weight ),
44- self .input_quantizer (inputs ),
45- * remaining_args ,
86+ * new_args ,
4687 ** kwargs ,
4788 )
4889 else :
49- weight , weight_fp8 , inputs = args [0 ], args [1 ], args [2 ]
50- remaining_args = args [3 :]
90+ idx = 1 if func_name == "_forward" else 0
91+ weight , weight_fp8 , inputs = args [idx ], args [idx + 1 ], args [idx + 2 ]
92+ remaining_args = args [idx + 3 :]
93+ weight = self .weight_quantizer (weight )
94+ inputs = self .input_quantizer (inputs )
95+ new_args = (weight , weight_fp8 , inputs , * remaining_args )
96+ new_args = (args [0 ], * new_args ) if func_name == "_forward" else new_args
5197 output = getattr (package , func_name )(
52- self .weight_quantizer (weight ),
53- weight_fp8 ,
54- self .input_quantizer (inputs ),
55- * remaining_args ,
98+ * new_args ,
5699 ** kwargs ,
57100 )
58101 return self .output_quantizer (output )
@@ -64,10 +107,17 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
64107# Register the public te.pytorch.GroupedLinear class
65108@QuantModuleRegistry .register ({te_grouped_linear .GroupedLinear : "te_GroupedLinear" })
66109class _QuantTEGroupedLinear (_ParallelLinear ):
67- _functionals_to_replace = [
68- (te_grouped_linear ._GroupedLinear , "forward" ),
69- (te_grouped_linear ._GroupedLinear , "apply" ),
70- ]
110+ @property
111+ def _functionals_to_replace (self ):
112+ return (
113+ [(te_grouped_linear ._GroupedLinear , "apply" )]
114+ if torch .is_grad_enabled ()
115+ else [(te_grouped_linear ._GroupedLinear , "forward" )]
116+ )
117+
118+ @_functionals_to_replace .setter
119+ def _functionals_to_replace (self , value ):
120+ self ._functionals_to_replace = value
71121
72122 def _setup (self ):
73123 # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
@@ -93,6 +143,7 @@ def modelopt_post_restore(self, prefix: str = ""):
93143
94144 @staticmethod
95145 def te_grouped_quantized_linear_fn (package , func_name , self , * args ):
146+ _assert_te_fp8_enabled ()
96147 idx = 1 if func_name == "_forward" else 0
97148 inp = args [idx ]
98149 num_gemms = len (args [idx + 1 ])
@@ -116,3 +167,120 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args):
116167
117168 # Override the quantized linear function
118169 _quantized_linear_fn = te_grouped_quantized_linear_fn
170+
171+
172+ class _QuantLayerNormLinearFunc (torch .autograd .Function ):
173+ """Patched version of _LayerNormLinear to quantize the input to the GEMM operation."""
174+
175+ @staticmethod
176+ def _get_original_gemm ():
177+ if Version ("2.0" ) <= _TE_VERSION :
178+ return te_layernorm_linear .general_gemm
179+ else :
180+ return te_layernorm_linear .tex .gemm
181+
182+ @staticmethod
183+ def _gemm_replace_args ():
184+ if Version ("2.0" ) <= _TE_VERSION :
185+ return (te_layernorm_linear , "general_gemm" )
186+ else :
187+ return (te_layernorm_linear .tex , "gemm" )
188+
189+ @staticmethod
190+ def forward (ctx , inp , ln_weight , ln_bias , weight , * args , ** kwargs ):
191+ input_quantizer , weight_quantizer = _QuantLayerNormLinearFunc .modelopt_quantizers
192+
193+ qweight = weight_quantizer (weight )
194+ qweight .requires_grad = weight .requires_grad
195+ if ctx is not None :
196+ # We need to recompute the quantized input for the backward pass, so we save the input_quantizer
197+ ctx .modelopt_input_quantizer = input_quantizer
198+
199+ original_gemm = _QuantLayerNormLinearFunc ._get_original_gemm ()
200+
201+ def _patched_general_gemm (weight , input , * gemm_args , ** gemm_kwargs ):
202+ qinput = input_quantizer (input )
203+ return original_gemm (weight , qinput , * gemm_args , ** gemm_kwargs )
204+
205+ with replace_function (
206+ * _QuantLayerNormLinearFunc ._gemm_replace_args (),
207+ _patched_general_gemm , # type: ignore[call-arg]
208+ ):
209+ outputs = te_layernorm_linear ._og_LayerNormLinear .forward (
210+ ctx , inp , ln_weight , ln_bias , qweight , * args , ** kwargs
211+ )
212+ return outputs
213+
214+ # TODO: Support non-pass-through backward behavior for activation quantization
215+ @staticmethod
216+ def backward (ctx , * grad_outputs ):
217+ """Backward pass for _QuantLayerNormLinearFunc functional.
218+
219+ The backward pass input and weight gradient estimation uses straight through estimator (STE).
220+ We should add support for advanced gradient estimation techniques like STE with clipping.
221+ However this is a low priority item.
222+ """
223+ gemm_call_counter = {"count" : 0 }
224+
225+ original_gemm = _QuantLayerNormLinearFunc ._get_original_gemm ()
226+
227+ def _patched_general_gemm (a , b , * gemm_args , ** gemm_kwargs ):
228+ # The first time, gemm is used for dgrad calculation
229+ # dgrad GEMM; dx = dy * qw; Called as gemm(qw, dy, ...)
230+ if gemm_call_counter ["count" ] == 0 :
231+ gemm_call_counter ["count" ] += 1
232+ return original_gemm (a , b , * gemm_args , ** gemm_kwargs )
233+
234+ # The second time, gemm is used for wgrad calculation
235+ # wgrad GEMM; dqw = dy^T * x; Called as gemm(x, dy, ..);
236+
237+ # x should be quantized input (qinput) for the backward pass as per chain rule,
238+ # but gemm is called with the unquantized input (a)
239+ # So lets first get the quantized input (qinput) and then call the gemm
240+ qinput = ctx .modelopt_input_quantizer (a )
241+ return original_gemm (qinput , b , * gemm_args , ** gemm_kwargs )
242+
243+ with replace_function (
244+ * _QuantLayerNormLinearFunc ._gemm_replace_args (),
245+ _patched_general_gemm , # type: ignore[call-arg]
246+ ):
247+ # During backward, the patch does not exist; autograd will automatically use
248+ # _QuantLayerNormLinearFunc.backward
249+ outputs = te_layernorm_linear ._LayerNormLinear .backward (ctx , * grad_outputs )
250+
251+ delattr (ctx , "modelopt_input_quantizer" )
252+ return outputs
253+
254+
255+ @QuantModuleRegistry .register ({te .pytorch .LayerNormLinear : "te_LayerNormLinear" })
256+ class _QuantTELayerNormLinear (_ParallelLinear ):
257+ _functionals_to_replace = []
258+
259+ def _setup (self ):
260+ super ()._setup ()
261+ if getattr (self , "fuse_wgrad_accumulation" , False ):
262+ warnings .warn (
263+ "fuse_wgrad_accumulation is not supported with ModelOpt quantization. "
264+ "Setting fuse_wgrad_accumulation to False."
265+ )
266+ self .fuse_wgrad_accumulation = False
267+
268+ def forward (self , * args , ** kwargs ):
269+ """Call ModelOpt patch for _LayerNormLinear functional."""
270+ _assert_te_fp8_enabled ()
271+ # This is multi-process safe (such as in torch distributed jobs), not multi-thread safe
272+ _QuantLayerNormLinearFunc .modelopt_quantizers = (
273+ self .input_quantizer ,
274+ self .weight_quantizer ,
275+ )
276+ with replace_function (
277+ te_layernorm_linear ,
278+ "_LayerNormLinear" ,
279+ _QuantLayerNormLinearFunc ,
280+ "_og_LayerNormLinear" ,
281+ ):
282+ outputs = super ().forward (* args , ** kwargs )
283+ delattr (_QuantLayerNormLinearFunc , "modelopt_quantizers" )
284+ if isinstance (outputs , tuple ):
285+ return (self .output_quantizer (outputs [0 ]), * outputs [1 :])
286+ return self .output_quantizer (outputs )
0 commit comments