Skip to content
7 changes: 7 additions & 0 deletions .gitignore
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This modification should be removed.

Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,10 @@ frozen_model.*

# Test system directories
system/

temp/
test_mptraj/
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
pkl/
history/
deepmd-kit/
*.hdf5
4 changes: 4 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,25 @@ def _make_dp_loader_set(
# LMDB path: single string → LmdbDataset
if isinstance(training_systems, str) and is_lmdb(training_systems):
auto_prob = training_dataset_params.get("auto_prob", None)
mixed_batch = training_dataset_params.get("mixed_batch", False)
train_data_single = LmdbDataset(
training_systems,
model_params_single["type_map"],
training_dataset_params["batch_size"],
mixed_batch=mixed_batch,
auto_prob_style=auto_prob,
)
if (
validation_systems is not None
and isinstance(validation_systems, str)
and is_lmdb(validation_systems)
):
val_mixed_batch = validation_dataset_params.get("mixed_batch", False)
validation_data_single = LmdbDataset(
validation_systems,
model_params_single["type_map"],
validation_dataset_params["batch_size"],
mixed_batch=val_mixed_batch,
)
elif validation_systems is not None:
validation_data_single = _make_dp_loader_set(
Expand Down
15 changes: 14 additions & 1 deletion deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,20 @@ def forward(
more_loss = {}
# more_loss['log_keys'] = [] # showed when validation on the fly
# more_loss['test_keys'] = [] # showed when doing dp test
atom_norm = 1.0 / natoms

# Detect mixed batch format
is_mixed_batch = "ptr" in input_dict and input_dict["ptr"] is not None

# For mixed batch, compute per-frame atom_norm and average
if is_mixed_batch:
ptr = input_dict["ptr"]
nframes = ptr.numel() - 1
# Compute natoms for each frame
natoms_per_frame = ptr[1:] - ptr[:-1] # [nframes]
# Average atom_norm across frames
atom_norm = torch.mean(1.0 / natoms_per_frame.float())
else:
atom_norm = 1.0 / natoms
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if self.has_e and "energy" in model_pred and "energy" in label:
energy_pred = model_pred["energy"]
energy_label = label["energy"]
Expand Down
131 changes: 131 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,137 @@ def forward_atomic(
)
return fit_ret

def forward_common_atomic_flat(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_batch: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor,
batch: torch.Tensor,
ptr: torch.Tensor,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
extended_ptr: torch.Tensor | None = None,
central_ext_index: torch.Tensor | None = None,
nlist_ext: torch.Tensor | None = None,
a_nlist: torch.Tensor | None = None,
a_nlist_ext: torch.Tensor | None = None,
nlist_mask: torch.Tensor | None = None,
a_nlist_mask: torch.Tensor | None = None,
edge_index: torch.Tensor | None = None,
angle_index: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Forward pass with flat batch format.

Parameters
----------
extended_coord : torch.Tensor
Extended coordinates [total_extended_atoms, 3].
extended_atype : torch.Tensor
Extended atom types [total_extended_atoms].
extended_batch : torch.Tensor
Frame assignment for extended atoms [total_extended_atoms].
nlist : torch.Tensor
Neighbor list [total_atoms, nnei].
mapping : torch.Tensor
Extended atom -> local flat index mapping [total_extended_atoms].
batch : torch.Tensor
Frame assignment for local atoms [total_atoms].
ptr : torch.Tensor
Frame boundaries [nframes + 1].
fparam : torch.Tensor | None
Frame parameters [nframes, ndf].
aparam : torch.Tensor | None
Atomic parameters [total_atoms, nda].
central_ext_index : torch.Tensor | None
Extended-atom indices corresponding to local atoms.
nlist_ext, a_nlist_ext : torch.Tensor | None
Edge and angle neighbor lists indexing concatenated extended atoms.
nlist_mask, a_nlist_mask : torch.Tensor | None
Valid-neighbor masks for flat edge and angle neighbor lists.
edge_index, angle_index : torch.Tensor | None
Dynamic graph indices produced by the flat graph preprocessor.

Returns
-------
result_dict : dict[str, torch.Tensor]
Model predictions in flat format.
"""
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)

# Descriptor and fitting both consume the flat atom layout.
descriptor_out = self.descriptor.forward_flat(
extended_coord,
extended_atype,
extended_batch,
nlist,
mapping,
batch,
ptr,
fparam=fparam if self.add_chg_spin_ebd else None,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
central_ext_index=central_ext_index,
nlist_ext=nlist_ext,
a_nlist=a_nlist,
a_nlist_ext=a_nlist_ext,
nlist_mask=nlist_mask,
a_nlist_mask=a_nlist_mask,
edge_index=edge_index,
angle_index=angle_index,
)

descriptor = descriptor_out.get("descriptor")
rot_mat = descriptor_out.get("rot_mat")
g2 = descriptor_out.get("g2")
h2 = descriptor_out.get("h2")

if self.enable_eval_descriptor_hook:
self.eval_descriptor_list.append(descriptor.detach())

if central_ext_index is None:
from deepmd.pt.utils.nlist import get_central_ext_index

central_ext_index = get_central_ext_index(extended_batch, ptr)
atype = extended_atype[central_ext_index]
else:
atype = extended_atype[central_ext_index]

fit_ret = self.fitting_net.forward_flat(
descriptor,
atype,
batch,
ptr,
gr=rot_mat,
g2=g2,
h2=h2,
fparam=fparam,
aparam=aparam,
)
fit_ret = self.apply_out_stat(fit_ret, atype)

atom_mask = self.make_atom_mask(atype).to(torch.int32)
if self.atom_excl is not None:
atom_mask *= self.atom_excl(atype.unsqueeze(0)).squeeze(0)

for kk in fit_ret.keys():
out_shape = fit_ret[kk].shape
out_shape2 = 1
for ss in out_shape[1:]:
out_shape2 *= ss
fit_ret[kk] = (
fit_ret[kk].reshape([out_shape[0], out_shape2]) * atom_mask[:, None]
).view(out_shape)
fit_ret["mask"] = atom_mask

if self.enable_eval_fitting_last_layer_hook:
if "middle_output" in fit_ret:
self.eval_fitting_last_layer_list.append(
fit_ret.pop("middle_output").detach()
)

return fit_ret

def compute_or_load_stat(
self,
sampled_func: Callable[[], list[dict]],
Expand Down
126 changes: 126 additions & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,132 @@ def forward(
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if sw is not None else None,
)

def forward_flat(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_batch: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor,
batch: torch.Tensor,
ptr: torch.Tensor,
fparam: torch.Tensor | None = None,
central_ext_index: torch.Tensor | None = None,
nlist_ext: torch.Tensor | None = None,
a_nlist: torch.Tensor | None = None,
a_nlist_ext: torch.Tensor | None = None,
nlist_mask: torch.Tensor | None = None,
a_nlist_mask: torch.Tensor | None = None,
edge_index: torch.Tensor | None = None,
angle_index: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Compute the descriptor with flat batch format.

Parameters
----------
extended_coord : torch.Tensor
Extended coordinates [total_extended_atoms, 3].
extended_atype : torch.Tensor
Extended atom types [total_extended_atoms].
extended_batch : torch.Tensor
Frame assignment for extended atoms [total_extended_atoms].
nlist : torch.Tensor
Neighbor list [total_atoms, nnei].
mapping : torch.Tensor
Extended atom -> local flat index mapping [total_extended_atoms].
batch : torch.Tensor
Frame assignment for local atoms [total_atoms].
ptr : torch.Tensor
Frame boundaries [nframes + 1].
fparam : torch.Tensor | None
Frame parameters [nframes, ndf].
central_ext_index : torch.Tensor | None
Extended-atom indices corresponding to local atoms.
nlist_ext, a_nlist_ext : torch.Tensor | None
Edge and angle neighbor lists indexing concatenated extended atoms.
nlist_mask, a_nlist_mask : torch.Tensor | None
Valid-neighbor masks for flat edge and angle neighbor lists.
edge_index, angle_index : torch.Tensor | None
Dynamic graph indices produced by the flat graph preprocessor.

Returns
-------
result : dict[str, torch.Tensor]
Dictionary containing:
- 'descriptor': [total_atoms, descriptor_dim]
- 'rot_mat': [total_atoms, e_dim, 3] or None
- 'g2': edge embedding or None
- 'h2': pair representation or None
"""
extended_coord = extended_coord.to(dtype=self.prec)

# Flat batches embed all extended atoms, then gather central atoms.
node_ebd_ext = self.type_embedding(
extended_atype
) # [total_extended_atoms, tebd_dim]

if self.add_chg_spin_ebd:
assert fparam is not None
assert self.chg_embedding is not None
assert self.spin_embedding is not None

# Expand frame-level charge/spin parameters to extended atoms.
charge = fparam[extended_batch, 0].to(dtype=torch.int64) + 100
spin = fparam[extended_batch, 1].to(dtype=torch.int64)
chg_ebd = self.chg_embedding(charge)
spin_ebd = self.spin_embedding(spin)
sys_cs_embd = self.act(
self.mix_cs_mlp(torch.cat((chg_ebd, spin_ebd), dim=-1))
)
node_ebd_ext = node_ebd_ext + sys_cs_embd

if central_ext_index is None:
from deepmd.pt.utils.nlist import get_central_ext_index

central_ext_index = get_central_ext_index(extended_batch, ptr)
node_ebd_inp = node_ebd_ext[central_ext_index]

node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows.forward_flat(
nlist,
extended_coord,
extended_atype,
extended_batch,
node_ebd_ext,
mapping,
batch,
ptr,
central_ext_index=central_ext_index,
nlist_ext=nlist_ext,
a_nlist=a_nlist,
a_nlist_ext=a_nlist_ext,
nlist_mask=nlist_mask,
a_nlist_mask=a_nlist_mask,
edge_index=edge_index,
angle_index=angle_index,
)

if self.concat_output_tebd:
node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1)

return {
"descriptor": node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
"rot_mat": (
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
if rot_mat is not None
else None
),
"g2": (
edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
if edge_ebd is not None
else None
),
"h2": (
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
if h2 is not None
else None
),
}

@classmethod
def update_sel(
cls,
Expand Down
Loading