Skip to content

Commit b286165

Browse files
authored
[5680954][ONNX][Autocast] Fix 0-dim scalar constant issue (#691)
## What does this PR do? **Type of change:** Bug fix **Overview:** Ops with 0-dim scalar constants were being forced to have shape of 1 instead of 0. This PR fixes that issue. ## Usage ```python $ python -m modelopt.onnx.autocast --onnx_path=$MODEL_NAME.onnx ``` ## Testing Added unittest. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com>
1 parent fc7ebe2 commit b286165

3 files changed

Lines changed: 9 additions & 9 deletions

File tree

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,10 +1336,6 @@ def _convert_constant_values(self, const_node, cast_node: onnx.NodeProto) -> Non
13361336
else:
13371337
casted_data = original_data.astype(cast_dtype)
13381338

1339-
# Workaround for 0-dimensional tensors (scalars)
1340-
if casted_data.ndim == 0:
1341-
casted_data = casted_data.reshape(1)
1342-
13431339
# Create a new constant node with casted data
13441340
if cast_to_type == onnx.TensorProto.BFLOAT16:
13451341
# Create TensorProto manually for bfloat16

tests/_test_utils/onnx/lib_test_models.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def forward(self, x):
7575
class SimpleMLP(nn.Module):
7676
"""Simple toy model."""
7777

78-
def __init__(self, fi=16, f1=18, f2=20, fo=22):
78+
def __init__(self, fi=16, f1=18, f2=20, fo=22, bias_add=False):
7979
super().__init__()
8080
self.net = nn.Sequential(
8181
nn.Linear(fi, f1, bias=False),
@@ -84,10 +84,13 @@ def __init__(self, fi=16, f1=18, f2=20, fo=22):
8484
nn.ReLU(),
8585
nn.Linear(f2, fo, bias=False),
8686
)
87+
self.bias_add = bias_add
8788

8889
def forward(self, x):
8990
for mod in self.net:
9091
x = mod(x)
92+
if self.bias_add:
93+
x += 1e-4
9194
return x
9295

9396

tests/unit/onnx/test_autocast_quantize.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,21 @@ def assert_nodes_are_quantized(nodes):
3636

3737

3838
@pytest.mark.parametrize("keep_io_types", [True, False])
39-
def test_autocast_quantize_int8(tmp_path, keep_io_types):
40-
model_torch = SimpleMLP()
39+
@pytest.mark.parametrize("bias_add", [True, False])
40+
def test_autocast_quantize_int8(tmp_path, keep_io_types, bias_add):
41+
model_torch = SimpleMLP(bias_add=bias_add)
4142
input_tensor = torch.randn(2, 16, 16)
4243
low_precision_type = "fp16"
4344

44-
onnx_path = os.path.join(tmp_path, "model.onnx")
45+
onnx_path = os.path.join(tmp_path, f"model{'_biasAdd' if bias_add else ''}.onnx")
4546
export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path)
4647

4748
# Convert model to low precision
4849
converted_model = convert_to_mixed_precision(
4950
onnx_path, keep_io_types=keep_io_types, low_precision_type=low_precision_type
5051
)
5152
converted_model_path = onnx_path.replace(
52-
".onnx", f".{low_precision_type}.{'keepIOTypes' if keep_io_types else ''}.onnx"
53+
".onnx", f".{low_precision_type}{'_keepIOTypes' if keep_io_types else ''}.onnx"
5354
)
5455
onnx.save(converted_model, converted_model_path)
5556

0 commit comments

Comments
 (0)