Skip to content

Commit 1d0ee04

Browse files
authored
[OMNIML-2932] Fusing pre_quant_scale for NVFP4 AWQ (#421)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** This PR and NVIDIA/TensorRT-LLM#8698 enable NVFP4 AWQ deployment for TRT-LLM. Specifically, this PR fuses pre_quant_scale in following two cases: * For MLP, pre_quant_scale of gate_proj layer is fused into up_proj's weight, so we don't need an extra handle in downstream fused moe kernels. * For attention, we will try to fuse the pre_quant_scale of o_proj to v_proj if their dimensions match, which means we will skip fusion for MQA/GQA models. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> unit test, e2e test for Qwen3 dense and moe models. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent a8641bc commit 1d0ee04

4 files changed

Lines changed: 330 additions & 19 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Model Optimizer Changelog (Linux)
4747
- Enabled native Modelopt quantization support for FP8 and NVFP4 formats in SGLang. See `SGLang quantization documentation <https://github.com/sgl-project/sglang/blob/main/docs/advanced_features/quantization.md#using-nvidia-modelopt>`_ for more details.
4848
- Added modelopt quantized checkpoints in vLLM/SGLang CI/CD pipelines (PRs are under review).
4949
- Add support for exporting QLoRA checkpoint fintuned using ModelOpt.
50+
- Update NVFP4 AWQ checkpoint export. It now fuses scaling factors of o_proj and down_proj layers into the model when possible to facilitate deployment.
5051

5152
**Documentation**
5253

modelopt/torch/export/quant_utils.py

Lines changed: 131 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
489489

490490
if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"):
491491
return QUANTIZATION_NVFP4_AWQ
492-
if getattr(layer, "fused_with_layernorm", False):
492+
if getattr(layer, "fused_with_prequant", False):
493493
return QUANTIZATION_NVFP4_AWQ
494494
assert input_quantizer is not None, (
495495
f"input_quantizer is None for {quantizer_attr_names}"
@@ -959,18 +959,145 @@ def all_items_same(item_list):
959959
return all(x == item_list[0] for x in item_list)
960960

961961

962+
def _update_pre_quant_scale(module, new_pre_quant_scale):
963+
old_pre_quant_scale = module.input_quantizer._pre_quant_scale
964+
# do the processing in fp32 for numerical stability
965+
dtype = module.weight.dtype
966+
module.weight = nn.Parameter(
967+
(
968+
module.weight.to(torch.float32)
969+
* old_pre_quant_scale.to(dtype=torch.float32, device=module.weight.device)
970+
/ new_pre_quant_scale.to(dtype=torch.float32, device=module.weight.device)
971+
).to(dtype)
972+
)
973+
module.input_quantizer.pre_quant_scale = new_pre_quant_scale
974+
975+
# Redo weights collection
976+
module.weight_quantizer.reset_amax()
977+
enable_stats_collection(module.weight_quantizer)
978+
module.weight_quantizer(module.weight)
979+
finish_stats_collection(module.weight_quantizer)
980+
981+
982+
# Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale))
983+
PQS_FUSE_MODULE_MAPPING = [
984+
# Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension
985+
# Mathematical equivalence:
986+
# Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T
987+
# After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T
988+
(["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")),
989+
# MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension
990+
# Mathematical equivalence:
991+
# Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T
992+
# After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T
993+
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")),
994+
]
995+
996+
997+
def fuse_prequant_to_linear(model: torch.nn.Module, fuse_grouped_heads=False):
998+
"""Fuse pre_quant_scale to the linear weights if possible.
999+
1000+
Args:
1001+
model: The model to fuse pre_quant_scale to.
1002+
fuse_grouped_heads: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale
1003+
and linear weights is not the same.
1004+
1005+
Returns:
1006+
fused_modules: A list of modules of which pre_quant_scale is fused to the previous linear layer.
1007+
"""
1008+
# Fuse pre_quant_scale to the linear weights
1009+
for _, module in model.named_modules():
1010+
for module_map in PQS_FUSE_MODULE_MAPPING:
1011+
target_module_list = module_map[0]
1012+
linear_pair = module_map[1]
1013+
if any(module_name in type(module).__name__ for module_name in target_module_list):
1014+
linear_fuse_into = module.get_submodule(linear_pair[0])
1015+
linear_pqs_from = module.get_submodule(linear_pair[1])
1016+
if hasattr(linear_pqs_from, "input_quantizer") and hasattr(
1017+
linear_pqs_from.input_quantizer, "_pre_quant_scale"
1018+
):
1019+
pre_quant_scale = linear_pqs_from.input_quantizer._pre_quant_scale
1020+
1021+
# for GQA/MQA models, we can apply averaging to the pre_quant_scale for shared head groups
1022+
if pre_quant_scale.numel() != linear_fuse_into.weight.shape[-2]:
1023+
if (
1024+
not fuse_grouped_heads
1025+
or "attention" not in type(module).__name__.lower()
1026+
):
1027+
warn(
1028+
f"Skipping pattern fuse prequant for {type(module).__name__}"
1029+
f"pre_quant_scale dim {pre_quant_scale.numel()} != "
1030+
f"out_channel dim {linear_fuse_into.weight.shape[-2]}"
1031+
)
1032+
continue
1033+
config = module.config
1034+
num_kv_heads = config.num_key_value_heads
1035+
kv_head_dim = linear_fuse_into.weight.shape[0] // num_kv_heads
1036+
n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim
1037+
1038+
# Reshape:(num_kv_heads, n_rep, kv_head_dim)
1039+
# n_rep is the number of query group
1040+
averaged_scale = pre_quant_scale.view(
1041+
num_kv_heads, n_rep, kv_head_dim
1042+
).mean(dim=1)
1043+
1044+
# To update o_proj, we need to repeat back to original shape
1045+
repeated_scale = (
1046+
averaged_scale.unsqueeze(1)
1047+
.expand(num_kv_heads, n_rep, kv_head_dim)
1048+
.reshape(-1)
1049+
)
1050+
# Update o_proj's pre_quant_scale
1051+
_update_pre_quant_scale(linear_pqs_from, repeated_scale)
1052+
1053+
# Use averaged scale (flattened) for v_proj fusion
1054+
pre_quant_scale = averaged_scale.reshape(-1)
1055+
1056+
# Fuse the pre_quant_scale to weight
1057+
linear_fuse_into.weight = torch.nn.Parameter(
1058+
linear_fuse_into.weight * pre_quant_scale.view(-1, 1)
1059+
)
1060+
if hasattr(linear_fuse_into, "bias") and linear_fuse_into.bias is not None:
1061+
linear_fuse_into.bias = torch.nn.Parameter(
1062+
linear_fuse_into.bias * pre_quant_scale
1063+
)
1064+
1065+
# Recalibrate the weight quantizer for linear_fuse_into
1066+
linear_fuse_into.weight_quantizer.reset_amax()
1067+
enable_stats_collection(linear_fuse_into.weight_quantizer)
1068+
linear_fuse_into.weight_quantizer(linear_fuse_into.weight)
1069+
finish_stats_collection(linear_fuse_into.weight_quantizer)
1070+
1071+
delattr(linear_pqs_from.input_quantizer, "_pre_quant_scale")
1072+
setattr(linear_pqs_from, "fused_with_prequant", True)
1073+
1074+
9621075
def fuse_prequant_layernorm(
9631076
layernorm_module: torch.nn.Module,
9641077
modules: list[torch.Tensor],
9651078
):
966-
"""Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted."""
1079+
"""Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.
1080+
1081+
original:
1082+
layernorm_output = (normalization(input) * weight) + bias
1083+
layernorm_output_scaled = layernorm_output * pre_quant_scale
1084+
1085+
fused:
1086+
fused_weight = weight * avg_pre_quant_scale
1087+
fused_bias = bias * avg_pre_quant_scale
1088+
layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias
1089+
"""
9671090
layernorm_module.weight = torch.nn.Parameter(
9681091
layernorm_module.weight * getattr(modules[0].input_quantizer, "_pre_quant_scale")
9691092
)
1093+
if hasattr(layernorm_module, "bias") and layernorm_module.bias is not None:
1094+
layernorm_module.bias = torch.nn.Parameter(
1095+
layernorm_module.bias * getattr(modules[0].input_quantizer, "_pre_quant_scale")
1096+
)
9701097
# Pre_quant_scales of modules must not be exported, since they have been fused with layernorm
9711098
for module in modules:
9721099
delattr(module.input_quantizer, "_pre_quant_scale")
973-
setattr(module, "fused_with_layernorm", True)
1100+
setattr(module, "fused_with_prequant", True)
9741101

9751102

9761103
def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False):
@@ -992,22 +1119,7 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False
9921119

9931120
for module in modules:
9941121
if not torch.equal(module.input_quantizer.pre_quant_scale, avg_prequant_scale):
995-
module.weight = nn.Parameter(
996-
module.weight
997-
* module.input_quantizer.pre_quant_scale.to(
998-
dtype=module.weight.dtype, device=module.weight.device
999-
)
1000-
/ avg_prequant_scale.to(
1001-
dtype=module.weight.dtype, device=module.weight.device
1002-
)
1003-
)
1004-
module.input_quantizer.pre_quant_scale = avg_prequant_scale
1005-
1006-
# Redo weights collection
1007-
module.weight_quantizer.reset_amax()
1008-
enable_stats_collection(module.weight_quantizer)
1009-
module.weight_quantizer(module.weight)
1010-
finish_stats_collection(module.weight_quantizer)
1122+
_update_pre_quant_scale(module, avg_prequant_scale)
10111123

10121124
if resmooth_only:
10131125
return

modelopt/torch/export/unified_export_hf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
6161
from .quant_utils import (
6262
fuse_prequant_layernorm,
63+
fuse_prequant_to_linear,
6364
get_activation_scaling_factor,
6465
get_quant_config,
6566
get_quantization_format,
@@ -107,6 +108,10 @@ def _output_hook(module, input, output):
107108
fused_linears = {}
108109
module_names = set()
109110

111+
# Fuse pre_quant_scale to the linear weights if possible
112+
if quantization_format is not None and "nvfp4_awq" in quantization_format.lower():
113+
fuse_prequant_to_linear(model)
114+
110115
for name, module in model.named_modules():
111116
module_names.add(name)
112117

0 commit comments

Comments
 (0)