Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions deepmd/pt_expt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,22 @@ def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict:
# (per-layer ghost-feature MPI exchange via deepmd_export::border_op).
# The C++ DeepPotPTExpt / DeepSpinPTExpt loaders branch on this flag.
meta["has_comm_artifact"] = _needs_with_comm_artifact(model)
# Whether the model's regular .pt2 graph consumes the ``mapping``
# tensor to gather per-layer ghost-atom features from local atoms.
# Mirrors the descriptor's ``has_message_passing()`` API: True for
# any message-passing descriptor (DPA2, DPA3, hybrids over those);
# False for non-message-passing descriptors (se_e2_a, DPA1, etc.).
# The C++ side gates its fail-fast on this — an absent mapping is
# fatal only for models that would silently corrupt ghost features
# otherwise.
desc = getattr(getattr(model, "atomic_model", None), "descriptor", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking: this currently derives has_message_passing only from model.atomic_model.descriptor.has_message_passing(). That is fine for the normal DPA2/DPA3 export path, but it may under-report for future/alternate wrappers where the top-level model (or atomic_model) exposes has_message_passing() without a directly exposed descriptor. Would it be safer to first try model.has_message_passing() / model.atomic_model.has_message_passing() and only then fall back to atomic_model.descriptor.has_message_passing()?

— OpenClaw 2026.5.12 (model: custom-chat-jinzhezeng-group/gpt-5.5)

if desc is not None and hasattr(desc, "has_message_passing"):
try:
meta["has_message_passing"] = bool(desc.has_message_passing())
except (AttributeError, NotImplementedError):
meta["has_message_passing"] = False
else:
meta["has_message_passing"] = False
return meta


Expand Down
9 changes: 9 additions & 0 deletions source/api_cc/include/DeepPotPTExpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,15 @@ class DeepPotPTExpt : public DeepPotBackend {
// passing. ``with_comm_tempfile_`` owns the extracted nested .pt2
// for the lifetime of ``with_comm_loader``.
bool has_comm_artifact_ = false;
// Whether the regular .pt2 graph consumes the mapping tensor for
// ghost-feature gather (true for any message-passing descriptor:
// DPA2/DPA3/hybrids; false for se_e2_a/DPA1/etc.). Mirrors the
// descriptor's ``has_message_passing()`` API; read from the
// ``has_message_passing`` metadata field. Defaults to false for
// pre-PR .pt2 archives that lack the field so non-GNN archives
// continue to work; GNN archives must be regenerated to opt into
// the fail-fast guard against the silent-corruption bug.
bool has_message_passing_ = false;
std::unique_ptr<deepmd::ptexpt::TempFile> with_comm_tempfile_;
std::unique_ptr<torch::inductor::AOTIModelPackageLoader> with_comm_loader;

Expand Down
3 changes: 3 additions & 0 deletions source/api_cc/include/DeepSpinPTExpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ class DeepSpinPTExpt : public DeepSpinBackend {
std::unique_ptr<torch::inductor::AOTIModelPackageLoader> loader;
// Optional with-comm artifact for multi-rank GNN spin inference.
bool has_comm_artifact_ = false;
// Mirrors descriptor's has_message_passing(). See DeepPotPTExpt.h
// for the full rationale and gating role.
bool has_message_passing_ = false;
std::unique_ptr<deepmd::ptexpt::TempFile> with_comm_tempfile_;
std::unique_ptr<torch::inductor::AOTIModelPackageLoader> with_comm_loader;

Expand Down
74 changes: 66 additions & 8 deletions source/api_cc/src/DeepPotPTExpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,18 @@ void DeepPotPTExpt::init(const std::string& model,
// exchange and producing wrong results.
has_comm_artifact_ = metadata.obj_val.count("has_comm_artifact") &&
metadata["has_comm_artifact"].as_bool();
// Whether the regular .pt2 graph consumes ``mapping`` for ghost-atom
// feature gather. Mirrors the descriptor's ``has_message_passing()``
// API: true for message-passing descriptors (DPA2, DPA3, hybrids
// over those), false for non-message-passing descriptors (se_e2_a,
// DPA1, etc.). Pre-PR .pt2 archives lack this field; default to
// false so they retain their previous behaviour (non-GNN archives
// continue to work; GNN archives that had the original
// silent-corruption bug must be regenerated to opt into the fail-
// fast guard). All in-tree fixtures are regenerated by the gen
// scripts and carry the explicit value.
has_message_passing_ = metadata.obj_val.count("has_message_passing") &&
metadata["has_message_passing"].as_bool();
if (has_comm_artifact_) {
try {
// Extract the nested ``extra/forward_lower_with_comm.pt2`` into a
Expand Down Expand Up @@ -353,6 +365,49 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener,
.clone()
.to(device);

// Dispatch decision: use the with-comm artifact when LAMMPS is running
// multi-rank. ``lmp_list.nswap > 0`` is the proxy for "multi-rank with
// cross-domain communication"; in single-rank LAMMPS (processors 1 1 1,
// including with PBC) the C++ side sees nswap == 0. api_cc is not
// linked against MPI directly, so we cannot call MPI_Comm_size; the
// proxy is set by LAMMPS's CommBrick at setup time.
//
// The regular artifact uses ``mapping`` to gather ghost-atom features
// from local-atom embeddings (``index_select(node_ebd[1, nloc, dim],
// mapping)``). Identity-mapping for ghost slots is silently wrong,
// so fail-fast when the regular path would be taken without a real
// mapping — applies uniformly to every caller (LAMMPS pair, ctest
// fixtures, direct C++ API users). Callers that want the regular
// path must populate ``lmp_list.mapping``.
bool multi_rank = (lmp_list.nswap > 0);
bool atom_map_present = (lmp_list.mapping != nullptr);
bool use_with_comm = has_comm_artifact_ && multi_rank;
// Fail-fast conditions:
// - ``has_message_passing_``: only models whose regular graph
// actually consumes ``mapping`` for ghost-feature gather can be
// silently corrupted by an absent mapping. Skip for non-GNN
// models (se_e2_a, DPA1, ...).
// - ``nghost > 0``: with no ghost atoms, identity mapping over
// [0, nloc) is trivially correct.
if (has_message_passing_ && !use_with_comm && !atom_map_present &&
nghost > 0) {
Comment on lines +392 to +393
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Fail fast for multi-rank regular path regardless atom-map presence

The new predicate only throws when !atom_map_present, so a multi-rank caller that does provide InputNlist.mapping can bypass this guard and still run the regular artifact (use_with_comm == false). In multi-rank runs, mapping lookups can validly resolve to ghost indices (>= nlocal), while the regular message-passing path gathers from local-only embeddings, which can still produce out-of-bounds indexing or corrupted forces. This leaves the original corruption class unblocked for a reachable configuration; the multi-rank fail-fast should not depend on whether a mapping pointer exists.

Useful? React with 👍 / 👎.

if (multi_rank) {
throw deepmd::deepmd_exception(
"Multi-rank LAMMPS .pt2 inference requires the model to be "
"exported with `use_loc_mapping=False`, which compiles a "
"with-comm artifact for cross-rank ghost-feature exchange. "
"Re-export the model with use_loc_mapping=False and try again.");
} else {
throw deepmd::deepmd_exception(
"Single-rank LAMMPS .pt2 inference requires `atom_modify map "
"yes` in the LAMMPS input (so InputNlist.mapping is populated "
"from the LAMMPS atom-map). The model gathers ghost-atom "
"features via this mapping; without it the C++ side has no "
"safe way to resolve ghost indices to local owners. C++ API "
"callers must set inlist.mapping explicitly before compute().");
}
}

// LAMMPS sets ago=0 on every nlist rebuild (neighbor rebuild, re-partition,
// atom exchange between subdomains), so `ago > 0` implies the cached
// mapping and nlist tensors are still valid. Rebuild only on ago==0.
Expand All @@ -372,7 +427,13 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener,
.clone()
.to(device);
} else {
// Default identity mapping for local atoms
// Identity fallback reached only on the with-comm path (where the
// model graph fills ghost features via border_op and ignores this
// tensor for ghost gather — see deepmd/pt_expt/descriptor/
// repflows.py::_exchange_ghosts) or for trusted direct C++ callers
// (world == nullptr, see the dispatch carve-out above). Any other
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking: this comment still mentions world == nullptr and a dispatch carve-out, but the fail-fast predicate above does not actually special-case world == nullptr; a direct C++ caller with nghost > 0, no mapping, and no with-comm path will now throw too. That behavior may be exactly what we want, but the comment should match it to avoid future misreads.

— OpenClaw 2026.5.12 (model: custom-chat-jinzhezeng-group/gpt-5.5)

// path that reaches here would have been rejected by the fail-fast
// throw, so identity values are safe.
std::vector<std::int64_t> mapping(nall_real);
for (int ii = 0; ii < nall_real; ii++) {
mapping[ii] = ii;
Expand Down Expand Up @@ -428,14 +489,11 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener,
aparam_tensor = torch::zeros({0}, options).to(device);
}

// Phase 4 dispatch: use the with-comm artifact when LAMMPS is
// running multi-rank. ``lmp_list.nswap > 0`` is the proxy for
// "multi-rank with cross-domain communication"; in single-rank
// mode LAMMPS sets nswap=0. Falling back to the regular artifact
// for nswap=0 is correct because that artifact uses the mapping
// tensor to gather ghost embeddings from local atoms.
// ``use_with_comm`` was computed earlier alongside the fail-fast
// dispatch check. Use the with-comm artifact for the multi-rank case
// (the regular artifact uses the mapping tensor to gather ghost
// embeddings, which only works in single-rank).
std::vector<torch::Tensor> flat_outputs;
bool use_with_comm = has_comm_artifact_ && lmp_list.nswap > 0;
if (use_with_comm && !with_comm_loader) {
throw deepmd::deepmd_exception(
"Multi-rank LAMMPS requires the with-comm artifact, but it failed "
Expand Down
40 changes: 39 additions & 1 deletion source/api_cc/src/DeepSpinPTExpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ void DeepSpinPTExpt::init(const std::string& model,
// dropping the MPI exchange.
has_comm_artifact_ = metadata.obj_val.count("has_comm_artifact") &&
metadata["has_comm_artifact"].as_bool();
// See DeepPotPTExpt::init for rationale. Defaults to false for
// pre-PR archives so they retain their previous behaviour.
has_message_passing_ = metadata.obj_val.count("has_message_passing") &&
metadata["has_message_passing"].as_bool();
if (has_comm_artifact_) {
try {
with_comm_tempfile_ = std::make_unique<deepmd::ptexpt::TempFile>(
Expand Down Expand Up @@ -372,6 +376,34 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
.clone()
.to(device);

// Dispatch decision: see DeepPotPTExpt.cc for the full rationale.
// Single-rank without atom-map cannot drive the regular path (no safe
// ghost→local mapping); multi-rank without a with-comm artifact cannot
// drive border_op (no inter-rank exchange tensor). Both unsupported
// combinations fail-fast for every caller.
bool multi_rank = (lmp_list.nswap > 0);
bool atom_map_present = (lmp_list.mapping != nullptr);
bool use_with_comm = has_comm_artifact_ && multi_rank;
// See DeepPotPTExpt::compute_inner for the rationale on these guards.
if (has_message_passing_ && !use_with_comm && !atom_map_present &&
nghost > 0) {
if (multi_rank) {
throw deepmd::deepmd_exception(
"Multi-rank LAMMPS .pt2 inference requires the model to be "
"exported with `use_loc_mapping=False`, which compiles a "
"with-comm artifact for cross-rank ghost-feature exchange. "
"Re-export the model with use_loc_mapping=False and try again.");
} else {
throw deepmd::deepmd_exception(
"Single-rank LAMMPS .pt2 inference requires `atom_modify map "
"yes` in the LAMMPS input (so InputNlist.mapping is populated "
"from the LAMMPS atom-map). The model gathers ghost-atom "
"features via this mapping; without it the C++ side has no "
"safe way to resolve ghost indices to local owners. C++ API "
"callers must set inlist.mapping explicitly before compute().");
}
}

// LAMMPS sets ago=0 on every nlist rebuild, so ago>0 implies the cached
// mapping and nlist tensors are still valid — see DeepPotPTExpt.cc for
// the same rationale.
Expand All @@ -391,6 +423,10 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
.clone()
.to(device);
} else {
// Identity fallback: only reached on the with-comm path (which
// fills ghost features via border_op and ignores this tensor for
// ghost gather) or for trusted direct C++ callers (world ==
// nullptr). Other paths were rejected by the fail-fast above.
std::vector<std::int64_t> mapping(nall_real);
for (int ii = 0; ii < nall_real; ii++) {
mapping[ii] = ii;
Expand Down Expand Up @@ -452,8 +488,10 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
// _with_comm), so C++ supplies the same 8 comm tensors as the
// non-spin path. ``nlocal``/``nghost`` carry the real-atom counts
// (pre atom-doubling); the spin override halves them internally.
//
// ``use_with_comm`` was computed earlier alongside the fail-fast
// dispatch check.
std::vector<torch::Tensor> flat_outputs;
bool use_with_comm = has_comm_artifact_ && lmp_list.nswap > 0;
if (use_with_comm && !with_comm_loader) {
throw deepmd::deepmd_exception(
"Multi-rank LAMMPS requires the with-comm artifact, but it failed "
Expand Down
17 changes: 15 additions & 2 deletions source/lmp/tests/run_mpi_pair_deepmd_dpa3_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@
"trigger nlist rebuilds on every step (and run a small ``--nsteps`` "
"to keep wall time low while still exercising the rebuild path).",
)
parser.add_argument(
"--no-atom-map",
action="store_true",
help="When set, omit ``atom_modify map yes`` from the LAMMPS input. "
"Used by the no-atom-map fail-fast / with-comm fallback tests; "
"with this flag the C++ DeepPotPTExpt sees inlist.mapping == "
"nullptr and either fails fast (no with-comm artifact) or routes "
"to with-comm (multi-rank, with-comm artifact present).",
)
parser.add_argument(
"--null-vx",
type=float,
Expand Down Expand Up @@ -124,8 +133,12 @@
# ``atom_modify map yes`` is required when single-rank dispatch goes
# through the regular artifact of a use_loc_mapping=False .pt2: the
# C++ side needs the LAMMPS global-id->local-index map to build the
# ``mapping`` tensor. It is harmless under multi-rank.
lammps.atom_modify("map yes")
# ``mapping`` tensor. It is harmless under multi-rank. The
# ``--no-atom-map`` flag skips this line so the no-atom-map fallback
# (multi-rank with-comm path) and fail-fast (no with-comm artifact)
# branches can be exercised.
if not args.no_atom_map:
lammps.atom_modify("map yes")
lammps.neighbor("2.0 bin")
lammps.neigh_modify(f"every {args.neigh_every} delay 0 check no")
lammps.read_data(args.DATAFILE)
Expand Down
Loading
Loading