Skip to content

Commit 391f6cb

Browse files
authored
[5750013][5591945][5360813]: AutoCast standalone implementation for type inference (#719)
## What does this PR do? **Type of change:** New feature **Overview:** AutoCast runs full type inference to get the new types after adding casts. ONNX doesn't have a separate function for type inference, and it is done as part of shape inference. Shape inference is a much more complex task than type inference, especially when dynamic shapes are involved. We're seeing some shape inference related bugs in AutoCast. Typically we can WAR, but it's cumbersome. A local implementation might allow users to WAR shape inference related issues. This is opt-in and marked as experimental. ## Usage python -m modelopt.onnx.autocast --onnx_path /path/to/input.onnx [options] --use_standalone_type_inference ## Testing Added use_standalone_type_inference=True to all existing PrecisionConverter tests. ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes - **Did you add or update any necessary documentation?**: Yes - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information A more permanent fix would be to decouple type and shape inference in ONNX, we should invest in that when we have the resources - see onnx/onnx#7100 . This is a quick fix, which is also why it is opt-in and not the default mode. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added `--use_standalone_type_inference` flag to ONNX AutoCast, enabling type-only inference as an alternative to standard shape inference. Useful as a workaround when shape inference fails or to reduce computational overhead. * **Documentation** * Added "Type Inference Control" section with usage examples and caveats for the new standalone type inference option. * **Tests** * Extended test coverage to validate both standard and standalone type inference paths. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
1 parent 38fb120 commit 391f6cb

7 files changed

Lines changed: 664 additions & 98 deletions

File tree

CHANGELOG.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
NVIDIA Model Optimizer Changelog (Linux)
22
========================================
33

4+
0.42 (TBD)
5+
^^^^^^^^^^^^^^^^^
6+
7+
**Bug Fixes**
8+
9+
**New Features**
10+
- Add standalone type inference option (``--use_standalone_type_inference``) in ONNX AutoCast as an alternative to ONNX's ``infer_shapes``. This experimental feature performs type-only inference without shape inference, useful as a workaround when shape inference fails or to avoid unnecessary shape inference overhead.
11+
412
0.41 (2026-01-19)
513
^^^^^^^^^^^^^^^^^
614

docs/source/guides/8_autocast.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ AutoCast can also be used programmatically through its Python API:
4242
trt_plugins=[], # list of TensorRT plugin library paths in .so format
4343
max_depth_of_reduction=None, # maximum depth of reduction allowed in low precision
4444
opset=None, # optional target ONNX opset version (default: 13 for fp16, 22 for bf16)
45+
use_standalone_type_inference=False, # use standalone type inference instead of ONNX's infer_shapes (WAR)
4546
)
4647
4748
# Save the converted model
@@ -82,6 +83,9 @@ AutoCast follows these steps to convert a model:
8283
- Converts eligible nodes to lower precision
8384
- Automatically inserts necessary cast operations
8485
- Automatically replaces initializers with lower precision values
86+
- Performs type inference to propagate types through the graph
87+
- By default, uses ONNX's ``infer_shapes`` which performs both shape and type inference using the ONNX infer_shapes API.
88+
- Use ``use_standalone_type_inference=True`` to use a standalone type-only inference implementation (experimental).
8589

8690
#. **Validation and Export**:
8791

@@ -145,6 +149,14 @@ Best Practices
145149
- A warning will be issued if you specify an opset lower than the original model's opset, as downgrading opset versions may cause compatibility issues.
146150
- The opset may be automatically increased beyond your specified value if certain operations require it (e.g., quantization nodes require opset >= 19).
147151

152+
#. **Type Inference Control**
153+
154+
- By default, AutoCast uses ONNX's ``infer_shapes`` which performs both shape and type inference.
155+
- Use ``--use_standalone_type_inference`` to enable a standalone type-only inference implementation.
156+
- This is a workaround for cases where shape inference fails for any reason, which allows us to bypass the dependency in ONNX's shape inference logic.
157+
- The standalone implementation uses graphsurgeon for topological sorting and handles special operators like Cast, QuantizeLinear, DequantizeLinear, Constant and ConstantOfShape.
158+
- Note: The standalone type inference may be less robust than ONNX's implementation for edge cases, but avoids unnecessary shape inference overhead and possible failures.
159+
148160
Limitations and Restrictions
149161
----------------------------
150162
- AutoCast does not yet support quantized models.
@@ -198,3 +210,9 @@ Convert to BF16 with a specific opset:
198210
.. code-block:: bash
199211
200212
python -m modelopt.onnx.autocast --onnx_path model.onnx --low_precision_type bf16 --opset 22
213+
214+
Use standalone type inference instead of ONNX's infer_shapes:
215+
216+
.. code-block:: bash
217+
218+
python -m modelopt.onnx.autocast --onnx_path model.onnx --use_standalone_type_inference

modelopt/onnx/autocast/__main__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,16 @@ def get_parser() -> argparse.ArgumentParser:
185185
"higher version."
186186
),
187187
)
188+
parser.add_argument(
189+
"--use_standalone_type_inference",
190+
action="store_true",
191+
default=False,
192+
help=(
193+
"Use local type inference implementation instead of ONNX's infer_shapes (experimental)."
194+
"This is a workaround for cases where shape inference fails for any reason."
195+
"Default: False (uses ONNX's infer_shapes which does both shape and type inference)."
196+
),
197+
)
188198

189199
return parser
190200

@@ -218,6 +228,7 @@ def main(argv=None):
218228
trt_plugins_precision=args.trt_plugins_precision,
219229
max_depth_of_reduction=args.max_depth_of_reduction,
220230
opset=args.opset,
231+
use_standalone_type_inference=args.use_standalone_type_inference,
221232
)
222233

223234
output_path = args.output_path

modelopt/onnx/autocast/convert.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def convert_to_mixed_precision(
6161
trt_plugins_precision: list[str] = [],
6262
max_depth_of_reduction: int | None = None,
6363
opset: int | None = None,
64+
use_standalone_type_inference: bool = False,
6465
) -> onnx.ModelProto:
6566
"""Convert model to mixed precision.
6667
@@ -85,6 +86,9 @@ def convert_to_mixed_precision(
8586
opset: Target ONNX opset version. If None, uses default minimum opset based on low_precision_type
8687
(22 for bf16, 13 for fp16). The opset may be automatically increased if certain operations
8788
require a higher version.
89+
use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's
90+
infer_shapes. This is a workaround (WAR) when only type inference is
91+
needed without shape inference. Default: False.
8892
8993
Returns:
9094
onnx.ModelProto: The converted mixed precision model.
@@ -132,7 +136,7 @@ def convert_to_mixed_precision(
132136
model = graph_sanitizer.model
133137

134138
# Setup internal mappings
135-
model = onnx_utils.infer_shapes(model)
139+
model = onnx_utils.infer_types(model, use_standalone_type_inference)
136140
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
137141

138142
# Automatically add 'trt' to list of providers if custom ops are detected
@@ -164,6 +168,7 @@ def convert_to_mixed_precision(
164168
low_precision_type=low_precision_type,
165169
init_conversion_max_bytes=init_conversion_max_bytes,
166170
custom_ops=graph_sanitizer.custom_ops,
171+
use_standalone_type_inference=use_standalone_type_inference,
167172
)
168173

169174
# Obtain reference data
@@ -196,6 +201,7 @@ def convert_to_f16(
196201
op_block_list: list[str] = [],
197202
tensor_block_dict: dict[str, dict[str, list[int]]] = {},
198203
trt_plugins: list[str] | None = [],
204+
use_standalone_type_inference: bool = False,
199205
) -> onnx.ModelProto:
200206
"""Convert model to mixed precision, using PrecisionConverter.
201207
@@ -208,6 +214,9 @@ def convert_to_f16(
208214
op_block_list: List of operation types that should remain in FP32.
209215
tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
210216
trt_plugins: List of TensorRT plugin library paths in .so format (compiled shared library).
217+
use_standalone_type_inference: If True, use standalone type inference implementation instead of ONNX's
218+
infer_shapes. This is a workaround (WAR) when only type inference is
219+
needed without shape inference. Default: False.
211220
"""
212221
assert low_precision_type in ["fp16", "bf16"], "low_precision_type must be either fp16 or bf16"
213222

@@ -225,7 +234,7 @@ def convert_to_f16(
225234
model = sanitizer.model
226235

227236
# Setup internal mappings
228-
model = onnx_utils.infer_shapes(model)
237+
model = onnx_utils.infer_types(model, use_standalone_type_inference)
229238
value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model)
230239

231240
precision_converter = PrecisionConverter(
@@ -237,6 +246,7 @@ def convert_to_f16(
237246
low_precision_type=low_precision_type,
238247
custom_ops=sanitizer.custom_ops,
239248
tensor_block_dict=tensor_block_dict,
249+
use_standalone_type_inference=use_standalone_type_inference,
240250
)
241251
high_precision_nodes = [node.name for node in model.graph.node if node.op_type in op_block_list]
242252
low_precision_nodes = [

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
max_ir_version: int | None = None,
9898
trt_plugins: list[str] | None = [],
9999
tensor_block_dict: dict[str, dict[str, list[int]]] = {},
100+
use_standalone_type_inference: bool = False,
100101
) -> None:
101102
"""Initialize PrecisionConverter.
102103
@@ -114,6 +115,7 @@ def __init__(
114115
max_ir_version: Max IR version for conversion.
115116
trt_plugins: List of custom TensorRT plugin library paths in .so format (compiled shared library).
116117
tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
118+
use_standalone_type_inference: Use standalone type inference instead of ONNX's infer_shapes.
117119
"""
118120
self.model = deepcopy(model)
119121
self.value_info_map = value_info_map
@@ -140,6 +142,7 @@ def __init__(
140142
self.min_opset = min_opset
141143
self.max_ir_version = max_ir_version
142144
self.trt_plugins = trt_plugins
145+
self.use_standalone_type_inference = use_standalone_type_inference
143146

144147
# Detect additional ops not supported in low precision according to the model's opset version
145148
self.op_types_not_supported_in_low_precision = OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION + (
@@ -254,10 +257,14 @@ def convert(
254257
# Clear type/shape information for intermediates and outputs (including subgraphs)
255258
self._clear_types_and_shapes_recursive(self.model.graph)
256259
# Populate type information with inferred types
257-
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False)
260+
self.model = onnx_utils.infer_types(
261+
self.model, self.use_standalone_type_inference, strict_mode=True, check_type=False
262+
)
258263
self._ensure_types_are_defined()
259264
# Sanity check: Verify type correctness
260-
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True)
265+
self.model = onnx_utils.infer_types(
266+
self.model, self.use_standalone_type_inference, strict_mode=True, check_type=True
267+
)
261268

262269
# Update value_info_map and initializer_map with casts we added
263270
self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings(
@@ -282,9 +289,9 @@ def _clear_types_and_shapes_recursive(
282289
) -> None:
283290
"""Recursively clear type/shape information for a graph and all its subgraphs.
284291
285-
This is necessary for control flow operators (Scan, If, Loop) which have subgraphs.
286-
For subgraphs, preserve value_info for outer scope variables (not produced by nodes in subgraph).
287-
For main graph, clear all value_info.
292+
If use_standalone_type_inference is True, we clear only types, not shapes.
293+
For subgraphs, input types/shapes are cleared, so that the input types/shapes are propagated
294+
from the main graph.
288295
289296
Args:
290297
graph: The ONNX graph to clear types and shapes for.
@@ -301,9 +308,10 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) ->
301308
for inp in g.input:
302309
if inp.type.HasField("tensor_type"):
303310
inp.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
304-
for idx, d in enumerate(inp.type.tensor_type.shape.dim):
305-
if d.dim_value:
306-
inp.type.tensor_type.shape.dim[idx].dim_param = "unk"
311+
if not self.use_standalone_type_inference:
312+
for idx, d in enumerate(inp.type.tensor_type.shape.dim):
313+
if d.dim_value:
314+
inp.type.tensor_type.shape.dim[idx].dim_param = "unk"
307315

308316
if is_sub:
309317
# Identify which tensors are produced by nodes in this subgraph
@@ -315,9 +323,10 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) ->
315323
for vi in g.value_info:
316324
if vi.name in subgraph_outputs:
317325
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
318-
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
319-
if d.dim_value:
320-
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
326+
if not self.use_standalone_type_inference:
327+
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
328+
if d.dim_value:
329+
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
321330
else:
322331
for vi in g.value_info:
323332
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
@@ -328,9 +337,10 @@ def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) ->
328337
# Clear outputs for both main graph and subgraphs
329338
for out in g.output:
330339
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
331-
for idx, d in enumerate(out.type.tensor_type.shape.dim):
332-
if d.dim_value:
333-
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
340+
if not self.use_standalone_type_inference:
341+
for idx, d in enumerate(out.type.tensor_type.shape.dim):
342+
if d.dim_value:
343+
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
334344

335345
utils.walk_subgraphs_recursive(graph, _clear_callback, is_subgraph=is_subgraph)
336346

@@ -1177,8 +1187,16 @@ def _remove_redundant_casts(self):
11771187
if self.custom_ops:
11781188
self.model = self._propagate_types_shapes_custom_ops(self.model)
11791189
else:
1180-
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True)
1181-
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True)
1190+
self.model = onnx_utils.infer_types(
1191+
self.model, self.use_standalone_type_inference, strict_mode=True
1192+
)
1193+
if not self.use_standalone_type_inference:
1194+
self.model = onnx_utils.infer_types(
1195+
self.model,
1196+
self.use_standalone_type_inference,
1197+
strict_mode=True,
1198+
check_type=True,
1199+
)
11821200

11831201
nodes_to_remove = []
11841202
for node in self.model.graph.node:
@@ -1263,7 +1281,12 @@ def _fix_network_output_names(self):
12631281
if self.custom_ops:
12641282
self.model = self._propagate_types_shapes_custom_ops(self.model)
12651283
else:
1266-
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=True)
1284+
self.model = onnx_utils.infer_types(
1285+
self.model,
1286+
self.use_standalone_type_inference,
1287+
strict_mode=True,
1288+
check_type=True,
1289+
)
12671290
self.value_info_map, self.initializer_map, self.node_to_init_map = utils.setup_mappings(
12681291
self.model
12691292
)

0 commit comments

Comments
 (0)