|
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | 16 | import os |
| 17 | +import tempfile as _tempfile |
17 | 18 | from collections.abc import Sequence |
18 | 19 |
|
19 | 20 | import numpy as np |
20 | 21 | import onnx |
21 | 22 | import onnx_graphsurgeon as gs |
| 23 | +import pytest |
22 | 24 | from _test_utils.onnx.lib_test_models import find_init |
23 | 25 |
|
24 | 26 | import modelopt.onnx.quantization as moq |
@@ -225,3 +227,57 @@ def is_quant_scale_with_right_shape(model, quant_axis, block_size): |
225 | 227 | ) |
226 | 228 |
|
227 | 229 | # Ensure above tests pass. |
| 230 | + |
| 231 | + |
| 232 | +@pytest.mark.parametrize("calibration_method", ["awq_lite", "awq_clip"]) |
| 233 | +@pytest.mark.parametrize("use_external_data_format", [True, False]) |
| 234 | +def test_awq_no_temp_file_leak(tmp_path, monkeypatch, calibration_method, use_external_data_format): |
| 235 | + """Test that tmp*.onnx and tmp*.onnx_data are written to the |
| 236 | + system temp directory must be removed even when quantization fails mid-run. |
| 237 | +
|
| 238 | + Simulates the real-world failure window (OOM, bad EP, driver error) by injecting |
| 239 | + a RuntimeError at ORT session creation — which happens after the augmented ONNX |
| 240 | + has already been written to disk but before the original cleanup code was reached. |
| 241 | +
|
| 242 | + Thread-safe: tracks the exact paths created by mkstemp during this test rather |
| 243 | + than glob-snapshotting the temp directory, so parallel test runs cannot interfere. |
| 244 | + """ |
| 245 | + onnx_path = _matmul_model( |
| 246 | + w=np.random.rand(288, 16).astype(np.float32), |
| 247 | + in_shape=(96, 288), |
| 248 | + out_shape=(96, 16), |
| 249 | + tmp_path=tmp_path, |
| 250 | + ) |
| 251 | + |
| 252 | + # Intercept mkstemp to record the exact augmented-model temp path(s) created. |
| 253 | + created_paths = [] |
| 254 | + real_mkstemp = _tempfile.mkstemp |
| 255 | + |
| 256 | + def _tracking_mkstemp(*args, **kwargs): |
| 257 | + fd, path = real_mkstemp(*args, **kwargs) |
| 258 | + created_paths.append(path) |
| 259 | + return fd, path |
| 260 | + |
| 261 | + monkeypatch.setattr("modelopt.onnx.quantization.int4.tempfile.mkstemp", _tracking_mkstemp) |
| 262 | + |
| 263 | + def _raise_session_error(*args, **kwargs): |
| 264 | + raise RuntimeError("injected ORT session failure") |
| 265 | + |
| 266 | + monkeypatch.setattr( |
| 267 | + "modelopt.onnx.quantization.int4.create_inference_session", |
| 268 | + _raise_session_error, |
| 269 | + ) |
| 270 | + |
| 271 | + with pytest.raises(RuntimeError, match="injected ORT session failure"): |
| 272 | + quantize_int4( |
| 273 | + onnx_path, |
| 274 | + calibration_method=calibration_method, |
| 275 | + use_external_data_format=use_external_data_format, |
| 276 | + block_size=8, |
| 277 | + ) |
| 278 | + |
| 279 | + assert created_paths, "Expected mkstemp to be called but it was not" |
| 280 | + for augmented_path in created_paths: |
| 281 | + assert not os.path.exists(augmented_path), f"Leaked: {augmented_path}" |
| 282 | + if use_external_data_format: |
| 283 | + assert not os.path.exists(augmented_path + "_data"), f"Leaked: {augmented_path}_data" |
0 commit comments