Skip to content
16 changes: 16 additions & 0 deletions deepmd/pt_expt/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,22 @@ def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict:
# (per-layer ghost-feature MPI exchange via deepmd_export::border_op).
# The C++ DeepPotPTExpt / DeepSpinPTExpt loaders branch on this flag.
meta["has_comm_artifact"] = _needs_with_comm_artifact(model)
# Whether the model's regular .pt2 graph consumes the ``mapping``
# tensor to gather per-layer ghost-atom features from local atoms.
# Mirrors the descriptor's ``has_message_passing()`` API: True for
# any message-passing descriptor (DPA2, DPA3, hybrids over those);
# False for non-message-passing descriptors (se_e2_a, DPA1, etc.).
# The C++ side gates its fail-fast on this — an absent mapping is
# fatal only for models that would silently corrupt ghost features
# otherwise.
desc = getattr(getattr(model, "atomic_model", None), "descriptor", None)
Comment thread
wanghan-iapcm marked this conversation as resolved.
Outdated
if desc is not None and hasattr(desc, "has_message_passing"):
try:
meta["has_message_passing"] = bool(desc.has_message_passing())
except (AttributeError, NotImplementedError):
meta["has_message_passing"] = False
else:
meta["has_message_passing"] = False
return meta


Expand Down
9 changes: 9 additions & 0 deletions source/api_cc/include/DeepPotPTExpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,15 @@ class DeepPotPTExpt : public DeepPotBackend {
// passing. ``with_comm_tempfile_`` owns the extracted nested .pt2
// for the lifetime of ``with_comm_loader``.
bool has_comm_artifact_ = false;
// Whether the regular .pt2 graph consumes the mapping tensor for
// ghost-feature gather (true for any message-passing descriptor:
// DPA2/DPA3/hybrids; false for se_e2_a/DPA1/etc.). Mirrors the
// descriptor's ``has_message_passing()`` API; read from the
// ``has_message_passing`` metadata field. Defaults to false for
// pre-PR .pt2 archives that lack the field so non-GNN archives
// continue to work; GNN archives must be regenerated to opt into
// the fail-fast guard against the silent-corruption bug.
bool has_message_passing_ = false;
std::unique_ptr<deepmd::ptexpt::TempFile> with_comm_tempfile_;
std::unique_ptr<torch::inductor::AOTIModelPackageLoader> with_comm_loader;

Expand Down
3 changes: 3 additions & 0 deletions source/api_cc/include/DeepSpinPTExpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ class DeepSpinPTExpt : public DeepSpinBackend {
std::unique_ptr<torch::inductor::AOTIModelPackageLoader> loader;
// Optional with-comm artifact for multi-rank GNN spin inference.
bool has_comm_artifact_ = false;
// Mirrors descriptor's has_message_passing(). See DeepPotPTExpt.h
// for the full rationale and gating role.
bool has_message_passing_ = false;
std::unique_ptr<deepmd::ptexpt::TempFile> with_comm_tempfile_;
std::unique_ptr<torch::inductor::AOTIModelPackageLoader> with_comm_loader;

Expand Down
74 changes: 66 additions & 8 deletions source/api_cc/src/DeepPotPTExpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,18 @@ void DeepPotPTExpt::init(const std::string& model,
// exchange and producing wrong results.
has_comm_artifact_ = metadata.obj_val.count("has_comm_artifact") &&
metadata["has_comm_artifact"].as_bool();
// Whether the regular .pt2 graph consumes ``mapping`` for ghost-atom
// feature gather. Mirrors the descriptor's ``has_message_passing()``
// API: true for message-passing descriptors (DPA2, DPA3, hybrids
// over those), false for non-message-passing descriptors (se_e2_a,
// DPA1, etc.). Pre-PR .pt2 archives lack this field; default to
// false so they retain their previous behaviour (non-GNN archives
// continue to work; GNN archives that had the original
// silent-corruption bug must be regenerated to opt into the fail-
// fast guard). All in-tree fixtures are regenerated by the gen
// scripts and carry the explicit value.
has_message_passing_ = metadata.obj_val.count("has_message_passing") &&
metadata["has_message_passing"].as_bool();
if (has_comm_artifact_) {
try {
// Extract the nested ``extra/forward_lower_with_comm.pt2`` into a
Expand Down Expand Up @@ -353,6 +365,49 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener,
.clone()
.to(device);

// Dispatch decision: use the with-comm artifact when LAMMPS is running
// multi-rank. ``lmp_list.nswap > 0`` is the proxy for "multi-rank with
// cross-domain communication"; in single-rank LAMMPS (processors 1 1 1,
// including with PBC) the C++ side sees nswap == 0. api_cc is not
// linked against MPI directly, so we cannot call MPI_Comm_size; the
// proxy is set by LAMMPS's CommBrick at setup time.
//
// The regular artifact uses ``mapping`` to gather ghost-atom features
// from local-atom embeddings (``index_select(node_ebd[1, nloc, dim],
// mapping)``). Identity-mapping for ghost slots is silently wrong,
// so fail-fast when the regular path would be taken without a real
// mapping — applies uniformly to every caller (LAMMPS pair, ctest
// fixtures, direct C++ API users). Callers that want the regular
// path must populate ``lmp_list.mapping``.
bool multi_rank = (lmp_list.nswap > 0);
bool atom_map_present = (lmp_list.mapping != nullptr);
bool use_with_comm = has_comm_artifact_ && multi_rank;
// Fail-fast conditions:
// - ``has_message_passing_``: only models whose regular graph
// actually consumes ``mapping`` for ghost-feature gather can be
// silently corrupted by an absent mapping. Skip for non-GNN
// models (se_e2_a, DPA1, ...).
// - ``nghost > 0``: with no ghost atoms, identity mapping over
// [0, nloc) is trivially correct.
if (has_message_passing_ && !use_with_comm && !atom_map_present &&
nghost > 0) {
Comment thread
wanghan-iapcm marked this conversation as resolved.
Outdated
if (multi_rank) {
throw deepmd::deepmd_exception(
"Multi-rank LAMMPS .pt2 inference requires the model to be "
"exported with `use_loc_mapping=False`, which compiles a "
"with-comm artifact for cross-rank ghost-feature exchange. "
"Re-export the model with use_loc_mapping=False and try again.");
} else {
throw deepmd::deepmd_exception(
"Single-rank LAMMPS .pt2 inference requires `atom_modify map "
"yes` in the LAMMPS input (so InputNlist.mapping is populated "
"from the LAMMPS atom-map). The model gathers ghost-atom "
"features via this mapping; without it the C++ side has no "
"safe way to resolve ghost indices to local owners. C++ API "
"callers must set inlist.mapping explicitly before compute().");
}
}

// LAMMPS sets ago=0 on every nlist rebuild (neighbor rebuild, re-partition,
// atom exchange between subdomains), so `ago > 0` implies the cached
// mapping and nlist tensors are still valid. Rebuild only on ago==0.
Expand All @@ -372,7 +427,13 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener,
.clone()
.to(device);
} else {
// Default identity mapping for local atoms
// Identity fallback reached only on the with-comm path (where the
// model graph fills ghost features via border_op and ignores this
// tensor for ghost gather — see deepmd/pt_expt/descriptor/
// repflows.py::_exchange_ghosts) or for trusted direct C++ callers
// (world == nullptr, see the dispatch carve-out above). Any other
Comment thread
wanghan-iapcm marked this conversation as resolved.
Outdated
// path that reaches here would have been rejected by the fail-fast
// throw, so identity values are safe.
std::vector<std::int64_t> mapping(nall_real);
for (int ii = 0; ii < nall_real; ii++) {
mapping[ii] = ii;
Expand Down Expand Up @@ -428,14 +489,11 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener,
aparam_tensor = torch::zeros({0}, options).to(device);
}

// Phase 4 dispatch: use the with-comm artifact when LAMMPS is
// running multi-rank. ``lmp_list.nswap > 0`` is the proxy for
// "multi-rank with cross-domain communication"; in single-rank
// mode LAMMPS sets nswap=0. Falling back to the regular artifact
// for nswap=0 is correct because that artifact uses the mapping
// tensor to gather ghost embeddings from local atoms.
// ``use_with_comm`` was computed earlier alongside the fail-fast
// dispatch check. Use the with-comm artifact for the multi-rank case
// (the regular artifact uses the mapping tensor to gather ghost
// embeddings, which only works in single-rank).
std::vector<torch::Tensor> flat_outputs;
bool use_with_comm = has_comm_artifact_ && lmp_list.nswap > 0;
if (use_with_comm && !with_comm_loader) {
throw deepmd::deepmd_exception(
"Multi-rank LAMMPS requires the with-comm artifact, but it failed "
Expand Down
40 changes: 39 additions & 1 deletion source/api_cc/src/DeepSpinPTExpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ void DeepSpinPTExpt::init(const std::string& model,
// dropping the MPI exchange.
has_comm_artifact_ = metadata.obj_val.count("has_comm_artifact") &&
metadata["has_comm_artifact"].as_bool();
// See DeepPotPTExpt::init for rationale. Defaults to false for
// pre-PR archives so they retain their previous behaviour.
has_message_passing_ = metadata.obj_val.count("has_message_passing") &&
metadata["has_message_passing"].as_bool();
if (has_comm_artifact_) {
try {
with_comm_tempfile_ = std::make_unique<deepmd::ptexpt::TempFile>(
Expand Down Expand Up @@ -372,6 +376,34 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
.clone()
.to(device);

// Dispatch decision: see DeepPotPTExpt.cc for the full rationale.
// Single-rank without atom-map cannot drive the regular path (no safe
// ghost→local mapping); multi-rank without a with-comm artifact cannot
// drive border_op (no inter-rank exchange tensor). Both unsupported
// combinations fail-fast for every caller.
bool multi_rank = (lmp_list.nswap > 0);
bool atom_map_present = (lmp_list.mapping != nullptr);
bool use_with_comm = has_comm_artifact_ && multi_rank;
// See DeepPotPTExpt::compute_inner for the rationale on these guards.
if (has_message_passing_ && !use_with_comm && !atom_map_present &&
nghost > 0) {
if (multi_rank) {
throw deepmd::deepmd_exception(
"Multi-rank LAMMPS .pt2 inference requires the model to be "
"exported with `use_loc_mapping=False`, which compiles a "
"with-comm artifact for cross-rank ghost-feature exchange. "
"Re-export the model with use_loc_mapping=False and try again.");
} else {
throw deepmd::deepmd_exception(
"Single-rank LAMMPS .pt2 inference requires `atom_modify map "
"yes` in the LAMMPS input (so InputNlist.mapping is populated "
"from the LAMMPS atom-map). The model gathers ghost-atom "
"features via this mapping; without it the C++ side has no "
"safe way to resolve ghost indices to local owners. C++ API "
"callers must set inlist.mapping explicitly before compute().");
}
}

// LAMMPS sets ago=0 on every nlist rebuild, so ago>0 implies the cached
// mapping and nlist tensors are still valid — see DeepPotPTExpt.cc for
// the same rationale.
Expand All @@ -391,6 +423,10 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
.clone()
.to(device);
} else {
// Identity fallback: only reached on the with-comm path (which
// fills ghost features via border_op and ignores this tensor for
// ghost gather) or for trusted direct C++ callers (world ==
// nullptr). Other paths were rejected by the fail-fast above.
std::vector<std::int64_t> mapping(nall_real);
for (int ii = 0; ii < nall_real; ii++) {
mapping[ii] = ii;
Expand Down Expand Up @@ -452,8 +488,10 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener,
// _with_comm), so C++ supplies the same 8 comm tensors as the
// non-spin path. ``nlocal``/``nghost`` carry the real-atom counts
// (pre atom-doubling); the spin override halves them internally.
//
// ``use_with_comm`` was computed earlier alongside the fail-fast
// dispatch check.
std::vector<torch::Tensor> flat_outputs;
bool use_with_comm = has_comm_artifact_ && lmp_list.nswap > 0;
if (use_with_comm && !with_comm_loader) {
throw deepmd::deepmd_exception(
"Multi-rank LAMMPS requires the with-comm artifact, but it failed "
Expand Down
17 changes: 15 additions & 2 deletions source/lmp/tests/run_mpi_pair_deepmd_dpa3_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@
"trigger nlist rebuilds on every step (and run a small ``--nsteps`` "
"to keep wall time low while still exercising the rebuild path).",
)
parser.add_argument(
"--no-atom-map",
action="store_true",
help="When set, omit ``atom_modify map yes`` from the LAMMPS input. "
"Used by the no-atom-map fail-fast / with-comm fallback tests; "
"with this flag the C++ DeepPotPTExpt sees inlist.mapping == "
"nullptr and either fails fast (no with-comm artifact) or routes "
"to with-comm (multi-rank, with-comm artifact present).",
)
parser.add_argument(
"--null-vx",
type=float,
Expand Down Expand Up @@ -124,8 +133,12 @@
# ``atom_modify map yes`` is required when single-rank dispatch goes
# through the regular artifact of a use_loc_mapping=False .pt2: the
# C++ side needs the LAMMPS global-id->local-index map to build the
# ``mapping`` tensor. It is harmless under multi-rank.
lammps.atom_modify("map yes")
# ``mapping`` tensor. It is harmless under multi-rank. The
# ``--no-atom-map`` flag skips this line so the no-atom-map fallback
# (multi-rank with-comm path) and fail-fast (no with-comm artifact)
# branches can be exercised.
if not args.no_atom_map:
lammps.atom_modify("map yes")
lammps.neighbor("2.0 bin")
lammps.neigh_modify(f"every {args.neigh_every} delay 0 check no")
lammps.read_data(args.DATAFILE)
Expand Down
Loading
Loading