Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion deepmd/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,14 @@ def calculate(

fparam = self.atoms.info.get("fparam", None)
aparam = self.atoms.info.get("aparam", None)
charge_spin = self.atoms.info.get("charge_spin", None)
e, f, v = self.dp.eval(
coords=coord, cells=cell, atom_types=atype, fparam=fparam, aparam=aparam
coords=coord,
cells=cell,
atom_types=atype,
fparam=fparam,
aparam=aparam,
charge_spin=charge_spin,
)[:3]
self.results["energy"] = e[0][0]
# see https://gitlab.com/ase/ase/-/merge_requests/2485
Expand Down
20 changes: 20 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,18 @@ def get_default_fparam(self) -> list[float] | None:
"""Get the default frame parameters."""
return None

def has_chg_spin_ebd(self) -> bool:
"""Check if the model has charge spin embedding."""
return False

def has_default_chg_spin(self) -> bool:
"""Check if the model has default charge_spin values."""
return False

def get_default_chg_spin(self) -> list[float] | None:
"""Get the default charge_spin values."""
return None

def reinit_atom_exclude(
self,
exclude_types: list[int] = [],
Expand Down Expand Up @@ -232,6 +244,7 @@ def forward_common_atomic(
fparam: Array | None = None,
aparam: Array | None = None,
comm_dict: dict | None = None,
charge_spin: Array | None = None,
) -> dict[str, Array]:
"""Common interface for atomic inference.

Expand Down Expand Up @@ -284,6 +297,7 @@ def forward_common_atomic(
fparam=fparam,
aparam=aparam,
comm_dict=comm_dict,
charge_spin=charge_spin,
)
ret_dict = self.apply_out_stat(ret_dict, atype)

Expand Down Expand Up @@ -312,6 +326,7 @@ def call(
mapping: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
charge_spin: Array | None = None,
) -> dict[str, Array]:
return self.forward_common_atomic(
extended_coord,
Expand All @@ -320,6 +335,7 @@ def call(
mapping=mapping,
fparam=fparam,
aparam=aparam,
charge_spin=charge_spin,
)

def get_intensive(self) -> bool:
Expand Down Expand Up @@ -524,6 +540,7 @@ def model_forward(
box: np.ndarray | None,
fparam: np.ndarray | None = None,
aparam: np.ndarray | None = None,
charge_spin: np.ndarray | None = None,
) -> dict[str, np.ndarray]:
# Get reference array to determine the target array type and device
# Use out_bias as reference since it's always present
Expand All @@ -543,6 +560,8 @@ def model_forward(
fparam = xp.asarray(fparam, device=device)
if aparam is not None:
aparam = xp.asarray(aparam, device=device)
if charge_spin is not None:
charge_spin = xp.asarray(charge_spin, device=device)

(
extended_coord,
Expand All @@ -564,6 +583,7 @@ def model_forward(
mapping=mapping,
fparam=fparam,
aparam=aparam,
charge_spin=charge_spin,
)
# Convert outputs back to numpy arrays
return {kk: to_numpy_array(vv) for kk, vv in atomic_ret.items()}
Expand Down
66 changes: 41 additions & 25 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,28 @@ def __init__(
)
super().init_out_stat()

def has_chg_spin_ebd(self) -> bool:
"""Check if the model has charge spin embedding."""
return self.add_chg_spin_ebd

def get_dim_chg_spin(self) -> int:
"""Get the dimension of charge_spin input."""
if self.add_chg_spin_ebd:
return self.descriptor.get_dim_chg_spin()
return 0

def has_default_chg_spin(self) -> bool:
"""Check if the model has default charge_spin values."""
if self.add_chg_spin_ebd:
return self.descriptor.has_default_chg_spin()
return False

def get_default_chg_spin(self) -> list[float] | None:
"""Get the default charge_spin values."""
if self.add_chg_spin_ebd and self.descriptor.has_default_chg_spin():
return self.descriptor.get_default_chg_spin()
return None

def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
return self.fitting_net.output_def()
Expand Down Expand Up @@ -158,6 +180,7 @@ def forward_atomic(
fparam: Array | None = None,
aparam: Array | None = None,
comm_dict: dict | None = None,
charge_spin: Array | None = None,
) -> dict[str, Array]:
"""Models' atomic predictions.

Expand All @@ -178,6 +201,8 @@ def forward_atomic(
comm_dict
MPI communication metadata for parallel inference. ``None`` for
non-parallel inference (default). Forwarded to the descriptor.
charge_spin
charge and spin parameter for descriptor. nf x 2

Returns
-------
Expand All @@ -188,38 +213,29 @@ def forward_atomic(
nframes, nloc, nnei = nlist.shape
atype = xp_take_first_n(extended_atype, 1, nloc)

# Handle default fparam if fitting net supports it
if (
hasattr(self.fitting_net, "get_dim_fparam")
and self.fitting_net.get_dim_fparam() > 0
and fparam is None
):
# use default fparam
from deepmd.dpmodel.array_api import (
array_api_compat,
)

default_fparam = self.fitting_net.get_default_fparam()
assert default_fparam is not None
xp = array_api_compat.array_namespace(extended_coord)
default_fparam_array = xp.asarray(
default_fparam,
dtype=extended_coord.dtype,
device=array_api_compat.device(extended_coord),
)
fparam_input_for_des = xp.tile(
xp.reshape(default_fparam_array, (1, -1)), (nframes, 1)
)
else:
fparam_input_for_des = fparam
# Handle default charge_spin if descriptor supports it
if self.add_chg_spin_ebd and charge_spin is None:
default_cs = self.descriptor.get_default_chg_spin()
if default_cs is not None:
from deepmd.dpmodel.array_api import (
array_api_compat,
)

xp = array_api_compat.array_namespace(extended_coord)
cs_array = xp.asarray(
default_cs,
dtype=extended_coord.dtype,
device=array_api_compat.device(extended_coord),
)
charge_spin = xp.tile(xp.reshape(cs_array, (1, -1)), (nframes, 1))

descriptor, rot_mat, g2, h2, sw = self.descriptor(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam_input_for_des if self.add_chg_spin_ebd else None,
comm_dict=comm_dict,
charge_spin=charge_spin if self.add_chg_spin_ebd else None,
)
ret = self.fitting_net(
descriptor,
Expand Down
2 changes: 2 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def forward_atomic(
fparam: Array | None = None,
aparam: Array | None = None,
comm_dict: dict | None = None,
charge_spin: Array | None = None,
) -> dict[str, Array]:
"""Return atomic prediction.

Expand Down Expand Up @@ -286,6 +287,7 @@ def forward_atomic(
fparam,
aparam,
comm_dict,
charge_spin=charge_spin,
)["energy"]
)
weights = self._compute_weight(extended_coord, extended_atype, nlists_)
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def fwd(
mapping: t_tensor | None = None,
fparam: t_tensor | None = None,
aparam: t_tensor | None = None,
charge_spin: t_tensor | None = None,
) -> dict[str, t_tensor]:
pass

Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def forward_atomic(
fparam: Array | None = None,
aparam: Array | None = None,
comm_dict: dict | None = None,
charge_spin: Array | None = None,
) -> dict[str, Array]:
del comm_dict # pairtab is local; no MPI ghost exchange needed.
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ def call(
mapping: Array | None = None,
fparam: Array | None = None,
comm_dict: dict | None = None,
charge_spin: Array | None = None,
) -> Array:
"""Compute the descriptor.

Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ def call(
mapping: Array | None = None,
fparam: Array | None = None,
comm_dict: dict | None = None,
charge_spin: Array | None = None,
) -> tuple[Array, Array, Array, Array, Array]:
"""Compute the descriptor.
Expand Down
26 changes: 23 additions & 3 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def __init__(
use_loc_mapping: bool = True,
type_map: list[str] | None = None,
add_chg_spin_ebd: bool = False,
default_chg_spin: list[float] | None = None,
Comment thread
wanghan-iapcm marked this conversation as resolved.
) -> None:
super().__init__()

Expand Down Expand Up @@ -433,6 +434,11 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:

self.use_econf_tebd = use_econf_tebd
self.add_chg_spin_ebd = add_chg_spin_ebd
if default_chg_spin is not None and len(default_chg_spin) != 2:
raise ValueError(
"default_chg_spin must have exactly 2 values [charge, spin]"
)
self.default_chg_spin = default_chg_spin
self.use_tebd_bias = use_tebd_bias
self.use_loc_mapping = use_loc_mapping
self.type_map = type_map
Expand Down Expand Up @@ -499,6 +505,18 @@ def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.rcut

def get_dim_chg_spin(self) -> int:
"""Returns the dimension of charge_spin input."""
return 2 if self.add_chg_spin_ebd else 0

def has_default_chg_spin(self) -> bool:
"""Returns whether default charge_spin values are set."""
return self.default_chg_spin is not None

def get_default_chg_spin(self) -> list[float] | None:
"""Returns the default charge_spin values."""
return self.default_chg_spin

def get_rcut_smth(self) -> float:
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
return self.rcut_smth
Expand Down Expand Up @@ -647,6 +665,7 @@ def call(
mapping: Array | None = None,
fparam: Array | None = None,
comm_dict: dict | None = None,
charge_spin: Array | None = None,
) -> tuple[Array, Array, Array, Array, Array]:
"""Compute the descriptor.

Expand Down Expand Up @@ -702,13 +721,13 @@ def call(
)

if self.add_chg_spin_ebd:
assert fparam is not None
assert charge_spin is not None
assert self.chg_embedding is not None
assert self.spin_embedding is not None
chg_tebd = self.chg_embedding.call()
spin_tebd = self.spin_embedding.call()
charge = xp.astype(fparam[:, 0], xp.int64) + 100
spin = xp.astype(fparam[:, 1], xp.int64)
charge = xp.astype(charge_spin[:, 0], xp.int64) + 100
spin = xp.astype(charge_spin[:, 1], xp.int64)
chg_ebd = xp.reshape(
xp.take(chg_tebd, xp.reshape(charge, (-1,)), axis=0),
(nframes, self.tebd_dim),
Expand Down Expand Up @@ -753,6 +772,7 @@ def serialize(self) -> dict:
"use_tebd_bias": self.use_tebd_bias,
"use_loc_mapping": self.use_loc_mapping,
"add_chg_spin_ebd": self.add_chg_spin_ebd,
"default_chg_spin": self.default_chg_spin,
"type_map": self.type_map,
"type_embedding": self.type_embedding.serialize(),
}
Expand Down
26 changes: 25 additions & 1 deletion deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return np.max([descrpt.get_rcut() for descrpt in self.descrpt_list]).item()

def get_dim_chg_spin(self) -> int:
"""Returns the dimension of charge_spin input (0 if not supported)."""
return max(
(descrpt.get_dim_chg_spin() for descrpt in self.descrpt_list), default=0
)

def has_default_chg_spin(self) -> bool:
"""Returns whether the descriptor has a default charge_spin value."""
return any(descrpt.has_default_chg_spin() for descrpt in self.descrpt_list)

def get_default_chg_spin(self) -> list[float] | None:
"""Returns the default charge_spin value, or None."""
for descrpt in self.descrpt_list:
if descrpt.has_default_chg_spin():
return descrpt.get_default_chg_spin()
return None

def get_rcut_smth(self) -> float:
"""Returns the radius where the neighbor information starts to smoothly decay to 0."""
# may not be a good idea...
Expand Down Expand Up @@ -287,6 +304,7 @@ def call(
mapping: Array | None = None,
fparam: Array | None = None,
comm_dict: dict | None = None,
charge_spin: Array | None = None,
) -> tuple[
Array,
Array | None,
Expand Down Expand Up @@ -344,7 +362,13 @@ def call(
assert nl_distinguish_types is not None
nl = nl_distinguish_types[:, :, nci]
odescriptor, gr, _g2, _h2, _sw = descrpt(
coord_ext, atype_ext, nl, mapping, comm_dict=comm_dict
coord_ext,
atype_ext,
nl,
mapping,
fparam=fparam,
comm_dict=comm_dict,
charge_spin=charge_spin,
)
out_descriptor.append(odescriptor)
if gr is not None:
Expand Down
13 changes: 13 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ def get_dim_emb(self) -> int:
"""Returns the embedding dimension of g2."""
pass

def get_dim_chg_spin(self) -> int:
"""Returns the dimension of charge_spin input (0 if not supported)."""
return 0

def has_default_chg_spin(self) -> bool:
"""Returns whether the descriptor has a default charge_spin value."""
return False

def get_default_chg_spin(self) -> Any:
"""Returns the default charge_spin value, or None."""
return None

@abstractmethod
def mixed_types(self) -> bool:
"""Returns if the descriptor requires a neighbor list that distinguish different
Expand Down Expand Up @@ -205,6 +217,7 @@ def fwd(
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
charge_spin: Array | None = None,
) -> Array:
"""Calculate descriptor."""
pass
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def call(
mapping: Array | None = None,
fparam: Array | None = None,
comm_dict: dict | None = None,
charge_spin: Array | None = None,
) -> Array:
"""Compute the descriptor.
Expand Down
Loading
Loading