Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,25 @@ def _make_dp_loader_set(
# LMDB path: single string → LmdbDataset
if isinstance(training_systems, str) and is_lmdb(training_systems):
auto_prob = training_dataset_params.get("auto_prob", None)
mixed_batch = training_dataset_params.get("mixed_batch", False)
train_data_single = LmdbDataset(
training_systems,
model_params_single["type_map"],
training_dataset_params["batch_size"],
mixed_batch=mixed_batch,
auto_prob_style=auto_prob,
)
if (
validation_systems is not None
and isinstance(validation_systems, str)
and is_lmdb(validation_systems)
):
val_mixed_batch = validation_dataset_params.get("mixed_batch", False)
validation_data_single = LmdbDataset(
validation_systems,
model_params_single["type_map"],
validation_dataset_params["batch_size"],
mixed_batch=val_mixed_batch,
)
elif validation_systems is not None:
validation_data_single = _make_dp_loader_set(
Expand Down
18 changes: 14 additions & 4 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,20 @@ def forward(
more_loss = {}
# more_loss['log_keys'] = [] # showed when validation on the fly
# more_loss['test_keys'] = [] # showed when doing dp test
atom_norm = 1.0 / natoms
# Normalization exponent controls loss scaling with system size:
# - norm_exp=2 (intensive_ener_virial=True): loss uses 1/N² scaling, making it independent of system size
# - norm_exp=1 (intensive_ener_virial=False, legacy): loss uses 1/N scaling, which varies with system size

# Detect mixed batch format
is_mixed_batch = "ptr" in input_dict and input_dict["ptr"] is not None

# For mixed batch, compute per-frame atom_norm and average
if is_mixed_batch:
ptr = input_dict["ptr"]
nframes = ptr.numel() - 1
# Compute natoms for each frame
natoms_per_frame = ptr[1:] - ptr[:-1] # [nframes]
# Average atom_norm across frames
atom_norm = torch.mean(1.0 / natoms_per_frame.float())
else:
atom_norm = 1.0 / natoms
Comment thread
coderabbitai[bot] marked this conversation as resolved.
norm_exp = 2 if self.intensive_ener_virial else 1
if self.has_e and "energy" in model_pred and "energy" in label:
energy_pred = model_pred["energy"]
Expand Down
131 changes: 131 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,137 @@ def forward_atomic(
)
return fit_ret

def forward_common_atomic_flat(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_batch: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor,
batch: torch.Tensor,
ptr: torch.Tensor,
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
extended_ptr: torch.Tensor | None = None,
central_ext_index: torch.Tensor | None = None,
nlist_ext: torch.Tensor | None = None,
a_nlist: torch.Tensor | None = None,
a_nlist_ext: torch.Tensor | None = None,
nlist_mask: torch.Tensor | None = None,
a_nlist_mask: torch.Tensor | None = None,
edge_index: torch.Tensor | None = None,
angle_index: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Forward pass with flat batch format.

Parameters
----------
extended_coord : torch.Tensor
Extended coordinates [total_extended_atoms, 3].
extended_atype : torch.Tensor
Extended atom types [total_extended_atoms].
extended_batch : torch.Tensor
Frame assignment for extended atoms [total_extended_atoms].
nlist : torch.Tensor
Neighbor list [total_atoms, nnei].
mapping : torch.Tensor
Extended atom -> local flat index mapping [total_extended_atoms].
batch : torch.Tensor
Frame assignment for local atoms [total_atoms].
ptr : torch.Tensor
Frame boundaries [nframes + 1].
fparam : torch.Tensor | None
Frame parameters [nframes, ndf].
aparam : torch.Tensor | None
Atomic parameters [total_atoms, nda].
central_ext_index : torch.Tensor | None
Extended-atom indices corresponding to local atoms.
nlist_ext, a_nlist_ext : torch.Tensor | None
Edge and angle neighbor lists indexing concatenated extended atoms.
nlist_mask, a_nlist_mask : torch.Tensor | None
Valid-neighbor masks for flat edge and angle neighbor lists.
edge_index, angle_index : torch.Tensor | None
Dynamic graph indices produced by the flat graph preprocessor.

Returns
-------
result_dict : dict[str, torch.Tensor]
Model predictions in flat format.
"""
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)

# Descriptor and fitting both consume the flat atom layout.
descriptor_out = self.descriptor.forward_flat(
extended_coord,
extended_atype,
extended_batch,
nlist,
mapping,
batch,
ptr,
fparam=fparam if self.add_chg_spin_ebd else None,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
central_ext_index=central_ext_index,
nlist_ext=nlist_ext,
a_nlist=a_nlist,
a_nlist_ext=a_nlist_ext,
nlist_mask=nlist_mask,
a_nlist_mask=a_nlist_mask,
edge_index=edge_index,
angle_index=angle_index,
)

descriptor = descriptor_out.get("descriptor")
rot_mat = descriptor_out.get("rot_mat")
g2 = descriptor_out.get("g2")
h2 = descriptor_out.get("h2")

if self.enable_eval_descriptor_hook:
self.eval_descriptor_list.append(descriptor.detach())

if central_ext_index is None:
from deepmd.pt.utils.nlist import get_central_ext_index

central_ext_index = get_central_ext_index(extended_batch, ptr)
atype = extended_atype[central_ext_index]
else:
atype = extended_atype[central_ext_index]

fit_ret = self.fitting_net.forward_flat(
descriptor,
atype,
batch,
ptr,
gr=rot_mat,
g2=g2,
h2=h2,
fparam=fparam,
aparam=aparam,
)
fit_ret = self.apply_out_stat(fit_ret, atype)

atom_mask = self.make_atom_mask(atype).to(torch.int32)
if self.atom_excl is not None:
atom_mask *= self.atom_excl(atype.unsqueeze(0)).squeeze(0)

for kk in fit_ret.keys():
out_shape = fit_ret[kk].shape
out_shape2 = 1
for ss in out_shape[1:]:
out_shape2 *= ss
fit_ret[kk] = (
fit_ret[kk].reshape([out_shape[0], out_shape2]) * atom_mask[:, None]
).view(out_shape)
fit_ret["mask"] = atom_mask

if self.enable_eval_fitting_last_layer_hook:
if "middle_output" in fit_ret:
self.eval_fitting_last_layer_list.append(
fit_ret.pop("middle_output").detach()
)

return fit_ret

def compute_or_load_stat(
self,
sampled_func: Callable[[], list[dict]],
Expand Down
126 changes: 126 additions & 0 deletions deepmd/pt/model/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,132 @@ def forward(
sw.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION) if sw is not None else None,
)

def forward_flat(
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
extended_batch: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor,
batch: torch.Tensor,
ptr: torch.Tensor,
fparam: torch.Tensor | None = None,
central_ext_index: torch.Tensor | None = None,
nlist_ext: torch.Tensor | None = None,
a_nlist: torch.Tensor | None = None,
a_nlist_ext: torch.Tensor | None = None,
nlist_mask: torch.Tensor | None = None,
a_nlist_mask: torch.Tensor | None = None,
edge_index: torch.Tensor | None = None,
angle_index: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Compute the descriptor with flat batch format.

Parameters
----------
extended_coord : torch.Tensor
Extended coordinates [total_extended_atoms, 3].
extended_atype : torch.Tensor
Extended atom types [total_extended_atoms].
extended_batch : torch.Tensor
Frame assignment for extended atoms [total_extended_atoms].
nlist : torch.Tensor
Neighbor list [total_atoms, nnei].
mapping : torch.Tensor
Extended atom -> local flat index mapping [total_extended_atoms].
batch : torch.Tensor
Frame assignment for local atoms [total_atoms].
ptr : torch.Tensor
Frame boundaries [nframes + 1].
fparam : torch.Tensor | None
Frame parameters [nframes, ndf].
central_ext_index : torch.Tensor | None
Extended-atom indices corresponding to local atoms.
nlist_ext, a_nlist_ext : torch.Tensor | None
Edge and angle neighbor lists indexing concatenated extended atoms.
nlist_mask, a_nlist_mask : torch.Tensor | None
Valid-neighbor masks for flat edge and angle neighbor lists.
edge_index, angle_index : torch.Tensor | None
Dynamic graph indices produced by the flat graph preprocessor.

Returns
-------
result : dict[str, torch.Tensor]
Dictionary containing:
- 'descriptor': [total_atoms, descriptor_dim]
- 'rot_mat': [total_atoms, e_dim, 3] or None
- 'g2': edge embedding or None
- 'h2': pair representation or None
"""
extended_coord = extended_coord.to(dtype=self.prec)

# Flat batches embed all extended atoms, then gather central atoms.
node_ebd_ext = self.type_embedding(
extended_atype
) # [total_extended_atoms, tebd_dim]

if self.add_chg_spin_ebd:
assert fparam is not None
assert self.chg_embedding is not None
assert self.spin_embedding is not None

# Expand frame-level charge/spin parameters to extended atoms.
charge = fparam[extended_batch, 0].to(dtype=torch.int64) + 100
spin = fparam[extended_batch, 1].to(dtype=torch.int64)
chg_ebd = self.chg_embedding(charge)
spin_ebd = self.spin_embedding(spin)
sys_cs_embd = self.act(
self.mix_cs_mlp(torch.cat((chg_ebd, spin_ebd), dim=-1))
)
node_ebd_ext = node_ebd_ext + sys_cs_embd

if central_ext_index is None:
from deepmd.pt.utils.nlist import get_central_ext_index

central_ext_index = get_central_ext_index(extended_batch, ptr)
node_ebd_inp = node_ebd_ext[central_ext_index]

node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows.forward_flat(
nlist,
extended_coord,
extended_atype,
extended_batch,
node_ebd_ext,
mapping,
batch,
ptr,
central_ext_index=central_ext_index,
nlist_ext=nlist_ext,
a_nlist=a_nlist,
a_nlist_ext=a_nlist_ext,
nlist_mask=nlist_mask,
a_nlist_mask=a_nlist_mask,
edge_index=edge_index,
angle_index=angle_index,
)

if self.concat_output_tebd:
node_ebd = torch.cat([node_ebd, node_ebd_inp], dim=-1)

return {
"descriptor": node_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION),
"rot_mat": (
rot_mat.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
if rot_mat is not None
else None
),
"g2": (
edge_ebd.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
if edge_ebd is not None
else None
),
"h2": (
h2.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION)
if h2 is not None
else None
),
}

@classmethod
def update_sel(
cls,
Expand Down
Loading