Skip to content

Commit 7e82a5c

Browse files
[Serialization]: remove explicit weights_only default from safe_load to allow user to bypass if needed (#1279)
## Summary - Remove the `kwargs.setdefault("weights_only", True)` call from `safe_load`, deferring to torch's built-in default (which is `True` for torch>=2.6) - This allows users to override via the `TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1` env var when they trust a checkpoint but hit `pickle.UnpicklingError` - Add a test that verifies the default fails on unsafe objects and the env var bypass works ## Test plan - [x] `python -m pytest tests/unit/torch/utils/test_serialization.py -v` 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Serialization utility now respects PyTorch's default behavior and environment-variable configuration instead of forcibly enforcing parameter overrides, providing greater configuration flexibility. * **Tests** * Added test coverage validating environment-variable override functionality and default behavior in the serialization utility. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 3162ff0 commit 7e82a5c

2 files changed

Lines changed: 28 additions & 2 deletions

File tree

modelopt/torch/utils/serialization.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,11 @@ def safe_save(obj: Any, f: str | os.PathLike | BinaryIO, **kwargs) -> None:
5454

5555

5656
def safe_load(f: str | os.PathLike | BinaryIO | bytes, **kwargs) -> Any:
57-
"""Load a checkpoint securely using weights_only=True by default."""
58-
kwargs.setdefault("weights_only", True)
57+
"""Load a checkpoint securely using ``weights_only=True`` by default.
5958
59+
NOTE: We dont set default ``weights_only`` (interpret as True for torch>=2.6) so you can override it with
60+
``export TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1`` if you see ``pickle.UnpicklingError`` and trust the checkpoint.
61+
"""
6062
if isinstance(f, (bytes, bytearray)):
6163
f = BytesIO(f)
6264

tests/unit/torch/utils/test_serialization.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
"""Tests for Modelopt's serialization utilities."""
1717

1818
from io import BytesIO
19+
from pickle import UnpicklingError
1920

21+
import pytest
2022
import torch
2123

2224
from modelopt.torch.opt.config import ModeloptBaseConfig
@@ -70,3 +72,25 @@ def test_safe_load_with_path(tmp_path):
7072
loaded_state = safe_load(file_path)
7173

7274
assert loaded_state["data"] == 42
75+
76+
77+
class _UnsafeObj:
78+
"""Not registered in torch safe globals — unpickling fails with weights_only=True."""
79+
80+
def __init__(self, v):
81+
self.v = v
82+
83+
84+
def test_safe_load_env_var_bypasses_weights_only(tmp_path, monkeypatch):
85+
"""Verify TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 allows safe_load to load objects unsafe for weights_only."""
86+
file_path = tmp_path / "unsafe.pt"
87+
torch.save({"obj": _UnsafeObj(42)}, file_path)
88+
89+
# Always fails when weights_only is not set (default=True)
90+
with pytest.raises(UnpicklingError):
91+
safe_load(file_path)
92+
93+
# With the env var, safe_load (no explicit weights_only) defers to torch's default=False
94+
monkeypatch.setenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "1")
95+
loaded = safe_load(file_path)
96+
assert loaded["obj"].v == 42

0 commit comments

Comments
 (0)