Skip to content

Commit c1956b8

Browse files
authored
[5763424][ONNX][Autocast] Fix ConstantOfShape layer output precision (#789)
## What does this PR do? **Type of change:** Bug fix **Overview:** Fixed the output precision of ConstantOfShape layers in models with custom ops. ## Usage <!-- You can potentially add a usage example below. --> ```python $ python -m modelopt.onnx.quantization --onnx_path=$MODEL_NAME.onnx ``` ## Testing See bug 5763424. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information This issue only affects models with custom ops. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved type propagation handling for ConstantOfShape operations in ONNX autocast, ensuring correct precision type conversion across related operations. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com>
1 parent e6e4efd commit c1956b8

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@ def _get_np_type(node, inp, opset=onnx.defs.onnx_opset_version()):
347347
return node.inputs[1].dtype # scale type
348348
elif node.op == "QuantizeLinear":
349349
return node.inputs[2].dtype # zero_point type
350+
elif node.op == "ConstantOfShape":
351+
return node.attrs["value"].dtype
350352
elif not inp.dtype or inp.dtype == onnx.TensorProto.UNDEFINED:
351353
return None
352354
elif node.op not in self.custom_ops:

0 commit comments

Comments
 (0)