Skip to content

Commit 8ed7250

Browse files
add unit test for checking any leak of temporary augmented onnx files during onnx int4 awq quantization
Signed-off-by: vipandya <vipandya@nvidia.com>
1 parent 9d2e608 commit 8ed7250

1 file changed

Lines changed: 56 additions & 0 deletions

File tree

tests/unit/onnx/quantization/test_quantize_zint4.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
# limitations under the License.
1515

1616
import os
17+
import tempfile as _tempfile
1718
from collections.abc import Sequence
1819

1920
import numpy as np
2021
import onnx
2122
import onnx_graphsurgeon as gs
23+
import pytest
2224
from _test_utils.onnx.lib_test_models import find_init
2325

2426
import modelopt.onnx.quantization as moq
@@ -225,3 +227,57 @@ def is_quant_scale_with_right_shape(model, quant_axis, block_size):
225227
)
226228

227229
# Ensure above tests pass.
230+
231+
232+
@pytest.mark.parametrize("calibration_method", ["awq_lite", "awq_clip"])
233+
@pytest.mark.parametrize("use_external_data_format", [True, False])
234+
def test_awq_no_temp_file_leak(tmp_path, monkeypatch, calibration_method, use_external_data_format):
235+
"""Test that tmp*.onnx and tmp*.onnx_data are written to the
236+
system temp directory must be removed even when quantization fails mid-run.
237+
238+
Simulates the real-world failure window (OOM, bad EP, driver error) by injecting
239+
a RuntimeError at ORT session creation — which happens after the augmented ONNX
240+
has already been written to disk but before the original cleanup code was reached.
241+
242+
Thread-safe: tracks the exact paths created by mkstemp during this test rather
243+
than glob-snapshotting the temp directory, so parallel test runs cannot interfere.
244+
"""
245+
onnx_path = _matmul_model(
246+
w=np.random.rand(288, 16).astype(np.float32),
247+
in_shape=(96, 288),
248+
out_shape=(96, 16),
249+
tmp_path=tmp_path,
250+
)
251+
252+
# Intercept mkstemp to record the exact augmented-model temp path(s) created.
253+
created_paths = []
254+
real_mkstemp = _tempfile.mkstemp
255+
256+
def _tracking_mkstemp(*args, **kwargs):
257+
fd, path = real_mkstemp(*args, **kwargs)
258+
created_paths.append(path)
259+
return fd, path
260+
261+
monkeypatch.setattr("modelopt.onnx.quantization.int4.tempfile.mkstemp", _tracking_mkstemp)
262+
263+
def _raise_session_error(*args, **kwargs):
264+
raise RuntimeError("injected ORT session failure")
265+
266+
monkeypatch.setattr(
267+
"modelopt.onnx.quantization.int4.create_inference_session",
268+
_raise_session_error,
269+
)
270+
271+
with pytest.raises(RuntimeError, match="injected ORT session failure"):
272+
quantize_int4(
273+
onnx_path,
274+
calibration_method=calibration_method,
275+
use_external_data_format=use_external_data_format,
276+
block_size=8,
277+
)
278+
279+
assert created_paths, "Expected mkstemp to be called but it was not"
280+
for augmented_path in created_paths:
281+
assert not os.path.exists(augmented_path), f"Leaked: {augmented_path}"
282+
if use_external_data_format:
283+
assert not os.path.exists(augmented_path + "_data"), f"Leaked: {augmented_path}_data"

0 commit comments

Comments
 (0)