Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions deepmd/pt_expt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
# as it's a stateless utility class
register_dpmodel_mapping(EnvMat, lambda v: v)

# Register opaque deepmd_export::border_op wrapper (used by GNN MPI
# parallel inference; see comm.py module docstring).
# Register fake tensor implementations for custom tabulate ops
from deepmd.pt_expt.utils import comm # noqa: F401
# Register fake tensor implementations for custom tabulate ops.
# comm.py (border_op fake/autograd) is NOT imported here — its
# ensure_comm_registered() is called lazily from the with_comm_dict
# export path in serialization.py to avoid eager libdeepmd_op_pt.so
# loading that breaks fake-op registration order in tests.
Comment on lines +26 to +29
from deepmd.pt_expt.utils import tabulate_ops # noqa: F401

__all__ = [
Expand Down
46 changes: 36 additions & 10 deletions deepmd/pt_expt/utils/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

import torch

_registered: bool = False


def _check_underlying_ops_loaded() -> None:
"""Surface a clearer error when libdeepmd_op_pt.so isn't loaded.
Expand Down Expand Up @@ -76,15 +78,11 @@
)


_check_underlying_ops_loaded()


# ---------------------------------------------------------------------------
# Fake (meta) impls — let make_fx / torch.export trace through.
# ---------------------------------------------------------------------------


@torch.library.register_fake("deepmd_export::border_op")
def _border_op_fake(
sendlist: torch.Tensor,
sendproc: torch.Tensor,
Expand All @@ -99,7 +97,6 @@
return torch.empty_like(g1)


@torch.library.register_fake("deepmd_export::border_op_backward")
def _border_op_backward_fake(
sendlist: torch.Tensor,
sendproc: torch.Tensor,
Expand Down Expand Up @@ -180,8 +177,37 @@
)


torch.library.register_autograd(
"deepmd_export::border_op",
_border_op_backward,
setup_context=_border_op_setup_context,
)
def ensure_comm_registered() -> None:
"""Load libdeepmd_op_pt.so and register fake/autograd metadata for border_op.

Idempotent — safe to call multiple times. Must be called before any
``make_fx`` / ``torch.export`` trace that passes through border_op (i.e.
before the ``with_comm_dict=True`` export path in serialization.py).

Kept lazy (not called at import time) so that merely importing
``deepmd.pt_expt.utils`` does not force-load libdeepmd_op_pt.so and
disrupt fake-op registration order in tests that don't exercise the comm
path at all.
Comment on lines +180 to +190
"""
global _registered
if _registered:
return
_check_underlying_ops_loaded()
try:
torch.library.register_fake("deepmd_export::border_op")(_border_op_fake)
except RuntimeError as e:
if "already has" not in str(e) and "already registered" not in str(e):
raise
try:
torch.library.register_fake("deepmd_export::border_op_backward")(
_border_op_backward_fake
)
except RuntimeError as e:
if "already has" not in str(e) and "already registered" not in str(e):
raise
torch.library.register_autograd(
"deepmd_export::border_op",
_border_op_backward,
setup_context=_border_op_setup_context,
)
Comment on lines +208 to +212
_registered = True

Check notice

Code scanning / CodeQL

Unused global variable Note

The global variable '_registered' is not used.
Comment on lines +192 to +213
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Make the lazy-registration guard atomic.

ensure_comm_registered() is documented as idempotent, but two callers can both pass Line 193 before _registered flips and then race through the global registration path. Please protect the block with a module-level lock and re-check _registered inside it.

Suggested fix
+import threading
+
 import torch
 
 _registered: bool = False
+_register_lock = threading.Lock()
 ...
 def ensure_comm_registered() -> None:
     ...
     global _registered
     if _registered:
         return
-    _check_underlying_ops_loaded()
-    try:
-        torch.library.register_fake("deepmd_export::border_op")(_border_op_fake)
-    except RuntimeError as e:
-        if "already has" not in str(e) and "already registered" not in str(e):
-            raise
-    try:
-        torch.library.register_fake("deepmd_export::border_op_backward")(
-            _border_op_backward_fake
-        )
-    except RuntimeError as e:
-        if "already has" not in str(e) and "already registered" not in str(e):
-            raise
-    torch.library.register_autograd(
-        "deepmd_export::border_op",
-        _border_op_backward,
-        setup_context=_border_op_setup_context,
-    )
-    _registered = True
+    with _register_lock:
+        if _registered:
+            return
+        _check_underlying_ops_loaded()
+        try:
+            torch.library.register_fake("deepmd_export::border_op")(_border_op_fake)
+        except RuntimeError as e:
+            if "already has" not in str(e) and "already registered" not in str(e):
+                raise
+        try:
+            torch.library.register_fake("deepmd_export::border_op_backward")(
+                _border_op_backward_fake
+            )
+        except RuntimeError as e:
+            if "already has" not in str(e) and "already registered" not in str(e):
+                raise
+        torch.library.register_autograd(
+            "deepmd_export::border_op",
+            _border_op_backward,
+            setup_context=_border_op_setup_context,
+        )
+        _registered = True
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
global _registered
if _registered:
return
_check_underlying_ops_loaded()
try:
torch.library.register_fake("deepmd_export::border_op")(_border_op_fake)
except RuntimeError as e:
if "already has" not in str(e) and "already registered" not in str(e):
raise
try:
torch.library.register_fake("deepmd_export::border_op_backward")(
_border_op_backward_fake
)
except RuntimeError as e:
if "already has" not in str(e) and "already registered" not in str(e):
raise
torch.library.register_autograd(
"deepmd_export::border_op",
_border_op_backward,
setup_context=_border_op_setup_context,
)
_registered = True
global _registered
if _registered:
return
with _register_lock:
if _registered:
return
_check_underlying_ops_loaded()
try:
torch.library.register_fake("deepmd_export::border_op")(_border_op_fake)
except RuntimeError as e:
if "already has" not in str(e) and "already registered" not in str(e):
raise
try:
torch.library.register_fake("deepmd_export::border_op_backward")(
_border_op_backward_fake
)
except RuntimeError as e:
if "already has" not in str(e) and "already registered" not in str(e):
raise
torch.library.register_autograd(
"deepmd_export::border_op",
_border_op_backward,
setup_context=_border_op_setup_context,
)
_registered = True
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/utils/comm.py` around lines 192 - 213, ensure_comm_registered
currently checks global _registered without synchronization, allowing a race;
wrap the registration block in a module-level lock (e.g., _register_lock) inside
ensure_comm_registered, acquire the lock, re-check _registered, then perform the
fake registrations and torch.library.register_autograd calls (referencing
_border_op_fake, _border_op_backward_fake, _border_op_backward, and
_border_op_setup_context) and set _registered = True before releasing the lock
to make the lazy-registration atomic and idempotent.

8 changes: 8 additions & 0 deletions deepmd/pt_expt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,14 @@ def _trace_and_export(
# matter for tracing — only that they're valid tensors of the right
# shape and dtype. See ``_make_comm_sample_inputs``.
if with_comm_dict:
# Load libdeepmd_op_pt.so and register border_op fake/autograd
# metadata now — deferred from import time so normal utils imports
# don't force-load the op library and break fake-op ordering.
from deepmd.pt_expt.utils.comm import (
ensure_comm_registered,
)

ensure_comm_registered()
if not _needs_with_comm_artifact(model):
Comment on lines 672 to 679
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Reject non-comm models before loading the comm op library.

Lines 676-678 force libdeepmd_op_pt.so to load before the code checks whether this model even needs a comm artifact. That brings back the side effect this PR is trying to avoid on invalid with_comm_dict=True calls, and it can replace the intended ValueError with an unrelated op-loading failure. Please move the _needs_with_comm_artifact(model) check ahead of ensure_comm_registered().

Suggested fix
     if with_comm_dict:
-        # Load libdeepmd_op_pt.so and register border_op fake/autograd
-        # metadata now — deferred from import time so normal utils imports
-        # don't force-load the op library and break fake-op ordering.
-        from deepmd.pt_expt.utils.comm import ensure_comm_registered
-
-        ensure_comm_registered()
         if not _needs_with_comm_artifact(model):
             raise ValueError(
                 "with_comm_dict=True requested but the model's descriptor "
                 "does not need cross-rank message passing "
                 "(has_message_passing_across_ranks() is False) — "
                 "there's nothing to compile."
             )
+        # Load libdeepmd_op_pt.so and register border_op fake/autograd
+        # metadata only for models that actually need the comm path.
+        from deepmd.pt_expt.utils.comm import ensure_comm_registered
+
+        ensure_comm_registered()
         nloc_sample = nlist_t.shape[1]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@deepmd/pt_expt/utils/serialization.py` around lines 672 - 679, The code
currently imports and calls ensure_comm_registered() before checking whether the
model actually requires a comm artifact, which can prematurely load
libdeepmd_op_pt.so; move the _needs_with_comm_artifact(model) check before
importing/calling ensure_comm_registered so you reject non-comm models first.
Concretely, evaluate if not _needs_with_comm_artifact(model) and raise the
existing ValueError (or return) before executing the from
deepmd.pt_expt.utils.comm import ensure_comm_registered and
ensure_comm_registered() calls; keep the import/call only in the branch where
_needs_with_comm_artifact(model) is True.

raise ValueError(
Comment on lines 672 to 682
"with_comm_dict=True requested but the model's descriptor "
Expand Down
7 changes: 4 additions & 3 deletions source/tests/pt_expt/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
_get_current_function_mode_stack,
)

# ``deepmd.pt_expt.utils.comm`` self-bootstraps libdeepmd_op_pt.so via
# ``_check_underlying_ops_loaded()``, so we no longer need to preload
# ``deepmd.pt`` here.
# ``deepmd.pt_expt.utils.comm`` is now lazy: libdeepmd_op_pt.so is only
# loaded when ``ensure_comm_registered()`` is explicitly called from the
# with_comm_dict export path. Tests that don't exercise that path never
# load the op library, preserving fake-op registration order.


def _pop_device_contexts() -> list:
Expand Down
Loading