Skip to content

Commit e5ce0ae

Browse files
authored
[NVBug 6102977] Add _disable_use_cache context manager to fix PTQ AttributeError on custom configs (#1324)
### What does this PR do? Type of change: Bug fix <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> <!-- Details about the change. --> - Summary: Running hf_ptq.py on stepfun-ai/Step-3.5-Flash (and any model whose custom HF config doesn't assign use_cache) crashed in get_max_batch_size() with AttributeError: 'Step3p5Config' object has no attribute 'use_cache' before calibration could start. - Extract the existing "disable KV cache during calibration" logic into a _disable_use_cache(model) context manager, apply it to both get_max_batch_size and _forward_loop. The CM sets config.use_cache = False unconditionally (not only when the attribute exists) and restores the prior value on exit if one was set. - Behavior unchanged for normal configs; the NemotronH hybrid-cache correctness guarantee from #1251 is preserved. ### Usage ```python # Add a code snippet demonstrating how to use this ``` ### Testing <!-- Mention how have you tested your change if applicable. --> Step-3.5-Flash PTQ now passes get_max_batch_size ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ / ❌ / N/A <!--- If ❌, explain why. --> - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: ✅ / ❌ / N/A <!--- Mandatory --> - Did you write any new necessary tests?: ✅ / ❌ / N/A <!--- Mandatory for new features or examples. --> - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ / ❌ / N/A <!--- Only for new features, API changes, critical bug fixes or backward incompatible changes. --> ### Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Improved memory handling during model evaluation and calibration by consistently disabling KV cache for both single-batch probes and full dataloader runs, simplifying and stabilizing inference flow and ensuring cache state is managed reliably. * **Tests** * Added unit tests verifying cache-state handling across models with and without cache settings, including correct restoration behavior even when errors occur. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com>
1 parent 8eec6d4 commit e5ce0ae

2 files changed

Lines changed: 161 additions & 55 deletions

File tree

modelopt/torch/utils/dataset_utils.py

Lines changed: 74 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import copy
1919
import json
2020
import os
21-
from collections.abc import Callable
21+
from collections.abc import Callable, Iterator
22+
from contextlib import contextmanager, suppress
2223
from pathlib import Path
2324
from typing import TYPE_CHECKING, Any
2425
from warnings import warn
@@ -437,6 +438,36 @@ def get_supported_datasets() -> list[str]:
437438
return list(SUPPORTED_DATASET_CONFIG.keys())
438439

439440

441+
@contextmanager
442+
def _disable_use_cache(model: torch.nn.Module) -> Iterator[None]:
443+
"""Set ``model.config.use_cache = False`` for the duration of the block.
444+
445+
KV caching is unwanted during calibration / memory-probe forward passes:
446+
it wastes memory, and for hybrid Mamba/attention models (e.g., NemotronH)
447+
the cache state is mutated in-place and breaks correctness. Setting
448+
``use_cache`` unconditionally (rather than only when it was already
449+
present) also sidesteps configs that never assign the attribute at all
450+
— e.g., ``Step3p5Config`` from stepfun-ai/Step-3.5-Flash — where forward
451+
code that reads ``self.config.use_cache`` would otherwise raise
452+
``AttributeError``. The prior value is restored on exit if one existed.
453+
"""
454+
config = getattr(model, "config", None)
455+
if config is None:
456+
yield
457+
return
458+
had_attr = hasattr(config, "use_cache")
459+
prev = config.use_cache if had_attr else None
460+
config.use_cache = False
461+
try:
462+
yield
463+
finally:
464+
if had_attr:
465+
config.use_cache = prev
466+
else:
467+
with suppress(AttributeError):
468+
delattr(config, "use_cache")
469+
470+
440471
def get_max_batch_size(
441472
model: torch.nn.Module,
442473
max_sample_length: int = 512,
@@ -467,42 +498,43 @@ def _get_free_gpu_mem():
467498
torch.ones([1, max_sample_length], dtype=torch.int32, device=model.device) * 100
468499
)
469500

470-
# Calculate single batch inference with dummy input.
471-
with torch.set_grad_enabled(enable_grad):
472-
infer_method(sample_input_single_batch)
473-
free_mem_after, max_allocated_after = _get_free_gpu_mem()
501+
with _disable_use_cache(model):
502+
# Calculate single batch inference with dummy input.
503+
with torch.set_grad_enabled(enable_grad):
504+
infer_method(sample_input_single_batch)
505+
free_mem_after, max_allocated_after = _get_free_gpu_mem()
474506

475-
mem_diff_per_data_batch = (
476-
max(
477-
(free_mem_before - free_mem_after),
478-
(max_allocated_after - max_allocated_before),
507+
mem_diff_per_data_batch = (
508+
max(
509+
(free_mem_before - free_mem_after),
510+
(max_allocated_after - max_allocated_before),
511+
)
512+
* sample_memory_usage_ratio
479513
)
480-
* sample_memory_usage_ratio
481-
)
482-
if mem_diff_per_data_batch <= 0:
483-
print(
484-
"Warning: No measurable memory usage found for a single batch. "
485-
"Falling back to batch_size=1."
514+
if mem_diff_per_data_batch <= 0: # pragma: no cover - GPU memory probe edge case
515+
print( # pragma: no cover
516+
"Warning: No measurable memory usage found for a single batch. "
517+
"Falling back to batch_size=1."
518+
)
519+
target_data_batch = 1 # pragma: no cover
520+
else:
521+
target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1)
522+
target_input = sample_input_single_batch.expand(
523+
[
524+
target_data_batch if index == 0 else dim
525+
for index, dim in enumerate(sample_input_single_batch.shape)
526+
]
486527
)
487-
target_data_batch = 1
488-
else:
489-
target_data_batch = max(int(free_mem_before / mem_diff_per_data_batch), 1)
490-
target_input = sample_input_single_batch.expand(
491-
[
492-
target_data_batch if index == 0 else dim
493-
for index, dim in enumerate(sample_input_single_batch.shape)
494-
]
495-
)
496528

497-
# For some models on multi GPU, we observe the memory per batch is not a constant.
498-
# So we just test the target batch size and make sure we do not go OOM.
499-
while target_data_batch > 1:
500-
with torch.set_grad_enabled(enable_grad):
501-
try:
502-
infer_method(target_input)
503-
break
504-
except torch.cuda.OutOfMemoryError:
505-
target_data_batch = target_data_batch // 2
529+
# For some models on multi GPU, we observe the memory per batch is not a constant.
530+
# So we just test the target batch size and make sure we do not go OOM.
531+
while target_data_batch > 1:
532+
with torch.set_grad_enabled(enable_grad):
533+
try:
534+
infer_method(target_input)
535+
break
536+
except torch.cuda.OutOfMemoryError: # pragma: no cover - GPU OOM retry path
537+
target_data_batch = target_data_batch // 2 # pragma: no cover
506538

507539
# Regulate the data batch target to be 1, 2, 4, 8, 12, ..., capped at 64
508540
if target_data_batch < 2:
@@ -601,28 +633,16 @@ def _forward_loop(
601633
dataloader: DataLoader containing the batched input data
602634
allowed_non_tensor_keys: Set of key names whose values may be non-tensor types
603635
"""
604-
# Disable KV caching during calibration — it is unnecessary overhead and causes
605-
# correctness issues with hybrid Mamba/attention models whose cache state is mutated
606-
# in-place (e.g., NemotronH).
607-
config = getattr(model, "config", None)
608-
prev_use_cache = getattr(config, "use_cache", None)
609-
if config is not None and prev_use_cache is not None:
610-
config.use_cache = False
636+
with _disable_use_cache(model), torch.no_grad():
637+
is_enc_dec = model_type_is_enc_dec(model)
638+
infer_method = model.generate if is_enc_dec else model.forward
639+
max_working_batch_size = None # Initialize max working batch size as None
611640

612-
try:
613-
with torch.no_grad():
614-
is_enc_dec = model_type_is_enc_dec(model)
615-
infer_method = model.generate if is_enc_dec else model.forward
616-
max_working_batch_size = None # Initialize max working batch size as None
617-
618-
for _, data in enumerate(tqdm(dataloader)):
619-
# Process batch and update max working batch size
620-
max_working_batch_size = _process_batch(
621-
data, infer_method, max_working_batch_size, allowed_non_tensor_keys
622-
)
623-
finally:
624-
if config is not None and prev_use_cache is not None:
625-
config.use_cache = prev_use_cache
641+
for _, data in enumerate(tqdm(dataloader)):
642+
# Process batch and update max working batch size
643+
max_working_batch_size = _process_batch(
644+
data, infer_method, max_working_batch_size, allowed_non_tensor_keys
645+
)
626646

627647

628648
def create_forward_loop(

tests/unit/torch/utils/test_dataset_utils.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,14 @@
1717

1818
import pytest
1919
import torch
20+
from torch.utils.data import DataLoader
2021

21-
from modelopt.torch.utils.dataset_utils import _process_batch, get_dataset_samples
22+
from modelopt.torch.utils.dataset_utils import (
23+
_disable_use_cache,
24+
_forward_loop,
25+
_process_batch,
26+
get_dataset_samples,
27+
)
2228

2329

2430
def setup_test_data():
@@ -145,6 +151,86 @@ def mock_infer(**kwargs):
145151
_process_batch(batch_data, mock_infer, allowed_non_tensor_keys={"base_model_outputs"})
146152

147153

154+
class _Config:
155+
"""Minimal config stand-in; instances start with no `use_cache` attribute."""
156+
157+
158+
def test_disable_use_cache_no_config_attr():
159+
"""Model without a `config` attribute: CM is a no-op and does not raise."""
160+
model = torch.nn.Linear(4, 4)
161+
assert not hasattr(model, "config")
162+
163+
with _disable_use_cache(model):
164+
assert not hasattr(model, "config")
165+
166+
assert not hasattr(model, "config")
167+
168+
169+
@pytest.mark.parametrize("prev_value", [True, False])
170+
def test_disable_use_cache_with_existing_attr(prev_value):
171+
"""Config that already has `use_cache`: forced to False inside, restored on exit."""
172+
model = torch.nn.Linear(4, 4)
173+
model.config = _Config()
174+
model.config.use_cache = prev_value
175+
176+
with _disable_use_cache(model):
177+
assert model.config.use_cache is False
178+
179+
assert model.config.use_cache is prev_value
180+
181+
182+
def test_disable_use_cache_without_existing_attr():
183+
"""Config that lacks `use_cache`: set to False inside, attribute removed on exit (no leak)."""
184+
model = torch.nn.Linear(4, 4)
185+
model.config = _Config()
186+
assert not hasattr(model.config, "use_cache")
187+
188+
with _disable_use_cache(model):
189+
assert model.config.use_cache is False
190+
191+
assert not hasattr(model.config, "use_cache")
192+
193+
194+
def test_forward_loop_runs_under_disabled_use_cache():
195+
"""`_forward_loop` runs forward on every batch and restores `use_cache` on exit."""
196+
seen_use_cache: list[bool] = []
197+
198+
class _Model(torch.nn.Module):
199+
def __init__(self):
200+
super().__init__()
201+
self.config = _Config()
202+
self.config.use_cache = True
203+
204+
def forward(self, **kwargs):
205+
seen_use_cache.append(self.config.use_cache)
206+
207+
model = _Model()
208+
209+
def _collate(samples):
210+
return {"input_ids": torch.stack([s["input_ids"] for s in samples])}
211+
212+
data = [{"input_ids": torch.zeros(8, dtype=torch.long)} for _ in range(3)]
213+
loader = DataLoader(data, batch_size=1, collate_fn=_collate)
214+
215+
_forward_loop(model, loader)
216+
217+
assert seen_use_cache == [False, False, False]
218+
assert model.config.use_cache is True
219+
220+
221+
def test_disable_use_cache_restores_on_exception():
222+
"""Restore must run even if the with-block raises."""
223+
model = torch.nn.Linear(4, 4)
224+
model.config = _Config()
225+
model.config.use_cache = True
226+
227+
with pytest.raises(RuntimeError, match="boom"), _disable_use_cache(model):
228+
assert model.config.use_cache is False
229+
raise RuntimeError("boom")
230+
231+
assert model.config.use_cache is True
232+
233+
148234
@pytest.mark.parametrize("test_local_path", [True, False])
149235
def test_get_dataset_samples_with_unsupported_minipile_dataset(tmp_path, test_local_path):
150236
pytest.importorskip("datasets")

0 commit comments

Comments
 (0)