Skip to content

Commit 9fc8306

Browse files
committed
[Fix]: Patch zero FP16 scales in INT4_AWQ ONNX export (NVBug 6110209)
replace_zero_scale_with_smallest_nonzero() in qdq_utils.py only inspected QuantizeLinear consumers and Constant-node producers, which made it a no-op for INT4_AWQ exports — those use DequantizeLinear (default and trt:: domain) consumers and store scales as graph initializers. Zero scales produced when the FP32→FP16 cast underflows therefore reached TensorRT, causing trtexec --stronglyTyped to fail with "Scale coefficients must all be positive". Extend the sanitizer to also walk DequantizeLinear / TRT_INT4DequantizeLinear nodes and to patch initializer-backed scales, while preserving dtype. Add regression tests for both the initializer + DQ path (default and trt:: domain) and the legacy Constant + QuantizeLinear path. Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 5b41ba4 commit 9fc8306

2 files changed

Lines changed: 114 additions & 4 deletions

File tree

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,14 +1011,37 @@ def replace_zero_scale_with_smallest_nonzero(onnx_model: onnx.ModelProto) -> onn
10111011
"""Replace zero scale values with smallest nonzero fp16 value in the ONNX model."""
10121012
graph = onnx_model.graph
10131013
fp16_smallest_nonzero = np.float16(6e-08)
1014-
scale_nodes = [node.input[1] for node in graph.node if node.op_type == "QuantizeLinear"]
1014+
qdq_op_types = {
1015+
"QuantizeLinear",
1016+
"DequantizeLinear",
1017+
"TRT_INT4QuantizeLinear",
1018+
"TRT_INT4DequantizeLinear",
1019+
}
1020+
scale_tensor_names = {
1021+
node.input[1]
1022+
for node in graph.node
1023+
if node.op_type in qdq_op_types and len(node.input) >= 2
1024+
}
1025+
# Scales stored as graph initializers (e.g. INT4_AWQ / TRT_INT4DequantizeLinear exports).
1026+
for init in graph.initializer:
1027+
if init.name in scale_tensor_names:
1028+
tensor = numpy_helper.to_array(init)
1029+
if tensor.dtype.kind == "f":
1030+
new_tensor = np.where(tensor == 0, fp16_smallest_nonzero, tensor).astype(
1031+
tensor.dtype
1032+
)
1033+
init.CopyFrom(numpy_helper.from_array(new_tensor, init.name))
1034+
# Scales emitted by Constant nodes (legacy QDQ export path).
10151035
for node in graph.node:
1016-
if node.op_type == "Constant" and node.output[0] in scale_nodes:
1036+
if node.op_type == "Constant" and node.output[0] in scale_tensor_names:
10171037
for attr in node.attribute:
10181038
if attr.name == "value":
10191039
tensor = numpy_helper.to_array(attr.t)
1020-
new_tensor = np.where(tensor == 0, fp16_smallest_nonzero, tensor)
1021-
attr.t.CopyFrom(numpy_helper.from_array(new_tensor, attr.t.name))
1040+
if tensor.dtype.kind == "f":
1041+
new_tensor = np.where(tensor == 0, fp16_smallest_nonzero, tensor).astype(
1042+
tensor.dtype
1043+
)
1044+
attr.t.CopyFrom(numpy_helper.from_array(new_tensor, attr.t.name))
10221045
return onnx_model
10231046

10241047

tests/unit/onnx/quantization/test_qdq_utils.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,3 +1021,90 @@ def test_column_major_gemm_trans_b_flip(self):
10211021

10221022
print(f"transB flipped: 1 -> {trans_b_value}")
10231023
print(f"Transpose nodes: {len(transpose_nodes)}")
1024+
1025+
1026+
def _build_model_with_zero_scale_initializer(dq_op_type: str):
1027+
"""Build an ONNX model whose scale initializer feeds a (Quantize|Dequantize)Linear node.
1028+
1029+
Mirrors the INT4_AWQ failure mode from NVBug 6110209: scales live in graph initializers
1030+
(not Constant nodes) and feed DequantizeLinear (default or trt:: domain) consumers.
1031+
"""
1032+
weight_data = np.random.randint(-8, 8, size=(6, 8), dtype=np.int8)
1033+
weight_tensor = numpy_helper.from_array(weight_data, "weight")
1034+
1035+
scale_data = np.array([1e-3, 0.0, 5e-4, 0.0, 0.0, 2e-3], dtype=np.float16).reshape(6, 1)
1036+
scale_tensor = numpy_helper.from_array(scale_data, "scale")
1037+
1038+
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT16, [None, 6])
1039+
dq_node = helper.make_node(
1040+
dq_op_type, inputs=["weight", "scale"], outputs=["dq_output"], name="weight_dq"
1041+
)
1042+
matmul_node = helper.make_node(
1043+
"MatMul", inputs=["input", "dq_output"], outputs=["output"], name="matmul"
1044+
)
1045+
graph = helper.make_graph(
1046+
nodes=[dq_node, matmul_node],
1047+
name="test_graph",
1048+
inputs=[input_tensor],
1049+
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT16, [None, 8])],
1050+
initializer=[weight_tensor, scale_tensor],
1051+
)
1052+
return helper.make_model(graph)
1053+
1054+
1055+
class TestReplaceZeroScaleWithSmallestNonzero:
1056+
"""Regression tests for ``replace_zero_scale_with_smallest_nonzero`` (NVBug 6110209)."""
1057+
1058+
@pytest.mark.parametrize("dq_op_type", ["DequantizeLinear", "TRT_INT4DequantizeLinear"])
1059+
def test_zero_scale_initializer_fed_to_dq_is_patched(self, dq_op_type):
1060+
from modelopt.onnx.quantization.qdq_utils import replace_zero_scale_with_smallest_nonzero
1061+
1062+
model = _build_model_with_zero_scale_initializer(dq_op_type)
1063+
scale_before = numpy_helper.to_array(
1064+
next(init for init in model.graph.initializer if init.name == "scale")
1065+
)
1066+
assert (scale_before == 0).any(), "fixture must contain zeros to exercise the fix"
1067+
1068+
patched = replace_zero_scale_with_smallest_nonzero(model)
1069+
1070+
scale_after_init = next(init for init in patched.graph.initializer if init.name == "scale")
1071+
scale_after = numpy_helper.to_array(scale_after_init)
1072+
assert not (scale_after == 0).any()
1073+
assert (scale_after > 0).all()
1074+
assert scale_after_init.data_type == TensorProto.FLOAT16
1075+
1076+
def test_constant_node_scale_path_still_patched(self):
1077+
"""Legacy Constant-node QDQ path must continue to be patched."""
1078+
from modelopt.onnx.quantization.qdq_utils import replace_zero_scale_with_smallest_nonzero
1079+
1080+
scale_data = np.array([1e-3, 0.0, 2e-3], dtype=np.float16)
1081+
scale_const = helper.make_node(
1082+
"Constant",
1083+
inputs=[],
1084+
outputs=["scale_out"],
1085+
value=numpy_helper.from_array(scale_data),
1086+
name="scale_constant",
1087+
)
1088+
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [3])
1089+
q_node = helper.make_node(
1090+
"QuantizeLinear",
1091+
inputs=["input", "scale_out"],
1092+
outputs=["q_output"],
1093+
name="q",
1094+
)
1095+
graph = helper.make_graph(
1096+
nodes=[scale_const, q_node],
1097+
name="test_graph",
1098+
inputs=[input_tensor],
1099+
outputs=[helper.make_tensor_value_info("q_output", TensorProto.INT8, [3])],
1100+
initializer=[],
1101+
)
1102+
model = helper.make_model(graph)
1103+
1104+
patched = replace_zero_scale_with_smallest_nonzero(model)
1105+
1106+
const = next(n for n in patched.graph.node if n.op_type == "Constant")
1107+
value_attr = next(a for a in const.attribute if a.name == "value")
1108+
scale_arr = numpy_helper.to_array(value_attr.t)
1109+
assert not (scale_arr == 0).any()
1110+
assert (scale_arr > 0).all()

0 commit comments

Comments
 (0)