Skip to content

Commit f731379

Browse files
Added support for TElinear ops (#632)
iCo-authored-by: Asma Kuriparambil Thekkumpate <akuriparambi@nvidia.com> ## What does this PR do? **Type of change:** New Feature **Overview:** This MR adds support for quantizing TE Ops in megatron, specifically TERowParallelLinear, TEColumnParallelLinear and TELayerNormColumnParallelLinear. ## Usage It can be used by enabling TE spec in megatron ## Testing Added unit tests for testing functionality `test_homogeneous_sharded_state_dict_te_spec` `test_convert_mcore_te_gpt_model` `test_quantize_forward_backward` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent fd66be2 commit f731379

11 files changed

Lines changed: 434 additions & 31 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ NVIDIA Model Optimizer Changelog (Linux)
2020
- Add support for PyTorch Geometric quantization.
2121
- Add per tensor and per channel MSE calibrator support.
2222
- Added support for PTQ/QAT checkpoint export and loading for running fakequant evaluation in vLLM. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
23+
- Add support for Transformer Engine quantization for Megatron Core models.
2324

2425
**Documentation**
2526

modelopt/torch/quantization/model_quant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from modelopt.torch.opt.utils import forward_with_reshard
3131
from modelopt.torch.quantization.config import QuantizeConfig
3232
from modelopt.torch.quantization.conversion import set_quantizer_by_cfg
33+
from modelopt.torch.utils import atomic_print
3334

3435
from .algorithms import AutoQuantizeGradientSearcher, AutoQuantizeKLDivSearcher, QuantRecipe
3536
from .config import QuantizeAlgoCfgType
@@ -506,6 +507,7 @@ def enable_quantizer(model: nn.Module, wildcard_or_filter_func: str | Callable):
506507
set_quantizer_attribute(model, wildcard_or_filter_func, {"enable": True})
507508

508509

510+
@atomic_print
509511
def print_quant_summary(model: nn.Module):
510512
"""Print summary of all quantizer modules in the model."""
511513
count = 0

modelopt/torch/quantization/plugins/custom.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class _QuantFunctionalMixin(QuantModule):
5959
def functionals_to_replace(self) -> Iterator[tuple[ModuleType, str, Callable]]:
6060
return (
6161
(package, func_name, quantized_func)
62-
for package, func_name, quantized_func in self._functionals_to_replace
62+
for package, func_name, quantized_func in getattr(self, "_functionals_to_replace", [])
6363
if hasattr(package, func_name)
6464
)
6565

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,14 @@
4747
try:
4848
from megatron.core.extensions.transformer_engine import (
4949
TEColumnParallelGroupedLinear,
50+
TEColumnParallelLinear,
5051
TEDotProductAttention,
52+
TELayerNormColumnParallelLinear,
5153
TERowParallelGroupedLinear,
54+
TERowParallelLinear,
5255
)
5356

54-
from .transformer_engine import _QuantTEGroupedLinear
57+
from .transformer_engine import _QuantTEGroupedLinear, _QuantTELayerNormLinear, _QuantTELinear
5558

5659
HAS_TE = True
5760
except ImportError:
@@ -549,6 +552,23 @@ def sync_moe_local_experts_amax(self):
549552

550553

551554
if HAS_TE:
555+
556+
@QuantModuleRegistry.register({TERowParallelLinear: "te_mcore_RowParallelLinear"})
557+
class _QuantTEMCoreRowParallelLinear(_QuantTELinear, _MegatronRowParallelLinear):
558+
pass
559+
560+
@QuantModuleRegistry.register({TEColumnParallelLinear: "te_mcore_ColumnParallelLinear"})
561+
class _QuantTEMCoreColumnParallelLinear(_QuantTELinear, _MegatronColumnParallelLinear):
562+
pass
563+
564+
@QuantModuleRegistry.register(
565+
{TELayerNormColumnParallelLinear: "te_mcore_LayerNormColumnParallelLinear"}
566+
)
567+
class _QuantTELayerNormColumnParallelLinear(
568+
_QuantTELayerNormLinear, _MegatronColumnParallelLinear
569+
):
570+
pass
571+
552572
# Quantized subclasses to support TEGroupedMLP quantization
553573
class _QuantMegatronTEGroupedLinear(_QuantTEGroupedLinear, _MegatronParallelLinear):
554574
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):

modelopt/torch/quantization/plugins/transformer_engine.py

Lines changed: 190 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,44 +15,87 @@
1515

1616
"""Support quantization for Transformer Engine layers."""
1717

18+
import warnings
19+
1820
import torch
1921
import transformer_engine as te
2022
import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear
23+
import transformer_engine.pytorch.module.layernorm_linear as te_layernorm_linear
2124
import transformer_engine.pytorch.module.linear as te_linear
25+
from packaging.version import Version
26+
27+
from modelopt.torch.quantization.utils import replace_function
2228

2329
from ..nn import QuantModuleRegistry
2430
from .custom import _ParallelLinear
2531

32+
_TE_VERSION = Version(te.__version__)
33+
34+
35+
def _assert_te_fp8_enabled():
36+
"""Check if Transformer Engine FP8 autocast is enabled and raise error if so."""
37+
try:
38+
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
39+
40+
if FP8GlobalStateManager.is_fp8_enabled():
41+
raise RuntimeError(
42+
"Transformer Engine FP8 training (fp8_autocast) is enabled, which conflicts with "
43+
"ModelOpt quantization. Please disable TE FP8 autocast when using ModelOpt "
44+
"quantization, or use ModelOpt's FP8 quantization instead."
45+
)
46+
except ImportError:
47+
pass # Older TE versions may not have this API
48+
2649

2750
@QuantModuleRegistry.register({te.pytorch.Linear: "te_Linear"})
2851
class _QuantTELinear(_ParallelLinear):
29-
_functionals_to_replace = [
30-
(
31-
te_linear._Linear,
32-
"apply" if torch.is_grad_enabled() else "forward",
33-
),
34-
]
52+
@property
53+
def _functionals_to_replace(self):
54+
return (
55+
[(te_linear._Linear, "apply")]
56+
if torch.is_grad_enabled()
57+
else [(te_linear._Linear, "forward")]
58+
)
59+
60+
@_functionals_to_replace.setter
61+
def _functionals_to_replace(self, value):
62+
self._functionals_to_replace = value
63+
64+
def _setup(self):
65+
super()._setup()
66+
if getattr(self, "fuse_wgrad_accumulation", False):
67+
warnings.warn(
68+
"fuse_wgrad_accumulation is not supported with ModelOpt quantization. "
69+
"Setting fuse_wgrad_accumulation to False."
70+
)
71+
self.fuse_wgrad_accumulation = False
3572

3673
@staticmethod
3774
def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
3875
"""Quantized version specifically for TE with weight first, then input."""
39-
if te.__version__ >= "2.0":
40-
weight, inputs = args[0], args[1]
41-
remaining_args = args[2:]
76+
_assert_te_fp8_enabled()
77+
if Version("2.0") <= _TE_VERSION:
78+
idx = 1 if func_name == "_forward" else 0
79+
weight, inputs = args[idx], args[idx + 1]
80+
remaining_args = args[idx + 2 :]
81+
weight = self.weight_quantizer(weight)
82+
inputs = self.input_quantizer(inputs)
83+
new_args = (weight, inputs, *remaining_args)
84+
new_args = (args[0], *new_args) if func_name == "_forward" else new_args
4285
output = getattr(package, func_name)(
43-
self.weight_quantizer(weight),
44-
self.input_quantizer(inputs),
45-
*remaining_args,
86+
*new_args,
4687
**kwargs,
4788
)
4889
else:
49-
weight, weight_fp8, inputs = args[0], args[1], args[2]
50-
remaining_args = args[3:]
90+
idx = 1 if func_name == "_forward" else 0
91+
weight, weight_fp8, inputs = args[idx], args[idx + 1], args[idx + 2]
92+
remaining_args = args[idx + 3 :]
93+
weight = self.weight_quantizer(weight)
94+
inputs = self.input_quantizer(inputs)
95+
new_args = (weight, weight_fp8, inputs, *remaining_args)
96+
new_args = (args[0], *new_args) if func_name == "_forward" else new_args
5197
output = getattr(package, func_name)(
52-
self.weight_quantizer(weight),
53-
weight_fp8,
54-
self.input_quantizer(inputs),
55-
*remaining_args,
98+
*new_args,
5699
**kwargs,
57100
)
58101
return self.output_quantizer(output)
@@ -64,10 +107,17 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs):
64107
# Register the public te.pytorch.GroupedLinear class
65108
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"})
66109
class _QuantTEGroupedLinear(_ParallelLinear):
67-
_functionals_to_replace = [
68-
(te_grouped_linear._GroupedLinear, "forward"),
69-
(te_grouped_linear._GroupedLinear, "apply"),
70-
]
110+
@property
111+
def _functionals_to_replace(self):
112+
return (
113+
[(te_grouped_linear._GroupedLinear, "apply")]
114+
if torch.is_grad_enabled()
115+
else [(te_grouped_linear._GroupedLinear, "forward")]
116+
)
117+
118+
@_functionals_to_replace.setter
119+
def _functionals_to_replace(self, value):
120+
self._functionals_to_replace = value
71121

72122
def _setup(self):
73123
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
@@ -93,6 +143,7 @@ def modelopt_post_restore(self, prefix: str = ""):
93143

94144
@staticmethod
95145
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
146+
_assert_te_fp8_enabled()
96147
idx = 1 if func_name == "_forward" else 0
97148
inp = args[idx]
98149
num_gemms = len(args[idx + 1])
@@ -116,3 +167,120 @@ def te_grouped_quantized_linear_fn(package, func_name, self, *args):
116167

117168
# Override the quantized linear function
118169
_quantized_linear_fn = te_grouped_quantized_linear_fn
170+
171+
172+
class _QuantLayerNormLinearFunc(torch.autograd.Function):
173+
"""Patched version of _LayerNormLinear to quantize the input to the GEMM operation."""
174+
175+
@staticmethod
176+
def _get_original_gemm():
177+
if Version("2.0") <= _TE_VERSION:
178+
return te_layernorm_linear.general_gemm
179+
else:
180+
return te_layernorm_linear.tex.gemm
181+
182+
@staticmethod
183+
def _gemm_replace_args():
184+
if Version("2.0") <= _TE_VERSION:
185+
return (te_layernorm_linear, "general_gemm")
186+
else:
187+
return (te_layernorm_linear.tex, "gemm")
188+
189+
@staticmethod
190+
def forward(ctx, inp, ln_weight, ln_bias, weight, *args, **kwargs):
191+
input_quantizer, weight_quantizer = _QuantLayerNormLinearFunc.modelopt_quantizers
192+
193+
qweight = weight_quantizer(weight)
194+
qweight.requires_grad = weight.requires_grad
195+
if ctx is not None:
196+
# We need to recompute the quantized input for the backward pass, so we save the input_quantizer
197+
ctx.modelopt_input_quantizer = input_quantizer
198+
199+
original_gemm = _QuantLayerNormLinearFunc._get_original_gemm()
200+
201+
def _patched_general_gemm(weight, input, *gemm_args, **gemm_kwargs):
202+
qinput = input_quantizer(input)
203+
return original_gemm(weight, qinput, *gemm_args, **gemm_kwargs)
204+
205+
with replace_function(
206+
*_QuantLayerNormLinearFunc._gemm_replace_args(),
207+
_patched_general_gemm, # type: ignore[call-arg]
208+
):
209+
outputs = te_layernorm_linear._og_LayerNormLinear.forward(
210+
ctx, inp, ln_weight, ln_bias, qweight, *args, **kwargs
211+
)
212+
return outputs
213+
214+
# TODO: Support non-pass-through backward behavior for activation quantization
215+
@staticmethod
216+
def backward(ctx, *grad_outputs):
217+
"""Backward pass for _QuantLayerNormLinearFunc functional.
218+
219+
The backward pass input and weight gradient estimation uses straight through estimator (STE).
220+
We should add support for advanced gradient estimation techniques like STE with clipping.
221+
However this is a low priority item.
222+
"""
223+
gemm_call_counter = {"count": 0}
224+
225+
original_gemm = _QuantLayerNormLinearFunc._get_original_gemm()
226+
227+
def _patched_general_gemm(a, b, *gemm_args, **gemm_kwargs):
228+
# The first time, gemm is used for dgrad calculation
229+
# dgrad GEMM; dx = dy * qw; Called as gemm(qw, dy, ...)
230+
if gemm_call_counter["count"] == 0:
231+
gemm_call_counter["count"] += 1
232+
return original_gemm(a, b, *gemm_args, **gemm_kwargs)
233+
234+
# The second time, gemm is used for wgrad calculation
235+
# wgrad GEMM; dqw = dy^T * x; Called as gemm(x, dy, ..);
236+
237+
# x should be quantized input (qinput) for the backward pass as per chain rule,
238+
# but gemm is called with the unquantized input (a)
239+
# So lets first get the quantized input (qinput) and then call the gemm
240+
qinput = ctx.modelopt_input_quantizer(a)
241+
return original_gemm(qinput, b, *gemm_args, **gemm_kwargs)
242+
243+
with replace_function(
244+
*_QuantLayerNormLinearFunc._gemm_replace_args(),
245+
_patched_general_gemm, # type: ignore[call-arg]
246+
):
247+
# During backward, the patch does not exist; autograd will automatically use
248+
# _QuantLayerNormLinearFunc.backward
249+
outputs = te_layernorm_linear._LayerNormLinear.backward(ctx, *grad_outputs)
250+
251+
delattr(ctx, "modelopt_input_quantizer")
252+
return outputs
253+
254+
255+
@QuantModuleRegistry.register({te.pytorch.LayerNormLinear: "te_LayerNormLinear"})
256+
class _QuantTELayerNormLinear(_ParallelLinear):
257+
_functionals_to_replace = []
258+
259+
def _setup(self):
260+
super()._setup()
261+
if getattr(self, "fuse_wgrad_accumulation", False):
262+
warnings.warn(
263+
"fuse_wgrad_accumulation is not supported with ModelOpt quantization. "
264+
"Setting fuse_wgrad_accumulation to False."
265+
)
266+
self.fuse_wgrad_accumulation = False
267+
268+
def forward(self, *args, **kwargs):
269+
"""Call ModelOpt patch for _LayerNormLinear functional."""
270+
_assert_te_fp8_enabled()
271+
# This is multi-process safe (such as in torch distributed jobs), not multi-thread safe
272+
_QuantLayerNormLinearFunc.modelopt_quantizers = (
273+
self.input_quantizer,
274+
self.weight_quantizer,
275+
)
276+
with replace_function(
277+
te_layernorm_linear,
278+
"_LayerNormLinear",
279+
_QuantLayerNormLinearFunc,
280+
"_og_LayerNormLinear",
281+
):
282+
outputs = super().forward(*args, **kwargs)
283+
delattr(_QuantLayerNormLinearFunc, "modelopt_quantizers")
284+
if isinstance(outputs, tuple):
285+
return (self.output_quantizer(outputs[0]), *outputs[1:])
286+
return self.output_quantizer(outputs)

modelopt/torch/quantization/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,14 +337,16 @@ def disable_lora_quantizers_in_config(config, layers):
337337

338338

339339
@contextmanager
340-
def replace_function(package, name, new_func):
340+
def replace_function(package, name, new_func, og_func_cache_name=None):
341341
"""Replace a function with a new one within a context."""
342+
if og_func_cache_name is None:
343+
og_func_cache_name = "_" + name
342344
old_func = getattr(package, name)
343345
setattr(package, name, new_func)
344-
setattr(package, "_" + name, old_func)
346+
setattr(package, og_func_cache_name, old_func)
345347
yield
346348
setattr(package, name, old_func)
347-
delattr(package, "_" + name)
349+
delattr(package, og_func_cache_name)
348350

349351

350352
@contextmanager

0 commit comments

Comments
 (0)