Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions deepmd/dpmodel/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
from copy import (
deepcopy,
)
from itertools import (
pairwise,
)
from typing import (
Any,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.array_api import (
Array,
)
Expand All @@ -14,6 +20,7 @@
)
from deepmd.dpmodel.common import (
NativeOP,
to_numpy_array,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
Expand Down Expand Up @@ -87,7 +94,51 @@ def call(
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
batch: Array | None = None,
ptr: Array | None = None,
extended_atype: Array | None = None,
extended_batch: Array | None = None,
extended_image: Array | None = None,
extended_ptr: Array | None = None,
mapping: Array | None = None,
central_ext_index: Array | None = None,
nlist: Array | None = None,
nlist_ext: Array | None = None,
a_nlist: Array | None = None,
a_nlist_ext: Array | None = None,
nlist_mask: Array | None = None,
a_nlist_mask: Array | None = None,
edge_index: Array | None = None,
angle_index: Array | None = None,
) -> dict[str, Array]:
if batch is not None or ptr is not None:
if batch is None or ptr is None:
raise ValueError("Both batch and ptr are required for mixed batches.")
return self.call_flat(
coord=coord,
atype=atype,
batch=batch,
ptr=ptr,
box=box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
extended_atype=extended_atype,
extended_batch=extended_batch,
extended_image=extended_image,
extended_ptr=extended_ptr,
mapping=mapping,
central_ext_index=central_ext_index,
nlist=nlist,
nlist_ext=nlist_ext,
a_nlist=a_nlist,
a_nlist_ext=a_nlist_ext,
nlist_mask=nlist_mask,
a_nlist_mask=a_nlist_mask,
edge_index=edge_index,
angle_index=angle_index,
)

model_ret = self.call_common(
coord,
atype,
Expand All @@ -111,6 +162,130 @@ def call(
model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-3)
return model_predict

def call_flat(
self,
coord: Array,
atype: Array,
batch: Array,
ptr: Array,
box: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
extended_atype: Array | None = None,
extended_batch: Array | None = None,
extended_image: Array | None = None,
extended_ptr: Array | None = None,
mapping: Array | None = None,
central_ext_index: Array | None = None,
nlist: Array | None = None,
nlist_ext: Array | None = None,
a_nlist: Array | None = None,
a_nlist_ext: Array | None = None,
nlist_mask: Array | None = None,
a_nlist_mask: Array | None = None,
edge_index: Array | None = None,
angle_index: Array | None = None,
) -> dict[str, Array]:
"""Evaluate a flattened mixed-nloc batch with the dpmodel backend.

The dpmodel backend reuses the regular one-frame call path for each
segment described by ``ptr`` and merges the translated outputs back into
the flat mixed-batch layout.
"""
del (
extended_atype,
extended_batch,
extended_image,
extended_ptr,
mapping,
central_ext_index,
nlist,
nlist_ext,
a_nlist,
a_nlist_ext,
nlist_mask,
a_nlist_mask,
edge_index,
angle_index,
)
if self._enable_hessian:
raise NotImplementedError(
"Hessian is not implemented for dpmodel mixed-batch flat calls."
)

xp = array_api_compat.array_namespace(coord, atype)
ptr_np = to_numpy_array(ptr)
if ptr_np is None:
raise ValueError("ptr is required for mixed batches.")
ptr_np = np.asarray(ptr_np, dtype=np.int64)
if ptr_np.ndim != 1 or ptr_np.size < 2:
raise ValueError("ptr must be a 1D array with at least two entries.")

total_atoms = coord.shape[0]
if ptr_np[0] != 0 or ptr_np[-1] != total_atoms:
raise ValueError("ptr must start at 0 and end at the number of atoms.")
if batch.shape[0] != total_atoms:
raise ValueError("batch length must match the number of atoms.")

frame_outputs = []
for frame_idx, (start, end) in enumerate(pairwise(ptr_np)):
nloc = int(end - start)
frame_coord = xp.reshape(coord[start:end], (1, nloc * 3))
frame_atype = xp.reshape(atype[start:end], (1, nloc))
frame_box = box[frame_idx : frame_idx + 1] if box is not None else None
frame_fparam = (
fparam[frame_idx : frame_idx + 1] if fparam is not None else None
)
frame_aparam = (
xp.reshape(aparam[start:end], (1, nloc, *aparam.shape[1:]))
if aparam is not None
else None
)
frame_outputs.append(
self.call(
frame_coord,
frame_atype,
box=frame_box,
fparam=frame_fparam,
aparam=frame_aparam,
do_atomic_virial=do_atomic_virial,
)
)

return self._merge_flat_frame_outputs(frame_outputs)

@staticmethod
def _merge_flat_frame_outputs(
frame_outputs: list[dict[str, Array]],
) -> dict[str, Array]:
if not frame_outputs:
raise ValueError("mixed-batch input must contain at least one frame.")

framewise_keys = {"energy", "virial"}
result: dict[str, Array] = {}
for key in frame_outputs[0]:
values = [frame_output[key] for frame_output in frame_outputs]
xp = array_api_compat.array_namespace(values[0])
if key in framewise_keys:
result[key] = xp.concat(values, axis=0)
elif key == "mask":
result[key] = xp.concat(
[xp.reshape(value, (-1,)) for value in values],
axis=0,
)
else:
result[key] = xp.concat(
[
xp.reshape(value, (-1, *value.shape[2:]))
if value.ndim >= 3
else xp.reshape(value, (-1,))
for value in values
],
axis=0,
)
return result

def call_lower(
self,
extended_coord: Array,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,25 @@ def _make_dp_loader_set(
# LMDB path: single string → LmdbDataset
if isinstance(training_systems, str) and is_lmdb(training_systems):
auto_prob = training_dataset_params.get("auto_prob", None)
mixed_batch = training_dataset_params.get("mixed_batch", False)
train_data_single = LmdbDataset(
training_systems,
model_params_single["type_map"],
training_dataset_params["batch_size"],
mixed_batch=mixed_batch,
auto_prob_style=auto_prob,
)
if (
validation_systems is not None
and isinstance(validation_systems, str)
and is_lmdb(validation_systems)
):
val_mixed_batch = validation_dataset_params.get("mixed_batch", False)
validation_data_single = LmdbDataset(
validation_systems,
model_params_single["type_map"],
validation_dataset_params["batch_size"],
mixed_batch=val_mixed_batch,
)
elif validation_systems is not None:
validation_data_single = _make_dp_loader_set(
Expand Down
Loading
Loading