Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions deepmd/dpmodel/utils/dist_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Minimum pairwise distance check for frame validity filtering."""

from __future__ import (
annotations,
)

import numpy as np


def compute_min_pair_dist_single(
coord: np.ndarray,
box: np.ndarray | None,
atype: np.ndarray,
) -> float:
"""Compute the minimum pairwise atomic distance for a single frame.

Parameters
----------
coord : np.ndarray
Atomic coordinates, flattened with shape (natoms * 3,)
or reshaped as (natoms, 3).
box : np.ndarray or None
Box vectors with shape (9,) for PBC, or None for non-PBC.
atype : np.ndarray
Atom types with shape (natoms,). Virtual atoms (type < 0)
are excluded from the distance check.

Returns
-------
float
Minimum pairwise distance. Returns inf if fewer than 2
real atoms exist.
"""
coord = coord.reshape(-1, 3)

# === Step 1. Filter out virtual atoms ===
real_mask = atype.ravel() >= 0
real_coord = coord[real_mask]
n_real = real_coord.shape[0]
if n_real < 2:
return float("inf")

# === Step 2. Compute pairwise displacement vectors ===
diff = real_coord[np.newaxis, :, :] - real_coord[:, np.newaxis, :]

# === Step 3. Apply minimum image convention for PBC ===
if box is not None:
cell = box.reshape(3, 3)
inv_cell = np.linalg.inv(cell)
frac_diff = diff @ inv_cell
frac_diff -= np.round(frac_diff)
diff = frac_diff @ cell

# === Step 4. Compute distances and exclude self-pairs ===
dist_sq = np.sum(diff * diff, axis=-1)
Comment thread
OutisLi marked this conversation as resolved.
Outdated
np.fill_diagonal(dist_sq, np.inf)

return float(np.sqrt(dist_sq.min()))
13 changes: 13 additions & 0 deletions deepmd/dpmodel/utils/lmdb_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import msgpack
import numpy as np

from deepmd.dpmodel.utils.dist_check import (
compute_min_pair_dist_single,
)
from deepmd.env import (
GLOBAL_ENER_FLOAT_PRECISION,
GLOBAL_NP_FLOAT_PRECISION,
Expand Down Expand Up @@ -597,6 +600,16 @@ def __getitem__(self, index: int) -> dict[str, Any]:
frame["natoms"] = fallback
frame["real_natoms_vec"] = fallback

if "min_pair_dist" in self._data_requirements and "min_pair_dist" not in frame:
box = frame.get("box")
if box is not None and np.allclose(box, 0.0):
box = None
frame["find_min_pair_dist"] = np.float32(1.0)
frame["min_pair_dist"] = np.array(
[compute_min_pair_dist_single(frame["coord"], box, frame["atype"])],
dtype=self._resolve_dtype("min_pair_dist"),
)

# Add find_* flags for all data keys present in the frame.
# Core structural keys and metadata are excluded — only label-like
# and auxiliary data keys get find_* flags.
Expand Down
14 changes: 11 additions & 3 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,17 @@ def extend_coord_with_ghosts(
shift_idx = xp.take(xyz, xp.argsort(xp.linalg.vector_norm(xyz, axis=1)), axis=0)
ns, _ = shift_idx.shape
nall = ns * nloc
# shift_vec = xp.einsum("sd,fdk->fsk", shift_idx, cell)
shift_vec = xp.tensordot(shift_idx, cell, axes=([1], [1]))
shift_vec = xp.permute_dims(shift_vec, (1, 0, 2))
xp_name = getattr(xp, "__name__", "")
if "jax" in xp_name:
Comment thread
OutisLi marked this conversation as resolved.
Outdated
# Avoid JAX internal errors in tensordot.
shift_vec = xp.sum(
shift_idx[xp.newaxis, :, :, xp.newaxis] * cell[:, xp.newaxis, :, :],
axis=2,
Comment thread
OutisLi marked this conversation as resolved.
)
else:
# shift_vec = xp.einsum("sd,fdk->fsk", shift_idx, cell)
shift_vec = xp.tensordot(shift_idx, cell, axes=([1], [1]))
shift_vec = xp.permute_dims(shift_vec, (1, 0, 2))
extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :]
extend_atype = xp.tile(atype[:, :, xp.newaxis], (1, ns, 1))
extend_aidx = xp.tile(aidx[:, :, xp.newaxis], (1, ns, 1))
Expand Down
Loading
Loading