Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions deepmd/pt_expt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,19 @@ def _deserialize_to_file_pt2(
data, model_json_override
)

# Compile via AOTInductor into a .pt2 package
aoti_compile_and_package(exported, package_path=model_file)
# AOTInductor's lowering code internally creates tensors (e.g.
# ``torch.zeros``) without an explicit ``device=`` argument. If a
# non-CPU default device is active (e.g. tests/pt/__init__.py sets
# ``torch.set_default_device("cuda:9999999")``), the compilation fails
# on CPU-only builds. Temporarily clear the default device so the
# inductor always targets CPU.
prev_device = torch.get_default_device()
Comment thread
wanghan-iapcm marked this conversation as resolved.
Outdated
torch.set_default_device(None)
try:
# Compile via AOTInductor into a .pt2 package
aoti_compile_and_package(exported, package_path=model_file)
finally:
torch.set_default_device(prev_device)

# Embed metadata into the .pt2 ZIP archive
model_def_script = data.get("model_def_script") or {}
Expand Down
10 changes: 10 additions & 0 deletions source/tests/pt_expt/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,21 @@
"""

import pytest
import torch._inductor.config as _inductor_config
import torch.utils._device as _device
from torch.overrides import (
_get_current_function_mode_stack,
)

# Reduce AOTInductor (.pt2) compile time for unit tests.
# Tests only validate correctness, not runtime performance, so we can
# skip expensive C++ optimizations. This cuts compile time by ~50%.
_inductor_config.max_fusion_size = 8
_inductor_config.epilogue_fusion = False
_inductor_config.pattern_matcher = False
_inductor_config.aot_inductor.package_cpp_only = True
_inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"


def _pop_device_contexts() -> list:
"""Pop all stale DeviceContext modes from the torch function mode stack."""
Expand Down
34 changes: 8 additions & 26 deletions source/tests/pt_expt/infer/test_deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,13 +543,7 @@ def setUpClass(cls) -> None:
cls.model_data = {"model": cls.model.serialize()}
cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False)
cls.tmpfile.close()
# Temporarily clear default device to avoid poisoning AOTInductor
# compilation (tests/pt/__init__.py sets it to "cuda:9999999").
torch.set_default_device(None)
try:
deserialize_to_file(cls.tmpfile.name, cls.model_data)
finally:
torch.set_default_device("cuda:9999999")
deserialize_to_file(cls.tmpfile.name, cls.model_data)

# Also save to .pte for cross-format comparison
cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)
Expand Down Expand Up @@ -606,15 +600,11 @@ def test_get_model_def_script_with_params(self) -> None:
with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
tmpfile2 = f.name
try:
torch.set_default_device(None)
try:
data_with_config = {
**self.model_data,
"model_def_script": training_config,
}
deserialize_to_file(tmpfile2, data_with_config)
finally:
torch.set_default_device("cuda:9999999")
data_with_config = {
**self.model_data,
"model_def_script": training_config,
}
deserialize_to_file(tmpfile2, data_with_config)
dp2 = DeepPot(tmpfile2)
mds = dp2.deep_eval.get_model_def_script()
self.assertEqual(mds, training_config)
Expand Down Expand Up @@ -970,11 +960,7 @@ def setUpClass(cls) -> None:
cls.model_data = {"model": cls.model.serialize()}
cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False)
cls.tmpfile.close()
torch.set_default_device(None)
try:
deserialize_to_file(cls.tmpfile.name, cls.model_data)
finally:
torch.set_default_device("cuda:9999999")
deserialize_to_file(cls.tmpfile.name, cls.model_data)

# Also save .pte for cross-format comparison
cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)
Expand Down Expand Up @@ -1185,11 +1171,7 @@ def setUpClass(cls) -> None:
cls.model_data = {"model": cls.model.serialize()}
cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False)
cls.tmpfile.close()
torch.set_default_device(None)
try:
deserialize_to_file(cls.tmpfile.name, cls.model_data)
finally:
torch.set_default_device("cuda:9999999")
deserialize_to_file(cls.tmpfile.name, cls.model_data)

# Also save .pte for cross-format comparison
cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)
Expand Down
23 changes: 3 additions & 20 deletions source/tests/pt_expt/infer/test_deep_eval_spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,7 @@ def spin_model_files():
tmpdir = tempfile.mkdtemp()
for ext in (".pt2", ".pte"):
path = os.path.join(tmpdir, f"spin_test{ext}")
# AOTInductor (.pt2) internally creates tensors using the PyTorch
# default device. Clear it so compilation stays on CPU.
prev = torch.get_default_device()
torch.set_default_device(None)
try:
deserialize_to_file(path, copy.deepcopy(data))
finally:
torch.set_default_device(prev)
deserialize_to_file(path, copy.deepcopy(data))
files[ext] = path
yield files, ref_pbc, ref_nopbc
for path in files.values():
Expand Down Expand Up @@ -362,12 +355,7 @@ def spin_fparam_model_files():
tmpdir = tempfile.mkdtemp()
for ext in (".pt2", ".pte"):
path = os.path.join(tmpdir, f"spin_fparam_test{ext}")
prev = torch.get_default_device()
torch.set_default_device(None)
try:
deserialize_to_file(path, copy.deepcopy(data))
finally:
torch.set_default_device(prev)
deserialize_to_file(path, copy.deepcopy(data))
files[ext] = path
yield files
for path in files.values():
Expand Down Expand Up @@ -426,12 +414,7 @@ def spin_aparam_model_files():
tmpdir = tempfile.mkdtemp()
for ext in (".pt2", ".pte"):
path = os.path.join(tmpdir, f"spin_aparam_test{ext}")
prev = torch.get_default_device()
torch.set_default_device(None)
try:
deserialize_to_file(path, copy.deepcopy(data))
finally:
torch.set_default_device(prev)
deserialize_to_file(path, copy.deepcopy(data))
files[ext] = path
yield files
for path in files.values():
Expand Down
9 changes: 1 addition & 8 deletions source/tests/pt_expt/test_change_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,7 @@ def setUpClass(cls) -> None:
cls.shared_pte = os.path.join(cls.tmpdir, "shared.pte")
freeze(model=cls.model_path, output=cls.shared_pte)
cls.shared_pt2 = os.path.join(cls.tmpdir, "shared.pt2")
# Clear default device: tests/pt/__init__.py may set a fake device
# for CPU fallback, which poisons AOTInductor compilation.
saved_device = torch.get_default_device()
torch.set_default_device(None)
try:
freeze(model=cls.model_path, output=cls.shared_pt2)
finally:
torch.set_default_device(saved_device)
freeze(model=cls.model_path, output=cls.shared_pt2)

@classmethod
def tearDownClass(cls) -> None:
Expand Down
Loading