Skip to content

Commit 610d9a9

Browse files
committed
Integrate Automated QDQ placement tool - part 2.3
Signed-off-by: Will Guo <willg@nvidia.com>
1 parent 81b67dd commit 610d9a9

1 file changed

Lines changed: 203 additions & 0 deletions

File tree

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Region search inspection tool for ONNX models."""
17+
18+
import argparse
19+
import logging
20+
import sys
21+
from collections import Counter
22+
23+
import onnx
24+
import onnx_graphsurgeon as gs
25+
26+
from modelopt.onnx.logging_config import logger
27+
from modelopt.onnx.quantization.autotune.common import Region, RegionType
28+
from modelopt.onnx.quantization.autotune.insertion_points import has_quantizable_operations
29+
from modelopt.onnx.quantization.autotune.region_search import (
30+
DEFAULT_MAX_STEPS,
31+
CombinedRegionSearch,
32+
)
33+
34+
35+
def inspect_region_search(
36+
onnx_path: str,
37+
max_sequence_size: int = 10,
38+
include_all_regions: bool = False,
39+
) -> list[Region]:
40+
"""Inspect region search results for an ONNX model.
41+
42+
This function loads an ONNX model, runs CombinedRegionSearch (which performs
43+
both bottom-up partitioning and top-down refinement internally), and prints
44+
detailed information about the discovered regions including their hierarchical
45+
structure.
46+
47+
**What it does:**
48+
1. Loads ONNX model and converts to GraphSurgeon format
49+
2. Creates CombinedRegionSearch instance with specified parameters
50+
3. Runs two-phase search (partitioning + refinement) via search_regions()
51+
4. Displays detailed region structure and statistics
52+
5. Returns the final list of refined regions
53+
54+
**Output Sections:**
55+
- Initialization: Shows search parameters
56+
- Two-Phase Search: Runs automatically via CombinedRegionSearch.search_regions()
57+
- Detailed Structure: Shows each region's hierarchy and properties
58+
- Summary Statistics: Shows region counts and node coverage
59+
60+
Args:
61+
onnx_path: Path to the ONNX model file
62+
max_sequence_size: Maximum size for sequence regions during refinement (default: 10)
63+
include_all_regions: Include all regions, even those without major quantizable
64+
operations (Conv, MatMul, etc.). Default: False (skips such regions)
65+
66+
Returns:
67+
List of discovered and refined regions (LEAF and COMPOSITE)
68+
"""
69+
# Load ONNX model
70+
logger.info(f"Loading model: {onnx_path}")
71+
onnx_model = onnx.load(onnx_path)
72+
# Convert to onnx_graphsurgeon Graph
73+
graph = gs.import_onnx(onnx_model)
74+
graph.cleanup().toposort()
75+
logger.info(
76+
f"Loaded graph: {len(graph.nodes)} nodes, {len(graph.inputs)} inputs, {len(graph.outputs)} outputs"
77+
)
78+
# Initialize CombinedRegionSearch (contains RegionPartitioner internally)
79+
logger.debug(
80+
f"Search parameters: max_steps={DEFAULT_MAX_STEPS}, max_sequence_size={max_sequence_size}"
81+
)
82+
83+
combined_search = CombinedRegionSearch(graph, maximum_sequence_region_size=max_sequence_size)
84+
85+
# Run complete two-phase region search
86+
logger.info("Running region search")
87+
regions = combined_search.search_regions()
88+
# Show detailed region structure
89+
logger.info("Analyzing region structure")
90+
all_regions = []
91+
for i, region in enumerate(regions):
92+
region.children = [
93+
c
94+
for c in region.get_children()
95+
if include_all_regions or has_quantizable_operations(c, graph)
96+
]
97+
if not include_all_regions and not has_quantizable_operations(region, graph):
98+
logger.debug(f"Filtered out region {i} (no quantizable operations)")
99+
continue
100+
logger.debug(
101+
f"Region {i}: {region.type.value}, {len(region.get_region_nodes_and_descendants())} nodes, "
102+
f"{len(region.inputs)} inputs, {len(region.outputs)} outputs"
103+
)
104+
all_regions.append(region)
105+
if region.type == RegionType.COMPOSITE:
106+
logger.debug(f" {len(region.get_children())} child regions")
107+
all_regions.extend(region.get_children())
108+
combined_search.print_tree(region, indent=2)
109+
110+
# Summary statistics
111+
type_counts = Counter(r.type for r in all_regions)
112+
leaf_regions, composite_regions = (
113+
type_counts[RegionType.LEAF],
114+
type_counts[RegionType.COMPOSITE],
115+
)
116+
117+
all_nodes = {n for r in all_regions for n in r.get_region_nodes_and_descendants()}
118+
total_nodes = len(all_nodes)
119+
coverage_pct = 100 * total_nodes / len(graph.nodes) if graph.nodes else 0
120+
121+
logger.info(
122+
f"Summary: {len(all_regions)} regions ({leaf_regions} LEAF, {composite_regions} COMPOSITE), "
123+
f"{total_nodes}/{len(graph.nodes)} nodes ({coverage_pct:.1f}%)"
124+
)
125+
126+
# Print histogram of region sizes
127+
region_sizes = [
128+
len(r.get_region_nodes_and_descendants()) for r in all_regions if r.type == RegionType.LEAF
129+
]
130+
131+
if region_sizes:
132+
min_size = min(region_sizes)
133+
max_size = max(region_sizes)
134+
avg_size = sum(region_sizes) / len(region_sizes)
135+
136+
logger.info(f"LEAF region sizes: min={min_size}, max={max_size}, avg={avg_size:.1f}")
137+
size_counts = Counter(region_sizes)
138+
logger.debug("Size distribution:")
139+
for size in sorted(size_counts.keys()):
140+
count = size_counts[size]
141+
bar = "█" * min(count, 50)
142+
logger.debug(f" {size:4d} nodes: {bar} ({count} regions)")
143+
144+
return all_regions
145+
146+
147+
def main():
148+
"""Command-line entry point for region search inspection."""
149+
parser = argparse.ArgumentParser(
150+
prog="modelopt.onnx.quantization.autotune.region_inspect",
151+
description="Inspect region search results for ONNX models",
152+
formatter_class=argparse.RawDescriptionHelpFormatter,
153+
epilog="""
154+
Examples:
155+
# Basic inspection
156+
python -m modelopt.onnx.quantization.autotune.region_inspect --model model.onnx
157+
158+
# Verbose mode for debug logging
159+
python -m modelopt.onnx.quantization.autotune.region_inspect \\
160+
--model model.onnx --verbose
161+
162+
# Custom maximum sequence size
163+
python -m modelopt.onnx.quantization.autotune.region_inspect \\
164+
--model model.onnx --max-sequence-size 20
165+
""",
166+
)
167+
168+
parser.add_argument("--model", "-m", type=str, required=True, help="Path to ONNX model file")
169+
parser.add_argument(
170+
"--max-sequence-size",
171+
type=int,
172+
default=10,
173+
help="Maximum size for sequence regions during refinement (default: 10)",
174+
)
175+
parser.add_argument(
176+
"--include-all-regions",
177+
action="store_true",
178+
help="Include all regions, even those without major quantizable operations. "
179+
"Default: False (skips such regions)",
180+
)
181+
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug logging")
182+
183+
args = parser.parse_args()
184+
185+
log_level = logging.DEBUG if args.verbose else logging.INFO
186+
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s")
187+
logger.setLevel(log_level)
188+
189+
try:
190+
regions = inspect_region_search(
191+
onnx_path=args.model,
192+
max_sequence_size=args.max_sequence_size,
193+
include_all_regions=args.include_all_regions,
194+
)
195+
logger.info(f"✓ Inspection complete: {len(regions)} regions discovered")
196+
return 0
197+
except Exception as e:
198+
logger.error(f"Inspection failed: {e}", exc_info=args.verbose)
199+
return 1
200+
201+
202+
if __name__ == "__main__":
203+
sys.exit(main())

0 commit comments

Comments
 (0)