diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index 99da68fe4f..93170162c3 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -22,10 +22,11 @@ # as it's a stateless utility class register_dpmodel_mapping(EnvMat, lambda v: v) -# Register opaque deepmd_export::border_op wrapper (used by GNN MPI -# parallel inference; see comm.py module docstring). -# Register fake tensor implementations for custom tabulate ops -from deepmd.pt_expt.utils import comm # noqa: F401 +# Register fake tensor implementations for custom tabulate ops. +# comm.py (border_op fake/autograd) is NOT imported here — its +# ensure_comm_registered() is called lazily from the with_comm_dict +# export path in serialization.py to avoid eager libdeepmd_op_pt.so +# loading that breaks fake-op registration order in tests. from deepmd.pt_expt.utils import tabulate_ops # noqa: F401 __all__ = [ diff --git a/deepmd/pt_expt/utils/comm.py b/deepmd/pt_expt/utils/comm.py index 434d2a97b0..cb77c6a335 100644 --- a/deepmd/pt_expt/utils/comm.py +++ b/deepmd/pt_expt/utils/comm.py @@ -35,6 +35,8 @@ import torch +_registered: bool = False + def _check_underlying_ops_loaded() -> None: """Surface a clearer error when libdeepmd_op_pt.so isn't loaded. @@ -76,15 +78,11 @@ def _check_underlying_ops_loaded() -> None: ) -_check_underlying_ops_loaded() - - # --------------------------------------------------------------------------- # Fake (meta) impls — let make_fx / torch.export trace through. # --------------------------------------------------------------------------- -@torch.library.register_fake("deepmd_export::border_op") def _border_op_fake( sendlist: torch.Tensor, sendproc: torch.Tensor, @@ -99,7 +97,6 @@ def _border_op_fake( return torch.empty_like(g1) -@torch.library.register_fake("deepmd_export::border_op_backward") def _border_op_backward_fake( sendlist: torch.Tensor, sendproc: torch.Tensor, @@ -180,8 +177,37 @@ def _border_op_backward( ) -torch.library.register_autograd( - "deepmd_export::border_op", - _border_op_backward, - setup_context=_border_op_setup_context, -) +def ensure_comm_registered() -> None: + """Load libdeepmd_op_pt.so and register fake/autograd metadata for border_op. + + Idempotent — safe to call multiple times. Must be called before any + ``make_fx`` / ``torch.export`` trace that passes through border_op (i.e. + before the ``with_comm_dict=True`` export path in serialization.py). + + Kept lazy (not called at import time) so that merely importing + ``deepmd.pt_expt.utils`` does not force-load libdeepmd_op_pt.so and + disrupt fake-op registration order in tests that don't exercise the comm + path at all. + """ + global _registered + if _registered: + return + _check_underlying_ops_loaded() + try: + torch.library.register_fake("deepmd_export::border_op")(_border_op_fake) + except RuntimeError as e: + if "already has" not in str(e) and "already registered" not in str(e): + raise + try: + torch.library.register_fake("deepmd_export::border_op_backward")( + _border_op_backward_fake + ) + except RuntimeError as e: + if "already has" not in str(e) and "already registered" not in str(e): + raise + torch.library.register_autograd( + "deepmd_export::border_op", + _border_op_backward, + setup_context=_border_op_setup_context, + ) + _registered = True diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index d85a334493..2f99c5ca73 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -670,6 +670,14 @@ def _trace_and_export( # matter for tracing — only that they're valid tensors of the right # shape and dtype. See ``_make_comm_sample_inputs``. if with_comm_dict: + # Load libdeepmd_op_pt.so and register border_op fake/autograd + # metadata now — deferred from import time so normal utils imports + # don't force-load the op library and break fake-op ordering. + from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, + ) + + ensure_comm_registered() if not _needs_with_comm_artifact(model): raise ValueError( "with_comm_dict=True requested but the model's descriptor " diff --git a/source/tests/pt_expt/conftest.py b/source/tests/pt_expt/conftest.py index d4d987fe95..228c6104ae 100644 --- a/source/tests/pt_expt/conftest.py +++ b/source/tests/pt_expt/conftest.py @@ -17,9 +17,10 @@ _get_current_function_mode_stack, ) -# ``deepmd.pt_expt.utils.comm`` self-bootstraps libdeepmd_op_pt.so via -# ``_check_underlying_ops_loaded()``, so we no longer need to preload -# ``deepmd.pt`` here. +# ``deepmd.pt_expt.utils.comm`` is now lazy: libdeepmd_op_pt.so is only +# loaded when ``ensure_comm_registered()`` is explicitly called from the +# with_comm_dict export path. Tests that don't exercise that path never +# load the op library, preserving fake-op registration order. def _pop_device_contexts() -> list: