diff --git a/.github/workflows/copilot-setup-steps.yml b/.github/workflows/copilot-setup-steps.yml index 55b456877f..ba111a43df 100644 --- a/.github/workflows/copilot-setup-steps.yml +++ b/.github/workflows/copilot-setup-steps.yml @@ -52,7 +52,7 @@ jobs: run: uv pip install --group pin_tensorflow_cpu --group pin_pytorch_cpu --torch-backend cpu - name: Build Python package - run: uv pip install -e .[cpu,test] + run: uv pip install -e .[cpu,test,torch] - name: Install prek tools run: uv tool install prek diff --git a/.github/workflows/test_cc.yml b/.github/workflows/test_cc.yml index b5c9166a9b..65a504fd88 100644 --- a/.github/workflows/test_cc.yml +++ b/.github/workflows/test_cc.yml @@ -46,7 +46,10 @@ jobs: run: | source/install/uv_with_retry.sh pip install --system --group pin_tensorflow_cpu --group pin_pytorch_cpu --group pin_jax_cpu --torch-backend cpu export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') - source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax] mpi4py mpich + export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') + source/install/uv_with_retry.sh pip install --system -e .[cpu,test,lmp,jax,torch] mpi4py mpich + env: + DP_ENABLE_PYTORCH: 1 - name: Convert models run: source/tests/infer/convert-models.sh # https://github.com/actions/runner-images/issues/9491 diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index 461d972f57..c723390266 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -31,7 +31,7 @@ jobs: source/install/uv_with_retry.sh pip install --system openmpi --group pin_tensorflow_cpu --group pin_pytorch_cpu --torch-backend cpu export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') - source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py --group pin_jax_cpu + source/install/uv_with_retry.sh pip install --system -e .[test,jax,torch] mpi4py --group pin_jax_cpu source/install/uv_with_retry.sh pip install --system --find-links "https://www.paddlepaddle.org.cn/packages/nightly/cpu/paddlepaddle/" --index-url https://pypi.org/simple --trusted-host www.paddlepaddle.org.cn --trusted-host paddlepaddle.org.cn paddlepaddle==3.4.0.dev20260310 env: # Please note that uv has some issues with finding diff --git a/backend/find_pytorch.py b/backend/find_pytorch.py index d50f57bf5e..ca9ac4daf1 100644 --- a/backend/find_pytorch.py +++ b/backend/find_pytorch.py @@ -136,6 +136,7 @@ def get_pt_requirement(pt_version: str = "") -> dict: if pt_version != "" # https://github.com/pytorch/pytorch/commit/7e0c26d4d80d6602aed95cb680dfc09c9ce533bc else "torch>=2.1.0", + "e3nn>=0.5.9", *mpi_requirement, *cibw_requirement, ], diff --git a/deepmd/dpmodel/utils/dist_check.py b/deepmd/dpmodel/utils/dist_check.py new file mode 100644 index 0000000000..3c5eb25681 --- /dev/null +++ b/deepmd/dpmodel/utils/dist_check.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Minimum pairwise distance check for frame validity filtering.""" + +from __future__ import ( + annotations, +) + +import numpy as np + +_MIN_PAIR_DIST_BLOCK_PAIRS = 262_144 + + +def compute_min_pair_dist_single( + coord: np.ndarray, + box: np.ndarray | None, + atype: np.ndarray, + stop_below: float | None = None, +) -> float: + """Compute the minimum pairwise atomic distance for a single frame. + + Parameters + ---------- + coord : np.ndarray + Atomic coordinates, flattened with shape (natoms * 3,) + or reshaped as (natoms, 3). + box : np.ndarray or None + Box vectors with shape (9,) for PBC, or None for non-PBC. + atype : np.ndarray + Atom types with shape (natoms,). Virtual atoms (type < 0) + are excluded from the distance check. + stop_below : float or None + Optional early-stop threshold. If a block has any pair closer + than this value, the block minimum is returned immediately. + + Returns + ------- + float + Minimum pairwise distance. Returns inf if fewer than 2 + real atoms exist. + """ + coord = coord.reshape(-1, 3) + + # === Step 1. Filter out virtual atoms === + real_mask = atype.ravel() >= 0 + real_coord = coord[real_mask] + n_real = real_coord.shape[0] + if n_real < 2: + return float("inf") + + # === Step 2. Prepare minimum image convention for PBC === + if box is not None: + cell = box.reshape(3, 3) + inv_cell = np.linalg.inv(cell) + else: + cell = None + inv_cell = None + + # === Step 3. Compute distances in bounded row blocks === + block_size = max(1, min(n_real, _MIN_PAIR_DIST_BLOCK_PAIRS // n_real)) + min_dist_sq = float("inf") + stop_dist_sq = ( + float(stop_below) * float(stop_below) + if stop_below is not None and stop_below > 0.0 + else None + ) + for start in range(0, n_real, block_size): + stop = min(start + block_size, n_real) + diff = real_coord[np.newaxis, :, :] - real_coord[start:stop, np.newaxis, :] + + if cell is not None and inv_cell is not None: + frac_diff = diff @ inv_cell + frac_diff -= np.round(frac_diff) + diff = frac_diff @ cell + + dist_sq = np.sum(diff * diff, axis=-1) + rows = np.arange(stop - start, dtype=np.int64) + dist_sq[rows, start + rows] = np.inf + min_dist_sq = min(min_dist_sq, float(dist_sq.min())) + if min_dist_sq == 0.0 or ( + stop_dist_sq is not None and min_dist_sq < stop_dist_sq + ): + break + + return float(np.sqrt(min_dist_sq)) diff --git a/deepmd/dpmodel/utils/lmdb_data.py b/deepmd/dpmodel/utils/lmdb_data.py index 29253263a6..196ea4f291 100644 --- a/deepmd/dpmodel/utils/lmdb_data.py +++ b/deepmd/dpmodel/utils/lmdb_data.py @@ -21,6 +21,9 @@ import msgpack import numpy as np +from deepmd.dpmodel.utils.dist_check import ( + compute_min_pair_dist_single, +) from deepmd.env import ( GLOBAL_ENER_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION, @@ -597,6 +600,29 @@ def __getitem__(self, index: int) -> dict[str, Any]: frame["natoms"] = fallback frame["real_natoms_vec"] = fallback + if "min_pair_dist" in self._data_requirements and "min_pair_dist" not in frame: + box = frame.get("box") + if box is not None and np.allclose(box, 0.0): + box = None + req = self._data_requirements["min_pair_dist"] + min_pair_dist = float( + req.get("default", 0.0) + if isinstance(req, dict) + else getattr(req, "default", 0.0) + ) + frame["find_min_pair_dist"] = np.float32(1.0) + frame["min_pair_dist"] = np.array( + [ + compute_min_pair_dist_single( + frame["coord"], + box, + frame["atype"], + stop_below=min_pair_dist, + ) + ], + dtype=self._resolve_dtype("min_pair_dist"), + ) + # Add find_* flags for all data keys present in the frame. # Core structural keys and metadata are excluded — only label-like # and auxiliary data keys get find_* flags. diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index cc299be147..b7b493f342 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -355,9 +355,16 @@ def extend_coord_with_ghosts( shift_idx = xp.take(xyz, xp.argsort(xp.linalg.vector_norm(xyz, axis=1)), axis=0) ns, _ = shift_idx.shape nall = ns * nloc - # shift_vec = xp.einsum("sd,fdk->fsk", shift_idx, cell) - shift_vec = xp.tensordot(shift_idx, cell, axes=([1], [1])) - shift_vec = xp.permute_dims(shift_vec, (1, 0, 2)) + if array_api_compat.is_jax_namespace(xp): + # Avoid JAX internal errors in tensordot. + shift_vec = xp.sum( + shift_idx[xp.newaxis, :, :, xp.newaxis] * cell[:, xp.newaxis, :, :], + axis=2, + ) + else: + # shift_vec = xp.einsum("sd,fdk->fsk", shift_idx, cell) + shift_vec = xp.tensordot(shift_idx, cell, axes=([1], [1])) + shift_vec = xp.permute_dims(shift_vec, (1, 0, 2)) extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :] extend_atype = xp.tile(atype[:, :, xp.newaxis], (1, ns, 1)) extend_aidx = xp.tile(aidx[:, :, xp.newaxis], (1, ns, 1)) diff --git a/deepmd/pt/entrypoints/freeze_pt2.py b/deepmd/pt/entrypoints/freeze_pt2.py new file mode 100644 index 0000000000..7c643e023e --- /dev/null +++ b/deepmd/pt/entrypoints/freeze_pt2.py @@ -0,0 +1,568 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""DPA4 / SeZM → AOTInductor ``.pt2`` freeze path for the pt backend. + +SeZM relies on a nested ``autograd.grad(create_graph=True)`` inside +``fit_output_to_model_output``; TorchScript cannot represent that +graph, so DPA4 / SeZM checkpoints are routed through AOTInductor instead. +The output archive layout matches the ``pt_expt`` convention and is +consumed directly by ``DeepPotPTExpt.cc`` without any C++ change. + +Tracing runs on CPU (``make_fx`` with ``_allow_non_fake_inputs=True`` +is brittle on CUDA because the proxy-tensor dispatcher does not set +up CUDA streams for the captured parameters). The compiled package +is moved to the target device via ``move_to_device_pass`` before +``aoti_compile_and_package``. + +``.pt2`` I/O is always float64, matching the C++ contract in +``DeepPotPTExpt::compute`` where LAMMPS coordinates are unconditionally +cast to ``torch::kFloat64``. SeZM's own ``_input_type_cast`` bridges +fp64 inputs to whatever internal compute dtype the checkpoint uses. +""" + +from __future__ import ( + annotations, +) + +import json +import logging +import zipfile +from copy import ( + deepcopy, +) +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.train.wrapper import ( + ModelWrapper, +) +from deepmd.pt.utils.env import ( + DEVICE, +) +from deepmd.utils.model_branch_dict import ( + get_model_dict, +) + +log = logging.getLogger(__name__) + + +def _model_has_spin(model: torch.nn.Module) -> bool: + """Return whether ``model`` uses the spin lower interface.""" + has_spin = getattr(model, "has_spin", False) + return bool(has_spin() if callable(has_spin) else has_spin) + + +def _get_model_ntypes(model: torch.nn.Module) -> int: + """Return atom type count even when the exported type map is empty.""" + type_map = list(model.get_type_map()) + if type_map: + return len(type_map) + descriptor = model.get_descriptor() + return int(descriptor.get_ntypes()) + + +def _strip_shape_assertions(graph_module: torch.nn.Module) -> None: + """Remove deferred shape assertions from spin export graphs. + + The spin lower path slices tensors using both ``nall`` and ``nloc`` after + virtual atom expansion. ``torch.export`` may turn valid dynamic cases into + deferred ``Ne(nall, nloc)`` assertions, even though the graph works for both + NoPBC and ghost-atom inputs. The generic pt_expt spin exporter applies the + same cleanup. + """ + graph = graph_module.graph + for node in list(graph.nodes): + if ( + node.op == "call_function" + and node.target is torch.ops.aten._assert_scalar.default + ): + graph.erase_node(node) + graph.eliminate_dead_code() + graph_module.recompile() + + +def _extract_state_and_params( + ckpt: Any, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Unwrap a ``torch.load`` result into ``(state_dict, model_params)``. + + Accepts both the training-wrapper layout (weights under a top-level + ``"model"`` key) and a bare state dict. + """ + inner = ckpt.get("model", ckpt) if isinstance(ckpt, dict) else ckpt + if not isinstance(inner, dict): + raise ValueError("Unsupported checkpoint: expected a dict-like state dict.") + extra = inner.get("_extra_state") or {} + params = extra.get("model_params") + if not isinstance(params, dict): + raise ValueError("Unsupported checkpoint: missing '_extra_state.model_params'.") + return inner, params + + +def is_sezm_checkpoint(ckpt_path: str) -> bool: + """Best-effort detection used by the CLI to route DPA4 / SeZM checkpoints. + + Returns ``False`` for unreadable files or non-SeZM checkpoints; no + exception leaks out so the caller can treat this as a pure routing + signal. + """ + try: + raw = torch.load(ckpt_path, map_location="cpu", weights_only=False) + except Exception: + return False + try: + _, params = _extract_state_and_params(raw) + except ValueError: + return False + if "model_dict" in params: + return any( + str(branch_params.get("type", "")).lower() in ("sezm", "dpa4") + for branch_params in params["model_dict"].values() + ) + return str(params.get("type", "")).lower() in ("sezm", "dpa4") + + +def _select_model_head( + state_dict: dict[str, Any], + params: dict[str, Any], + head: str | None, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Extract a single selected model branch from a checkpoint.""" + if "model_dict" not in params: + if head is not None: + raise NotImplementedError( + "SeZM .pt2 freeze does not yet support head selection for single-task checkpoints; pass head=None." + ) + return state_dict, params + + model_alias_dict, _ = get_model_dict(params["model_dict"]) + model_keys = list(params["model_dict"]) + if head is None and "Default" in model_alias_dict: + head = "Default" + log.info( + "Using default head %s for multitask SeZM freeze.", model_alias_dict[head] + ) + if head is None: + raise ValueError( + "Head must be set for multitask SeZM/DPA4 freeze. " + f"Available heads are: {model_keys}." + ) + if head not in model_alias_dict: + head_lower = head.lower() + for key in model_alias_dict: + if key.lower() == head_lower: + head = key + break + if head not in model_alias_dict: + raise ValueError( + f"No head or alias named {head!r} in model. Available heads are: {model_keys}." + ) + + branch = model_alias_dict[head] + branch_params = deepcopy(params["model_dict"][branch]) + branch_state: dict[str, Any] = { + "_extra_state": deepcopy(state_dict.get("_extra_state", {})), + } + branch_state["_extra_state"]["model_params"] = branch_params + prefix = f"model.{branch}." + for key, value in state_dict.items(): + if key.startswith(prefix): + branch_state[key.replace(prefix, "model.Default.")] = value + return branch_state, branch_params + + +def _to_py_list(value: Any) -> Any: + """Coerce torch / numpy scalars into JSON-friendly Python values.""" + if value is None: + return None + if isinstance(value, torch.Tensor): + return value.detach().cpu().tolist() + if isinstance(value, np.ndarray): + return value.tolist() + if isinstance(value, (list, tuple)): + return list(value) + if isinstance(value, (int, float, bool, str)): + return value + raise TypeError(f"Cannot JSON-serialize value of type {type(value)!r}") + + +def _collect_metadata( + model: torch.nn.Module, + output_keys: list[str], + is_spin: bool | None = None, +) -> dict: + """Assemble the flat metadata dict expected by :class:`DeepPotPTExpt`. + + Mirrors the reader contract at ``source/api_cc/src/DeepPotPTExpt.cc`` and + the metadata-only load path in ``deepmd.pt_expt.infer.deep_eval.DeepEval``: + every field consumed by C++ LAMMPS inference **and** every field + consumed by ``DeepEval._init_from_metadata`` must be present here. + + ``output_keys`` is the insertion order that the loader zips with + ``AOTIModelPackageLoader::run``'s flat output vector. + """ + if is_spin is None: + is_spin = _model_has_spin(model) + fitting_output_defs: list[dict[str, Any]] = [] + for vdef in model.atomic_output_def().get_data().values(): + fitting_output_defs.append( + { + "name": vdef.name, + "shape": list(vdef.shape), + "reducible": vdef.reducible, + "r_differentiable": vdef.r_differentiable, + "c_differentiable": vdef.c_differentiable, + "atomic": vdef.atomic, + # OutputVariableCategory is an IntEnum; force plain int for + # deterministic JSON serialisation across Python versions. + "category": int(vdef.category), + "r_hessian": vdef.r_hessian, + "magnetic": bool(vdef.magnetic or (is_spin and vdef.name == "energy")), + "intensive": vdef.intensive, + } + ) + metadata = { + "type_map": list(model.get_type_map()), + "ntypes": _get_model_ntypes(model), + "rcut": float(model.get_rcut()), + "sel": [int(s) for s in model.get_sel()], + "dim_fparam": int(model.get_dim_fparam()), + "dim_aparam": int(model.get_dim_aparam()), + "dim_chg_spin": int(model.get_dim_chg_spin()), + "mixed_types": bool(model.mixed_types()), + "has_default_fparam": bool(model.has_default_fparam()), + "default_fparam": _to_py_list(model.get_default_fparam()), + "default_chg_spin": _to_py_list(model.get_default_chg_spin()), + "output_keys": list(output_keys), + "fitting_output_defs": fitting_output_defs, + # sel_type feeds DeepEval.get_sel_type() in metadata-only mode. + # SeZM energy models return [] (every type selected). + "sel_type": [int(t) for t in model.get_sel_type()], + "is_spin": bool(is_spin), + } + if is_spin: + metadata["ntypes_spin"] = int(model.spin.get_ntypes_spin()) + metadata["use_spin"] = [bool(v) for v in model.spin.use_spin] + return metadata + + +def _make_sample_inputs( + model: torch.nn.Module, + nframes: int, + nloc: int, + device: torch.device, + has_spin: bool = False, +) -> tuple[torch.Tensor | None, ...]: + """Build representative ``forward_common_lower`` inputs for tracing. + + Tensors are float64 / int64 (matching the ``.pt2`` I/O contract). + """ + rcut = float(model.get_rcut()) + sel = list(model.get_sel()) + ntypes = len(model.get_type_map()) + if ntypes == 0: + ntypes = int(model.get_descriptor().get_ntypes()) + if ntypes <= 0: + raise ValueError("SeZM .pt2 freeze requires at least one atom type.") + dim_fparam = int(model.get_dim_fparam()) + dim_aparam = int(model.get_dim_aparam()) + dim_chg_spin = int(model.get_dim_chg_spin()) + mixed_types = bool(model.mixed_types()) + + box_size = rcut * 3.0 + box = np.eye(3, dtype=np.float64) * box_size + box_np = box.reshape(1, 9) + + rng = np.random.default_rng(42) + coord_np = rng.random((nframes, nloc, 3), dtype=np.float64) * box_size * 0.5 + coord_np += box_size * 0.25 # centre roughly in the middle of the cell + + atype_np = np.zeros((nframes, nloc), dtype=np.int32) + for i in range(nloc): + atype_np[:, i] = i % ntypes + spin_np = np.zeros((nframes, nloc, 3), dtype=np.float64) + if has_spin: + atom_idx = np.arange(nloc, dtype=np.float64).reshape(1, nloc) + spin_np[:, :, 0] = 0.10 + 0.01 * atom_idx + spin_np[:, :, 1] = 0.20 + 0.02 * atom_idx + spin_np[:, :, 2] = 0.05 + + coord_normalized = normalize_coord( + coord_np.reshape(nframes, nloc, 3), + np.tile(box.reshape(1, 3, 3), (nframes, 1, 1)), + ) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype_np, np.tile(box_np, (nframes, 1)), rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + distinguish_types=not mixed_types, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + + ext_coord = torch.tensor(extended_coord, dtype=torch.float64, device=device) + ext_atype = torch.tensor(extended_atype, dtype=torch.int64, device=device) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=device) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=device) + if has_spin: + extended_spin = np.take_along_axis(spin_np, mapping[..., None], axis=1) + ext_spin = torch.tensor(extended_spin, dtype=torch.float64, device=device) + fparam = ( + torch.zeros(nframes, dim_fparam, dtype=torch.float64, device=device) + if dim_fparam > 0 + else None + ) + aparam = ( + torch.zeros(nframes, nloc, dim_aparam, dtype=torch.float64, device=device) + if dim_aparam > 0 + else None + ) + charge_spin = None + if dim_chg_spin > 0: + default_chg_spin = model.get_default_chg_spin() + if default_chg_spin is None: + raise ValueError( + "SeZM .pt2 freeze requires default_chg_spin when charge/spin " + "conditioning is enabled; runtime charge_spin input is not exposed." + ) + charge_spin = ( + default_chg_spin.to(device=device, dtype=torch.float64) + .view(1, dim_chg_spin) + .expand(nframes, -1) + .contiguous() + ) + if has_spin: + if charge_spin is not None: + return ( + ext_coord, + ext_atype, + ext_spin, + nlist_t, + mapping_t, + fparam, + aparam, + charge_spin, + ) + return ext_coord, ext_atype, ext_spin, nlist_t, mapping_t, fparam, aparam + if charge_spin is not None: + return ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin + return ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam + + +def _resolve_nframes( + model: torch.nn.Module, + nloc: int, + device: torch.device, + start: int = 2, + has_spin: bool = False, +) -> tuple[int, tuple[torch.Tensor | None, ...]]: + """Pick an ``nframes`` that does not collide with any other dim size. + + ``torch.export``'s duck-sizing unifies symbolic dims whose concrete + sample values match; if ``nframes`` happens to equal, say, the + spatial ``3`` or the virial ``9``, the ExportedProgram rejects + later calls whose ``nframes`` differs. Bumping ``nframes`` until + no collision is left keeps the export safe. + """ + nframes = start + sample = _make_sample_inputs( + model, + nframes=nframes, + nloc=nloc, + device=device, + has_spin=has_spin, + ) + other_dims: set[int] = set() + for t in sample: + if t is not None: + other_dims.update(t.shape[1:]) + while nframes in other_dims: + nframes += 1 + if nframes != start: + sample = _make_sample_inputs( + model, + nframes=nframes, + nloc=nloc, + device=device, + has_spin=has_spin, + ) + return nframes, sample + + +def _build_dynamic_shapes( + sample_inputs: tuple[torch.Tensor | None, ...], +) -> tuple: + """Positional ``dynamic_shapes`` for the traced + ``(ext_coord, ext_atype, nlist, mapping, fparam, aparam)`` signature. + """ + nframes_dim = torch.export.Dim("nframes", min=1) + has_spin = ( + len(sample_inputs) >= 7 + and sample_inputs[2] is not None + and sample_inputs[2].is_floating_point() + ) + has_charge_spin = (has_spin and len(sample_inputs) == 8) or ( + not has_spin and len(sample_inputs) == 7 + ) + # Spin export currently generates a valid lower-bound guard from its + # virtual-atom split/concat pattern. Matching the bound keeps export strict, + # while `_strip_shape_assertions` removes the spurious deferred guards later. + nall_dim = torch.export.Dim("nall", min=4 if has_spin else 1) + nloc_dim = torch.export.Dim("nloc", min=1) + fparam = sample_inputs[5] if has_spin else sample_inputs[4] + aparam = sample_inputs[6] if has_spin else sample_inputs[5] + if has_spin: + shapes = ( + {0: nframes_dim, 1: nall_dim}, # extended_coord + {0: nframes_dim, 1: nall_dim}, # extended_atype + {0: nframes_dim, 1: nall_dim}, # extended_spin + {0: nframes_dim, 1: nloc_dim}, # nlist + {0: nframes_dim, 1: nall_dim}, # mapping + {0: nframes_dim} if fparam is not None else None, + {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, + ) + if has_charge_spin: + shapes = (*shapes, {0: nframes_dim}) + return shapes + shapes = ( + {0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3) + {0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall) + {0: nframes_dim, 1: nloc_dim}, # nlist: (nframes, nloc, nnei) + {0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall) + {0: nframes_dim} if fparam is not None else None, + {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, + ) + if has_charge_spin: + shapes = (*shapes, {0: nframes_dim}) + return shapes + + +def freeze_sezm_to_pt2( + ckpt_path: str, + out_path: str, + *, + device: torch.device | None = None, + head: str | None = None, +) -> None: + """Freeze a SeZM checkpoint into an AOTInductor ``.pt2`` archive. + + Parameters + ---------- + ckpt_path + Path to the SeZM training checkpoint (``.pt``). + out_path + Destination file. A ``.pt2`` suffix is expected. + device + Target device for the compiled shared library. Defaults to + :data:`DEVICE`. Tracing itself always runs on CPU. + head + Model head to export from a multi-task checkpoint. If omitted, the + ``Default`` head is used when present; otherwise multi-task checkpoints + must pass an explicit head. Single-task checkpoints must pass ``None``. + """ + from torch._inductor import ( + aoti_compile_and_package, + ) + + target_device = device if device is not None else DEVICE + + raw = torch.load(ckpt_path, map_location="cpu", weights_only=False) + state_dict, params = _extract_state_and_params(raw) + state_dict, params = _select_model_head(state_dict, params, head) + + model_type = str(params.get("type", "")).lower() + if model_type not in ("sezm", "dpa4"): + raise ValueError( + f"freeze_sezm_to_pt2 expects a SeZM/DPA4 checkpoint, got type={params.get('type')!r}." + ) + + model = get_model(params) + is_spin = _model_has_spin(model) + ModelWrapper(model).load_state_dict(state_dict) + model.eval() + model.to("cpu") + + _, sample_inputs_cpu = _resolve_nframes( + model, + nloc=7, + device=torch.device("cpu"), + has_spin=is_spin, + ) + + # do_atomic_virial=True pulls every key that DeepPotPTExpt may read + # (energy, energy_redu, energy_derv_r, energy_derv_c, energy_derv_c_redu) + # into the traced graph. + traced = model.forward_common_lower_exportable( + *sample_inputs_cpu, + do_atomic_virial=True, + ) + + # Output key order is taken from a concrete run; Python dict order + # is stable and matches what DeepPotPTExpt::extract_outputs zips + # against AOTIModelPackageLoader::run's output vector. + with torch.no_grad(): + sample_out = traced(*sample_inputs_cpu) + output_keys = list(sample_out.keys()) + + exported = torch.export.export( + traced, + sample_inputs_cpu, + dynamic_shapes=_build_dynamic_shapes(sample_inputs_cpu), + strict=False, + prefer_deferred_runtime_asserts_over_guards=True, + ) + if is_spin: + _strip_shape_assertions(exported.graph_module) + + # move_to_device_pass handles FakeTensor device propagation cleanly; + # a naive .to(device) on the exported program does not. + if target_device.type != "cpu": + from torch.export.passes import ( + move_to_device_pass, + ) + + exported = move_to_device_pass(exported, target_device) + + out_path_str = str(out_path) + aoti_compile_and_package(exported, package_path=out_path_str) + + metadata = _collect_metadata(model, output_keys=output_keys, is_spin=is_spin) + with zipfile.ZipFile(out_path_str, "a") as zf: + zf.writestr("model/extra/metadata.json", json.dumps(metadata)) + # The raw training params are preserved so `dp change-bias` and + # other downstream tooling can recover the exact training config. + # ``default=str`` is a safety net for exotic nested values. + zf.writestr( + "model/extra/model_def_script.json", + json.dumps(params, default=str), + ) + + log.info( + "Saved SeZM .pt2 to %s (device=%s, output_keys=%s)", + out_path_str, + target_device, + output_keys, + ) + + +__all__ = [ + "freeze_sezm_to_pt2", + "is_sezm_checkpoint", +] diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 7b45c46333..bf5c49a4d6 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -454,6 +454,26 @@ def freeze( output: str = "frozen_model.pth", head: str | None = None, ) -> None: + # DPA4 / SeZM checkpoints are routed to the AOTInductor .pt2 exporter + from deepmd.pt.entrypoints.freeze_pt2 import ( + freeze_sezm_to_pt2, + is_sezm_checkpoint, + ) + + output_path = Path(output) + if is_sezm_checkpoint(model): + out_pt2 = str(output_path.with_suffix(".pt2")) + freeze_sezm_to_pt2(model, out_pt2, head=head) + log.info( + "Detected DPA4 / SeZM checkpoint '%s'; saved AOTInductor archive to %s", + model, + out_pt2, + ) + return + + # TorchScript frozen models use the .pth suffix by convention. + output = str(output_path.with_suffix(".pth")) + tester = inference.Tester(model, head=head) model = tester.model model.eval() @@ -644,7 +664,9 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None: FLAGS.model = str(checkpoint_path.joinpath(latest_ckpt_file)) else: FLAGS.model = FLAGS.checkpoint_folder - FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth")) + # Output suffix is decided inside freeze(): SeZM checkpoints + # produce ``.pt2`` (AOTInductor), every other backend produces + # the legacy ``.pth`` (TorchScript). freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head) elif FLAGS.command == "change-bias": change_bias( diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 2e30b8574a..695d87a6f0 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -84,6 +84,18 @@ log = logging.getLogger(__name__) +def _is_sezm_model_params(model_params: dict[str, Any]) -> bool: + """Return whether the params describe a SeZM / DPA4 model.""" + model_type = str(model_params.get("type", "")).lower() + if model_type in {"sezm", "dpa4", "sezm_spin"}: + return True + descriptor = model_params.get("descriptor") + if isinstance(descriptor, dict): + descriptor_type = str(descriptor.get("type", "")).lower() + return descriptor_type in {"sezm", "dpa4"} + return False + + class DeepEval(DeepEvalBackend): """PyTorch backend implementation of DeepEval. @@ -167,7 +179,8 @@ def __init__( ] = state_dict[item].clone() state_dict = state_dict_head model = get_model(self.input_param).to(DEVICE) - if not self.input_param.get("hessian_mode") and not no_jit: + disable_jit = no_jit or _is_sezm_model_params(self.input_param) + if not self.input_param.get("hessian_mode") and not disable_jit: model = torch.jit.script(model) self.dp = ModelWrapper(model) missing, unexpected = self.dp.load_state_dict(state_dict, strict=False) @@ -709,7 +722,10 @@ def eval_typeebd(self) -> np.ndarray: """ out = [] for mm in self.dp.model["Default"].modules(): - if mm.original_name == TypeEmbedNetConsistent.__name__: + if ( + getattr(mm, "original_name", type(mm).__name__) + == TypeEmbedNetConsistent.__name__ + ): out.append(mm(DEVICE)) if not out: raise KeyError("The model has no type embedding networks.") diff --git a/deepmd/pt/loss/__init__.py b/deepmd/pt/loss/__init__.py index 1d25c1e52f..0d4e55a5fa 100644 --- a/deepmd/pt/loss/__init__.py +++ b/deepmd/pt/loss/__init__.py @@ -2,6 +2,9 @@ from .denoise import ( DenoiseLoss, ) +from .dens import ( + DeNSLoss, +) from .dos import ( DOSLoss, ) @@ -24,6 +27,7 @@ __all__ = [ "DOSLoss", + "DeNSLoss", "DenoiseLoss", "EnergyHessianStdLoss", "EnergySpinLoss", diff --git a/deepmd/pt/loss/dens.py b/deepmd/pt/loss/dens.py new file mode 100644 index 0000000000..03e1c297e4 --- /dev/null +++ b/deepmd/pt/loss/dens.py @@ -0,0 +1,477 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch +import torch.nn.functional as F + +from deepmd.pt.loss.ener import ( + EnergyStdLoss, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + GLOBAL_PT_FLOAT_PRECISION, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +class DeNSLoss(EnergyStdLoss): + """ + Joint energy and direct-force/denoising loss for SeZM `dens` mode. + + This loss follows the EquiformerV3 DeNS training semantics: + + - energy is supervised in one global normalized space + - clean atoms predict globally normalized direct forces + - corrupted atoms predict normalized Gaussian noise `epsilon / sigma` + + A batch enters the denoising path with probability `dens_prob`. Otherwise the + batch falls back to clean direct-force supervision while still using the `dens` + head. When only part of the batch is corrupted, each subset loss is weighted by + its atom fraction so the mixed objective matches one full-batch per-atom average. + """ + + def __init__( + self, + starter_learning_rate: float = 1.0, + start_pref_e: float = 20, + limit_pref_e: float = 20, + start_pref_f: float = 20, + limit_pref_f: float = 20, + loss_func: str = "mae", + inference: bool = False, + dens_prob: float = 0.5, + dens_fixed_noise_std: bool = True, + dens_std: float = 0.025, + dens_corrupt_ratio: float | None = 0.5, + dens_denoising_pos_coefficient: float = 10.0, + start_pref_v: float = 0.0, + limit_pref_v: float = 0.0, + start_pref_ae: float = 0.0, + limit_pref_ae: float = 0.0, + start_pref_pf: float = 0.0, + limit_pref_pf: float = 0.0, + start_pref_gf: float = 0.0, + limit_pref_gf: float = 0.0, + numb_generalized_coord: int = 0, + **kwargs: Any, + ) -> None: + unsupported = sorted(key for key in kwargs if key not in {"type"}) + if unsupported: + unsupported_str = ", ".join(unsupported) + raise ValueError(f"Unsupported `dens` loss options: {unsupported_str}.") + if not dens_fixed_noise_std: + raise NotImplementedError( + "`dens_fixed_noise_std=false` is not supported. " + "This matches the current EquiformerV3 DeNS trainer path, " + "which only uses the fixed-noise-std setting." + ) + if not 0.0 <= float(dens_prob) <= 1.0: + raise ValueError("`dens_prob` must be within [0, 1].") + if ( + dens_corrupt_ratio is not None + and not 0.0 <= float(dens_corrupt_ratio) <= 1.0 + ): + raise ValueError("`dens_corrupt_ratio` must be within [0, 1] or None.") + if float(dens_std) <= 0.0: + raise ValueError("`dens_std` must be > 0.") + if float(dens_denoising_pos_coefficient) < 0.0: + raise ValueError("`dens_denoising_pos_coefficient` must be >= 0.") + unsupported_prefactors = ( + float(start_pref_v), + float(limit_pref_v), + float(start_pref_ae), + float(limit_pref_ae), + float(start_pref_pf), + float(limit_pref_pf), + float(start_pref_gf), + float(limit_pref_gf), + float(numb_generalized_coord), + ) + if any(value != 0.0 for value in unsupported_prefactors): + raise ValueError( + "`dens` loss currently supports only energy and force/noise supervision." + ) + super().__init__( + starter_learning_rate=starter_learning_rate, + start_pref_e=start_pref_e, + limit_pref_e=limit_pref_e, + start_pref_f=start_pref_f, + limit_pref_f=limit_pref_f, + start_pref_v=0.0, + limit_pref_v=0.0, + start_pref_ae=0.0, + limit_pref_ae=0.0, + start_pref_pf=0.0, + limit_pref_pf=0.0, + relative_f=None, + enable_atom_ener_coeff=False, + start_pref_gf=0.0, + limit_pref_gf=0.0, + numb_generalized_coord=0, + loss_func=loss_func, + inference=inference, + use_huber=False, + f_use_norm=(loss_func == "mae"), + huber_delta=0.01, + ) + self.dens_prob = float(dens_prob) + self.dens_fixed_noise_std = bool(dens_fixed_noise_std) + self.dens_std = float(dens_std) + self.dens_corrupt_ratio = ( + None if dens_corrupt_ratio is None else float(dens_corrupt_ratio) + ) + self.dens_denoising_pos_coefficient = float(dens_denoising_pos_coefficient) + + @staticmethod + def _canonicalize_vec3_tensor( + tensor: torch.Tensor, + *, + nf: int, + nloc: int, + name: str, + ) -> torch.Tensor: + """Convert `(nf, nloc*3)` or `(nf, nloc, 3)` to `(nf, nloc, 3)`.""" + if tensor.ndim == 3: + if tensor.shape != (nf, nloc, 3): + raise ValueError( + f"`{name}` must have shape ({nf}, {nloc}, 3), got {tuple(tensor.shape)}." + ) + return tensor + if tensor.ndim == 2: + if tensor.shape != (nf, nloc * 3): + raise ValueError( + f"`{name}` must have shape ({nf}, {nloc * 3}) when flattened, got {tuple(tensor.shape)}." + ) + return tensor.view(nf, nloc, 3) + raise ValueError( + f"`{name}` must have shape ({nf}, {nloc}, 3) or ({nf}, {nloc * 3})." + ) + + def _prepare_dens_inputs( + self, + input_dict: dict[str, torch.Tensor], + label: dict[str, torch.Tensor], + *, + enable_dens: bool, + ) -> tuple[ + dict[str, torch.Tensor], + torch.Tensor, + torch.Tensor, + torch.Tensor, + bool, + ]: + """Build noisy coordinates and mixed targets for one forward pass.""" + atype = input_dict["atype"] + nf, nloc = atype.shape[:2] + coord_raw = input_dict["coord"] + coord = self._canonicalize_vec3_tensor( + coord_raw, nf=nf, nloc=nloc, name="coord" + ) + force_label = self._canonicalize_vec3_tensor( + label["force"], nf=nf, nloc=nloc, name="force" + ).to(device=coord.device, dtype=coord.dtype) + + use_dens = bool( + enable_dens + and self.dens_prob > 0.0 + and torch.rand( + (), dtype=GLOBAL_PT_FLOAT_PRECISION, device=coord.device + ).item() + < self.dens_prob + ) + noise_mask = torch.zeros((nf, nloc), dtype=torch.bool, device=coord.device) + noise_vec = torch.zeros_like(coord) + if use_dens: + if self.dens_corrupt_ratio is None: + noise_mask = torch.ones( + (nf, nloc), dtype=torch.bool, device=coord.device + ) + else: + noise_mask = ( + torch.rand( + (nf, nloc), dtype=GLOBAL_PT_FLOAT_PRECISION, device=coord.device + ) + < self.dens_corrupt_ratio + ) + noise_vec = torch.randn_like(coord) * self.dens_std + noise_vec = noise_vec * noise_mask.unsqueeze(-1) + coord_model = coord + noise_vec + + # DeNS predicts normalized noise epsilon / sigma for corrupted atoms. + noise_target = noise_vec / self.dens_std + + model_input = dict(input_dict) + if coord_raw.ndim == 2: + model_input["coord"] = coord_model.view(nf, nloc * 3) + else: + model_input["coord"] = coord_model + model_input["noise_mask"] = noise_mask + if use_dens: + model_input["force_input"] = force_label + return model_input, force_label, noise_target, noise_mask, use_dens + + @staticmethod + def _get_sezm_atomic_model(model: torch.nn.Module) -> Any: + """Return the SeZM atomic model used by `dens` training.""" + atomic_model = getattr(model, "atomic_model", None) + if atomic_model is None: + raise TypeError("SeZM `dens` loss expects `model.atomic_model` to exist.") + required = ( + "norm_dens_energy", + "denorm_dens_energy", + "norm_dens_force", + "denorm_dens_force", + ) + missing = [name for name in required if not hasattr(atomic_model, name)] + if missing: + missing_str = ", ".join(sorted(missing)) + raise TypeError( + f"SeZM `dens` loss requires atomic_model methods: {missing_str}." + ) + return atomic_model + + def _compute_force_subset_loss( + self, + force_pred: torch.Tensor, + force_target: torch.Tensor, + coefficient: float | torch.Tensor, + ) -> torch.Tensor: + """Compute one clean-force or denoising-force subset loss.""" + if force_pred.numel() == 0: + return force_pred.new_zeros((), dtype=GLOBAL_PT_FLOAT_PRECISION) + diff_f = (force_target - force_pred).reshape(-1) + if self.loss_func == "mse": + subset_loss = torch.mean(torch.square(diff_f)) + elif self.loss_func == "mae": + subset_loss = torch.linalg.vector_norm( + (force_target - force_pred).reshape(-1, 3), + ord=2, + dim=1, + keepdim=True, + ).mean() + else: + raise NotImplementedError( + f"Loss type {self.loss_func} is not implemented for `dens` force loss." + ) + return (coefficient * subset_loss).to(GLOBAL_PT_FLOAT_PRECISION) + + def forward( + self, + input_dict: dict[str, torch.Tensor], + model: torch.nn.Module, + label: dict[str, torch.Tensor], + natoms: int, + learning_rate: float, + mae: bool = False, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor, dict[str, torch.Tensor]]: + """Return loss on SeZM `dens` energy and direct-force/noise outputs.""" + model_input, force_label, noise_target, noise_mask, use_dens = ( + self._prepare_dens_inputs( + input_dict, + label, + enable_dens=model.training, + ) + ) + model_pred = model(**model_input) + atomic_model = self._get_sezm_atomic_model(model) + + coef = learning_rate / self.starter_learning_rate + pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef + pref_f = self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * coef + denoise_pref = self.dens_denoising_pos_coefficient + + loss = force_label.new_zeros((), dtype=env.GLOBAL_PT_FLOAT_PRECISION) + more_loss: dict[str, torch.Tensor] = {} + atom_norm = 1.0 / natoms + + if self.has_e and "energy" in model_pred and "energy" in label: + energy_pred = model_pred.get("energy_norm", model_pred["energy"]) + energy_label = label["energy"].to( + device=energy_pred.device, dtype=energy_pred.dtype + ) + energy_label_norm = atomic_model.norm_dens_energy( + energy_label, + input_dict["atype"], + ) + if "energy_norm" in model_pred: + energy_pred_phys = model_pred["energy"].to( + device=energy_pred.device, + dtype=energy_pred.dtype, + ) + else: + energy_pred_phys = atomic_model.denorm_dens_energy( + energy_pred, + input_dict["atype"], + ) + find_energy = label.get("find_energy", 0.0) + pref_e = pref_e * find_energy + if self.loss_func == "mse": + l2_ener_loss = torch.mean(torch.square(energy_pred - energy_label_norm)) + if not self.inference: + more_loss["l2_ener_loss"] = self.display_if_exist( + l2_ener_loss.detach(), + find_energy, + ) + loss += atom_norm * (pref_e * l2_ener_loss) + rmse_e = ( + torch.mean(torch.square(energy_pred_phys - energy_label)).sqrt() + * atom_norm + ) + more_loss["rmse_e"] = self.display_if_exist( + rmse_e.detach(), + find_energy, + ) + elif self.loss_func == "mae": + l1_ener_loss = F.l1_loss( + energy_pred.reshape(-1), + energy_label_norm.reshape(-1), + reduction="mean", + ) + loss += atom_norm * (pref_e * l1_ener_loss) + mae_e = ( + torch.mean(torch.abs(energy_pred_phys - energy_label)) * atom_norm + ) + more_loss["mae_e"] = self.display_if_exist( + mae_e.detach(), + find_energy, + ) + else: + raise NotImplementedError( + f"Loss type {self.loss_func} is not implemented for `dens` energy loss." + ) + if mae: + mae_e = ( + torch.mean(torch.abs(energy_pred_phys - energy_label)) * atom_norm + ) + more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy) + mae_e_all = torch.mean(torch.abs(energy_pred_phys - energy_label)) + more_loss["mae_e_all"] = self.display_if_exist( + mae_e_all.detach(), + find_energy, + ) + + if self.has_f and "force" in model_pred and "force" in label: + find_force = label.get("find_force", 0.0) + clean_force_pred_norm = self._canonicalize_vec3_tensor( + model_pred.get( + "clean_force_norm", + model_pred.get("force_norm", model_pred["force"]), + ), + nf=force_label.shape[0], + nloc=force_label.shape[1], + name="predicted normalized clean force", + ) + denoising_force_pred_norm = self._canonicalize_vec3_tensor( + model_pred.get( + "denoising_force_norm", + model_pred.get("force_norm", model_pred["force"]), + ), + nf=force_label.shape[0], + nloc=force_label.shape[1], + name="predicted normalized denoising force", + ) + if "force_norm" in model_pred: + force_pred_phys = self._canonicalize_vec3_tensor( + model_pred["force"], + nf=force_label.shape[0], + nloc=force_label.shape[1], + name="predicted physical force", + ) + else: + force_pred_phys = atomic_model.denorm_dens_force(clean_force_pred_norm) + force_target_norm = atomic_model.norm_dens_force(force_label) + clean_mask = ~noise_mask + noise_only_mask = noise_mask if use_dens else torch.zeros_like(noise_mask) + clean_fraction = clean_mask.to(dtype=GLOBAL_PT_FLOAT_PRECISION).mean() + noise_fraction = noise_only_mask.to(dtype=GLOBAL_PT_FLOAT_PRECISION).mean() + clean_force_loss = self._compute_force_subset_loss( + clean_force_pred_norm[clean_mask].reshape(-1, 3), + force_target_norm[clean_mask].reshape(-1, 3), + coefficient=(pref_f * find_force) * clean_fraction, + ) + loss += clean_force_loss + if use_dens: + noise_force_loss = self._compute_force_subset_loss( + denoising_force_pred_norm[noise_only_mask].reshape(-1, 3), + noise_target[noise_only_mask].reshape(-1, 3), + coefficient=(denoise_pref * find_force) * noise_fraction, + ) + loss += noise_force_loss + if self.loss_func == "mse": + diff_clean = clean_force_pred_norm[clean_mask].reshape( + -1, 3 + ) - force_target_norm[clean_mask].reshape(-1, 3) + diff_noise = denoising_force_pred_norm[noise_only_mask].reshape( + -1, 3 + ) - noise_target[noise_only_mask].reshape(-1, 3) + l2_num = torch.sum(torch.square(diff_clean)) + l2_den = max(diff_clean.numel(), 1) + if noise_count := int(noise_only_mask.sum().item()): + l2_num = l2_num + torch.sum(torch.square(diff_noise)) + l2_den += diff_noise.numel() + l2_force_loss = l2_num / l2_den + if not self.inference: + more_loss["l2_force_loss"] = self.display_if_exist( + l2_force_loss.detach(), + find_force, + ) + elif self.loss_func == "mae": + pass + clean_count = int(clean_mask.sum().item()) + if clean_count > 0: + clean_force_pred_phys = force_pred_phys[clean_mask].reshape(-1, 3) + clean_force_label_phys = force_label[clean_mask].reshape(-1, 3) + if self.loss_func == "mse": + clean_rmse_f = torch.mean( + torch.square(clean_force_pred_phys - clean_force_label_phys) + ).sqrt() + more_loss["rmse_f"] = self.display_if_exist( + clean_rmse_f.detach(), + find_force, + ) + elif self.loss_func == "mae": + clean_mae_f = torch.linalg.vector_norm( + clean_force_pred_phys - clean_force_label_phys, + ord=2, + dim=1, + keepdim=True, + ).mean() + more_loss["mae_f"] = self.display_if_exist( + clean_mae_f.detach(), + find_force, + ) + if not self.inference: + more_loss["rmse"] = torch.sqrt(loss.detach()) + return model_pred, loss, more_loss + + def serialize(self) -> dict: + """Serialize the `dens` loss.""" + return { + "@class": "DeNSLoss", + "@version": 1, + "starter_learning_rate": self.starter_learning_rate, + "start_pref_e": self.start_pref_e, + "limit_pref_e": self.limit_pref_e, + "start_pref_f": self.start_pref_f, + "limit_pref_f": self.limit_pref_f, + "loss_func": self.loss_func, + "dens_prob": self.dens_prob, + "dens_fixed_noise_std": self.dens_fixed_noise_std, + "dens_std": self.dens_std, + "dens_corrupt_ratio": self.dens_corrupt_ratio, + "dens_denoising_pos_coefficient": self.dens_denoising_pos_coefficient, + } + + @classmethod + def deserialize(cls, data: dict) -> "DeNSLoss": + """Deserialize the `dens` loss.""" + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + return cls(**data) diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index 4da9bf781b..218031ea81 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -42,6 +42,9 @@ from .property_atomic_model import ( DPPropertyAtomicModel, ) +from .sezm_atomic_model import ( + SeZMAtomicModel, +) __all__ = [ "BaseAtomicModel", @@ -54,4 +57,5 @@ "DPZBLLinearEnergyAtomicModel", "LinearEnergyAtomicModel", "PairTabAtomicModel", + "SeZMAtomicModel", ] diff --git a/deepmd/pt/model/atomic_model/sezm_atomic_model.py b/deepmd/pt/model/atomic_model/sezm_atomic_model.py new file mode 100644 index 0000000000..cfc1de910b --- /dev/null +++ b/deepmd/pt/model/atomic_model/sezm_atomic_model.py @@ -0,0 +1,788 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""SeZM atomic model definitions.""" + +from __future__ import ( + annotations, +) + +import copy +import math +from typing import ( + TYPE_CHECKING, + Any, +) + +import numpy as np +import torch + +from deepmd.pt.model.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.pt.model.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt.model.descriptor.sezm_nn import ( + SeZMDeNSFittingNet, +) +from deepmd.pt.model.task.base_fitting import ( + BaseFitting, +) +from deepmd.pt.model.task.ener import ( + EnergyFittingNet, + EnergyFittingNetDirect, + InvarFitting, +) +from deepmd.pt.model.task.sezm_ener import ( + SeZMEnergyFittingNet, +) +from deepmd.pt.utils.utils import ( + to_torch_tensor, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +if TYPE_CHECKING: + from deepmd.dpmodel import ( + FittingOutputDef, + ) + from deepmd.utils.path import ( + DPPath, + ) + + +class SeZMAtomicModel(DPAtomicModel): + """Atomic model scaffold for SeZM parallel `ener` / `dens` fitting. + + Parameters + ---------- + descriptor + Descriptor instance. + fitting + Standard `ener` fitting network instance. + dens_fitting + Optional parallel `dens` fitting network instance. + type_map + Atom type map. + active_mode + Default active execution mode. + **kwargs + Additional keyword arguments forwarded to DPAtomicModel. + + Raises + ------ + TypeError + If fitting is not an energy fitting network. + """ + + def __init__( + self, + descriptor: Any, + fitting: Any, + type_map: Any, + dens_fitting: Any | None = None, + active_mode: str | None = None, + **kwargs: Any, + ) -> None: + if not ( + isinstance(fitting, EnergyFittingNet) + or isinstance(fitting, EnergyFittingNetDirect) + or isinstance(fitting, InvarFitting) + ): + raise TypeError( + "fitting must be an instance of EnergyFittingNet, EnergyFittingNetDirect or InvarFitting for SeZMAtomicModel" + ) + if dens_fitting is not None and not isinstance( + dens_fitting, SeZMDeNSFittingNet + ): + raise TypeError( + "dens_fitting must be an instance of SeZMDeNSFittingNet for SeZMAtomicModel" + ) + super().__init__(descriptor, fitting, type_map, **kwargs) + self.register_buffer( + "dens_force_rmsd", + self.out_std.new_tensor(1.0), + ) + self.dens_fitting_net = dens_fitting + # Start unlocked when `active_mode` is not provided. + # The mode will be decided later by training setup (`loss.type`) + # or inferred from checkpoint contents during state_dict loading. + self._mode_locked = active_mode is not None + self._active_mode = "ener" + if active_mode is not None: + self.set_active_mode(active_mode) + + def _load_from_state_dict( + self, + state_dict: dict[str, torch.Tensor], + prefix: str, + local_metadata: dict[str, Any], + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], + ) -> None: + """Materialize the optional `dens` head before recursive loading.""" + dens_rmsd_key = prefix + "dens_force_rmsd" + if dens_rmsd_key not in state_dict: + state_dict[dens_rmsd_key] = self.dens_force_rmsd.data.clone() + has_dens_state = any( + key.startswith(prefix + "dens_fitting_net.") for key in state_dict + ) + if self.dens_fitting_net is None and has_dens_state: + self._ensure_dens_fitting_net() + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + # Training mode should normally come from `loss.type`. + # This is only a fallback for bare state_dict loads when mode was not restored. + if has_dens_state and not self._mode_locked: + self._active_mode = "dens" + + def get_active_mode(self) -> str: + """Return the current SeZM execution mode.""" + return str(getattr(self, "_active_mode", "ener")) + + def _compute_or_load_dens_force_stat( + self, + sampled_func: Any, + stat_file_path: DPPath | None = None, + ) -> None: + """ + Compute or load the SeZM `dens` direct-force RMSD scale. + + Parameters + ---------- + sampled_func + Packed statistics samples or a lazy callable that returns them. + stat_file_path + Statistics file path. + + Raises + ------ + ValueError + If force labels are unavailable for SeZM `dens` statistics. + """ + force_stat_path = ( + None if stat_file_path is None else stat_file_path / "rmsd_dforce" + ) + if force_stat_path is not None and force_stat_path.is_file(): + force_rmsd = float(np.asarray(force_stat_path.load_numpy()).reshape(-1)[0]) + else: + sampled = sampled_func() if callable(sampled_func) else sampled_func + force_square_sum = 0.0 + force_atom_count = 0 + for sample in sampled: + find_force = sample.get("find_force", 0.0) + if isinstance(find_force, torch.Tensor): + find_force = float(find_force.detach().cpu().item()) + if not bool(find_force): + continue + + force = sample.get("force") + atype = sample.get("atype") + if force is None or atype is None: + continue + + force_np = ( + force.detach().cpu().numpy() + if isinstance(force, torch.Tensor) + else np.asarray(force) + ) + atype_np = ( + atype.detach().cpu().numpy() + if isinstance(atype, torch.Tensor) + else np.asarray(atype) + ) + if force_np.ndim == 2 and atype_np.ndim == 2: + force_np = force_np.reshape(*atype_np.shape, 3) + if force_np.ndim != 3 or atype_np.ndim != 2: + raise ValueError( + "SeZM `dens` force statistics expect `force` with shape " + "(nf, nloc, 3) or (nf, nloc*3)." + ) + + atom_mask = atype_np >= 0 + exclude_types = sample.get("atom_exclude_types", []) + for type_idx in exclude_types: + atom_mask &= atype_np != type_idx + valid_force = force_np[atom_mask] + if valid_force.size == 0: + continue + force_square_sum += float(np.square(valid_force).sum()) + force_atom_count += int(valid_force.shape[0]) + + if force_atom_count == 0: + raise ValueError( + "SeZM `dens` statistics require atomic `force` labels so that " + "the global direct-force RMSD can be computed." + ) + force_rmsd = math.sqrt(force_square_sum / force_atom_count) + if force_stat_path is not None: + force_stat_path.save_numpy(np.asarray([force_rmsd], dtype=np.float64)) + + if force_rmsd <= 0.0: + raise ValueError("SeZM `dens` direct-force RMSD must be positive.") + self.dens_force_rmsd.copy_(self.dens_force_rmsd.new_tensor(force_rmsd)) + + def _get_dens_energy_stat_tensors( + self, + atype: torch.Tensor, + *, + dtype: torch.dtype, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Return the SeZM `dens` energy bias/std tensors derived from `out_stat`. + + Parameters + ---------- + atype + Local atom types with shape `(nf, nloc)`. + dtype + Target floating-point dtype. + device + Target device. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + Per-atom energy bias, per-atom broadcast energy std, and system-level + global energy std. + """ + out_bias, out_std = self._fetch_out_stat(["energy"]) + atom_mask = self.make_atom_mask(atype) + if self.atom_excl is not None: + atom_mask *= self.atom_excl(atype) + safe_atype = atype.clamp_min(0) + energy_bias_atom = out_bias["energy"][safe_atype].to(device=device, dtype=dtype) + energy_std_atom = out_std["energy"][safe_atype].to(device=device, dtype=dtype) + atom_mask_float = atom_mask.to(device=device, dtype=dtype).unsqueeze(-1) + energy_bias_atom = energy_bias_atom * atom_mask_float + energy_std_atom = energy_std_atom * atom_mask_float + energy_std = out_std["energy"][0].to(device=device, dtype=dtype).view(1, -1) + return energy_bias_atom, energy_std_atom, energy_std + + def norm_dens_energy( + self, + energy: torch.Tensor, + atype: torch.Tensor, + ) -> torch.Tensor: + """ + Normalize `dens` system energies using the standard energy bias and + the global residual std. + + Parameters + ---------- + energy + System energy tensor. + atype + Local atom types with shape `(nf, nloc)`. + + Returns + ------- + torch.Tensor + Normalized energy tensor. + """ + energy_bias_atom, _, energy_std = self._get_dens_energy_stat_tensors( + atype, + dtype=energy.dtype, + device=energy.device, + ) + energy_bias = energy_bias_atom.sum(dim=1) + return (energy - energy_bias) / energy_std + + def denorm_dens_energy( + self, + energy: torch.Tensor, + atype: torch.Tensor, + ) -> torch.Tensor: + """ + Denormalize `dens` system energies using the standard energy bias + and the global residual std. + + Parameters + ---------- + energy + Normalized system energy tensor. + atype + Local atom types with shape `(nf, nloc)`. + + Returns + ------- + torch.Tensor + Physical energy tensor. + """ + energy_bias_atom, _, energy_std = self._get_dens_energy_stat_tensors( + atype, + dtype=energy.dtype, + device=energy.device, + ) + energy_bias = energy_bias_atom.sum(dim=1) + return energy * energy_std + energy_bias + + def norm_dens_force(self, force: torch.Tensor) -> torch.Tensor: + """ + Normalize `dens` direct-force targets with the global RMSD. + + Parameters + ---------- + force + Physical direct-force tensor. + + Returns + ------- + torch.Tensor + Normalized force tensor. + """ + force_rmsd = self.dens_force_rmsd.to(device=force.device, dtype=force.dtype) + return force / force_rmsd + + def denorm_dens_force(self, force: torch.Tensor) -> torch.Tensor: + """ + Denormalize `dens` direct-force predictions with the global RMSD. + + Parameters + ---------- + force + Normalized direct-force tensor. + + Returns + ------- + torch.Tensor + Physical direct-force tensor. + """ + force_rmsd = self.dens_force_rmsd.to(device=force.device, dtype=force.dtype) + return force * force_rmsd + + def apply_out_stat_dens( + self, + ret: dict[str, torch.Tensor], + atype: torch.Tensor, + *, + noise_mask: torch.Tensor, + energy_redu_dtype: torch.dtype, + ) -> dict[str, torch.Tensor]: + """ + Apply SeZM `dens` output-stat semantics for both normalized training + outputs and public physical predictions. + + Parameters + ---------- + ret + Raw normalized `dens` outputs with keys `energy`, `clean_dforce`, and + `denoising_dforce`. + atype + Local atom types with shape `(nf, nloc)`. + noise_mask + Corruption mask with shape `(nf, nloc)`. + energy_redu_dtype + Reduction dtype used for summed system energies. + + Returns + ------- + dict[str, torch.Tensor] + Outputs carrying normalized tensors for loss calculation together + with public DeePMD-style physical predictions. + """ + atom_mask = self.make_atom_mask(atype).to(torch.int32) + if self.atom_excl is not None: + atom_mask *= self.atom_excl(atype) + + atom_mask_float = atom_mask.to(dtype=ret["energy"].dtype) + energy_bias_atom, energy_std_atom, _ = self._get_dens_energy_stat_tensors( + atype, + dtype=ret["energy"].dtype, + device=ret["energy"].device, + ) + energy_norm = ret["energy"] * atom_mask_float.unsqueeze(-1) + energy = energy_norm * energy_std_atom + energy_bias_atom + energy_redu_norm = torch.sum(energy_norm.to(energy_redu_dtype), dim=1) + energy_redu = torch.sum(energy.to(energy_redu_dtype), dim=1) + + clean_dforce_norm = ret["clean_dforce"] * atom_mask.to( + dtype=ret["clean_dforce"].dtype + ).unsqueeze(-1) + denoising_dforce_norm = ret["denoising_dforce"] * atom_mask.to( + dtype=ret["denoising_dforce"].dtype + ).unsqueeze(-1) + dforce_norm = torch.where( + noise_mask.unsqueeze(-1), + denoising_dforce_norm, + clean_dforce_norm, + ) + clean_dforce = self.denorm_dens_force(clean_dforce_norm) + return { + "energy": energy, + "energy_redu": energy_redu, + "dforce": clean_dforce, + "energy_norm": energy_redu_norm, + "atom_energy_norm": energy_norm, + "dforce_norm": dforce_norm, + "clean_dforce_norm": clean_dforce_norm, + "denoising_dforce_norm": denoising_dforce_norm, + "mask": atom_mask, + } + + def _ensure_dens_fitting_net(self) -> SeZMDeNSFittingNet: + """ + Materialize the optional `dens` fitting head from the current energy head. + + Returns + ------- + SeZMDeNSFittingNet + The existing or newly created `dens` fitting head. + """ + dens_fitting = getattr(self, "dens_fitting_net", None) + if dens_fitting is not None: + return dens_fitting + self.dens_fitting_net = SeZMDeNSFittingNet(**self._build_dens_fitting_kwargs()) + return self.dens_fitting_net + + def get_dens_fitting_net(self) -> SeZMDeNSFittingNet: + """Return the `dens` fitting head, materializing it on demand.""" + return self._ensure_dens_fitting_net() + + def set_active_mode(self, mode: str) -> None: + """ + Switch the active SeZM execution mode. + + Parameters + ---------- + mode + Target mode. Must be `ener` or `dens`. + """ + normalized = str(mode).lower() + if normalized not in {"ener", "dens"}: + raise ValueError(f"Unsupported SeZM mode: {mode!r}") + if normalized == "dens": + self._ensure_dens_fitting_net() + self._mode_locked = True + self._active_mode = normalized + + def get_active_fitting_net(self) -> Any: + """Return the fitting network selected by the current active mode.""" + if self.get_active_mode() == "dens": + return self._ensure_dens_fitting_net() + return self.fitting_net + + def reset_head_for_mode(self, mode: str) -> None: + """ + Reinitialize the fitting head of certain mode from stored kwargs. + + Parameters + ---------- + mode + Target mode to reset. + """ + normalized = str(mode).lower() + if normalized == "ener": + self.fitting_net = SeZMEnergyFittingNet(**self._build_ener_fitting_kwargs()) + elif normalized == "dens": + self.dens_fitting_net = None + self._ensure_dens_fitting_net() + else: + raise ValueError(f"Unsupported SeZM mode: {mode!r}") + + @torch.jit.unused + def fitting_output_def(self) -> FittingOutputDef: + """Return the fitting output definition of the active SeZM mode.""" + active_fitting = self.get_active_fitting_net() + if active_fitting is None: + return super().fitting_output_def() + return active_fitting.output_def() + + def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: + """ + Set the fitting-last-layer evaluation hook for the active fitting path. + + Parameters + ---------- + enable + Whether to enable the hook. + """ + self.enable_eval_fitting_last_layer_hook = enable + active_fitting = self.get_active_fitting_net() + if active_fitting is not None and hasattr( + active_fitting, "set_return_middle_output" + ): + active_fitting.set_return_middle_output(enable) + self.eval_fitting_last_layer_list.clear() + + def change_type_map( + self, + type_map: list[str], + model_with_new_type_stat: SeZMAtomicModel | None = None, + ) -> None: + """ + Change the type map for the descriptor and both SeZM fitting heads. + + Parameters + ---------- + type_map + New atom type map. + model_with_new_type_stat + Optional reference model that carries new-type statistics. + """ + super().change_type_map( + type_map=type_map, + model_with_new_type_stat=model_with_new_type_stat, + ) + if self.dens_fitting_net is not None: + ref_dens = ( + None + if model_with_new_type_stat is None + else model_with_new_type_stat.dens_fitting_net + ) + self.dens_fitting_net.change_type_map( + type_map=type_map, + model_with_new_type_stat=ref_dens, + ) + + def compute_or_load_stat( + self, + sampled_func: Any, + stat_file_path: Any = None, + compute_or_load_out_stat: bool = True, + preset_observed_type: list[str] | None = None, + ) -> None: + """ + Compute/load SeZM statistics for the active execution mode. + + Parameters + ---------- + sampled_func + Lazy sampler providing training frames. + stat_file_path + Statistics file path. + compute_or_load_out_stat + Whether to compute or load output statistics. `dens` mode keeps the + standard `ener`-branch statistics intact and additionally fits one + global direct-force RMSD scale for the normalized DeNS training + path. The `dens` energy path reuses the standard per-type energy bias + and the broadcast global residual std already stored in `out_stat`. + preset_observed_type + Optional observed-type override. + """ + original_mode = self.get_active_mode() + if stat_file_path is not None and self.type_map is not None: + stat_file_path /= " ".join(self.type_map) + + wrapped_sampler = self._make_wrapped_sampler(sampled_func) + self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) + self.compute_fitting_input_stat(wrapped_sampler, stat_file_path) + if compute_or_load_out_stat: + self.set_active_mode("ener") + try: + self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + finally: + self.set_active_mode(original_mode) + if original_mode == "dens": + self._compute_or_load_dens_force_stat(wrapped_sampler, stat_file_path) + + self._collect_and_set_observed_type( + wrapped_sampler, + stat_file_path, + preset_observed_type, + ) + + def apply_out_stat( + self, + ret: dict[str, torch.Tensor], + atype: torch.Tensor, + ) -> dict[str, torch.Tensor]: + """ + Apply SeZM-specific output statistics. + + Parameters + ---------- + ret + Atomic fitting outputs. + atype + Local atom types with shape `(nf, nloc)`. + + Returns + ------- + dict[str, torch.Tensor] + Outputs after SeZM output-stat post-processing. + """ + if "energy" in ret: + out_bias, _ = self._fetch_out_stat(["energy"]) + ret["energy"] = ret["energy"] + out_bias["energy"][atype] + return ret + + def get_dim_fparam(self) -> int: + """Return frame-parameter width of the active SeZM branch.""" + active_fitting = self.get_active_fitting_net() + if active_fitting is not None and hasattr(active_fitting, "get_dim_fparam"): + return active_fitting.get_dim_fparam() + return super().get_dim_fparam() + + def has_default_fparam(self) -> bool: + """Return whether the active SeZM branch has default frame parameters.""" + active_fitting = self.get_active_fitting_net() + if active_fitting is not None and hasattr(active_fitting, "has_default_fparam"): + return active_fitting.has_default_fparam() + return super().has_default_fparam() + + def get_default_fparam(self) -> torch.Tensor | None: + """Return default frame parameters of the active SeZM branch.""" + active_fitting = self.get_active_fitting_net() + if active_fitting is not None and hasattr(active_fitting, "get_default_fparam"): + return active_fitting.get_default_fparam() + return super().get_default_fparam() + + def has_chg_spin_ebd(self) -> bool: + """Return whether charge/spin condition embedding is enabled.""" + return bool(getattr(self.descriptor, "add_chg_spin_ebd", False)) + + def get_dim_chg_spin(self) -> int: + """Return charge/spin condition width.""" + if self.has_chg_spin_ebd() and hasattr(self.descriptor, "get_dim_chg_spin"): + return self.descriptor.get_dim_chg_spin() + return 0 + + def has_default_chg_spin(self) -> bool: + """Return whether default charge/spin conditions are configured.""" + if self.has_chg_spin_ebd() and hasattr(self.descriptor, "has_default_chg_spin"): + return self.descriptor.has_default_chg_spin() + return False + + def get_default_chg_spin(self) -> torch.Tensor | None: + """Return default charge/spin conditions as a tensor.""" + if self.has_chg_spin_ebd() and hasattr(self.descriptor, "get_default_chg_spin"): + default_chg_spin = self.descriptor.get_default_chg_spin() + if default_chg_spin is not None: + return self.out_std.new_tensor(default_chg_spin) + return None + + def get_dim_aparam(self) -> int: + """Return atomic-parameter width of the active SeZM branch.""" + active_fitting = self.get_active_fitting_net() + if active_fitting is not None and hasattr(active_fitting, "get_dim_aparam"): + return active_fitting.get_dim_aparam() + return super().get_dim_aparam() + + def get_sel_type(self) -> list[int]: + """Return selected atom types of the active SeZM branch.""" + active_fitting = self.get_active_fitting_net() + if active_fitting is not None and hasattr(active_fitting, "get_sel_type"): + return active_fitting.get_sel_type() + return super().get_sel_type() + + def serialize(self) -> dict: + """Serialize the SeZM atomic model including the optional `dens` head.""" + data = DPAtomicModel.serialize(self) + data["@variables"]["dens_force_rmsd"] = ( + self.dens_force_rmsd.detach().cpu().numpy() + ) + data.update( + { + "@version": 3, + "type": "sezm_atomic", + "dens_fitting": None + if self.dens_fitting_net is None + else self.dens_fitting_net.serialize(), + "active_mode": self.get_active_mode(), + } + ) + return data + + def _build_ener_fitting_kwargs(self) -> dict[str, Any]: + """Reconstruct SeZM energy-head kwargs from the current fitting head.""" + fitting = self.fitting_net + return { + "ntypes": int(fitting.ntypes), + "dim_descrpt": int(fitting.dim_descrpt), + "neuron": copy.deepcopy(list(fitting.neuron)), + "bias_atom_e": None + if fitting.bias_atom_e is None + else fitting.bias_atom_e.detach().cpu().numpy().copy(), + "resnet_dt": bool(fitting.resnet_dt), + "numb_fparam": int(fitting.numb_fparam), + "numb_aparam": int(fitting.numb_aparam), + "dim_case_embd": int(fitting.dim_case_embd), + "case_film_embd": bool(getattr(fitting, "case_film_embd", False)), + "activation_function": str(fitting.activation_function), + "bias_out": bool(getattr(fitting, "bias_out", False)), + "precision": str(fitting.precision), + "mixed_types": bool(fitting.mixed_types), + "seed": copy.deepcopy(fitting.seed), + "type_map": None if fitting.type_map is None else list(fitting.type_map), + "default_fparam": copy.deepcopy(fitting.default_fparam), + "rcond": fitting.rcond, + "exclude_types": copy.deepcopy(fitting.exclude_types), + "trainable": copy.deepcopy(fitting.trainable), + "atom_ener": copy.deepcopy(fitting.atom_ener), + "use_aparam_as_mask": bool(fitting.use_aparam_as_mask), + } + + def _build_dens_fitting_kwargs(self) -> dict[str, Any]: + """Reconstruct SeZM `dens`-head kwargs from energy head and descriptor.""" + descriptor = self.descriptor + kwargs = self._build_ener_fitting_kwargs() + kwargs["condition_lmax"] = int(descriptor.l_schedule[0]) + kwargs["latent_lmax"] = int(descriptor.l_schedule[-1]) + kwargs["channels"] = int(descriptor.channels) + return kwargs + + @classmethod + def deserialize(cls, data: dict) -> SeZMAtomicModel: + """ + Deserialize the SeZM atomic model. + + Parameters + ---------- + data + Serialized atomic-model data. + + Returns + ------- + SeZMAtomicModel + Deserialized SeZM atomic model. + """ + payload = data.copy() + version = int(payload.pop("@version", 2)) + check_version_compatibility(version, 3, 2) + payload.pop("@class", None) + payload.pop("type", None) + + descriptor_obj = BaseDescriptor.deserialize(payload.pop("descriptor")) + fitting_payload = payload.pop("fitting") + fitting_obj = BaseFitting.deserialize(fitting_payload) + dens_payload = payload.pop("dens_fitting", None) + dens_obj = ( + None + if dens_payload is None + else SeZMDeNSFittingNet.deserialize(dens_payload) + ) + active_mode = payload.pop("active_mode", None) + payload["descriptor"] = descriptor_obj + payload["fitting"] = fitting_obj + payload["dens_fitting"] = dens_obj + payload["active_mode"] = active_mode + variables = payload.pop("@variables", None) + obj = cls(**payload) + variables = ( + {"out_bias": None, "out_std": None} if variables is None else variables + ) + obj["out_bias"] = ( + to_torch_tensor(variables["out_bias"]) + if variables["out_bias"] is not None + else obj._default_bias() + ) + obj["out_std"] = ( + to_torch_tensor(variables["out_std"]) + if variables["out_std"] is not None + else obj._default_std() + ) + dens_force_rmsd = variables.get("dens_force_rmsd") + if dens_force_rmsd is not None: + obj.dens_force_rmsd.copy_(to_torch_tensor(dens_force_rmsd)) + return obj diff --git a/deepmd/pt/model/descriptor/__init__.py b/deepmd/pt/model/descriptor/__init__.py index 9f3468d1db..6da6c3b864 100644 --- a/deepmd/pt/model/descriptor/__init__.py +++ b/deepmd/pt/model/descriptor/__init__.py @@ -42,6 +42,9 @@ DescrptBlockSeTTebd, DescrptSeTTebd, ) +from .sezm import ( + DescrptSeZM, +) __all__ = [ "BaseDescriptor", @@ -59,6 +62,7 @@ "DescrptSeR", "DescrptSeT", "DescrptSeTTebd", + "DescrptSeZM", "make_default_type_embedding", "prod_env_mat", ] diff --git a/deepmd/pt/model/descriptor/sezm.py b/deepmd/pt/model/descriptor/sezm.py new file mode 100644 index 0000000000..1821317feb --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm.py @@ -0,0 +1,1953 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +SeZM: The descriptor of smooth equivariant Zone-bridging Model. + +PyTorch backend + +This implementation is designed around two non-negotiables: + +1) Conservative forces: the descriptor is computed from differentiable energy. +2) Speed-first inference: edge geometry and Wigner-D rotation blocks are computed + exactly once per `forward()` and reused by all interaction blocks. + +Shared descriptor building blocks are re-exported by `sezm_nn/__init__.py`. + +Runtime flow at a glance: +1) Build edge cache and radial features once. +2) Run interaction blocks with shared geometric caches. +3) Return scalar (`l=0`) descriptor channels for fitting. + +Layout notes +------------ +- Node-level backbone features use contiguous `(N, D, 1, C)` where + `D=(lmax+1)^2` and `C=channels`. +- The singleton focus axis is kept only to reuse the existing equivariant + operators; real multi-focus structure lives strictly inside `SO2Convolution`. +- Edge-level SO(2) internal operators keep m-major reduced layout + `(E, F, D_m_trunc, Cf)` with `F=n_focus` and `Cf=focus_dim` inside the + SO(2) branch only. +""" + +from __future__ import ( + annotations, +) + +import math +import os +from contextlib import ( + contextmanager, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch +import torch.nn as nn +from einops import ( + rearrange, +) + +from deepmd.dpmodel.utils import EnvMat as DPEnvMat +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISION_DICT, +) +from deepmd.pt.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.pt.utils.update_sel import ( + UpdateSel, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_descriptor import ( + BaseDescriptor, +) +from .sezm_nn import ( + ATTN_RES_MODES, + BridgingSwitch, + C3CutoffEnvelope, + ChargeSpinEmbedding, + DepthAttnRes, + EdgeFeatureCache, + EnvironmentInitialEmbedding, + EquivariantFFN, + GeometricInitialEmbedding, + InnerClamp, + RadialBasis, + RadialMLP, + ScalarRMSNorm, + SeZMInteractionBlock, + SeZMTypeEmbedding, + WignerDCalculator, + build_edge_cache, + build_edge_cache_from_edges, + edge_cache_to_dtype, + fold_lora_state_dict_keys, + get_promoted_dtype, + get_so3_dim_of_lmax, + has_lora, + np_safe, + nvtx_range, + safe_numpy_to_tensor, +) + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Generator, + ) + + from deepmd.utils.data_system import ( + DeepmdDataSystem, + ) + from deepmd.utils.path import ( + DPPath, + ) + + +@BaseDescriptor.register("SeZM") +@BaseDescriptor.register("sezm") +@BaseDescriptor.register("dpa4") +class DescrptSeZM(BaseDescriptor, nn.Module): + """ + SeZM: The descriptor of smooth equivariant Zone-bridging Model for DeePMD-kit. + + Execution outline + ----------------- + 1. Build a per-forward `EdgeFeatureCache` (geometry, envelope, Wigner-D). + 2. Build radial/type edge features once and reuse across blocks. + 3. Run `SeZMInteractionBlock` stack with optional l/m schedules. + 4. Extract scalar channels and apply the final scalar FFN. + + Parameters + ---------- + ntypes + Number of element types. + sel + Maximum number of neighbors per type within `rcut`. + - int: broadcast to all types, e.g. sel=100 with ntypes=2 → [100, 100] + - list[int]: sel[i] is the maximum number of type i atoms within `rcut` + rcut + Cutoff radius in Å. + env_exp + C^3 cutoff envelope exponents `[rbf_env_exp, edge_env_exp]`. + - `rbf_env_exp`: Controls radial basis function envelope decay. + - `edge_env_exp`: Controls message passing edge weight envelope decay. + Larger values give weaker suppression (values stay near 1.0 longer). + channels + Total channels per (l,m) coefficient. + basis_type + Radial basis type. Supported values are ``"bessel"`` and ``"gaussian"``. + n_radial + Number of radial basis functions. + radial_mlp + Hidden layer sizes for radial networks. An output layer of size + `(l_schedule[0]+1)*channels` will be automatically appended. + use_env_seed + If True, apply environment matrix initial embedding as FiLM conditioning + on l=0 features using 4D `[s, s*r_hat]` representation. FiLM deltas are + normalized and scaled with learnable strengths initialized to small values. + Internal dimensions are derived from `channels`: + `embed_dim=min(channels, 128)`, + `axis_dim=min(4 if embed_dim < 64 else 8, embed_dim-1)`, + `type_dim=clamp(channels//4, 8, 32)`, + `rbf_out_dim=max(32, embed_dim-2*type_dim)`, + `hidden_dim=min(256, max(2*embed_dim, rbf_out_dim+2*type_dim))`. + random_gamma + If True, apply a random roll about the edge-aligned local ``+Z`` axis + before building the Wigner-D blocks. The roll is sampled independently + per edge and per forward call. + lmax + Maximum degree, only used when `l_schedule` is None. + l_schedule + Pyramid schedule of lmax per block, e.g. [3, 3, 2]. Must be non-increasing. + If set, lmax and n_blocks will be ignored. + mmax + Maximum SO(2) order (|m|), only used when `m_schedule` is None. + If None, defaults to the per-block `lmax` (i.e. `m_schedule = l_schedule`). + m_schedule + Schedule of mmax per block, e.g. [2, 2, 1, 0]. Must satisfy + `m_schedule[i] <= l_schedule[i]` for every block. A non-increasing schedule is + recommended but not required. If set, `mmax` will be ignored. + n_blocks + Number of blocks (only used when `l_schedule` is None). + so2_norm + If True, apply intermediate ReducedEquivariantRMSNorm between SO(2) mixing layers. + When False (default), no normalization is applied between layers. + so2_layers + Number of SO(2) mixing layers per block. + so2_attn_res + SO(2)-internal depth-wise attention residual mode inside each interaction + block. Must be one of ``"none"``, ``"independent"``, or ``"dependent"``. + radial_so2_mode + Dynamic radial degree mixer mode inside SO(2) convolution. ``"none"`` + applies elementwise radial modulation, ``"degree"`` uses a + channel-shared edge-conditioned cross-degree kernel, and + ``"degree_channel"`` uses a per-channel cross-degree kernel. + radial_so2_rank + Low-rank channel factorization rank for + ``radial_so2_mode="degree_channel"``. ``0`` uses the full + per-channel dynamic degree kernel. + n_focus + Number of parallel focus streams used only inside the SO(2) convolution. + Node-level backbone tensors still keep a singleton focus axis. + focus_dim + Hidden width per focus stream inside the SO(2) convolution. + ``focus_dim=0`` means using ``channels``. + n_atten_head + Number of attention heads when aggregating messages in SO(2) convolution. + 0 applies a plain envelope-weighted scatter-sum; >0 enables + envelope-gated grouped softmax attention with output-side head gate. + Attention uses ``w**2 * exp(logit)`` in the numerator and + ``zeta + sum(w**2 * exp(logit))`` in the denominator. + atten_f_mix + If True, merge all SO(2) focus streams into one attention stream after + rotate-back. Attention heads split ``n_focus * focus_dim`` instead of + each focus stream independently. + atten_v_proj + If True, apply an explicit degree-aware value projection inside SO(2) + attention. + atten_o_proj + If True, apply an explicit degree-aware output projection inside SO(2) + attention. + ffn_neurons + Hidden width for block FFNs and the final scalar output FFN. + If ``>0``, both paths use this width. + If ``=0``, each path resolves its own width from ``channels`` and its + effective GLU setting: ``4 * channels`` without GLU, ``(8 / 3) * channels`` + with GLU, then round up to a multiple of 32. + grid_mlp + If True, use the optional grid-MLP structure for the block-internal FFN + units. The final scalar output head is unchanged. + ffn_blocks + Number of FFN subblocks per interaction block. + sandwich_norm + Pre/post-norm switches for [SO(2), FFN] residual branches in order: + [so2_pre, so2_post, ffn_pre, ffn_post], shared across all blocks. + mlp_bias + Whether to use bias in equivariant layers. When False, removes bias from: + - SO3Linear: l=0 bias + - SO2Linear: l=0 bias + - GatedActivation: gate linear bias + - DepthAttnRes: input-dependent query projection + - EnvironmentInitialEmbedding: + rbf_proj_layer1/2 and g_layer1/2 + Attention projections in SO2Convolution + (attn_radial_logit_proj, attn_output_gate_proj) are always bias-free. + layer_scale + If True, apply learnable LayerScale (init 1e-3) on residual branches: + - SO(2) branch: per-focus-channel scales `(n_focus, focus_dim)` + on each SO(2) mixing layer. + - FFN branch: per-channel scales `(channels,)` on each FFN subblock. + full_attn_res + Descriptor-level full attention residual mode over the unit history + `[x0, so2_0, ffn_0_0, ffn_0_1, ..., so2_1, ffn_1_0, ffn_1_1, ...]`, + where each FFN subblock contributes its own completed unit + representation. `independent` uses learned query vectors, while + `dependent` derives queries from the current SeZM state before the + SO(2) unit, before each FFN unit, and before the final aggregation. + Must be one of ``"none"``, ``"independent"``, or ``"dependent"``. + block_attn_res + Descriptor-level block attention residual mode over the block history + `[x0, b1, b2, ...]`, where each `b_i` is the sum of all unit outputs + inside one `SeZMInteractionBlock`. `independent` uses learned query + vectors, while `dependent` derives queries from the current SeZM state + before the SO(2) unit, before each FFN unit, and before the final block + aggregation. Must be one of ``"none"``, ``"independent"``, or + ``"dependent"``. Cannot be enabled together with `full_attn_res`. + s2_activation + Two booleans ``[so2_enabled, ffn_enabled]``. + ``so2_enabled=True`` makes the SO(2) gated activation path use + ``activation_function="silu"``. + ``ffn_enabled=True`` makes the block-internal FFN path use + ``activation_function="silu"`` and ``glu_activation=True``. + S2-grid resolutions are resolved automatically per block. The e3nn + product grid uses ``[2 * mmax + 4, ceil_even(3 * lmax + 2)]`` in the + SO(2) branch, and the FFN branch lifts it to a square + ``[max(R_phi, R_theta), max(R_phi, R_theta)]`` grid. Lebedev branches + use the smallest packaged rule with precision at least ``3 * lmax``. + The final ``l=0`` output FFN is unchanged. + lebedev_quadrature + Either one boolean applied to both S2 branches, or two booleans + ``[so2_enabled, ffn_enabled]`` aligned with ``s2_activation``. If + enabled for a branch, that branch uses Lebedev quadrature instead of + the e3nn product grid in its S2 projector. + activation_function + Base activation function for helper MLPs, the SO(2) gated activation + path, and the final ``l=0`` output FFN. + It is overridden to ``"silu"`` only on paths whose ``s2_activation`` + switch is enabled. + glu_activation + Base GLU switch for FFN. The block-internal FFN path overrides it to + ``True`` only when ``s2_activation[1]=True``. The final ``l=0`` output + FFN always keeps this user-provided value. + use_amp + If True, use automatic mixed precision (AMP) with bfloat16 on CUDA. + This does not provide accelerations under fp32 precision but will decrease + the memory usage, while preserving model accuracy. + exclude_types + List of excluded type pairs. + precision + Precision for neural network parameters and computations. Geometry computations + (edge distances, Wigner-D matrices, rotations, GIE) always run in fp32+ to + provide accurate geometric information for better convergence. Only the + interaction blocks use this precision. + eps + Small epsilon for numerical stability in division and normalization. + trainable + Whether parameters are trainable. + seed + Random seed(s). + type_map + Type names. + inner_clamp_r_inner + Inner radius for distance saturation in Å. If both inner and outer radii + are set, the descriptor freezes short-range descriptor geometry inside + the zone-bridging window. + inner_clamp_r_outer + Outer radius for distance saturation in Å. + add_chg_spin_ebd + If True, add frame-level charge/spin condition embedding to scalar type + features before edge features are built. + default_chg_spin + Default frame-level charge/spin condition `[charge, spin]`. This value is + used when `add_chg_spin_ebd=True` and no explicit `charge_spin` tensor is + provided at the descriptor or SeZM model boundary. + + Notes + ----- + SeZM does not use the traditional environment matrix (r, a_x, a_y, a_z). + Instead, it uses radial basis functions and spherical harmonics directly. + The mean/stddev statistics are kept for interface compatibility but are not + actively used in the forward pass. + """ + + _ENV_DIM: int = 1 # Use se_r style (radial only) for EnvMatStatSe compatibility + + def __init__( + self, + ntypes: int, + sel: list[int] | int, + rcut: float = 6.0, + env_exp: list[int] | None = None, + channels: int = 64, + basis_type: str = "bessel", + n_radial: int = 16, + radial_mlp: list[int] | None = None, + use_env_seed: bool = True, + random_gamma: bool = True, + lmax: int = 3, + l_schedule: list[int] | None = None, + mmax: int | None = 1, + m_schedule: list[int] | None = None, + n_blocks: int = 3, + so2_norm: bool = False, + so2_layers: int = 4, + so2_attn_res: str = "none", + radial_so2_mode: str = "degree_channel", + radial_so2_rank: int = 1, + n_focus: int = 1, + focus_dim: int = 0, + n_atten_head: int = 1, + atten_f_mix: bool = False, + atten_v_proj: bool = False, + atten_o_proj: bool = False, + ffn_neurons: int = 0, + grid_mlp: bool = False, + ffn_blocks: int = 1, + sandwich_norm: list[bool] | None = None, + mlp_bias: bool = False, + layer_scale: bool = False, + full_attn_res: str = "none", + block_attn_res: str = "none", + s2_activation: list[bool] | None = None, + lebedev_quadrature: bool | list[bool] | None = True, + activation_function: str = "silu", + glu_activation: bool = True, + use_amp: bool = True, + exclude_types: list[tuple[int, int]] | None = None, + precision: str = "float32", + eps: float = 1e-7, + trainable: bool = True, + seed: int | list[int] | None = None, + type_map: list[str] | None = None, + inner_clamp_r_inner: float | None = None, + inner_clamp_r_outer: float | None = None, + add_chg_spin_ebd: bool = False, + default_chg_spin: list[float] | None = None, + **kwargs: Any, + ) -> None: + super().__init__() + + self.rcut = float(rcut) + if env_exp is None: + env_exp = [7, 5] + if len(env_exp) != 2: + raise ValueError( + "`env_exp` must be a list of two integers: [rbf_env_exp, edge_env_exp]" + ) + self.env_exp = [int(x) for x in env_exp] + self.eps = float(eps) + + if isinstance(sel, int): + sel = [sel] + self.ntypes = int(ntypes) + self.sel = [int(x) for x in sel] + self.type_map = type_map + self.nnei = int(sum(self.sel)) + self.ndescrpt = int(self.nnei * self._ENV_DIM) + + self.channels = int(channels) + self.n_focus = int(n_focus) + if self.n_focus < 1: + raise ValueError("`n_focus` must be >= 1") + self.focus_dim = int(focus_dim) + if self.focus_dim < 0: + raise ValueError("`focus_dim` must be >= 0") + self.basis_type = str(basis_type).lower() + self.n_radial = int(n_radial) + if radial_mlp is None: + radial_mlp = [0] + self.radial_mlp = [self.channels if x == 0 else int(x) for x in radial_mlp] + if sandwich_norm is None: + sandwich_norm = [False, True, True, False] + if not isinstance(sandwich_norm, (list, tuple)) or len(sandwich_norm) != 4: + raise ValueError( + "sandwich_norm must be a list[bool] of length 4: [so2_pre, so2_post, ffn_pre, ffn_post]" + ) + self.sandwich_norm = [bool(x) for x in sandwich_norm] + self.so2_pre_norm = self.sandwich_norm[0] + self.so2_post_norm = self.sandwich_norm[1] + self.ffn_pre_norm = self.sandwich_norm[2] + self.ffn_post_norm = self.sandwich_norm[3] + if s2_activation is None: + s2_activation = [False, True] + if not isinstance(s2_activation, list) or len(s2_activation) != 2: + raise ValueError( + "`s2_activation` must be a list[bool] of length 2: [so2_activation, ffn_activation]" + ) + if any(not isinstance(flag, bool) for flag in s2_activation): + raise ValueError( + "`s2_activation` must be a list[bool] of length 2: [so2_activation, ffn_activation]" + ) + self.s2_activation = list(s2_activation) + if lebedev_quadrature is None: + lebedev_quadrature = [False, False] + elif isinstance(lebedev_quadrature, bool): + lebedev_quadrature = [lebedev_quadrature, lebedev_quadrature] + if not isinstance(lebedev_quadrature, list) or len(lebedev_quadrature) != 2: + raise ValueError( + "`lebedev_quadrature` must be a bool or a list[bool] of length 2: [so2_quadrature, ffn_quadrature]" + ) + if any(not isinstance(flag, bool) for flag in lebedev_quadrature): + raise ValueError( + "`lebedev_quadrature` must be a bool or a list[bool] of length 2: [so2_quadrature, ffn_quadrature]" + ) + self.lebedev_quadrature = list(lebedev_quadrature) + self.activation_function = str(activation_function) + self.glu_activation = bool(glu_activation) + + # === Split effective activation config by branch === + self.so2_s2_activation = self.s2_activation[0] + self.ffn_s2_activation = self.s2_activation[1] + self.so2_lebedev_quadrature = self.lebedev_quadrature[0] + self.ffn_lebedev_quadrature = self.lebedev_quadrature[1] + self.so2_activation_function = ( + "silu" if self.so2_s2_activation else self.activation_function + ) + self.ffn_activation_function = ( + "silu" if self.ffn_s2_activation else self.activation_function + ) + self.ffn_glu_activation = ( + True if self.ffn_s2_activation else self.glu_activation + ) + self.out_activation_function = self.activation_function + self.out_glu_activation = self.glu_activation + self.precision = str(precision) + self.dtype = PRECISION_DICT[self.precision] + self.device = env.DEVICE + self.compute_dtype = get_promoted_dtype(self.dtype) + self.mlp_bias = bool(mlp_bias) + self.layer_scale = bool(layer_scale) + self.use_amp = bool(use_amp) # and self.training + self.trainable = bool(trainable) + self.use_triton = os.environ.get("DP_TRITON", "0").lower() in ( + "1", + "true", + "yes", + "on", + ) + self.seed = seed + self.random_gamma = bool(random_gamma) + self.add_chg_spin_ebd = bool(add_chg_spin_ebd) + if default_chg_spin is not None and len(default_chg_spin) != 2: + raise ValueError("`default_chg_spin` must contain [charge, spin].") + self.default_chg_spin = ( + None if default_chg_spin is None else [float(x) for x in default_chg_spin] + ) + + # === Zone bridging: InnerClamp + Source Freeze Propagation Gate === + # Both the geometry clamp (``InnerClamp``) and the message-passing + # switch (``BridgingSwitch``) are activated together on the same + # ``[r_inner, r_outer]`` window. The clamp freezes scalar distance + # on every ``(j, k)`` edge with ``r_{jk} < r_inner``; the switch + # feeds a per-edge C3 amplitude into ``compute_edge_src_gate`` so + # that any node with a frozen neighbor cannot propagate + # information through the GNN, closing the direction / multi-hop + # leakage channels that a pure ``InnerClamp`` cannot reach. Both + # modules are parameter-free, so enabling bridging does not add + # any keys to the descriptor's state dict. + self.inner_clamp_r_inner = ( + float(inner_clamp_r_inner) if inner_clamp_r_inner is not None else None + ) + self.inner_clamp_r_outer = ( + float(inner_clamp_r_outer) if inner_clamp_r_outer is not None else None + ) + if ( + self.inner_clamp_r_inner is not None + and self.inner_clamp_r_outer is not None + ): + self.inner_clamp: InnerClamp | None = InnerClamp( + self.inner_clamp_r_inner, self.inner_clamp_r_outer + ) + self.bridging_switch: BridgingSwitch | None = BridgingSwitch( + self.inner_clamp_r_inner, self.inner_clamp_r_outer + ) + else: + self.inner_clamp = None + self.bridging_switch = None + + # === Env seed parameters === + self.use_env_seed = bool(use_env_seed) + self.env_seed_embed_dim = min(self.channels, 128) + self.env_seed_type_dim = min(32, max(8, self.channels // 4)) + axis_dim = 4 if self.env_seed_embed_dim < 64 else 8 + self.env_seed_axis_dim = min(axis_dim, max(1, self.env_seed_embed_dim - 1)) + rbf_out_dim = max(32, self.env_seed_embed_dim - 2 * self.env_seed_type_dim) + g_in_dim = rbf_out_dim + 2 * self.env_seed_type_dim + self.env_seed_hidden_dim = min(256, max(2 * self.env_seed_embed_dim, g_in_dim)) + + # === Split deterministic seeds at the descriptor top-level === + seed_type_embedding = child_seed(self.seed, 0) + seed_blocks = child_seed(self.seed, 1) + seed_out = child_seed(self.seed, 2) + seed_radial_embedding = child_seed(self.seed, 3) + seed_env_seed = child_seed(self.seed, 4) + seed_full_attn = child_seed(self.seed, 5) + seed_block_attn = child_seed(self.seed, 6) + seed_charge_spin = child_seed(self.seed, 7) + + # === L/M schedules === + self._init_lm_schedules(lmax, n_blocks, l_schedule, mmax, m_schedule) + self.ebed_dims = [get_so3_dim_of_lmax(l) for l in self.l_schedule] + self.rad_sizes_per_block = [l + 1 for l in self.l_schedule] + + self.so2_norm = bool(so2_norm) + self.so2_layers = int(so2_layers) + self.so2_attn_res_mode = str(so2_attn_res).lower() + if self.so2_attn_res_mode not in ATTN_RES_MODES: + raise ValueError( + "`so2_attn_res` must be one of 'none', 'independent', or 'dependent'" + ) + self.radial_so2_mode = str(radial_so2_mode).lower() + if self.radial_so2_mode not in {"none", "degree", "degree_channel"}: + raise ValueError( + "`radial_so2_mode` must be one of 'none', 'degree', or 'degree_channel'" + ) + self.radial_so2_rank = int(radial_so2_rank) + if self.radial_so2_rank < 0: + raise ValueError("`radial_so2_rank` must be non-negative") + self.ffn_neurons = int(ffn_neurons) + self.block_ffn_neurons = self._resolve_ffn_neurons( + self.ffn_neurons, + glu_activation=self.ffn_glu_activation, + ) + self.out_ffn_neurons = self._resolve_ffn_neurons( + self.ffn_neurons, + glu_activation=self.out_glu_activation, + ) + self.grid_mlp = bool(grid_mlp) + self.ffn_blocks = int(ffn_blocks) + if self.ffn_blocks < 1: + raise ValueError("`ffn_blocks` must be >= 1") + self.full_attn_res_mode = str(full_attn_res).lower() + if self.full_attn_res_mode not in ATTN_RES_MODES: + raise ValueError( + "`full_attn_res` must be one of 'none', 'independent', or 'dependent'" + ) + self.block_attn_res_mode = str(block_attn_res).lower() + if self.block_attn_res_mode not in ATTN_RES_MODES: + raise ValueError( + "`block_attn_res` must be one of 'none', 'independent', or 'dependent'" + ) + self.use_full_attn_res = self.full_attn_res_mode != "none" + self.use_block_attn_res = self.block_attn_res_mode != "none" + if self.use_full_attn_res and self.use_block_attn_res: + raise ValueError( + "`full_attn_res` and `block_attn_res` cannot both be enabled" + ) + self.n_atten_head = int(n_atten_head) + self.atten_f_mix = bool(atten_f_mix) + self.use_atten_v_proj = bool(atten_v_proj) + self.use_atten_o_proj = bool(atten_o_proj) + so2_focus_dim = self.channels if self.focus_dim == 0 else self.focus_dim + attn_focus_dim = ( + self.n_focus * so2_focus_dim if self.atten_f_mix else so2_focus_dim + ) + if self.n_atten_head > 0 and attn_focus_dim % self.n_atten_head != 0: + raise ValueError( + "`n_atten_head` must divide the attention width " + "(`focus_dim` or `n_focus * focus_dim` when `atten_f_mix=True`)" + ) + + # === Excluded type pairs === + self.reinit_exclude(exclude_types) + + # === Type embedding === + self.type_embedding = SeZMTypeEmbedding( + ntypes=self.ntypes, + embed_dim=self.channels, + dtype=self.compute_dtype, # force fp32+ + seed=seed_type_embedding, + trainable=self.trainable, + ) + if self.add_chg_spin_ebd: + self.charge_spin_embedding: ChargeSpinEmbedding | None = ( + ChargeSpinEmbedding( + embed_dim=self.channels, + activation_function=self.activation_function, + dtype=self.compute_dtype, + seed=seed_charge_spin, + trainable=self.trainable, + ) + ) + else: + self.charge_spin_embedding = None + + # === Env FiLM embedding (optional) === + if self.use_env_seed: + self.env_seed_embedding: EnvironmentInitialEmbedding | None = ( + EnvironmentInitialEmbedding( + ntypes=self.ntypes, + n_radial=self.n_radial, + channels=self.channels, + embed_dim=self.env_seed_embed_dim, + axis_dim=self.env_seed_axis_dim, + type_dim=self.env_seed_type_dim, + hidden_dim=self.env_seed_hidden_dim, + mlp_bias=self.mlp_bias, + activation_function=self.activation_function, + eps=self.eps, + dtype=self.compute_dtype, # force fp32+ + trainable=self.trainable, + seed=seed_env_seed, + ) + ) + self.film_scale_norm = ScalarRMSNorm( + channels=self.channels, + n_focus=1, + eps=self.eps, + dtype=self.compute_dtype, + trainable=self.trainable, + ) + self.film_shift_norm = ScalarRMSNorm( + channels=self.channels, + n_focus=1, + eps=self.eps, + dtype=self.compute_dtype, + trainable=self.trainable, + ) + film_strength_init = 0.01 + # Use 1D tensor (not scalar) for FSDP2 compatibility + self.film_scale_strength_log = nn.Parameter( + torch.full( + (1,), + math.log(film_strength_init), + dtype=self.compute_dtype, + device=self.device, + ), + requires_grad=self.trainable, + ) + self.film_shift_strength_log = nn.Parameter( + torch.full( + (1,), + math.log(film_strength_init), + dtype=self.compute_dtype, + device=self.device, + ), + requires_grad=self.trainable, + ) + else: + self.env_seed_embedding = None + self.film_scale_norm = None + self.film_shift_norm = None + self.film_scale_strength_log = None + self.film_shift_strength_log = None + + self.radial_basis = RadialBasis( + rcut=self.rcut, + basis_type=self.basis_type, + n_radial=self.n_radial, + dtype=self.compute_dtype, # force fp32+ + exponent=self.env_exp[0], + ) + + # === Shared radial embedding: RBF -> per-l radial features === + # Output dimension is (lmax+1)*channels, directly usable by GIE and SO2Conv. + # radial_mlp specifies hidden layer sizes; input/output layers are prepended/appended. + # Use fp32+ precision (same as RBF output) for numerical stability. + radial_out_dim = (self.lmax + 1) * self.channels + radial_mlp_layers = [self.n_radial, *self.radial_mlp, radial_out_dim] + self.radial_embedding = RadialMLP( + radial_mlp_layers, + activation_function=self.activation_function, + dtype=self.compute_dtype, # force fp32+ + trainable=self.trainable, + seed=seed_radial_embedding, + ) + + # === C^3 cutoff envelope for edge weight === + self.edge_envelope = C3CutoffEnvelope(rcut=self.rcut, exponent=self.env_exp[1]) + + wigner_lmax = self.l_schedule[0] + # force fp32+ + self.wigner_calc = WignerDCalculator( + lmax=wigner_lmax, + eps=self.eps, + dtype=self.compute_dtype, + ) + + self.use_gie = self.l_schedule[0] > 0 + if self.use_gie: + self.gie = GeometricInitialEmbedding( + lmax=self.l_schedule[0], + channels=self.channels, + dtype=self.compute_dtype, # force fp32+ + ) + else: + self.gie = None + + blocks: list[SeZMInteractionBlock] = [] + for block_idx, (l_b, m_b) in enumerate(zip(self.l_schedule, self.m_schedule)): + blocks.append( + SeZMInteractionBlock( + lmax=l_b, + mmax=m_b, + channels=self.channels, + n_focus=self.n_focus, + focus_dim=self.focus_dim, + so2_norm=self.so2_norm, + so2_layers=self.so2_layers, + so2_attn_res=self.so2_attn_res_mode, + radial_so2_mode=self.radial_so2_mode, + radial_so2_rank=self.radial_so2_rank, + ffn_neurons=self.block_ffn_neurons, + grid_mlp=self.grid_mlp, + ffn_blocks=self.ffn_blocks, + layer_scale=self.layer_scale, + full_attn_res=self.full_attn_res_mode, + block_attn_res=self.block_attn_res_mode, + so2_s2_activation=self.so2_s2_activation, + ffn_s2_activation=self.ffn_s2_activation, + so2_lebedev_quadrature=self.so2_lebedev_quadrature, + ffn_lebedev_quadrature=self.ffn_lebedev_quadrature, + n_atten_head=self.n_atten_head, + atten_f_mix=self.atten_f_mix, + atten_v_proj=self.use_atten_v_proj, + atten_o_proj=self.use_atten_o_proj, + so2_pre_norm=self.so2_pre_norm, + so2_post_norm=self.so2_post_norm, + so2_activation_function=self.so2_activation_function, + ffn_pre_norm=self.ffn_pre_norm, + ffn_post_norm=self.ffn_post_norm, + ffn_activation_function=self.ffn_activation_function, + ffn_glu_activation=self.ffn_glu_activation, + mlp_bias=self.mlp_bias, + use_triton=self.use_triton, + eps=self.eps, + dtype=self.dtype, + seed=child_seed(seed_blocks, block_idx), + trainable=self.trainable, + ) + ) + self.blocks = nn.ModuleList(blocks) + + # === Optional descriptor-level attention residuals === + self.final_block_attn_res = None + if self.use_full_attn_res: + self.final_full_attn_res: DepthAttnRes | None = DepthAttnRes( + channels=self.channels, + input_dependent=self.full_attn_res_mode == "dependent", + eps=self.eps, + bias=self.mlp_bias, + dtype=self.compute_dtype, + trainable=self.trainable, + seed=child_seed(seed_full_attn, 2000), + ) + else: + self.final_full_attn_res = None + if self.use_block_attn_res: + self.final_block_attn_res: DepthAttnRes | None = DepthAttnRes( + channels=self.channels, + input_dependent=self.block_attn_res_mode == "dependent", + eps=self.eps, + bias=self.mlp_bias, + dtype=self.compute_dtype, + trainable=self.trainable, + seed=child_seed(seed_block_attn, 2000), + ) + + # === Final FFN for l=0 output mixing === + self.output_ffn = EquivariantFFN( + lmax=0, + channels=self.channels, + hidden_channels=self.out_ffn_neurons, + grid_mlp=False, + dtype=self.compute_dtype, + s2_activation=False, + activation_function=self.out_activation_function, + glu_activation=self.out_glu_activation, + mlp_bias=self.mlp_bias, + trainable=self.trainable, + seed=seed_out, + ) + + for p in self.parameters(): + p.requires_grad = self.trainable + + # Pre-allocate empty tensor for interface compatibility (torch.compile + DDP) + self.register_buffer( + "_empty_tensor", + torch.empty(0, device=env.DEVICE, dtype=env.GLOBAL_PT_FLOAT_PRECISION), + persistent=True, + ) + + # === Statistics buffers (interface compatibility) === + self.stats: dict[str, Any] | None = None + self.register_buffer( + "mean", + torch.zeros(0, dtype=self.dtype, device=self.device), + persistent=True, + ) + self.register_buffer( + "stddev", + torch.ones(0, dtype=self.dtype, device=self.device), + persistent=True, + ) + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + edge_index: torch.Tensor | None = None, + edge_vec: torch.Tensor | None = None, + edge_mask: torch.Tensor | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, + fparam: torch.Tensor | None = None, + force_embedding: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """ + Compute the descriptor. + + Parameters + ---------- + extended_coord + Extended coordinates of atoms with shape (nf, nall*3) or (nf, nall, 3) in Å. + extended_atype + Extended atom types with shape (nf, nall). + nlist + Neighbor list with shape (nf, nloc, nnei). + mapping + Extended-to-local mapping with shape (nf, nall), or None. + edge_index + Fixed-shape edge indices with shape (2, E). If provided, the descriptor + uses the edge-list path and ignores `nlist` and `mapping`. + edge_vec + Fixed-shape edge vectors with shape (E, 3) in Å. Required when + `edge_index` is provided. + edge_mask + Fixed-shape edge mask with shape (E,). Required when `edge_index` + is provided. + comm_dict + Communication dictionary for parallel inference (unused). + fparam + Frame parameters with shape (nf, nfp). Not used by SeZM, kept for + interface compatibility. + force_embedding + Optional precomputed equivariant force embedding with shape + ``(nf * nloc, D, 1, channels)``, where + ``D = (l_schedule[0] + 1) ** 2``. This tensor is added to the + initial SO(3) backbone state before the interaction blocks. + charge_spin + Frame-level charge and spin conditions with shape (nf, 2). + + Returns + ------- + descriptor + Descriptor with shape (nf, nloc, channels). Only l=0 is returned. + rot_mat + Empty tensor (not used). + g2 + Empty tensor (not used). + h2 + Empty tensor (not used). + sw + Empty tensor (not used). + """ + if extended_coord.ndim == 2: + extended_coord = rearrange(extended_coord, "nf (nall c) -> nf nall c", c=3) + elif extended_coord.ndim != 3: + raise ValueError( + "extended_coord must have shape (nf, nall*3) or (nf, nall, 3)" + ) + + if edge_index is not None: + nf_edge = extended_atype.shape[0] + charge_spin = self._canonicalize_charge_spin( + charge_spin, + nf=nf_edge, + dtype=extended_coord.dtype, + device=extended_coord.device, + ) + descriptor, _ = self.forward_with_edges( + extended_coord=extended_coord, + extended_atype=extended_atype, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + force_embedding=force_embedding, + charge_spin=charge_spin, + ) + return ( + descriptor, + self._empty_tensor, + self._empty_tensor, + self._empty_tensor, + self._empty_tensor, + ) + + # === Step 1. Setup dimensions === + extended_coord = extended_coord.to(self.compute_dtype) + nf, nloc, nnei = nlist.shape + nall = extended_coord.shape[1] + n_nodes = int(nf * nloc) + charge_spin = self._canonicalize_charge_spin( + charge_spin, + nf=nf, + dtype=extended_coord.dtype, + device=extended_coord.device, + ) + + # === Step 2. Excluded type pairs === + if self.exclude_types: + # (nf, nloc, nnei), True means keep. + pair_keep_mask = self.emask(nlist, extended_atype).to(dtype=torch.bool) + else: + pair_keep_mask = torch.ones_like(nlist, dtype=torch.bool) + + # === Step 3. Type embedding (l=0) === + with nvtx_range("type_embedding"): + atype_loc = extended_atype[:, :nloc] # (nf, nloc) + type_ebed = self.type_embedding(atype_loc).reshape( + n_nodes, self.channels + ) # (N, C) + if self.charge_spin_embedding is not None: + type_ebed = self._apply_charge_spin_embedding( + type_ebed, + charge_spin, + nf=nf, + nloc=nloc, + ) + + # === Step 4. Build edge cache once (geometry + RBF + Wigner-D) === + # Zone bridging (InnerClamp + SFPG + ZBL) is not routed through the + # standard DeePMD path: bridging only makes physical sense when + # paired with the ZBL energy that ``SeZMModel`` injects on the + # sparse-edge path, so ``forward`` keeps the original + # bridging-free aggregation semantics. + with nvtx_range("build_edge_cache"): + edge_cache = build_edge_cache( + type_ebed=type_ebed, + extended_coord=extended_coord, + nlist=nlist, + mapping=mapping, + pair_keep_mask=pair_keep_mask, + eps=self.eps, + edge_envelope=self.edge_envelope, + radial_basis=self.radial_basis, + n_radial=self.radial_basis.n_radial, + random_gamma=self.random_gamma, + wigner_calc=self.wigner_calc, + use_geometry_rbf_triton=(self.use_triton and not self.training), + ) + + lmax_0 = self.l_schedule[0] + ebed_dim_0 = get_so3_dim_of_lmax(lmax_0) # (lmax+1)^2 + x0 = type_ebed # (N, C) + x0_out = x0 # (N, C) + + # === Step 5. Compute radial features once (fp32+) === + # Shape: (E, (lmax+1)*C) -> (E, lmax+1, C) + radial_feat = None + with nvtx_range("radial_embedding"): + if edge_cache.src.numel() > 0: + radial_feat = rearrange( + self.radial_embedding(edge_cache.edge_rbf), + "E (L C) -> E L C", + L=self.lmax + 1, + C=self.channels, + ) # (E, lmax+1, C) + + # === Step 6. Env FiLM conditioning (optional, fp32+) === + with nvtx_range("env_film"): + if self.use_env_seed and edge_cache.src.numel() > 0: + atype_flat = atype_loc.reshape(-1) # (N,) + film = self.env_seed_embedding( + edge_cache=edge_cache, + atype_flat=atype_flat, + n_nodes=n_nodes, + ) # (N, 2*C) + scale_logits = film[:, : self.channels] # (N, C) + shift_logits = film[:, self.channels :] # (N, C) + scale_hat = self.film_scale_norm(scale_logits) # (N, C) + shift_hat = self.film_shift_norm(shift_logits) # (N, C) + scale_strength = torch.exp(self.film_scale_strength_log) + shift_strength = torch.exp(self.film_shift_strength_log) + scale = 1.0 + scale_strength * torch.tanh(scale_hat) # (N, C) + shift = shift_strength * torch.tanh(shift_hat) # (N, C) + x0_out = x0 * scale + shift + + # === Step 7. Build backbone l=0 features === + x = type_ebed.new_zeros(n_nodes, ebed_dim_0, 1, self.channels) # (N, D, 1, C) + x[:, 0, 0, :] = x0_out + + # === Step 8. Geometric Initial Embedding (fp32+) === + with nvtx_range("gie"): + if self.use_gie and radial_feat is not None: + # GIE only needs l>=1, slice radial_feat[:, 1:, :] + x = x + self.gie( + n_nodes=n_nodes, + edge_cache=edge_cache, + radial_feat=radial_feat[:, 1:, :], + ).unsqueeze(2) + + # === Step 9. Fuse edge type features into radial features (fp32+) === + with nvtx_range("radial_fuse"): + if radial_feat is not None: + radial_feat = radial_feat + rearrange( + edge_cache.edge_type_feat, "E C -> E 1 C" + ) + radial_feat = radial_feat.to(dtype=self.dtype) + rad_feat_per_block = [ + radial_feat[:, :rad_len, :] for rad_len in self.rad_sizes_per_block + ] # list of (E, lmax+1, C) + else: + rad_feat_per_block = [] + + # === Step 10. Convert to self.dtype and run blocks === + with nvtx_range("blocks"): + x = x.to(dtype=self.dtype) # (N, D, 1, C) + if force_embedding is not None: + x = x + force_embedding.to(dtype=self.dtype) + if edge_cache.src.numel() > 0: + edge_cache = edge_cache_to_dtype(edge_cache, self.dtype) + with self._compute_mode_ctx(extended_coord.device): + x = self._forward_blocks(x, edge_cache, rad_feat_per_block) + + # === Step 11. Final l=0 output mixing === + # Extract l=0 scalar features and apply FFN in promoted dtype. + # Residual keeps the output close to identity with zero-initialized FFN output. + with nvtx_range("output_ffn"): + x_scalar = ( + x[:, 0:1, :, :] + .reshape(n_nodes, 1, 1, self.channels) + .to(dtype=self.compute_dtype) + ) # (N, 1, 1, C) + x_scalar = x_scalar + self.output_ffn(x_scalar) + + # === Step 12. Reshape to (nf, nloc, channels) and return === + descriptor = rearrange( + x_scalar, "(nf nloc) 1 1 C -> nf nloc C", nf=nf, nloc=nloc + ) # (nf, nloc, C) + return ( + descriptor.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), + self._empty_tensor, + self._empty_tensor, + self._empty_tensor, + self._empty_tensor, + ) + + def forward_with_edges( + self, + *, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + edge_index: torch.Tensor, + edge_vec: torch.Tensor, + edge_mask: torch.Tensor, + force_embedding: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute the descriptor from a sparse edge list. + + Parameters + ---------- + extended_coord + Coordinates with shape (nf, nloc*3) or (nf, nloc, 3) in Å. + extended_atype + Atom types with shape (nf, nloc). + edge_index + Edge indices with shape (2, E). + edge_vec + Edge vectors with shape (E, 3) in Å. + edge_mask + Edge mask with shape (E,). + force_embedding + Optional precomputed equivariant force embedding with shape + ``(nf * nloc, D, 1, channels)``, where + ``D = (l_schedule[0] + 1) ** 2``. This tensor is added to the + initial SO(3) backbone state before the interaction blocks. + charge_spin + Frame-level charge and spin conditions with shape (nf, 2). + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + The scalar descriptor with shape ``(nf, nloc, channels)`` and the + final equivariant latent with shape ``(nf * nloc, D_final, 1, channels)``. + """ + # === Step 1. Setup dimensions === + extended_coord = extended_coord.to(self.compute_dtype) + nf, nloc = extended_atype.shape[:2] + + # === Step 2. Type embedding (l=0) === + with nvtx_range("type_embedding"): + atype_loc = extended_atype[:, :nloc] # (nf, nloc) + type_ebed = self.type_embedding(atype_loc).reshape( + -1, self.channels + ) # (N, C) + if self.charge_spin_embedding is not None: + type_ebed = self._apply_charge_spin_embedding( + type_ebed, + charge_spin, + nf=nf, + nloc=nloc, + ) + n_nodes = type_ebed.shape[0] + + # === Step 3. Build edge cache once (sparse edges) === + with nvtx_range("build_edge_cache"): + edge_cache = build_edge_cache_from_edges( + type_ebed=type_ebed, + atype_flat=atype_loc.reshape(-1), + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + compute_dtype=self.compute_dtype, + eps=self.eps, + inner_clamp=self.inner_clamp, + bridging_switch=self.bridging_switch, + edge_envelope=self.edge_envelope, + radial_basis=self.radial_basis, + has_exclude_types=bool(self.exclude_types), + edge_type_keep_mask=self._edge_type_keep_mask, + random_gamma=self.random_gamma, + wigner_calc=self.wigner_calc, + ) + + lmax_0 = self.l_schedule[0] + ebed_dim_0 = get_so3_dim_of_lmax(lmax_0) # (lmax+1)^2 + x0 = type_ebed # (N, C) + x0_out = x0 # (N, C) + + # === Step 4. Compute radial features once (fp32+) === + with nvtx_range("radial_embedding"): + radial_feat_flat = self.radial_embedding(edge_cache.edge_rbf) + radial_feat = radial_feat_flat.reshape( + radial_feat_flat.shape[0], self.lmax + 1, self.channels + ) # (E, lmax+1, C) + + # === Step 5. Env FiLM conditioning (optional, fp32+) === + with nvtx_range("env_film"): + if self.use_env_seed: + atype_flat = atype_loc.reshape(-1) # (N,) + film = self.env_seed_embedding( + edge_cache=edge_cache, + atype_flat=atype_flat, + n_nodes=n_nodes, + ) # (N, 2*C) + scale_logits = film[:, : self.channels] # (N, C) + shift_logits = film[:, self.channels :] # (N, C) + scale_hat = self.film_scale_norm(scale_logits) # (N, C) + shift_hat = self.film_shift_norm(shift_logits) # (N, C) + scale_strength = torch.exp(self.film_scale_strength_log) + shift_strength = torch.exp(self.film_shift_strength_log) + scale = 1.0 + scale_strength * torch.tanh(scale_hat) # (N, C) + shift = shift_strength * torch.tanh(shift_hat) # (N, C) + x0_out = x0 * scale + shift + + # === Step 6. Build backbone l=0 features === + x = type_ebed.new_zeros(n_nodes, ebed_dim_0, 1, self.channels) # (N, D, 1, C) + x[:, 0, 0, :] = x0_out + + # === Step 7. Geometric Initial Embedding (fp32+) === + with nvtx_range("gie"): + if self.use_gie: + x = x + self.gie( + n_nodes=n_nodes, + edge_cache=edge_cache, + radial_feat=radial_feat[:, 1:, :], + ).unsqueeze(2) + + # === Step 8. Fuse edge type features into radial features (fp32+) === + with nvtx_range("radial_fuse"): + radial_feat = radial_feat.to(dtype=self.dtype) + radial_feat = radial_feat + rearrange( + edge_cache.edge_type_feat.to(dtype=self.dtype), "E C -> E 1 C" + ) + rad_feat_per_block = [ + radial_feat[:, :rad_len, :] for rad_len in self.rad_sizes_per_block + ] + + # === Step 9. Convert to self.dtype and run blocks === + with nvtx_range("blocks"): + x = x.to(dtype=self.dtype) # (N, D, 1, C) + if force_embedding is not None: + x = x + force_embedding.to(dtype=self.dtype) + edge_cache = edge_cache_to_dtype(edge_cache, self.dtype) + with self._compute_mode_ctx(extended_coord.device): + x = self._forward_blocks(x, edge_cache, rad_feat_per_block) + + # === Step 10. Final l=0 output mixing === + with nvtx_range("output_ffn"): + x_scalar = ( + x[:, 0:1, :, :] + .reshape(n_nodes, 1, 1, self.channels) + .to(dtype=self.compute_dtype) + ) # (N, 1, 1, C) + x_scalar = x_scalar + self.output_ffn(x_scalar) + + # === Step 11. Reshape to (nf, nloc, channels) and return === + descriptor = x_scalar.reshape(nf, nloc, self.channels) # (nf, nloc, C) + return descriptor.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), x.contiguous() + + def _forward_blocks( + self, + x: torch.Tensor, + edge_cache: EdgeFeatureCache, + radial_feat_per_block: list[torch.Tensor], + ) -> torch.Tensor: + """ + Run the interaction blocks with optional depth attention. + + Parameters + ---------- + x + Initial node features with shape (N, D, 1, C). + edge_cache + Per-edge cache. + radial_feat_per_block + List of per-block radial features already truncated to l_schedule[i]+1. + + Returns + ------- + torch.Tensor + Output features with shape (N, D, 1, C). + """ + if not self.use_full_attn_res and not self.use_block_attn_res: + # === Fast path without descriptor-level attention residuals === + for i, block in enumerate(self.blocks): + x = x[:, : self.ebed_dims[i], :, :] + blk_radial = radial_feat_per_block[i] + with nvtx_range(f"block_{i}"): + x, _, _, _ = block(x, edge_cache, blk_radial) + return x + + n_node = x.shape[0] + + def node_l0_extractor(v: torch.Tensor) -> torch.Tensor: + """Extract scalar features from global SO(3) layout.""" + return v[:, 0, :, :].reshape(n_node, self.channels) + + if self.use_full_attn_res: + # === Step 1. Maintain descriptor-level unit history === + unit_history = [x] + + # === Step 2. Run each block with selective unit-history aggregation === + for i, block in enumerate(self.blocks): + current_dim = self.ebed_dims[i] + current_x = x[:, :current_dim, :, :] + truncated_unit_history = [ + source[:, :current_dim, :, :] for source in unit_history + ] + blk_radial = radial_feat_per_block[i] + with nvtx_range(f"block_{i}"): + block_output, _, so2_unit_output, ffn_unit_outputs = block( + current_x, + edge_cache, + blk_radial, + unit_history=truncated_unit_history, + ) + unit_history.append(so2_unit_output) + unit_history.extend(ffn_unit_outputs) + x = block_output + + # === Step 3. Final aggregation over all completed unit representations === + final_dim = self.ebed_dims[-1] + final_sources = [source[:, :final_dim, :, :] for source in unit_history] + x = self.final_full_attn_res( + sources=final_sources, + scalar_extractor=node_l0_extractor, + current_x=x, + ).to(dtype=self.dtype) + return x + + # === Step 1. Maintain descriptor-level block history === + block_history = [x] + + # === Step 2. Run each block with selective block-history aggregation === + for i, block in enumerate(self.blocks): + current_dim = self.ebed_dims[i] + current_x = x[:, :current_dim, :, :] + truncated_block_history = [ + source[:, :current_dim, :, :] for source in block_history + ] + blk_radial = radial_feat_per_block[i] + with nvtx_range(f"block_{i}"): + block_output, block_summary, _, _ = block( + current_x, + edge_cache, + blk_radial, + unit_history=truncated_block_history, + ) + block_history.append(block_summary) + x = block_output + + # === Step 3. Final aggregation over all completed block summaries === + final_dim = self.ebed_dims[-1] + final_sources = [source[:, :final_dim, :, :] for source in block_history] + x = self.final_block_attn_res( + sources=final_sources, + scalar_extractor=node_l0_extractor, + current_x=x, + ).to(dtype=self.dtype) + return x + + def _apply_charge_spin_embedding( + self, + type_ebed: torch.Tensor, + charge_spin: torch.Tensor, + *, + nf: int, + nloc: int, + ) -> torch.Tensor: + """ + Add frame-level charge and spin conditions to scalar type features. + + Parameters + ---------- + type_ebed + Flattened type embeddings with shape (nf * nloc, channels). + charge_spin + Frame-level charge and spin conditions with shape (nf, 2). + nf + Number of frames. + nloc + Number of local atoms. + + Returns + ------- + torch.Tensor + Conditioned type embeddings with shape (nf * nloc, channels). + """ + condition = self.charge_spin_embedding(charge_spin.to(dtype=type_ebed.dtype)) + condition = condition[:, None, :].expand(nf, nloc, self.channels) + return type_ebed + condition.reshape_as(type_ebed) + + def _edge_type_keep_mask( + self, + atype_flat: torch.Tensor, + src: torch.Tensor, + dst: torch.Tensor, + ) -> torch.Tensor: + """ + Build keep mask for edge pairs based on excluded type pairs. + + Parameters + ---------- + atype_flat + Flattened local atom types with shape (N,). + src + Source indices with shape (E,). + dst + Destination indices with shape (E,). + + Returns + ------- + torch.Tensor + Boolean mask with shape (E,), True means keep. + """ + if self.emask.no_exclusion: + return torch.ones_like(src, dtype=torch.bool, device=src.device) + type_i = atype_flat.index_select(0, dst) + type_j = atype_flat.index_select(0, src) + type_i = torch.where(type_i >= 0, type_i, self.ntypes) + type_j = torch.where(type_j >= 0, type_j, self.ntypes) + type_ij = type_i * (self.ntypes + 1) + type_j + type_mask = self.emask.type_mask.to(device=atype_flat.device) + keep = type_mask.index_select(0, type_ij.to(dtype=torch.long)) + return keep.to(dtype=torch.bool) + + def _resolve_ffn_neurons( + self, + ffn_neurons: int, + *, + glu_activation: bool, + ) -> int: + """Resolve one FFN hidden width from the descriptor config.""" + resolved = int(ffn_neurons) + if resolved < 0: + raise ValueError("`ffn_neurons` must be >= 0") + if resolved > 0: + return resolved + base_width = ( + (8.0 * float(self.channels) / 3.0) + if glu_activation + else (4.0 * float(self.channels)) + ) + return int(32 * math.ceil(base_width / 32.0)) + + def _init_lm_schedules( + self, + lmax: int, + n_blocks: int, + l_schedule: list[int] | None, + mmax: int | None, + m_schedule: list[int] | None, + ) -> None: + """Parse and validate L/M schedules, setting self.l_schedule/m_schedule/lmax/mmax.""" + # === L schedule === + if l_schedule is None: + self.l_schedule = [int(lmax)] * int(n_blocks) + else: + self.l_schedule = [int(x) for x in l_schedule] + if len(self.l_schedule) == 0: + raise ValueError("`l_schedule` must be non-empty") + if any(x < 0 for x in self.l_schedule): + raise ValueError("`l_schedule` entries must be non-negative") + if any( + self.l_schedule[i] < self.l_schedule[i + 1] + for i in range(len(self.l_schedule) - 1) + ): + raise ValueError("`l_schedule` must be non-increasing (pyramid schedule)") + + self.lmax = int(self.l_schedule[0]) + self.n_blocks = len(self.l_schedule) + + # === M schedule === + if m_schedule is None: + if mmax is None: + self.m_schedule = [int(l) for l in self.l_schedule] + else: + mmax_i = int(mmax) + if mmax_i < 0: + raise ValueError("`mmax` must be non-negative") + self.m_schedule = [min(mmax_i, int(l)) for l in self.l_schedule] + else: + self.m_schedule = [int(x) for x in m_schedule] + if len(self.m_schedule) == 0: + raise ValueError("`m_schedule` must be non-empty") + if len(self.m_schedule) != len(self.l_schedule): + raise ValueError("`m_schedule` must have the same length as `l_schedule`") + if any(x < 0 for x in self.m_schedule): + raise ValueError("`m_schedule` entries must be non-negative") + if any(m > l for m, l in zip(self.m_schedule, self.l_schedule)): + raise ValueError( + "`m_schedule` entries must satisfy `m_schedule[i] <= l_schedule[i]`" + ) + + self.mmax = int(self.m_schedule[0]) + + def _canonicalize_charge_spin( + self, + charge_spin: torch.Tensor | None, + *, + nf: int, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor | None: + """ + Canonicalize charge/spin conditions for the public descriptor path. + + Parameters + ---------- + charge_spin + Optional frame-level charge and spin conditions. + nf + Number of frames. + dtype + Target floating-point dtype. + device + Target device. + + Returns + ------- + torch.Tensor or None + Tensor with shape (nf, 2) when condition embedding is enabled. + """ + if self.charge_spin_embedding is None: + return None + if charge_spin is None: + if self.default_chg_spin is None: + raise ValueError("`charge_spin` is required for this SeZM descriptor.") + charge_spin = torch.tensor( + self.default_chg_spin, + dtype=dtype, + device=device, + ).view(1, 2) + else: + charge_spin = charge_spin.to(dtype=dtype, device=device) + + if charge_spin.ndim == 1: + if charge_spin.numel() != 2: + raise ValueError("`charge_spin` must contain [charge, spin].") + charge_spin = charge_spin.view(1, 2) + elif charge_spin.ndim != 2 or charge_spin.shape[-1] != 2: + raise ValueError("`charge_spin` must have shape (nf, 2).") + + if charge_spin.shape[0] == 1 and nf != 1: + charge_spin = charge_spin.expand(nf, -1) + elif charge_spin.shape[0] != nf: + raise ValueError("`charge_spin` first dimension must match nframes.") + return charge_spin + + @contextmanager + def _compute_mode_ctx(self, device: torch.device) -> Generator[None, None, None]: + """ + Context manager that applies automatic mixed precision (AMP) for forward(). + + Parameters + ---------- + device + The device of the input tensors (used to determine if CUDA ops apply). + + Notes + ----- + - When `use_amp=True` and the model is in training mode, enables + torch.autocast with bfloat16 on CUDA. + - Only affects autocast-eligible operations (matmul, conv, etc.). + - Does nothing during inference (`self.training=False`), on non-CUDA + devices, or when `use_amp=False`. + + Yields + ------ + None + Runs the wrapped region under the configured AMP setting. + """ + if not self.use_amp or device.type != "cuda" or not self.training: + yield + return + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + yield + + # === DeePMD descriptor interface === + def get_rcut(self) -> float: + return self.rcut + + def get_rcut_smth(self) -> float: + return self.rcut + + def get_sel(self) -> list[int]: + return self.sel + + def get_nsel(self) -> int: + return sum(self.sel) + + def get_ntypes(self) -> int: + return self.ntypes + + def get_type_map(self) -> list[str]: + return self.type_map if self.type_map is not None else [] + + def get_dim_chg_spin(self) -> int: + """Return the charge/spin condition width.""" + return 2 if self.add_chg_spin_ebd else 0 + + def has_default_chg_spin(self) -> bool: + """Return whether default charge/spin conditions are configured.""" + return self.default_chg_spin is not None + + def get_default_chg_spin(self) -> list[float] | None: + """Return default charge/spin conditions.""" + return self.default_chg_spin + + def get_dim_out(self) -> int: + return self.channels + + def get_dim_emb(self) -> int: + return self.get_dim_out() + + def mixed_types(self) -> bool: + """ + If true, the descriptor + 1. assumes total number of atoms aligned across frames; + 2. requires a neighbor list that does not distinguish different atomic types. + + If false, the descriptor + 1. assumes total number of atoms of each atom type aligned across frames; + 2. requires a neighbor list that distinguishes different atomic types. + + SeZM uses SeZMTypeEmbedding for type handling, so it does not require + a type-distinguished neighbor list. + """ + return True + + def has_message_passing(self) -> bool: + return bool(len(self.blocks) > 0 and self.lmax > 0) + + def need_sorted_nlist_for_lower(self) -> bool: + return False + + def get_env_protection(self) -> float: + return self.eps + + @property + def dim_out(self) -> int: + return self.get_dim_out() + + @property + def dim_emb(self) -> int: + return self.get_dim_emb() + + def share_params( + self, base_class: Any, shared_level: int, resume: bool = False + ) -> None: + """ + Share the parameters of self to the base_class with shared_level during multitask training. + + SeZM does not rely on running mean/stddev statistics in ``forward`` + (``EquivariantRMSNorm`` is used instead), so only submodules and + the optional FiLM strength parameters need to be linked. + + Parameters + ---------- + base_class + The base class to share parameters with. Must be the same class as self. + + shared_level + The level of sharing. + + - ``0``: share every learnable submodule and FiLM strength parameter + (type_embedding, env_seed_embedding, film_*_norm, + film_*_strength_log, radial_basis, radial_embedding, + edge_envelope, wigner_calc, gie, blocks, final_*_attn_res, + output_ffn). + - ``1``: share ``type_embedding`` and optional condition embedding. + + resume + Unused for SeZM; kept for interface compatibility. + + Raises + ------ + NotImplementedError + If ``shared_level`` is not ``0`` or ``1``. + """ + del resume + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + # NOTE: ``nn.Module.__setattr__`` routes plain assignment of a + # child ``nn.Module`` through the ``_modules`` dict, so iterating + # that dict covers every learnable submodule registered by + # ``__init__`` (type_embedding, env_seed_embedding, film norms, + # radial_*, edge_envelope, wigner_calc, gie, blocks, final attn + # residuals, output_ffn). Raw ``nn.Parameter`` attributes + # (``film_*_strength_log``) live in ``_parameters`` instead and + # are linked explicitly below. + for item in self._modules: + self._modules[item] = base_class._modules[item] + for name in ("film_scale_strength_log", "film_shift_strength_log"): + if self._parameters.get(name) is not None: + self._parameters[name] = base_class._parameters[name] + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + if self.charge_spin_embedding is not None: + self._modules["charge_spin_embedding"] = base_class._modules[ + "charge_spin_embedding" + ] + else: + raise NotImplementedError + + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statistics (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + raise NotImplementedError("Compression is unsupported for SeZM.") + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat: Any | None = None + ) -> None: + raise NotImplementedError("change_type_map is not supported for SeZM") + + def reinit_exclude( + self, exclude_types: list[tuple[int, int]] | None = None + ) -> None: + if exclude_types is None: + exclude_types = [] + self.exclude_types = exclude_types + self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + + # ========================================================================= + # Statistics interface (interface compatibility only) + # ------------------------------------------------------------------------- + # SeZM uses EquivariantRMSNorm inside blocks for feature normalization, + # so mean/stddev are NOT used in forward(). These methods are kept for: + # 1. Interface compatibility with BaseDescriptor + # 2. Consistent serialization format (davg/dstd in checkpoint) + # ========================================================================= + + def set_stat_mean_and_stddev( + self, mean: torch.Tensor, stddev: torch.Tensor + ) -> None: + """Set mean and stddev (interface compatibility, not used in forward).""" + self.mean = mean + self.stddev = stddev + + def get_stat_mean_and_stddev(self) -> tuple[torch.Tensor, torch.Tensor]: + """Get mean and stddev (interface compatibility, not used in forward).""" + return self.mean, self.stddev + + def compute_input_stats( + self, + merged: Callable[[], list[dict]] | list[dict], + path: DPPath | None = None, + ) -> None: + """ + Compute statistics (interface compatibility, not used in forward). + + SeZM uses learnable EquivariantRMSNorm for normalization, so these + statistics do not affect the forward pass. This is a no-op that keeps + mean/stddev at their initialized values (zero/one) for interface consistency. + """ + # No-op: mean and stddev are already initialized to zero/one in __init__ + # and are not used in forward() due to EquivariantRMSNorm. + + def serialize(self) -> dict[str, Any]: + state = self.state_dict() + return { + "@class": "Descriptor", + "type": "SeZM", + "@version": 1, + "config": { + "ntypes": self.ntypes, + "sel": self.sel, + "rcut": self.rcut, + "env_exp": self.env_exp, + "type_map": self.type_map, + "lmax": self.lmax, + "n_blocks": self.n_blocks, + "l_schedule": self.l_schedule, + "mmax": self.mmax, + "m_schedule": self.m_schedule, + "channels": self.channels, + "basis_type": self.basis_type, + "n_radial": self.n_radial, + "radial_mlp": self.radial_mlp, + "use_env_seed": self.use_env_seed, + "random_gamma": self.random_gamma, + "so2_norm": self.so2_norm, + "so2_layers": self.so2_layers, + "so2_attn_res": self.so2_attn_res_mode, + "radial_so2_mode": self.radial_so2_mode, + "radial_so2_rank": self.radial_so2_rank, + "n_focus": self.n_focus, + "focus_dim": self.focus_dim, + "ffn_neurons": self.ffn_neurons, + "grid_mlp": self.grid_mlp, + "ffn_blocks": self.ffn_blocks, + "layer_scale": self.layer_scale, + "n_atten_head": self.n_atten_head, + "atten_f_mix": self.atten_f_mix, + "atten_v_proj": self.use_atten_v_proj, + "atten_o_proj": self.use_atten_o_proj, + "sandwich_norm": self.sandwich_norm, + "full_attn_res": self.full_attn_res_mode, + "block_attn_res": self.block_attn_res_mode, + "s2_activation": self.s2_activation, + "lebedev_quadrature": self.lebedev_quadrature, + "activation_function": self.activation_function, + "glu_activation": self.glu_activation, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "mlp_bias": self.mlp_bias, + "exclude_types": self.exclude_types, + "eps": self.eps, + "trainable": self.trainable, + "seed": self.seed, + "inner_clamp_r_inner": self.inner_clamp_r_inner, + "inner_clamp_r_outer": self.inner_clamp_r_outer, + "add_chg_spin_ebd": self.add_chg_spin_ebd, + "default_chg_spin": self.default_chg_spin, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + "env_mat": DPEnvMat(self.rcut, self.rcut, self.eps).serialize(), + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> DescrptSeZM: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "Descriptor": + raise ValueError(f"Invalid class for DescrptSeZM: {data_cls}") + type_val = data.pop("type") + if type_val not in ("SeZM", "sezm", "dpa4"): + raise ValueError(f"Invalid type for DescrptSeZM: {type_val}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + data.pop("env_mat", None) + config.pop("s2_grid_resolution", None) + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: list[str] | None, + local_jdata: dict, + ) -> tuple[dict, float | None]: + """ + Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + Data used to do neighbor statistics. + type_map : list[str] | None + The name of each type of atoms. + local_jdata : dict + The local data refer to the current class. + + Returns + ------- + dict + The updated local data. + float | None + The minimum distance between two atoms. + """ + local_jdata_cpy = local_jdata.copy() + min_nbor_dist, sel = UpdateSel().update_one_sel( + train_data, + type_map, + local_jdata_cpy["rcut"], + local_jdata_cpy["sel"], + True, # mixed_type=True for unified sel + ) + local_jdata_cpy["sel"] = sel[0] + return local_jdata_cpy, min_nbor_dist + + def _load_from_state_dict( + self, + state_dict: dict[str, torch.Tensor], + prefix: str, + local_metadata: dict[str, Any], + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], + ) -> None: + """Fold LoRA adapters and drop transient state before loading. + + When a LoRA-trained checkpoint is loaded into a plain (non-LoRA) + descriptor, any ``A_by_l``/``B_by_l`` (SO3) and + ``A_m0``/``B_m0``/``A_m.*``/``B_m.*`` (SO2) keys are folded into + their corresponding base weight keys (``weight``, ``weight_m0``, + ``weight_m.*``) using ``ΔW = einsum(B, A) * scaling``. The LoRA + keys are then removed so the load proceeds as if the checkpoint + were a plain SeZM. This enables resume, finetune, and full-train + from any LoRA checkpoint without manual merging. + + When the current descriptor is itself LoRA-injected, however, the + incoming ``A_*`` / ``B_*`` / ``lora_scaling`` tensors are + first-class parameters this descriptor already owns, *not* + redundant adapters to be folded away. Folding in that case would + consume the LoRA keys and then ``super()._load_from_state_dict`` + would report them as ``Missing key(s)`` against the target + module. ``has_lora(self)`` gates the fold step so the + LoRA-to-plain merge still runs when appropriate, while + LoRA-to-LoRA loads (full training, ckpt resume, tests, and + cross-instance copies via ``model_a.load_state_dict( + model_b.state_dict())``) pass the adapter keys through + unchanged. + """ + # === Step 1. Fold any LoRA keys into base weights === + # Only fold when the current descriptor has no LoRA adapters + # (see docstring). + if not has_lora(self): + fold_lora_state_dict_keys(state_dict, prefix) + + # === Step 2. Drop transient descriptor state rebuilt at construction === + expected_keys = {prefix + key for key in self.state_dict().keys()} + for full_key in list(state_dict.keys()): + if full_key.startswith(prefix) and full_key not in expected_keys: + state_dict.pop(full_key) + + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/__init__.py b/deepmd/pt/model/descriptor/sezm_nn/__init__.py new file mode 100644 index 0000000000..9faa82ee97 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/__init__.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Public building blocks for the SeZM descriptor. + +This package re-exports the helper functions, embeddings, equivariant layers, +and quaternion-based Wigner-D utilities used by the SeZM descriptor and model. +""" + +from .activation import ( + GatedActivation, + S2GridProjector, + SwiGLU, + SwiGLUS2Activation, + resolve_s2_grid_resolution, +) +from .attention import ( + segment_envelope_gated_softmax, +) +from .attn_res import ( + DepthAttnRes, +) +from .block import ( + SeZMInteractionBlock, +) +from .dens import ( + ForceEmbedding, + SeZMDenoisingHead, + SeZMDeNSFittingNet, + SeZMDirectForceHead, +) +from .edge_cache import ( + EdgeFeatureCache, + build_edge_cache, + build_edge_cache_from_edges, + build_edge_type_feat, + compute_edge_src_gate, + edge_cache_to_dtype, +) +from .embedding import ( + ChargeSpinEmbedding, + EnvironmentInitialEmbedding, + GeometricInitialEmbedding, + SeZMTypeEmbedding, +) +from .ffn import ( + EquivariantFFN, +) +from .indexing import ( + build_l_major_index, + build_m_major_index, + build_m_major_l_index, + build_rotate_inv_rescale, + get_so3_dim_of_lmax, + map_degree_idx, + project_D_to_m, + project_Dt_from_m, + so3_packed_index, +) +from .lebedev import ( + LEBEDEV_PRECISION_TO_NPOINTS, + load_lebedev_rule, +) +from .lora import ( + LoRASO2, + LoRASO3, + apply_lora_to_sezm, + build_merged_state_dict, + fold_lora_state_dict_keys, + has_lora, + merge_lora_into_base, + strip_lora_from_extra_state, +) +from .norm import ( + EquivariantRMSNorm, + ReducedEquivariantRMSNorm, + RMSNorm, + ScalarRMSNorm, +) +from .radial import ( + BridgingSwitch, + C3CutoffEnvelope, + InnerClamp, + RadialBasis, + RadialMLP, +) +from .so2 import ( + DynamicRadialDegreeMixer, + SO2Convolution, + SO2Linear, +) +from .so3 import ( + ChannelLinear, + FocusLinear, + SO3Linear, +) +from .utils import ( + ATTN_RES_MODES, + get_promoted_dtype, + init_trunc_normal_fan_in_out, + np_safe, + nvtx_range, + safe_norm, + safe_numpy_to_tensor, +) +from .wignerd import ( + WignerDCalculator, + build_edge_quaternion, + quaternion_multiply, + quaternion_nlerp, + quaternion_normalize, + quaternion_to_rotation_matrix, + quaternion_z_rotation, +) + +__all__ = [ + "ATTN_RES_MODES", + "LEBEDEV_PRECISION_TO_NPOINTS", + "BridgingSwitch", + "C3CutoffEnvelope", + "ChannelLinear", + "ChargeSpinEmbedding", + "DepthAttnRes", + "DynamicRadialDegreeMixer", + "EdgeFeatureCache", + "EnvironmentInitialEmbedding", + "EquivariantFFN", + "EquivariantRMSNorm", + "FocusLinear", + "ForceEmbedding", + "GatedActivation", + "GeometricInitialEmbedding", + "InnerClamp", + "LoRASO2", + "LoRASO3", + "RMSNorm", + "RadialBasis", + "RadialMLP", + "ReducedEquivariantRMSNorm", + "S2GridProjector", + "SO2Convolution", + "SO2Linear", + "SO3Linear", + "ScalarRMSNorm", + "SeZMDeNSFittingNet", + "SeZMDenoisingHead", + "SeZMDirectForceHead", + "SeZMInteractionBlock", + "SeZMTypeEmbedding", + "SwiGLU", + "SwiGLUS2Activation", + "WignerDCalculator", + "apply_lora_to_sezm", + "build_edge_cache", + "build_edge_cache_from_edges", + "build_edge_quaternion", + "build_edge_type_feat", + "build_l_major_index", + "build_m_major_index", + "build_m_major_l_index", + "build_merged_state_dict", + "build_rotate_inv_rescale", + "compute_edge_src_gate", + "edge_cache_to_dtype", + "fold_lora_state_dict_keys", + "get_promoted_dtype", + "get_so3_dim_of_lmax", + "has_lora", + "init_trunc_normal_fan_in_out", + "load_lebedev_rule", + "map_degree_idx", + "merge_lora_into_base", + "np_safe", + "nvtx_range", + "project_D_to_m", + "project_Dt_from_m", + "quaternion_multiply", + "quaternion_nlerp", + "quaternion_normalize", + "quaternion_to_rotation_matrix", + "quaternion_z_rotation", + "resolve_s2_grid_resolution", + "safe_norm", + "safe_numpy_to_tensor", + "segment_envelope_gated_softmax", + "so3_packed_index", + "strip_lora_from_extra_state", +] diff --git a/deepmd/pt/model/descriptor/sezm_nn/activation.py b/deepmd/pt/model/descriptor/sezm_nn/activation.py new file mode 100644 index 0000000000..1ce567f72b --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/activation.py @@ -0,0 +1,807 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Activation and S2-grid helper modules for SeZM. + +This module contains SeZM nonlinear operators, including GatedActivation, +point-wise SwiGLU, and the S2-grid projection helper used by the +S2 activation path. +""" + +from __future__ import ( + annotations, +) + +import math +from typing import ( + Any, +) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from e3nn.o3 import ( + FromS2Grid, + ToS2Grid, + spherical_harmonics, +) + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + ActivationFn, + get_generator, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .indexing import ( + build_l_major_index, + build_m_major_index, + build_m_major_l_index, + map_degree_idx, +) +from .lebedev import ( + LEBEDEV_PRECISION_TO_NPOINTS, + load_lebedev_rule, +) +from .so3 import ( + FocusLinear, +) +from .utils import ( + np_safe, + safe_numpy_to_tensor, +) + + +class GatedActivation(nn.Module): + """ + Gated activation for SO(3) equivariant features with per-l independent gates. + + Standard mode (gate=None in forward): + - l=0: Uses the specified activation function + - l>0: Each degree l has an independent gate derived from the l=0 scalar features. + The gate for each l is expanded to all m components within that l-block. + + GLU mode (gate provided in forward, e.g., from split linear output): + - l=0: x0 * act(g0) (SwiGLU-style when act=silu, GeGLU when act=gelu, etc.) + - l>0: Uses gate's scalar (g0) to generate sigmoid gates for x's vector components. + This preserves SO(3) equivariance (scalar gates vector, not vector gates vector). + + This module also supports the m-major reduced layout used inside SO(2) blocks. + If `mmax` is provided, the coefficient axis is assumed to follow the truncated + m-major order built by `build_m_major_index(lmax, mmax)`; otherwise, it is assumed + to be the full packed (l, m) layout with D=(lmax+1)^2. + + Parameters + ---------- + lmax + Maximum spherical harmonic degree. + mmax + Maximum order (|m|) for the m-major reduced layout. If None, use the full + packed layout with D=(lmax+1)^2. + channels + Number of channels per focus stream. + n_focus + Number of focus streams. + dtype + Internal compute dtype used by the gate projection and sigmoid path. + activation_function + Activation function for l=0 components (e.g., "silu", "tanh", "gelu"). + mlp_bias + Whether to use bias in the gate linear layer. + layout + Tensor layout convention. ``"nfdc"`` means input shape (N, F, D, C); + ``"ndfc"`` means input shape (N, D, F, C). + trainable + Whether parameters are trainable. + seed + Random seed for weight initialization. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + channels: int, + n_focus: int = 1, + dtype: torch.dtype, + activation_function: str = "silu", + mlp_bias: bool = False, + layout: str = "nfdc", + trainable: bool, + seed: int | list[int] | None = None, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.mmax = None if mmax is None else int(mmax) + if self.mmax is not None: + if self.mmax < 0: + raise ValueError("`mmax` must be non-negative") + if self.mmax > self.lmax: + raise ValueError("`mmax` must be <= `lmax`") + self.channels = int(channels) + self.n_focus = int(n_focus) + self.dtype = dtype + self.device = env.DEVICE + self.precision = RESERVED_PRECISION_DICT[dtype] + self.mlp_bias = bool(mlp_bias) + self.layout = str(layout).lower() + if self.layout not in {"nfdc", "ndfc"}: + raise ValueError("`layout` must be either 'nfdc' or 'ndfc'") + + self.scalar_act = ActivationFn(activation_function) + + # === Build expand_index for mapping per-l gates to all m components === + if self.lmax > 0: + if self.mmax is None: + expand_index = map_degree_idx(self.lmax, device=self.device)[1:] - 1 + else: + degree_index = build_m_major_l_index( + self.lmax, self.mmax, device=self.device + ) + expand_index = degree_index[1:] - 1 + self.gate_linear: nn.Module = FocusLinear( + in_channels=self.channels, + out_channels=self.lmax * self.channels, + n_focus=self.n_focus, + dtype=self.dtype, + bias=self.mlp_bias, + seed=seed, + trainable=trainable, + ) + + gen_gate = get_generator(child_seed(seed, 1)) + nn.init.normal_( + self.gate_linear.weight, mean=0.0, std=0.01, generator=gen_gate + ) + if self.gate_linear.bias is not None: + nn.init.zeros_(self.gate_linear.bias) + else: + expand_index = torch.zeros(0, dtype=torch.long, device=self.device) + self.gate_linear = nn.Identity() + self.register_buffer("expand_index", expand_index, persistent=True) + + for p in self.parameters(): + p.requires_grad = trainable + + def forward( + self, x: torch.Tensor, gate: torch.Tensor | None = None + ) -> torch.Tensor: + """ + Parameters + ---------- + x + Value features. Shape is (N, F, D, C) when ``layout='nfdc'``, + or (N, D, F, C) when ``layout='ndfc'``. + gate + Optional gate features with the same layout as ``x``. + When provided, enables GLU mode: + - l=0: x0 * act(g0) (e.g., SwiGLU when act=silu) + - l>0: sigmoid(Linear(g0)) gates x's vector components + When None (default), uses standard mode where gates are derived from x itself. + + Returns + ------- + torch.Tensor + Gated features with the same layout as ``x``. + """ + degree_axis = 1 if self.layout == "ndfc" else 2 + + if gate is not None: + gate_scalar_source = gate.select(dim=degree_axis, index=0) + else: + gate_scalar_source = x.select(dim=degree_axis, index=0) + + if gate is not None: + x0 = x.narrow(degree_axis, 0, 1) * self.scalar_act( + gate.narrow(degree_axis, 0, 1) + ) + else: + x0 = self.scalar_act(x.narrow(degree_axis, 0, 1)) + + if self.lmax == 0: + return x0 + + input_dtype = gate_scalar_source.dtype + gating_scalars = torch.sigmoid( + self.gate_linear(gate_scalar_source.to(dtype=self.dtype)) + ).to(dtype=input_dtype) + gating_scalars = gating_scalars.reshape( + x.shape[0], gate_scalar_source.shape[1], self.lmax, self.channels + ) + gates = gating_scalars.index_select(dim=2, index=self.expand_index) + if self.layout == "ndfc": + gates = gates.transpose(1, 2) + + out = x.new_empty(x.shape) + out.narrow(degree_axis, 0, 1).copy_(x0) + out.narrow(degree_axis, 1, x.shape[degree_axis] - 1).copy_( + x.narrow(degree_axis, 1, x.shape[degree_axis] - 1) * gates + ) + return out + + def serialize(self) -> dict[str, Any]: + trainable = all(p.requires_grad for p in self.parameters()) + state = self.state_dict() + return { + "@class": "GatedActivation", + "@version": 1, + "config": { + "lmax": self.lmax, + "mmax": self.mmax, + "channels": self.channels, + "n_focus": self.n_focus, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "activation_function": self.scalar_act.activation, + "mlp_bias": self.mlp_bias, + "layout": self.layout, + "trainable": trainable, + "seed": None, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> GatedActivation: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "GatedActivation": + raise ValueError(f"Invalid class for GatedActivation: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj + + +class SwiGLU(nn.Module): + """Point-wise SwiGLU on the last feature axis.""" + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + gate, value = torch.chunk(inputs, chunks=2, dim=-1) + return F.silu(gate) * value + + +class S2GridProjector(nn.Module): + """ + Project SO(3) coefficients to/from a flattened S2 grid. + + Parameters + ---------- + lmax + Maximum spherical harmonic degree. + mmax + Maximum order kept in the coefficient layout. If None, use ``lmax``. + dtype + Buffer dtype used by the projection matrices. + grid_resolution_list + Two-element resolution list. For ``grid_method='e3nn'`` it is + ``[R_phi, R_theta]`` and is converted to the ``e3nn`` + ``(lat, long) = (R_theta, R_phi)`` ordering. For + ``grid_method='lebedev'`` it is ``[precision, n_points]``. + coefficient_layout + Coefficient ordering expected by the caller: + - ``"packed"``: packed ``(l, m)`` order, optionally truncated by ``mmax``. + - ``"m_major"``: reduced m-major order used inside ``SO2Convolution``. + grid_method + S2 quadrature backend. Must be ``"e3nn"`` or ``"lebedev"``. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + dtype: torch.dtype, + grid_resolution_list: list[int] | None = None, + coefficient_layout: str = "packed", + grid_method: str = "e3nn", + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.mmax = int(self.lmax if mmax is None else mmax) + if self.mmax < 0: + raise ValueError("`mmax` must be non-negative") + if self.mmax > self.lmax: + raise ValueError("`mmax` must be <= `lmax`") + self.dtype = dtype + self.device = env.DEVICE + self.coefficient_layout = str(coefficient_layout).lower() + if self.coefficient_layout not in {"packed", "m_major"}: + raise ValueError( + "`coefficient_layout` must be either 'packed' or 'm_major'" + ) + self.grid_method = str(grid_method).lower() + if self.grid_method not in {"e3nn", "lebedev"}: + raise ValueError("`grid_method` must be either 'e3nn' or 'lebedev'") + + self.grid_resolution_list = _normalize_s2_grid_resolution( + self.lmax, + self.mmax, + grid_resolution_list, + method=self.grid_method, + ) + if self.grid_method == "e3nn": + self.phi_resolution, self.theta_resolution = self.grid_resolution_list + self.lebedev_precision = 0 + self.lebedev_npoints = 0 + else: + self.phi_resolution = 0 + self.theta_resolution = 0 + self.lebedev_precision, self.lebedev_npoints = self.grid_resolution_list + + coeff_index = self._build_coefficient_index(device=torch.device("cpu")) + self.coeff_dim = int(coeff_index.numel()) + to_grid_mat, from_grid_mat = self._build_projection_mats(coeff_index) + to_grid_mat = to_grid_mat.to(device=self.device, dtype=self.dtype) + from_grid_mat = from_grid_mat.to(device=self.device, dtype=self.dtype) + self.register_buffer("to_grid_mat", to_grid_mat, persistent=True) + self.register_buffer("from_grid_mat", from_grid_mat, persistent=True) + + def _build_coefficient_index(self, device: torch.device) -> torch.Tensor: + if self.coefficient_layout == "m_major": + return build_m_major_index(self.lmax, self.mmax, device=device) + if self.mmax == self.lmax: + return torch.arange((self.lmax + 1) ** 2, device=device, dtype=torch.long) + return build_l_major_index(self.lmax, self.mmax, device=device) + + def _rescale_truncated_orders(self, mat: torch.Tensor) -> None: + if self.lmax == self.mmax: + return + for l in range(self.lmax + 1): + if l <= self.mmax: + continue + start_idx = l * l + length = 2 * l + 1 + rescale = math.sqrt(length / float(2 * self.mmax + 1)) + mat[:, :, start_idx : start_idx + length].mul_(rescale) + + def _rescale_truncated_matrix(self, mat: torch.Tensor) -> None: + if self.lmax == self.mmax: + return + for l in range(self.lmax + 1): + if l <= self.mmax: + continue + start_idx = l * l + length = 2 * l + 1 + rescale = math.sqrt(length / float(2 * self.mmax + 1)) + mat[:, start_idx : start_idx + length].mul_(rescale) + + def _build_projection_mats( + self, coeff_index: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.grid_method == "lebedev": + return self._build_lebedev_projection_mats(coeff_index) + return self._build_e3nn_projection_mats(coeff_index) + + def _build_e3nn_projection_mats( + self, coeff_index: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + with torch.device("cpu"): + to_grid = ToS2Grid( + self.lmax, + (self.theta_resolution, self.phi_resolution), + normalization="component", + device="cpu", + ) + to_grid_mat = torch.einsum("mbi,am->bai", to_grid.shb, to_grid.sha).detach() + self._rescale_truncated_orders(to_grid_mat) + + from_grid = FromS2Grid( + (self.theta_resolution, self.phi_resolution), + self.lmax, + normalization="component", + device="cpu", + ) + from_grid_mat = torch.einsum( + "am,mbi->bai", from_grid.sha, from_grid.shb + ).detach() + self._rescale_truncated_orders(from_grid_mat) + + to_grid_mat = to_grid_mat.flatten(0, 1).index_select(1, coeff_index) + from_grid_mat = ( + from_grid_mat.flatten(0, 1).permute(1, 0).index_select(0, coeff_index) + ) + return to_grid_mat, from_grid_mat + + def _build_lebedev_projection_mats( + self, coeff_index: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + with torch.device("cpu"): + points, weights = load_lebedev_rule( + self.lebedev_precision, + dtype=torch.float64, + device=torch.device("cpu"), + ) + harmonics = spherical_harmonics( + list(range(self.lmax + 1)), + points, + normalize=True, + normalization="norm", + ) + # e3nn's ``norm`` harmonics are ``component / sqrt(2*l+1)``. + # ``ToS2Grid(..., normalization="component")`` additionally divides + # every degree block by ``sqrt(lmax+1)``; keep the same convention so + # the Lebedev backend can replace the e3nn product-grid backend. + scale = math.sqrt(float(self.lmax + 1)) + degree_factors = harmonics.new_tensor( + [ + float(2 * l + 1) + for l in range(self.lmax + 1) + for _ in range(2 * l + 1) + ] + ) + to_grid_mat = harmonics / scale + # The packaged Lebedev weights sum to one. For ``norm`` harmonics, + # ``sum_a w_a Y_j(a) Y_k(a) = delta_jk / (2*l+1)``; the + # degree_factors and ``scale`` invert this normalization. + from_grid_mat = harmonics * ( + weights[:, None] * scale * degree_factors[None, :] + ) + self._rescale_truncated_matrix(to_grid_mat) + self._rescale_truncated_matrix(from_grid_mat) + + to_grid_mat = to_grid_mat.index_select(1, coeff_index) + from_grid_mat = from_grid_mat.index_select(1, coeff_index).transpose(0, 1) + return to_grid_mat, from_grid_mat + + def to_grid(self, embedding: torch.Tensor) -> torch.Tensor: + """Project coefficients ``(N, D, C)`` to a flattened grid ``(N, A, C)``.""" + return torch.einsum("aj,njc->nac", self.to_grid_mat, embedding) + + def from_grid(self, grid: torch.Tensor) -> torch.Tensor: + """Project a flattened grid ``(N, A, C)`` back to coefficients ``(N, D, C)``.""" + return torch.einsum("ja,nac->njc", self.from_grid_mat, grid) + + def serialize(self) -> dict[str, Any]: + return { + "@class": "S2GridProjector", + "@version": 1, + "config": { + "lmax": self.lmax, + "mmax": self.mmax, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "grid_resolution_list": self.grid_resolution_list, + "coefficient_layout": self.coefficient_layout, + "grid_method": self.grid_method, + }, + "@variables": {}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> S2GridProjector: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "S2GridProjector": + raise ValueError(f"Invalid class for S2GridProjector: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + data.pop("@variables", None) + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + return cls(**config) + + +class SwiGLUS2Activation(nn.Module): + """ + Apply the merged scalar/grid SwiGLU-S2 activation to SO(3) coefficients. + + The degree-0 slice provides two scalar paths: + + - a scalar ``SwiGLU`` branch that is merged back into the output ``l=0`` part + - a learned sigmoid gate that modulates the full output reconstructed from + the S2 grid path + + The equivariant branch projects the full ``2 * channels`` coefficients to the + S2 grid, multiplies the two channel halves point-wise on the grid, projects + back to coefficients, and applies the scalar sigmoid gate. + + Parameters + ---------- + lmax + Maximum spherical harmonic degree. + mmax + Maximum order kept in the coefficient layout. If None, use ``lmax``. + channels + Output channel count after SwiGLU. The input is expected to have + ``2 * channels`` on the last axis. + dtype + Projection buffer dtype. + n_focus + Number of focus streams in the input layout. + layout + Tensor layout convention: + - ``"ndfc"`` for ``(N, D, F, C)`` + - ``"nfdc"`` for ``(N, F, D, C)`` + grid_resolution_list + Two-element list ``[R_phi, R_theta]``. + coefficient_layout + Coefficient ordering: ``"packed"`` or ``"m_major"``. + grid_method + S2 quadrature backend. Must be ``"e3nn"`` or ``"lebedev"``. + mlp_bias + Whether the scalar sigmoid projection uses bias. + trainable + Whether parameters are trainable. + seed + Random seed for the scalar sigmoid projection. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + channels: int, + dtype: torch.dtype, + n_focus: int = 1, + layout: str = "ndfc", + grid_resolution_list: list[int] | None = None, + coefficient_layout: str = "packed", + grid_method: str = "e3nn", + mlp_bias: bool = False, + trainable: bool, + seed: int | list[int] | None = None, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.mmax = int(self.lmax if mmax is None else mmax) + self.channels = int(channels) + self.dtype = dtype + self.n_focus = int(n_focus) + self.mlp_bias = bool(mlp_bias) + self.layout = str(layout).lower() + if self.layout not in {"ndfc", "nfdc"}: + raise ValueError("`layout` must be either 'ndfc' or 'nfdc'") + self.coefficient_layout = str(coefficient_layout).lower() + self.grid_method = str(grid_method).lower() + self.grid_resolution_list = _normalize_s2_grid_resolution( + self.lmax, + self.mmax, + grid_resolution_list, + method=self.grid_method, + ) + self.scalar_act = SwiGLU() + self.scalar_gate = FocusLinear( + in_channels=2 * self.channels, + out_channels=self.channels, + n_focus=self.n_focus, + dtype=self.dtype, + bias=self.mlp_bias, + trainable=trainable, + seed=child_seed(seed, 0), + init_std=0.01, + ) + self.projector: S2GridProjector | None + if self.lmax == 0: + self.projector = None + self.coeff_dim = 1 + else: + self.projector = S2GridProjector( + lmax=self.lmax, + mmax=self.mmax, + dtype=self.dtype, + grid_resolution_list=self.grid_resolution_list, + coefficient_layout=self.coefficient_layout, + grid_method=self.grid_method, + ) + self.coeff_dim = self.projector.coeff_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x + Input tensor with last dimension ``2 * channels``. + + Returns + ------- + torch.Tensor + Activated tensor with the same coefficient layout and ``channels`` on + the last axis. + """ + input_dtype = x.dtype + # Promote before slicing to avoid the TorchInductor AMP compile bug on + # the scalar SwiGLU branch in PyTorch 2.11. + scalar_inputs = self._extract_scalar_inputs(x.to(dtype=self.dtype)) + scalar_outputs = self.scalar_act(scalar_inputs) + + if self.projector is None: + return self._restore_scalar_outputs(scalar_outputs.to(dtype=input_dtype)) + + gate_scalars = torch.sigmoid(self.scalar_gate(scalar_inputs)) + x_flat, shape_info = self._flatten_inputs(x) + x_grid = self.projector.to_grid(x_flat.to(dtype=self.dtype)) + x_grid_1, x_grid_2 = torch.chunk(x_grid, chunks=2, dim=-1) + out_flat = self.projector.from_grid(x_grid_1 * x_grid_2) + outputs = self._restore_outputs(out_flat, shape_info) + outputs = outputs * self._broadcast_scalar_gate(gate_scalars) + self._merge_scalar_outputs(outputs, scalar_outputs) + return outputs.to(dtype=input_dtype) + + def _extract_scalar_inputs(self, x: torch.Tensor) -> torch.Tensor: + if self.layout == "ndfc": + return x.select(dim=1, index=0) + return x.select(dim=2, index=0) + + def _broadcast_scalar_gate(self, gate_scalars: torch.Tensor) -> torch.Tensor: + if self.layout == "ndfc": + return gate_scalars.unsqueeze(1) + return gate_scalars.unsqueeze(2) + + def _restore_scalar_outputs(self, scalar_outputs: torch.Tensor) -> torch.Tensor: + if self.layout == "ndfc": + return scalar_outputs.unsqueeze(1) + return scalar_outputs.unsqueeze(2) + + def _flatten_inputs( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, tuple[int, int, int]]: + if self.layout == "ndfc": + n_batch, coeff_dim, n_focus, _ = x.shape + return ( + x.permute(0, 2, 1, 3).reshape( + n_batch * n_focus, coeff_dim, x.shape[-1] + ), + (n_batch, coeff_dim, n_focus), + ) + n_batch, n_focus, coeff_dim, _ = x.shape + return ( + x.reshape(n_batch * n_focus, coeff_dim, x.shape[-1]), + (n_batch, coeff_dim, n_focus), + ) + + def _restore_outputs( + self, x: torch.Tensor, shape_info: tuple[int, int, int] + ) -> torch.Tensor: + n_batch, coeff_dim, n_focus = shape_info + if self.layout == "ndfc": + return x.reshape(n_batch, n_focus, coeff_dim, self.channels).permute( + 0, 2, 1, 3 + ) + return x.reshape(n_batch, n_focus, coeff_dim, self.channels) + + def _merge_scalar_outputs( + self, outputs: torch.Tensor, scalar_outputs: torch.Tensor + ) -> None: + if self.layout == "ndfc": + outputs[:, 0, :, :].add_(scalar_outputs) + else: + outputs[:, :, 0, :].add_(scalar_outputs) + + def serialize(self) -> dict[str, Any]: + trainable = all(p.requires_grad for p in self.parameters()) + state = self.state_dict() + return { + "@class": "SwiGLUS2Activation", + "@version": 1, + "config": { + "lmax": self.lmax, + "mmax": self.mmax, + "channels": self.channels, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "n_focus": self.n_focus, + "layout": self.layout, + "grid_resolution_list": self.grid_resolution_list, + "coefficient_layout": self.coefficient_layout, + "grid_method": self.grid_method, + "mlp_bias": self.mlp_bias, + "trainable": trainable, + "seed": None, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SwiGLUS2Activation: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "SwiGLUS2Activation": + raise ValueError(f"Invalid class for SwiGLUS2Activation: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj + + +def resolve_s2_grid_resolution( + lmax: int, + mmax: int, + *, + method: str = "e3nn", +) -> list[int]: + """ + Resolve the default S2 grid resolution. + + For ``method='e3nn'``, the automatic default uses even azimuthal sampling + ``R_phi = 2 * mmax + 4`` and even polar sampling + ``R_theta = ceil_even(3 * lmax + 2)``. + + For ``method='lebedev'``, the automatic default picks the smallest packaged + Lebedev rule whose algebraic precision is at least ``3 * lmax`` and returns + ``[precision, n_points]``. + """ + method = str(method).lower() + if method not in {"e3nn", "lebedev"}: + raise ValueError("`method` must be either 'e3nn' or 'lebedev'") + if method == "lebedev": + required_precision = 3 * int(lmax) + for precision, n_points in LEBEDEV_PRECISION_TO_NPOINTS.items(): + if precision >= required_precision: + return [precision, n_points] + raise ValueError( + f"No packaged Lebedev rule has precision >= {required_precision}" + ) + + phi_resolution = 2 * mmax + 4 + theta_resolution = 3 * lmax + 2 + theta_resolution += theta_resolution % 2 + return [phi_resolution, theta_resolution] + + +def _normalize_s2_grid_resolution( + lmax: int, + mmax: int, + grid_resolution_list: list[int] | None, + *, + method: str, +) -> list[int]: + """Resolve default grids or validate already-resolved low-level grids.""" + method = str(method).lower() + if grid_resolution_list is None: + return resolve_s2_grid_resolution(lmax, mmax, method=method) + if method == "lebedev": + if len(grid_resolution_list) != 2: + raise ValueError( + "Lebedev `grid_resolution_list` must be [precision, n_points]" + ) + precision = int(grid_resolution_list[0]) + n_points = int(grid_resolution_list[1]) + expected_n_points = LEBEDEV_PRECISION_TO_NPOINTS.get(precision) + if expected_n_points != n_points: + raise ValueError( + "Lebedev `grid_resolution_list` must match a packaged " + f"[precision, n_points] pair; got [{precision}, {n_points}]" + ) + return [precision, n_points] + + if len(grid_resolution_list) != 2: + raise ValueError("`grid_resolution_list` must contain two integers") + resolution = [int(grid_resolution_list[0]), int(grid_resolution_list[1])] + if resolution[0] < 1 or resolution[1] < 1: + raise ValueError("grid resolutions must be positive") + return resolution diff --git a/deepmd/pt/model/descriptor/sezm_nn/attention.py b/deepmd/pt/model/descriptor/sezm_nn/attention.py new file mode 100644 index 0000000000..4f42188c2e --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/attention.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Attention utilities for SeZM message passing. + +This module implements the destination-wise envelope-gated softmax used by the +SO(2) attention path in the SeZM descriptor. +""" + +from __future__ import ( + annotations, +) + +import torch +import torch.nn.functional as F + + +@torch.amp.autocast("cuda", enabled=False) +def segment_envelope_gated_softmax( + logits: torch.Tensor, + edge_env: torch.Tensor, + dst: torch.Tensor, + n_nodes: int, + z_bias_raw: torch.Tensor, + eps: float, + src_weight: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Compute destination-wise envelope-gated softmax attention. + + Parameters + ---------- + logits + Attention logits with shape (E, F, H). + edge_env + Cutoff envelope weights with shape (E, 1) or (E,). + dst + Destination node indices with shape (E,). + n_nodes + Number of nodes. + z_bias_raw + Unconstrained denominator bias with shape (F, H). + Softplus is applied to keep the bias strictly positive. + eps + Small epsilon for denominator stability. + src_weight + Optional per-edge source-side multiplier with shape (E, 1) or + (E,). When provided the per-edge weight becomes + ``edge_env**2 * src_weight`` and the attention reduces to + ``edge_env**2 * src_weight * exp(logits) / + (zeta + sum(edge_env**2 * src_weight * exp(logits)))``. + ``src_weight = 0`` therefore removes the source from both the + numerator and the denominator, which is what SFPG needs so that + a muted source does not even leak through the softmax + normalization. + + Returns + ------- + torch.Tensor + Normalized edge weights with shape (E, F, H). + """ + n_edge, n_focus, n_head = logits.shape + n_channel = n_focus * n_head + eps_f = float(eps) + + # === Step 1. Flatten (F, H) and build the effective per-edge weight === + logits_2d = logits.reshape(n_edge, n_channel) + edge_env_1d = edge_env.squeeze(-1).to(dtype=logits.dtype).clamp_min(0.0) + # edge_weight_sq acts as the non-negative multiplier applied to every + # ``exp(logit)`` term. Folding ``src_weight`` here guarantees that any + # edge with ``src_weight = 0`` is excluded from the group max, the + # numerator, and the denominator in a single pass. + edge_weight_sq = edge_env_1d.square() + if src_weight is not None: + edge_weight_sq = edge_weight_sq * src_weight.reshape(n_edge).to( + dtype=logits.dtype + ).clamp_min(0.0) + zeta = F.softplus(z_bias_raw).reshape(1, n_channel).to(dtype=logits.dtype) + dst_index = dst.reshape(n_edge, 1).expand(n_edge, n_channel) + has_weight = edge_weight_sq > 0.0 + logits_for_max = torch.where( + has_weight.reshape(n_edge, 1), + logits_2d, + torch.full_like(logits_2d, float("-inf")), + ) + + # === Step 2. Destination-wise max for stable exponentials === + group_max = torch.full( + (n_nodes, n_channel), + float("-inf"), + dtype=logits.dtype, + device=logits.device, + ) + group_max = torch.scatter_reduce( + group_max, + 0, + dst_index, + logits_for_max, + reduce="amax", + include_self=True, + ) + edge_max = group_max.index_select(0, dst) + edge_max = torch.where( + torch.isfinite(edge_max), edge_max, torch.zeros_like(edge_max) + ) + group_max_safe = torch.where( + torch.isfinite(group_max), group_max, torch.zeros_like(group_max) + ) + + # === Step 3. Envelope/SFPG-gated exponential terms === + exp_shifted = torch.exp(logits_2d - edge_max) + edge_weighted_exp = edge_weight_sq.reshape(n_edge, 1) * exp_shifted + + # === Step 4. Destination-wise normalization with positive denominator bias === + denom_sum = torch.zeros( + n_nodes, + n_channel, + dtype=logits.dtype, + device=logits.device, + ) + denom_sum = torch.scatter_add(denom_sum, 0, dst_index, edge_weighted_exp) + denom = denom_sum + zeta * torch.exp(-group_max_safe) + + alpha = edge_weighted_exp / (denom.index_select(0, dst) + eps_f) + return alpha.reshape(n_edge, n_focus, n_head) diff --git a/deepmd/pt/model/descriptor/sezm_nn/attn_res.py b/deepmd/pt/model/descriptor/sezm_nn/attn_res.py new file mode 100644 index 0000000000..1a8d883299 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/attn_res.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Attention-residual layers for the SeZM descriptor. + +This module defines the depth-wise attention residual aggregator used to +combine equivariant states across descriptor and block histories. +""" + +from __future__ import ( + annotations, +) + +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch +import torch.nn as nn + +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISION_DICT, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .norm import ( + ScalarRMSNorm, +) +from .so3 import ( + ChannelLinear, +) +from .utils import ( + np_safe, + safe_numpy_to_tensor, +) + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + ) + + +class DepthAttnRes(nn.Module): + """ + Depth-wise attention residual aggregation for equivariant tensors. + + Attention logits are computed only from scalar ``l=0`` channels, while the + resulting scalar weights are broadcast to the full equivariant value tensors. + This keeps the aggregation equivariant as long as all sources share the same + representation space. + + Query modes + ----------- + - ``input_dependent=True``: query comes from the current scalar state. + - ``input_dependent=False``: use a learned pseudo-query shared across inputs. + + Both query paths are zero-initialized so the initial aggregation is a uniform + average over all provided sources. + + Parameters + ---------- + channels + Scalar feature dimension used by query and key. + input_dependent + Whether to project the current scalar state into a query vector. + eps + Small epsilon for key RMS normalization. + bias + Whether to use bias in the input-dependent query projection. Only + effective when ``input_dependent=True``. + dtype + Parameter and compute dtype. Caller should pass compute_dtype (fp32+). + trainable + Whether parameters are trainable. + seed + Random seed reserved for consistency with other modules. + """ + + if TYPE_CHECKING: + query_proj: ChannelLinear + adamw_pseudo_query: torch.Tensor + + def __init__( + self, + *, + channels: int, + input_dependent: bool = True, + eps: float = 1e-7, + bias: bool = True, + dtype: torch.dtype, + trainable: bool, + seed: int | list[int] | None = None, + ) -> None: + super().__init__() + self.channels = int(channels) + self.input_dependent = bool(input_dependent) + self.eps = float(eps) + self.query_bias = bool(bias) + self.dtype = dtype + self.device = env.DEVICE + self.precision = RESERVED_PRECISION_DICT[dtype] + + self.key_norm = ScalarRMSNorm( + channels=self.channels, + n_focus=1, + eps=self.eps, + dtype=self.dtype, + trainable=trainable, + ) + if self.input_dependent: + self.query_proj = ChannelLinear( + in_channels=self.channels, + out_channels=self.channels, + dtype=self.dtype, + bias=self.query_bias, + trainable=trainable, + seed=seed, + init_std=0.0, + ) + else: + self.adamw_pseudo_query = nn.Parameter( + torch.zeros(self.channels, dtype=self.dtype, device=self.device), + requires_grad=trainable, + ) + + for p in self.parameters(): + p.requires_grad = trainable + + def forward( + self, + *, + sources: list[torch.Tensor], + scalar_extractor: Callable[[torch.Tensor], torch.Tensor], + current_x: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Aggregate same-shape sources with depth attention. + + Parameters + ---------- + sources + Source tensors with identical shape ``(B, ...)``. + scalar_extractor + Function that extracts scalar features from each source with shape + ``(B, C)`` where ``C=channels``. + current_x + Current tensor state. Required when ``input_dependent=True`` and + converted to scalar query features via ``scalar_extractor``. + + Returns + ------- + torch.Tensor + Aggregated tensor with the same shape as each source. + """ + source0 = sources[0] + if len(sources) == 1: + return source0 + batch_size = int(source0.shape[0]) + value_dtype = source0.dtype + + # === Step 1. Build the query vector === + if self.input_dependent: + current_x_scalar = scalar_extractor(current_x) + query = self.query_proj(current_x_scalar.to(dtype=self.dtype)) + else: + query = self.adamw_pseudo_query.unsqueeze(0).expand(batch_size, -1) + + # === Step 2. Extract and normalize scalar keys === + source_count = len(sources) + raw_keys = torch.stack( + [scalar_extractor(source).to(dtype=self.dtype) for source in sources], + dim=1, + ) # (B, S, C) + keys = self.key_norm(raw_keys) + logits = torch.einsum("bc,bsc->bs", query, keys) + alpha = torch.softmax(logits, dim=1) # (B, S) + + # === Step 3. Broadcast scalar weights to equivariant values === + value_stack = torch.stack( + [source.to(dtype=self.dtype) for source in sources], + dim=1, + ) + alpha = alpha.reshape( + batch_size, + source_count, + *([1] * (value_stack.ndim - 2)), + ) + aggregated = (alpha * value_stack).sum(dim=1) + return aggregated.to(dtype=value_dtype) + + def serialize(self) -> dict[str, Any]: + trainable = all(p.requires_grad for p in self.parameters()) + state = self.state_dict() + return { + "@class": "DepthAttnRes", + "@version": 1, + "config": { + "channels": self.channels, + "input_dependent": self.input_dependent, + "eps": self.eps, + "bias": self.query_bias, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "trainable": trainable, + "seed": None, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> DepthAttnRes: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "DepthAttnRes": + raise ValueError(f"Invalid class for DepthAttnRes: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj diff --git a/deepmd/pt/model/descriptor/sezm_nn/block.py b/deepmd/pt/model/descriptor/sezm_nn/block.py new file mode 100644 index 0000000000..ddc5d9847e --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/block.py @@ -0,0 +1,837 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Interaction blocks for the SeZM descriptor. + +This module defines the SeZM interaction block that combines SO(2) +message passing, equivariant feed-forward subblocks, and optional +attention-residual history aggregation. +""" + +from __future__ import ( + annotations, +) + +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch +import torch.nn as nn + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISION_DICT, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .attn_res import ( + DepthAttnRes, +) +from .ffn import ( + EquivariantFFN, +) +from .norm import ( + EquivariantRMSNorm, +) +from .so2 import ( + SO2Convolution, +) +from .utils import ( + ATTN_RES_MODES, + get_promoted_dtype, + np_safe, + nvtx_range, + safe_numpy_to_tensor, +) + +if TYPE_CHECKING: + from .edge_cache import ( + EdgeFeatureCache, + ) + + +class SeZMInteractionBlock(nn.Module): + """ + SeZM interaction block with SO(2) message passing and equivariant FFN stack. + + Branch order: + 1. SO(2) branch: optional pre-norm -> `SO2Convolution` -> optional post-norm. + 2. FFN branch: repeated subblocks of + optional pre-norm -> `EquivariantFFN` -> optional post-norm. + + In the baseline path, outer residual shortcuts are applied around the SO(2) + unit and each FFN subblock. In AttnRes paths, these shortcuts are replaced by + selective depth-wise aggregation before each unit. + + `SO2Convolution` internally handles the real multi-focus expansion, so this + block keeps a singleton-focus backbone layout `(N, D, 1, C)` at boundaries. + + Parameters + ---------- + lmax + Maximum spherical harmonic degree. + mmax + Maximum SO(2) order (|m|) mixed inside SO(2) convolution. + channels + Total channels per (l, m) coefficient. + n_focus + Number of multi-focus streams used only by the internal SO(2) branch. + focus_dim + Hidden width per focus stream used inside the SO(2) branch. + ``focus_dim=0`` means using ``channels``. + focus_compete + If True, enable cross-focus softmax competition in SO(2) convolution. + so2_norm + If True, apply intermediate ReducedEquivariantRMSNorm between SO(2) mixing layers. + When False (default), no normalization is applied between layers. + so2_layers + Number of SO(2) mixing layers. + so2_attn_res + Depth-wise attention residual mode across the internal SO(2) layer + history. Must be one of ``"none"``, ``"independent"``, or + ``"dependent"``. + radial_so2_mode + Dynamic radial degree mixer mode inside SO(2) convolution. ``"none"`` + applies elementwise radial modulation, ``"degree"`` uses a + channel-shared edge-conditioned cross-degree kernel, and + ``"degree_channel"`` uses a per-channel cross-degree kernel. + radial_so2_rank + Low-rank channel factorization rank for + ``radial_so2_mode="degree_channel"``. ``0`` uses the full + per-channel dynamic degree kernel. + n_atten_head + Number of attention heads when aggregating messages in SO(2) convolution. + 0 means no attention is used; >0 enables envelope-gated grouped softmax + attention with output-side head gate. + atten_f_mix + If True, merge SO(2) focus streams into one attention stream after + rotate-back. This gives each attention head access to the full + multi-focus hidden width. + atten_v_proj + If True, apply an explicit degree-aware value projection inside SO(2) + attention. + atten_o_proj + If True, apply an explicit degree-aware output projection inside SO(2) + attention. + so2_pre_norm + If True, apply pre-norm before SO(2) convolution. + so2_post_norm + If True, apply post-norm on SO(2) output before the residual add. + ffn_pre_norm + If True, apply pre-norm before each FFN subblock. + ffn_post_norm + If True, apply post-norm on each FFN subblock output before the residual add. + ffn_neurons + Hidden dimension for each FFN subblock. + grid_mlp + If True, use the optional grid-MLP structure for the block-internal FFN + units. The final descriptor output head is unaffected. + ffn_blocks + Number of FFN subblocks per block. + layer_scale + If True, apply learnable LayerScale (init 1e-3) on residual branches: + - SO(2) branch: per-focus-channel scales `(n_focus, focus_dim)` + on each SO(2) mixing layer. + - FFN branch: per-channel scales `(channels,)` on each FFN subblock. + full_attn_res + Descriptor-level full attention residual mode for this block wrapper. + When enabled, the block uses external unit history to build the SO(2) + input and the input of each FFN unit. + block_attn_res + Descriptor-level block attention residual mode for this block wrapper. + When enabled, the block uses external block history plus an intra-block + partial sum to build the SO(2) input and the input of each FFN unit. + so2_s2_activation + If True, enable the merged scalar/grid SwiGLU-S2 activation in the SO(2) + branch. + ffn_s2_activation + If True, enable the merged scalar/grid SwiGLU-S2 activation in the + default FFN activation path. + so2_lebedev_quadrature + If True, use Lebedev quadrature for the SO(2) S2 activation projector. + ffn_lebedev_quadrature + If True, use Lebedev quadrature for the FFN S2 activation projector. + so2_activation_function + Activation function for the block-internal SO(2) l=0 gated activation + path when ``so2_s2_activation=False``. + ffn_activation_function + Activation function for the block-internal FFN l=0 components. + ffn_glu_activation + If True, use GLU-style gating in the block-internal FFN + (e.g., silu -> swiglu, gelu -> geglu). + mlp_bias + Whether to use bias in equivariant layers. Controls: + - SO3Linear: l=0 bias + - SO2Linear: l=0 bias + - GatedActivation: gate linear bias + use_triton + If True, opt into fused Triton SO(2) rotation kernels inside + ``SO2Convolution`` when the runtime supports them. + eps + Small epsilon for numerical stability. + dtype + Parameter dtype. + seed + Random seed for weight initialization. + trainable + Whether parameters are trainable. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + channels: int, + n_focus: int = 1, + focus_dim: int = 0, + focus_compete: bool = True, + so2_norm: bool = False, + so2_layers: int = 4, + so2_attn_res: str = "none", + radial_so2_mode: str = "none", + radial_so2_rank: int = 0, + n_atten_head: int = 1, + atten_f_mix: bool = False, + atten_v_proj: bool = False, + atten_o_proj: bool = False, + so2_pre_norm: bool = True, + so2_post_norm: bool = False, + ffn_pre_norm: bool = True, + ffn_post_norm: bool = False, + ffn_neurons: int = 96, + grid_mlp: bool = False, + ffn_blocks: int = 1, + layer_scale: bool = False, + full_attn_res: str = "none", + block_attn_res: str = "none", + so2_s2_activation: bool = False, + ffn_s2_activation: bool = False, + so2_lebedev_quadrature: bool = False, + ffn_lebedev_quadrature: bool = False, + so2_activation_function: str = "silu", + ffn_activation_function: str, + ffn_glu_activation: bool = True, + mlp_bias: bool = False, + use_triton: bool = False, + eps: float = 1e-7, + dtype: torch.dtype, + seed: int | list[int] | None, + trainable: bool, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.mmax = int(self.lmax if mmax is None else mmax) + if self.mmax < 0: + raise ValueError("`mmax` must be non-negative") + if self.mmax > self.lmax: + raise ValueError("`mmax` must be <= `lmax`") + self.channels = int(channels) + self.n_focus = int(n_focus) + if self.n_focus < 1: + raise ValueError("`n_focus` must be >= 1") + self.focus_dim = int(focus_dim) + if self.focus_dim < 0: + raise ValueError("`focus_dim` must be >= 0") + self.focus_compete = bool(focus_compete) + self.so2_norm = bool(so2_norm) + self.so2_layers = int(so2_layers) + self.so2_attn_res_mode = str(so2_attn_res).lower() + if self.so2_attn_res_mode not in ATTN_RES_MODES: + raise ValueError( + "`so2_attn_res` must be one of 'none', 'independent', or 'dependent'" + ) + self.radial_so2_mode = str(radial_so2_mode).lower() + self.radial_so2_rank = int(radial_so2_rank) + self.n_atten_head = int(n_atten_head) + self.atten_f_mix = bool(atten_f_mix) + self.use_atten_v_proj = bool(atten_v_proj) + self.use_atten_o_proj = bool(atten_o_proj) + self.so2_pre_norm = bool(so2_pre_norm) + self.so2_post_norm = bool(so2_post_norm) + self.ffn_pre_norm = bool(ffn_pre_norm) + self.ffn_post_norm = bool(ffn_post_norm) + self.ffn_neurons = int(ffn_neurons) + self.grid_mlp = bool(grid_mlp) + self.ffn_blocks = int(ffn_blocks) + if self.ffn_blocks < 1: + raise ValueError("`ffn_blocks` must be >= 1") + self.layer_scale = bool(layer_scale) + self.full_attn_res_mode = str(full_attn_res).lower() + if self.full_attn_res_mode not in ATTN_RES_MODES: + raise ValueError( + "`full_attn_res` must be one of 'none', 'independent', or 'dependent'" + ) + self.block_attn_res_mode = str(block_attn_res).lower() + if self.block_attn_res_mode not in ATTN_RES_MODES: + raise ValueError( + "`block_attn_res` must be one of 'none', 'independent', or 'dependent'" + ) + self.use_full_attn_res = self.full_attn_res_mode != "none" + self.use_block_attn_res = self.block_attn_res_mode != "none" + if self.use_full_attn_res and self.use_block_attn_res: + raise ValueError( + "`full_attn_res` and `block_attn_res` cannot both be enabled" + ) + self.so2_s2_activation = bool(so2_s2_activation) + self.ffn_s2_activation = bool(ffn_s2_activation) + self.so2_lebedev_quadrature = bool(so2_lebedev_quadrature) + self.ffn_lebedev_quadrature = bool(ffn_lebedev_quadrature) + self.so2_activation_function = str(so2_activation_function) + self.ffn_activation_function = str(ffn_activation_function) + self.ffn_glu_activation = bool(ffn_glu_activation) + self.mlp_bias = bool(mlp_bias) + self.use_triton = bool(use_triton) + self.eps = float(eps) + self.dtype = dtype + self.device = env.DEVICE + self.precision = RESERVED_PRECISION_DICT[dtype] + self.compute_dtype = get_promoted_dtype(self.dtype) + + # === Step 0. Split deterministic seeds at the block top-level === + seed_so2_conv = child_seed(seed, 0) + seed_ffn = child_seed(seed, 1) + seed_full_attn = child_seed(seed, 2) + seed_block_attn = child_seed(seed, 3) + + # === Step 1. SO(2) convolution branch norms === + if self.so2_pre_norm: + self.pre_so2_norm: nn.Module = EquivariantRMSNorm( + self.lmax, + self.channels, + n_focus=1, + dtype=self.compute_dtype, + trainable=trainable, + ) + else: + self.pre_so2_norm = nn.Identity() + + if self.so2_post_norm: + self.post_so2_norm: nn.Module = EquivariantRMSNorm( + self.lmax, + self.channels, + n_focus=1, + dtype=self.compute_dtype, + trainable=trainable, + ) + else: + self.post_so2_norm = nn.Identity() + + self.so2_conv = SO2Convolution( + lmax=self.lmax, + mmax=self.mmax, + channels=self.channels, + n_focus=self.n_focus, + focus_dim=self.focus_dim, + focus_compete=self.focus_compete, + so2_norm=self.so2_norm, + so2_layers=self.so2_layers, + so2_attn_res=self.so2_attn_res_mode, + radial_so2_mode=self.radial_so2_mode, + radial_so2_rank=self.radial_so2_rank, + layer_scale=self.layer_scale, + n_atten_head=n_atten_head, + atten_f_mix=self.atten_f_mix, + atten_v_proj=self.use_atten_v_proj, + atten_o_proj=self.use_atten_o_proj, + s2_activation=self.so2_s2_activation, + lebedev_quadrature=self.so2_lebedev_quadrature, + activation_function=self.so2_activation_function, + mlp_bias=self.mlp_bias, + use_triton=self.use_triton, + eps=self.eps, + dtype=dtype, + seed=seed_so2_conv, + trainable=trainable, + ) + + # === Step 2. FFN subblock sequence === + pre_ffn_norms: list[nn.Module] = [] + post_ffn_norms: list[nn.Module] = [] + ffns: list[EquivariantFFN] = [] + + for i in range(self.ffn_blocks): + seed_ffn_i = child_seed(seed_ffn, i) + + if self.ffn_pre_norm: + pre_ffn_norms.append( + EquivariantRMSNorm( + self.lmax, + self.channels, + n_focus=1, + dtype=self.compute_dtype, + trainable=trainable, + ) + ) + else: + pre_ffn_norms.append(nn.Identity()) + + if self.ffn_post_norm: + post_ffn_norms.append( + EquivariantRMSNorm( + self.lmax, + self.channels, + n_focus=1, + dtype=self.compute_dtype, + trainable=trainable, + ) + ) + else: + post_ffn_norms.append(nn.Identity()) + + ffns.append( + EquivariantFFN( + lmax=self.lmax, + channels=self.channels, + hidden_channels=ffn_neurons, + grid_mlp=self.grid_mlp, + dtype=dtype, + s2_activation=self.ffn_s2_activation, + lebedev_quadrature=self.ffn_lebedev_quadrature, + activation_function=self.ffn_activation_function, + glu_activation=self.ffn_glu_activation, + mlp_bias=self.mlp_bias, + trainable=trainable, + seed=seed_ffn_i, + ) + ) + + self.pre_ffn_norms = nn.ModuleList(pre_ffn_norms) + self.post_ffn_norms = nn.ModuleList(post_ffn_norms) + self.ffns = nn.ModuleList(ffns) + + # Optional per-channel LayerScale on each FFN residual branch + if self.layer_scale: + self.adam_ffn_layer_scales = nn.ParameterList( + [ + nn.Parameter( + torch.ones(self.channels, dtype=self.dtype, device=self.device) + * 1e-3, + requires_grad=trainable, + ) + for _ in range(self.ffn_blocks) + ] + ) + else: + self.adam_ffn_layer_scales = None + + # === Step 3. Optional full attention residuals for block inputs === + if self.use_full_attn_res: + self.full_attn_res_so2: DepthAttnRes | None = DepthAttnRes( + channels=self.channels, + input_dependent=self.full_attn_res_mode == "dependent", + eps=self.eps, + bias=self.mlp_bias, + dtype=self.compute_dtype, + trainable=trainable, + seed=child_seed(seed_full_attn, 0), + ) + self.full_attn_res_ffns: nn.ModuleList | None = nn.ModuleList( + [ + DepthAttnRes( + channels=self.channels, + input_dependent=self.full_attn_res_mode == "dependent", + eps=self.eps, + bias=self.mlp_bias, + dtype=self.compute_dtype, + trainable=trainable, + seed=child_seed(seed_full_attn, i + 1), + ) + for i in range(self.ffn_blocks) + ] + ) + self.block_attn_res_so2 = None + self.block_attn_res_ffns = None + self._forward_impl = self._forward_with_full_attn_res + elif self.use_block_attn_res: + self.full_attn_res_so2 = None + self.full_attn_res_ffns = None + self.block_attn_res_so2: DepthAttnRes | None = DepthAttnRes( + channels=self.channels, + input_dependent=self.block_attn_res_mode == "dependent", + eps=self.eps, + bias=self.mlp_bias, + dtype=self.compute_dtype, + trainable=trainable, + seed=child_seed(seed_block_attn, 0), + ) + self.block_attn_res_ffns: nn.ModuleList | None = nn.ModuleList( + [ + DepthAttnRes( + channels=self.channels, + input_dependent=self.block_attn_res_mode == "dependent", + eps=self.eps, + bias=self.mlp_bias, + dtype=self.compute_dtype, + trainable=trainable, + seed=child_seed(seed_block_attn, i + 1), + ) + for i in range(self.ffn_blocks) + ] + ) + self._forward_impl = self._forward_with_block_attn_res + else: + self.full_attn_res_so2 = None + self.full_attn_res_ffns = None + self.block_attn_res_so2 = None + self.block_attn_res_ffns = None + self._forward_impl = self._forward_with_residual_shortcuts + + def forward( + self, + x: torch.Tensor, + edge_cache: EdgeFeatureCache, + radial_feat: torch.Tensor, + unit_history: list[torch.Tensor] | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + list[torch.Tensor] | None, + ]: + """ + Parameters + ---------- + x + Features with shape `(N, D, 1, C)`. + edge_cache + Edge cache. + radial_feat + Per-edge radial features with shape (E, lmax+1, C). + unit_history + Optional truncated depth history in canonical node layout. When + `full_attn_res != "none"`, it is interpreted as completed unit + history. When `block_attn_res != "none"`, it is interpreted as + completed block history. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, list[torch.Tensor] | None] + Tuple `(block_output, block_summary, so2_unit_output, ffn_unit_outputs)` + in canonical node layout. `block_output` is always returned. + Auxiliary outputs are mode-dependent and may be `None` when the + current caller does not need them: + + - baseline path returns `(block_output, None, None, None)` + - full AttnRes path returns `(block_output, None, so2_unit_output, ffn_unit_outputs)` + - block AttnRes path returns `(block_output, block_summary, None, None)` + """ + return self._forward_impl(x, edge_cache, radial_feat, unit_history) + + def _extract_l0_from_canonical(self, value: torch.Tensor) -> torch.Tensor: + """ + Extract scalar channels from canonical node layout. + + Parameters + ---------- + value + Canonical node features with shape `(N, D, 1, C)`. + + Returns + ------- + torch.Tensor + Scalar channels with shape (N, channels). + """ + return value[:, 0, :, :].reshape(value.shape[0], self.channels) + + def _run_so2_unit( + self, + x: torch.Tensor, + edge_cache: EdgeFeatureCache, + radial_feat: torch.Tensor, + ) -> torch.Tensor: + """ + Run the SO(2) unit without an outer block-level residual shortcut. + + Parameters + ---------- + x + Canonical node features with shape `(N, D, 1, C)`. + edge_cache + Edge cache. + radial_feat + Per-edge radial features with shape (E, lmax+1, C). + + Returns + ------- + torch.Tensor + SO(2) unit output with shape `(N, D, 1, C)`. + """ + n_node = x.shape[0] + ebed_dim = x.shape[1] + channels = self.channels + x_pre = self.pre_so2_norm(x) + so2_unit_output = self.so2_conv( + x_pre.reshape(n_node, ebed_dim, channels), edge_cache, radial_feat + ) + return self.post_so2_norm(so2_unit_output.unsqueeze(2)) + + def _run_ffn_unit(self, x: torch.Tensor, unit_idx: int) -> torch.Tensor: + """ + Run one FFN subblock without the outer unit-level residual shortcut. + + Parameters + ---------- + x + Canonical node features with shape `(N, D, 1, C)`. + unit_idx + FFN subblock index. + + Returns + ------- + torch.Tensor + FFN unit output with shape `(N, D, 1, C)`. + """ + n_node = x.shape[0] + ebed_dim = x.shape[1] + channels = self.channels + x_ffn = x.reshape(n_node, ebed_dim, 1, channels) # (N, D, 1, C) + x_pre = self.pre_ffn_norms[unit_idx](x_ffn) + y: torch.Tensor = self.ffns[unit_idx](x_pre) + y = self.post_ffn_norms[unit_idx](y) + if self.layer_scale: + y = y * self.adam_ffn_layer_scales[unit_idx] + return y + + def _forward_with_residual_shortcuts( + self, + x: torch.Tensor, + edge_cache: EdgeFeatureCache, + radial_feat: torch.Tensor, + unit_history: list[torch.Tensor] | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + list[torch.Tensor] | None, + ]: + """ + Run the original residual-connected block path. + + Parameters + ---------- + x + Canonical node features with shape `(N, D, 1, C)`. + edge_cache + Edge cache. + radial_feat + Per-edge radial features with shape (E, lmax+1, C). + unit_history + Unused in the residual-connected path. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, list[torch.Tensor] | None] + Tuple `(block_output, None, None, None)`. + """ + with nvtx_range("so2_conv"): + so2_unit_output = self._run_so2_unit(x, edge_cache, radial_feat) + so2_state = x + so2_unit_output + + with nvtx_range("ffn"): + ffn_state = so2_state + for i in range(self.ffn_blocks): + ffn_unit_output = self._run_ffn_unit(ffn_state, i) + ffn_state = ffn_state + ffn_unit_output + + block_output = ffn_state + return block_output, None, None, None + + def _forward_with_full_attn_res( + self, + x: torch.Tensor, + edge_cache: EdgeFeatureCache, + radial_feat: torch.Tensor, + unit_history: list[torch.Tensor] | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + list[torch.Tensor] | None, + ]: + """ + Run the block with full attention residuals over unit history. + + Parameters + ---------- + x + Current block input with shape `(N, D, 1, C)`. + edge_cache + Edge cache. + radial_feat + Per-edge radial features with shape (E, lmax+1, C). + unit_history + Truncated history in canonical node layout. Each source has shape + `(N, D, 1, C)`. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, list[torch.Tensor] | None] + Tuple `(block_output, None, so2_unit_output, ffn_unit_outputs)`. + """ + with nvtx_range("so2_conv"): + with nvtx_range("full_attn_res"): + so2_input = self.full_attn_res_so2( + sources=unit_history, + scalar_extractor=self._extract_l0_from_canonical, + current_x=x, + ) + so2_unit_output = self._run_so2_unit(so2_input, edge_cache, radial_feat) + + with nvtx_range("ffn"): + completed_units = [*unit_history, so2_unit_output] + current_x = so2_unit_output + ffn_unit_outputs: list[torch.Tensor] = [] + for i in range(self.ffn_blocks): + with nvtx_range("full_attn_res"): + ffn_input: torch.Tensor = self.full_attn_res_ffns[i]( + sources=completed_units, + scalar_extractor=self._extract_l0_from_canonical, + current_x=current_x, + ) + ffn_unit_output = self._run_ffn_unit(ffn_input, i) + ffn_unit_outputs.append(ffn_unit_output) + completed_units.append(ffn_unit_output) + current_x = ffn_unit_output + + block_output = current_x + return block_output, None, so2_unit_output, ffn_unit_outputs + + def _forward_with_block_attn_res( + self, + x: torch.Tensor, + edge_cache: EdgeFeatureCache, + radial_feat: torch.Tensor, + unit_history: list[torch.Tensor] | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + list[torch.Tensor] | None, + ]: + """ + Run the block with block attention residuals over block history. + + Parameters + ---------- + x + Current block input with shape `(N, D, 1, C)`. + edge_cache + Edge cache. + radial_feat + Per-edge radial features with shape (E, lmax+1, C). + unit_history + Truncated block history in canonical node layout. Each source has shape + `(N, D, 1, C)`. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, list[torch.Tensor] | None] + Tuple `(block_output, block_summary, None, None)`. + """ + with nvtx_range("so2_conv"): + with nvtx_range("block_attn_res"): + so2_input = self.block_attn_res_so2( + sources=unit_history, + scalar_extractor=self._extract_l0_from_canonical, + current_x=x, + ) + so2_unit_output = self._run_so2_unit(so2_input, edge_cache, radial_feat) + + with nvtx_range("ffn"): + partial_block = so2_unit_output + current_x = so2_unit_output + for i in range(self.ffn_blocks): + with nvtx_range("block_attn_res"): + ffn_input: torch.Tensor = self.block_attn_res_ffns[i]( + sources=[*unit_history, partial_block], + scalar_extractor=self._extract_l0_from_canonical, + current_x=current_x, + ) + ffn_unit_output = self._run_ffn_unit(ffn_input, i) + partial_block = partial_block + ffn_unit_output + current_x = ffn_unit_output + + block_output = current_x + block_summary = partial_block + return block_output, block_summary, None, None + + def serialize(self) -> dict[str, Any]: + trainable = all(p.requires_grad for p in self.parameters()) + state = self.state_dict() + return { + "@class": "SeZMInteractionBlock", + "@version": 1, + "config": { + "lmax": self.lmax, + "mmax": self.mmax, + "channels": self.channels, + "n_focus": self.n_focus, + "focus_dim": self.focus_dim, + "focus_compete": self.focus_compete, + "so2_norm": self.so2_norm, + "so2_layers": self.so2_layers, + "so2_attn_res": self.so2_attn_res_mode, + "radial_so2_mode": self.radial_so2_mode, + "radial_so2_rank": self.radial_so2_rank, + "n_atten_head": self.n_atten_head, + "atten_f_mix": self.atten_f_mix, + "atten_v_proj": self.use_atten_v_proj, + "atten_o_proj": self.use_atten_o_proj, + "so2_pre_norm": self.so2_pre_norm, + "so2_post_norm": self.so2_post_norm, + "ffn_pre_norm": self.ffn_pre_norm, + "ffn_post_norm": self.ffn_post_norm, + "ffn_neurons": self.ffn_neurons, + "grid_mlp": self.grid_mlp, + "ffn_blocks": self.ffn_blocks, + "full_attn_res": self.full_attn_res_mode, + "block_attn_res": self.block_attn_res_mode, + "so2_s2_activation": self.so2_s2_activation, + "ffn_s2_activation": self.ffn_s2_activation, + "so2_lebedev_quadrature": self.so2_lebedev_quadrature, + "ffn_lebedev_quadrature": self.ffn_lebedev_quadrature, + "so2_activation_function": self.so2_activation_function, + "ffn_activation_function": self.ffn_activation_function, + "ffn_glu_activation": self.ffn_glu_activation, + "mlp_bias": self.mlp_bias, + "layer_scale": self.layer_scale, + "eps": self.eps, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "trainable": trainable, + "seed": None, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SeZMInteractionBlock: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "SeZMInteractionBlock": + raise ValueError(f"Invalid class for SeZMInteractionBlock: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj diff --git a/deepmd/pt/model/descriptor/sezm_nn/dens.py b/deepmd/pt/model/descriptor/sezm_nn/dens.py new file mode 100644 index 0000000000..e08c6bccf7 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/dens.py @@ -0,0 +1,759 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +DeNS-specific SeZM modules. + +This module provides the force embedding together with the +parallel SeZM `dens` fitting branches: + +1. An energy head operating on the scalar descriptor. +2. A clean-force head operating on the final equivariant latent. +3. A denoising head operating on the same latent. +""" + +from __future__ import ( + annotations, +) + +import copy +import math +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.task.sezm_ener import ( + SeZMEnergyFittingNet, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .so3 import ( + SO3Linear, +) +from .utils import ( + np_safe, + safe_numpy_to_tensor, +) + +_SQRT_2 = math.sqrt(2.0) +_SQRT_INV_3 = 1.0 / math.sqrt(3.0) +_SQRT_4PI_OVER_3 = math.sqrt(4.0 * math.pi / 3.0) + + +def _build_real_sh_norm(lmax: int, *, device: torch.device) -> torch.Tensor: + """Precompute real-spherical-harmonic normalization factors.""" + norm = torch.zeros(lmax + 1, lmax + 1, dtype=torch.float64, device=device) + for l in range(lmax + 1): + for m in range(l + 1): + norm[l, m] = math.sqrt( + (2 * l + 1) + / (4.0 * math.pi) + * math.exp(math.lgamma(l - m + 1) - math.lgamma(l + m + 1)) + ) + return norm + + +def _associated_legendre_all( + lmax: int, + x: torch.Tensor, +) -> torch.Tensor: + """ + Evaluate associated Legendre polynomials `P_l^m(x)` up to `lmax`. + + Parameters + ---------- + lmax + Maximum angular degree. + x + Cosine values with shape `(N,)`. + + Returns + ------- + torch.Tensor + Tensor with shape `(lmax + 1, lmax + 1, N)` where the second axis is + `m`. Entries with `m > l` stay zero. + """ + n_sample = x.shape[0] + out = x.new_zeros((lmax + 1, lmax + 1, n_sample)) + out[0, 0] = 1.0 + if lmax == 0: + return out + + sin_theta = torch.sqrt((1.0 - x * x).clamp_min(0.0)) + for m in range(1, lmax + 1): + out[m, m] = -(2 * m - 1) * sin_theta * out[m - 1, m - 1] + for m in range(lmax): + out[m + 1, m] = (2 * m + 1) * x * out[m, m] + for m in range(lmax + 1): + for l in range(m + 2, lmax + 1): + out[l, m] = ( + (2 * l - 1) * x * out[l - 1, m] - (l + m - 1) * out[l - 2, m] + ) / float(l - m) + return out + + +def _real_spherical_harmonics( + lmax: int, + unit_vec: torch.Tensor, + sh_norm: torch.Tensor, + sqrt_2: torch.Tensor, +) -> torch.Tensor: + """ + Compute packed real spherical harmonics in the SeZM `(l, m)` layout. + + Parameters + ---------- + lmax + Maximum angular degree. + unit_vec + Unit vectors with shape `(N, 3)`. + + Returns + ------- + torch.Tensor + Packed real spherical harmonics with shape `(N, (lmax + 1) ** 2)`. + """ + x = unit_vec[:, 0] + y = unit_vec[:, 1] + z = unit_vec[:, 2].clamp(-1.0, 1.0) + phi = torch.atan2(y, x) + legendre = _associated_legendre_all(lmax, z) + + out = unit_vec.new_zeros((unit_vec.shape[0], (lmax + 1) ** 2)) + for l in range(lmax + 1): + for m in range(l + 1): + base = legendre[l, m] * sh_norm[l, m] + zero_idx = l * l + l + if m == 0: + out[:, zero_idx] = base + continue + sin_term = torch.sin(float(m) * phi) + cos_term = torch.cos(float(m) * phi) + out[:, zero_idx - m] = sqrt_2 * base * sin_term + out[:, zero_idx + m] = sqrt_2 * base * cos_term + return out + + +class ForceEmbedding(torch.nn.Module): + """ + Embed atom-wise force inputs into the SeZM SO(3) latent space. + + Parameters + ---------- + lmax + Maximum spherical harmonic degree of the receiving backbone state. + channels + Number of channels per `(l, m)` coefficient. + precision + Module precision. + mlp_bias + Whether the final SO(3) projection uses an `l=0` bias. + trainable + Whether the projection weights are trainable. + seed + Initialization seed. + eps + Numerical epsilon used for vector normalization. + """ + + def __init__( + self, + *, + lmax: int, + channels: int, + precision: str = DEFAULT_PRECISION, + mlp_bias: bool = True, + trainable: bool = True, + seed: int | list[int] | None = None, + eps: float = 1e-7, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.channels = int(channels) + self.precision = str(precision) + self.dtype = PRECISION_DICT[self.precision] + self.device = env.DEVICE + self.eps = float(eps) + self.register_buffer( + "sqrt_inv_3", + torch.tensor(_SQRT_INV_3, dtype=self.dtype, device=self.device), + persistent=True, + ) + self.register_buffer( + "sqrt_2", + torch.tensor(_SQRT_2, dtype=self.dtype, device=self.device), + persistent=True, + ) + self.register_buffer( + "sh_norm", + _build_real_sh_norm(self.lmax, device=self.device).to(dtype=self.dtype), + persistent=True, + ) + self.proj = SO3Linear( + lmax=self.lmax, + in_channels=1, + out_channels=self.channels, + n_focus=1, + dtype=self.dtype, + mlp_bias=mlp_bias, + trainable=trainable, + seed=seed, + ) + + def forward( + self, + force_input: torch.Tensor, + noise_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Project atom-wise force inputs into the SeZM SO(3) layout. + + Parameters + ---------- + force_input + Force tensor with shape `(nf, nloc, 3)` or `(N, 3)`. + noise_mask + Optional corruption mask with shape `(nf, nloc)` or `(N,)`. + Only masked atoms contribute non-zero embeddings. + + Returns + ------- + torch.Tensor + Force embedding with shape `(nf * nloc, D, 1, channels)`. + """ + if force_input.ndim == 3: + force_input = force_input.reshape(-1, 3) + elif force_input.ndim != 2 or force_input.shape[-1] != 3: + raise ValueError( + "`force_input` must have shape (nf, nloc, 3) or (N, 3) for force embedding." + ) + + if noise_mask is None: + mask = torch.ones( + force_input.shape[0], + device=force_input.device, + dtype=torch.bool, + ) + else: + mask = noise_mask.reshape(-1).to( + dtype=torch.bool, device=force_input.device + ) + if mask.shape[0] != force_input.shape[0]: + raise ValueError( + "`noise_mask` must match the flattened atom dimension of `force_input`." + ) + + force_input = force_input.to(dtype=self.dtype) + force_norm = torch.linalg.vector_norm(force_input, dim=-1) + safe_norm = force_norm.clamp_min(self.eps) + unit_vec = force_input / safe_norm.unsqueeze(-1) + sh = _real_spherical_harmonics( + self.lmax, + unit_vec, + self.sh_norm, + self.sqrt_2, + ) + sh = sh * (force_norm * self.sqrt_inv_3).unsqueeze(-1) + sh = sh.view(force_input.shape[0], -1, 1, 1) + embedded = self.proj(sh) + return embedded * mask.view(-1, 1, 1, 1).to(dtype=embedded.dtype) + + +class _SeZMVectorHead(torch.nn.Module): + """ + Read a Cartesian vector from the `l=1` SeZM latent block. + + Parameters + ---------- + lmax + Maximum angular degree of the input latent. + channels + Number of input channels per `(l, m)` coefficient. + precision + Module precision. + mlp_bias + Whether the SO(3) projection uses an `l=0` bias. + trainable + Whether parameters are trainable. + seed + Initialization seed. + """ + + def __init__( + self, + *, + lmax: int, + channels: int, + precision: str = DEFAULT_PRECISION, + mlp_bias: bool = False, + trainable: bool = True, + seed: int | list[int] | None = None, + ) -> None: + super().__init__() + self.lmax = int(lmax) + if self.lmax < 1: + raise ValueError("`lmax` must be >= 1 for a vector-valued SeZM head.") + self.channels = int(channels) + self.precision = str(precision) + self.dtype = PRECISION_DICT[self.precision] + self.device = env.DEVICE + self.register_buffer( + "cartesian_scale", + torch.tensor(_SQRT_4PI_OVER_3, dtype=self.dtype, device=self.device), + persistent=True, + ) + self.proj = SO3Linear( + lmax=self.lmax, + in_channels=self.channels, + out_channels=1, + n_focus=1, + dtype=self.dtype, + mlp_bias=mlp_bias, + trainable=trainable, + seed=seed, + ) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + """ + Predict Cartesian vectors from the final SeZM equivariant latent. + + Parameters + ---------- + latent + Final equivariant latent with shape `(nf * nloc, D, 1, channels)`. + + Returns + ------- + torch.Tensor + Cartesian vectors with shape `(nf * nloc, 3)`. + """ + projected = self.proj(latent.to(dtype=self.dtype)) + l1 = projected[:, 1:4, 0, 0] + # SeZM keeps the l=1 packed basis as (-y, z, -x), so decode back to + # Cartesian order (x, y, z) with two sign flips and one permutation. + return self.cartesian_scale * torch.stack( + [-l1[:, 2], -l1[:, 0], l1[:, 1]], + dim=-1, + ) + + +class SeZMDirectForceHead(_SeZMVectorHead): + """Predict clean direct forces from the final SeZM latent.""" + + +class SeZMDenoisingHead(_SeZMVectorHead): + """Predict denoising vectors from the final SeZM latent.""" + + +class SeZMDeNSEnergyHead(SeZMEnergyFittingNet): + """Energy head used by the SeZM `dens` fitting network.""" + + +class SeZMDeNSFittingNet(torch.nn.Module): + """ + Parallel SeZM fitting branches for the `dens` mode. + + Parameters + ---------- + ntypes + Number of atom types. + dim_descrpt + Scalar descriptor width. + condition_lmax + Maximum spherical harmonic degree of the descriptor entry state that + receives the external force embedding. + latent_lmax + Maximum spherical harmonic degree of the final equivariant latent. + channels + Number of latent channels per `(l, m)` coefficient. + neuron + Hidden widths of the scalar energy branch. + bias_atom_e + Optional per-type atomic energy bias for the scalar energy branch. + resnet_dt + Residual time-step flag for the scalar energy branch. + numb_fparam + Number of frame parameters. + numb_aparam + Number of atomic parameters. + dim_case_embd + Case embedding width for the scalar energy branch. + case_film_embd + Whether the scalar energy branch uses case FiLM conditioning. + activation_function + Activation function of the scalar energy branch. + bias_out + Whether the scalar energy branch uses output bias. + precision + Module precision. + mixed_types + Whether the scalar energy branch shares parameters across atom types. + seed + Initialization seed. + type_map + Atom type names. + default_fparam + Default frame parameters for the scalar energy branch. + rcond + Optional condition number used by the scalar energy branch. + exclude_types + Atom types excluded by the scalar energy branch. + trainable + Whether the `dens` fitting parameters are trainable. + atom_ener + Optional vacuum atomic energy contribution for the scalar energy branch. + use_aparam_as_mask + Whether atomic parameters act as masks in the scalar energy branch. + """ + + def __init__( + self, + *, + ntypes: int, + dim_descrpt: int, + condition_lmax: int, + latent_lmax: int, + channels: int, + neuron: list[int] | None = None, + bias_atom_e: torch.Tensor | None = None, + resnet_dt: bool = False, + numb_fparam: int = 0, + numb_aparam: int = 0, + dim_case_embd: int = 0, + case_film_embd: bool = False, + activation_function: str = "silu", + bias_out: bool = False, + precision: str = DEFAULT_PRECISION, + mixed_types: bool = True, + seed: int | list[int] | None = None, + type_map: list[str] | None = None, + default_fparam: list[float] | None = None, + rcond: float | None = None, + exclude_types: list[int] | None = None, + trainable: bool | list[bool] = True, + atom_ener: list[torch.Tensor | None] | None = None, + use_aparam_as_mask: bool = False, + ) -> None: + super().__init__() + if neuron is None: + neuron = [0] + self.ntypes = int(ntypes) + self.dim_descrpt = int(dim_descrpt) + self.condition_lmax = int(condition_lmax) + self.latent_lmax = int(latent_lmax) + self.channels = int(channels) + self.neuron = [int(width) for width in neuron] + self.activation_function = str(activation_function) + self.precision = str(precision) + self.mixed_types = bool(mixed_types) + self.numb_fparam = int(numb_fparam) + self.numb_aparam = int(numb_aparam) + self.dim_case_embd = int(dim_case_embd) + self.case_film_embd = bool(case_film_embd and self.dim_case_embd > 0) + self.bias_out = bool(bias_out) + self.resnet_dt = bool(resnet_dt) + self.type_map = None if type_map is None else list(type_map) + self.default_fparam = default_fparam + self.rcond = None if rcond is None else float(rcond) + self.exclude_types = [] if exclude_types is None else list(exclude_types) + self.trainable = copy.deepcopy(trainable) + self.atom_ener = atom_ener + self.use_aparam_as_mask = bool(use_aparam_as_mask) + self._return_middle_output = False + self.has_force_embedding_latent = self.condition_lmax >= 1 + self.has_vector_latent = self.latent_lmax >= 1 + trainable_flag = ( + all(self.trainable) + if isinstance(self.trainable, list) + else bool(self.trainable) + ) + + # === Step 1. Build the scalar energy branch === + self.energy_head = SeZMDeNSEnergyHead( + ntypes=self.ntypes, + dim_descrpt=self.dim_descrpt, + neuron=self.neuron, + bias_atom_e=bias_atom_e, + resnet_dt=self.resnet_dt, + numb_fparam=self.numb_fparam, + numb_aparam=self.numb_aparam, + dim_case_embd=self.dim_case_embd, + case_film_embd=self.case_film_embd, + activation_function=self.activation_function, + bias_out=self.bias_out, + precision=self.precision, + mixed_types=self.mixed_types, + seed=child_seed(seed, 0), + type_map=self.type_map, + default_fparam=self.default_fparam, + rcond=self.rcond, + exclude_types=self.exclude_types, + trainable=self.trainable, + atom_ener=self.atom_ener, + use_aparam_as_mask=self.use_aparam_as_mask, + ) + + # === Step 2. Build force-embedding and vector heads === + if self.has_force_embedding_latent: + self.force_embedding = ForceEmbedding( + lmax=self.condition_lmax, + channels=self.channels, + precision=self.precision, + mlp_bias=True, + trainable=trainable_flag, + seed=child_seed(seed, 1), + ) + else: + self.force_embedding = None + + if self.has_vector_latent: + self.direct_force_head = SeZMDirectForceHead( + lmax=self.latent_lmax, + channels=self.channels, + precision=self.precision, + mlp_bias=False, + trainable=trainable_flag, + seed=child_seed(seed, 2), + ) + self.denoising_head = SeZMDenoisingHead( + lmax=self.latent_lmax, + channels=self.channels, + precision=self.precision, + mlp_bias=False, + trainable=trainable_flag, + seed=child_seed(seed, 3), + ) + else: + self.direct_force_head = None + self.denoising_head = None + + def output_def(self) -> FittingOutputDef: + """Return the public fitting output contract for `dens` mode.""" + return FittingOutputDef( + [ + OutputVariableDef( + "energy", + [1], + reducible=True, + r_differentiable=False, + c_differentiable=False, + ), + OutputVariableDef( + "dforce", + [3], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), + ] + ) + + def get_dim_fparam(self) -> int: + """Return the frame-parameter width of the energy branch.""" + return self.energy_head.get_dim_fparam() + + def has_default_fparam(self) -> bool: + """Return whether the energy branch has default frame parameters.""" + return self.energy_head.has_default_fparam() + + def get_default_fparam(self) -> torch.Tensor | None: + """Return default frame parameters of the energy branch.""" + return self.energy_head.get_default_fparam() + + def get_dim_aparam(self) -> int: + """Return the atomic-parameter width of the energy branch.""" + return self.energy_head.get_dim_aparam() + + def get_sel_type(self) -> list[int]: + """Return selected atom types of the energy branch.""" + return self.energy_head.get_sel_type() + + def set_return_middle_output(self, enable: bool) -> None: + """Enable or disable forwarding of the scalar energy hidden activations.""" + self._return_middle_output = bool(enable) + self.energy_head.set_return_middle_output(enable) + + def build_force_embedding( + self, + force_input: torch.Tensor, + noise_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Build the descriptor-entry force embedding from atom-wise force inputs. + + Parameters + ---------- + force_input + Force tensor with shape `(nf, nloc, 3)` or `(N, 3)`. + noise_mask + Optional corruption mask. + + Returns + ------- + torch.Tensor + Force embedding with shape `(nf * nloc, D_cond, 1, channels)`. + """ + if self.force_embedding is None: + raise RuntimeError( + f"SeZM `dens` mode requires descriptor condition_lmax >= 1. Got condition_lmax={self.condition_lmax}." + ) + return self.force_embedding(force_input, noise_mask=noise_mask) + + def change_type_map( + self, + type_map: list[str], + model_with_new_type_stat: Any | None = None, + ) -> None: + """ + Update type-related metadata for the scalar energy branch. + + Parameters + ---------- + type_map + New atom type map. + model_with_new_type_stat + Optional reference model carrying new-type statistics. + """ + self.type_map = list(type_map) + ref_energy_head = ( + None + if model_with_new_type_stat is None + else model_with_new_type_stat.energy_head + ) + self.energy_head.change_type_map( + type_map=type_map, + model_with_new_type_stat=ref_energy_head, + ) + + def forward( + self, + descriptor: torch.Tensor, + latent: torch.Tensor, + atype: torch.Tensor, + *, + noise_mask: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + return_components: bool = False, + ) -> dict[str, torch.Tensor]: + """ + Run the parallel `dens` fitting branches. + + Parameters + ---------- + descriptor + Scalar descriptor with shape `(nf, nloc, dim_descrpt)`. + latent + Final equivariant latent with shape `(nf * nloc, D, 1, channels)`. + atype + Atom types with shape `(nf, nloc)`. + noise_mask + Optional corruption mask with shape `(nf, nloc)`. + fparam + Optional frame parameters. + aparam + Optional atomic parameters. + return_components + If true, also return the clean-force and denoising branches. + + Returns + ------- + dict[str, torch.Tensor] + Public outputs contain `energy` and mixed `dforce`. + """ + if self.direct_force_head is None or self.denoising_head is None: + raise RuntimeError( + f"SeZM `dens` mode requires descriptor latent_lmax >= 1. Got latent_lmax={self.latent_lmax}." + ) + nf, nloc = atype.shape[:2] + energy_ret = self.energy_head( + descriptor, + atype, + fparam=fparam, + aparam=aparam, + ) + clean_force = self.direct_force_head(latent).view(nf, nloc, 3) + denoising_force = self.denoising_head(latent).view(nf, nloc, 3) + + if noise_mask is None: + mixed_force = clean_force + else: + mask = noise_mask.to(dtype=torch.bool, device=clean_force.device).unsqueeze( + -1 + ) + mixed_force = torch.where(mask, denoising_force, clean_force) + + result = { + "energy": energy_ret["energy"], + "dforce": mixed_force.to(dtype=descriptor.dtype), + } + if "middle_output" in energy_ret: + result["middle_output"] = energy_ret["middle_output"] + if return_components: + result["clean_dforce"] = clean_force + result["denoising_dforce"] = denoising_force + return result + + def serialize(self) -> dict[str, Any]: + """Serialize the SeZM `dens` fitting network.""" + state = self.state_dict() + return { + "@class": "SeZMDeNSFittingNet", + "@version": 1, + "config": { + "ntypes": self.ntypes, + "dim_descrpt": self.dim_descrpt, + "condition_lmax": self.condition_lmax, + "latent_lmax": self.latent_lmax, + "channels": self.channels, + "neuron": self.neuron.copy(), + "resnet_dt": self.resnet_dt, + "numb_fparam": self.numb_fparam, + "numb_aparam": self.numb_aparam, + "dim_case_embd": self.dim_case_embd, + "case_film_embd": self.case_film_embd, + "activation_function": self.activation_function, + "bias_out": self.bias_out, + "precision": self.precision, + "mixed_types": self.mixed_types, + "type_map": self.type_map, + "default_fparam": self.default_fparam, + "rcond": self.rcond, + "exclude_types": self.exclude_types.copy(), + "trainable": self.trainable, + "atom_ener": self.atom_ener, + "use_aparam_as_mask": self.use_aparam_as_mask, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SeZMDeNSFittingNet: + """Deserialize the SeZM `dens` fitting network.""" + data = data.copy() + if data.pop("@class") != "SeZMDeNSFittingNet": + raise ValueError("Invalid class for SeZMDeNSFittingNet deserialization.") + version = int(data.pop("@version", 1)) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + obj = cls(**config) + state = {key: safe_numpy_to_tensor(value) for key, value in variables.items()} + obj.load_state_dict(state) + return obj diff --git a/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py b/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py new file mode 100644 index 0000000000..6aa73f475e --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/edge_cache.py @@ -0,0 +1,878 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Edge cache construction utilities for SeZM. + +This module defines the shared procedures that assemble per-edge geometry, +radial features, rotation blocks, and normalization terms used by the SeZM +descriptor. +""" + +from __future__ import ( + annotations, +) + +import math +from collections.abc import ( + Callable, +) +from typing import ( + NamedTuple, +) + +import torch +from einops import ( + rearrange, +) + +from .triton import ( + edge_geometry_rbf_triton, +) +from .utils import ( + get_promoted_dtype, + nvtx_range, + safe_norm, +) +from .wignerd import ( + build_edge_quaternion, + quaternion_multiply, + quaternion_z_rotation, +) + +WignerCalculatorFn = Callable[[torch.Tensor], tuple[torch.Tensor, torch.Tensor]] +EdgeTypeKeepMaskFn = Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] + + +class EdgeFeatureCache(NamedTuple): + """ + Global edge feature cache created once per forward(). + + All tensors are aligned on the same edge axis (E = number of valid edges). + + Parameters + ---------- + src + Source node indices with shape (E,). + dst + Destination node indices with shape (E,). + edge_type_feat + Per-edge type embeddings with shape (E, C), computed as src+dst. + edge_vec + Edge vectors with shape (E, 3) in Å. + edge_rbf + Radial basis with shape (E, n_radial). + The C^3 cutoff envelope is already baked in. + edge_env + C^3 cutoff envelope weights with shape (E, 1). + deg + Envelope-squared smooth degree with shape (N,), computed as + ``sum(edge_env**2)`` over incoming edges. + Used for smooth normalization in EnvironmentInitialEmbedding. + inv_sqrt_deg + Inverse square root smooth degree normalization with shape (N, 1, 1). + D_full + Block-diagonal Wigner-D matrix with shape (E, D, D) where D=(lmax+1)^2. + Used for efficient batched rotation. None if not available. + Dt_full + Transpose of D_full with shape (E, D, D). None if not available. + D_to_m_cache + Lazy cache for projected D matrices keyed by a normalized + ``"lmax:mmax"`` identifier. + Dt_from_m_cache + Lazy cache for projected Dt matrices keyed by a normalized + ``"lmax:mmax"`` identifier. + edge_src_gate + Optional per-edge Source Freeze Propagation Gate (SFPG) weight with + shape (E, 1). Equals ``eta[src]`` where + ``eta[j] = prod_{k in N(j)} w(r_{jk})`` and ``w`` is the + :class:`BridgingSwitch` C3 switching amplitude. Present only when + the model runs in bridging mode; ``None`` otherwise. Aggregation + sites (``GeometricInitialEmbedding``, ``EnvironmentInitialEmbedding``, + ``SO2Convolution``) multiply their per-edge message contribution + by this gate to forbid any node whose local neighborhood enters + the frozen zone from propagating information along its outgoing + edges. + """ + + src: torch.Tensor + dst: torch.Tensor + edge_type_feat: torch.Tensor + edge_vec: torch.Tensor + edge_rbf: torch.Tensor + edge_env: torch.Tensor + deg: torch.Tensor + inv_sqrt_deg: torch.Tensor + D_full: torch.Tensor | None = None + Dt_full: torch.Tensor | None = None + D_to_m_cache: dict[str, torch.Tensor] | None = None + Dt_from_m_cache: dict[str, torch.Tensor] | None = None + edge_src_gate: torch.Tensor | None = None + + +def compute_edge_src_gate( + *, + edge_len: torch.Tensor, + src: torch.Tensor, + n_nodes: int, + bridging_switch: Callable[[torch.Tensor], torch.Tensor], + edge_keep_f: torch.Tensor | None = None, +) -> torch.Tensor: + """ + Compute the per-edge source gate for SFPG from edge lengths. + + The gate implements a per-node "non-frozen confidence" and broadcasts + it back to edges along the source axis:: + + w_e = bridging_switch(edge_len_e) in [0, 1] + eta_j = prod_{e: src_e = j} w_e in [0, 1] + gate_e = eta_{src_e} in [0, 1] + + ``w_e = 0`` at ``r_{jk} <= r_inner`` ensures ``eta_j = 0`` for any + node with at least one neighbor in the frozen zone. Masked edges + (padding, excluded type pairs) must contribute the multiplicative + identity ``1`` so they never spuriously mute a valid source node; + callers supply ``edge_keep_f`` for this. + + The product is **not** realised by ``scatter_reduce(reduce="prod")``: + its registered backward handles exact zeros with a data-dependent + "count leave-one-out" branch that creates unbacked symints under + ``make_fx(tracing_mode="symbolic")`` and breaks the SeZM compile + path's double-backward tracing. Instead, the product is decomposed + into a log-sum on non-zero contributions combined with an explicit + "any zero per group" indicator that routes the frozen case through + ``torch.where``. Both branches use only shape-preserving standard + ops (``scatter_add``, ``where``, ``exp``, ``log``) with backed + symints, so the graph survives symbolic tracing cleanly. + + The gradient consequence at the plateau is exact: ``BridgingSwitch`` + places ``w'(r) = 0`` for every ``r <= r_inner``, so the chain rule + ``d eta / d r = (leave-one-out factor) * w'(r) = anything * 0 = 0`` + holds regardless of how the muted ``torch.where`` branch treats the + upstream gradient. In the transition zone every edge has strictly + positive ``w`` and the log-sum branch gives the standard product + gradient. + + Parameters + ---------- + edge_len + Per-edge distances with shape (E, 1). + src + Source node indices with shape (E,). + n_nodes + Total number of nodes N. + bridging_switch + Callable ``r -> w(r)`` with ``w: [0, ∞) -> [0, 1]``, typically a + :class:`BridgingSwitch` instance. + edge_keep_f + Optional per-edge keep weights with shape (E, 1), with ``0`` on + masked edges and ``1`` on kept edges. If provided, masked edges + are rewritten to ``w = 1`` before the product reduction. + + Returns + ------- + torch.Tensor + Per-edge source gate with shape (E, 1), aligned on the same edge + axis as the rest of the cache. + """ + # === Step 1. Per-edge switching amplitude w(r) in [0, 1] === + edge_w = bridging_switch(edge_len) # (E, 1) + if edge_keep_f is not None: + # Force w = 1 on masked edges so they are neutral for the product. + edge_w = edge_w * edge_keep_f + (1.0 - edge_keep_f) + + edge_w_flat = edge_w.squeeze(-1) # (E,) + is_zero = edge_w_flat <= 0.0 # (E,) bool + + # === Step 2. Log-sum reduction on non-zero contributions === + # Replace exact zeros with the multiplicative identity 1 so their + # ``log`` contribution is 0 and the group-wise sum equals the log of + # the product of non-zero ``w`` values. + safe_w = torch.where(is_zero, torch.ones_like(edge_w_flat), edge_w_flat) + log_safe = torch.log(safe_w) + log_eta = torch.zeros( + n_nodes, dtype=edge_w.dtype, device=edge_w.device + ).scatter_add(0, src, log_safe) + eta_nonzero_path = torch.exp(log_eta) + + # === Step 3. Exact-zero indicator per source node === + # ``scatter_add`` over an ``int64`` cast of the zero mask counts how + # many frozen edges each source node owns. A strictly positive count + # means the product is 0 by the hard-freeze rule. + zero_count = torch.zeros( + n_nodes, dtype=torch.int64, device=edge_w.device + ).scatter_add(0, src, is_zero.to(torch.int64)) + any_zero = zero_count > 0 + + # === Step 4. Combine and broadcast back to edges via source === + eta = torch.where(any_zero, torch.zeros_like(eta_nonzero_path), eta_nonzero_path) + return eta.index_select(0, src).unsqueeze(-1) + + +@torch.amp.autocast("cuda", enabled=False) +def build_edge_cache( + *, + type_ebed: torch.Tensor, + extended_coord: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + pair_keep_mask: torch.Tensor, + eps: float, + edge_envelope: Callable[[torch.Tensor], torch.Tensor], + radial_basis: Callable[[torch.Tensor], torch.Tensor], + n_radial: int, + random_gamma: bool, + wigner_calc: WignerCalculatorFn, + use_geometry_rbf_triton: bool = False, +) -> EdgeFeatureCache: + """ + Build the global edge cache from DeePMD padded neighbor list. + + This converts DeePMD's per-frame padded neighbor list into a flat list of + valid edges used by message passing, and computes all per-edge tensors that + are reused across blocks. + + The resulting cache contains: + + - per-edge endpoints: ``src``, ``dst`` and per-edge type features: ``edge_type_feat`` (src+dst) + - per-edge geometry: ``edge_vec`` + - per-edge smooth weights: C^3 cutoff envelope ``edge_env`` + - per-edge radial basis: ``edge_rbf`` (envelope already baked in) + - per-edge rotation blocks: block-diagonal Wigner-D matrices ``D_full`` and ``Dt_full`` + - destination-node smooth normalization: ``inv_sqrt_deg`` from + envelope-squared degree ``sum(edge_env**2)`` + + Notes + ----- + Input formats follow DeePMD conventions: + + - ``extended_coord`` has shape ``(nf, nall, 3)``. + - ``nlist`` has shape ``(nf, nloc, nnei)`` and stores indices into the extended axis + (``0..nall-1``), with ``-1`` indicating padding. + - ``mapping`` (when provided) maps extended indices to local indices ``0..nloc-1``. + When ``mapping`` is ``None``, the function assumes the neighbor indices are already local. + + This function builds the edge cache directly on the valid edge set, so + padded or excluded neighbor slots never enter the geometry, radial basis, + or Wigner-D evaluation. + + Parameters + ---------- + type_ebed + Per-node type embedding with shape (N, C), where N=nf*nloc. + extended_coord + Extended coordinates with shape (nf, nall, 3). + nlist + Neighbor list with shape (nf, nloc, nnei). + mapping + Mapping from extended indices to local indices with shape (nf, nall), or None. + pair_keep_mask + Pair keep mask from `PairExcludeMask` with shape (nf, nloc, nnei). True means keep. + eps + Small positive epsilon for safe norm and degree normalization. + edge_envelope + C^3 edge envelope module. + radial_basis + Radial basis module. + n_radial + Number of radial basis channels used for empty-cache allocation. + random_gamma + Whether to apply a random roll around the local +Z axis before + constructing Wigner-D blocks. + wigner_calc + Callable that converts edge-aligned quaternions into packed Wigner-D + blocks. + use_geometry_rbf_triton + Whether to allow the standard-path fused Triton geometry/RBF chain + ``gather -> vec -> len -> env -> rbf``. + + Returns + ------- + EdgeFeatureCache + Per-edge cache. + """ + nf, nloc, nnei = nlist.shape + n_nodes = int(nf * nloc) + + # === Step 1. Force fp32+ for geometry === + geom_dtype = get_promoted_dtype(extended_coord.dtype) + coord = extended_coord.to(dtype=geom_dtype) # (nf, nall, 3) + nall = coord.shape[1] + + # === Step 2. Build valid edge indices once === + with nvtx_range("index"): + src, dst, center_coord_index, neighbor_coord_index = _build_standard_edge_index( + nlist=nlist, + mapping=mapping, + pair_keep_mask=pair_keep_mask, + nall=nall, + ) + + if src.numel() == 0: + return _get_empty_edge_cache( + n_nodes=n_nodes, + n_radial=n_radial, + n_channel=type_ebed.shape[1], + device=extended_coord.device, + dtype=extended_coord.dtype, + ) + + # === Step 3-5. Edge geometry/RBF chain === + # This segment covers: + # gather -> edge_vec -> edge_len -> edge_env -> edge_rbf + # The Triton path is only used on the standard non-compile path when the + # caller explicitly allows it (descriptor eval/inference path). Bridging + # primitives never enter here; they are owned by the sparse-edge path. + coord_flat = coord.reshape(nf * nall, 3) + use_bessel_triton = ( + use_geometry_rbf_triton + and getattr(radial_basis, "basis_type", "bessel") == "bessel" + ) + if use_bessel_triton: + with nvtx_range("edge_geometry_rbf_triton"): + edge_vec, edge_len, edge_env, edge_rbf = edge_geometry_rbf_triton( + coord_flat=coord_flat, + center_coord_index=center_coord_index, + neighbor_coord_index=neighbor_coord_index, + edge_envelope=edge_envelope, + radial_basis=radial_basis, + eps=eps, + inner_clamp=None, + ) + else: + # === Step 3. Gather per-edge geometry === + # edge_vec points from center -> neighbor: r_ij = r_j - r_i (in Å). + # edge_len is the scalar distance. + with nvtx_range("edge_geom"): + center_pos = coord_flat.index_select(0, center_coord_index) + neighbor_pos = coord_flat.index_select(0, neighbor_coord_index) + edge_vec = neighbor_pos - center_pos # (E, 3) + edge_len = safe_norm(edge_vec, eps) # (E, 1) + + # === Step 4. C^3 envelope weight === + # Edges with r >= rcut are not removed from the cache. Their envelope is + # exactly zero, so messages vanish naturally while degree normalization + # remains smooth at the cutoff boundary. + with nvtx_range("envelope"): + edge_env = edge_envelope(edge_len) # (E, 1) + + # === Step 5. Radial basis (envelope already baked in) === + with nvtx_range("radial_basis"): + edge_rbf = radial_basis(edge_len) # (E, n_radial) + + # === Step 6. Edge quaternion -> Wigner-D blocks === + with nvtx_range("wigner_d"): + D_full, Dt_full = _build_edge_wigner( + edge_vec=edge_vec, + edge_len=edge_len, + eps=eps, + random_gamma=random_gamma, + wigner_calc=wigner_calc, + ) # (E, D, D), (E, D, D) + + edge_type_feat = build_edge_type_feat(type_ebed, src, dst) # (E, C) + + return _finalize_edge_cache( + n_nodes=n_nodes, + src=src, + dst=dst, + edge_type_feat=edge_type_feat, + edge_vec=edge_vec, + edge_rbf=edge_rbf, + edge_env=edge_env, + D_full=D_full, + Dt_full=Dt_full, + eps=eps, + ) + + +@torch.amp.autocast("cuda", enabled=False) +def build_edge_cache_from_edges( + *, + type_ebed: torch.Tensor, + atype_flat: torch.Tensor, + edge_index: torch.Tensor, + edge_vec: torch.Tensor, + edge_mask: torch.Tensor, + compute_dtype: torch.dtype, + eps: float, + inner_clamp: Callable[[torch.Tensor], torch.Tensor] | None, + bridging_switch: Callable[[torch.Tensor], torch.Tensor] | None, + edge_envelope: Callable[[torch.Tensor], torch.Tensor], + radial_basis: Callable[[torch.Tensor], torch.Tensor], + has_exclude_types: bool, + edge_type_keep_mask: EdgeTypeKeepMaskFn, + random_gamma: bool, + wigner_calc: WignerCalculatorFn, +) -> EdgeFeatureCache: + """ + Build the global edge cache from a sparse edge list. + + Parameters + ---------- + type_ebed + Per-node type embedding with shape (N, C), where N=nf*nloc. + atype_flat + Flattened local atom types with shape (N,). + edge_index + Edge indices with shape (2, E). + edge_vec + Edge vectors with shape (E, 3) in Å. + edge_mask + Edge mask with shape (E,). True means keep. + compute_dtype + Promoted compute dtype used for geometry and radial features. + eps + Small positive epsilon for safe norm and degree normalization. + inner_clamp + Optional inner clamp used to freeze short-range geometry below `r_inner`. + bridging_switch + Optional C3 switching amplitude ``w(r) -> [0, 1]`` that drives + the Source Freeze Propagation Gate. When provided, a per-edge + ``edge_src_gate`` is computed from the node-wise product of + ``w(r_{jk})`` along each source node's outgoing edges. Masked + edges (``edge_keep=False``) are forced to ``w=1`` so they never + leak into the product. + edge_envelope + C^3 edge envelope module. + radial_basis + Radial basis module. + has_exclude_types + Whether excluded type pairs should be filtered in this path. + edge_type_keep_mask + Callable that builds the keep mask for edge type exclusions. + random_gamma + Whether to apply a random roll around the local +Z axis before + constructing Wigner-D blocks. + wigner_calc + Callable that converts edge-aligned quaternions into packed Wigner-D + blocks. + + Returns + ------- + EdgeFeatureCache + Per-edge cache. + """ + n_nodes = type_ebed.shape[0] + src = edge_index[0].to(dtype=torch.long) + dst = edge_index[1].to(dtype=torch.long) + + # === Step 1. Normalize mask and apply type exclusions === + edge_keep = edge_mask.to(dtype=torch.bool) + if has_exclude_types: + edge_keep = edge_keep & edge_type_keep_mask(atype_flat, src, dst) + + # === Step 2. Promote geometry dtype === + edge_vec = edge_vec.to(dtype=compute_dtype) + edge_keep_f = edge_keep.to(dtype=compute_dtype).unsqueeze(-1) + edge_vec = edge_vec * edge_keep_f + edge_vec = edge_vec + (1.0 - edge_keep_f) * edge_vec.new_tensor([0.0, 0.0, 1.0]) + + # === Step 3. Edge length, envelope, and radial basis === + with nvtx_range("envelope"): + edge_len = safe_norm(edge_vec, eps) + if inner_clamp is not None: + clamped = inner_clamp(edge_len) + scale = clamped / edge_len + edge_vec = edge_vec * scale + edge_len = clamped + edge_env = edge_envelope(edge_len) * edge_keep_f # (E, 1) + edge_rbf = radial_basis(edge_len) * edge_keep_f # (E, n_radial) + + # === Step 4. Edge quaternion -> Wigner-D blocks === + with nvtx_range("wigner_d"): + D_full, Dt_full = _build_edge_wigner( + edge_vec=edge_vec, + edge_len=edge_len, + eps=eps, + random_gamma=random_gamma, + wigner_calc=wigner_calc, + ) # (E, D, D), (E, D, D) + + # === Step 5. Edge type features === + edge_type_feat = build_edge_type_feat(type_ebed, src, dst) + edge_type_feat = edge_type_feat * edge_keep_f.to(dtype=edge_type_feat.dtype) + + # === Step 6. Source Freeze Propagation Gate (optional) === + # The sparse-edge path packs one dummy masked edge per frame so the + # compiled graph sees a statically non-empty tensor. ``edge_keep_f`` + # rewrites any such slot to ``w=1`` inside ``compute_edge_src_gate``, + # keeping the product reduction unaffected by padding. + edge_src_gate: torch.Tensor | None = None + if bridging_switch is not None: + with nvtx_range("src_gate"): + edge_src_gate = compute_edge_src_gate( + edge_len=edge_len, + src=src, + n_nodes=n_nodes, + bridging_switch=bridging_switch, + edge_keep_f=edge_keep_f, + ) + + return _finalize_edge_cache( + n_nodes=n_nodes, + src=src, + dst=dst, + edge_type_feat=edge_type_feat, + edge_vec=edge_vec, + edge_rbf=edge_rbf, + edge_env=edge_env, + D_full=D_full, + Dt_full=Dt_full, + eps=eps, + edge_src_gate=edge_src_gate, + ) + + +def _build_edge_wigner( + *, + edge_vec: torch.Tensor, + edge_len: torch.Tensor, + eps: float, + random_gamma: bool, + wigner_calc: WignerCalculatorFn, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Build packed Wigner-D blocks from edge vectors. + + Parameters + ---------- + edge_vec + Edge vectors with shape (E, 3) in Å. + edge_len + Edge lengths with shape (E, 1). + eps + Small positive epsilon used in quaternion construction. + random_gamma + Whether to apply a random roll around the local +Z axis. + wigner_calc + Callable that converts edge-aligned quaternions into packed Wigner-D + blocks. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + Packed Wigner-D matrices ``(D_full, Dt_full)`` with shape ``(E, D, D)``. + """ + # === Step 1. Build edge-aligned quaternions === + edge_quat = build_edge_quaternion( + edge_vec, + edge_len=edge_len, + eps=eps, + ) + + # === Step 2. Apply optional random local-Z roll === + if random_gamma: + gamma = torch.rand( + edge_quat.shape[0], + dtype=edge_quat.dtype, + device=edge_quat.device, + ) * (2.0 * math.pi) + edge_quat = quaternion_multiply(quaternion_z_rotation(gamma), edge_quat) + + # === Step 3. Convert quaternions to packed Wigner-D blocks === + return wigner_calc(edge_quat) + + +def _finalize_edge_cache( + *, + n_nodes: int, + src: torch.Tensor, + dst: torch.Tensor, + edge_type_feat: torch.Tensor, + edge_vec: torch.Tensor, + edge_rbf: torch.Tensor, + edge_env: torch.Tensor, + D_full: torch.Tensor, + Dt_full: torch.Tensor, + eps: float, + edge_src_gate: torch.Tensor | None = None, +) -> EdgeFeatureCache: + """ + Assemble the shared `EdgeFeatureCache` layout. + + Parameters + ---------- + n_nodes + Number of local nodes in the flattened frame-major layout. + src + Source node indices with shape (E,). + dst + Destination node indices with shape (E,). + edge_type_feat + Per-edge type features with shape (E, C). + edge_vec + Edge vectors with shape (E, 3). + edge_rbf + Radial basis features with shape (E, n_radial). + edge_env + Smooth edge envelope weights with shape (E, 1). + D_full + Packed Wigner-D matrices with shape (E, D, D). + Dt_full + Transposed packed Wigner-D matrices with shape (E, D, D). + eps + Small positive epsilon used in degree normalization. + edge_src_gate + Optional per-edge SFPG weight with shape (E, 1). ``None`` in + non-bridging mode. + + Returns + ------- + EdgeFeatureCache + Finalized per-edge cache shared by eager and compile paths. + """ + # === Step 1. Build smooth destination degrees === + with nvtx_range("degree"): + deg = torch.zeros(n_nodes, dtype=edge_vec.dtype, device=edge_vec.device) # (N,) + deg.index_add_(0, dst, edge_env.squeeze(-1).to(dtype=edge_vec.dtype).square()) + eps_tensor = deg.new_tensor(eps) + inv_sqrt_deg = rearrange( + torch.rsqrt(deg + eps_tensor), "N -> N 1 1" + ) # (N, 1, 1) + + return EdgeFeatureCache( + src=src, + dst=dst, + edge_type_feat=edge_type_feat, + edge_vec=edge_vec, + edge_rbf=edge_rbf, + edge_env=edge_env, + deg=deg, + inv_sqrt_deg=inv_sqrt_deg, + D_full=D_full, + Dt_full=Dt_full, + D_to_m_cache={}, + Dt_from_m_cache={}, + edge_src_gate=edge_src_gate, + ) + + +def _get_empty_edge_cache( + *, + n_nodes: int, + n_radial: int, + n_channel: int, + device: torch.device, + dtype: torch.dtype, +) -> EdgeFeatureCache: + """ + Allocate an empty edge cache for one SeZM forward pass. + + Parameters + ---------- + n_nodes + Number of local nodes in the flattened frame-major layout. + n_radial + Number of radial basis channels. + n_channel + Edge type feature width. + device + Target device for the cache tensors. + dtype + Target floating-point dtype for the cache tensors. + + Returns + ------- + EdgeFeatureCache + Empty cache with valid tensor shapes and neutral degree normalization. + """ + empty_long = torch.empty(0, dtype=torch.long, device=device) + empty_vec = torch.empty(0, 3, dtype=dtype, device=device) + empty_rbf = torch.empty(0, n_radial, dtype=dtype, device=device) + empty_type_feat = torch.empty(0, n_channel, dtype=dtype, device=device) + deg = torch.zeros(n_nodes, dtype=dtype, device=device) + inv_sqrt_deg = torch.ones(n_nodes, 1, 1, dtype=dtype, device=device) + return EdgeFeatureCache( + src=empty_long, + dst=empty_long, + edge_type_feat=empty_type_feat, + edge_vec=empty_vec, + edge_rbf=empty_rbf, + edge_env=torch.empty(0, 1, dtype=dtype, device=device), + deg=deg, + inv_sqrt_deg=inv_sqrt_deg, + D_full=None, + Dt_full=None, + D_to_m_cache={}, + Dt_from_m_cache={}, + edge_src_gate=None, + ) + + +def _build_standard_edge_index( + *, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + pair_keep_mask: torch.Tensor, + nall: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Flatten DeePMD valid neighbor slots into per-edge indices. + + This helper keeps the original edge semantics used by the eager standard path: + + - padding slots (``nlist == -1``) are removed + - excluded type pairs are removed + - no distance-based filtering is applied here; edges beyond ``rcut`` remain + in the cache and are later zeroed naturally by the smooth envelope + + Parameters + ---------- + nlist + DeePMD neighbor list with shape ``(nf, nloc, nnei)``. + mapping + Optional extended-to-local mapping with shape ``(nf, nall)``. + pair_keep_mask + Pair exclusion keep mask with shape ``(nf, nloc, nnei)``. + nall + Number of atoms on the extended axis per frame. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + ``(src, dst, center_coord_index, neighbor_coord_index)`` for the valid + standard-path edges. All tensors have shape ``(E,)``. + """ + nf, nloc, nnei = nlist.shape + nlist_flat = nlist.reshape(-1) + + # === Step 1. Identify valid edge slots === + # An edge is valid if: + # - it is not padding (nlist >= 0) + # - the type pair is allowed (pair_keep_mask) + # Note: We do NOT filter by distance here. Edges beyond rcut stay in the + # cache and will later get edge_env=0 from the cutoff envelope. + valid_nlist = nlist >= 0 + edge_keep = (valid_nlist & pair_keep_mask).reshape(-1) + edge_slot = torch.nonzero(edge_keep).squeeze(-1).to(dtype=torch.long) + + if edge_slot.numel() == 0: + empty = torch.empty(0, dtype=torch.long, device=nlist.device) + return empty, empty, empty, empty + + # === Step 2. Decode flat edge slots === + # edge_slot indexes the flattened (nf, nloc, nnei) axis in row-major order. + # Convert it back to: + # frame_idx in [0, nf) + # center_local in [0, nloc) + # neighbor_ext from the extended axis in [0, nall) + frame_idx = edge_slot // (nloc * nnei) + rem = edge_slot % (nloc * nnei) + center_local = rem // nnei + neighbor_ext = nlist_flat.index_select(0, edge_slot) + + if mapping is None: + # Neighbor indices are already local indices in [0, nloc). + src_local = neighbor_ext + else: + # Map extended index -> local index for each frame. + # mapping_flat packs (nf, nall), so frame k uses offset k * nall. + mapping_flat = mapping.reshape(-1) + src_local = mapping_flat.index_select(0, frame_idx * nall + neighbor_ext) + + src_ok = (src_local >= 0) & (src_local < nloc) + if not bool(src_ok.all()): + # Drop edges that map outside the local range, e.g. broken mapping + # or ghost-only neighbors. + frame_idx = frame_idx[src_ok] + center_local = center_local[src_ok] + neighbor_ext = neighbor_ext[src_ok] + src_local = src_local[src_ok] + + if src_local.numel() == 0: + empty = torch.empty(0, dtype=torch.long, device=nlist.device) + return empty, empty, empty, empty + + # === Step 3. Build node and coordinate indices === + # dst is the center atom: per-frame local index -> global node index. + # src is the neighbor atom: per-frame local index -> global node index. + # The coordinate indices still point to the extended coordinate tensor. + src = frame_idx * nloc + src_local + dst = frame_idx * nloc + center_local + center_coord_index = frame_idx * nall + center_local + neighbor_coord_index = frame_idx * nall + neighbor_ext + return src, dst, center_coord_index, neighbor_coord_index + + +def build_edge_type_feat( + type_ebed: torch.Tensor, + src: torch.Tensor, + dst: torch.Tensor, +) -> torch.Tensor: + """ + Build per-edge type features by summing src/dst embeddings. + + Parameters + ---------- + type_ebed + Per-node type embedding with shape (N, C). + src + Source node indices with shape (E,). + dst + Destination node indices with shape (E,). + + Returns + ------- + torch.Tensor + Per-edge type features with shape (E, C). + """ + # === Step 1. Normalize index dtypes === + if src.dtype != torch.long: + src = src.to(dtype=torch.long) + if dst.dtype != torch.long: + dst = dst.to(dtype=torch.long) + + # === Step 2. Sum source and destination embeddings === + return type_ebed.index_select(0, src) + type_ebed.index_select(0, dst) + + +def edge_cache_to_dtype( + cache: EdgeFeatureCache, dtype: torch.dtype +) -> EdgeFeatureCache: + """ + Convert all floating-point tensors in EdgeFeatureCache to the specified dtype. + + Integer tensors (src, dst) are unchanged. This is a standalone function + (not a method) to keep it side-effect free. + + Parameters + ---------- + cache + The edge feature cache to convert. + dtype + Target dtype for floating-point tensors. + + Returns + ------- + EdgeFeatureCache + New cache with converted tensors. + """ + # Handle Optional tensors explicitly. + # Use local variables with explicit None check and assignment. + _D_full = cache.D_full + _Dt_full = cache.Dt_full + _edge_src_gate = cache.edge_src_gate + D_full: torch.Tensor | None = None + Dt_full: torch.Tensor | None = None + edge_src_gate: torch.Tensor | None = None + if _D_full is not None: + D_full = _D_full.to(dtype=dtype) + if _Dt_full is not None: + Dt_full = _Dt_full.to(dtype=dtype) + if _edge_src_gate is not None: + edge_src_gate = _edge_src_gate.to(dtype=dtype) + + return EdgeFeatureCache( + src=cache.src, + dst=cache.dst, + edge_type_feat=cache.edge_type_feat.to(dtype=dtype), + edge_vec=cache.edge_vec.to(dtype=dtype), + edge_rbf=cache.edge_rbf.to(dtype=dtype), + edge_env=cache.edge_env.to(dtype=dtype), + deg=cache.deg.to(dtype=dtype), + inv_sqrt_deg=cache.inv_sqrt_deg.to(dtype=dtype), + D_full=D_full, + Dt_full=Dt_full, + D_to_m_cache=None if cache.D_to_m_cache is None else {}, + Dt_from_m_cache=None if cache.Dt_from_m_cache is None else {}, + edge_src_gate=edge_src_gate, + ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/embedding.py b/deepmd/pt/model/descriptor/sezm_nn/embedding.py new file mode 100644 index 0000000000..5e6862cfe3 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/embedding.py @@ -0,0 +1,698 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Embedding layers for the SeZM descriptor. + +This module defines the type embedding, geometric initial embedding, and +environment-seed embedding used to initialize SeZM node features. +""" + +from __future__ import ( + annotations, +) + +import math +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch +import torch.nn as nn + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.network.mlp import ( + MLPLayer, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + get_generator, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .indexing import ( + get_so3_dim_of_lmax, + map_degree_idx, +) +from .utils import ( + np_safe, + safe_numpy_to_tensor, +) + +if TYPE_CHECKING: + from .edge_cache import ( + EdgeFeatureCache, + ) + + +class SeZMTypeEmbedding(nn.Module): + """ + Minimal SeZM type embedding with Adam-routed parameter naming. + + Parameters + ---------- + ntypes + Number of atom types. + embed_dim + Embedding dimension. + dtype + Parameter dtype. + seed + Random seed for initialization. + trainable + Whether parameters are trainable. + padding + Whether to append one all-zero padding row. + + Notes + ----- + The parameter is named with ``adam_`` prefix so HybridMuon routes it to Adam. + """ + + def __init__( + self, + *, + ntypes: int, + embed_dim: int, + dtype: torch.dtype, + seed: int | list[int] | None = None, + trainable: bool, + padding: bool = True, + ) -> None: + super().__init__() + self.ntypes = int(ntypes) + self.embed_dim = int(embed_dim) + self.dtype = dtype + self.seed = seed + self.device = env.DEVICE + self.padding = bool(padding) + if self.ntypes <= 0: + raise ValueError("`ntypes` must be positive") + if self.embed_dim <= 0: + raise ValueError("`embed_dim` must be positive") + + # === Step 1. Build embedding table parameter === + n_rows = self.ntypes + int(self.padding) + self.adam_type_embedding = nn.Parameter( + torch.empty( + n_rows, + self.embed_dim, + device=self.device, + dtype=self.dtype, + ) + ) + + # === Step 2. Initialize active type rows with default normal scale === + init_std = 1.0 / math.sqrt(float(self.ntypes + self.embed_dim)) + nn.init.normal_( + self.adam_type_embedding[: self.ntypes], + mean=0.0, + std=init_std, + generator=get_generator(child_seed(seed, 0)), + ) + if self.padding: + with torch.no_grad(): + self.adam_type_embedding[self.ntypes].zero_() + + for p in self.parameters(): + p.requires_grad = trainable + + def forward(self, atype: torch.Tensor) -> torch.Tensor: + """ + Gather type embeddings. + + Parameters + ---------- + atype + Atom types with shape (...,). Valid type range is [0, ntypes-1]. + + Returns + ------- + torch.Tensor + Type embeddings with shape (..., embed_dim). + """ + return torch.embedding(self.adam_type_embedding, atype) + + +class GeometricInitialEmbedding(nn.Module): + """ + Geometric initial embedding that adds zonal (m=0) rotated features. + + This module rotates pre-computed radial features for each degree l >= 1 using the + zonal (m=0) column of the cached inverse Wigner-D blocks (local->global). + The l=0 component is not computed here since it comes from type embedding. + + Parameters + ---------- + lmax + Maximum degree, should match ``l_schedule[0]``. + channels + Number of channels per (l, m) coefficient. + dtype + Parameter dtype. + """ + + def __init__( + self, + *, + lmax: int, + channels: int, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.channels = int(channels) + self.ebed_dim = get_so3_dim_of_lmax(self.lmax) + self.dtype = dtype + self.device = env.DEVICE + self.precision = RESERVED_PRECISION_DICT[dtype] + if self.lmax > 0: + packed_degree_by_row = map_degree_idx(self.lmax, device=self.device) + # These aligned arrays describe one packed non-scalar row at a time. + # non_scalar_row_index[k] picks the output row in the packed SO(3) layout. + # zonal_m0_col_index_for_row[k] picks the matching m=0 column in Dt_full. + # radial_slot_index_for_row[k] picks the matching degree slot in radial_feat. + non_scalar_row_index = torch.arange( + 1, self.ebed_dim, device=self.device, dtype=torch.long + ) + non_scalar_degree_by_row = packed_degree_by_row[1:] + zonal_m0_col_index_for_row = non_scalar_degree_by_row * ( + non_scalar_degree_by_row + 1 + ) + radial_slot_index_for_row = non_scalar_degree_by_row - 1 + self.register_buffer( + "non_scalar_row_index", non_scalar_row_index, persistent=True + ) + self.register_buffer( + "zonal_m0_col_index_for_row", + zonal_m0_col_index_for_row, + persistent=True, + ) + self.register_buffer( + "radial_slot_index_for_row", + radial_slot_index_for_row, + persistent=True, + ) + else: + self.register_buffer( + "non_scalar_row_index", + torch.empty(0, device=self.device, dtype=torch.long), + persistent=True, + ) + self.register_buffer( + "zonal_m0_col_index_for_row", + torch.empty(0, device=self.device, dtype=torch.long), + persistent=True, + ) + self.register_buffer( + "radial_slot_index_for_row", + torch.empty(0, device=self.device, dtype=torch.long), + persistent=True, + ) + + def forward( + self, + *, + n_nodes: int, + edge_cache: EdgeFeatureCache, + radial_feat: torch.Tensor, + ) -> torch.Tensor: + """ + Parameters + ---------- + n_nodes + Number of nodes (nf*nloc). + edge_cache + Per-edge cache containing geometry, weights, and Wigner-D blocks. + radial_feat + Per-edge radial features with shape (E, lmax, C) for l=1..lmax. + + Returns + ------- + torch.Tensor + Initial features to add with shape (N, D, C). l=0 is guaranteed zero. + """ + # === Step 1. Initialize output === + device = edge_cache.edge_vec.device + dtype = edge_cache.edge_vec.dtype + out = torch.zeros( + n_nodes, self.ebed_dim, self.channels, device=device, dtype=dtype + ) # (N, D, C) + if self.lmax == 0: + return out + + # === Step 2. Gather all m=0 columns (l >= 1) in one shot === + # Advanced indexing pairs one packed non-scalar row with the zonal m=0 column + # from the same degree block in Dt_full. + Dt_full = edge_cache.Dt_full # (E, D, D) + zonal_m0_value_for_row = Dt_full[ + :, + self.non_scalar_row_index, + self.zonal_m0_col_index_for_row, + ] # (E, D-1) + + # === Step 3. Broadcast radial features per row === + # Each non-scalar packed row reuses the radial feature of its degree l. + radial_value_for_row = radial_feat.index_select( + 1, self.radial_slot_index_for_row + ) # (E, D-1, C) + non_scalar_message = ( + zonal_m0_value_for_row.unsqueeze(-1) * radial_value_for_row + ) # (E, D-1, C) + + # === Step 4. Source Freeze Propagation Gate (optional) === + # Mute messages emitted by nodes whose local neighborhood enters + # the frozen zone. ``edge_src_gate`` is ``None`` outside bridging + # mode so this is a no-op in normal training. + src_gate = edge_cache.edge_src_gate + if src_gate is not None: + non_scalar_message = non_scalar_message * src_gate.to( + dtype=non_scalar_message.dtype + ).unsqueeze(-1) + + # === Step 5. Scatter to nodes and normalize === + # Avoid advanced-index writeback (out[:, non_scalar_row_index, :]) which produces a copy. + non_scalar_out = out.new_zeros( + n_nodes, self.non_scalar_row_index.numel(), self.channels + ) # (N, D-1, C) + non_scalar_out.index_add_(0, edge_cache.dst, non_scalar_message) + out[:, self.non_scalar_row_index, :] = non_scalar_out + out.mul_(edge_cache.inv_sqrt_deg) + return out + + def serialize(self) -> dict[str, Any]: + return { + "@class": "GeometricInitialEmbedding", + "@version": 1, + "lmax": self.lmax, + "channels": self.channels, + "precision": RESERVED_PRECISION_DICT[self.dtype], + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> GeometricInitialEmbedding: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "GeometricInitialEmbedding": + raise ValueError(f"Invalid class for GeometricInitialEmbedding: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + precision = data.pop("precision") + data["dtype"] = PRECISION_DICT[precision] + return cls(**data) + + +class EnvironmentInitialEmbedding(nn.Module): + """ + Environment matrix initial embedding for l=0 features. + + Computes an initial embedding based on the 4D environment matrix:: + + [s, s * rx, s * ry, s * rz] + + Combined with independent type embeddings (individual type embedding), + providing physical inductive bias for l=0 features. + + The computation follows the environment matrix approach where:: + + 1. Build `r_tilde = [s, s*r_hat]` where `s = edge_env / r` and `r_hat = edge_vec / r` + 2. G network: `g = G(rbf_proj(edge_rbf), type_src, type_dst)` produces per-edge features + - Uses independent `env_type_embed` instead of projecting from main type embedding + - Uses `rbf_proj` to project edge_rbf to `rbf_out_dim` + 3. env_agg: aggregate outer product `r_tilde ⊗ g` by destination node + 4. D matrix: `D = env_agg^T @ env_agg[:, :, :axis_dim]` + 5. Output: projection of flattened D matrix into FiLM logits + + Parameters + ---------- + ntypes : int + Number of atom types. + n_radial : int + Number of radial basis functions. + channels : int + Output channel dimension per FiLM branch (final output is 2*channels). + embed_dim : int + G network output dimension (filter width). + axis_dim : int + D matrix axis dimension (must be < embed_dim). + type_dim : int + Dimension for independent type embeddings in env_seed. + hidden_dim : int + Hidden layer size for G network. + mlp_bias : bool + Whether to enable bias terms in env-seed MLP layers + (`rbf_proj_layer1/2` and `g_layer1/2`). + activation_function : str + Activation function for G network hidden layer. + eps : float + Small epsilon for numerical stability. + dtype : torch.dtype + Parameter dtype. + trainable : bool + Whether parameters are trainable. + seed : int | list[int] | None + Random seed for reproducibility. + """ + + def __init__( + self, + *, + ntypes: int, + n_radial: int, + channels: int, + embed_dim: int = 64, + axis_dim: int = 8, + type_dim: int = 16, + hidden_dim: int = 64, + mlp_bias: bool = False, + activation_function: str = "silu", + eps: float = 1e-7, + dtype: torch.dtype, + trainable: bool, + seed: int | list[int] | None = None, + ) -> None: + super().__init__() + + # === Validate parameters === + if axis_dim >= embed_dim: + raise ValueError( + f"`axis_dim` ({axis_dim}) must be < `embed_dim` ({embed_dim})" + ) + + self.ntypes = int(ntypes) + self.n_radial = int(n_radial) + self.channels = int(channels) + self.embed_dim = int(embed_dim) + self.axis_dim = int(axis_dim) + self.type_dim = int(type_dim) + self.hidden_dim = int(hidden_dim) + self.mlp_bias = bool(mlp_bias) + self.activation_function = str(activation_function) + self.eps = float(eps) + self.dtype = dtype + self.device = env.DEVICE + self.precision = RESERVED_PRECISION_DICT[dtype] + self.register_buffer( + "eps_tensor", + torch.tensor(self.eps, dtype=self.dtype, device=self.device), + persistent=False, + ) + self.register_buffer( + "eps_sq_tensor", + torch.tensor(self.eps * self.eps, dtype=self.dtype, device=self.device), + persistent=False, + ) + + # === RBF projection: n_radial -> rbf_out_dim (two-layer MLP) === + # rbf_out_dim = max(32, embed_dim - 2*type_dim) to align G-network width to embed_dim + # First layer: n_radial -> rbf_out_dim with activation + # Second layer: rbf_out_dim -> rbf_out_dim linear + self.rbf_out_dim = max(32, self.embed_dim - 2 * self.type_dim) + seed_rbf_proj = child_seed(seed, 0) + self.rbf_proj_layer1 = MLPLayer( + self.n_radial, + self.rbf_out_dim, + bias=self.mlp_bias, + activation_function=self.activation_function, + precision=self.precision, + seed=child_seed(seed_rbf_proj, 0), + ) + self.rbf_proj_layer2 = MLPLayer( + self.rbf_out_dim, + self.rbf_out_dim, + bias=self.mlp_bias, + activation_function=None, + precision=self.precision, + seed=child_seed(seed_rbf_proj, 1), + ) + + # === Independent type embedding: ntypes -> type_dim === + # Individual type embedding + seed_type_embed = child_seed(seed, 1) + self.env_type_embed = SeZMTypeEmbedding( + ntypes=self.ntypes, + embed_dim=self.type_dim, + dtype=self.dtype, + seed=seed_type_embed, + trainable=trainable, + ) + + # === G network: (rbf_out_dim + 2*type_dim) -> hidden_dim -> embed_dim === + seed_g_net = child_seed(seed, 2) + g_in_dim = self.rbf_out_dim + 2 * self.type_dim + self.g_layer1 = MLPLayer( + g_in_dim, + self.hidden_dim, + bias=self.mlp_bias, + activation_function=self.activation_function, + precision=self.precision, + seed=child_seed(seed_g_net, 0), + ) + self.g_layer2 = MLPLayer( + self.hidden_dim, + self.embed_dim, + bias=self.mlp_bias, + activation_function=None, + precision=self.precision, + seed=child_seed(seed_g_net, 1), + ) + + # === Output projection: embed_dim * axis_dim -> 2*channels === + # Zero init so FiLM logits start at zero; strengths control magnitude. + seed_out = child_seed(seed, 3) + self.output_proj = MLPLayer( + self.embed_dim * self.axis_dim, + 2 * self.channels, + bias=False, + activation_function=None, + init="final", + precision=self.precision, + seed=seed_out, + ) + + for p in self.parameters(): + p.requires_grad = trainable + + def forward( + self, + *, + edge_cache: EdgeFeatureCache, + atype_flat: torch.Tensor, + n_nodes: int, + ) -> torch.Tensor: + """ + Compute environment FiLM logits for l=0 conditioning. + + Parameters + ---------- + edge_cache : EdgeFeatureCache + Edge cache containing src, dst, edge_vec, edge_rbf, edge_env. + atype_flat : torch.Tensor + Flattened atom types with shape (N,), where N = nf * nloc. + n_nodes : int + Number of nodes (N = nf * nloc). + + Returns + ------- + torch.Tensor + FiLM logits with shape (N, 2*channels). + """ + src, dst = edge_cache.src, edge_cache.dst + edge_vec = edge_cache.edge_vec # (E, 3) + edge_rbf = edge_cache.edge_rbf # (E, n_radial) + edge_env = edge_cache.edge_env # (E, 1) + + # === Step 1. Construct r_tilde = [s, s*r_hat] === + # s = edge_env * (1/r), r_hat = edge_vec / r + r_sq = (edge_vec * edge_vec).sum(dim=-1, keepdim=True) # (E, 1) + inv_r = torch.rsqrt(r_sq + self.eps_sq_tensor) # (E, 1) + s = edge_env * inv_r # (E, 1) + r_hat = edge_vec * inv_r # (E, 3) + r_tilde = torch.cat([s, s * r_hat], dim=-1) # (E, 4) + + # === Step 2. Compute G network input and output === + # Use independent type embeddings (decoupled from main type embedding) + atype_src = atype_flat.index_select(0, src) # (E,) + atype_dst = atype_flat.index_select(0, dst) # (E,) + type_src = self.env_type_embed(atype_src) # (E, type_dim) + type_dst = self.env_type_embed(atype_dst) # (E, type_dim) + + # Project edge_rbf to rbf_out_dim (two-layer MLP) + rbf_proj = self.rbf_proj_layer2( + self.rbf_proj_layer1(edge_rbf) + ) # (E, rbf_out_dim) + + # G network input: concat projected RBF and type embeddings + g_input = torch.cat([rbf_proj, type_src, type_dst], dim=-1) # (E, g_in_dim) + g = self.g_layer2(self.g_layer1(g_input)) # (E, embed_dim) + + # === Step 3. Aggregate outer product by destination node === + # outer = r_tilde[:, :, None] * g[:, None, :] # (E, 4, embed_dim) + outer = torch.einsum("ei,ej->eij", r_tilde, g) # (E, 4, embed_dim) + outer_flat = outer.reshape(-1, 4 * self.embed_dim) # (E, 4*embed_dim) + # Source Freeze Propagation Gate: mute the outer-product contribution + # of any edge whose source node has a neighbor in the frozen zone. + src_gate = edge_cache.edge_src_gate + if src_gate is not None: + outer_flat = outer_flat * src_gate.to(dtype=outer_flat.dtype) + env_agg = outer_flat.new_zeros(n_nodes, 4 * self.embed_dim) # (N, 4*embed_dim) + env_agg.index_add_(0, dst, outer_flat) + env_agg = env_agg.reshape(n_nodes, 4, self.embed_dim) # (N, 4, embed_dim) + + # === Step 4. Smooth normalization by envelope-squared degree === + deg_scale = torch.rsqrt(edge_cache.deg + self.eps_tensor).reshape( + -1, 1, 1 + ) # (N, 1, 1) + env_agg = env_agg * deg_scale + + # === Step 5. D matrix construction: D = env_agg^T @ env_agg[:,:,:axis_dim] === + env_agg_t = env_agg.permute(0, 2, 1) # (N, embed_dim, 4) + env_agg_axis = env_agg[:, :, : self.axis_dim] # (N, 4, axis_dim) + D = torch.bmm(env_agg_t, env_agg_axis) # (N, embed_dim, axis_dim) + + # === Step 6. Output projection for FiLM logits === + D_flat = D.reshape( + n_nodes, self.embed_dim * self.axis_dim + ) # (N, embed_dim*axis_dim) + return self.output_proj(D_flat) + + def serialize(self) -> dict[str, Any]: + trainable = all(p.requires_grad for p in self.parameters()) + state = self.state_dict() + return { + "@class": "EnvironmentInitialEmbedding", + "@version": 1, + "config": { + "ntypes": self.ntypes, + "n_radial": self.n_radial, + "channels": self.channels, + "embed_dim": self.embed_dim, + "axis_dim": self.axis_dim, + "type_dim": self.type_dim, + "hidden_dim": self.hidden_dim, + "mlp_bias": self.mlp_bias, + "activation_function": self.activation_function, + "eps": self.eps, + "precision": self.precision, + "trainable": trainable, + "seed": None, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> EnvironmentInitialEmbedding: + """Deserialize from dictionary.""" + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "EnvironmentInitialEmbedding": + raise ValueError(f"Invalid class: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj + + +class ChargeSpinEmbedding(nn.Module): + """ + Frame-level charge and spin embedding for scalar type features. + + Parameters + ---------- + embed_dim + Embedding dimension. + activation_function + Activation function used by the mixing layer. + dtype + Parameter dtype. + seed + Random seed for initialization. + trainable + Whether parameters are trainable. + """ + + def __init__( + self, + *, + embed_dim: int, + activation_function: str, + dtype: torch.dtype, + seed: int | list[int] | None = None, + trainable: bool, + ) -> None: + super().__init__() + self.embed_dim = int(embed_dim) + self.activation_function = str(activation_function) + self.dtype = dtype + self.precision = RESERVED_PRECISION_DICT[dtype] + if self.embed_dim <= 0: + raise ValueError("`embed_dim` must be positive") + + self.charge_embedding = SeZMTypeEmbedding( + ntypes=200, + embed_dim=self.embed_dim, + dtype=self.dtype, + seed=child_seed(seed, 0), + trainable=trainable, + padding=False, + ) + self.spin_embedding = SeZMTypeEmbedding( + ntypes=100, + embed_dim=self.embed_dim, + dtype=self.dtype, + seed=child_seed(seed, 1), + trainable=trainable, + padding=False, + ) + self.mix_layer = MLPLayer( + 2 * self.embed_dim, + self.embed_dim, + activation_function=self.activation_function, + precision=self.precision, + seed=child_seed(seed, 2), + trainable=trainable, + ) + + for p in self.parameters(): + p.requires_grad = trainable + + def forward(self, charge_spin: torch.Tensor) -> torch.Tensor: + """ + Embed frame-level charge and spin. + + Parameters + ---------- + charge_spin + Frame charge and spin values with shape (nf, 2). + + Returns + ------- + torch.Tensor + Mixed condition embedding with shape (nf, embed_dim). + """ + charge = charge_spin[:, 0].to(dtype=torch.int64) + 100 + spin = charge_spin[:, 1].to(dtype=torch.int64) + charge_embed = self.charge_embedding(charge) + spin_embed = self.spin_embedding(spin) + return self.mix_layer(torch.cat((charge_embed, spin_embed), dim=-1)) diff --git a/deepmd/pt/model/descriptor/sezm_nn/ffn.py b/deepmd/pt/model/descriptor/sezm_nn/ffn.py new file mode 100644 index 0000000000..0e06162bf5 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/ffn.py @@ -0,0 +1,406 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Equivariant feed-forward layers for SeZM. + +This module defines the full SO(3)-equivariant feed-forward network used +inside SeZM interaction blocks and the descriptor output head. +""" + +from __future__ import ( + annotations, +) + +from typing import ( + Any, +) + +import torch +import torch.nn as nn + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISION_DICT, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .activation import ( + GatedActivation, + S2GridProjector, + SwiGLU, + SwiGLUS2Activation, + resolve_s2_grid_resolution, +) +from .so3 import ( + ChannelLinear, + SO3Linear, +) +from .utils import ( + get_promoted_dtype, + np_safe, + safe_numpy_to_tensor, +) + + +class EquivariantFFN(nn.Module): + """ + Full equivariant FFN operating on all spherical harmonic degrees. + + Default structure (glu_activation=False): + SO3 linear (in -> hidden) -> GatedActivation -> SO3 linear (hidden -> out) + + Default structure (glu_activation=True): + SO3 linear (in -> 2*hidden) -> split -> GatedActivation(val, gate) -> SO3 linear (hidden -> out) + + Optional grid-FFN structure (grid_mlp=True): + SO3 linear (in -> hidden) + -> project packed SO(3) coefficients to the S2 grid + -> packed S2-grid point-wise MLP on hidden features + -> project grid features back to packed SO(3) coefficients + -> add scalar LinearSwiGLU branch to l=0 + -> SO3 linear (hidden -> out) + + GatedActivation serves as the unified "activation" for equivariant networks, + analogous to SiLU in standard MLPs, but respecting SO(3) equivariance: + - l=0: Uses the specified activation function (or GLU variant when glu_activation=True) + - l>0: sigmoid gate from l=0 scalar features + + When glu_activation=True, the first linear outputs 2*hidden_channels, then splits into + value and gate branches. This transforms activations like silu->swiglu, gelu->geglu. + The split approach is more efficient than two separate linear layers. + + Parameters + ---------- + lmax + Maximum degree. + channels + Number of channels per (l, m) coefficient. + hidden_channels + Hidden dimension for the FFN. + grid_mlp + If True, use the optional grid-MLP FFN structure on the block-internal + FFN path. This path takes precedence over the simpler activation-only + path inside this module. + dtype + Parameter dtype. + s2_activation + If True and ``grid_mlp=False``, replace the default GatedActivation path + with the merged scalar/grid SwiGLU-S2 activation. + lebedev_quadrature + If True, use Lebedev quadrature for the S2 projector in this FFN. + activation_function + Activation function for l=0 components (e.g., "silu", "tanh", "gelu"). + glu_activation + If True, use GLU-style gating (e.g., silu -> swiglu, gelu -> geglu). + mlp_bias + Whether to use bias in SO3Linear (l=0 bias), GatedActivation + (gate linear bias), and the scalar point-wise projection when + ``grid_mlp=True``. + trainable + Whether parameters are trainable. + seed + Random seed for weight initialization. + """ + + def __init__( + self, + *, + lmax: int, + channels: int, + hidden_channels: int, + grid_mlp: bool = False, + dtype: torch.dtype, + s2_activation: bool = False, + lebedev_quadrature: bool = False, + activation_function: str = "silu", + glu_activation: bool = True, + mlp_bias: bool = False, + trainable: bool, + seed: int | list[int] | None = None, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.channels = int(channels) + self.hidden_channels = int(hidden_channels) + self.use_grid_mlp = bool(grid_mlp) + self.s2_activation = bool(s2_activation) + self.lebedev_quadrature = bool(lebedev_quadrature) + self.s2_grid_method = "lebedev" if self.lebedev_quadrature else "e3nn" + base_grid = resolve_s2_grid_resolution( + self.lmax, + self.lmax, + method=self.s2_grid_method, + ) + self.s2_grid_resolution = ( + [max(base_grid), max(base_grid)] + if self.s2_grid_method == "e3nn" + else base_grid + ) + self.activation_function = activation_function + self.glu_activation = bool(glu_activation) + self.mlp_bias = bool(mlp_bias) + self.dtype = dtype + self.compute_dtype = get_promoted_dtype(self.dtype) + self.device = env.DEVICE + self.precision = RESERVED_PRECISION_DICT[dtype] + + # === Step 0. Split deterministic seeds at the module top-level === + seed_so3_in = child_seed(seed, 0) + seed_act = child_seed(seed, 1) + seed_so3_out = child_seed(seed, 2) + + # === First SO3Linear for channel mixing === + # Grid-FFN keeps the hidden width and performs the nonlinear expansion + # inside the scalar/grid point-wise MLPs. + linear1_out_channels = self.hidden_channels + if not self.use_grid_mlp: + linear1_out_channels = ( + 2 * self.hidden_channels + if self.glu_activation + else self.hidden_channels + ) + self.so3_linear_1 = SO3Linear( + lmax=self.lmax, + in_channels=self.channels, + out_channels=linear1_out_channels, + n_focus=1, + dtype=dtype, + mlp_bias=self.mlp_bias, + trainable=trainable, + seed=seed_so3_in, + ) + + # === Equivariant nonlinearity path === + self.scalar_mlp: nn.Module | None = None + self.grid_projector: S2GridProjector | None = None + self.pointwise_grid_mlp: nn.Module | None = None + if self.use_grid_mlp: + self.scalar_mlp = nn.Sequential( + ChannelLinear( + in_channels=self.channels, + out_channels=2 * self.hidden_channels, + dtype=dtype, + bias=self.mlp_bias, + trainable=trainable, + seed=child_seed(seed_act, 0), + ), + SwiGLU(), + ) + self.grid_projector = S2GridProjector( + lmax=self.lmax, + mmax=self.lmax, + dtype=dtype, + grid_resolution_list=self.s2_grid_resolution, + coefficient_layout="packed", + grid_method=self.s2_grid_method, + ) + self.pointwise_grid_mlp = PointwiseGridMLP( + channels=self.hidden_channels, + dtype=dtype, + trainable=trainable, + seed=child_seed(seed_act, 1), + ) + self.act = nn.Identity() + elif self.s2_activation: + self.act = SwiGLUS2Activation( + lmax=self.lmax, + channels=self.hidden_channels, + dtype=self.compute_dtype, + n_focus=1, + layout="ndfc", + grid_resolution_list=self.s2_grid_resolution, + coefficient_layout="packed", + grid_method=self.s2_grid_method, + mlp_bias=self.mlp_bias, + trainable=trainable, + seed=seed_act, + ) + else: + self.act = GatedActivation( + lmax=self.lmax, + channels=self.hidden_channels, + dtype=self.compute_dtype, + activation_function=activation_function, + mlp_bias=self.mlp_bias, + layout="ndfc", + trainable=trainable, + seed=seed_act, + ) + + # === Second SO3Linear for channel mixing === + # Zero-initialized so residual path starts near-identity. + self.so3_linear_2 = SO3Linear( + lmax=self.lmax, + in_channels=self.hidden_channels, + out_channels=self.channels, + n_focus=1, + dtype=dtype, + mlp_bias=self.mlp_bias, + trainable=trainable, + seed=seed_so3_out, + init_std=0.0, + ) + + for p in self.parameters(): + p.requires_grad = trainable + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x + Input with shape (N, D, F, C) where D=(lmax+1)^2. + + Returns + ------- + torch.Tensor + Output with shape (N, D, F, C). + """ + # === Step 1. Input up projection === + x_input = x + x = self.so3_linear_1(x) + + # === Step 2. Equivariant nonlinearity === + if self.use_grid_mlp: + scalar_outputs = self.scalar_mlp(x_input.select(dim=1, index=0)) + x_flat, shape_info = self._flatten_grid_inputs(x) + x_grid = self.grid_projector.to_grid(x_flat.to(dtype=self.dtype)) + x_grid = self.pointwise_grid_mlp(x_grid) + x = self._restore_grid_outputs( + self.grid_projector.from_grid(x_grid), shape_info + ) + x[:, 0, :, :].add_(scalar_outputs) + elif self.s2_activation: + x = self.act(x) + elif self.glu_activation: + # Split into value and gate branches along channel dimension + x_val, x_gate = x.chunk(2, dim=-1) + # Pass gate to GatedActivation for GLU-style gating + x = self.act(x_val, gate=x_gate) + else: + x = self.act(x) + + # === Step 3. Per-degree output projection === + x = self.so3_linear_2(x) + + return x + + def _flatten_grid_inputs( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, tuple[int, int, int]]: + n_batch, coeff_dim, n_focus, _ = x.shape + return ( + x.permute(0, 2, 1, 3).reshape(n_batch * n_focus, coeff_dim, x.shape[-1]), + (n_batch, coeff_dim, n_focus), + ) + + def _restore_grid_outputs( + self, x: torch.Tensor, shape_info: tuple[int, int, int] + ) -> torch.Tensor: + n_batch, coeff_dim, n_focus = shape_info + return x.reshape(n_batch, n_focus, coeff_dim, self.hidden_channels).permute( + 0, 2, 1, 3 + ) + + def serialize(self) -> dict[str, Any]: + trainable = all(p.requires_grad for p in self.parameters()) + state = self.state_dict() + return { + "@class": "EquivariantFFN", + "@version": 1, + "config": { + "lmax": self.lmax, + "channels": self.channels, + "hidden_channels": self.hidden_channels, + "grid_mlp": self.use_grid_mlp, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "s2_activation": self.s2_activation, + "lebedev_quadrature": self.lebedev_quadrature, + "activation_function": self.activation_function, + "glu_activation": self.glu_activation, + "mlp_bias": self.mlp_bias, + "trainable": trainable, + "seed": None, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> EquivariantFFN: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "EquivariantFFN": + raise ValueError(f"Invalid class for EquivariantFFN: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj + + +class PointwiseGridMLP(nn.Module): + """ + Apply a two-layer point-wise MLP on flattened S2 grid features. + + Parameters + ---------- + channels + Hidden feature dimension on the grid. + dtype + Parameter dtype. + trainable + Whether parameters are trainable. + seed + Random seed for weight initialization. + """ + + def __init__( + self, + *, + channels: int, + dtype: torch.dtype, + trainable: bool, + seed: int | list[int] | None = None, + ) -> None: + super().__init__() + self.channels = int(channels) + self.linear_1 = ChannelLinear( + in_channels=self.channels, + out_channels=2 * self.channels, + dtype=dtype, + bias=False, + trainable=trainable, + seed=child_seed(seed, 0), + ) + self.act = SwiGLU() + self.linear_2 = ChannelLinear( + in_channels=self.channels, + out_channels=self.channels, + dtype=dtype, + bias=False, + trainable=trainable, + seed=child_seed(seed, 1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the point-wise grid MLP.""" + x = self.act(self.linear_1(x)) + return self.linear_2(x) diff --git a/deepmd/pt/model/descriptor/sezm_nn/indexing.py b/deepmd/pt/model/descriptor/sezm_nn/indexing.py new file mode 100644 index 0000000000..e550c9053b --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/indexing.py @@ -0,0 +1,423 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +SO(3) packed-index and projection helpers for SeZM. + +This module defines the packed `(l, m)` indexing helpers and the projection +utilities used by the SeZM equivariant operators. +""" + +from __future__ import ( + annotations, +) + +import torch + + +def get_so3_dim_of_lmax(lmax: int) -> int: + """ + Return SO(3) representation dimension for given lmax. + + The dimension equals:: + + sum_{l<=lmax} (2l+1) = (lmax+1)^2 + + which is the number of spherical harmonics basis functions. + + Parameters + ---------- + lmax + Maximum spherical harmonic degree. + + Returns + ------- + int + The SO(3) dimension D = (lmax+1)^2. + """ + return int((int(lmax) + 1) ** 2) + + +def map_degree_idx(lmax: int, *, device: torch.device) -> torch.Tensor: + """ + Build degree (l) index for each position in the packed (l, m) layout. + + For each spherical harmonic coefficient position in the packed tensor, + returns the corresponding angular momentum quantum number l. + + Examples + -------- + For lmax=2, the packed layout has D=9 positions: + - Position 0: l=0, m=0 + - Positions 1-3: l=1, m=-1,0,+1 + - Positions 4-8: l=2, m=-2,-1,0,+1,+2 + + Returns: [0, 1,1,1, 2,2,2,2,2] + + Parameters + ---------- + lmax + Maximum angular momentum degree. + device + Device for the returned tensor. + + Returns + ------- + torch.Tensor + Integer tensor with shape (D,), where D=(lmax+1)^2. + Each element is the l value for that position. + """ + lmax = int(lmax) + counts = torch.tensor( + [2 * degree + 1 for degree in range(lmax + 1)], + device=device, + dtype=torch.long, + ) + return torch.repeat_interleave( + torch.arange(lmax + 1, device=device, dtype=torch.long), counts + ) + + +def project_D_to_m( + D_full: torch.Tensor, + coeff_index_m: torch.Tensor, + ebed_dim_full: int, + cache: dict[str, torch.Tensor] | None, + key_lmax: int, + key_mmax: int, +) -> torch.Tensor: + """ + Row-project block-diagonal Wigner-D to the m-major truncated layout. + + Parameters + ---------- + D_full + Block-diagonal Wigner-D with shape (E, D, D). + coeff_index_m + Indices for m-major reduced layout with shape (D_m_trunc,). + ebed_dim_full + Full SO(3) dimension D_full = (lmax+1)^2 to slice the block. + cache + Optional cache mapping (lmax, mmax) -> projected matrix. + key_lmax + lmax used to build coeff_index_m (cache key). + key_mmax + mmax used to build coeff_index_m (cache key). + + Returns + ------- + torch.Tensor + Projected rotation matrix with shape (E, D_m_trunc, D). + + Examples + -------- + For lmax=2, mmax=1 (D=9, D_m_trunc=7), coeff_index_m selects + [0,2,6,1,5,3,7] in packed (l,m) order. The returned tensor keeps only those + rows of ``D_full`` while retaining all columns, so that rotating and truncating + is done in a single bmm: ``x_local = D_to_m @ x_global``. + """ + cache_key = f"{int(key_lmax)}:{int(key_mmax)}" + if cache is not None: + cached = cache.get(cache_key) + if cached is not None: + return cached + + D_block = D_full[:, :ebed_dim_full, :ebed_dim_full] + proj = D_block.index_select(1, coeff_index_m) + if cache is not None: + cache[cache_key] = proj + return proj + + +def project_Dt_from_m( + Dt_full: torch.Tensor, + coeff_index_m: torch.Tensor, + ebed_dim_full: int, + cache: dict[str, torch.Tensor] | None, + key_lmax: int, + key_mmax: int, +) -> torch.Tensor: + """ + Column-project block-diagonal Wigner-D^T for inverse rotation. + + Parameters + ---------- + Dt_full + Block-diagonal Wigner-D^T with shape (E, D, D). + coeff_index_m + Indices for m-major reduced layout with shape (D_m_trunc,). + ebed_dim_full + Full SO(3) dimension D_full = (lmax+1)^2 to slice the block. + cache + Optional cache mapping (lmax, mmax) -> projected matrix. + key_lmax + lmax used to build coeff_index_m (cache key). + key_mmax + mmax used to build coeff_index_m (cache key). + + Returns + ------- + torch.Tensor + Projected inverse rotation matrix with shape (E, D, D_m_trunc). + + Examples + -------- + Continuing lmax=2, mmax=1, the projection selects the same column subset + [0,2,6,1,5,3,7] from ``Dt_full``. This enables inverse rotation with missing + coefficients implicitly zeroed: ``x_global = Dt_from_m @ x_local``. + """ + cache_key = f"{int(key_lmax)}:{int(key_mmax)}" + if cache is not None: + cached = cache.get(cache_key) + if cached is not None: + return cached + + Dt_block = Dt_full[:, :ebed_dim_full, :ebed_dim_full] + proj = Dt_block.index_select(2, coeff_index_m) + if cache is not None: + cache[cache_key] = proj + return proj + + +def so3_packed_index(degree: int, m: int) -> int: + """ + Compute packed (l, m) index for real spherical harmonics layout. + + The packed layout is l-primary with m ordered as ``-l..+l`` inside each l-block. + The index formula is:: + + idx(l, m) = l^2 + l + m + + Parameters + ---------- + degree + Degree l. + m + Order m, must satisfy ``-l <= m <= l``. + + Returns + ------- + int + Packed index. + """ + degree = int(degree) + m = int(m) + return degree * degree + degree + m + + +def build_l_major_index(lmax: int, mmax: int, *, device: torch.device) -> torch.Tensor: + """ + Build coefficient indices for l-major layout truncated by mmax. + + The returned indices select coefficients with ``|m| <= min(mmax, l)`` in the + standard packed (l, m) layout. The order is l-major: + + - l = 0..lmax + - within each l, m = -min(mmax, l) .. +min(mmax, l) + + Parameters + ---------- + lmax + Maximum degree. + mmax + Maximum order (|m|). Must satisfy ``0 <= mmax <= lmax``. + device + Device for the returned tensor. + + Returns + ------- + torch.Tensor + Long tensor of indices with shape (D_m_trunc,), selecting coefficients + from the full packed layout with D=(lmax+1)^2, where D_m_trunc is + the number of coefficients kept under ``|m| <= min(mmax, l)``. + + Examples + -------- + For lmax=2, mmax=1: + - Full packed layout: l=0(0), l=1(1-3), l=2(4-8) + - Truncated by mmax=1: skip (l=2, m=±2) at indices 4,8 + - Returns: [0, 1, 2, 3, 5, 6, 7] + """ + lmax_i = int(lmax) + mmax_i = int(mmax) + if lmax_i < 0: + raise ValueError("`lmax` must be non-negative") + if mmax_i < 0: + raise ValueError("`mmax` must be non-negative") + if mmax_i > lmax_i: + raise ValueError("`mmax` must be <= `lmax`") + + indices: list[int] = [] + for degree in range(lmax_i + 1): + m_keep = min(mmax_i, degree) + for m in range(-m_keep, m_keep + 1): + indices.append(so3_packed_index(degree, m)) + return torch.tensor(indices, device=device, dtype=torch.long) + + +def build_m_major_index(lmax: int, mmax: int, *, device: torch.device) -> torch.Tensor: + """ + Build coefficient indices for m-major layout truncated by mmax. + + This layout minimizes rotation cost and avoids gather-heavy indexing: + + - m = 0: l = 0..lmax (single coefficient per l) + - for each m = 1..mmax: + - negative part: l = m..lmax, coefficient (l, -m) + - positive part: l = m..lmax, coefficient (l, +m) + + Parameters + ---------- + lmax + Maximum degree. + mmax + Maximum order (|m|). Must satisfy ``0 <= mmax <= lmax``. + device + Device for the returned tensor. + + Returns + ------- + torch.Tensor + Long tensor of indices with shape (D_m_trunc,), selecting coefficients + from the full packed layout with D=(lmax+1)^2, where D_m_trunc is + the number of coefficients kept under ``|m| <= min(mmax, l)``. + + Examples + -------- + For lmax=2, mmax=1: + - m=0 group: (l=0,m=0)→0, (l=1,m=0)→2, (l=2,m=0)→6 + - m=1 neg group: (l=1,m=-1)→1, (l=2,m=-1)→5 + - m=1 pos group: (l=1,m=+1)→3, (l=2,m=+1)→7 + - Returns: [0, 2, 6, 1, 5, 3, 7] + """ + lmax_i = int(lmax) + mmax_i = int(mmax) + if lmax_i < 0: + raise ValueError("`lmax` must be non-negative") + if mmax_i < 0: + raise ValueError("`mmax` must be non-negative") + if mmax_i > lmax_i: + raise ValueError("`mmax` must be <= `lmax`") + + indices: list[int] = [] + # === Step 1. m = 0 group (l = 0..lmax) === + for degree in range(lmax_i + 1): + indices.append(so3_packed_index(degree, 0)) + + # === Step 2. m > 0 groups (neg then pos) === + for m in range(1, mmax_i + 1): + for degree in range(m, lmax_i + 1): + indices.append(so3_packed_index(degree, -m)) + for degree in range(m, lmax_i + 1): + indices.append(so3_packed_index(degree, m)) + + return torch.tensor(indices, device=device, dtype=torch.long) + + +def build_m_major_l_index( + lmax: int, mmax: int, *, device: torch.device +) -> torch.Tensor: + """ + Build degree (l) index aligned with `build_m_major_index`. + + Parameters + ---------- + lmax + Maximum degree. + mmax + Maximum order (|m|). Must satisfy ``0 <= mmax <= lmax``. + device + Device for the returned tensor. + + Returns + ------- + torch.Tensor + Long tensor of degrees with shape (D_m_trunc,). Entry i is the degree + l for the i-th coefficient in the m-major layout. + + Examples + -------- + For lmax=2, mmax=1: + - m=0 group: l=0,1,2 + - m=1 neg group: l=1,2 + - m=1 pos group: l=1,2 + - Returns: [0, 1, 2, 1, 2, 1, 2] + """ + lmax_i = int(lmax) + mmax_i = int(mmax) + if lmax_i < 0: + raise ValueError("`lmax` must be non-negative") + if mmax_i < 0: + raise ValueError("`mmax` must be non-negative") + if mmax_i > lmax_i: + raise ValueError("`mmax` must be <= `lmax`") + + degrees: list[int] = [] + # === Step 1. m = 0 group === + for degree in range(lmax_i + 1): + degrees.append(degree) + + # === Step 2. m > 0 groups (neg then pos) === + for m in range(1, mmax_i + 1): + for degree in range(m, lmax_i + 1): + degrees.append(degree) + for degree in range(m, lmax_i + 1): + degrees.append(degree) + + return torch.tensor(degrees, device=device, dtype=torch.long) + + +def build_rotate_inv_rescale( + lmax: int, + mmax: int, + degree_index: torch.Tensor, + *, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Build reduced-layout inverse-rotation rescale factors. + + When ``mmax < lmax``, the reduced local layout keeps only ``2*mmax+1`` orders + for each degree ``l > mmax``. The inverse rotation rescales those truncated + degrees by ``sqrt((2*l+1)/(2*mmax+1))`` so the reduced representation matches + the amplitude expected by the full SO(3) basis. + + Parameters + ---------- + lmax + Maximum degree. + mmax + Maximum order (|m|). Must satisfy ``0 <= mmax <= lmax``. + degree_index + Degree index aligned with the reduced coefficient layout, typically + returned by ``build_m_major_l_index``. + device + Device for the returned tensor. + dtype + Floating-point dtype for the returned tensor. + + Returns + ------- + torch.Tensor + Rescale vector with shape (D_m_trunc,), aligned with the reduced + coefficient layout. + """ + lmax_i = int(lmax) + mmax_i = int(mmax) + if lmax_i < 0: + raise ValueError("`lmax` must be non-negative") + if mmax_i < 0: + raise ValueError("`mmax` must be non-negative") + if mmax_i > lmax_i: + raise ValueError("`mmax` must be <= `lmax`") + + degrees = degree_index.to(device=device, dtype=torch.long) + rescale = torch.ones(degrees.shape[0], device=device, dtype=dtype) + if mmax_i == lmax_i: + return rescale + + mask = degrees > mmax_i + if mask.any(): + denom = float(2 * mmax_i + 1) + degree_values = degrees[mask].to(dtype=dtype) + rescale[mask] = torch.sqrt((2.0 * degree_values + 1.0) / denom) + return rescale diff --git a/deepmd/pt/model/descriptor/sezm_nn/lebedev.py b/deepmd/pt/model/descriptor/sezm_nn/lebedev.py new file mode 100644 index 0000000000..6e7105677e --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/lebedev.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Lebedev quadrature data loader for SeZM S2 projections.""" + +from __future__ import ( + annotations, +) + +from pathlib import ( + Path, +) + +import numpy as np +import torch + +# See: https://people.sc.fsu.edu/~jburkardt/datasets/sphere_lebedev_rule/sphere_lebedev_rule.html +LEBEDEV_RULES_FILE = Path(__file__).with_name("lebedev_rules.npz") +LEBEDEV_PRECISION_TO_NPOINTS = { + 3: 6, + 5: 14, + 7: 26, + 9: 38, + 11: 50, + 13: 74, + 15: 86, + 17: 110, + 19: 146, + 21: 170, + 23: 194, + 25: 230, + 27: 266, + 29: 302, + 31: 350, + 35: 434, + 41: 590, + 47: 770, + 53: 974, + 59: 1202, + 65: 1454, + 71: 1730, + 77: 2030, + 83: 2354, + 89: 2702, + 95: 3074, + 101: 3470, + 107: 3890, + 113: 4334, + 119: 4802, + 125: 5294, + 131: 5810, +} + + +def load_lebedev_rule( + precision: int, + *, + dtype: torch.dtype, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Load one Lebedev rule from the packaged compressed data file. + + Parameters + ---------- + precision + Algebraic precision of the requested Lebedev rule. + dtype + Output tensor dtype. + device + Output tensor device. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + Cartesian unit points with shape ``(A, 3)`` and normalized weights with + shape ``(A,)``. The weights sum to one, so the sphere integral is + ``4*pi*sum(weights*f(points))``. + """ + rule_key = f"{int(precision):03d}" + if not LEBEDEV_RULES_FILE.exists(): + raise FileNotFoundError( + f"Lebedev quadrature data file is missing: {LEBEDEV_RULES_FILE}" + ) + with np.load(LEBEDEV_RULES_FILE) as rules: + point_key = f"points_{rule_key}" + weight_key = f"weights_{rule_key}" + if point_key not in rules or weight_key not in rules: + raise ValueError(f"Lebedev rule with precision {precision} is not packaged") + points_np = rules[point_key] + weights_np = rules[weight_key] + points = torch.as_tensor(points_np, dtype=dtype, device=device) + weights = torch.as_tensor(weights_np, dtype=dtype, device=device) + return points, weights diff --git a/deepmd/pt/model/descriptor/sezm_nn/lebedev_rules.npz b/deepmd/pt/model/descriptor/sezm_nn/lebedev_rules.npz new file mode 100644 index 0000000000..6e5ba699bd Binary files /dev/null and b/deepmd/pt/model/descriptor/sezm_nn/lebedev_rules.npz differ diff --git a/deepmd/pt/model/descriptor/sezm_nn/lora.py b/deepmd/pt/model/descriptor/sezm_nn/lora.py new file mode 100644 index 0000000000..3db8befae2 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/lora.py @@ -0,0 +1,810 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""LoRA low-rank fine-tuning support for SeZM. + +This module adds two things: + +* ``LoRASO3`` and ``LoRASO2`` subclasses that wrap the corresponding base + equivariant linear operators (``SO3Linear`` / ``SO2Linear``). Each one + freezes the pre-trained weights and registers rank-``R`` adapter + parameters ``A``/``B`` whose shapes share the base's batch layout + (per-``l`` for SO(3), per-``|m|``-group for SO(2)). The LoRA delta is + folded into the *effective* weight before the single large einsum that + already exists in the base module; forward FLOPs are therefore identical + to the base, and the overhead comes only from an ``O(R)`` weight-side + matmul that does not depend on the number of edges or nodes. + +* ``apply_lora_to_sezm``, ``merge_lora_into_base`` and a few helpers that + drive the fine-tune policy (which submodules stay trainable, which ones + remain frozen) and the merged-checkpoint export used by + ``Trainer.save_model_merged``. + +Naming convention: the LoRA parameter names -- ``A_by_l``, ``B_by_l``, +``A_m0``, ``B_m0``, ``A_m``, ``B_m`` -- intentionally do **not** start with +``adam_`` / ``adamw_`` and do not contain ``bias``. ``HybridMuon.get_adam_route`` +therefore classifies them as ``muon`` and, because the tensors have the +same rank structure as the corresponding base weights, the slice-mode +matrix view gives per-``l`` / per-``|m|``-group Newton-Schulz updates that +match the base training recipe. +""" + +from __future__ import ( + annotations, +) + +import math +from copy import ( + deepcopy, +) +from typing import ( + Any, +) + +import torch +import torch.nn as nn + +from .activation import ( + GatedActivation, +) +from .so2 import ( + SO2Linear, +) +from .so3 import ( + SO3Linear, +) + +# --------------------------------------------------------------------------- +# LoRA adapter modules +# --------------------------------------------------------------------------- + + +class LoRASO3(SO3Linear): + """ + Per-l ELoRA adapter for ``SO3Linear``. + + The pre-trained weight ``self.weight`` (``(lmax+1, C_in, F*C_out)``) is + frozen. Two new 3D parameters ``A_by_l`` (``(lmax+1, rank, C_in)``) and + ``B_by_l`` (``(lmax+1, F*C_out, rank)``) share the same ``lmax+1`` batch + axis as the base so that ``muon_mode="slice"`` updates every ``l``-block + independently. SO(3) equivariance is preserved because the per-``l`` + delta only rotates within each ``l``-block (no cross-``l`` mixing). + + Parameters + ---------- + base + Pre-trained ``SO3Linear`` to adapt. Its weights are copied and then + frozen. + rank + LoRA rank. Must satisfy ``rank >= 1``. + alpha + Scaling numerator; the effective scaling is ``alpha / rank``. + ``None`` defaults to ``alpha = rank`` (scaling ``1.0``). + """ + + def __init__( + self, + base: SO3Linear, + *, + rank: int, + alpha: float | None = None, + ) -> None: + if rank < 1: + raise ValueError(f"LoRASO3 requires rank >= 1, got {rank}") + # Construct a same-shape SO3Linear, then overwrite its weight with + # base's state. ``init_std=0.0`` skips the expensive random init. + super().__init__( + lmax=base.lmax, + in_channels=base.in_channels, + out_channels=base.out_channels, + n_focus=base.n_focus, + dtype=base.dtype, + mlp_bias=base.mlp_bias, + trainable=False, + seed=None, + init_std=0.0, + ) + self.load_state_dict(base.state_dict()) + # Defensive: ensure the base weight is frozen even if ``base`` was + # trainable at serialize time. ``self.bias`` is intentionally *kept* + # trainable — ``apply_lora_to_sezm`` re-enables every leaf whose name + # contains ``"bias"`` so the LoRA-preserved bias can absorb the + # downstream mean shift alongside the low-rank ``ΔW``. The assignment + # here is only a "known starting state" before that policy step runs. + self.weight.requires_grad_(False) + if self.bias is not None: + self.bias.requires_grad_(False) + + self.rank = int(rank) + alpha_value = float(alpha) if alpha is not None else float(rank) + self.scaling = alpha_value / float(rank) + self.register_buffer( + "lora_scaling", + torch.tensor(self.scaling, dtype=self.dtype, device=self.device), + persistent=True, + ) + + num_l = self.lmax + 1 + self.A_by_l = nn.Parameter( + torch.empty( + num_l, + self.rank, + self.in_channels, + dtype=self.dtype, + device=self.device, + ) + ) + # B is zero-initialised so that the initial forward is an exact + # identity to the base module; training backprop updates B first + # (gradA is zero while B is zero), which is the standard LoRA + # two-step unlock pattern and is compatible with Newton-Schulz on + # rectangular matrices. + self.B_by_l = nn.Parameter( + torch.zeros( + num_l, + self.n_focus * self.out_channels, + self.rank, + dtype=self.dtype, + device=self.device, + ) + ) + nn.init.normal_(self.A_by_l, mean=0.0, std=1.0 / math.sqrt(self.rank)) + + def extra_repr(self) -> str: + return f"rank={self.rank}, scaling={self.scaling}" + + def _compute_delta_weight(self) -> torch.Tensor: + """Return ``ΔW`` with shape ``(lmax+1, C_in, F*C_out)``.""" + return torch.einsum("lor,lri->lio", self.B_by_l, self.A_by_l) * self.scaling + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x + Input features with shape ``(N, D, F, C_in)`` where ``D=(lmax+1)^2``. + + Returns + ------- + torch.Tensor + Output features with shape ``(N, D, F, C_out)``. + """ + delta_w = self._compute_delta_weight() + weight = (self.weight + delta_w).view( + self.lmax + 1, + self.in_channels, + self.n_focus, + self.out_channels, + ) + weight_expanded = torch.index_select(weight, dim=0, index=self.expand_index) + out = torch.einsum("ndfi,difo->ndfo", x, weight_expanded) + if self.mlp_bias: + bias = self.bias.view(self.n_focus, self.out_channels) + out[:, 0, :, :] = out[:, 0, :, :] + bias.unsqueeze(0) + return out + + def merge_into_base(self) -> SO3Linear: + """Build a plain ``SO3Linear`` whose weight has absorbed the LoRA delta.""" + base = SO3Linear( + lmax=self.lmax, + in_channels=self.in_channels, + out_channels=self.out_channels, + n_focus=self.n_focus, + dtype=self.dtype, + mlp_bias=self.mlp_bias, + trainable=True, + seed=None, + init_std=0.0, + ) + with torch.no_grad(): + merged = self.weight.detach() + self._compute_delta_weight().detach() + base.weight.copy_(merged) + if self.bias is not None: + assert base.bias is not None + base.bias.copy_(self.bias.detach()) + return base + + +class LoRASO2(SO2Linear): + """ + Per-``|m|``-group LoRA adapter for ``SO2Linear``. + + ``weight_m0`` (``(num_in_m0, F*num_out_m0)``) and each + ``weight_m[i]`` (``(num_in_m, F*2*num_out_m)``) get an independent 2D + LoRA pair ``A``/``B``. SO(2) equivariance is preserved because the + ``|m|>0`` 2x2 complex block ``[[W_u, -W_v], [W_v, W_u]]`` stays intact + when ``ΔW_m`` is absorbed into the concatenated ``[W_u | W_v]`` layout + before ``_build_so2_weight`` splits it (the shared input basis ``A`` + splits naturally into ``ΔW_u = B_u A`` and ``ΔW_v = B_v A``). + + The base ``forward``/``_cached_weight``/``train`` logic is inherited + unchanged; only ``_build_so2_weight`` is overridden to fold the LoRA + delta into each base block prior to assembling the block-diagonal + weight. The ``ΔW_m`` construction does not depend on the edge count + ``E``, so the forward FLOPs remain identical to the base. + + Parameters + ---------- + base + Pre-trained ``SO2Linear`` to adapt. + rank + LoRA rank. + alpha + Scaling numerator; scaling is ``alpha / rank``. ``None`` defaults + to ``alpha = rank`` (scaling ``1.0``). + """ + + def __init__( + self, + base: SO2Linear, + *, + rank: int, + alpha: float | None = None, + ) -> None: + if rank < 1: + raise ValueError(f"LoRASO2 requires rank >= 1, got {rank}") + super().__init__( + lmax=base.lmax, + mmax=base.mmax, + in_channels=base.in_channels, + out_channels=base.out_channels, + n_focus=base.n_focus, + dtype=base.dtype, + mlp_bias=base.mlp_bias, + seed=None, + trainable=False, + ) + self.load_state_dict(base.state_dict()) + # Defensive: the base matrices are frozen here, but ``self.bias0`` is + # intentionally re-enabled later by ``apply_lora_to_sezm`` via the + # "any leaf containing 'bias' is trainable" rule (``"bias" in "bias0"`` + # is ``True``) so the LoRA-preserved scalar offset can absorb the + # downstream mean shift alongside the low-rank ``ΔW``. + self.weight_m0.requires_grad_(False) + if self.bias0 is not None: + self.bias0.requires_grad_(False) + for w in self.weight_m: + w.requires_grad_(False) + # Any cached block-diagonal from the base is stale now; force rebuild. + self._cached_weight = None + + self.rank = int(rank) + alpha_value = float(alpha) if alpha is not None else float(rank) + self.scaling = alpha_value / float(rank) + self.register_buffer( + "lora_scaling", + torch.tensor(self.scaling, dtype=self.dtype, device=self.device), + persistent=True, + ) + + num_in_m0 = (self.lmax + 1) * self.in_channels + num_out_m0_per_focus = (self.lmax + 1) * self.out_channels + focus_num_out_m0 = self.n_focus * num_out_m0_per_focus + self.A_m0 = nn.Parameter( + torch.empty( + self.rank, + num_in_m0, + dtype=self.dtype, + device=self.device, + ) + ) + self.B_m0 = nn.Parameter( + torch.zeros( + focus_num_out_m0, + self.rank, + dtype=self.dtype, + device=self.device, + ) + ) + nn.init.normal_(self.A_m0, mean=0.0, std=1.0 / math.sqrt(self.rank)) + + self.A_m = nn.ParameterList() + self.B_m = nn.ParameterList() + for w in self.weight_m: + num_in, focus_two_num_out = w.shape + a_m = nn.Parameter( + torch.empty( + self.rank, + num_in, + dtype=self.dtype, + device=self.device, + ) + ) + b_m = nn.Parameter( + torch.zeros( + focus_two_num_out, + self.rank, + dtype=self.dtype, + device=self.device, + ) + ) + nn.init.normal_(a_m, mean=0.0, std=1.0 / math.sqrt(self.rank)) + self.A_m.append(a_m) + self.B_m.append(b_m) + + def extra_repr(self) -> str: + return f"rank={self.rank}, scaling={self.scaling}" + + def _compute_delta_m0(self) -> torch.Tensor: + """Return ``ΔW_m0`` with shape ``(num_in_m0, F*num_out_m0)``.""" + return torch.einsum("ri,or->io", self.A_m0, self.B_m0) * self.scaling + + def _compute_delta_m(self, m_idx: int) -> torch.Tensor: + """Return ``ΔW_m[m_idx]`` with the same shape as ``weight_m[m_idx]``.""" + return ( + torch.einsum("ri,or->io", self.A_m[m_idx], self.B_m[m_idx]) * self.scaling + ) + + def _build_so2_weight(self) -> torch.Tensor: + """Assemble the block-diagonal weight with LoRA delta folded in.""" + in_total = self.reduced_dim * self.in_channels + out_total = self.reduced_dim * self.out_channels + weight = self.weight_m0.new_zeros(in_total, self.n_focus, out_total) + num_in_m0 = (self.lmax + 1) * self.in_channels + num_out_m0 = (self.lmax + 1) * self.out_channels + + # m=0 block: fold ΔW_m0 into the base weight before the view. + w_m0_eff = (self.weight_m0 + self._compute_delta_m0()).view( + num_in_m0, self.n_focus, num_out_m0 + ) + weight[: self._m0_in, :, : self._m0_out] = w_m0_eff + + # |m|>0 blocks: same 2x2 coupling assembly as the base, but with + # ΔW_m folded into the concatenated [W_u | W_v] layout first. + for m_idx, w_base in enumerate(self.weight_m): + ni0, ni1, pi0, pi1, no0, no1, po0, po1 = self._block_slices[m_idx] + ib = ni1 - ni0 + ob = no1 - no0 + w_eff = (w_base + self._compute_delta_m(m_idx)).view( + ib, self.n_focus, 2 * ob + ) + w_u = w_eff[:, :, :ob] + w_v = w_eff[:, :, ob:] + weight[ni0:ni1, :, no0:no1] = w_u + weight[ni0:ni1, :, po0:po1] = w_v + weight[pi0:pi1, :, no0:no1] = -w_v + weight[pi0:pi1, :, po0:po1] = w_u + return weight + + def merge_into_base(self) -> SO2Linear: + """Build a plain ``SO2Linear`` whose weights have absorbed every LoRA delta.""" + base = SO2Linear( + lmax=self.lmax, + mmax=self.mmax, + in_channels=self.in_channels, + out_channels=self.out_channels, + n_focus=self.n_focus, + dtype=self.dtype, + mlp_bias=self.mlp_bias, + seed=None, + trainable=True, + ) + with torch.no_grad(): + base.weight_m0.copy_( + self.weight_m0.detach() + self._compute_delta_m0().detach() + ) + if self.bias0 is not None: + assert base.bias0 is not None + base.bias0.copy_(self.bias0.detach()) + for m_idx, w in enumerate(self.weight_m): + base.weight_m[m_idx].copy_( + w.detach() + self._compute_delta_m(m_idx).detach() + ) + return base + + +# --------------------------------------------------------------------------- +# Fine-tune policy: freeze / unfreeze rules +# --------------------------------------------------------------------------- + +# Leaf parameter names that stay trainable during LoRA fine-tune. These are small +# scalar / per-l scales / attention gating weights whose full-rank update costs +# are negligible but directly absorb the domain shift of the downstream dataset. +_UNFREEZE_LEAF_NAMES: frozenset[str] = frozenset( + { + "adam_scale", + "adam_so2_layer_scales", + "adam_ffn_layer_scales", + "film_scale_strength_log", + "film_shift_strength_log", + "adamw_attn_logit_w", + "adamw_attn_z_bias_raw", + "adamw_attn_gate_w", + "adamw_focus_compete_w", + "adamw_pseudo_query", + "focus_compete_bias", + } +) + +# Leaf names that stay frozen (override any unfreeze rule above). The backbone +# pre-training has already converged on these quantities for all-element +# datasets; downstream fine-tuning should keep them fixed. +_OVERRIDE_FREEZE_LEAF_NAMES: frozenset[str] = frozenset( + { + "adam_type_embedding", + "adam_freqs", + } +) + +# Submodule paths (rooted at the SeZMModel) that get fully unfrozen. +_UNFREEZE_SUBMODULE_PATHS: tuple[str, ...] = ( + "atomic_model.fitting_net", + "atomic_model.dens_fitting_net", + "atomic_model.descriptor.radial_embedding", + "atomic_model.descriptor.env_seed_embedding", + "atomic_model.descriptor.film_scale_norm", + "atomic_model.descriptor.film_shift_norm", + "atomic_model.descriptor.final_full_attn_res", + "atomic_model.descriptor.final_block_attn_res", +) + +# Per-interaction-block submodule paths that get fully unfrozen. The +# descriptor stores the block list at ``atomic_model.descriptor.blocks``. +_UNFREEZE_PER_BLOCK_SUBPATHS: tuple[str, ...] = ( + "full_attn_res_so2", + "full_attn_res_ffns", + "block_attn_res_so2", + "block_attn_res_ffns", + "so2_conv.attn_q_proj", + "so2_conv.attn_k_proj", + "so2_conv.attn_qk_norm", + "so2_conv.attn_output_gate_norm", + "so2_conv.focus_compete_norm", + "so2_conv.radial_hidden_proj", + "so2_conv.so2_layer_attn_res", +) + +_BLOCKS_PATH: str = "atomic_model.descriptor.blocks" + + +def _leaf_name(param_name: str) -> str: + """Return the trailing non-numeric segment of a parameter name. + + ``nn.ParameterList`` children show up as ``foo.0``, ``foo.1``, ...; + ``get_adam_route`` strips those numeric indices before routing, so this + helper keeps the policy in sync. + """ + parts = param_name.split(".") + i = len(parts) - 1 + while i > 0 and parts[i].isdigit(): + i -= 1 + return parts[i] + + +def _get_submodule_or_none(root: nn.Module, path: str) -> nn.Module | None: + if not path: + return root + try: + return root.get_submodule(path) + except AttributeError: + return None + + +def _clear_sezm_compile_cache(model: nn.Module) -> None: + """Invalidate any ``compiled_core_compute_cache`` / ``compiled_dens_compute``. + + LoRA injection or merge replaces submodules, which changes the Python + object graph that ``torch.compile`` had captured. Without clearing the + cache the next forward would reuse the stale compiled callable and + crash or silently skip LoRA parameters. Mirrors the pattern used in + :meth:`SeZMModel.reset_head_for_mode`. + """ + for m in model.modules(): + core_cache = getattr(m, "compiled_core_compute_cache", None) + if isinstance(core_cache, dict): + core_cache.clear() + if hasattr(m, "_core_compute_pending_compile_t0"): + m._core_compute_pending_compile_t0 = None + if hasattr(m, "_core_compute_pending_compile_key"): + m._core_compute_pending_compile_key = None + if hasattr(m, "compiled_dens_compute"): + object.__setattr__(m, "compiled_dens_compute", None) + if hasattr(m, "_dens_compiled"): + m._dens_compiled = False + if hasattr(m, "_dens_pending_compile_t0"): + m._dens_pending_compile_t0 = None + + +def _swap_submodule(parent: nn.Module, attr: str, new_module: nn.Module) -> None: + """Replace ``parent.attr`` with ``new_module``. + + Uses ``parent._modules[attr]`` so that numeric attribute names (for + ``nn.ModuleList`` / ``nn.ParameterList`` children) work as well. + """ + parent._modules[attr] = new_module + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def has_lora(module: nn.Module) -> bool: + """Return ``True`` iff any submodule is a LoRA adapter.""" + return any(isinstance(m, (LoRASO3, LoRASO2)) for m in module.modules()) + + +def apply_lora_to_sezm( + model: nn.Module, + *, + rank: int, + alpha: float | None = None, +) -> nn.Module: + """ + Inject LoRA adapters into every ``SO3Linear`` / ``SO2Linear`` of a SeZM + model and apply the SeZM fine-tune freeze/unfreeze policy in place. + + This function is idempotent-safe: the ``type(mod) is SO3Linear`` (exact + type) test prevents re-wrapping a LoRASO3 that is already present. + + Parameters + ---------- + model + A ``SeZMModel`` instance (or any ``nn.Module`` containing SeZM + ``SO3Linear`` / ``SO2Linear`` submodules). + rank + LoRA rank applied uniformly to every adapter. + alpha + LoRA scaling numerator; scaling is ``alpha / rank``. ``None`` + defaults to ``alpha = rank`` (scaling ``1.0``). + + Returns + ------- + nn.Module + The same ``model`` after injection (returned for chaining). + """ + # === Step 1. Freeze all parameters === + for p in model.parameters(): + p.requires_grad_(False) + + # === Step 2. Replace SO3Linear / SO2Linear with LoRA subclasses === + # Snapshot named_modules() first so the later in-place replacement does + # not invalidate the iterator. ``type(...) is ...`` is deliberate: it + # matches only the exact base class, skipping any pre-existing LoRA + # adapter so apply_lora_to_sezm remains idempotent. + replacements: list[tuple[nn.Module, str, nn.Module]] = [] + for name, mod in list(model.named_modules()): + if type(mod) is SO3Linear: + parent_name, _, attr = name.rpartition(".") + parent = model.get_submodule(parent_name) if parent_name else model + replacements.append((parent, attr, LoRASO3(mod, rank=rank, alpha=alpha))) + elif type(mod) is SO2Linear: + parent_name, _, attr = name.rpartition(".") + parent = model.get_submodule(parent_name) if parent_name else model + replacements.append((parent, attr, LoRASO2(mod, rank=rank, alpha=alpha))) + for parent, attr, new_mod in replacements: + _swap_submodule(parent, attr, new_mod) + + # === Step 3. Unfreeze whole submodules (descriptor-level and per-block) === + for path in _UNFREEZE_SUBMODULE_PATHS: + sub = _get_submodule_or_none(model, path) + if sub is None: + continue + for p in sub.parameters(): + p.requires_grad_(True) + + blocks = _get_submodule_or_none(model, _BLOCKS_PATH) + if blocks is not None: + for block in blocks: + for subpath in _UNFREEZE_PER_BLOCK_SUBPATHS: + sub = _get_submodule_or_none(block, subpath) + if sub is None: + continue + for p in sub.parameters(): + p.requires_grad_(True) + + # === Step 4. Unfreeze small parameters by leaf name === + # Any name ending in a LoRA-listed leaf or containing ``bias`` becomes + # trainable. The ``"bias" in leaf`` rule deliberately also re-enables the + # base biases that ``LoRASO3.__init__`` / ``LoRASO2.__init__`` had frozen + # (``SO3Linear.bias``, ``SO2Linear.bias0``); keeping those trainable lets + # the LoRA-preserved offsets absorb the downstream mean shift alongside + # the low-rank ``ΔW``. The same rule also unfreezes norm biases + # (``EquivariantRMSNorm.bias``, ``ReducedEquivariantRMSNorm.bias0``) + # anywhere in the model -- tiny parameter counts, large domain-shift + # headroom. ``adam_scale`` is listed similarly: every RMSNorm scale in + # the backbone (per-block ``pre/post_so2_norm``, ``pre/post_ffn_norms``, + # ``so2_inter_norms``, etc.) becomes trainable, again at negligible cost. + for name, p in model.named_parameters(): + leaf = _leaf_name(name) + if leaf in _UNFREEZE_LEAF_NAMES or "bias" in leaf: + p.requires_grad_(True) + + # === Step 5. Override-freeze converged parameters by leaf name === + # Must run after steps 3/4 because earlier whole-module unfreezes may + # have turned them back on (e.g. ``adam_type_embedding`` inside the + # unfrozen ``env_seed_embedding``). + for name, p in model.named_parameters(): + leaf = _leaf_name(name) + if leaf in _OVERRIDE_FREEZE_LEAF_NAMES: + p.requires_grad_(False) + + # === Step 6. Override-freeze every GatedActivation submodule === + # Stable gate patterns; avoids turning on gate_linear.bias via the + # step-4 "bias" rule. + for mod in model.modules(): + if isinstance(mod, GatedActivation): + for p in mod.parameters(): + p.requires_grad_(False) + + return model + + +def fold_lora_state_dict_keys(state_dict: dict[str, torch.Tensor], prefix: str) -> None: + """Fold LoRA adapter keys into base weight keys in *state_dict* (in-place). + + Scans for SO3-style ``A_by_l``/``B_by_l`` pairs and SO2-style + ``A_m0``/``B_m0``/``A_m.*``/``B_m.*`` groups under *prefix*. For each + pair whose corresponding base weight key also exists, the delta + ``einsum(B, A) * scaling`` is added to the weight and the adapter keys + are popped. ``lora_scaling`` is read from *state_dict* when present; + otherwise ``1.0`` is assumed (the default when ``alpha == rank``). + + Called by ``DescrptSeZM._load_from_state_dict`` so that a LoRA-trained + checkpoint can be loaded into a plain (non-LoRA) descriptor transparently. + + Parameters + ---------- + state_dict + Flat state dict to mutate in place. + prefix + Key prefix that scopes the scan (e.g. ``"model.Default.atomic_model.descriptor."``). + """ + # === SO3: fold A_by_l / B_by_l into weight === + so3_prefixes = [ + k[: -len("A_by_l")] + for k in list(state_dict) + if k.startswith(prefix) and k.endswith(".A_by_l") + ] + for sp in so3_prefixes: + a_key, b_key, w_key = sp + "A_by_l", sp + "B_by_l", sp + "weight" + if b_key not in state_dict or w_key not in state_dict: + continue + a = state_dict.pop(a_key) + b = state_dict.pop(b_key) + scaling_tensor = state_dict.pop(sp + "lora_scaling", None) + scaling = float(scaling_tensor) if scaling_tensor is not None else 1.0 + state_dict[w_key] = ( + state_dict[w_key] + torch.einsum("lor,lri->lio", b, a) * scaling + ) + + # === SO2: fold A_m0 / B_m0 and A_m.* / B_m.* into weight_m0 / weight_m.* === + so2_prefixes = [ + k[: -len("A_m0")] + for k in list(state_dict) + if k.startswith(prefix) and k.endswith(".A_m0") + ] + for sp in so2_prefixes: + a0_key, b0_key, w0_key = sp + "A_m0", sp + "B_m0", sp + "weight_m0" + if b0_key not in state_dict or w0_key not in state_dict: + continue + scaling_tensor = state_dict.pop(sp + "lora_scaling", None) + scaling = float(scaling_tensor) if scaling_tensor is not None else 1.0 + a0 = state_dict.pop(a0_key) + b0 = state_dict.pop(b0_key) + state_dict[w0_key] = ( + state_dict[w0_key] + torch.einsum("ri,or->io", a0, b0) * scaling + ) + m_idx = 0 + while True: + a_key = sp + f"A_m.{m_idx}" + b_key = sp + f"B_m.{m_idx}" + w_key = sp + f"weight_m.{m_idx}" + if a_key not in state_dict: + break + a_m = state_dict.pop(a_key) + b_m = state_dict.pop(b_key) + state_dict[w_key] = ( + state_dict[w_key] + torch.einsum("ri,or->io", a_m, b_m) * scaling + ) + m_idx += 1 + + +def build_merged_state_dict( + module: nn.Module, + state_dict: dict[str, torch.Tensor] | None = None, + *, + prefix: str = "", +) -> dict[str, torch.Tensor]: + """ + Produce a plain (LoRA-free) state dict from a LoRA-augmented module. + + Walks ``module.named_modules()`` and, for every ``LoRASO3`` / + ``LoRASO2`` submodule, folds ``ΔW = BA·scaling`` into the base weight + key and removes the ``A``/``B`` keys. The returned dict has the same + key set as a same-topology SeZM that has never been LoRA-wrapped, and + is suitable for loading into a plain SeZM model with ``strict=True``. + + Non-destructive: when ``state_dict`` is ``None`` a deep copy of + ``module.state_dict()`` is taken; when the caller provides a + ``state_dict`` it is assumed to already be a detached copy (e.g. the + full-gathered state dict from FSDP2) and is *mutated in place* for + efficiency. + + Parameters + ---------- + module + The LoRA-augmented module tree. Only used for structural + information (LoRA submodule prefixes, ``scaling``, ``weight_m`` + length); its parameters are not modified. + state_dict + Optional pre-collected state dict (e.g. gathered from FSDP2). If + ``None``, ``deepcopy(module.state_dict())`` is used. + prefix + Prefix to prepend to every LoRA submodule name when looking keys + up in ``state_dict``. Use this when the caller has state keyed + under an outer wrapper (for example ``"model.Default."``). + + Returns + ------- + dict + Flat state dict with LoRA adapters folded into base weights. + """ + state = deepcopy(module.state_dict()) if state_dict is None else state_dict + for name, mod in module.named_modules(): + key_prefix = prefix + name + "." if name else prefix + if isinstance(mod, LoRASO3): + a = state.pop(key_prefix + "A_by_l") + b = state.pop(key_prefix + "B_by_l") + state.pop(key_prefix + "lora_scaling", None) + weight_key = key_prefix + "weight" + delta = torch.einsum("lor,lri->lio", b, a) * mod.scaling + state[weight_key] = state[weight_key] + delta + elif isinstance(mod, LoRASO2): + a_m0 = state.pop(key_prefix + "A_m0") + b_m0 = state.pop(key_prefix + "B_m0") + state.pop(key_prefix + "lora_scaling", None) + w_m0_key = key_prefix + "weight_m0" + state[w_m0_key] = ( + state[w_m0_key] + torch.einsum("ri,or->io", a_m0, b_m0) * mod.scaling + ) + for m_idx in range(len(mod.weight_m)): + a_i = state.pop(key_prefix + f"A_m.{m_idx}") + b_i = state.pop(key_prefix + f"B_m.{m_idx}") + w_i_key = key_prefix + f"weight_m.{m_idx}" + state[w_i_key] = ( + state[w_i_key] + torch.einsum("ri,or->io", a_i, b_i) * mod.scaling + ) + return state + + +def strip_lora_from_extra_state(extra_state: dict[str, Any]) -> dict[str, Any]: + """ + Drop any ``lora`` entry from ``_extra_state["model_params"]``. + + Handles both single-task (``model_params`` is the model config) and + multi-task (``model_params["model_dict"][]`` is each branch's + config). Returns a deep-copied dict; the input is not mutated. + """ + out = deepcopy(extra_state) + model_params = out.get("model_params") + if not isinstance(model_params, dict): + return out + model_params.pop("lora", None) + model_dict = model_params.get("model_dict") + if isinstance(model_dict, dict): + for branch_cfg in model_dict.values(): + if isinstance(branch_cfg, dict): + branch_cfg.pop("lora", None) + return out + + +def merge_lora_into_base(model: nn.Module) -> nn.Module: + """ + Destructively replace every ``LoRASO3`` / ``LoRASO2`` with its merged + plain base module. + + After this call the model no longer contains LoRA submodules: the + optimizer, EMA state, and any compiled callables that reference the old + submodules become invalid. Prefer :func:`build_merged_state_dict` for + non-destructive checkpoint export during or after training; this function + is primarily useful in tests and offline scripts. + """ + replacements: list[tuple[nn.Module, str, nn.Module]] = [] + for name, mod in list(model.named_modules()): + if isinstance(mod, (LoRASO3, LoRASO2)): + parent_name, _, attr = name.rpartition(".") + parent = model.get_submodule(parent_name) if parent_name else model + replacements.append((parent, attr, mod.merge_into_base())) + for parent, attr, new_mod in replacements: + _swap_submodule(parent, attr, new_mod) + _clear_sezm_compile_cache(model) + return model diff --git a/deepmd/pt/model/descriptor/sezm_nn/norm.py b/deepmd/pt/model/descriptor/sezm_nn/norm.py new file mode 100644 index 0000000000..453c6af6af --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/norm.py @@ -0,0 +1,669 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Normalization layers for the SeZM descriptor. + +This module defines the packed-layout, reduced-layout, generic, and scalar +RMS normalization layers used throughout SeZM. +""" + +from __future__ import ( + annotations, +) + +from typing import ( + Any, +) + +import torch +import torch.nn as nn + +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISION_DICT, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .indexing import ( + map_degree_idx, +) +from .utils import ( + np_safe, + safe_numpy_to_tensor, +) + + +class RMSNorm(nn.Module): + """ + Generic RMSNorm on tensors with shape `(..., C)`. + + This is the plain channel-wise RMS normalization used for non-equivariant + branches whose last axis stores feature channels. A learnable affine scale is + applied on the channel axis only, while all leading axes are treated as batch + dimensions. + + Parameters + ---------- + channels + Feature dimension of the last axis. + eps + Small epsilon for numerical stability. + dtype + Parameter and computation dtype. Caller should pass compute_dtype (fp32+) + for numerical stability. + trainable + Whether parameters are trainable. + """ + + def __init__( + self, + *, + channels: int, + eps: float = 1e-7, + dtype: torch.dtype, + trainable: bool, + ) -> None: + super().__init__() + self.channels = int(channels) + self.dtype = dtype + self.device = env.DEVICE + self.eps = float(eps) + self.register_buffer( + "eps_tensor", + torch.tensor(self.eps, dtype=self.dtype, device=self.device), + persistent=False, + ) + + # adam_ prefix routes this to Adam (no weight decay) in HybridMuon. + self.adam_scale = nn.Parameter( + torch.ones(self.channels, dtype=self.dtype, device=self.device) + ) + + for p in self.parameters(): + p.requires_grad = trainable + + @torch.amp.autocast("cuda", enabled=False) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x + Input tensor with shape `(..., C)`. + + Returns + ------- + torch.Tensor + Normalized tensor with shape `(..., C)`, same dtype as input. + """ + in_dtype = x.dtype + x = x.to(dtype=self.dtype) + inv_rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps_tensor) + scale = self.adam_scale.view(*([1] * (x.ndim - 1)), self.channels) + x = x * inv_rms * scale + return x.to(dtype=in_dtype) + + def serialize(self) -> dict[str, Any]: + trainable = all(p.requires_grad for p in self.parameters()) + state = self.state_dict() + return { + "@class": "RMSNorm", + "@version": 1, + "config": { + "channels": self.channels, + "eps": self.eps, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "trainable": trainable, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> RMSNorm: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "RMSNorm": + raise ValueError(f"Invalid class for RMSNorm: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj + + +class EquivariantRMSNorm(nn.Module): + """ + Degree-balanced equivariant RMS normalization on packed `(l, m)` layout. + + The scalar slice `l=0` is mean-centered across channels before the shared + RMS is evaluated. All coefficients, including the centered scalar slice, + contribute to the same per-sample and per-focus RMS. Degree balancing + assigns each coefficient from degree `l` the weight + `1 / ((2 * l + 1) * (lmax + 1))`, so each degree contributes equally + regardless of its multiplicity. A learnable per-focus, per-degree scale is + then expanded to all `m` coefficients, and a learnable bias is added only + to the scalar slice. + + Parameters + ---------- + lmax + Maximum spherical harmonic degree. + channels + Channels per `(l, m)` coefficient in each focus stream. + n_focus + Number of focus streams. Affine parameters are independent per focus. + eps + Small epsilon for numerical stability. + dtype + Parameter and computation dtype. Caller should pass compute_dtype (fp32+) + for numerical stability and handle input/output conversion at boundaries. + trainable + Whether parameters are trainable. + """ + + def __init__( + self, + lmax: int, + channels: int, + n_focus: int = 1, + *, + eps: float = 1e-5, + dtype: torch.dtype, + trainable: bool, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.channels = int(channels) + self.n_focus = int(n_focus) + self.dtype = dtype + self.device = env.DEVICE + self.eps = float(eps) + self.register_buffer( + "eps_tensor", + torch.tensor(self.eps, dtype=self.dtype, device=self.device), + persistent=False, + ) + + # === Step 1. Learnable Parameters === + # Store affine scales in degree-major layout (L, F, C). This matches the + # packed output layout after degree expansion + # adam_ prefix routes this to Adam (no weight decay) in HybridMuon. + self.adam_scale = nn.Parameter( + torch.ones( + self.lmax + 1, + self.n_focus, + self.channels, + dtype=self.dtype, + device=self.device, + ) + ) + # Bias only for l=0, independent per focus. + self.bias = nn.Parameter( + torch.zeros( + self.n_focus, self.channels, dtype=self.dtype, device=self.device + ) + ) + + # === Step 2. Index and Weight Buffers === + expand_index = map_degree_idx(self.lmax, device=self.device) + self.register_buffer("expand_index", expand_index, persistent=True) + + # Pre-fuse degree balancing and channel averaging into a single weight: + # w_d = 1 / ((2l+1) * (lmax+1) * C) + # so that + # mean_variance = einsum('ndfc,d->nf', x^2, balance_weight) + # directly computes the shared RMS statistic without allocating an + # intermediate (N, D, F, C) buffer beyond x^2 itself. + weights_list = [] + scale = 1.0 / ((self.lmax + 1) * self.channels) + for l in range(self.lmax + 1): + w = scale / (2 * l + 1) + weights_list.extend([w] * (2 * l + 1)) + balance_weight = torch.tensor( + weights_list, dtype=self.dtype, device=self.device + ) + self.register_buffer("balance_weight", balance_weight, persistent=True) + + for p in self.parameters(): + p.requires_grad = trainable + + @torch.amp.autocast("cuda", enabled=False) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x + Features with shape `(N, D, F, C)` where `D = (lmax + 1)^2`. + + Returns + ------- + torch.Tensor + Normalized features with shape `(N, D, F, C)`, same dtype as input. + """ + in_dtype = x.dtype + x = x.to(dtype=self.dtype) + x0 = x[:, :1, :, :] # (N, 1, F, C) + xt = x[:, 1:, :, :] # (N, D-1, F, C) + + # === Step 1. Center the scalar slice === + x0 = x0 - x0.mean(dim=-1, keepdim=True) + + # === Step 2. Compute a shared degree-balanced RMS === + mean_variance = x0.square().sum(dim=(1, 3)) * self.balance_weight[0] + if xt.numel() > 0: + mean_variance = mean_variance + torch.einsum( + "ndfc,d->nf", xt * xt, self.balance_weight[1:] + ) + inv_rms = ( + torch.rsqrt(mean_variance + self.eps_tensor).unsqueeze(1).unsqueeze(-1) + ) + + x0 = x0 * inv_rms + if xt.numel() > 0: + xt = xt * inv_rms + + # === Step 3. Apply per-degree affine parameters === + expanded_scale = torch.index_select( + self.adam_scale, dim=0, index=self.expand_index + ) + expanded_scale = expanded_scale.unsqueeze(0) # (1, D, F, C) + x0 = x0 * expanded_scale[:, :1, :, :] + if xt.numel() > 0: + xt = xt * expanded_scale[:, 1:, :, :] + + # === Step 4. Add scalar bias and restore layout === + bias0 = self.bias.reshape(1, 1, self.n_focus, -1) # (1, 1, F, C) + x0 = x0 + bias0 + + out = x0 if xt.numel() == 0 else torch.cat([x0, xt], dim=1) + out = out.to(dtype=in_dtype) + return out + + def serialize(self) -> dict[str, Any]: + trainable = all(p.requires_grad for p in self.parameters()) + state = self.state_dict() + return { + "@class": "EquivariantRMSNorm", + "@version": 1, + "config": { + "lmax": self.lmax, + "channels": self.channels, + "n_focus": self.n_focus, + "eps": self.eps, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "trainable": trainable, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> EquivariantRMSNorm: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "EquivariantRMSNorm": + raise ValueError(f"Invalid class for EquivariantRMSNorm: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj + + +class ReducedEquivariantRMSNorm(nn.Module): + """ + Degree-balanced equivariant RMS normalization on reduced m-major layout. + + The scalar slice `l=0` is mean-centered across channels before the shared + RMS is evaluated. All retained coefficients, including the centered scalar + slice, contribute to the same per-edge and per-focus RMS. Degree balancing + assigns each retained coefficient from degree `l` the weight + `1 / (n_coeff_l * (lmax + 1))`, where + `n_coeff_l = 2 * min(l, mmax) + 1` is the number of retained coefficients + for that degree in the reduced layout. A learnable per-focus, per-degree + scale is expanded with `degree_index_m`, and a learnable bias is added only + to the scalar slice. + + Parameters + ---------- + lmax + Maximum spherical harmonic degree. + mmax + Maximum order kept in the truncated layout. + channels + Number of channels per retained coefficient. + degree_index_m + Degree index per coefficient in m-major truncated layout, with shape + `(D_m_trunc,)`. + n_focus + Number of focus streams. + eps + Epsilon for numerical stability. + dtype + Parameter and computation dtype. Caller should pass compute_dtype (fp32+) + for numerical stability. + trainable + Whether parameters are trainable. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int, + channels: int, + degree_index_m: torch.Tensor, + n_focus: int = 1, + eps: float = 1e-5, + dtype: torch.dtype, + trainable: bool, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.mmax = int(mmax) + self.channels = int(channels) + self.n_focus = int(n_focus) + self.eps = float(eps) + self.dtype = dtype + self.device = env.DEVICE + self.register_buffer( + "eps_tensor", + torch.tensor(self.eps, dtype=self.dtype, device=self.device), + persistent=False, + ) + + if degree_index_m.dtype != torch.long: + degree_index_m = degree_index_m.to(dtype=torch.long) + self.register_buffer("degree_index_m", degree_index_m, persistent=True) + + # Pre-fuse degree balancing and channel averaging into a single weight: + # w_d = 1 / (n_coeff_l * (lmax+1) * C) + # where n_coeff_l is the number of retained coefficients for degree l in + # the reduced layout. + weights = torch.zeros( + degree_index_m.numel(), dtype=self.dtype, device=self.device + ) + scale = 1.0 / ((self.lmax + 1) * self.channels) + for l in range(self.lmax + 1): + n_coeff_l = 2 * min(l, self.mmax) + 1 + w_l = scale / float(n_coeff_l) + weights[degree_index_m == l] = w_l + if torch.any(weights == 0): + raise ValueError( + "ReducedEquivariantRMSNorm: balance_weight has zeros; degree_index_m may be invalid." + ) + self.register_buffer("balance_weight", weights, persistent=True) + + # adam_ prefix routes this to Adam (no weight decay) in HybridMuon. + self.adam_scale = nn.Parameter( + torch.ones( + self.n_focus, + self.lmax + 1, + self.channels, + dtype=self.dtype, + device=self.device, + ) + ) + self.bias0 = nn.Parameter( + torch.zeros( + self.n_focus, + self.channels, + dtype=self.dtype, + device=self.device, + ) + ) + + for p in self.parameters(): + p.requires_grad = trainable + + @torch.amp.autocast("cuda", enabled=False) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x + Input tensor with shape (E, F, D_m_trunc, C). + + Returns + ------- + torch.Tensor + Normalized tensor with shape `(E, F, D_m_trunc, C)`, same dtype as + input. + """ + in_dtype = x.dtype + x = x.to(dtype=self.dtype) + x0 = x[:, :, :1, :] # (E, F, 1, C) + xt = x[:, :, 1:, :] # (E, F, D_m_trunc-1, C) + + # === Step 1. Center the scalar slice === + x0 = x0 - x0.mean(dim=-1, keepdim=True) + + # === Step 2. Compute a shared degree-balanced RMS === + mean_variance = x0.square().sum(dim=(2, 3)) * self.balance_weight[0] + if xt.numel() > 0: + mean_variance = mean_variance + torch.einsum( + "efdc,d->ef", xt * xt, self.balance_weight[1:] + ) + inv_rms = ( + torch.rsqrt(mean_variance + self.eps_tensor).unsqueeze(-1).unsqueeze(-1) + ) + + x0 = x0 * inv_rms + if xt.numel() > 0: + xt = xt * inv_rms + + # === Step 3. Apply per-degree affine parameters === + expanded_scale = torch.index_select( + self.adam_scale, dim=1, index=self.degree_index_m + ) + expanded_scale = expanded_scale.unsqueeze(0) # (1, F, D_m_trunc, C) + x0 = x0 * expanded_scale[:, :, :1, :] + if xt.numel() > 0: + xt = xt * expanded_scale[:, :, 1:, :] + + # === Step 4. Add scalar bias and restore layout === + bias0 = self.bias0.reshape(1, self.n_focus, 1, -1) # (1, F, 1, C) + x0 = x0 + bias0 + + out = x0 if xt.numel() == 0 else torch.cat([x0, xt], dim=2) + out = out.to(dtype=in_dtype) + return out + + def serialize(self) -> dict[str, Any]: + trainable = all(p.requires_grad for p in self.parameters()) + state = self.state_dict() + return { + "@class": "ReducedEquivariantRMSNorm", + "@version": 1, + "config": { + "lmax": self.lmax, + "mmax": self.mmax, + "channels": self.channels, + "degree_index_m": np_safe(self.degree_index_m), + "n_focus": self.n_focus, + "eps": self.eps, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "trainable": trainable, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> ReducedEquivariantRMSNorm: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "ReducedEquivariantRMSNorm": + raise ValueError(f"Invalid class for ReducedEquivariantRMSNorm: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + degree_index_m = safe_numpy_to_tensor( + config.pop("degree_index_m"), + device=env.DEVICE, + dtype=torch.long, + ) + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + config["degree_index_m"] = degree_index_m + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj + + +class ScalarRMSNorm(nn.Module): + """ + Lightweight per-focus RMSNorm for scalar branches. + + This is the unified scalar norm used by SeZM: + - `n_focus=1` naturally degenerates to the single-stream behavior. + - `n_focus>1` uses independent learnable scales per focus stream. + Bias is intentionally omitted to keep the gate paths minimal. + + Parameters + ---------- + channels + Feature dimension of the last axis. + n_focus + Number of focus streams. + eps + Small epsilon for numerical stability. + dtype + Parameter and computation dtype. Caller should pass compute_dtype (fp32+) + for numerical stability. + trainable + Whether parameters are trainable. + """ + + def __init__( + self, + *, + channels: int, + n_focus: int = 1, + eps: float = 1e-7, + dtype: torch.dtype, + trainable: bool, + ) -> None: + super().__init__() + self.channels = int(channels) + self.n_focus = int(n_focus) + self.dtype = dtype + self.device = env.DEVICE + self.eps = float(eps) + self.register_buffer( + "eps_tensor", + torch.tensor(self.eps, dtype=self.dtype, device=self.device), + persistent=False, + ) + + # adam_ prefix routes this to Adam (no weight decay) in HybridMuon. + self.adam_scale = nn.Parameter( + torch.ones( + self.n_focus, + self.channels, + dtype=self.dtype, + device=self.device, + ) + ) + + for p in self.parameters(): + p.requires_grad = trainable + + @torch.amp.autocast("cuda", enabled=False) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x + Input tensor with shape (B, F, C) or (B, C) when `n_focus=1`. + + Returns + ------- + torch.Tensor + Normalized tensor with the same shape as input and same dtype. + """ + in_dtype = x.dtype + x = x.to(dtype=self.dtype) + + if x.ndim == 2: + inv_rms = torch.rsqrt( + x.square().mean(dim=-1, keepdim=True) + self.eps_tensor + ) + x = x * inv_rms + x = x * self.adam_scale[0] + return x.to(dtype=in_dtype) + + inv_rms = torch.rsqrt(x.square().mean(dim=-1, keepdim=True) + self.eps_tensor) + x = x * inv_rms + x = x * self.adam_scale.unsqueeze(0) + return x.to(dtype=in_dtype) + + def serialize(self) -> dict[str, Any]: + trainable = all(p.requires_grad for p in self.parameters()) + state = self.state_dict() + return { + "@class": "ScalarRMSNorm", + "@version": 1, + "config": { + "channels": self.channels, + "n_focus": self.n_focus, + "eps": self.eps, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "trainable": trainable, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> ScalarRMSNorm: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "ScalarRMSNorm": + raise ValueError(f"Invalid class for ScalarRMSNorm: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj diff --git a/deepmd/pt/model/descriptor/sezm_nn/radial.py b/deepmd/pt/model/descriptor/sezm_nn/radial.py new file mode 100644 index 0000000000..4f4e4d888c --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/radial.py @@ -0,0 +1,621 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Radial building blocks for the SeZM descriptor. + +This module defines the cutoff envelope, inner-distance clamp, radial basis, +and radial multilayer perceptron used by SeZM. +""" + +from __future__ import ( + annotations, +) + +import math +from typing import ( + Any, +) + +import torch +import torch.nn as nn +from einops import ( + rearrange, +) + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.network.mlp import ( + MLPLayer, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + ActivationFn, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .norm import ( + RMSNorm, +) +from .utils import ( + np_safe, + safe_numpy_to_tensor, +) + + +class RadialMLP(nn.Module): + """ + Radial MLP with channel RMSNorm and configurable activation. + + Parameters + ---------- + mlp_layers : list[int] + Layer sizes including input and output dimensions. + E.g., [in_dim, hidden1, hidden2, out_dim]. + activation_function : str + Activation function name (e.g., "silu", "tanh", "gelu"). + dtype : torch.dtype + Floating point dtype for the linear layers. + trainable : bool + Whether the parameters are trainable. + + Architecture + ------------ + Linear → RMSNorm → Activation for all hidden layers, + with the final layer being a plain Linear (no norm, no activation). + + Notes + ----- + All bias terms are disabled (Linear bias=False, RMSNorm bias-free) to + guarantee ``RadialMLP(0) = 0``. This is required because the compile path + pads masked edges with zero ``edge_rbf``; any non-zero bias would leak + spurious features into GIE scatter, causing energy divergence between + compile and non-compile paths. + """ + + def __init__( + self, + mlp_layers: list[int], + *, + activation_function: str = "silu", + dtype: torch.dtype = torch.float32, + trainable: bool = True, + seed: int | list[int] | None = None, + ) -> None: + super().__init__() + if len(mlp_layers) < 2: + raise ValueError("`mlp_layers` must have at least 2 elements") + self.mlp_layers = list(mlp_layers) + self.activation_function = str(activation_function) + self.dtype = dtype + self.device = env.DEVICE + self.precision = RESERVED_PRECISION_DICT[self.dtype] + self.trainable = bool(trainable) + + modules: list[nn.Module] = [] + n_layers = len(mlp_layers) + for i in range(n_layers - 1): + linear = MLPLayer( + mlp_layers[i], + mlp_layers[i + 1], + bias=False, + activation_function=None, + precision=self.precision, + seed=child_seed(seed, i), + trainable=trainable, + ) + modules.append(linear) + # Last layer: no RMSNorm/activation + if i < n_layers - 2: + modules.append( + RMSNorm( + channels=mlp_layers[i + 1], + dtype=self.dtype, + trainable=trainable, + ) + ) + modules.append(ActivationFn(self.activation_function)) + + self.net = nn.Sequential(*modules) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass. + + Parameters + ---------- + x : torch.Tensor + Input tensor with shape (..., mlp_layers[0]). + + Returns + ------- + torch.Tensor + Output tensor with shape (..., mlp_layers[-1]). + """ + return self.net(x) + + def serialize(self) -> dict[str, Any]: + """Serialize the RadialMLP to a dict.""" + state = self.net.state_dict() + return { + "@class": "RadialMLP", + "@version": 1, + "mlp_layers": self.mlp_layers.copy(), + "activation_function": self.activation_function, + "dtype": RESERVED_PRECISION_DICT[self.dtype], + "trainable": self.trainable, + "@variables": {k: np_safe(v) for k, v in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> RadialMLP: + """Deserialize a RadialMLP from a dict.""" + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "RadialMLP": + raise ValueError(f"Invalid class for RadialMLP: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + variables = data.pop("@variables") + data["dtype"] = PRECISION_DICT[data["dtype"]] + obj = cls(**data) + state = { + k: safe_numpy_to_tensor(v, device=env.DEVICE, dtype=obj.dtype) + for k, v in variables.items() + } + obj.net.load_state_dict(state) + return obj + + +class C3CutoffEnvelope(torch.nn.Module): + """ + C^3-continuous polynomial cutoff envelope function. + + This envelope provides a smooth transition to zero at the cutoff radius, + ensuring continuity of the function value and the first three derivatives. + + Notes + ----- + The envelope function is defined for scaled distance ``x = r / rcut`` as:: + + E(x) = 1 + x^p * (a + b*x + c*x^2 + d*x^3), for x < 1 + E(x) = 0, for x >= 1 + + where the coefficients are chosen to satisfy:: + + E(0) = 1, E(1) = 0 + E'(1) = 0, E''(1) = 0, E'''(1) = 0 + + This ensures C^3 continuity at the cutoff boundary. The coefficients are:: + + a = -(p + 1)(p + 2)(p + 3) / 6 + b = p(p + 2)(p + 3) / 2 + c = -p(p + 1)(p + 3) / 2 + d = p(p + 1)(p + 2) / 6 + + For the default exponent p=5, the coefficients are a=-56, b=140, c=-120, + d=35:: + + E(x) = 1 + x^5 * (-56 + 140*x - 120*x^2 + 35*x^3) + = 1 - 56*x^5 + 140*x^6 - 120*x^7 + 35*x^8 + + Parameters + ---------- + rcut : float + Cutoff radius in Å. + exponent : int, optional + Polynomial exponent (p), must be positive. Default is 5. + + Attributes + ---------- + rcut : float + Cutoff radius in Å. + p : float + Polynomial exponent. + a : float + Quadratic coefficient for x^p term. + b : float + Linear coefficient for x^(p+1) term. + c : float + Quadratic coefficient for x^(p+2) term. + d : float + Cubic coefficient for x^(p+3) term. + """ + + def __init__( + self, + rcut: float, + exponent: int = 5, + *, + dtype: torch.dtype = torch.float32, + ) -> None: + super().__init__() + if rcut <= 0.0: + raise ValueError("`rcut` must be positive") + if exponent <= 0: + raise ValueError("`exponent` must be positive") + self.rcut = float(rcut) + self.p = int(exponent) + self.dtype = dtype + self.device = env.DEVICE + coeff_a = -((self.p + 1) * (self.p + 2) * (self.p + 3)) / 6.0 + coeff_b = (self.p * (self.p + 2) * (self.p + 3)) / 2.0 + coeff_c = -(self.p * (self.p + 1) * (self.p + 3)) / 2.0 + coeff_d = (self.p * (self.p + 1) * (self.p + 2)) / 6.0 + self.register_buffer( + "rcut_tensor", + torch.tensor(self.rcut, dtype=self.dtype, device=self.device), + persistent=False, + ) + self.register_buffer( + "coeff_a", + torch.tensor(coeff_a, dtype=self.dtype, device=self.device), + persistent=False, + ) + self.register_buffer( + "coeff_b", + torch.tensor(coeff_b, dtype=self.dtype, device=self.device), + persistent=False, + ) + self.register_buffer( + "coeff_c", + torch.tensor(coeff_c, dtype=self.dtype, device=self.device), + persistent=False, + ) + self.register_buffer( + "coeff_d", + torch.tensor(coeff_d, dtype=self.dtype, device=self.device), + persistent=False, + ) + + def forward(self, dst: torch.Tensor) -> torch.Tensor: + """Compute the envelope value for given distances.""" + d_scaled = (dst / self.rcut_tensor).clamp(min=0.0, max=1.0) + poly = self.coeff_a + d_scaled * ( + self.coeff_b + d_scaled * (self.coeff_c + d_scaled * self.coeff_d) + ) + env_val = 1 + d_scaled.pow(self.p) * poly + return env_val * ((d_scaled < 1.0).to(dst.dtype)) + + +class InnerClamp(nn.Module): + """ + C3-continuous inner distance clamping for zone bridging. + + Applies a septic Hermite polynomial transition that freezes distances + below ``r_inner`` to the constant ``r_inner``, then smoothly transitions + back to identity at ``r_outer``:: + + r̃(r) = r_inner if r <= r_inner + r̃(r) = r_inner + (r_outer - r_inner) * h(t) if r_inner < r < r_outer + r̃(r) = r if r >= r_outer + + h(t) = 20t^4 - 45t^5 + 36t^6 - 10t^7, t = (r - r_inner) / (r_outer - r_inner) + + Boundary conditions: + ``h(0)=0``, ``h(1)=1``, ``h'(0)=0``, ``h'(1)=1``, + ``h''(0)=0``, ``h''(1)=0``, ``h'''(0)=0``, ``h'''(1)=0``. + This ensures C3 continuity: ``dr̃/dr = 0`` at r_inner (frozen zone) and + ``dr̃/dr = 1`` at r_outer (identity zone), with matched second and third + derivatives at both boundaries. + + Parameters + ---------- + r_inner : float + Freeze radius in Å. Distances below this are clamped to ``r_inner``. + r_outer : float + Outer boundary of the transition zone in Å. Above this, ``r̃ = r``. + + Raises + ------ + ValueError + If ``r_inner >= r_outer`` or either is non-positive. + """ + + def __init__(self, r_inner: float, r_outer: float) -> None: + super().__init__() + if r_inner <= 0 or r_outer <= 0: + raise ValueError("r_inner and r_outer must be positive") + if r_inner >= r_outer: + raise ValueError(f"r_inner ({r_inner}) must be < r_outer ({r_outer})") + self.r_inner = float(r_inner) + self.r_outer = float(r_outer) + + def forward(self, r: torch.Tensor) -> torch.Tensor: + """ + Apply inner distance clamping. + + Parameters + ---------- + r : torch.Tensor + Pair distances with shape (...) or (..., 1) in Å. + + Returns + ------- + torch.Tensor + Clamped distances r̃ with the same shape as input. + """ + t = ((r - self.r_inner) / (self.r_outer - self.r_inner)).clamp(0.0, 1.0) + t2 = t * t + t4 = t2 * t2 + # h(t) = 20t^4 - 45t^5 + 36t^6 - 10t^7 + # Satisfies: + # h(0)=0, h(1)=1 + # h'(0)=0, h'(1)=1 + # h''(0)=0, h''(1)=0 + # h'''(0)=0, h'''(1)=0 + h = t4 * (20.0 + t * (-45.0 + t * (36.0 - 10.0 * t))) + interpolated = self.r_inner + (self.r_outer - self.r_inner) * h + # Identity zone: r >= r_outer returns r directly. + # Both branches have matching first three derivatives at r_outer, + # so torch.where preserves C3 continuity here. + return torch.where(r >= self.r_outer, r, interpolated) + + +class BridgingSwitch(nn.Module): + r""" + C3-continuous switching amplitude for the SeZM bridging zone. + + ``BridgingSwitch`` returns a per-edge scalar amplitude in ``[0, 1]`` + that measures how far an edge sits outside the frozen zone. It is + the elementary piece the Source Freeze Propagation Gate (SFPG) + aggregates into a per-node "non-frozen confidence" via a product + over each source node's outgoing edges:: + + w(r) = 0 if r <= r_inner (frozen) + w(r) = h((r - r_inner) / (r_outer - r_inner)) if r_inner < r < r_outer (transition) + w(r) = 1 if r >= r_outer (normal) + + h(t) = 35 t^4 - 84 t^5 + 70 t^6 - 20 t^7 + + Boundary conditions at ``t=0`` and ``t=1``:: + + h(0) = h'(0) = h''(0) = h'''(0) = 0 + h(1)=1, h'(1) = h''(1) = h'''(1) = 0 + + The vanishing first three derivatives at both endpoints give + ``w \in C^3(\mathbb{R}_{\ge 0})`` with zero slope/curvature at + ``r_inner`` and ``r_outer``, so forces (first derivatives) and the + force derivatives consumed by second-order training stay continuous + across both zone boundaries. + + The surrounding infrastructure (``compute_edge_src_gate``) owns the + per-node product reduction and broadcast; this module only encodes + the scalar amplitude shape. + + Parameters + ---------- + r_inner : float + Inner radius in Å. At or below this distance ``w = 0``. + r_outer : float + Outer radius in Å. At or above this distance ``w = 1``. + + Raises + ------ + ValueError + If ``r_inner <= 0``, ``r_outer <= 0``, or ``r_inner >= r_outer``. + """ + + def __init__(self, r_inner: float, r_outer: float) -> None: + super().__init__() + if r_inner <= 0 or r_outer <= 0: + raise ValueError("r_inner and r_outer must be positive") + if r_inner >= r_outer: + raise ValueError(f"r_inner ({r_inner}) must be < r_outer ({r_outer})") + self.r_inner = float(r_inner) + self.r_outer = float(r_outer) + + def forward(self, r: torch.Tensor) -> torch.Tensor: + """ + Evaluate the C3 switching amplitude. + + Parameters + ---------- + r : torch.Tensor + Pair distances with shape (...) or (..., 1) in Å. + + Returns + ------- + torch.Tensor + Switching amplitudes in ``[0, 1]`` with the same shape as input. + """ + t = ((r - self.r_inner) / (self.r_outer - self.r_inner)).clamp(0.0, 1.0) + t2 = t * t + t4 = t2 * t2 + # h(t) = 35 t^4 - 84 t^5 + 70 t^6 - 20 t^7 (Horner form). + # Degree-7 smootherstep: the unique polynomial of this degree that + # hits ``w(r_inner)=0, w(r_outer)=1`` together with C3 flatness at + # both radii. + return t4 * (35.0 + t * (-84.0 + t * (70.0 - 20.0 * t))) + + +class RadialBasis(nn.Module): + """ + Radial basis with C^3 cutoff envelope. + + The trainable radial parameters are stored in ``adam_freqs`` so HybridMuon + routes them to Adam without weight decay. + + Notes + ----- + The Bessel basis uses PyTorch's sinc function for numerical stability:: + + phi_n(r) = w_n * sinc(w_n * r / π) + + where ``torch.sinc(z) = sin(π*z) / (π*z)``. This is mathematically + equivalent to the standard form ``sin(w_n * r) / r``, but sinc handles + the r->0 limit via Taylor expansion, providing continuous gradients + without explicit epsilon clamping. + + The ``r -> 0`` limit is finite:: + + lim_{r->0} w_n * sinc(w_n * r / π) = w_n + + The initial Bessel frequencies follow a common spacing:: + + w_n = n * π / rcut, for n = 1..n_radial (in 1/Å) + + The C^3 cutoff envelope is multiplied directly into the output to ensure + strict smoothness at ``rcut``. + + Parameters + ---------- + rcut : float + Cutoff radius in Å. + n_radial : int + Number of basis functions. + basis_type : str, optional + Radial basis type. Supported values are ``"bessel"`` and ``"gaussian"``. + dtype : torch.dtype + Floating-point dtype for the radial basis frequencies and outputs. + exponent : int, optional + Exponent for the C^3 cutoff envelope polynomial. Default is 7. + """ + + def __init__( + self, + rcut: float, + basis_type: str = "bessel", + n_radial: int = 10, + dtype: torch.dtype = torch.float32, + exponent: int = 7, + ) -> None: + super().__init__() + self.rcut = float(rcut) + if self.rcut <= 0.0: + raise ValueError("`rcut` must be positive") + self.n_radial = int(n_radial) + if self.n_radial <= 0: + raise ValueError("`n_radial` must be positive") + self.basis_type = str(basis_type).lower() + if self.basis_type not in ("bessel", "gaussian"): + raise ValueError("`basis_type` must be either 'bessel' or 'gaussian'") + self.dtype = dtype + self.device = env.DEVICE + self.precision = RESERVED_PRECISION_DICT[self.dtype] + self.exponent = int(exponent) + self.register_buffer( + "pi_tensor", + torch.tensor(math.pi, dtype=self.dtype, device=self.device), + persistent=False, + ) + + # Frequencies: n*π/rcut, n=1..n_radial + # Shape: (1, n_radial), stored as trainable nn.Parameter. + if self.basis_type == "bessel": + freqs = torch.arange( + 1, + self.n_radial + 1, + device=self.device, + dtype=self.dtype, + ) * (math.pi / self.rcut) + else: + freqs = torch.linspace( + 0.0, + self.rcut, + self.n_radial, + device=self.device, + dtype=self.dtype, + ) + self.adam_freqs = nn.Parameter( + rearrange(freqs, "n_radial -> 1 n_radial"), requires_grad=True + ) + gaussian_width = self.rcut / max(self.n_radial - 1, 1) + self.register_buffer( + "gaussian_coeff", + torch.tensor( + -0.5 / (gaussian_width * gaussian_width), + dtype=self.dtype, + device=self.device, + ), + persistent=False, + ) + + self.envelope = C3CutoffEnvelope( + rcut=self.rcut, + exponent=self.exponent, + dtype=self.dtype, + ) + + def forward(self, r: torch.Tensor) -> torch.Tensor: + """ + Compute radial basis functions. + + Parameters + ---------- + r : torch.Tensor + Pair distances with shape (N, 1) in Å, where N is the number of pairs. + + Returns + ------- + torch.Tensor + Radial basis multiplied by C^3 cutoff envelope with shape (N, n_rbf). + The output is smoothly truncated to zero at r = rcut. + """ + # === Step 1. Radial basis === + # Shape: (N, 1) * (1, n_radial) -> (N, n_radial) + if self.basis_type == "bessel": + # phi_n(r) = w_n * sinc(w_n * r / π) + x = r * self.adam_freqs # (N, n_rbf) + raw = self.adam_freqs * torch.sinc(x / self.pi_tensor) # (N, n_rbf) + else: + dr = r - self.adam_freqs # (N, n_rbf) + raw = torch.exp(dr * dr * self.gaussian_coeff) # (N, n_rbf) + + # === Step 2. Apply C^3 envelope for smooth cutoff === + envelope = self.envelope(r) # (N, 1) + return raw * envelope + + def serialize(self) -> dict[str, Any]: + """Serialize RadialBasis including trainable frequencies.""" + state = self.state_dict() + return { + "@class": "RadialBasis", + "@version": 1, + "config": { + "rcut": self.rcut, + "basis_type": self.basis_type, + "n_radial": self.n_radial, + "exponent": self.exponent, + "precision": RESERVED_PRECISION_DICT[self.dtype], + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> RadialBasis: + """Deserialize RadialBasis including trainable frequencies.""" + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "RadialBasis": + raise ValueError(f"Invalid class for RadialBasis: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config", data) + variables = data.pop("@variables", None) + precision = config["precision"] + dtype = PRECISION_DICT[precision] + obj = cls( + rcut=float(config["rcut"]), + n_radial=int(config["n_radial"]), + basis_type=str(config.get("basis_type", "bessel")), + exponent=int(config.get("exponent", 7)), + dtype=dtype, + ) + if variables is not None: + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj diff --git a/deepmd/pt/model/descriptor/sezm_nn/so2.py b/deepmd/pt/model/descriptor/sezm_nn/so2.py new file mode 100644 index 0000000000..ab0bd84f05 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/so2.py @@ -0,0 +1,1664 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +SO(2)-equivariant message-passing layers for SeZM. + +This module defines the reduced-layout SO(2) linear operator and the +edge convolution used inside SeZM interaction blocks. +""" + +from __future__ import ( + annotations, +) + +import math +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch +import torch.nn as nn + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + get_generator, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .activation import ( + GatedActivation, + SwiGLUS2Activation, + resolve_s2_grid_resolution, +) +from .attention import ( + segment_envelope_gated_softmax, +) +from .attn_res import ( + DepthAttnRes, +) +from .indexing import ( + build_m_major_index, + build_m_major_l_index, + build_rotate_inv_rescale, + get_so3_dim_of_lmax, + map_degree_idx, + project_D_to_m, + project_Dt_from_m, +) +from .norm import ( + ReducedEquivariantRMSNorm, + ScalarRMSNorm, +) +from .so3 import ( + ChannelLinear, + FocusLinear, + SO3Linear, +) +from .triton import ( + resolve_triton_rotation_mode, + rotate_back_triton, + rotate_to_local_triton, + sezm_triton_enabled, +) +from .utils import ( + ATTN_RES_MODES, + get_promoted_dtype, + init_trunc_normal_fan_in_out, + np_safe, + nvtx_range, + safe_numpy_to_tensor, +) + +if TYPE_CHECKING: + from .edge_cache import ( + EdgeFeatureCache, + ) + + +class SO2Linear(nn.Module): + """ + SO(2)-equivariant linear mixing in the edge-aligned local frame. + + Coefficient layout (m-major, truncated by mmax) + ------------------------------------------------ + The coefficient axis D_m_trunc is ordered by |m| groups:: + + [ m=0: l=0..lmax | m=1: (l,-1) then (l,+1) | ... | m=mmax: ... ] + |___ lmax+1 ____| |_______ 2*(lmax) ________| + + Each |m| group is contiguous, enabling a single block-diagonal matmul. + + Block-diagonal weight structure + ------------------------------- + The full weight matrix W has shape ``(F, D_m_trunc*Cout, D_m_trunc*Cin)`` + and is block-diagonal over |m| groups:: + + W = diag[W_m0, B_m1, B_m2, ..., B_mmax] + + - ``W_m0``: unconstrained ``(num_l*Cout, num_l*Cin)`` block for m=0. + Cross-l mixing is allowed since m=0 coefficients are real scalars. + + - ``B_m`` (|m|>0): SO(2)-constrained 2x2 block coupling (-m, +m) pairs:: + + B_m = [ W_u^T , -W_v^T ] where W_u, W_v are learnable + [ W_v^T , W_u^T ] (num_l*Cin, num_l*Cout) each. + + This structure is the real-valued form of complex multiplication + ``(u + iv)(a + ib) = (ua - vb) + i(va + ub)``, which guarantees + SO(2) equivariance: rotating the input by angle phi around z + rotates the output by the same angle. + + The weight is assembled once per forward (training) or cached (eval) + by ``_build_so2_weight()``, then applied via a single batched matmul + over all focus streams: ``einsum("efi,foi->efo")``. + + Parameters + ---------- + lmax + Maximum spherical harmonic degree. + mmax + Maximum SO(2) order (|m|) to mix. If None, defaults to ``lmax``. + in_channels + Number of input channels per (l, m) coefficient. + out_channels + Number of output channels per (l, m) coefficient. + n_focus + Number of independent focus streams. Each stream has its own + weight matrices; the batched matmul vectorizes over all streams. + dtype + Parameter dtype. + mlp_bias + Whether to use bias for l=0 (scalar) components. + seed + Random seed for weight initialization. + trainable + Whether parameters are trainable. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + in_channels: int, + out_channels: int, + n_focus: int = 1, + dtype: torch.dtype, + mlp_bias: bool = False, + seed: int | list[int] | None, + trainable: bool, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.mmax = int(self.lmax if mmax is None else mmax) + if self.mmax < 0: + raise ValueError("`mmax` must be non-negative") + if self.mmax > self.lmax: + raise ValueError("`mmax` must be <= `lmax`") + self.in_channels = int(in_channels) + self.out_channels = int(out_channels) + self.n_focus = int(n_focus) + self.dtype = dtype + self.device = env.DEVICE + self.precision = RESERVED_PRECISION_DICT[dtype] + self.mlp_bias = bool(mlp_bias) + + # === Step 1. Build m-major coefficient layout === + # Map each |m| group to contiguous index ranges in the flattened axis. + # Example for lmax=2, mmax=2: + # m=0 : indices [0, 1, 2] (l=0,1,2) + # m=1-: indices [3, 4] (l=1,2 with -m) + # m=1+: indices [5, 6] (l=1,2 with +m) + # m=2-: index [7] (l=2 with -m) + # m=2+: index [8] (l=2 with +m) + # => reduced_dim = 9 + m0_size = self.lmax + 1 + self.register_buffer( + "m0_idx", + torch.arange(m0_size, device=self.device, dtype=torch.long), + persistent=True, + ) + + pos_indices_list: list[torch.Tensor] = [] + neg_indices_list: list[torch.Tensor] = [] + # Each entry: (neg_start, pos_start, num_l) for a fixed |m|. + # These ranges are contiguous in m-major layout. + m_ranges: list[tuple[int, int, int]] = [] + + offset = m0_size + for m in range(1, self.mmax + 1): + num_l = self.lmax - m + 1 + neg_start = offset + pos_start = offset + num_l + neg_idx = torch.arange( + neg_start, neg_start + num_l, device=self.device, dtype=torch.long + ) + pos_idx = torch.arange( + pos_start, pos_start + num_l, device=self.device, dtype=torch.long + ) + neg_indices_list.append(neg_idx) + pos_indices_list.append(pos_idx) + m_ranges.append((neg_start, pos_start, num_l)) + offset += 2 * num_l + + self.reduced_dim = int(offset) + + if len(pos_indices_list) > 0: + self.register_buffer( + "pos_indices", torch.cat(pos_indices_list), persistent=True + ) + self.register_buffer( + "neg_indices", torch.cat(neg_indices_list), persistent=True + ) + self._m_ranges = m_ranges + else: + self.register_buffer( + "pos_indices", + torch.empty(0, device=self.device, dtype=torch.long), + persistent=True, + ) + self.register_buffer( + "neg_indices", + torch.empty(0, device=self.device, dtype=torch.long), + persistent=True, + ) + self._m_ranges = [] + + # === Step 2. Learnable weight parameters === + # weight_m0: folded (num_l*Cin, F*num_l*Cout) storage — (in, out) convention. + # Runtime view: (num_l*Cin, F, num_l*Cout). + # Cross-l mixing is allowed because m=0 coefficients are real. + num_m0 = self.lmax + 1 + num_in_m0 = num_m0 * self.in_channels + num_out_m0 = num_m0 * self.out_channels + self.weight_m0 = nn.Parameter( + torch.empty( + num_in_m0, + self.n_focus * num_out_m0, + device=self.device, + dtype=self.dtype, + ) + ) + weight_m0_view = self.weight_m0.view(num_in_m0, self.n_focus, num_out_m0) + for focus_idx in range(self.n_focus): + init_trunc_normal_fan_in_out( + weight_m0_view[:, focus_idx, :], child_seed(seed, 1000 + focus_idx) + ) + if self.mlp_bias: + self.bias0: nn.Parameter | None = nn.Parameter( + torch.zeros( + self.n_focus * self.out_channels, + device=self.device, + dtype=self.dtype, + ) + ) + else: + self.bias0 = None + + # weight_m[i]: folded (num_l*Cin, F*2*num_l*Cout) storage — (in, out) convention. + # Runtime view: (num_l*Cin, F, 2*num_l*Cout). + # The factor of 2 comes from storing W_u and W_v concatenated along the + # output axis. _build_so2_weight() splits them and fills the 2x2 block. + # Scaling by 1/sqrt(2) compensates for the doubled parameter count. + self.weight_m: nn.ParameterList = nn.ParameterList() + for m in range(1, self.mmax + 1): + num_l = self.lmax - m + 1 + num_in = num_l * self.in_channels + num_out = 2 * num_l * self.out_channels + weight = nn.Parameter( + torch.empty( + num_in, + self.n_focus * num_out, + device=self.device, + dtype=self.dtype, + ) + ) + weight_view = weight.view(num_in, self.n_focus, num_out) + for focus_idx in range(self.n_focus): + init_trunc_normal_fan_in_out( + weight_view[:, focus_idx, :], + child_seed(seed, 2000 + m * 100 + focus_idx), + ) + # Apply scaling for SO(2) equivariance + weight.data.mul_(1.0 / math.sqrt(2.0)) + self.weight_m.append(weight) + + for p in self.parameters(): + p.requires_grad = trainable + + # === Step 3. Precompute flattened slice ranges for _build_so2_weight === + # Each |m|>0 group occupies two sub-blocks (neg, pos) in the flattened + # weight matrix. Pre-computing the row/col ranges avoids repeated + # arithmetic in the hot path. + # Tuple layout: (neg_i0, neg_i1, pos_i0, pos_i1, <- input row ranges + # neg_o0, neg_o1, pos_o0, pos_o1) <- output col ranges + self._m0_in = (self.lmax + 1) * self.in_channels + self._m0_out = (self.lmax + 1) * self.out_channels + self._block_slices: list[tuple[int, int, int, int, int, int, int, int]] = [] + for neg_start, pos_start, num_l in self._m_ranges: + ib = num_l * self.in_channels + ob = num_l * self.out_channels + self._block_slices.append( + ( + neg_start * self.in_channels, + neg_start * self.in_channels + ib, + pos_start * self.in_channels, + pos_start * self.in_channels + ib, + neg_start * self.out_channels, + neg_start * self.out_channels + ob, + pos_start * self.out_channels, + pos_start * self.out_channels + ob, + ) + ) + + # Weight cache: only used in eval + no_grad (inference). + # Invalidated on train() via overridden method below. + self._cached_weight: torch.Tensor | None = None + + def train(self, mode: bool = True) -> SO2Linear: + """Invalidate weight cache when switching to training mode.""" + self._cached_weight = None + return super().train(mode) + + def _apply(self, fn: Any) -> SO2Linear: + """Invalidate weight cache on device or dtype moves.""" + self._cached_weight = None + return super()._apply(fn) + + def _load_from_state_dict( + self, + state_dict: dict[str, torch.Tensor], + prefix: str, + local_metadata: dict[str, Any], + strict: bool, + missing_keys: list[str], + unexpected_keys: list[str], + error_msgs: list[str], + ) -> None: + """Invalidate weight cache before loading new weights.""" + self._cached_weight = None + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + + def _build_so2_weight(self) -> torch.Tensor: + """ + Assemble the per-focus block-diagonal SO(2) weight matrix. + + The flattened weight has shape ``(D_m*Cin, F, D_m*Cout)`` (in, out) + where both axes follow the same m-major coefficient ordering. + Off-diagonal blocks (cross-|m|) are zero, enforcing SO(2) equivariance. + + Returns + ------- + torch.Tensor + Weight with shape (D_m*Cin, F, D_m*Cout). + """ + in_total = self.reduced_dim * self.in_channels + out_total = self.reduced_dim * self.out_channels + weight = self.weight_m0.new_zeros(in_total, self.n_focus, out_total) + num_in_m0 = (self.lmax + 1) * self.in_channels + num_out_m0 = (self.lmax + 1) * self.out_channels + weight_m0 = self.weight_m0.view(num_in_m0, self.n_focus, num_out_m0) + + # m=0 block: (Cin_blk, F, Cout_blk) — (in, out) convention. + weight[: self._m0_in, :, : self._m0_out] = weight_m0 + + # |m|>0 blocks: fill the 2x2 SO(2) coupling structure. + # For each |m|, the learnable param w has shape (in_blk, F, 2*out_blk) + # which is split into W_u and W_v along the output axis. + for m_idx, w in enumerate(self.weight_m): + ni0, ni1, pi0, pi1, no0, no1, po0, po1 = self._block_slices[m_idx] + ib = ni1 - ni0 # in_block size + ob = no1 - no0 # out_block size + w = w.view(ib, self.n_focus, 2 * ob) + w_u = w[:, :, :ob] # (in_blk, F, out_blk) + w_v = w[:, :, ob:] # (in_blk, F, out_blk) + # Fill the 2x2 coupling: + # Row = input (neg/pos), Col = output (neg/pos). + # [ W_u^T, -W_v^T ]^T => row=neg_in: W_u to neg_out, W_v to pos_out + # [ W_v^T, W_u^T ]^T => row=pos_in: -W_v to neg_out, W_u to pos_out + weight[ni0:ni1, :, no0:no1] = w_u # neg_in -> neg_out + weight[ni0:ni1, :, po0:po1] = w_v # neg_in -> pos_out + weight[pi0:pi1, :, no0:no1] = -w_v # pos_in -> neg_out + weight[pi0:pi1, :, po0:po1] = w_u # pos_in -> pos_out + return weight + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x + Input with shape (E, F, D_m_trunc, Cin), where D_m_trunc is the + coefficient dimension of the m-major layout truncated by `mmax`. + + Returns + ------- + torch.Tensor + Output with shape (E, F, D_m_trunc, Cout), where Cout is output channels. + """ + # === Step 1. Flatten coefficient + channel axes for matmul === + # (E, F, D_m, Cin) -> (E, F, D_m*Cin) + n_edge = x.shape[0] + in_dim_total = self.reduced_dim * self.in_channels + x_flat = x.reshape(n_edge, self.n_focus, in_dim_total) + + # === Step 2. Get block-diagonal weight (cached in eval+no_grad) === + if self._cached_weight is not None: + weight = self._cached_weight + else: + weight = self._build_so2_weight() + # Cache only in eval mode with grad disabled (pure inference). + if not self.training and not torch.is_grad_enabled(): + self._cached_weight = weight.detach() + + # === Step 3. Batched matmul over focus streams + reshape back === + # einsum "efi,ifo->efo": (E,F,D_m*Cin) x (D_m*Cin,F,D_m*Cout) -> (E,F,D_m*Cout) + out_flat = torch.einsum("efi,ifo->efo", x_flat, weight) + out = out_flat.reshape( + n_edge, self.n_focus, self.reduced_dim, self.out_channels + ) + + # === Step 4. Bias on l=0 scalar index === + if self.mlp_bias: + bias0 = self.bias0.view(self.n_focus, self.out_channels) + out[:, :, 0, :] = out[:, :, 0, :] + bias0.unsqueeze(0) + return out + + def serialize(self) -> dict[str, Any]: + trainable = all(p.requires_grad for p in self.parameters()) + state = self.state_dict() + return { + "@class": "SO2Linear", + "@version": 1, + "config": { + "lmax": self.lmax, + "mmax": self.mmax, + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "n_focus": self.n_focus, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "mlp_bias": self.mlp_bias, + "trainable": trainable, + "seed": None, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SO2Linear: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "SO2Linear": + raise ValueError(f"Invalid class for SO2Linear: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj + + +class DynamicRadialDegreeMixer(nn.Module): + """ + Edge-conditioned degree mixer in the SO(2) reduced local layout. + + The mixer replaces per-degree scalar radial modulation by an edge-conditioned + degree kernel without channel output mixing: + + degree: + y[e, l_out, m, c] = sum_l_in W[e, l_in, l_out, |m|] x[e, l_in, m, c] + degree_channel: + y[e, l_out, m, c] = sum_l_in W[e, l_in, l_out, |m|, c] x[e, l_in, m, c] + + `mode="degree"` shares W across channels. `mode="degree_channel"` gives each + channel its own W, optionally with a low-rank channel factorization. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + channels: int, + mode: str, + rank: int = 0, + dtype: torch.dtype, + seed: int | list[int] | None, + trainable: bool, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.mmax = int(self.lmax if mmax is None else mmax) + if self.mmax < 0: + raise ValueError("`mmax` must be non-negative") + if self.mmax > self.lmax: + raise ValueError("`mmax` must be <= `lmax`") + self.channels = int(channels) + if self.channels < 1: + raise ValueError("`channels` must be positive") + self.mode = str(mode).lower() + if self.mode not in {"degree", "degree_channel"}: + raise ValueError("`mode` must be one of 'degree' or 'degree_channel'") + self.rank = int(rank) + if self.rank < 0: + raise ValueError("`rank` must be non-negative") + self.dtype = dtype + self.device = env.DEVICE + + # m-major reduced layout: m=0 block followed by (-m, +m) blocks. + self.reduced_dim = (self.lmax + 1) + sum( + 2 * (self.lmax - m + 1) for m in range(1, self.mmax + 1) + ) + self.degree_kernel_size = sum( + (self.lmax - m + 1) ** 2 for m in range(self.mmax + 1) + ) + self.input_dim = (self.lmax + 1) * self.channels + if self.mode == "degree": + self.proj_out_dim = self.degree_kernel_size + elif self.rank > 0: + self.proj_out_dim = self.degree_kernel_size * self.rank + else: + self.proj_out_dim = self.degree_kernel_size * self.channels + + self.weight = nn.Parameter( + torch.empty( + self.input_dim, + self.proj_out_dim, + device=self.device, + dtype=self.dtype, + ) + ) + init_trunc_normal_fan_in_out(self.weight, child_seed(seed, 0)) + + if self.mode == "degree_channel" and self.rank > 0: + self.channel_basis: nn.Parameter | None = nn.Parameter( + torch.empty( + self.rank, + self.channels, + device=self.device, + dtype=self.dtype, + ) + ) + init_trunc_normal_fan_in_out(self.channel_basis, child_seed(seed, 1)) + else: + self.channel_basis = None + + compact_idx, dense_idx = self._build_dense_scatter_indices() + self.register_buffer("kernel_compact_index", compact_idx, persistent=True) + self.register_buffer("kernel_dense_index", dense_idx, persistent=True) + for p in self.parameters(): + p.requires_grad = trainable + + def _build_dense_scatter_indices(self) -> tuple[torch.Tensor, torch.Tensor]: + compact_indices: list[int] = [] + dense_indices: list[int] = [] + compact_offset = 0 + reduced_dim = self.reduced_dim + + def append_block(start_in: int, start_out: int, num_l: int) -> None: + for l_in in range(num_l): + for l_out in range(num_l): + compact_indices.append(compact_offset + l_in * num_l + l_out) + # Store dense kernels in matmul layout (out, in) so forward + # can call bmm/einsum without transposing the degree matrix. + dense_indices.append( + (start_out + l_out) * reduced_dim + start_in + l_in + ) + + # m=0: single real block. + num_l0 = self.lmax + 1 + append_block(0, 0, num_l0) + compact_offset += num_l0 * num_l0 + + # |m|>0: same degree kernel is applied to the negative and positive + # signed-m blocks. No cross signed-m mixing is introduced. + offset = num_l0 + for m in range(1, self.mmax + 1): + num_l = self.lmax - m + 1 + neg_start = offset + pos_start = offset + num_l + append_block(neg_start, neg_start, num_l) + append_block(pos_start, pos_start, num_l) + compact_offset += num_l * num_l + offset += 2 * num_l + + return ( + torch.tensor(compact_indices, device=self.device, dtype=torch.long), + torch.tensor(dense_indices, device=self.device, dtype=torch.long), + ) + + def _project_radial(self, radial_feat: torch.Tensor) -> torch.Tensor: + radial_m0 = radial_feat[:, : self.lmax + 1, :].reshape( + radial_feat.shape[0], self.input_dim + ) + return torch.matmul(radial_m0, self.weight) + + def _scatter_degree_kernel(self, compact: torch.Tensor) -> torch.Tensor: + n_edge = compact.shape[0] + dense = compact.new_zeros(n_edge, self.reduced_dim * self.reduced_dim) + source = compact.index_select(1, self.kernel_compact_index) + dense.index_copy_(1, self.kernel_dense_index, source) + return dense.view(n_edge, self.reduced_dim, self.reduced_dim) + + def _scatter_rank_kernel(self, compact: torch.Tensor) -> torch.Tensor: + n_edge = compact.shape[0] + dense = compact.new_zeros( + n_edge, self.reduced_dim * self.reduced_dim, self.rank + ) + source = compact.index_select(1, self.kernel_compact_index) + dense.index_copy_(1, self.kernel_dense_index, source) + return dense.view(n_edge, self.reduced_dim, self.reduced_dim, self.rank) + + def _scatter_channel_kernel(self, compact: torch.Tensor) -> torch.Tensor: + n_edge = compact.shape[0] + dense = compact.new_zeros( + n_edge, self.reduced_dim * self.reduced_dim, self.channels + ) + source = compact.index_select(1, self.kernel_compact_index) + dense.index_copy_(1, self.kernel_dense_index, source) + return dense.view(n_edge, self.reduced_dim, self.reduced_dim, self.channels) + + def forward(self, x_local: torch.Tensor, radial_feat: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x_local + Local reduced features with shape (E, D_m, C_wide). + radial_feat + Invariant radial/type features with shape (E, D_m, C_wide). + """ + if x_local.shape != radial_feat.shape: + raise ValueError("`x_local` and `radial_feat` must have the same shape") + if x_local.shape[1] != self.reduced_dim or x_local.shape[2] != self.channels: + raise ValueError("Input shape is incompatible with this mixer") + + kernel_flat = self._project_radial(radial_feat) + if self.mode == "degree": + kernel = self._scatter_degree_kernel(kernel_flat) + return torch.bmm(kernel, x_local) + + if self.rank > 0: + compact = kernel_flat.view( + x_local.shape[0], self.degree_kernel_size, self.rank + ) + kernel = self._scatter_rank_kernel(compact) + mixed = torch.einsum("eoir,eic->eorc", kernel, x_local) + channel_basis = self.channel_basis.view(1, 1, self.rank, self.channels) + return (mixed * channel_basis).sum(dim=2) + + compact = kernel_flat.view( + x_local.shape[0], self.degree_kernel_size, self.channels + ) + kernel = self._scatter_channel_kernel(compact) + return torch.einsum("eoic,eic->eoc", kernel, x_local) + + +class SO2Convolution(nn.Module): + """ + SO(2)-equivariant edge convolution with cached geometry and rotations. + + This module consumes node features in packed SO(3) layout `(N, D, C)` and + performs edge message passing in the reduced m-major local layout. The + operation pipeline is: + + 1. `pre_focus_mix`: project node features `(N, D, C)` to the SO(2) hidden width. + 2. rotate global -> local reduced basis with cached `D_to_m`. + 3. radial modulation in reduced layout. + 4. `so2_layers` stacked local mixers: + `inter_norm -> SO2Linear -> non_linearity -> residual(+LayerScale)`. + 5. rotate local -> global with cached `Dt_from_m`. + 6. edge aggregation (plain envelope scatter or envelope-aware grouped + softmax attention with exact envelope-gated competition and + output-side head gate). + 7. `post_focus_mix`: project aggregated hidden messages back to `(N, D, C)`. + + Equivariance is preserved because both `pre_focus_mix` and `post_focus_mix` + only mix the channel axis for each `(l, m)` coefficient and never mix + coefficient indices across `(l, m)`. + + Parameters + ---------- + lmax + Maximum degree. + mmax + Maximum SO(2) order (|m|). If None, defaults to lmax. + channels + Number of channels per (l, m) coefficient. + n_focus + Number of focus streams inside the SO(2) branch. + focus_dim + Hidden width per focus stream inside SO(2). + ``focus_dim=0`` means using ``channels``. + focus_compete + If True, apply cross-focus softmax competition in SO(2) local layout. + Competition logits are constructed only from l=0 scalar channels and the + resulting invariant weights are broadcast to all (l, m) components. + so2_norm + If True, apply intermediate ReducedEquivariantRMSNorm as pre-norm before + each SO(2) mixing layer. The last SO(2) layer always uses Identity. + so2_layers + Number of SO2Linear layers per convolution (default: 1). + so2_attn_res + Depth-wise attention residual mode across the internal SO(2) layer + history. Must be one of ``"none"``, ``"independent"``, or + ``"dependent"``. The same scalar weights are broadcast to the full + reduced equivariant tensor. + layer_scale + If True, apply per-layer learnable LayerScale (per-focus-channel, + init 1e-3) on each SO(2) residual branch. + n_atten_head + Number of attention heads used during aggregation. + - 0: plain envelope-weighted scatter-sum. + - >0: envelope-gated grouped softmax attention with output-side head + gates. Attention uses ``w**2 * exp(logit)`` in the numerator and + ``zeta + sum(w**2 * exp(logit))`` in the denominator. + atten_f_mix + If True, merge the internal focus streams into one attention stream + after rotate-back. Attention heads then split the full hidden width + ``n_focus * focus_dim`` instead of each focus stream independently. + atten_v_proj + If True, apply an explicit degree-aware value projection before + attention aggregation. + atten_o_proj + If True, apply an explicit degree-aware output projection after the + output-side attention gate. + s2_activation + If True, replace each intermediate reduced-layout gate with S2-grid + SwiGLU. Intermediate ``SO2Linear`` layers then output ``2 * focus_dim`` + channels before the activation folds them back to ``focus_dim``. + lebedev_quadrature + If True, use Lebedev quadrature for the S2 projector. + activation_function + Activation function for the gated activation path when + ``s2_activation=False``. + mlp_bias + Whether to use bias in SO2Linear (l=0 bias) and GatedActivation + (gate linear bias). + use_triton + If True, opt into fused Triton SO(2) rotation kernels on supported + CUDA dtypes. The eager projection path remains the default. + radial_so2_mode + Dynamic radial degree mixer mode. ``"none"`` applies elementwise + radial modulation, ``"degree"`` applies a channel-shared dynamic + cross-degree kernel, and ``"degree_channel"`` applies a + per-channel dynamic cross-degree kernel. + radial_so2_rank + Low-rank channel factorization rank for ``radial_so2_mode="degree_channel"``. + ``0`` uses the full per-channel dynamic degree kernel. + eps + Small epsilon for normalization modules. + dtype + Parameter dtype. + seed + Random seed for weight initialization. + trainable + Whether parameters are trainable. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + channels: int, + n_focus: int = 1, + focus_dim: int = 0, + focus_compete: bool = True, + so2_norm: bool = False, + so2_layers: int = 4, + so2_attn_res: str = "none", + layer_scale: bool = False, + n_atten_head: int = 1, + atten_f_mix: bool = False, + atten_v_proj: bool = False, + atten_o_proj: bool = False, + s2_activation: bool = False, + lebedev_quadrature: bool = False, + activation_function: str = "silu", + mlp_bias: bool = False, + use_triton: bool = False, + radial_so2_mode: str = "none", + radial_so2_rank: int = 0, + eps: float = 1e-7, + dtype: torch.dtype, + seed: int | list[int] | None, + trainable: bool, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.mmax = int(self.lmax if mmax is None else mmax) + if self.mmax < 0: + raise ValueError("`mmax` must be non-negative") + if self.mmax > self.lmax: + raise ValueError("`mmax` must be <= `lmax`") + self.channels = int(channels) + self.n_focus = int(n_focus) + if self.n_focus < 1: + raise ValueError("`n_focus` must be >= 1") + self.focus_dim = int(focus_dim) + if self.focus_dim < 0: + raise ValueError("`focus_dim` must be >= 0") + self.so2_focus_dim = self.channels if self.focus_dim == 0 else self.focus_dim + self.hidden_channels = int(self.n_focus * self.so2_focus_dim) + self.use_hidden_projection = self.hidden_channels != self.channels + self.focus_compete = bool(focus_compete) + self.focus_softmax_tau = 1.0 + self.focus_label_smoothing = 0.02 + self.so2_norm = bool(so2_norm) + self.so2_layers = int(so2_layers) + if self.so2_layers < 1: + raise ValueError("`so2_layers` must be >= 1") + self.so2_attn_res_mode = str(so2_attn_res).lower() + if self.so2_attn_res_mode not in ATTN_RES_MODES: + raise ValueError( + "`so2_attn_res` must be one of 'none', 'independent', or 'dependent'" + ) + self.use_so2_attn_res = self.so2_attn_res_mode != "none" + self.layer_scale = bool(layer_scale) + self.n_atten_head = int(n_atten_head) + self.atten_f_mix = bool(atten_f_mix) + self.use_atten_v_proj = bool(atten_v_proj) + self.use_atten_o_proj = bool(atten_o_proj) + self.s2_activation = bool(s2_activation) + self.lebedev_quadrature = bool(lebedev_quadrature) + self.s2_grid_method = "lebedev" if self.lebedev_quadrature else "e3nn" + self.s2_grid_resolution = resolve_s2_grid_resolution( + self.lmax, + self.mmax, + method=self.s2_grid_method, + ) + self.activation_function = str(activation_function) + if self.n_atten_head < 0: + raise ValueError("`n_atten_head` must be non-negative") + self.attn_n_focus = ( + 1 if self.atten_f_mix and self.n_atten_head > 0 else self.n_focus + ) + self.attn_focus_dim = ( + self.hidden_channels + if self.atten_f_mix and self.n_atten_head > 0 + else self.so2_focus_dim + ) + if self.n_atten_head > 0 and self.attn_focus_dim % self.n_atten_head != 0: + raise ValueError( + "`n_atten_head` must divide the attention width " + "(`focus_dim` or `n_focus * focus_dim` when `atten_f_mix=True`)" + ) + self.head_dim = ( + None + if self.n_atten_head == 0 + else int(self.attn_focus_dim // self.n_atten_head) + ) + self.mlp_bias = bool(mlp_bias) + self.use_triton = bool(use_triton) + self.radial_so2_mode = str(radial_so2_mode).lower() + if self.radial_so2_mode not in {"none", "degree", "degree_channel"}: + raise ValueError( + "`radial_so2_mode` must be one of 'none', 'degree', or 'degree_channel'" + ) + self.radial_so2_rank = int(radial_so2_rank) + if self.radial_so2_rank < 0: + raise ValueError("`radial_so2_rank` must be non-negative") + self.eps = float(eps) + self.ebed_dim_full = get_so3_dim_of_lmax(self.lmax) + self.dtype = dtype + self.device = env.DEVICE + self.precision = RESERVED_PRECISION_DICT[dtype] + self.compute_dtype = get_promoted_dtype(self.dtype) + self.use_triton_rotations = self.use_triton and sezm_triton_enabled( + device=self.device, + dtype=self.dtype, + ) + + # === Step 1. Precompute coefficient indices for m-major reduced layout === + coeff_index_m = build_m_major_index(self.lmax, self.mmax, device=self.device) + degree_index_m = build_m_major_l_index(self.lmax, self.mmax, device=self.device) + degree_index_full = map_degree_idx(self.lmax, device=self.device) + rotate_inv_rescale_full = build_rotate_inv_rescale( + lmax=self.lmax, + mmax=self.mmax, + degree_index=degree_index_full, + device=self.device, + dtype=self.dtype, + ) + self.register_buffer("coeff_index_m", coeff_index_m, persistent=True) + self.register_buffer("degree_index_m", degree_index_m, persistent=True) + self.register_buffer( + "rotate_inv_rescale_full", rotate_inv_rescale_full, persistent=True + ) + self.reduced_dim = int(coeff_index_m.numel()) + self.triton_rotation_mode = resolve_triton_rotation_mode( + dim_full=self.ebed_dim_full, + reduced_dim=self.reduced_dim, + ) + + # === Step 2. Split deterministic seeds at the module top-level === + seed_so2_stack = child_seed(seed, 0) + seed_non_linearities = child_seed(seed, 1) + seed_so3_pre = child_seed(seed, 2) + seed_so3_post = child_seed(seed, 3) + seed_gate = child_seed(seed, 4) + seed_depth_attn = child_seed(seed, 5) + seed_radial_hidden = child_seed(seed, 6) + seed_radial_degree = child_seed(seed, 7) + + # === Step 3. Multiple SO2Linear layers === + self.so2_linears = nn.ModuleList( + [ + SO2Linear( + lmax=self.lmax, + mmax=self.mmax, + in_channels=self.so2_focus_dim, + out_channels=( + 2 * self.so2_focus_dim + if self.s2_activation and i < self.so2_layers - 1 + else self.so2_focus_dim + ), + n_focus=self.n_focus, + dtype=self.dtype, + mlp_bias=self.mlp_bias, + seed=child_seed(seed_so2_stack, i), + trainable=trainable, + ) + for i in range(self.so2_layers) + ] + ) + + # === Step 4. Intermediate norms (Optional) === + inter_norms: list[nn.Module] = [] + if self.so2_norm: + for _ in range(max(0, self.so2_layers - 1)): + inter_norms.append( + ReducedEquivariantRMSNorm( + lmax=self.lmax, + mmax=self.mmax, + channels=self.so2_focus_dim, + degree_index_m=self.degree_index_m, + n_focus=self.n_focus, + dtype=self.compute_dtype, + trainable=trainable, + ) + ) + else: + for _ in range(max(0, self.so2_layers - 1)): + inter_norms.append(nn.Identity()) + inter_norms.append(nn.Identity()) + self.so2_inter_norms = nn.ModuleList(inter_norms) + + # === Step 5. Intermediate non-linearity === + non_linearities: list[nn.Module] = [] + for i in range(max(0, self.so2_layers - 1)): + if self.s2_activation: + non_linearities.append( + SwiGLUS2Activation( + lmax=self.lmax, + mmax=self.mmax, + channels=self.so2_focus_dim, + dtype=self.compute_dtype, + n_focus=self.n_focus, + layout="nfdc", + grid_resolution_list=self.s2_grid_resolution, + coefficient_layout="m_major", + grid_method=self.s2_grid_method, + mlp_bias=self.mlp_bias, + trainable=trainable, + seed=child_seed(seed_non_linearities, i), + ) + ) + else: + non_linearities.append( + GatedActivation( + lmax=self.lmax, + mmax=self.mmax, + channels=self.so2_focus_dim, + n_focus=self.n_focus, + dtype=self.compute_dtype, + activation_function=self.activation_function, + mlp_bias=self.mlp_bias, + layout="nfdc", + trainable=trainable, + seed=child_seed(seed_non_linearities, i), + ) + ) + non_linearities.append(nn.Identity()) + self.non_linearities = nn.ModuleList(non_linearities) + + # === Step 5.5. Optional depth-wise attention residuals across SO(2) layers === + if self.use_so2_attn_res: + self.so2_layer_attn_res: nn.ModuleList | None = nn.ModuleList( + [ + DepthAttnRes( + channels=self.hidden_channels, + input_dependent=self.so2_attn_res_mode == "dependent", + eps=self.eps, + bias=self.mlp_bias, + dtype=self.compute_dtype, + trainable=trainable, + seed=child_seed(seed_depth_attn, i), + ) + for i in range(self.so2_layers) + ] + ) + else: + self.so2_layer_attn_res = None + + # === Step 6. Optional per-layer LayerScale for SO(2) residual branches === + if self.layer_scale: + self.adam_so2_layer_scales = nn.ParameterList( + [ + nn.Parameter( + torch.ones( + self.n_focus, + self.so2_focus_dim, + dtype=self.dtype, + device=self.device, + ) + * 1e-3, + requires_grad=trainable, + ) + for _ in range(self.so2_layers) + ] + ) + else: + self.adam_so2_layer_scales = None + + # === Step 7. Optional attention projections (n_atten_head > 0) === + self.attn_qk_norm: ScalarRMSNorm | None = None + self.attn_q_proj: FocusLinear | None = None + self.attn_k_proj: FocusLinear | None = None + self.attn_focus_mix: SO3Linear | None = None + self.attn_v_proj: SO3Linear | None = None + self.attn_o_proj: SO3Linear | None = None + self.adamw_attn_logit_w: nn.Parameter | None = None + self.adamw_attn_z_bias_raw: nn.Parameter | None = None + self.attn_output_gate_norm: ScalarRMSNorm | None = None + self.adamw_attn_gate_w: nn.Parameter | None = None + if self.n_atten_head > 0: + self.attn_qk_norm = ScalarRMSNorm( + channels=self.attn_focus_dim, + n_focus=self.attn_n_focus, + eps=self.eps, + dtype=self.compute_dtype, + trainable=trainable, + ) + self.attn_q_proj = FocusLinear( + in_channels=self.attn_focus_dim, + out_channels=self.attn_focus_dim, + n_focus=self.attn_n_focus, + dtype=self.compute_dtype, + bias=False, + seed=child_seed(seed_gate, 0), + trainable=trainable, + ) + self.attn_k_proj = FocusLinear( + in_channels=self.attn_focus_dim, + out_channels=self.attn_focus_dim, + n_focus=self.attn_n_focus, + dtype=self.compute_dtype, + bias=False, + seed=child_seed(seed_gate, 1), + trainable=trainable, + ) + if self.atten_f_mix: + self.attn_focus_mix = SO3Linear( + lmax=self.lmax, + in_channels=self.hidden_channels, + out_channels=self.hidden_channels, + n_focus=1, + dtype=self.compute_dtype, + mlp_bias=False, + seed=child_seed(seed_gate, 19), + trainable=trainable, + ) + if self.use_atten_v_proj: + self.attn_v_proj = SO3Linear( + lmax=self.lmax, + in_channels=self.attn_focus_dim, + out_channels=self.attn_focus_dim, + n_focus=self.attn_n_focus, + dtype=self.compute_dtype, + mlp_bias=False, + seed=child_seed(seed_gate, 20), + trainable=trainable, + ) + if self.use_atten_o_proj: + self.attn_o_proj = SO3Linear( + lmax=self.lmax, + in_channels=self.attn_focus_dim, + out_channels=self.attn_focus_dim, + n_focus=self.attn_n_focus, + dtype=self.compute_dtype, + mlp_bias=False, + seed=child_seed(seed_gate, 21), + trainable=trainable, + ) + self.adamw_attn_logit_w = nn.Parameter( + torch.empty( + self.attn_focus_dim, + self.attn_n_focus, + self.n_atten_head, + dtype=self.compute_dtype, + device=self.device, + ), + requires_grad=trainable, + ) + nn.init.normal_( + self.adamw_attn_logit_w, + mean=0.0, + std=0.01, + generator=get_generator(child_seed(seed_gate, 2)), + ) + # softplus(0.5413) ~= 1.0 provides balanced initial competition. + self.adamw_attn_z_bias_raw = nn.Parameter( + torch.full( + (self.attn_n_focus, self.n_atten_head), + 0.5413, + dtype=self.compute_dtype, + device=self.device, + ), + requires_grad=trainable, + ) + self.attn_output_gate_norm = ScalarRMSNorm( + channels=self.attn_focus_dim, + n_focus=self.attn_n_focus, + eps=self.eps, + dtype=self.compute_dtype, + trainable=trainable, + ) + self.adamw_attn_gate_w = nn.Parameter( + torch.empty( + self.attn_focus_dim, + self.attn_n_focus, + self.n_atten_head, + dtype=self.compute_dtype, + device=self.device, + ), + requires_grad=trainable, + ) + nn.init.normal_( + self.adamw_attn_gate_w, + mean=0.0, + std=0.01, + generator=get_generator(child_seed(seed_gate, 3)), + ) + + # === Step 7.5. Optional cross-focus competition === + self.focus_compete_norm: ScalarRMSNorm | None = None + self.adamw_focus_compete_w: nn.Parameter | None = None + self.focus_compete_bias: nn.Parameter | None = None + if self.focus_compete and self.n_focus > 1: + self.focus_compete_norm = ScalarRMSNorm( + channels=self.so2_focus_dim, + n_focus=self.n_focus, + eps=self.eps, + dtype=self.compute_dtype, + trainable=trainable, + ) + self.adamw_focus_compete_w = nn.Parameter( + torch.empty( + self.so2_focus_dim, + self.n_focus, + dtype=self.compute_dtype, + device=self.device, + ), + requires_grad=trainable, + ) + nn.init.normal_( + self.adamw_focus_compete_w, + mean=0.0, + std=0.01, + generator=get_generator(child_seed(seed_gate, 4)), + ) + if self.mlp_bias: + self.focus_compete_bias = nn.Parameter( + torch.zeros( + self.n_focus, + dtype=self.compute_dtype, + device=self.device, + ), + requires_grad=trainable, + ) + + # === Step 8. Optional radial hidden projection === + self.radial_hidden_proj: ChannelLinear | None = None + if self.use_hidden_projection: + self.radial_hidden_proj = ChannelLinear( + in_channels=self.channels, + out_channels=self.hidden_channels, + dtype=self.dtype, + bias=False, + seed=seed_radial_hidden, + trainable=trainable, + ) + self.radial_degree_mixer: DynamicRadialDegreeMixer | None = None + if self.radial_so2_mode != "none": + self.radial_degree_mixer = DynamicRadialDegreeMixer( + lmax=self.lmax, + mmax=self.mmax, + channels=self.hidden_channels, + mode=self.radial_so2_mode, + rank=self.radial_so2_rank, + dtype=self.dtype, + seed=seed_radial_degree, + trainable=trainable, + ) + + # === Step 9. Pre-focus channel mixing === + # This projects the full channel width before the SO(2) focus split. + self.pre_focus_mix = SO3Linear( + lmax=self.lmax, + in_channels=self.channels, + out_channels=self.hidden_channels, + n_focus=1, + dtype=dtype, + mlp_bias=self.mlp_bias, + trainable=trainable, + seed=seed_so3_pre, + ) + + # === Step 10. Post-focus channel mixing === + self.post_focus_mix = SO3Linear( + lmax=self.lmax, + in_channels=self.hidden_channels, + out_channels=self.channels, + n_focus=1, + dtype=dtype, + mlp_bias=self.mlp_bias, + trainable=trainable, + seed=seed_so3_post, + init_std=0.0, + ) + + def forward( + self, + x: torch.Tensor, + edge_cache: EdgeFeatureCache, + radial_feat: torch.Tensor, + ) -> torch.Tensor: + """ + Parameters + ---------- + x + Node features with shape (N, D, C), where D=(lmax+1)^2 is the + SO(3) coefficient dimension. + edge_cache + Precomputed edge cache. Must be compatible with this block's lmax. + radial_feat + Per-edge radial features with shape (E, lmax+1, C), already fused + with edge type features. + + Returns + ------- + torch.Tensor + Message updates with shape (N, D, C). + """ + src, dst = edge_cache.src, edge_cache.dst + n_node = x.shape[0] + n_edge = src.numel() + + # === Step 1. Pre-focus channel mixing on full width === + with nvtx_range("SO2Conv/pre_focus_mix"): + # (N, D, C_wide), C_wide = F * Cf + x = self.pre_focus_mix(x.unsqueeze(2)).squeeze(2) + + # === Step 2. Rotate to edge-aligned local frame === + with nvtx_range("SO2Conv/rotate_to_local"): + D_full = edge_cache.D_full + if self.use_triton_rotations and not self.training: + x_local = rotate_to_local_triton( + x=x, + src=src, + wigner=D_full, + coeff_index=self.coeff_index_m, + dim_full=self.ebed_dim_full, + rotation_mode=self.triton_rotation_mode, + ) # (E, D_m, C_wide) + else: + D_m_prime = project_D_to_m( + D_full=D_full, + coeff_index_m=self.coeff_index_m, + ebed_dim_full=self.ebed_dim_full, + cache=edge_cache.D_to_m_cache, + key_lmax=self.lmax, + key_mmax=self.mmax, + ) + x_src = x.index_select(0, src) # (E, D, C_wide) + x_local = torch.bmm(D_m_prime, x_src) # (E, D_m, C_wide) + + # === Step 3. Select radial/type features for reduced layout === + with nvtx_range("SO2Conv/radial_fuse"): + rad_feat = radial_feat[:, self.degree_index_m, :] # (E, D_m, C) + if self.radial_hidden_proj is not None: + rad_feat = self.radial_hidden_proj(rad_feat) + if self.radial_degree_mixer is None: + x_local.mul_(rad_feat) + else: + x_local = self.radial_degree_mixer(x_local, rad_feat) + rad_feat_l0_focus = rad_feat[:, 0, :].reshape( + n_edge, self.n_focus, self.so2_focus_dim + ) # (E, F, Cf) + + # === Step 4. Convert to SO(2) internal focus layout === + focus_gate_src: torch.Tensor | None = None + with nvtx_range("SO2Conv/reshape_for_so2"): + x_local = x_local.reshape( + n_edge, self.reduced_dim, self.n_focus, self.so2_focus_dim + ).transpose(1, 2) # (E, F, D_m, Cf), strided + if self.focus_compete and self.n_focus > 1: + focus_gate_src = x_local[:, :, 0, :] + + # === Step 5. Multi-layer SO(2) mixing (pre-norm + residual) === + with nvtx_range("SO2Conv/so2_layers"): + + def so2_l0_extractor(v: torch.Tensor) -> torch.Tensor: + """Extract scalar features from SO(2) reduced layout.""" + return v[:, :, 0, :].reshape(v.shape[0], self.hidden_channels) + + def apply_bias_correction( + x_local: torch.Tensor, + so2_linear: SO2Linear, + layer_idx: int, + ) -> None: + if layer_idx != 0 or so2_linear.bias0 is None: + return + bias0 = so2_linear.bias0.view( + self.n_focus, so2_linear.out_channels + ).unsqueeze(0) + if so2_linear.out_channels == self.so2_focus_dim: + radial_factor = rad_feat_l0_focus + elif so2_linear.out_channels == 2 * self.so2_focus_dim: + radial_factor = torch.cat( + [rad_feat_l0_focus, rad_feat_l0_focus], dim=-1 + ) + else: + raise RuntimeError( + "Unexpected SO2Linear output width in bias correction" + ) + bias_correction = bias0 * ( + radial_factor * edge_cache.edge_env.reshape(-1, 1, 1) - 1.0 + ) + x_local[:, :, 0, :].add_(bias_correction) + + if self.use_so2_attn_res: + so2_depth_sources = [x_local] + for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate( + zip( + self.so2_linears, + self.so2_inter_norms, + self.non_linearities, + strict=True, + ) + ): + x_local: torch.Tensor = self.so2_layer_attn_res[layer_idx]( + sources=so2_depth_sources, + scalar_extractor=so2_l0_extractor, + current_x=x_local, + ) + residual = x_local + x_local = inter_norm(x_local) + x_local = so2_linear(x_local) + apply_bias_correction(x_local, so2_linear, layer_idx) + + x_local = non_linear(x_local) + + if self.layer_scale: + scale: torch.Tensor = self.adam_so2_layer_scales[ + layer_idx + ].reshape(1, self.n_focus, 1, self.so2_focus_dim) + x_local = residual + scale * x_local + else: + x_local = residual + x_local + so2_depth_sources.append(x_local - residual) + else: + for layer_idx, (so2_linear, inter_norm, non_linear) in enumerate( + zip( + self.so2_linears, + self.so2_inter_norms, + self.non_linearities, + strict=True, + ) + ): + residual = x_local + x_local = inter_norm(x_local) + x_local = so2_linear(x_local) + apply_bias_correction(x_local, so2_linear, layer_idx) + + x_local = non_linear(x_local) + + if self.layer_scale: + scale = self.adam_so2_layer_scales[layer_idx].reshape( + 1, self.n_focus, 1, self.so2_focus_dim + ) + x_local = residual + scale * x_local + else: + x_local = residual + x_local + + # === Step 6. Cross-focus softmax competition === + if self.focus_compete and self.n_focus > 1: + focus_gate_src = focus_gate_src.to(dtype=self.compute_dtype) + focus_logits = torch.einsum( + "efi,if->ef", + self.focus_compete_norm(focus_gate_src), + self.adamw_focus_compete_w, + ) + if self.mlp_bias: + focus_logits = focus_logits + self.focus_compete_bias.unsqueeze(0) + alpha = torch.softmax(focus_logits / self.focus_softmax_tau, dim=1).to( + dtype=x_local.dtype + ) + alpha = alpha * (1.0 - self.focus_label_smoothing) + ( + self.focus_label_smoothing / float(self.n_focus) + ) + x_local = x_local * alpha.unsqueeze(-1).unsqueeze(-1) + + # Restore reduced global layout for inverse rotation + x_local = x_local.transpose(1, 2).contiguous() # (E, D_m, F, Cf) + x_local = x_local.reshape( + n_edge, self.reduced_dim, self.hidden_channels + ) # (E, D_m, C_wide) + + # === Step 7. Rotate back to global frame === + with nvtx_range("SO2Conv/rotate_back"): + Dt_full = edge_cache.Dt_full + if self.use_triton_rotations and not self.training: + x_message = rotate_back_triton( + x_local=x_local, + wigner=Dt_full, + coeff_index=self.coeff_index_m, + dim_full=self.ebed_dim_full, + rotation_mode=self.triton_rotation_mode, + ) # (E, D, C_wide) + else: + Dt_from_m = project_Dt_from_m( + Dt_full=Dt_full, + coeff_index_m=self.coeff_index_m, + ebed_dim_full=self.ebed_dim_full, + cache=edge_cache.Dt_from_m_cache, + key_lmax=self.lmax, + key_mmax=self.mmax, + ) + x_message = torch.bmm(Dt_from_m, x_local) # (E, D, C_wide) + # Reduced layouts keep only 2*mmax+1 orders for l>mmax. Applying the + # inverse-rotation degree rescale after the global lift restores the + # full-basis amplitude expected by the block output contract. + x_message = x_message * self.rotate_inv_rescale_full.view(1, -1, 1) + if self.attn_focus_mix is not None: + x_message = self.attn_focus_mix(x_message.unsqueeze(2)).squeeze(2) + + # === Step 8. Aggregate with optional head-wise gating === + with nvtx_range("SO2Conv/aggregate"): + # Source Freeze Propagation Gate: broadcast the per-edge scalar + # eta[src] to the edge message before destination aggregation. + # ``edge_src_gate`` is ``None`` outside bridging mode, in which + # case this branch disappears and the baseline / attention paths + # run unchanged. + edge_src_gate = edge_cache.edge_src_gate + if self.n_atten_head == 0: + # Baseline path: fused envelope-weighted scatter add -> degree norm. + # Folding edge_src_gate into the scalar envelope keeps the + # op count unchanged. + edge_weight = edge_cache.edge_env # (E, 1) + if edge_src_gate is not None: + edge_weight = edge_weight * edge_src_gate.to( + dtype=edge_weight.dtype + ) + x_message = x_message * edge_weight.unsqueeze(-1) + out = x.new_zeros(x.shape, dtype=self.compute_dtype) + out.index_add_(0, dst, x_message.to(dtype=self.compute_dtype)) + out.mul_(edge_cache.inv_sqrt_deg.to(dtype=self.compute_dtype)) + out = out.to(dtype=self.dtype) # (N, D, C_wide) + else: + # === Step 8.1. Build attention logits from scalar channels === + compute_dtype = self.compute_dtype + x_l0_node = x[:, 0, :].reshape( + n_node, self.attn_n_focus, self.attn_focus_dim + ) # (N, Fa, Ca) + qk_input = self.attn_qk_norm(x_l0_node.to(dtype=compute_dtype)) + q_node = self.attn_q_proj(qk_input) # (N, Fa, Ca) + k_node = self.attn_k_proj(qk_input) # (N, Fa, Ca) + q_edge = q_node.index_select(0, dst).reshape( + n_edge, self.attn_n_focus, self.n_atten_head, self.head_dim + ) # (E, Fa, H, Ch), Ca = H * Ch + k_edge = k_node.index_select(0, src).reshape( + n_edge, self.attn_n_focus, self.n_atten_head, self.head_dim + ) # (E, Fa, H, Ch) + radial_l0 = rad_feat[:, 0, :].reshape( + n_edge, self.attn_n_focus, self.attn_focus_dim + ) # (E, Fa, Ca) + radial_bias = torch.einsum( + "efi,ifo->efo", + radial_l0.to(dtype=compute_dtype), + self.adamw_attn_logit_w, + ) # (E, F, H) + attn_logits: torch.Tensor = (q_edge * k_edge).sum(-1) * ( + self.head_dim**-0.5 + ) + attn_logits = attn_logits + radial_bias + + # === Step 8.2. Destination-wise stable envelope-gated softmax === + # ``src_weight=edge_src_gate`` folds SFPG into both the + # numerator and the denominator of the softmax. A muted + # source (``eta_src = 0``) therefore drops out of the + # destination's attention normalization entirely, which + # is required for the attention path to honor the + # frozen-zone invariance: a post-multiplication on + # ``attn_alpha`` alone would still leave the muted + # source leaking through the shared denominator. + attn_alpha = segment_envelope_gated_softmax( + logits=attn_logits, + edge_env=edge_cache.edge_env.to(dtype=compute_dtype), + dst=dst, + n_nodes=n_node, + z_bias_raw=self.adamw_attn_z_bias_raw, + eps=self.eps, + src_weight=( + None + if edge_src_gate is None + else edge_src_gate.to(dtype=compute_dtype) + ), + ) # (E, F, H) + + # === Step 8.3. Value projection and head-wise aggregation === + value_focus = x_message.reshape( + n_edge, + self.ebed_dim_full, + self.attn_n_focus, + self.attn_focus_dim, + ).to(dtype=compute_dtype) # (E, D, Fa, Ca) + if self.attn_v_proj is not None: + value_focus = self.attn_v_proj(value_focus) + value_heads = value_focus.reshape( + n_edge, + self.ebed_dim_full, + self.attn_n_focus, + self.n_atten_head, + self.head_dim, + ) # (E, D, Fa, H, Ch) + weighted_value = value_heads * attn_alpha.reshape( + n_edge, 1, self.attn_n_focus, self.n_atten_head, 1 + ) + out_heads = torch.zeros( + n_node, + self.ebed_dim_full, + self.attn_n_focus, + self.n_atten_head, + self.head_dim, + device=x.device, + dtype=compute_dtype, + ) # (N, D, Fa, H, Ch) + out_heads.index_add_(0, dst, weighted_value) + + # === Step 8.4. Output-side head gate === + attn_output_gate = torch.sigmoid( + torch.einsum( + "nfi,ifo->nfo", + self.attn_output_gate_norm(x_l0_node.to(dtype=compute_dtype)), + self.adamw_attn_gate_w, + ) + ) # (N, F, H) + out_heads = out_heads * attn_output_gate.reshape( + n_node, 1, self.attn_n_focus, self.n_atten_head, 1 + ) # (N, D, Fa, H, Ch) + + # === Step 8.5. Output projection and merge heads === + out_focus = out_heads.reshape( + n_node, + self.ebed_dim_full, + self.attn_n_focus, + self.attn_focus_dim, + ) # (N, D, Fa, Ca) + if self.attn_o_proj is not None: + out_focus = self.attn_o_proj(out_focus) + out = out_focus.reshape( + n_node, self.ebed_dim_full, self.hidden_channels + ).to(dtype=self.dtype) # (N, D, C_wide) + + # === Step 9. Final channel mixing === + with nvtx_range("SO2Conv/post_focus_mix"): + out = self.post_focus_mix(out.unsqueeze(2)).squeeze(2) + return out # (N, D, C) + + def serialize(self) -> dict[str, Any]: + trainable = all(p.requires_grad for p in self.parameters()) + state = self.state_dict() + return { + "@class": "SO2Convolution", + "@version": 1, + "config": { + "lmax": self.lmax, + "mmax": self.mmax, + "channels": self.channels, + "n_focus": self.n_focus, + "focus_dim": self.focus_dim, + "focus_compete": self.focus_compete, + "so2_norm": self.so2_norm, + "so2_layers": self.so2_layers, + "so2_attn_res": self.so2_attn_res_mode, + "layer_scale": self.layer_scale, + "n_atten_head": self.n_atten_head, + "atten_f_mix": self.atten_f_mix, + "atten_v_proj": self.use_atten_v_proj, + "atten_o_proj": self.use_atten_o_proj, + "s2_activation": self.s2_activation, + "lebedev_quadrature": self.lebedev_quadrature, + "activation_function": self.activation_function, + "mlp_bias": self.mlp_bias, + "radial_so2_mode": self.radial_so2_mode, + "radial_so2_rank": self.radial_so2_rank, + "eps": self.eps, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "trainable": trainable, + "seed": None, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SO2Convolution: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "SO2Convolution": + raise ValueError(f"Invalid class for SO2Convolution: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj diff --git a/deepmd/pt/model/descriptor/sezm_nn/so3.py b/deepmd/pt/model/descriptor/sezm_nn/so3.py new file mode 100644 index 0000000000..914d516018 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/so3.py @@ -0,0 +1,430 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +SO(3)-equivariant linear layers for SeZM. + +This module defines the channel-only and focus-aware linear maps used by SeZM +SO(3) feature transformations. +""" + +from __future__ import ( + annotations, +) + +import math +from typing import ( + Any, +) + +import torch +import torch.nn as nn + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, + RESERVED_PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + get_generator, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .indexing import ( + get_so3_dim_of_lmax, + map_degree_idx, +) +from .utils import ( + init_trunc_normal_fan_in_out, + np_safe, + safe_numpy_to_tensor, +) + + +class FocusLinear(nn.Module): + """ + Per-focus linear projection on the last feature axis. + + Notes + ----- + Parameters are stored in (in, out) convention to match Muon's rectangular + correction assumption (rows=fan_in, cols=fan_out): + - weight: (in_channels, n_focus * out_channels) + - bias: (n_focus * out_channels,) + + Parameters + ---------- + in_channels + Input feature dimension. + out_channels + Output feature dimension. + n_focus + Number of focus streams. + dtype + Parameter dtype. + bias + Whether to use bias. + trainable + Whether parameters are trainable. + seed + Random seed for initialization. + init_std + If given, use normal(0, init_std) instead of default uniform init. + Useful for gate projections where small initial logits are desired. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: int, + n_focus: int, + dtype: torch.dtype, + bias: bool = True, + trainable: bool, + seed: int | list[int] | None = None, + init_std: float | None = None, + ) -> None: + super().__init__() + self.in_channels = int(in_channels) + self.out_channels = int(out_channels) + self.n_focus = int(n_focus) + self.dtype = dtype + self.device = env.DEVICE + self.use_bias = bool(bias) + self.weight = nn.Parameter( + torch.empty( + self.in_channels, + self.n_focus * self.out_channels, + device=self.device, + dtype=self.dtype, + ) + ) + gen = get_generator(seed) + if init_std is not None: + nn.init.normal_(self.weight, mean=0.0, std=init_std, generator=gen) + else: + bound = 1.0 / math.sqrt(self.in_channels) + nn.init.uniform_(self.weight, -bound, bound, generator=gen) + if self.use_bias: + self.bias: nn.Parameter | None = nn.Parameter( + torch.zeros( + self.n_focus * self.out_channels, + device=self.device, + dtype=self.dtype, + ) + ) + else: + self.bias = None + for p in self.parameters(): + p.requires_grad = trainable + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x + Input tensor with shape (B, F, Cin). + + Returns + ------- + torch.Tensor + Projected tensor with shape (B, F, Cout). + """ + weight = self.weight.view(self.in_channels, self.n_focus, self.out_channels) + out = torch.einsum("bfi,ifo->bfo", x, weight) + if self.use_bias: + bias = self.bias.view(self.n_focus, self.out_channels) + out = out + bias.unsqueeze(0) + return out + + +class ChannelLinear(nn.Module): + """ + Channel-only linear projection on the last feature axis. + + Notes + ----- + Parameters are stored in (in, out) convention to match Muon's rectangular + correction assumption (rows=fan_in, cols=fan_out): + - weight: (in_channels, out_channels) + - bias: (out_channels,) + + Parameters + ---------- + in_channels + Input feature dimension. + out_channels + Output feature dimension. + dtype + Parameter dtype. + bias + Whether to use bias. + trainable + Whether parameters are trainable. + seed + Random seed for initialization. + init_std + If given, use normal(0, init_std) instead of default uniform init. + Useful for gate projections where small initial logits are desired. + """ + + def __init__( + self, + *, + in_channels: int, + out_channels: int, + dtype: torch.dtype, + bias: bool = True, + trainable: bool, + seed: int | list[int] | None = None, + init_std: float | None = None, + ) -> None: + super().__init__() + self.in_channels = int(in_channels) + self.out_channels = int(out_channels) + self.dtype = dtype + self.device = env.DEVICE + self.use_bias = bool(bias) + self.weight = nn.Parameter( + torch.empty( + self.in_channels, + self.out_channels, + device=self.device, + dtype=self.dtype, + ) + ) + gen = get_generator(seed) + if init_std is not None: + nn.init.normal_(self.weight, mean=0.0, std=init_std, generator=gen) + else: + bound = 1.0 / math.sqrt(self.in_channels) + nn.init.uniform_(self.weight, -bound, bound, generator=gen) + if self.use_bias: + self.bias: nn.Parameter | None = nn.Parameter( + torch.zeros( + self.out_channels, + device=self.device, + dtype=self.dtype, + ) + ) + else: + self.bias = None + for p in self.parameters(): + p.requires_grad = trainable + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x + Input tensor with shape ``(..., C_in)``. + + Returns + ------- + torch.Tensor + Projected tensor with shape ``(..., C_out)``. + """ + out = torch.einsum("...i,io->...o", x, self.weight) + if self.use_bias: + out = out + self.bias + return out + + +class SO3Linear(nn.Module): + """ + Focus-aware degree-wise linear self-interaction. + + This vectorized implementation avoids Python loops by using ``torch.einsum`` + and ``index_select``. The key insight is that weights are shared across all + ``m`` components within each ``l`` block. + + Notes + ----- + - Weight storage: ``(lmax+1, C_in, F*C_out)``. + - Bias storage: ``(F*C_out,)``, only applied to ``l=0`` scalar components. + - Runtime view restores weights to ``(lmax+1, C_in, F, C_out)`` via reshape. + - ``expand_index`` maps each packed ``(l,m)`` position to its ``l`` value. + - Einsum ``ndfi,difo->ndfo`` keeps the whole multi-focus path vectorized. + - In HybridMuon slice mode, each ``(C_in, F*C_out)`` slice gets independent + NS update with stable rectangular scaling. + + Parameters + ---------- + lmax + Maximum spherical harmonic degree. + in_channels + Number of input channels per (l, m) coefficient. + out_channels + Number of output channels per (l, m) coefficient. + n_focus + Number of focus streams. + dtype + Parameter dtype. + mlp_bias + Whether to use bias for l=0 (scalar) components. + trainable + Whether parameters are trainable. + seed + Random seed for weight initialization. + init_std + If given, use normal(0, init_std) for all weights instead of default + trunc-normal fan-in/fan-out init. Use 0.0 for zero initialization. + """ + + def __init__( + self, + *, + lmax: int, + in_channels: int, + out_channels: int, + n_focus: int = 1, + dtype: torch.dtype, + mlp_bias: bool = False, + trainable: bool, + seed: int | list[int] | None = None, + init_std: float | None = None, + ) -> None: + super().__init__() + self.lmax = int(lmax) + self.in_channels = int(in_channels) + self.out_channels = int(out_channels) + self.n_focus = int(n_focus) + self.dtype = dtype + self.device = env.DEVICE + self.precision = RESERVED_PRECISION_DICT[dtype] + self.ebed_dim = get_so3_dim_of_lmax(self.lmax) + self.mlp_bias = bool(mlp_bias) + + # === Step 1. Per-l weight matrix with focus folded on output axis === + # Storage: (lmax+1, C_in, F*C_out); runtime view: (lmax+1, C_in, F, C_out). + num_l = self.lmax + 1 + self.weight = nn.Parameter( + torch.empty( + num_l, + self.in_channels, + self.n_focus * self.out_channels, + dtype=self.dtype, + device=self.device, + ) + ) + if init_std is not None: + if init_std == 0.0: + nn.init.zeros_(self.weight) + else: + nn.init.normal_( + self.weight, + mean=0.0, + std=init_std, + generator=get_generator(seed), + ) + else: + for l_idx in range(num_l): + init_trunc_normal_fan_in_out( + self.weight[l_idx], + child_seed(seed, 1000 + l_idx), + ) + + # === Step 2. Bias only for l=0 (scalar components) === + if self.mlp_bias: + self.bias: nn.Parameter | None = nn.Parameter( + torch.zeros( + self.n_focus * self.out_channels, + dtype=self.dtype, + device=self.device, + ) + ) + else: + self.bias = None + + # === Step 3. Precompute expand_index for weight lookup === + self.register_buffer( + "expand_index", + map_degree_idx(self.lmax, device=self.device), + persistent=True, + ) + + for p in self.parameters(): + p.requires_grad = trainable + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ---------- + x + Input features with shape (N, D, F, C_in) where D=(lmax+1)^2. + + Returns + ------- + torch.Tensor + Order-wise mixed features with shape (N, D, F, C_out). + """ + # === Step 1. Expand per-l weights to packed coefficient layout === + # (L, Cin, F*Cout) -> (L, Cin, F, Cout) + weight = self.weight.view( + self.lmax + 1, + self.in_channels, + self.n_focus, + self.out_channels, + ) # (L, Cin, F, Cout) + # (L, Cin, F, Cout) -> (D, Cin, F, Cout) + weight_expanded = torch.index_select( + weight, dim=0, index=self.expand_index + ) # (D, Cin, F, Cout) + + # === Step 2. Per-focus, per-degree channel mixing === + out = torch.einsum("ndfi,difo->ndfo", x, weight_expanded) + + # === Step 3. Add l=0 bias === + if self.mlp_bias: + bias = self.bias.view(self.n_focus, self.out_channels) + out[:, 0, :, :] = out[:, 0, :, :] + bias.unsqueeze(0) + + return out + + def serialize(self) -> dict[str, Any]: + trainable = all(p.requires_grad for p in self.parameters()) + state = self.state_dict() + return { + "@class": "SO3Linear", + "@version": 1, + "config": { + "lmax": self.lmax, + "in_channels": self.in_channels, + "out_channels": self.out_channels, + "n_focus": self.n_focus, + "precision": RESERVED_PRECISION_DICT[self.dtype], + "mlp_bias": self.mlp_bias, + "trainable": trainable, + "seed": None, + }, + "@variables": {key: np_safe(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SO3Linear: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "SO3Linear": + raise ValueError(f"Invalid class for SO3Linear: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + precision = config.pop("precision") + config["dtype"] = PRECISION_DICT[precision] + obj = cls(**config) + template = obj.state_dict() + state = { + key: safe_numpy_to_tensor( + value, device=template[key].device, dtype=template[key].dtype + ) + for key, value in variables.items() + } + obj.load_state_dict(state) + return obj diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py b/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py new file mode 100644 index 0000000000..5c80f7824d --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/__init__.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Public Triton entry points for SeZM SO(2) rotations.""" + +from .autograd import ( + edge_geometry_rbf_triton, + rotate_back_triton, + rotate_to_local_triton, +) +from .constants import ( + SEZM_TRITON_AVAILABLE, + TritonRotationMode, +) +from .dispatch import ( + resolve_triton_rotation_mode, + sezm_triton_enabled, + uses_triton_kernel, +) + +__all__ = [ + "SEZM_TRITON_AVAILABLE", + "TritonRotationMode", + "edge_geometry_rbf_triton", + "resolve_triton_rotation_mode", + "rotate_back_triton", + "rotate_to_local_triton", + "sezm_triton_enabled", + "uses_triton_kernel", +] diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/autograd.py b/deepmd/pt/model/descriptor/sezm_nn/triton/autograd.py new file mode 100644 index 0000000000..dd1c9bbc06 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/autograd.py @@ -0,0 +1,837 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Autograd and public API for SeZM Triton kernels.""" + +from __future__ import ( + annotations, +) + +from typing import ( + Any, +) + +import torch +from torch import ( + Tensor, +) + +from ..utils import ( + safe_norm, +) +from .constants import ( + SEZM_TRITON_AVAILABLE, + TritonRotationMode, +) +from .dispatch import ( + coerce_rotation_mode, + resolve_triton_rotation_mode, +) + +if SEZM_TRITON_AVAILABLE: + from . import custom_ops as _custom_ops # noqa: F401 + + +def _compute_cutoff_envelope_eager( + *, + r: Tensor, + rcut: float, + a: float, + b: float, + c: float, + d: float, + exponent: int, +) -> Tensor: + """Reference eager evaluation of the C^3 cutoff envelope.""" + x = (r / rcut).clamp(min=0.0, max=1.0) + poly = a + x * (b + x * (c + x * d)) + env = 1.0 + (x ** int(exponent)) * poly + return env * (x < 1.0).to(dtype=r.dtype) + + +def _edge_geometry_rbf_eager( + *, + coord_flat: Tensor, + center_coord_index: Tensor, + neighbor_coord_index: Tensor, + freqs: Tensor, + eps: float, + rcut: float, + edge_env_a: float, + edge_env_b: float, + edge_env_c: float, + edge_env_d: float, + edge_env_exponent: int, + radial_env_a: float, + radial_env_b: float, + radial_env_c: float, + radial_env_d: float, + radial_env_exponent: int, + r_inner: float, + r_outer: float, + has_inner_clamp: bool, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Reference eager implementation of the edge geometry/RBF chain.""" + center_pos = coord_flat.index_select(0, center_coord_index) + neighbor_pos = coord_flat.index_select(0, neighbor_coord_index) + edge_vec = neighbor_pos - center_pos + raw_len = safe_norm(edge_vec, float(eps)) + edge_len = raw_len + if has_inner_clamp: + delta = float(r_outer - r_inner) + t = ((edge_len - float(r_inner)) / delta).clamp(0.0, 1.0) + t2 = t * t + t4 = t2 * t2 + h = t4 * (20.0 + t * (-45.0 + t * (36.0 - 10.0 * t))) + clamped = float(r_inner) + delta * h + edge_len = torch.where(edge_len >= float(r_outer), edge_len, clamped) + scale = edge_len / raw_len + edge_vec = edge_vec * scale + edge_env = _compute_cutoff_envelope_eager( + r=edge_len, + rcut=float(rcut), + a=float(edge_env_a), + b=float(edge_env_b), + c=float(edge_env_c), + d=float(edge_env_d), + exponent=int(edge_env_exponent), + ) + radial_env = _compute_cutoff_envelope_eager( + r=edge_len, + rcut=float(rcut), + a=float(radial_env_a), + b=float(radial_env_b), + c=float(radial_env_c), + d=float(radial_env_d), + exponent=int(radial_env_exponent), + ) + freqs_row = freqs.view(1, -1) + phase = edge_len * freqs_row + edge_rbf = freqs_row * torch.sinc(phase / torch.pi) * radial_env + return edge_vec, edge_len, edge_env, edge_rbf + + +def _extract_envelope_params( + envelope: Any, +) -> tuple[float, float, float, float, float, int]: + """Extract the polynomial envelope parameters from one SeZM module.""" + return ( + float(envelope.rcut), + float(envelope.coeff_a), + float(envelope.coeff_b), + float(envelope.coeff_c), + float(envelope.coeff_d), + int(envelope.p), + ) + + +def _extract_edge_geometry_rbf_constants( + *, + edge_envelope: Any, + radial_basis: Any, + inner_clamp: Any, +) -> tuple[ + float, + float, + float, + float, + float, + int, + float, + float, + float, + float, + int, + float, + float, + bool, +]: + """Extract scalar constants used by the fused geometry/RBF chain.""" + ( + rcut, + edge_env_a, + edge_env_b, + edge_env_c, + edge_env_d, + edge_env_exponent, + ) = _extract_envelope_params(edge_envelope) + ( + _, + radial_env_a, + radial_env_b, + radial_env_c, + radial_env_d, + radial_env_exponent, + ) = _extract_envelope_params(radial_basis.envelope) + if inner_clamp is None: + r_inner = 0.0 + r_outer = 0.0 + has_inner_clamp = False + else: + r_inner = float(inner_clamp.r_inner) + r_outer = float(inner_clamp.r_outer) + has_inner_clamp = True + return ( + rcut, + edge_env_a, + edge_env_b, + edge_env_c, + edge_env_d, + edge_env_exponent, + radial_env_a, + radial_env_b, + radial_env_c, + radial_env_d, + radial_env_exponent, + r_inner, + r_outer, + has_inner_clamp, + ) + + +def _rotate_to_local_eager( + *, + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> Tensor: + """Reference eager implementation for ``D_to_m @ x[src]``.""" + D_to_m = wigner[:, :dim_full, :dim_full].index_select(1, coeff_index) + return torch.bmm(D_to_m, x.index_select(0, src)) + + +def _rotate_back_eager( + *, + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, +) -> Tensor: + """Reference eager implementation for ``Dt_from_m @ x_local``.""" + Dt_from_m = wigner[:, :dim_full, :dim_full].index_select(2, coeff_index) + return torch.bmm(Dt_from_m, x_local) + + +def _resolve_rotation_mode_for_call( + *, + dim_full: int, + coeff_index: Tensor, + rotation_mode: int | TritonRotationMode | None, +) -> TritonRotationMode: + """Resolve the effective dispatch mode for one public API call.""" + if rotation_mode is None: + return resolve_triton_rotation_mode( + dim_full=int(dim_full), + reduced_dim=int(coeff_index.numel()), + ) + return coerce_rotation_mode(rotation_mode) + + +if SEZM_TRITON_AVAILABLE: + + class _RotateToLocalFunction(torch.autograd.Function): + """Autograd wrapper for the fused ``global -> local reduced`` rotation.""" + + @staticmethod + def forward( + ctx: Any, + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + rotation_mode: int, + ) -> Tensor: + reduced_dim = int(coeff_index.numel()) + out = torch.empty( + src.shape[0], + reduced_dim, + x.shape[2], + dtype=x.dtype, + device=x.device, + ) + torch.ops.deepmd._kernel_sezm_rotate_to_local( + x, + src, + wigner, + coeff_index, + out, + dim_full, + rotation_mode, + ) + ctx.save_for_backward(x, src, wigner, coeff_index) + ctx.dim_full = int(dim_full) + ctx.rotation_mode = int(rotation_mode) + return out + + @staticmethod + def backward( + ctx: Any, + grad_out: Tensor, + ) -> tuple[Tensor, None, Tensor, None, None, None]: + x, src, wigner, coeff_index = ctx.saved_tensors + dim_full = int(ctx.dim_full) + rotation_mode = coerce_rotation_mode(int(ctx.rotation_mode)) + grad_out = grad_out.contiguous() + grad_edge = torch.empty( + src.shape[0], + dim_full, + x.shape[2], + dtype=grad_out.dtype, + device=grad_out.device, + ) + torch.ops.deepmd._kernel_sezm_rotate_to_local_bwd_dx( + grad_out, + wigner, + coeff_index, + grad_edge, + dim_full, + int(rotation_mode), + ) + grad_x = torch.zeros_like(x) + grad_x.index_add_(0, src, grad_edge) + + if rotation_mode == TritonRotationMode.GENERIC_TILED: + grad_rows = torch.empty( + src.shape[0], + coeff_index.numel(), + dim_full, + dtype=wigner.dtype, + device=grad_out.device, + ) + torch.ops.deepmd._kernel_sezm_rotate_to_local_bwd_dw( + grad_out, + x, + src, + coeff_index, + grad_rows, + dim_full, + int(rotation_mode), + ) + grad_wigner = torch.zeros_like(wigner) + grad_wigner[:, coeff_index, :dim_full] = grad_rows + else: + grad_wigner = torch.zeros_like(wigner) + torch.ops.deepmd._kernel_sezm_rotate_to_local_bwd_dw( + grad_out, + x, + src, + coeff_index, + grad_wigner, + dim_full, + int(rotation_mode), + ) + return grad_x, None, grad_wigner, None, None, None + + class _RotateBackFunction(torch.autograd.Function): + """Autograd wrapper for the fused ``local reduced -> global`` rotation.""" + + @staticmethod + def forward( + ctx: Any, + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + rotation_mode: int, + ) -> Tensor: + out = torch.empty( + x_local.shape[0], + dim_full, + x_local.shape[2], + dtype=x_local.dtype, + device=x_local.device, + ) + torch.ops.deepmd._kernel_sezm_rotate_back( + x_local, + wigner, + coeff_index, + out, + dim_full, + rotation_mode, + ) + ctx.save_for_backward(x_local, wigner, coeff_index) + ctx.dim_full = int(dim_full) + ctx.rotation_mode = int(rotation_mode) + return out + + @staticmethod + def backward( + ctx: Any, + grad_out: Tensor, + ) -> tuple[Tensor, Tensor, None, None, None]: + x_local, wigner, coeff_index = ctx.saved_tensors + dim_full = int(ctx.dim_full) + rotation_mode = coerce_rotation_mode(int(ctx.rotation_mode)) + grad_out = grad_out.contiguous() + grad_x_local = torch.empty_like(x_local) + torch.ops.deepmd._kernel_sezm_rotate_back_bwd_dx( + grad_out, + wigner, + coeff_index, + grad_x_local, + dim_full, + int(rotation_mode), + ) + + if rotation_mode == TritonRotationMode.GENERIC_TILED: + grad_cols = torch.empty( + x_local.shape[0], + dim_full, + coeff_index.numel(), + dtype=wigner.dtype, + device=grad_out.device, + ) + torch.ops.deepmd._kernel_sezm_rotate_back_bwd_dw( + grad_out, + x_local, + coeff_index, + grad_cols, + dim_full, + int(rotation_mode), + ) + grad_wigner = torch.zeros_like(wigner) + grad_wigner[:, :dim_full, coeff_index] = grad_cols + else: + grad_wigner = torch.zeros_like(wigner) + torch.ops.deepmd._kernel_sezm_rotate_back_bwd_dw( + grad_out, + x_local, + coeff_index, + grad_wigner, + dim_full, + int(rotation_mode), + ) + return grad_x_local, grad_wigner, None, None, None + + class _EdgeGeometryRBFFunction(torch.autograd.Function): + """Autograd wrapper for the fused edge geometry/RBF chain.""" + + @staticmethod + def forward( + ctx: Any, + coord_flat: Tensor, + center_coord_index: Tensor, + neighbor_coord_index: Tensor, + freqs: Tensor, + eps: float, + rcut: float, + edge_env_a: float, + edge_env_b: float, + edge_env_c: float, + edge_env_d: float, + edge_env_exponent: int, + radial_env_a: float, + radial_env_b: float, + radial_env_c: float, + radial_env_d: float, + radial_env_exponent: int, + r_inner: float, + r_outer: float, + has_inner_clamp: bool, + ) -> tuple[Tensor, Tensor, Tensor, Tensor]: + freq_flat = freqs.reshape(-1) + num_edges = int(center_coord_index.shape[0]) + edge_vec = torch.empty( + num_edges, + 3, + dtype=coord_flat.dtype, + device=coord_flat.device, + ) + edge_len = torch.empty( + num_edges, + dtype=coord_flat.dtype, + device=coord_flat.device, + ) + edge_env = torch.empty( + num_edges, + dtype=coord_flat.dtype, + device=coord_flat.device, + ) + edge_rbf = torch.empty( + num_edges, + freq_flat.numel(), + dtype=coord_flat.dtype, + device=coord_flat.device, + ) + torch.ops.deepmd._kernel_sezm_edge_geometry_rbf( + coord_flat, + center_coord_index, + neighbor_coord_index, + freq_flat, + edge_vec, + edge_len, + edge_env, + edge_rbf, + float(eps), + float(rcut), + float(edge_env_a), + float(edge_env_b), + float(edge_env_c), + float(edge_env_d), + int(edge_env_exponent), + float(radial_env_a), + float(radial_env_b), + float(radial_env_c), + float(radial_env_d), + int(radial_env_exponent), + float(r_inner), + float(r_outer), + bool(has_inner_clamp), + ) + ctx.save_for_backward( + coord_flat, + center_coord_index, + neighbor_coord_index, + freqs, + ) + ctx.eps = float(eps) + ctx.rcut = float(rcut) + ctx.edge_env_a = float(edge_env_a) + ctx.edge_env_b = float(edge_env_b) + ctx.edge_env_c = float(edge_env_c) + ctx.edge_env_d = float(edge_env_d) + ctx.edge_env_exponent = int(edge_env_exponent) + ctx.radial_env_a = float(radial_env_a) + ctx.radial_env_b = float(radial_env_b) + ctx.radial_env_c = float(radial_env_c) + ctx.radial_env_d = float(radial_env_d) + ctx.radial_env_exponent = int(radial_env_exponent) + ctx.r_inner = float(r_inner) + ctx.r_outer = float(r_outer) + ctx.has_inner_clamp = bool(has_inner_clamp) + return edge_vec, edge_len.unsqueeze(-1), edge_env.unsqueeze(-1), edge_rbf + + @staticmethod + def backward( + ctx: Any, + grad_edge_vec: Tensor | None, + grad_edge_len: Tensor | None, + grad_edge_env: Tensor | None, + grad_edge_rbf: Tensor | None, + ) -> tuple[ + Tensor, + None, + None, + Tensor, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ]: + coord_flat, center_coord_index, neighbor_coord_index, freqs = ( + ctx.saved_tensors + ) + num_edges = int(center_coord_index.shape[0]) + freq_flat = freqs.reshape(-1) + + if grad_edge_vec is None: + grad_edge_vec = torch.zeros( + num_edges, + 3, + dtype=coord_flat.dtype, + device=coord_flat.device, + ) + else: + grad_edge_vec = grad_edge_vec.contiguous() + if grad_edge_len is None: + grad_edge_len = torch.zeros( + num_edges, + dtype=coord_flat.dtype, + device=coord_flat.device, + ) + else: + grad_edge_len = grad_edge_len.contiguous().squeeze(-1) + if grad_edge_env is None: + grad_edge_env = torch.zeros( + num_edges, + dtype=coord_flat.dtype, + device=coord_flat.device, + ) + else: + grad_edge_env = grad_edge_env.contiguous().squeeze(-1) + if grad_edge_rbf is None: + grad_edge_rbf = torch.zeros( + num_edges, + freq_flat.numel(), + dtype=coord_flat.dtype, + device=coord_flat.device, + ) + else: + grad_edge_rbf = grad_edge_rbf.contiguous() + + grad_r_total = torch.zeros( + num_edges, + dtype=coord_flat.dtype, + device=coord_flat.device, + ) + grad_freq = torch.zeros( + freq_flat.numel(), + dtype=freq_flat.dtype, + device=coord_flat.device, + ) + torch.ops.deepmd._kernel_sezm_edge_geometry_rbf_bwd_accum( + grad_edge_len, + grad_edge_env, + grad_edge_rbf, + coord_flat, + center_coord_index, + neighbor_coord_index, + freq_flat, + grad_r_total, + grad_freq, + float(ctx.eps), + float(ctx.rcut), + float(ctx.edge_env_a), + float(ctx.edge_env_b), + float(ctx.edge_env_c), + float(ctx.edge_env_d), + int(ctx.edge_env_exponent), + float(ctx.radial_env_a), + float(ctx.radial_env_b), + float(ctx.radial_env_c), + float(ctx.radial_env_d), + int(ctx.radial_env_exponent), + float(ctx.r_inner), + float(ctx.r_outer), + bool(ctx.has_inner_clamp), + ) + grad_coord = torch.zeros_like(coord_flat) + torch.ops.deepmd._kernel_sezm_edge_geometry_rbf_bwd_coord( + grad_edge_vec, + grad_r_total, + coord_flat, + center_coord_index, + neighbor_coord_index, + grad_coord, + float(ctx.eps), + float(ctx.r_inner), + float(ctx.r_outer), + bool(ctx.has_inner_clamp), + ) + return ( + grad_coord, + None, + None, + grad_freq.view_as(freqs), + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def rotate_to_local_triton( + x: Tensor, + src: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + rotation_mode: int | TritonRotationMode | None = None, +) -> Tensor: + """ + Apply the fused ``global -> local reduced`` rotation. + + Parameters + ---------- + x + Node features with shape ``(N, D, C)``. + src + Source-node indices with shape ``(E,)``. + wigner + Packed Wigner matrices with shape ``(E, D, D)``. + coeff_index + Reduced-layout row indices with shape ``(D_m,)``. + dim_full + Full packed SO(3) dimension. + rotation_mode + Optional pre-resolved dispatch mode. + + Returns + ------- + Tensor + Rotated reduced-layout edge features with shape ``(E, D_m, C)``. + """ + if not SEZM_TRITON_AVAILABLE: + raise RuntimeError("SeZM Triton kernels are not available in this environment.") + src = src.contiguous() + coeff_index = coeff_index.contiguous() + resolved_mode = _resolve_rotation_mode_for_call( + dim_full=int(dim_full), + coeff_index=coeff_index, + rotation_mode=rotation_mode, + ) + if resolved_mode == TritonRotationMode.EAGER_REFERENCE: + return _rotate_to_local_eager( + x=x, + src=src, + wigner=wigner, + coeff_index=coeff_index, + dim_full=int(dim_full), + ) + return _RotateToLocalFunction.apply( + x, + src, + wigner, + coeff_index, + int(dim_full), + int(resolved_mode), + ) + + +def rotate_back_triton( + x_local: Tensor, + wigner: Tensor, + coeff_index: Tensor, + dim_full: int, + rotation_mode: int | TritonRotationMode | None = None, +) -> Tensor: + """ + Apply the fused ``local reduced -> global`` rotation. + + Parameters + ---------- + x_local + Reduced-layout edge features with shape ``(E, D_m, C)``. + wigner + Packed Wigner matrices with shape ``(E, D, D)``. + coeff_index + Reduced-layout column indices with shape ``(D_m,)``. + dim_full + Full packed SO(3) dimension. + rotation_mode + Optional pre-resolved dispatch mode. + + Returns + ------- + Tensor + Lifted global-layout edge features with shape ``(E, D, C)``. + """ + if not SEZM_TRITON_AVAILABLE: + raise RuntimeError("SeZM Triton kernels are not available in this environment.") + coeff_index = coeff_index.contiguous() + resolved_mode = _resolve_rotation_mode_for_call( + dim_full=int(dim_full), + coeff_index=coeff_index, + rotation_mode=rotation_mode, + ) + if resolved_mode == TritonRotationMode.EAGER_REFERENCE: + return _rotate_back_eager( + x_local=x_local, + wigner=wigner, + coeff_index=coeff_index, + dim_full=int(dim_full), + ) + return _RotateBackFunction.apply( + x_local, + wigner, + coeff_index, + int(dim_full), + int(resolved_mode), + ) + + +def edge_geometry_rbf_triton( + *, + coord_flat: Tensor, + center_coord_index: Tensor, + neighbor_coord_index: Tensor, + edge_envelope: Any, + radial_basis: Any, + eps: float, + inner_clamp: Any, +) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Apply the fused edge geometry/RBF chain with eager fallback.""" + ( + rcut, + edge_env_a, + edge_env_b, + edge_env_c, + edge_env_d, + edge_env_exponent, + radial_env_a, + radial_env_b, + radial_env_c, + radial_env_d, + radial_env_exponent, + r_inner, + r_outer, + has_inner_clamp, + ) = _extract_edge_geometry_rbf_constants( + edge_envelope=edge_envelope, + radial_basis=radial_basis, + inner_clamp=inner_clamp, + ) + center_coord_index = center_coord_index.contiguous() + neighbor_coord_index = neighbor_coord_index.contiguous() + freqs = radial_basis.adam_freqs.contiguous() + if ( + center_coord_index.numel() == 0 + or not SEZM_TRITON_AVAILABLE + or coord_flat.device.type != "cuda" + or coord_flat.dtype not in (torch.float16, torch.bfloat16, torch.float32) + ): + return _edge_geometry_rbf_eager( + coord_flat=coord_flat, + center_coord_index=center_coord_index, + neighbor_coord_index=neighbor_coord_index, + freqs=freqs, + eps=float(eps), + rcut=rcut, + edge_env_a=edge_env_a, + edge_env_b=edge_env_b, + edge_env_c=edge_env_c, + edge_env_d=edge_env_d, + edge_env_exponent=edge_env_exponent, + radial_env_a=radial_env_a, + radial_env_b=radial_env_b, + radial_env_c=radial_env_c, + radial_env_d=radial_env_d, + radial_env_exponent=radial_env_exponent, + r_inner=r_inner, + r_outer=r_outer, + has_inner_clamp=has_inner_clamp, + ) + return _EdgeGeometryRBFFunction.apply( + coord_flat, + center_coord_index, + neighbor_coord_index, + freqs, + float(eps), + rcut, + edge_env_a, + edge_env_b, + edge_env_c, + edge_env_d, + edge_env_exponent, + radial_env_a, + radial_env_b, + radial_env_c, + radial_env_d, + radial_env_exponent, + r_inner, + r_outer, + has_inner_clamp, + ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/constants.py b/deepmd/pt/model/descriptor/sezm_nn/triton/constants.py new file mode 100644 index 0000000000..c2aabb8147 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/constants.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Shared constants and feature flags for SeZM Triton kernels.""" + +from __future__ import ( + annotations, +) + +from enum import ( + IntEnum, +) + +import torch + +_HAS_TORCH_TRITON_OP = hasattr(torch.library, "triton_op") and hasattr( + torch.library, "wrap_triton" +) + +if _HAS_TORCH_TRITON_OP: + try: + import triton # noqa: F401 + except ImportError: + SEZM_TRITON_AVAILABLE = False + else: + SEZM_TRITON_AVAILABLE = True +else: + SEZM_TRITON_AVAILABLE = False + +# Triton dot kernels require K >= 16 on the current CUDA backend. +TRITON_GRID_E_STRIDE = 2048 +TRITON_BLOCK_FULL = 16 +TRITON_BLOCK_REDUCED = 16 +TRITON_BLOCK_CHANNEL = 32 +TRITON_SMALL_BLOCK_CHANNEL = 128 +TRITON_SMALL_FULL_DIM = 16 +TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE = 128 +TRITON_EDGE_GEOMETRY_RBF_BLOCK_RADIAL = 16 + + +class TritonRotationMode(IntEnum): + """Dispatch mode for the SeZM rotation hot path.""" + + GENERIC_TILED = 0 + SMALL_LE1 = 1 + SMALL_L2 = 2 + SMALL_L3 = 3 + EAGER_REFERENCE = 4 diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/custom_ops.py b/deepmd/pt/model/descriptor/sezm_nn/triton/custom_ops.py new file mode 100644 index 0000000000..23b31aa2f5 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/custom_ops.py @@ -0,0 +1,861 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Triton custom-op launchers for SeZM SO(2) rotation kernels. + +This layer only decides how to launch a resolved dispatch mode. Fallback policy +stays in the public autograd API so the launchers remain focused on Triton +grids, kernel families, and argument packing. +""" + +from __future__ import ( + annotations, +) + +import torch # noqa: TC002 + +from .constants import ( + SEZM_TRITON_AVAILABLE, + TRITON_BLOCK_CHANNEL, + TRITON_BLOCK_FULL, + TRITON_BLOCK_REDUCED, + TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE, + TRITON_EDGE_GEOMETRY_RBF_BLOCK_RADIAL, + TRITON_GRID_E_STRIDE, + TRITON_SMALL_BLOCK_CHANNEL, + TritonRotationMode, +) +from .dispatch import ( + coerce_rotation_mode, +) + + +def _require_kernel_mode( + rotation_mode: int | TritonRotationMode, +) -> TritonRotationMode: + """Reject eager fallback before entering the Triton launch layer.""" + resolved_mode = coerce_rotation_mode(rotation_mode) + if resolved_mode == TritonRotationMode.EAGER_REFERENCE: + raise ValueError("Eager reference mode must be handled before Triton launch.") + return resolved_mode + + +if SEZM_TRITON_AVAILABLE: + from torch.library import ( + triton_op, + wrap_triton, + ) + + from .kernels_edge_geometry_rbf import ( + edge_geometry_rbf_bwd_accum_kernel, + edge_geometry_rbf_bwd_coord_kernel, + edge_geometry_rbf_forward_kernel, + ) + from .kernels_generic import ( + rotate_back_bwd_dw_kernel, + rotate_back_bwd_dx_kernel, + rotate_back_forward_kernel, + rotate_to_local_bwd_dw_kernel, + rotate_to_local_bwd_dx_kernel, + rotate_to_local_forward_kernel, + ) + from .kernels_small import ( + rotate_back_l1_bwd_dx_kernel, + rotate_back_l1_forward_kernel, + rotate_back_l2_bwd_dx_kernel, + rotate_back_l2_forward_kernel, + rotate_back_l3_bwd_dx_kernel, + rotate_back_l3_forward_kernel, + rotate_back_small_bwd_dw_kernel, + rotate_to_local_l1_bwd_dx_kernel, + rotate_to_local_l1_forward_kernel, + rotate_to_local_l2_bwd_dx_kernel, + rotate_to_local_l2_forward_kernel, + rotate_to_local_l3_bwd_dx_kernel, + rotate_to_local_l3_forward_kernel, + rotate_to_local_small_bwd_dw_kernel, + ) + + _ROTATE_TO_LOCAL_SMALL_FORWARD = { + TritonRotationMode.SMALL_LE1: rotate_to_local_l1_forward_kernel, + TritonRotationMode.SMALL_L2: rotate_to_local_l2_forward_kernel, + TritonRotationMode.SMALL_L3: rotate_to_local_l3_forward_kernel, + } + _ROTATE_TO_LOCAL_SMALL_BWD_DX = { + TritonRotationMode.SMALL_LE1: rotate_to_local_l1_bwd_dx_kernel, + TritonRotationMode.SMALL_L2: rotate_to_local_l2_bwd_dx_kernel, + TritonRotationMode.SMALL_L3: rotate_to_local_l3_bwd_dx_kernel, + } + _ROTATE_BACK_SMALL_FORWARD = { + TritonRotationMode.SMALL_LE1: rotate_back_l1_forward_kernel, + TritonRotationMode.SMALL_L2: rotate_back_l2_forward_kernel, + TritonRotationMode.SMALL_L3: rotate_back_l3_forward_kernel, + } + _ROTATE_BACK_SMALL_BWD_DX = { + TritonRotationMode.SMALL_LE1: rotate_back_l1_bwd_dx_kernel, + TritonRotationMode.SMALL_L2: rotate_back_l2_bwd_dx_kernel, + TritonRotationMode.SMALL_L3: rotate_back_l3_bwd_dx_kernel, + } + + def _small_channel_grid(channels: int) -> tuple[int, int]: + """Return the standard ``(edge, channel)`` grid for small kernels.""" + return ( + TRITON_GRID_E_STRIDE, + (channels + TRITON_SMALL_BLOCK_CHANNEL - 1) // TRITON_SMALL_BLOCK_CHANNEL, + ) + + def _generic_rotate_to_local_forward_grid( + reduced_dim: int, + channels: int, + ) -> tuple[int, int, int]: + """Return the standard forward grid for generic rotate-to-local.""" + return ( + TRITON_GRID_E_STRIDE, + (reduced_dim + TRITON_BLOCK_REDUCED - 1) // TRITON_BLOCK_REDUCED, + (channels + TRITON_BLOCK_CHANNEL - 1) // TRITON_BLOCK_CHANNEL, + ) + + def _generic_rotate_to_local_bwd_dx_grid( + dim_full: int, + channels: int, + ) -> tuple[int, int, int]: + """Return the source-gradient grid for generic rotate-to-local.""" + return ( + TRITON_GRID_E_STRIDE, + (dim_full + TRITON_BLOCK_FULL - 1) // TRITON_BLOCK_FULL, + (channels + TRITON_BLOCK_CHANNEL - 1) // TRITON_BLOCK_CHANNEL, + ) + + def _generic_rotate_to_local_bwd_dw_grid( + reduced_dim: int, + dim_full: int, + ) -> tuple[int, int, int]: + """Return the Wigner-gradient grid for generic rotate-to-local.""" + return ( + TRITON_GRID_E_STRIDE, + (reduced_dim + TRITON_BLOCK_REDUCED - 1) // TRITON_BLOCK_REDUCED, + (dim_full + TRITON_BLOCK_FULL - 1) // TRITON_BLOCK_FULL, + ) + + def _generic_rotate_back_forward_grid( + dim_full: int, + channels: int, + ) -> tuple[int, int, int]: + """Return the standard forward grid for generic rotate-back.""" + return ( + TRITON_GRID_E_STRIDE, + (dim_full + TRITON_BLOCK_FULL - 1) // TRITON_BLOCK_FULL, + (channels + TRITON_BLOCK_CHANNEL - 1) // TRITON_BLOCK_CHANNEL, + ) + + def _generic_rotate_back_bwd_dx_grid( + reduced_dim: int, + channels: int, + ) -> tuple[int, int, int]: + """Return the reduced-gradient grid for generic rotate-back.""" + return ( + TRITON_GRID_E_STRIDE, + (reduced_dim + TRITON_BLOCK_REDUCED - 1) // TRITON_BLOCK_REDUCED, + (channels + TRITON_BLOCK_CHANNEL - 1) // TRITON_BLOCK_CHANNEL, + ) + + def _generic_rotate_back_bwd_dw_grid( + dim_full: int, + reduced_dim: int, + ) -> tuple[int, int, int]: + """Return the Wigner-gradient grid for generic rotate-back.""" + return ( + TRITON_GRID_E_STRIDE, + (dim_full + TRITON_BLOCK_FULL - 1) // TRITON_BLOCK_FULL, + (reduced_dim + TRITON_BLOCK_REDUCED - 1) // TRITON_BLOCK_REDUCED, + ) + + def _edge_geometry_rbf_grid(num_edges: int, n_radial: int) -> tuple[int, int]: + """Return the standard grid for the fused edge geometry/RBF chain.""" + return ( + (num_edges + TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE - 1) + // TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE, + (n_radial + TRITON_EDGE_GEOMETRY_RBF_BLOCK_RADIAL - 1) + // TRITON_EDGE_GEOMETRY_RBF_BLOCK_RADIAL, + ) + + def _edge_geometry_rbf_coord_grid(num_edges: int) -> tuple[int]: + """Return the edge-only grid for geometry/RBF coordinate gradients.""" + return ( + (num_edges + TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE - 1) + // TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE, + ) + + def _launch_rotate_to_local_small_forward( + *, + rotation_mode: TritonRotationMode, + x: torch.Tensor, + src: torch.Tensor, + wigner: torch.Tensor, + coeff_index: torch.Tensor, + out: torch.Tensor, + dim_full: int, + ) -> None: + """Launch one specialized small-family rotate-to-local forward kernel.""" + reduced_dim = coeff_index.numel() + channels = x.shape[2] + kernel = _ROTATE_TO_LOCAL_SMALL_FORWARD[rotation_mode] + wrap_triton(kernel)[_small_channel_grid(channels)]( + x, + src, + wigner, + coeff_index, + out, + src.shape[0], + reduced_dim, + dim_full, + channels, + x.stride(0), + x.stride(1), + x.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_CHANNEL=TRITON_SMALL_BLOCK_CHANNEL, + GRID_E_STRIDE=TRITON_GRID_E_STRIDE, + num_warps=1, + ) + + def _launch_rotate_to_local_small_bwd_dx( + *, + rotation_mode: TritonRotationMode, + grad_out: torch.Tensor, + wigner: torch.Tensor, + coeff_index: torch.Tensor, + grad_edge: torch.Tensor, + dim_full: int, + ) -> None: + """Launch one specialized small-family rotate-to-local dx kernel.""" + reduced_dim = coeff_index.numel() + channels = grad_out.shape[2] + kernel = _ROTATE_TO_LOCAL_SMALL_BWD_DX[rotation_mode] + wrap_triton(kernel)[_small_channel_grid(channels)]( + grad_out, + wigner, + coeff_index, + grad_edge, + grad_out.shape[0], + reduced_dim, + dim_full, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + grad_edge.stride(0), + grad_edge.stride(1), + grad_edge.stride(2), + BLOCK_CHANNEL=TRITON_SMALL_BLOCK_CHANNEL, + GRID_E_STRIDE=TRITON_GRID_E_STRIDE, + num_warps=1, + ) + + def _launch_rotate_back_small_forward( + *, + rotation_mode: TritonRotationMode, + x_local: torch.Tensor, + wigner: torch.Tensor, + coeff_index: torch.Tensor, + out: torch.Tensor, + dim_full: int, + ) -> None: + """Launch one specialized small-family rotate-back forward kernel.""" + reduced_dim = coeff_index.numel() + channels = x_local.shape[2] + kernel = _ROTATE_BACK_SMALL_FORWARD[rotation_mode] + wrap_triton(kernel)[_small_channel_grid(channels)]( + x_local, + wigner, + coeff_index, + out, + x_local.shape[0], + reduced_dim, + dim_full, + channels, + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_CHANNEL=TRITON_SMALL_BLOCK_CHANNEL, + GRID_E_STRIDE=TRITON_GRID_E_STRIDE, + num_warps=1, + ) + + def _launch_rotate_back_small_bwd_dx( + *, + rotation_mode: TritonRotationMode, + grad_out: torch.Tensor, + wigner: torch.Tensor, + coeff_index: torch.Tensor, + grad_x_local: torch.Tensor, + dim_full: int, + ) -> None: + """Launch one specialized small-family rotate-back dx kernel.""" + reduced_dim = coeff_index.numel() + channels = grad_out.shape[2] + kernel = _ROTATE_BACK_SMALL_BWD_DX[rotation_mode] + wrap_triton(kernel)[_small_channel_grid(channels)]( + grad_out, + wigner, + coeff_index, + grad_x_local, + grad_out.shape[0], + reduced_dim, + dim_full, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + grad_x_local.stride(0), + grad_x_local.stride(1), + grad_x_local.stride(2), + BLOCK_CHANNEL=TRITON_SMALL_BLOCK_CHANNEL, + GRID_E_STRIDE=TRITON_GRID_E_STRIDE, + num_warps=1, + ) + + @triton_op( + "deepmd::_kernel_sezm_rotate_to_local", + mutates_args=("out",), + ) + def _kernel_sezm_rotate_to_local( + x: torch.Tensor, + src: torch.Tensor, + wigner: torch.Tensor, + coeff_index: torch.Tensor, + out: torch.Tensor, + dim_full: int, + rotation_mode: int, + ) -> None: + """Launch the fused Triton forward kernel for ``D_to_m @ x[src]``.""" + mode = _require_kernel_mode(rotation_mode) + reduced_dim = coeff_index.numel() + channels = x.shape[2] + if mode != TritonRotationMode.GENERIC_TILED: + _launch_rotate_to_local_small_forward( + rotation_mode=mode, + x=x, + src=src, + wigner=wigner, + coeff_index=coeff_index, + out=out, + dim_full=dim_full, + ) + return + wrap_triton(rotate_to_local_forward_kernel)[ + _generic_rotate_to_local_forward_grid(reduced_dim, channels) + ]( + x, + src, + wigner, + coeff_index, + out, + src.shape[0], + reduced_dim, + dim_full, + channels, + x.stride(0), + x.stride(1), + x.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_REDUCED=TRITON_BLOCK_REDUCED, + BLOCK_FULL=TRITON_BLOCK_FULL, + BLOCK_CHANNEL=TRITON_BLOCK_CHANNEL, + GRID_E_STRIDE=TRITON_GRID_E_STRIDE, + num_warps=1, + ) + + @triton_op( + "deepmd::_kernel_sezm_rotate_to_local_bwd_dx", + mutates_args=("grad_edge",), + ) + def _kernel_sezm_rotate_to_local_bwd_dx( + grad_out: torch.Tensor, + wigner: torch.Tensor, + coeff_index: torch.Tensor, + grad_edge: torch.Tensor, + dim_full: int, + rotation_mode: int, + ) -> None: + """Launch the Triton backward kernel for source-feature gradients.""" + mode = _require_kernel_mode(rotation_mode) + reduced_dim = coeff_index.numel() + channels = grad_out.shape[2] + if mode != TritonRotationMode.GENERIC_TILED: + _launch_rotate_to_local_small_bwd_dx( + rotation_mode=mode, + grad_out=grad_out, + wigner=wigner, + coeff_index=coeff_index, + grad_edge=grad_edge, + dim_full=dim_full, + ) + return + wrap_triton(rotate_to_local_bwd_dx_kernel)[ + _generic_rotate_to_local_bwd_dx_grid(dim_full, channels) + ]( + grad_out, + wigner, + coeff_index, + grad_edge, + grad_out.shape[0], + reduced_dim, + dim_full, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + grad_edge.stride(0), + grad_edge.stride(1), + grad_edge.stride(2), + BLOCK_REDUCED=TRITON_BLOCK_REDUCED, + BLOCK_FULL=TRITON_BLOCK_FULL, + BLOCK_CHANNEL=TRITON_BLOCK_CHANNEL, + GRID_E_STRIDE=TRITON_GRID_E_STRIDE, + num_warps=1, + ) + + @triton_op( + "deepmd::_kernel_sezm_rotate_to_local_bwd_dw", + mutates_args=("grad_wigner",), + ) + def _kernel_sezm_rotate_to_local_bwd_dw( + grad_out: torch.Tensor, + x: torch.Tensor, + src: torch.Tensor, + coeff_index: torch.Tensor, + grad_wigner: torch.Tensor, + dim_full: int, + rotation_mode: int, + ) -> None: + """Launch the Triton backward kernel for Wigner gradients.""" + mode = _require_kernel_mode(rotation_mode) + reduced_dim = coeff_index.numel() + channels = grad_out.shape[2] + if mode != TritonRotationMode.GENERIC_TILED: + wrap_triton(rotate_to_local_small_bwd_dw_kernel)[(TRITON_GRID_E_STRIDE,)]( + grad_out, + x, + src, + coeff_index, + grad_wigner, + grad_out.shape[0], + reduced_dim, + dim_full, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + x.stride(0), + x.stride(1), + x.stride(2), + grad_wigner.stride(0), + grad_wigner.stride(1), + grad_wigner.stride(2), + BLOCK_CHANNEL=TRITON_SMALL_BLOCK_CHANNEL, + GRID_E_STRIDE=TRITON_GRID_E_STRIDE, + num_warps=1, + ) + return + wrap_triton(rotate_to_local_bwd_dw_kernel)[ + _generic_rotate_to_local_bwd_dw_grid(reduced_dim, dim_full) + ]( + grad_out, + x, + src, + coeff_index, + grad_wigner, + grad_out.shape[0], + reduced_dim, + dim_full, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + x.stride(0), + x.stride(1), + x.stride(2), + grad_wigner.stride(0), + grad_wigner.stride(1), + grad_wigner.stride(2), + BLOCK_REDUCED=TRITON_BLOCK_REDUCED, + BLOCK_FULL=TRITON_BLOCK_FULL, + BLOCK_CHANNEL=TRITON_BLOCK_CHANNEL, + GRID_E_STRIDE=TRITON_GRID_E_STRIDE, + num_warps=1, + ) + + @triton_op( + "deepmd::_kernel_sezm_rotate_back", + mutates_args=("out",), + ) + def _kernel_sezm_rotate_back( + x_local: torch.Tensor, + wigner: torch.Tensor, + coeff_index: torch.Tensor, + out: torch.Tensor, + dim_full: int, + rotation_mode: int, + ) -> None: + """Launch the fused Triton forward kernel for ``Dt_from_m @ x_local``.""" + mode = _require_kernel_mode(rotation_mode) + reduced_dim = coeff_index.numel() + channels = x_local.shape[2] + if mode != TritonRotationMode.GENERIC_TILED: + _launch_rotate_back_small_forward( + rotation_mode=mode, + x_local=x_local, + wigner=wigner, + coeff_index=coeff_index, + out=out, + dim_full=dim_full, + ) + return + wrap_triton(rotate_back_forward_kernel)[ + _generic_rotate_back_forward_grid(dim_full, channels) + ]( + x_local, + wigner, + coeff_index, + out, + x_local.shape[0], + reduced_dim, + dim_full, + channels, + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_REDUCED=TRITON_BLOCK_REDUCED, + BLOCK_FULL=TRITON_BLOCK_FULL, + BLOCK_CHANNEL=TRITON_BLOCK_CHANNEL, + GRID_E_STRIDE=TRITON_GRID_E_STRIDE, + num_warps=1, + ) + + @triton_op( + "deepmd::_kernel_sezm_rotate_back_bwd_dx", + mutates_args=("grad_x_local",), + ) + def _kernel_sezm_rotate_back_bwd_dx( + grad_out: torch.Tensor, + wigner: torch.Tensor, + coeff_index: torch.Tensor, + grad_x_local: torch.Tensor, + dim_full: int, + rotation_mode: int, + ) -> None: + """Launch the Triton backward kernel for reduced-layout gradients.""" + mode = _require_kernel_mode(rotation_mode) + reduced_dim = coeff_index.numel() + channels = grad_out.shape[2] + if mode != TritonRotationMode.GENERIC_TILED: + _launch_rotate_back_small_bwd_dx( + rotation_mode=mode, + grad_out=grad_out, + wigner=wigner, + coeff_index=coeff_index, + grad_x_local=grad_x_local, + dim_full=dim_full, + ) + return + wrap_triton(rotate_back_bwd_dx_kernel)[ + _generic_rotate_back_bwd_dx_grid(reduced_dim, channels) + ]( + grad_out, + wigner, + coeff_index, + grad_x_local, + grad_out.shape[0], + reduced_dim, + dim_full, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + wigner.stride(0), + wigner.stride(1), + wigner.stride(2), + grad_x_local.stride(0), + grad_x_local.stride(1), + grad_x_local.stride(2), + BLOCK_REDUCED=TRITON_BLOCK_REDUCED, + BLOCK_FULL=TRITON_BLOCK_FULL, + BLOCK_CHANNEL=TRITON_BLOCK_CHANNEL, + GRID_E_STRIDE=TRITON_GRID_E_STRIDE, + num_warps=1, + ) + + @triton_op( + "deepmd::_kernel_sezm_rotate_back_bwd_dw", + mutates_args=("grad_wigner",), + ) + def _kernel_sezm_rotate_back_bwd_dw( + grad_out: torch.Tensor, + x_local: torch.Tensor, + coeff_index: torch.Tensor, + grad_wigner: torch.Tensor, + dim_full: int, + rotation_mode: int, + ) -> None: + """Launch the Triton backward kernel for Wigner gradients.""" + mode = _require_kernel_mode(rotation_mode) + reduced_dim = coeff_index.numel() + channels = grad_out.shape[2] + if mode != TritonRotationMode.GENERIC_TILED: + wrap_triton(rotate_back_small_bwd_dw_kernel)[(TRITON_GRID_E_STRIDE,)]( + grad_out, + x_local, + coeff_index, + grad_wigner, + grad_out.shape[0], + x_local.shape[1], + dim_full, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + grad_wigner.stride(0), + grad_wigner.stride(1), + grad_wigner.stride(2), + BLOCK_CHANNEL=TRITON_SMALL_BLOCK_CHANNEL, + GRID_E_STRIDE=TRITON_GRID_E_STRIDE, + num_warps=1, + ) + return + wrap_triton(rotate_back_bwd_dw_kernel)[ + _generic_rotate_back_bwd_dw_grid(dim_full, reduced_dim) + ]( + grad_out, + x_local, + grad_wigner, + grad_out.shape[0], + x_local.shape[1], + dim_full, + channels, + grad_out.stride(0), + grad_out.stride(1), + grad_out.stride(2), + x_local.stride(0), + x_local.stride(1), + x_local.stride(2), + grad_wigner.stride(0), + grad_wigner.stride(1), + grad_wigner.stride(2), + BLOCK_REDUCED=TRITON_BLOCK_REDUCED, + BLOCK_FULL=TRITON_BLOCK_FULL, + BLOCK_CHANNEL=TRITON_BLOCK_CHANNEL, + GRID_E_STRIDE=TRITON_GRID_E_STRIDE, + num_warps=1, + ) + + @triton_op( + "deepmd::_kernel_sezm_edge_geometry_rbf", + mutates_args=("edge_vec", "edge_len", "edge_env", "edge_rbf"), + ) + def _kernel_sezm_edge_geometry_rbf( + coord_flat: torch.Tensor, + center_coord_index: torch.Tensor, + neighbor_coord_index: torch.Tensor, + freqs: torch.Tensor, + edge_vec: torch.Tensor, + edge_len: torch.Tensor, + edge_env: torch.Tensor, + edge_rbf: torch.Tensor, + eps: float, + rcut: float, + edge_env_a: float, + edge_env_b: float, + edge_env_c: float, + edge_env_d: float, + edge_env_exponent: int, + radial_env_a: float, + radial_env_b: float, + radial_env_c: float, + radial_env_d: float, + radial_env_exponent: int, + r_inner: float, + r_outer: float, + has_inner_clamp: bool, + ) -> None: + """Launch the fused edge geometry/RBF forward kernel.""" + wrap_triton(edge_geometry_rbf_forward_kernel)[ + _edge_geometry_rbf_grid(center_coord_index.shape[0], freqs.numel()) + ]( + coord_flat, + center_coord_index, + neighbor_coord_index, + freqs, + edge_vec, + edge_len, + edge_env, + edge_rbf, + center_coord_index.shape[0], + freqs.numel(), + coord_flat.stride(0), + coord_flat.stride(1), + edge_vec.stride(0), + edge_vec.stride(1), + edge_rbf.stride(0), + edge_rbf.stride(1), + eps, + rcut, + edge_env_a, + edge_env_b, + edge_env_c, + edge_env_d, + radial_env_a, + radial_env_b, + radial_env_c, + radial_env_d, + r_inner, + r_outer, + EDGE_ENV_EXPONENT=int(edge_env_exponent), + RADIAL_ENV_EXPONENT=int(radial_env_exponent), + HAS_INNER_CLAMP=bool(has_inner_clamp), + BLOCK_EDGE=TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE, + BLOCK_RADIAL=TRITON_EDGE_GEOMETRY_RBF_BLOCK_RADIAL, + num_warps=4, + ) + + @triton_op( + "deepmd::_kernel_sezm_edge_geometry_rbf_bwd_accum", + mutates_args=("grad_r_total", "grad_freq"), + ) + def _kernel_sezm_edge_geometry_rbf_bwd_accum( + grad_edge_len: torch.Tensor, + grad_edge_env: torch.Tensor, + grad_edge_rbf: torch.Tensor, + coord_flat: torch.Tensor, + center_coord_index: torch.Tensor, + neighbor_coord_index: torch.Tensor, + freqs: torch.Tensor, + grad_r_total: torch.Tensor, + grad_freq: torch.Tensor, + eps: float, + rcut: float, + edge_env_a: float, + edge_env_b: float, + edge_env_c: float, + edge_env_d: float, + edge_env_exponent: int, + radial_env_a: float, + radial_env_b: float, + radial_env_c: float, + radial_env_d: float, + radial_env_exponent: int, + r_inner: float, + r_outer: float, + has_inner_clamp: bool, + ) -> None: + """Launch the fused edge geometry/RBF accumulation kernel.""" + wrap_triton(edge_geometry_rbf_bwd_accum_kernel)[ + _edge_geometry_rbf_grid(center_coord_index.shape[0], freqs.numel()) + ]( + grad_edge_len, + grad_edge_env, + grad_edge_rbf, + coord_flat, + center_coord_index, + neighbor_coord_index, + freqs, + grad_r_total, + grad_freq, + center_coord_index.shape[0], + freqs.numel(), + coord_flat.stride(0), + coord_flat.stride(1), + grad_edge_rbf.stride(0), + grad_edge_rbf.stride(1), + eps, + rcut, + edge_env_a, + edge_env_b, + edge_env_c, + edge_env_d, + radial_env_a, + radial_env_b, + radial_env_c, + radial_env_d, + r_inner, + r_outer, + EDGE_ENV_EXPONENT=int(edge_env_exponent), + RADIAL_ENV_EXPONENT=int(radial_env_exponent), + HAS_INNER_CLAMP=bool(has_inner_clamp), + BLOCK_EDGE=TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE, + BLOCK_RADIAL=TRITON_EDGE_GEOMETRY_RBF_BLOCK_RADIAL, + num_warps=4, + ) + + @triton_op( + "deepmd::_kernel_sezm_edge_geometry_rbf_bwd_coord", + mutates_args=("grad_coord",), + ) + def _kernel_sezm_edge_geometry_rbf_bwd_coord( + grad_edge_vec: torch.Tensor, + grad_r_total: torch.Tensor, + coord_flat: torch.Tensor, + center_coord_index: torch.Tensor, + neighbor_coord_index: torch.Tensor, + grad_coord: torch.Tensor, + eps: float, + r_inner: float, + r_outer: float, + has_inner_clamp: bool, + ) -> None: + """Launch the fused edge geometry/RBF coordinate backward kernel.""" + wrap_triton(edge_geometry_rbf_bwd_coord_kernel)[ + _edge_geometry_rbf_coord_grid(center_coord_index.shape[0]) + ]( + grad_edge_vec, + grad_r_total, + coord_flat, + center_coord_index, + neighbor_coord_index, + grad_coord, + center_coord_index.shape[0], + coord_flat.stride(0), + coord_flat.stride(1), + grad_edge_vec.stride(0), + grad_edge_vec.stride(1), + grad_coord.stride(0), + grad_coord.stride(1), + eps, + r_inner, + r_outer, + HAS_INNER_CLAMP=bool(has_inner_clamp), + BLOCK_EDGE=TRITON_EDGE_GEOMETRY_RBF_BLOCK_EDGE, + num_warps=4, + ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/dispatch.py b/deepmd/pt/model/descriptor/sezm_nn/triton/dispatch.py new file mode 100644 index 0000000000..5c16c6d759 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/dispatch.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Dispatch helpers for SeZM Triton rotation kernels.""" + +from __future__ import ( + annotations, +) + +from typing import ( + Final, +) + +import torch + +from .constants import ( + SEZM_TRITON_AVAILABLE, + TRITON_BLOCK_REDUCED, + TritonRotationMode, +) + +_SMALL_MODE_FROM_DIM: Final[dict[int, TritonRotationMode]] = { + 1: TritonRotationMode.SMALL_LE1, + 4: TritonRotationMode.SMALL_LE1, + 9: TritonRotationMode.SMALL_L2, + 16: TritonRotationMode.SMALL_L3, +} + + +def coerce_rotation_mode( + rotation_mode: int | TritonRotationMode, +) -> TritonRotationMode: + """ + Convert an integer-like dispatch value to ``TritonRotationMode``. + + Parameters + ---------- + rotation_mode + Rotation dispatch value. + + Returns + ------- + TritonRotationMode + Normalized rotation dispatch mode. + """ + if isinstance(rotation_mode, TritonRotationMode): + return rotation_mode + return TritonRotationMode(int(rotation_mode)) + + +def resolve_triton_rotation_mode( + *, + dim_full: int, + reduced_dim: int, +) -> TritonRotationMode: + """ + Resolve the SeZM rotation dispatch mode. + + Parameters + ---------- + dim_full + Full packed SO(3) dimension. + reduced_dim + Truncated m-major coefficient count. + + Returns + ------- + TritonRotationMode + Dispatch mode for the current ``(dim_full, reduced_dim)`` pair. + + Raises + ------ + ValueError + If either dimension is non-positive. + """ + dim_full = int(dim_full) + reduced_dim = int(reduced_dim) + if dim_full <= 0: + raise ValueError("dim_full must be positive") + if reduced_dim <= 0: + raise ValueError("reduced_dim must be positive") + base_mode = _SMALL_MODE_FROM_DIM.get( + dim_full, + TritonRotationMode.GENERIC_TILED, + ) + if ( + base_mode == TritonRotationMode.GENERIC_TILED + and reduced_dim < TRITON_BLOCK_REDUCED + ): + return TritonRotationMode.EAGER_REFERENCE + return base_mode + + +def sezm_triton_enabled( + *, + device: torch.device, + dtype: torch.dtype, +) -> bool: + """ + Return whether SeZM should enable the Triton rotation path. + + Parameters + ---------- + device + Target device for the rotation path. + dtype + Activation dtype for the rotation path. + + Returns + ------- + bool + Whether Triton kernels are available for the given device and dtype. + """ + supported_dtypes = (torch.float16, torch.bfloat16, torch.float32) + return bool( + SEZM_TRITON_AVAILABLE and device.type == "cuda" and dtype in supported_dtypes + ) + + +def uses_triton_kernel( + rotation_mode: int | TritonRotationMode, +) -> bool: + """ + Return whether the dispatch mode launches a Triton kernel. + + Parameters + ---------- + rotation_mode + Rotation dispatch value. + + Returns + ------- + bool + ``True`` when the mode launches a Triton kernel instead of eager fallback. + """ + return coerce_rotation_mode(rotation_mode) != TritonRotationMode.EAGER_REFERENCE diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.py b/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.py new file mode 100644 index 0000000000..b6235173c6 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_edge_geometry_rbf.py @@ -0,0 +1,550 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# pyright: reportMissingImports=false +# ruff: noqa: ANN001, ANN201, ANN202 +"""Triton kernels for the SeZM edge geometry/RBF chain. + +This file implements the standard non-compile path hot segment: + +``coord_gather -> edge_vec -> edge_len -> inner_clamp -> edge_env -> edge_rbf`` + +The kernels intentionally stop before Wigner-D construction so the existing eager +quaternion/Wigner path remains unchanged. +""" + +from __future__ import ( + annotations, +) + +import triton +import triton.language as tl + + +@triton.jit +def _pow_int(x, power: tl.constexpr): + """Raise ``x`` to a small compile-time integer power.""" + out = x * 0.0 + 1.0 + for _ in tl.static_range(power): + out = out * x + return out + + +@triton.jit +def _safe_sinc_no_pi(x): + """Compute ``sin(x) / x`` with a short Taylor branch near zero.""" + x2 = x * x + approx = 1.0 - x2 / 6.0 + (x2 * x2) / 120.0 + regular = tl.sin(x) / x + return tl.where(tl.abs(x) < 1.0e-4, approx, regular) + + +@triton.jit +def _safe_sinc_grad_no_pi(x): + """Compute ``d/dx [sin(x) / x]`` with a short Taylor branch near zero.""" + x2 = x * x + approx = -x / 3.0 + (x * x2) / 30.0 + regular = (x * tl.cos(x) - tl.sin(x)) / x2 + return tl.where(tl.abs(x) < 1.0e-4, approx, regular) + + +@triton.jit +def _compute_cutoff_envelope( + r, + rcut, + a, + b, + c, + d, + exponent: tl.constexpr, +): + """Evaluate the C^3 cutoff envelope on one distance vector.""" + x = tl.maximum(0.0, tl.minimum(r / rcut, 1.0)) + poly = a + x * (b + x * (c + x * d)) + env = 1.0 + _pow_int(x, exponent) * poly + return tl.where(x < 1.0, env, 0.0) + + +@triton.jit +def _compute_cutoff_envelope_grad( + r, + rcut, + a, + b, + c, + d, + exponent: tl.constexpr, +): + """Evaluate ``d envelope / d r`` on one distance vector.""" + x = tl.maximum(0.0, tl.minimum(r / rcut, 1.0)) + poly = a + x * (b + x * (c + x * d)) + poly_grad = b + 2.0 * c * x + 3.0 * d * x * x + if exponent == 1: + leading = poly + else: + leading = float(exponent) * _pow_int(x, exponent - 1) * poly + grad_x = leading + _pow_int(x, exponent) * poly_grad + return tl.where(x < 1.0, grad_x / rcut, 0.0) + + +@triton.jit +def _apply_inner_clamp( + raw_len, + r_inner, + r_outer, +): + """Apply the septic Hermite inner clamp.""" + delta = r_outer - r_inner + t = tl.maximum(0.0, tl.minimum((raw_len - r_inner) / delta, 1.0)) + t2 = t * t + t4 = t2 * t2 + h = t4 * (20.0 + t * (-45.0 + t * (36.0 - 10.0 * t))) + clamped = r_inner + delta * h + return tl.where(raw_len >= r_outer, raw_len, clamped) + + +@triton.jit +def _apply_inner_clamp_grad( + raw_len, + r_inner, + r_outer, +): + """Evaluate ``d clamp / d raw_len`` for the septic Hermite inner clamp.""" + delta = r_outer - r_inner + t = tl.maximum(0.0, tl.minimum((raw_len - r_inner) / delta, 1.0)) + t2 = t * t + t3 = t2 * t + grad = t3 * (80.0 + t * (-225.0 + t * (216.0 - 70.0 * t))) + return tl.where(raw_len >= r_outer, 1.0, grad) + + +@triton.jit +def edge_geometry_rbf_forward_kernel( + coord_ptr, + center_index_ptr, + neighbor_index_ptr, + freq_ptr, + edge_vec_ptr, + edge_len_ptr, + edge_env_ptr, + edge_rbf_ptr, + num_edges, + n_radial, + coord_stride_n, + coord_stride_c, + edge_vec_stride_e, + edge_vec_stride_c, + edge_rbf_stride_e, + edge_rbf_stride_r, + eps, + rcut, + edge_env_a, + edge_env_b, + edge_env_c, + edge_env_d, + radial_env_a, + radial_env_b, + radial_env_c, + radial_env_d, + r_inner, + r_outer, + EDGE_ENV_EXPONENT: tl.constexpr, + RADIAL_ENV_EXPONENT: tl.constexpr, + HAS_INNER_CLAMP: tl.constexpr, + BLOCK_EDGE: tl.constexpr, + BLOCK_RADIAL: tl.constexpr, +): + """Compute the fused edge geometry/RBF chain for one edge/radial tile.""" + pid_edge = tl.program_id(0) + pid_radial = tl.program_id(1) + + edge_offsets = pid_edge * BLOCK_EDGE + tl.arange(0, BLOCK_EDGE) + radial_offsets = pid_radial * BLOCK_RADIAL + tl.arange(0, BLOCK_RADIAL) + edge_mask = edge_offsets < num_edges + radial_mask = radial_offsets < n_radial + first_radial_mask = edge_mask & (pid_radial == 0) + + center_index = tl.load(center_index_ptr + edge_offsets, mask=edge_mask, other=0) + neighbor_index = tl.load(neighbor_index_ptr + edge_offsets, mask=edge_mask, other=0) + + center_x = tl.load( + coord_ptr + center_index * coord_stride_n + 0 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + center_y = tl.load( + coord_ptr + center_index * coord_stride_n + 1 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + center_z = tl.load( + coord_ptr + center_index * coord_stride_n + 2 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + neighbor_x = tl.load( + coord_ptr + neighbor_index * coord_stride_n + 0 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + neighbor_y = tl.load( + coord_ptr + neighbor_index * coord_stride_n + 1 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + neighbor_z = tl.load( + coord_ptr + neighbor_index * coord_stride_n + 2 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + + diff_x = neighbor_x - center_x + diff_y = neighbor_y - center_y + diff_z = neighbor_z - center_z + raw_len = tl.sqrt(diff_x * diff_x + diff_y * diff_y + diff_z * diff_z + eps * eps) + + if HAS_INNER_CLAMP: + clamped_len = _apply_inner_clamp(raw_len, r_inner, r_outer) + scale = clamped_len / raw_len + edge_vec_x = diff_x * scale + edge_vec_y = diff_y * scale + edge_vec_z = diff_z * scale + edge_len = clamped_len + else: + edge_vec_x = diff_x + edge_vec_y = diff_y + edge_vec_z = diff_z + edge_len = raw_len + + edge_env = _compute_cutoff_envelope( + edge_len, + rcut, + edge_env_a, + edge_env_b, + edge_env_c, + edge_env_d, + exponent=EDGE_ENV_EXPONENT, + ) + radial_env = _compute_cutoff_envelope( + edge_len, + rcut, + radial_env_a, + radial_env_b, + radial_env_c, + radial_env_d, + exponent=RADIAL_ENV_EXPONENT, + ) + + tl.store( + edge_vec_ptr + edge_offsets * edge_vec_stride_e + 0 * edge_vec_stride_c, + edge_vec_x, + mask=first_radial_mask, + ) + tl.store( + edge_vec_ptr + edge_offsets * edge_vec_stride_e + 1 * edge_vec_stride_c, + edge_vec_y, + mask=first_radial_mask, + ) + tl.store( + edge_vec_ptr + edge_offsets * edge_vec_stride_e + 2 * edge_vec_stride_c, + edge_vec_z, + mask=first_radial_mask, + ) + tl.store(edge_len_ptr + edge_offsets, edge_len, mask=first_radial_mask) + tl.store(edge_env_ptr + edge_offsets, edge_env, mask=first_radial_mask) + + freqs = tl.load(freq_ptr + radial_offsets, mask=radial_mask, other=0.0) + phase = edge_len[:, None] * freqs[None, :] + raw = freqs[None, :] * _safe_sinc_no_pi(phase) + edge_rbf = raw * radial_env[:, None] + tl.store( + edge_rbf_ptr + + edge_offsets[:, None] * edge_rbf_stride_e + + radial_offsets[None, :] * edge_rbf_stride_r, + edge_rbf, + mask=edge_mask[:, None] & radial_mask[None, :], + ) + + +@triton.jit +def edge_geometry_rbf_bwd_accum_kernel( + grad_edge_len_ptr, + grad_edge_env_ptr, + grad_edge_rbf_ptr, + coord_ptr, + center_index_ptr, + neighbor_index_ptr, + freq_ptr, + grad_r_total_ptr, + grad_freq_ptr, + num_edges, + n_radial, + coord_stride_n, + coord_stride_c, + grad_edge_rbf_stride_e, + grad_edge_rbf_stride_r, + eps, + rcut, + edge_env_a, + edge_env_b, + edge_env_c, + edge_env_d, + radial_env_a, + radial_env_b, + radial_env_c, + radial_env_d, + r_inner, + r_outer, + EDGE_ENV_EXPONENT: tl.constexpr, + RADIAL_ENV_EXPONENT: tl.constexpr, + HAS_INNER_CLAMP: tl.constexpr, + BLOCK_EDGE: tl.constexpr, + BLOCK_RADIAL: tl.constexpr, +): + """Accumulate scalar distance gradients and frequency gradients.""" + pid_edge = tl.program_id(0) + pid_radial = tl.program_id(1) + + edge_offsets = pid_edge * BLOCK_EDGE + tl.arange(0, BLOCK_EDGE) + radial_offsets = pid_radial * BLOCK_RADIAL + tl.arange(0, BLOCK_RADIAL) + edge_mask = edge_offsets < num_edges + radial_mask = radial_offsets < n_radial + + center_index = tl.load(center_index_ptr + edge_offsets, mask=edge_mask, other=0) + neighbor_index = tl.load(neighbor_index_ptr + edge_offsets, mask=edge_mask, other=0) + + center_x = tl.load( + coord_ptr + center_index * coord_stride_n + 0 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + center_y = tl.load( + coord_ptr + center_index * coord_stride_n + 1 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + center_z = tl.load( + coord_ptr + center_index * coord_stride_n + 2 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + neighbor_x = tl.load( + coord_ptr + neighbor_index * coord_stride_n + 0 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + neighbor_y = tl.load( + coord_ptr + neighbor_index * coord_stride_n + 1 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + neighbor_z = tl.load( + coord_ptr + neighbor_index * coord_stride_n + 2 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + + diff_x = neighbor_x - center_x + diff_y = neighbor_y - center_y + diff_z = neighbor_z - center_z + raw_len = tl.sqrt(diff_x * diff_x + diff_y * diff_y + diff_z * diff_z + eps * eps) + + if HAS_INNER_CLAMP: + edge_len = _apply_inner_clamp(raw_len, r_inner, r_outer) + else: + edge_len = raw_len + + radial_env = _compute_cutoff_envelope( + edge_len, + rcut, + radial_env_a, + radial_env_b, + radial_env_c, + radial_env_d, + exponent=RADIAL_ENV_EXPONENT, + ) + radial_env_grad = _compute_cutoff_envelope_grad( + edge_len, + rcut, + radial_env_a, + radial_env_b, + radial_env_c, + radial_env_d, + exponent=RADIAL_ENV_EXPONENT, + ) + + grad_edge_rbf = tl.load( + grad_edge_rbf_ptr + + edge_offsets[:, None] * grad_edge_rbf_stride_e + + radial_offsets[None, :] * grad_edge_rbf_stride_r, + mask=edge_mask[:, None] & radial_mask[None, :], + other=0.0, + ) + freqs = tl.load(freq_ptr + radial_offsets, mask=radial_mask, other=0.0) + phase = edge_len[:, None] * freqs[None, :] + raw = freqs[None, :] * _safe_sinc_no_pi(phase) + raw_grad_r = freqs[None, :] * freqs[None, :] * _safe_sinc_grad_no_pi(phase) + radial_grad_r = raw_grad_r * radial_env[:, None] + raw * radial_env_grad[:, None] + grad_rbf_to_r = tl.sum(grad_edge_rbf * radial_grad_r, axis=1) + tl.atomic_add(grad_r_total_ptr + edge_offsets, grad_rbf_to_r, mask=edge_mask) + + grad_freq = tl.sum(grad_edge_rbf * (radial_env[:, None] * tl.cos(phase)), axis=0) + tl.atomic_add(grad_freq_ptr + radial_offsets, grad_freq, mask=radial_mask) + + if pid_radial == 0: + grad_edge_len = tl.load( + grad_edge_len_ptr + edge_offsets, mask=edge_mask, other=0.0 + ) + grad_edge_env = tl.load( + grad_edge_env_ptr + edge_offsets, mask=edge_mask, other=0.0 + ) + edge_env_grad = _compute_cutoff_envelope_grad( + edge_len, + rcut, + edge_env_a, + edge_env_b, + edge_env_c, + edge_env_d, + exponent=EDGE_ENV_EXPONENT, + ) + base = grad_edge_len + grad_edge_env * edge_env_grad + tl.atomic_add(grad_r_total_ptr + edge_offsets, base, mask=edge_mask) + + +@triton.jit +def edge_geometry_rbf_bwd_coord_kernel( + grad_edge_vec_ptr, + grad_r_total_ptr, + coord_ptr, + center_index_ptr, + neighbor_index_ptr, + grad_coord_ptr, + num_edges, + coord_stride_n, + coord_stride_c, + grad_edge_vec_stride_e, + grad_edge_vec_stride_c, + grad_coord_stride_n, + grad_coord_stride_c, + eps, + r_inner, + r_outer, + HAS_INNER_CLAMP: tl.constexpr, + BLOCK_EDGE: tl.constexpr, +): + """Backpropagate the fused geometry/RBF chain into flat coordinates.""" + pid_edge = tl.program_id(0) + edge_offsets = pid_edge * BLOCK_EDGE + tl.arange(0, BLOCK_EDGE) + edge_mask = edge_offsets < num_edges + + center_index = tl.load(center_index_ptr + edge_offsets, mask=edge_mask, other=0) + neighbor_index = tl.load(neighbor_index_ptr + edge_offsets, mask=edge_mask, other=0) + + center_x = tl.load( + coord_ptr + center_index * coord_stride_n + 0 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + center_y = tl.load( + coord_ptr + center_index * coord_stride_n + 1 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + center_z = tl.load( + coord_ptr + center_index * coord_stride_n + 2 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + neighbor_x = tl.load( + coord_ptr + neighbor_index * coord_stride_n + 0 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + neighbor_y = tl.load( + coord_ptr + neighbor_index * coord_stride_n + 1 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + neighbor_z = tl.load( + coord_ptr + neighbor_index * coord_stride_n + 2 * coord_stride_c, + mask=edge_mask, + other=0.0, + ) + + diff_x = neighbor_x - center_x + diff_y = neighbor_y - center_y + diff_z = neighbor_z - center_z + raw_len = tl.sqrt(diff_x * diff_x + diff_y * diff_y + diff_z * diff_z + eps * eps) + + if HAS_INNER_CLAMP: + edge_len = _apply_inner_clamp(raw_len, r_inner, r_outer) + clamp_grad = _apply_inner_clamp_grad(raw_len, r_inner, r_outer) + scale = edge_len / raw_len + else: + edge_len = raw_len + clamp_grad = raw_len * 0.0 + 1.0 + scale = raw_len * 0.0 + 1.0 + + grad_edge_vec_x = tl.load( + grad_edge_vec_ptr + + edge_offsets * grad_edge_vec_stride_e + + 0 * grad_edge_vec_stride_c, + mask=edge_mask, + other=0.0, + ) + grad_edge_vec_y = tl.load( + grad_edge_vec_ptr + + edge_offsets * grad_edge_vec_stride_e + + 1 * grad_edge_vec_stride_c, + mask=edge_mask, + other=0.0, + ) + grad_edge_vec_z = tl.load( + grad_edge_vec_ptr + + edge_offsets * grad_edge_vec_stride_e + + 2 * grad_edge_vec_stride_c, + mask=edge_mask, + other=0.0, + ) + grad_r_total = tl.load(grad_r_total_ptr + edge_offsets, mask=edge_mask, other=0.0) + + dot_grad_vec = ( + grad_edge_vec_x * diff_x + grad_edge_vec_y * diff_y + grad_edge_vec_z * diff_z + ) + inv_raw_len = 1.0 / raw_len + scalar = grad_r_total * clamp_grad + dot_grad_vec * ( + (clamp_grad * raw_len - edge_len) * inv_raw_len * inv_raw_len + ) + grad_diff_common = scalar * inv_raw_len + grad_diff_x = grad_edge_vec_x * scale + diff_x * grad_diff_common + grad_diff_y = grad_edge_vec_y * scale + diff_y * grad_diff_common + grad_diff_z = grad_edge_vec_z * scale + diff_z * grad_diff_common + + tl.atomic_add( + grad_coord_ptr + neighbor_index * grad_coord_stride_n + 0 * grad_coord_stride_c, + grad_diff_x, + mask=edge_mask, + ) + tl.atomic_add( + grad_coord_ptr + neighbor_index * grad_coord_stride_n + 1 * grad_coord_stride_c, + grad_diff_y, + mask=edge_mask, + ) + tl.atomic_add( + grad_coord_ptr + neighbor_index * grad_coord_stride_n + 2 * grad_coord_stride_c, + grad_diff_z, + mask=edge_mask, + ) + tl.atomic_add( + grad_coord_ptr + center_index * grad_coord_stride_n + 0 * grad_coord_stride_c, + -grad_diff_x, + mask=edge_mask, + ) + tl.atomic_add( + grad_coord_ptr + center_index * grad_coord_stride_n + 1 * grad_coord_stride_c, + -grad_diff_y, + mask=edge_mask, + ) + tl.atomic_add( + grad_coord_ptr + center_index * grad_coord_stride_n + 2 * grad_coord_stride_c, + -grad_diff_z, + mask=edge_mask, + ) diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.py b/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.py new file mode 100644 index 0000000000..f2a89abda9 --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_generic.py @@ -0,0 +1,555 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# pyright: reportMissingImports=false +# ruff: noqa: ANN001, ANN201 +"""Generic tiled Triton kernels for SeZM SO(2) rotation hot paths. + +This file holds the variable-``lmax`` family used once the packed SO(3) block +no longer fits the small specialized kernels. The tile sizes are fixed on +purpose: ``BLOCK_FULL == BLOCK_REDUCED == 16`` keeps every ``tl.dot`` on a CUDA +shape that Triton accepts, and the kernels below explicitly request +``input_precision="ieee"`` so float32 matches eager PyTorch instead of TF32. +""" + +from __future__ import ( + annotations, +) + +import triton +import triton.language as tl + +# Keep both contraction dimensions at 16 so Triton always sees a legal dot tile. + + +@triton.jit +def rotate_to_local_forward_kernel( + x_ptr, + src_ptr, + wigner_ptr, + coeff_index_ptr, + out_ptr, + num_edges, + reduced_dim, + dim_full, + channels, + x_stride_n, + x_stride_d, + x_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + out_stride_e, + out_stride_r, + out_stride_c, + BLOCK_REDUCED: tl.constexpr, + BLOCK_FULL: tl.constexpr, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute fused row-projected Wigner rotation ``D_to_m @ x[src]``.""" + edge_id = tl.program_id(0) + reduced_block_id = tl.program_id(1) + channel_block_id = tl.program_id(2) + + reduced_offsets = reduced_block_id * BLOCK_REDUCED + tl.arange(0, BLOCK_REDUCED) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + reduced_mask = reduced_offsets < reduced_dim + channel_mask = channel_offsets < channels + + while edge_id < num_edges: + src_idx = tl.load(src_ptr + edge_id).to(tl.int64) + coeff_rows = tl.load( + coeff_index_ptr + reduced_offsets, + mask=reduced_mask, + other=0, + ).to(tl.int64) + acc = tl.zeros((BLOCK_REDUCED, BLOCK_CHANNEL), dtype=tl.float32) + + for full_block in range(0, tl.cdiv(dim_full, BLOCK_FULL)): + full_offsets = full_block * BLOCK_FULL + tl.arange(0, BLOCK_FULL) + full_mask = full_offsets < dim_full + wigner_ptrs = ( + wigner_ptr + + edge_id * wigner_stride_e + + coeff_rows[:, None] * wigner_stride_r + + full_offsets[None, :] * wigner_stride_k + ) + x_ptrs = ( + x_ptr + + src_idx * x_stride_n + + full_offsets[:, None] * x_stride_d + + channel_offsets[None, :] * x_stride_c + ) + w_block = tl.load( + wigner_ptrs, + mask=reduced_mask[:, None] & full_mask[None, :], + other=0.0, + ) + x_block = tl.load( + x_ptrs, + mask=full_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + # Match the eager autocast path: rotate in the activation dtype chosen + # by the current AMP context instead of forcing a higher Wigner dtype. + w_block = w_block.to(x_block.dtype) + acc = tl.dot( + w_block, + x_block, + acc, + input_precision="ieee", + ) + + out_ptrs = ( + out_ptr + + edge_id * out_stride_e + + reduced_offsets[:, None] * out_stride_r + + channel_offsets[None, :] * out_stride_c + ) + tl.store( + out_ptrs, + acc, + mask=reduced_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_to_local_bwd_dx_kernel( + grad_out_ptr, + wigner_ptr, + coeff_index_ptr, + grad_edge_ptr, + num_edges, + reduced_dim, + dim_full, + channels, + grad_out_stride_e, + grad_out_stride_r, + grad_out_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + grad_edge_stride_e, + grad_edge_stride_d, + grad_edge_stride_c, + BLOCK_REDUCED: tl.constexpr, + BLOCK_FULL: tl.constexpr, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute per-edge source gradients ``D_to_m^T @ grad`` before scatter.""" + edge_id = tl.program_id(0) + full_block_id = tl.program_id(1) + channel_block_id = tl.program_id(2) + + full_offsets = full_block_id * BLOCK_FULL + tl.arange(0, BLOCK_FULL) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + full_mask = full_offsets < dim_full + channel_mask = channel_offsets < channels + + while edge_id < num_edges: + acc = tl.zeros((BLOCK_FULL, BLOCK_CHANNEL), dtype=tl.float32) + + for reduced_block in range(0, tl.cdiv(reduced_dim, BLOCK_REDUCED)): + reduced_offsets = reduced_block * BLOCK_REDUCED + tl.arange( + 0, BLOCK_REDUCED + ) + reduced_mask = reduced_offsets < reduced_dim + coeff_rows = tl.load( + coeff_index_ptr + reduced_offsets, + mask=reduced_mask, + other=0, + ).to(tl.int64) + wigner_ptrs = ( + wigner_ptr + + edge_id * wigner_stride_e + + coeff_rows[:, None] * wigner_stride_r + + full_offsets[None, :] * wigner_stride_k + ) + grad_ptrs = ( + grad_out_ptr + + edge_id * grad_out_stride_e + + reduced_offsets[:, None] * grad_out_stride_r + + channel_offsets[None, :] * grad_out_stride_c + ) + w_block = tl.load( + wigner_ptrs, + mask=reduced_mask[:, None] & full_mask[None, :], + other=0.0, + ) + grad_block = tl.load( + grad_ptrs, + mask=reduced_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + w_block = w_block.to(grad_block.dtype) + acc = tl.dot( + tl.trans(w_block), + grad_block, + acc, + input_precision="ieee", + ) + + grad_edge_ptrs = ( + grad_edge_ptr + + edge_id * grad_edge_stride_e + + full_offsets[:, None] * grad_edge_stride_d + + channel_offsets[None, :] * grad_edge_stride_c + ) + tl.store( + grad_edge_ptrs, + acc, + mask=full_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_to_local_bwd_dw_kernel( + grad_out_ptr, + x_ptr, + src_ptr, + coeff_index_ptr, + grad_rows_ptr, + num_edges, + reduced_dim, + dim_full, + channels, + grad_out_stride_e, + grad_out_stride_r, + grad_out_stride_c, + x_stride_n, + x_stride_d, + x_stride_c, + grad_rows_stride_e, + grad_rows_stride_r, + grad_rows_stride_d, + BLOCK_REDUCED: tl.constexpr, + BLOCK_FULL: tl.constexpr, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute row-selected Wigner gradients ``grad @ x[src]^T``.""" + edge_id = tl.program_id(0) + reduced_block_id = tl.program_id(1) + full_block_id = tl.program_id(2) + + reduced_offsets = reduced_block_id * BLOCK_REDUCED + tl.arange(0, BLOCK_REDUCED) + full_offsets = full_block_id * BLOCK_FULL + tl.arange(0, BLOCK_FULL) + reduced_mask = reduced_offsets < reduced_dim + full_mask = full_offsets < dim_full + + while edge_id < num_edges: + src_idx = tl.load(src_ptr + edge_id).to(tl.int64) + acc = tl.zeros((BLOCK_REDUCED, BLOCK_FULL), dtype=tl.float32) + + for channel_block in range(0, tl.cdiv(channels, BLOCK_CHANNEL)): + channel_offsets = channel_block * BLOCK_CHANNEL + tl.arange( + 0, BLOCK_CHANNEL + ) + channel_mask = channel_offsets < channels + grad_ptrs = ( + grad_out_ptr + + edge_id * grad_out_stride_e + + reduced_offsets[:, None] * grad_out_stride_r + + channel_offsets[None, :] * grad_out_stride_c + ) + x_ptrs = ( + x_ptr + + src_idx * x_stride_n + + full_offsets[:, None] * x_stride_d + + channel_offsets[None, :] * x_stride_c + ) + grad_block = tl.load( + grad_ptrs, + mask=reduced_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + x_block = tl.load( + x_ptrs, + mask=full_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + acc = tl.dot( + grad_block, + tl.trans(x_block), + acc, + input_precision="ieee", + ) + + grad_rows_ptrs = ( + grad_rows_ptr + + edge_id * grad_rows_stride_e + + reduced_offsets[:, None] * grad_rows_stride_r + + full_offsets[None, :] * grad_rows_stride_d + ) + tl.store( + grad_rows_ptrs, + acc, + mask=reduced_mask[:, None] & full_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_back_forward_kernel( + x_local_ptr, + wigner_ptr, + coeff_index_ptr, + out_ptr, + num_edges, + reduced_dim, + dim_full, + channels, + x_local_stride_e, + x_local_stride_r, + x_local_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + out_stride_e, + out_stride_d, + out_stride_c, + BLOCK_REDUCED: tl.constexpr, + BLOCK_FULL: tl.constexpr, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute fused inverse rotation ``Dt_from_m @ x_local``.""" + edge_id = tl.program_id(0) + full_block_id = tl.program_id(1) + channel_block_id = tl.program_id(2) + + full_offsets = full_block_id * BLOCK_FULL + tl.arange(0, BLOCK_FULL) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + full_mask = full_offsets < dim_full + channel_mask = channel_offsets < channels + + while edge_id < num_edges: + acc = tl.zeros((BLOCK_FULL, BLOCK_CHANNEL), dtype=tl.float32) + + for reduced_block in range(0, tl.cdiv(reduced_dim, BLOCK_REDUCED)): + reduced_offsets = reduced_block * BLOCK_REDUCED + tl.arange( + 0, BLOCK_REDUCED + ) + reduced_mask = reduced_offsets < reduced_dim + coeff_cols = tl.load( + coeff_index_ptr + reduced_offsets, + mask=reduced_mask, + other=0, + ).to(tl.int64) + wigner_ptrs = ( + wigner_ptr + + edge_id * wigner_stride_e + + full_offsets[:, None] * wigner_stride_r + + coeff_cols[None, :] * wigner_stride_k + ) + x_ptrs = ( + x_local_ptr + + edge_id * x_local_stride_e + + reduced_offsets[:, None] * x_local_stride_r + + channel_offsets[None, :] * x_local_stride_c + ) + w_block = tl.load( + wigner_ptrs, + mask=full_mask[:, None] & reduced_mask[None, :], + other=0.0, + ) + x_block = tl.load( + x_ptrs, + mask=reduced_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + w_block = w_block.to(x_block.dtype) + acc = tl.dot( + w_block, + x_block, + acc, + input_precision="ieee", + ) + + out_ptrs = ( + out_ptr + + edge_id * out_stride_e + + full_offsets[:, None] * out_stride_d + + channel_offsets[None, :] * out_stride_c + ) + tl.store( + out_ptrs, + acc, + mask=full_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_back_bwd_dx_kernel( + grad_out_ptr, + wigner_ptr, + coeff_index_ptr, + grad_x_ptr, + num_edges, + reduced_dim, + dim_full, + channels, + grad_out_stride_e, + grad_out_stride_d, + grad_out_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + grad_x_stride_e, + grad_x_stride_r, + grad_x_stride_c, + BLOCK_REDUCED: tl.constexpr, + BLOCK_FULL: tl.constexpr, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute reduced-layout gradients ``Dt_from_m^T @ grad``.""" + edge_id = tl.program_id(0) + reduced_block_id = tl.program_id(1) + channel_block_id = tl.program_id(2) + + reduced_offsets = reduced_block_id * BLOCK_REDUCED + tl.arange(0, BLOCK_REDUCED) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + reduced_mask = reduced_offsets < reduced_dim + channel_mask = channel_offsets < channels + + while edge_id < num_edges: + coeff_cols = tl.load( + coeff_index_ptr + reduced_offsets, + mask=reduced_mask, + other=0, + ).to(tl.int64) + acc = tl.zeros((BLOCK_REDUCED, BLOCK_CHANNEL), dtype=tl.float32) + + for full_block in range(0, tl.cdiv(dim_full, BLOCK_FULL)): + full_offsets = full_block * BLOCK_FULL + tl.arange(0, BLOCK_FULL) + full_mask = full_offsets < dim_full + wigner_ptrs = ( + wigner_ptr + + edge_id * wigner_stride_e + + full_offsets[:, None] * wigner_stride_r + + coeff_cols[None, :] * wigner_stride_k + ) + grad_ptrs = ( + grad_out_ptr + + edge_id * grad_out_stride_e + + full_offsets[:, None] * grad_out_stride_d + + channel_offsets[None, :] * grad_out_stride_c + ) + w_block = tl.load( + wigner_ptrs, + mask=full_mask[:, None] & reduced_mask[None, :], + other=0.0, + ) + grad_block = tl.load( + grad_ptrs, + mask=full_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + w_block = w_block.to(grad_block.dtype) + acc = tl.dot( + tl.trans(w_block), + grad_block, + acc, + input_precision="ieee", + ) + + grad_x_ptrs = ( + grad_x_ptr + + edge_id * grad_x_stride_e + + reduced_offsets[:, None] * grad_x_stride_r + + channel_offsets[None, :] * grad_x_stride_c + ) + tl.store( + grad_x_ptrs, + acc, + mask=reduced_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_back_bwd_dw_kernel( + grad_out_ptr, + x_local_ptr, + grad_cols_ptr, + num_edges, + reduced_dim, + dim_full, + channels, + grad_out_stride_e, + grad_out_stride_d, + grad_out_stride_c, + x_local_stride_e, + x_local_stride_r, + x_local_stride_c, + grad_cols_stride_e, + grad_cols_stride_d, + grad_cols_stride_r, + BLOCK_REDUCED: tl.constexpr, + BLOCK_FULL: tl.constexpr, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute column-selected inverse Wigner gradients ``grad @ x_local^T``.""" + edge_id = tl.program_id(0) + full_block_id = tl.program_id(1) + reduced_block_id = tl.program_id(2) + + full_offsets = full_block_id * BLOCK_FULL + tl.arange(0, BLOCK_FULL) + reduced_offsets = reduced_block_id * BLOCK_REDUCED + tl.arange(0, BLOCK_REDUCED) + full_mask = full_offsets < dim_full + reduced_mask = reduced_offsets < reduced_dim + + while edge_id < num_edges: + acc = tl.zeros((BLOCK_FULL, BLOCK_REDUCED), dtype=tl.float32) + + for channel_block in range(0, tl.cdiv(channels, BLOCK_CHANNEL)): + channel_offsets = channel_block * BLOCK_CHANNEL + tl.arange( + 0, BLOCK_CHANNEL + ) + channel_mask = channel_offsets < channels + grad_ptrs = ( + grad_out_ptr + + edge_id * grad_out_stride_e + + full_offsets[:, None] * grad_out_stride_d + + channel_offsets[None, :] * grad_out_stride_c + ) + x_ptrs = ( + x_local_ptr + + edge_id * x_local_stride_e + + reduced_offsets[:, None] * x_local_stride_r + + channel_offsets[None, :] * x_local_stride_c + ) + grad_block = tl.load( + grad_ptrs, + mask=full_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + x_block = tl.load( + x_ptrs, + mask=reduced_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + acc = tl.dot( + grad_block, + tl.trans(x_block), + acc, + input_precision="ieee", + ) + + grad_cols_ptrs = ( + grad_cols_ptr + + edge_id * grad_cols_stride_e + + full_offsets[:, None] * grad_cols_stride_d + + reduced_offsets[None, :] * grad_cols_stride_r + ) + tl.store( + grad_cols_ptrs, + acc, + mask=full_mask[:, None] & reduced_mask[None, :], + ) + edge_id += GRID_E_STRIDE diff --git a/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_small.py b/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_small.py new file mode 100644 index 0000000000..524acfd72f --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/triton/kernels_small.py @@ -0,0 +1,1317 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +# pyright: reportMissingImports=false +# ruff: noqa: ANN001, ANN201 +"""Specialized small-family Triton kernels for SeZM SO(2) rotations. + +These kernels are the intended fast path for ``lmax <= 3``. They keep one +masked ``16x16`` Wigner tile in registers, so ``lmax=0`` and ``lmax=1`` can +share the same specialized family without paying the loop overhead of the +generic tiled kernels. +""" + +from __future__ import ( + annotations, +) + +import triton +import triton.language as tl + +from .constants import TRITON_SMALL_FULL_DIM as TRITON_SMALL_FULL_DIM_VALUE + +# Small kernels always materialize one padded ``16x16`` block and mask tails. +TRITON_SMALL_FULL_DIM = tl.constexpr(TRITON_SMALL_FULL_DIM_VALUE) + + +@triton.jit +def _load_full_wigner_matrix( + wigner_ptr, + edge_id, + full_dim, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, +) -> tl.tensor: + """Load one padded ``16x16`` Wigner block in l-major order.""" + full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) + full_mask = full_offsets < full_dim + wigner_ptrs = ( + wigner_ptr + + edge_id * wigner_stride_e + + full_offsets[:, None] * wigner_stride_r + + full_offsets[None, :] * wigner_stride_k + ) + return tl.load( + wigner_ptrs, + mask=full_mask[:, None] & full_mask[None, :], + other=0.0, + ) + + +@triton.jit +def _load_full_node_values( + x_ptr, + node_idx, + full_dim, + channel_offsets, + channel_mask, + x_stride_n, + x_stride_d, + x_stride_c, +) -> tl.tensor: + """Load one padded ``16xC`` node feature block in l-major order.""" + full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) + full_mask = full_offsets < full_dim + x_ptrs = ( + x_ptr + + node_idx * x_stride_n + + full_offsets[:, None] * x_stride_d + + channel_offsets[None, :] * x_stride_c + ) + return tl.load( + x_ptrs, + mask=full_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + + +@triton.jit +def _load_reduced_values_with_index( + x_ptr, + coeff_index_ptr, + edge_id, + reduced_dim, + channel_offsets, + channel_mask, + x_stride_e, + x_stride_r, + x_stride_c, +) -> tuple[tl.tensor, tl.tensor, tl.tensor]: + """Load reduced values together with the padded reduced->full row mapping.""" + reduced_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) + reduced_mask = reduced_offsets < reduced_dim + x_ptrs = ( + x_ptr + + edge_id * x_stride_e + + reduced_offsets[:, None] * x_stride_r + + channel_offsets[None, :] * x_stride_c + ) + reduced_values = tl.load( + x_ptrs, + mask=reduced_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + coeff_rows = tl.load( + coeff_index_ptr + reduced_offsets, + mask=reduced_mask, + other=-1, + ).to(tl.int64) + return reduced_values, reduced_mask, coeff_rows + + +@triton.jit +def _scatter_reduced_to_full_matrix( + reduced_values, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL: tl.constexpr, +) -> tl.tensor: + """Scatter a padded reduced block into a padded full l-major block.""" + row_ids = tl.arange(0, TRITON_SMALL_FULL_DIM) + full_values = tl.zeros( + (TRITON_SMALL_FULL_DIM, BLOCK_CHANNEL), + dtype=reduced_values.dtype, + ) + for row in range(TRITON_SMALL_FULL_DIM): + row_mask = (coeff_rows == row)[:, None] & reduced_mask[:, None] + row_value = tl.sum(tl.where(row_mask, reduced_values, 0.0), axis=0).to( + reduced_values.dtype + ) + full_values = tl.where( + row_ids[:, None] == row, + row_value[None, :], + full_values, + ) + return full_values + + +@triton.jit +def _select_reduced_from_full_matrix( + full_values, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL: tl.constexpr, +) -> tl.tensor: + """Select reduced rows from a padded full l-major block.""" + row_ids = tl.arange(0, TRITON_SMALL_FULL_DIM) + reduced_values = tl.zeros( + (TRITON_SMALL_FULL_DIM, BLOCK_CHANNEL), + dtype=full_values.dtype, + ) + for row in range(TRITON_SMALL_FULL_DIM): + row_value = tl.sum( + tl.where(row_ids[:, None] == row, full_values, 0.0), + axis=0, + ).to(full_values.dtype) + reduced_values = tl.where( + (coeff_rows == row)[:, None] & reduced_mask[:, None], + row_value[None, :], + reduced_values, + ) + return reduced_values + + +@triton.jit +def _build_full_matrix_l1( + y0, + y1, + y2, + y3, + BLOCK_CHANNEL: tl.constexpr, +) -> tl.tensor: + """Build a padded full matrix from the ``lmax=1`` row vectors.""" + row_ids = tl.arange(0, TRITON_SMALL_FULL_DIM) + full_values = tl.zeros( + (TRITON_SMALL_FULL_DIM, BLOCK_CHANNEL), + dtype=tl.float32, + ) + full_values = tl.where(row_ids[:, None] == 0, y0[None, :], full_values) + full_values = tl.where(row_ids[:, None] == 1, y1[None, :], full_values) + full_values = tl.where(row_ids[:, None] == 2, y2[None, :], full_values) + full_values = tl.where(row_ids[:, None] == 3, y3[None, :], full_values) + return full_values + + +@triton.jit +def _build_full_matrix_l2( + y0, + y1, + y2, + y3, + y4, + y5, + y6, + y7, + y8, + BLOCK_CHANNEL: tl.constexpr, +) -> tl.tensor: + """Build a padded full matrix from the ``lmax=2`` row vectors.""" + row_ids = tl.arange(0, TRITON_SMALL_FULL_DIM) + full_values = tl.zeros( + (TRITON_SMALL_FULL_DIM, BLOCK_CHANNEL), + dtype=tl.float32, + ) + full_values = tl.where(row_ids[:, None] == 0, y0[None, :], full_values) + full_values = tl.where(row_ids[:, None] == 1, y1[None, :], full_values) + full_values = tl.where(row_ids[:, None] == 2, y2[None, :], full_values) + full_values = tl.where(row_ids[:, None] == 3, y3[None, :], full_values) + full_values = tl.where(row_ids[:, None] == 4, y4[None, :], full_values) + full_values = tl.where(row_ids[:, None] == 5, y5[None, :], full_values) + full_values = tl.where(row_ids[:, None] == 6, y6[None, :], full_values) + full_values = tl.where(row_ids[:, None] == 7, y7[None, :], full_values) + full_values = tl.where(row_ids[:, None] == 8, y8[None, :], full_values) + return full_values + + +@triton.jit +def _matvec_l1( + w_full, + x_full, +) -> tl.tensor: + """Apply the packed ``lmax=1`` block-diagonal Wigner matrix.""" + return tl.dot(w_full.to(x_full.dtype), x_full, input_precision="ieee") + + +@triton.jit +def _matvec_t_l1( + w_full, + x_full, +) -> tl.tensor: + """Apply the transpose of the packed ``lmax=1`` Wigner matrix.""" + return tl.dot( + tl.trans(w_full.to(x_full.dtype)), + x_full, + input_precision="ieee", + ) + + +@triton.jit +def _matvec_l2( + w_full, + x_full, +) -> tl.tensor: + """Apply the packed ``lmax=2`` block-diagonal Wigner matrix.""" + return tl.dot(w_full.to(x_full.dtype), x_full, input_precision="ieee") + + +@triton.jit +def _matvec_t_l2( + w_full, + x_full, +) -> tl.tensor: + """Apply the transpose of the packed ``lmax=2`` Wigner matrix.""" + return tl.dot( + tl.trans(w_full.to(x_full.dtype)), + x_full, + input_precision="ieee", + ) + + +@triton.jit +def rotate_to_local_l1_forward_kernel( + x_ptr, + src_ptr, + wigner_ptr, + coeff_index_ptr, + out_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + x_stride_n, + x_stride_d, + x_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + out_stride_e, + out_stride_r, + out_stride_c, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Fused ``global -> local reduced`` rotation specialized for ``lmax=1``.""" + edge_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + channel_mask = channel_offsets < channels + while edge_id < num_edges: + src_idx = tl.load(src_ptr + edge_id).to(tl.int64) + coeff_rows = tl.load( + coeff_index_ptr + tl.arange(0, TRITON_SMALL_FULL_DIM), + mask=tl.arange(0, TRITON_SMALL_FULL_DIM) < reduced_dim, + other=-1, + ).to(tl.int64) + reduced_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < reduced_dim + x_full = _load_full_node_values( + x_ptr, + src_idx, + full_dim, + channel_offsets, + channel_mask, + x_stride_n, + x_stride_d, + x_stride_c, + ) + w_full = _load_full_wigner_matrix( + wigner_ptr, + edge_id, + full_dim, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + ).to(x_full.dtype) + y_full = _matvec_l1(w_full, x_full) + out_values = _select_reduced_from_full_matrix( + y_full, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + out_ptrs = ( + out_ptr + + edge_id * out_stride_e + + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * out_stride_r + + channel_offsets[None, :] * out_stride_c + ) + tl.store( + out_ptrs, + out_values, + mask=reduced_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_to_local_l2_forward_kernel( + x_ptr, + src_ptr, + wigner_ptr, + coeff_index_ptr, + out_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + x_stride_n, + x_stride_d, + x_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + out_stride_e, + out_stride_r, + out_stride_c, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Fused ``global -> local reduced`` rotation specialized for ``lmax=2``.""" + edge_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + channel_mask = channel_offsets < channels + while edge_id < num_edges: + src_idx = tl.load(src_ptr + edge_id).to(tl.int64) + coeff_rows = tl.load( + coeff_index_ptr + tl.arange(0, TRITON_SMALL_FULL_DIM), + mask=tl.arange(0, TRITON_SMALL_FULL_DIM) < reduced_dim, + other=-1, + ).to(tl.int64) + reduced_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < reduced_dim + x_full = _load_full_node_values( + x_ptr, + src_idx, + full_dim, + channel_offsets, + channel_mask, + x_stride_n, + x_stride_d, + x_stride_c, + ) + w_full = _load_full_wigner_matrix( + wigner_ptr, + edge_id, + full_dim, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + ).to(x_full.dtype) + y_full = _matvec_l2(w_full, x_full) + out_values = _select_reduced_from_full_matrix( + y_full, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + out_ptrs = ( + out_ptr + + edge_id * out_stride_e + + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * out_stride_r + + channel_offsets[None, :] * out_stride_c + ) + tl.store( + out_ptrs, + out_values, + mask=reduced_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_to_local_l3_forward_kernel( + x_ptr, + src_ptr, + wigner_ptr, + coeff_index_ptr, + out_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + x_stride_n, + x_stride_d, + x_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + out_stride_e, + out_stride_r, + out_stride_c, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Fused ``global -> local reduced`` rotation specialized for ``lmax=3``.""" + edge_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + channel_mask = channel_offsets < channels + while edge_id < num_edges: + src_idx = tl.load(src_ptr + edge_id).to(tl.int64) + coeff_rows = tl.load( + coeff_index_ptr + tl.arange(0, TRITON_SMALL_FULL_DIM), + mask=tl.arange(0, TRITON_SMALL_FULL_DIM) < reduced_dim, + other=-1, + ).to(tl.int64) + reduced_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < reduced_dim + x_full = _load_full_node_values( + x_ptr, + src_idx, + full_dim, + channel_offsets, + channel_mask, + x_stride_n, + x_stride_d, + x_stride_c, + ) + w_full = _load_full_wigner_matrix( + wigner_ptr, + edge_id, + full_dim, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + ).to(x_full.dtype) + y_full = tl.dot(w_full, x_full, input_precision="ieee") + out_values = _select_reduced_from_full_matrix( + y_full, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + out_ptrs = ( + out_ptr + + edge_id * out_stride_e + + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * out_stride_r + + channel_offsets[None, :] * out_stride_c + ) + tl.store( + out_ptrs, + out_values, + mask=reduced_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_to_local_l1_bwd_dx_kernel( + grad_out_ptr, + wigner_ptr, + coeff_index_ptr, + grad_edge_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + grad_out_stride_e, + grad_out_stride_r, + grad_out_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + grad_edge_stride_e, + grad_edge_stride_d, + grad_edge_stride_c, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute per-edge source gradients specialized for ``lmax=1``.""" + edge_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + channel_mask = channel_offsets < channels + while edge_id < num_edges: + grad_reduced, reduced_mask, coeff_rows = _load_reduced_values_with_index( + grad_out_ptr, + coeff_index_ptr, + edge_id, + reduced_dim, + channel_offsets, + channel_mask, + grad_out_stride_e, + grad_out_stride_r, + grad_out_stride_c, + ) + grad_full = _scatter_reduced_to_full_matrix( + grad_reduced, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + w_full = _load_full_wigner_matrix( + wigner_ptr, + edge_id, + full_dim, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + ).to(grad_full.dtype) + dx_full = _matvec_t_l1(w_full, grad_full) + full_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < full_dim + grad_edge_ptrs = ( + grad_edge_ptr + + edge_id * grad_edge_stride_e + + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * grad_edge_stride_d + + channel_offsets[None, :] * grad_edge_stride_c + ) + tl.store( + grad_edge_ptrs, + dx_full, + mask=full_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_to_local_l2_bwd_dx_kernel( + grad_out_ptr, + wigner_ptr, + coeff_index_ptr, + grad_edge_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + grad_out_stride_e, + grad_out_stride_r, + grad_out_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + grad_edge_stride_e, + grad_edge_stride_d, + grad_edge_stride_c, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute per-edge source gradients specialized for ``lmax=2``.""" + edge_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + channel_mask = channel_offsets < channels + while edge_id < num_edges: + grad_reduced, reduced_mask, coeff_rows = _load_reduced_values_with_index( + grad_out_ptr, + coeff_index_ptr, + edge_id, + reduced_dim, + channel_offsets, + channel_mask, + grad_out_stride_e, + grad_out_stride_r, + grad_out_stride_c, + ) + grad_full = _scatter_reduced_to_full_matrix( + grad_reduced, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + w_full = _load_full_wigner_matrix( + wigner_ptr, + edge_id, + full_dim, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + ).to(grad_full.dtype) + dx_full = _matvec_t_l2(w_full, grad_full) + full_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < full_dim + grad_edge_ptrs = ( + grad_edge_ptr + + edge_id * grad_edge_stride_e + + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * grad_edge_stride_d + + channel_offsets[None, :] * grad_edge_stride_c + ) + tl.store( + grad_edge_ptrs, + dx_full, + mask=full_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_to_local_l3_bwd_dx_kernel( + grad_out_ptr, + wigner_ptr, + coeff_index_ptr, + grad_edge_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + grad_out_stride_e, + grad_out_stride_r, + grad_out_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + grad_edge_stride_e, + grad_edge_stride_d, + grad_edge_stride_c, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute per-edge source gradients specialized for ``lmax=3``.""" + edge_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + channel_mask = channel_offsets < channels + while edge_id < num_edges: + grad_reduced, reduced_mask, coeff_rows = _load_reduced_values_with_index( + grad_out_ptr, + coeff_index_ptr, + edge_id, + reduced_dim, + channel_offsets, + channel_mask, + grad_out_stride_e, + grad_out_stride_r, + grad_out_stride_c, + ) + grad_full = _scatter_reduced_to_full_matrix( + grad_reduced, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + w_full = _load_full_wigner_matrix( + wigner_ptr, + edge_id, + full_dim, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + ).to(grad_full.dtype) + dx_full = tl.dot( + tl.trans(w_full), + grad_full, + input_precision="ieee", + ) + full_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < full_dim + grad_edge_ptrs = ( + grad_edge_ptr + + edge_id * grad_edge_stride_e + + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * grad_edge_stride_d + + channel_offsets[None, :] * grad_edge_stride_c + ) + tl.store( + grad_edge_ptrs, + dx_full, + mask=full_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_to_local_small_bwd_dw_kernel( + grad_out_ptr, + x_ptr, + src_ptr, + coeff_index_ptr, + grad_wigner_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + grad_out_stride_e, + grad_out_stride_r, + grad_out_stride_c, + x_stride_n, + x_stride_d, + x_stride_c, + grad_wigner_stride_e, + grad_wigner_stride_r, + grad_wigner_stride_k, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute full Wigner gradients for specialized small-l rotate-to-local.""" + edge_id = tl.program_id(0) + channel_offsets = tl.arange(0, BLOCK_CHANNEL) + full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) + while edge_id < num_edges: + coeff_rows = tl.load( + coeff_index_ptr + full_offsets, + mask=full_offsets < reduced_dim, + other=-1, + ).to(tl.int64) + reduced_mask = full_offsets < reduced_dim + src_idx = tl.load(src_ptr + edge_id).to(tl.int64) + grad_w_acc = tl.zeros( + (TRITON_SMALL_FULL_DIM, TRITON_SMALL_FULL_DIM), + dtype=tl.float32, + ) + channel_start = 0 + while channel_start < channels: + block_offsets = channel_start + channel_offsets + channel_mask = block_offsets < channels + grad_reduced, _, _ = _load_reduced_values_with_index( + grad_out_ptr, + coeff_index_ptr, + edge_id, + reduced_dim, + block_offsets, + channel_mask, + grad_out_stride_e, + grad_out_stride_r, + grad_out_stride_c, + ) + grad_full_block = _scatter_reduced_to_full_matrix( + grad_reduced, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + x_full_block = _load_full_node_values( + x_ptr, + src_idx, + full_dim, + block_offsets, + channel_mask, + x_stride_n, + x_stride_d, + x_stride_c, + ) + grad_w_acc += tl.dot( + grad_full_block, + tl.trans(x_full_block.to(grad_full_block.dtype)), + input_precision="ieee", + ) + channel_start += BLOCK_CHANNEL + grad_w_ptrs = ( + grad_wigner_ptr + + edge_id * grad_wigner_stride_e + + full_offsets[:, None] * grad_wigner_stride_r + + full_offsets[None, :] * grad_wigner_stride_k + ) + full_mask = full_offsets < full_dim + tl.store( + grad_w_ptrs, + grad_w_acc, + mask=full_mask[:, None] & full_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_back_l1_forward_kernel( + x_ptr, + wigner_ptr, + coeff_index_ptr, + out_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + x_stride_e, + x_stride_r, + x_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + out_stride_e, + out_stride_d, + out_stride_c, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Fused ``local reduced -> global`` rotation specialized for ``lmax=1``.""" + edge_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + channel_mask = channel_offsets < channels + while edge_id < num_edges: + reduced_values, reduced_mask, coeff_rows = _load_reduced_values_with_index( + x_ptr, + coeff_index_ptr, + edge_id, + reduced_dim, + channel_offsets, + channel_mask, + x_stride_e, + x_stride_r, + x_stride_c, + ) + x_full = _scatter_reduced_to_full_matrix( + reduced_values, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + w_full = _load_full_wigner_matrix( + wigner_ptr, + edge_id, + full_dim, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + ).to(x_full.dtype) + y_full = _matvec_l1(w_full, x_full) + full_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < full_dim + out_ptrs = ( + out_ptr + + edge_id * out_stride_e + + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * out_stride_d + + channel_offsets[None, :] * out_stride_c + ) + tl.store( + out_ptrs, + y_full, + mask=full_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_back_l2_forward_kernel( + x_ptr, + wigner_ptr, + coeff_index_ptr, + out_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + x_stride_e, + x_stride_r, + x_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + out_stride_e, + out_stride_d, + out_stride_c, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Fused ``local reduced -> global`` rotation specialized for ``lmax=2``.""" + edge_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + channel_mask = channel_offsets < channels + while edge_id < num_edges: + reduced_values, reduced_mask, coeff_rows = _load_reduced_values_with_index( + x_ptr, + coeff_index_ptr, + edge_id, + reduced_dim, + channel_offsets, + channel_mask, + x_stride_e, + x_stride_r, + x_stride_c, + ) + x_full = _scatter_reduced_to_full_matrix( + reduced_values, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + w_full = _load_full_wigner_matrix( + wigner_ptr, + edge_id, + full_dim, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + ).to(x_full.dtype) + y_full = _matvec_l2(w_full, x_full) + full_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < full_dim + out_ptrs = ( + out_ptr + + edge_id * out_stride_e + + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * out_stride_d + + channel_offsets[None, :] * out_stride_c + ) + tl.store( + out_ptrs, + y_full, + mask=full_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_back_l3_forward_kernel( + x_ptr, + wigner_ptr, + coeff_index_ptr, + out_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + x_stride_e, + x_stride_r, + x_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + out_stride_e, + out_stride_d, + out_stride_c, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Fused ``local reduced -> global`` rotation specialized for ``lmax=3``.""" + edge_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + channel_mask = channel_offsets < channels + while edge_id < num_edges: + reduced_values, reduced_mask, coeff_rows = _load_reduced_values_with_index( + x_ptr, + coeff_index_ptr, + edge_id, + reduced_dim, + channel_offsets, + channel_mask, + x_stride_e, + x_stride_r, + x_stride_c, + ) + x_full = _scatter_reduced_to_full_matrix( + reduced_values, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + w_full = _load_full_wigner_matrix( + wigner_ptr, + edge_id, + full_dim, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + ).to(x_full.dtype) + y_full = tl.dot( + w_full.to(x_full.dtype), + x_full, + input_precision="ieee", + ) + full_mask = tl.arange(0, TRITON_SMALL_FULL_DIM) < full_dim + out_ptrs = ( + out_ptr + + edge_id * out_stride_e + + tl.arange(0, TRITON_SMALL_FULL_DIM)[:, None] * out_stride_d + + channel_offsets[None, :] * out_stride_c + ) + tl.store( + out_ptrs, + y_full, + mask=full_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_back_l1_bwd_dx_kernel( + grad_out_ptr, + wigner_ptr, + coeff_index_ptr, + grad_x_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + grad_out_stride_e, + grad_out_stride_d, + grad_out_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + grad_x_stride_e, + grad_x_stride_r, + grad_x_stride_c, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute reduced-layout gradients specialized for ``lmax=1``.""" + edge_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + channel_mask = channel_offsets < channels + while edge_id < num_edges: + full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) + full_mask = full_offsets < full_dim + grad_ptrs = ( + grad_out_ptr + + edge_id * grad_out_stride_e + + full_offsets[:, None] * grad_out_stride_d + + channel_offsets[None, :] * grad_out_stride_c + ) + grad_full = tl.load( + grad_ptrs, + mask=full_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + coeff_rows = tl.load( + coeff_index_ptr + full_offsets, + mask=full_offsets < reduced_dim, + other=-1, + ).to(tl.int64) + reduced_mask = full_offsets < reduced_dim + w_full = _load_full_wigner_matrix( + wigner_ptr, + edge_id, + full_dim, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + ).to(grad_full.dtype) + dx_full = _matvec_t_l1(w_full, grad_full) + grad_reduced = _select_reduced_from_full_matrix( + dx_full, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + grad_x_ptrs = ( + grad_x_ptr + + edge_id * grad_x_stride_e + + full_offsets[:, None] * grad_x_stride_r + + channel_offsets[None, :] * grad_x_stride_c + ) + tl.store( + grad_x_ptrs, + grad_reduced, + mask=reduced_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_back_l2_bwd_dx_kernel( + grad_out_ptr, + wigner_ptr, + coeff_index_ptr, + grad_x_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + grad_out_stride_e, + grad_out_stride_d, + grad_out_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + grad_x_stride_e, + grad_x_stride_r, + grad_x_stride_c, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute reduced-layout gradients specialized for ``lmax=2``.""" + edge_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + channel_mask = channel_offsets < channels + while edge_id < num_edges: + full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) + full_mask = full_offsets < full_dim + grad_ptrs = ( + grad_out_ptr + + edge_id * grad_out_stride_e + + full_offsets[:, None] * grad_out_stride_d + + channel_offsets[None, :] * grad_out_stride_c + ) + grad_full = tl.load( + grad_ptrs, + mask=full_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + coeff_rows = tl.load( + coeff_index_ptr + full_offsets, + mask=full_offsets < reduced_dim, + other=-1, + ).to(tl.int64) + reduced_mask = full_offsets < reduced_dim + w_full = _load_full_wigner_matrix( + wigner_ptr, + edge_id, + full_dim, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + ).to(grad_full.dtype) + dx_full = _matvec_t_l2(w_full, grad_full) + grad_reduced = _select_reduced_from_full_matrix( + dx_full, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + grad_x_ptrs = ( + grad_x_ptr + + edge_id * grad_x_stride_e + + full_offsets[:, None] * grad_x_stride_r + + channel_offsets[None, :] * grad_x_stride_c + ) + tl.store( + grad_x_ptrs, + grad_reduced, + mask=reduced_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_back_l3_bwd_dx_kernel( + grad_out_ptr, + wigner_ptr, + coeff_index_ptr, + grad_x_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + grad_out_stride_e, + grad_out_stride_d, + grad_out_stride_c, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + grad_x_stride_e, + grad_x_stride_r, + grad_x_stride_c, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute reduced-layout gradients specialized for ``lmax=3``.""" + edge_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + channel_offsets = channel_block_id * BLOCK_CHANNEL + tl.arange(0, BLOCK_CHANNEL) + channel_mask = channel_offsets < channels + while edge_id < num_edges: + full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) + full_mask = full_offsets < full_dim + grad_ptrs = ( + grad_out_ptr + + edge_id * grad_out_stride_e + + full_offsets[:, None] * grad_out_stride_d + + channel_offsets[None, :] * grad_out_stride_c + ) + grad_full = tl.load( + grad_ptrs, + mask=full_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + coeff_rows = tl.load( + coeff_index_ptr + full_offsets, + mask=full_offsets < reduced_dim, + other=-1, + ).to(tl.int64) + reduced_mask = full_offsets < reduced_dim + w_full = _load_full_wigner_matrix( + wigner_ptr, + edge_id, + full_dim, + wigner_stride_e, + wigner_stride_r, + wigner_stride_k, + ).to(grad_full.dtype) + dx_full = tl.dot( + tl.trans(w_full.to(grad_full.dtype)), + grad_full, + input_precision="ieee", + ) + grad_reduced = _select_reduced_from_full_matrix( + dx_full, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + grad_x_ptrs = ( + grad_x_ptr + + edge_id * grad_x_stride_e + + full_offsets[:, None] * grad_x_stride_r + + channel_offsets[None, :] * grad_x_stride_c + ) + tl.store( + grad_x_ptrs, + grad_reduced, + mask=reduced_mask[:, None] & channel_mask[None, :], + ) + edge_id += GRID_E_STRIDE + + +@triton.jit +def rotate_back_small_bwd_dw_kernel( + grad_out_ptr, + x_ptr, + coeff_index_ptr, + grad_wigner_ptr, + num_edges, + reduced_dim, + full_dim, + channels, + grad_out_stride_e, + grad_out_stride_d, + grad_out_stride_c, + x_stride_e, + x_stride_r, + x_stride_c, + grad_wigner_stride_e, + grad_wigner_stride_r, + grad_wigner_stride_k, + BLOCK_CHANNEL: tl.constexpr, + GRID_E_STRIDE: tl.constexpr, +): + """Compute full Wigner gradients for specialized small-l rotate-back.""" + edge_id = tl.program_id(0) + channel_offsets = tl.arange(0, BLOCK_CHANNEL) + full_offsets = tl.arange(0, TRITON_SMALL_FULL_DIM) + while edge_id < num_edges: + coeff_rows = tl.load( + coeff_index_ptr + full_offsets, + mask=full_offsets < reduced_dim, + other=-1, + ).to(tl.int64) + reduced_mask = full_offsets < reduced_dim + grad_w_acc = tl.zeros( + (TRITON_SMALL_FULL_DIM, TRITON_SMALL_FULL_DIM), + dtype=tl.float32, + ) + channel_start = 0 + while channel_start < channels: + block_offsets = channel_start + channel_offsets + channel_mask = block_offsets < channels + full_mask = full_offsets < full_dim + grad_ptrs = ( + grad_out_ptr + + edge_id * grad_out_stride_e + + full_offsets[:, None] * grad_out_stride_d + + block_offsets[None, :] * grad_out_stride_c + ) + grad_full = tl.load( + grad_ptrs, + mask=full_mask[:, None] & channel_mask[None, :], + other=0.0, + ) + reduced_values, _, _ = _load_reduced_values_with_index( + x_ptr, + coeff_index_ptr, + edge_id, + reduced_dim, + block_offsets, + channel_mask, + x_stride_e, + x_stride_r, + x_stride_c, + ) + x_full = _scatter_reduced_to_full_matrix( + reduced_values, + reduced_mask, + coeff_rows, + BLOCK_CHANNEL=BLOCK_CHANNEL, + ) + grad_w_acc += tl.dot( + grad_full, + tl.trans(x_full.to(grad_full.dtype)), + input_precision="ieee", + ) + channel_start += BLOCK_CHANNEL + grad_w_ptrs = ( + grad_wigner_ptr + + edge_id * grad_wigner_stride_e + + full_offsets[:, None] * grad_wigner_stride_r + + full_offsets[None, :] * grad_wigner_stride_k + ) + full_mask = full_offsets < full_dim + tl.store( + grad_w_ptrs, + grad_w_acc, + mask=full_mask[:, None] & full_mask[None, :], + ) + edge_id += GRID_E_STRIDE diff --git a/deepmd/pt/model/descriptor/sezm_nn/utils.py b/deepmd/pt/model/descriptor/sezm_nn/utils.py new file mode 100644 index 0000000000..7b3d347bec --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/utils.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Utility helpers for the SeZM descriptor package. + +This module provides small numerical helpers, dtype conversion utilities, and +profiling helpers shared across the SeZM descriptor implementation. +""" + +from __future__ import ( + annotations, +) + +import math +from contextlib import ( + contextmanager, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +import numpy as np +import torch +import torch.nn as nn + +from deepmd.pt.utils.utils import ( + get_generator, +) + +if TYPE_CHECKING: + from collections.abc import ( + Generator, + ) + +ATTN_RES_MODES = ("none", "independent", "dependent") + + +def init_trunc_normal_fan_in_out( + weight: torch.Tensor, + seed: int | list[int] | None, + scale: float = 1.0, +) -> None: + """Initialize weight with truncated normal distribution. + + Uses Xavier-like variance scaling: std = scale / sqrt(fan_in + fan_out). + Truncation at +/-3*std prevents extreme outliers. + + Parameters + ---------- + weight : torch.Tensor + Weight tensor with shape (out_features, in_features). + seed : int | list[int] | None + Random seed for reproducibility. + scale : float, default=1.0 + Multiplicative scale factor in the standard deviation numerator. + """ + if weight.ndim != 2: + raise ValueError("`weight` must be a 2D tensor") + if scale <= 0: + raise ValueError("`scale` must be positive") + fan_out, fan_in = weight.shape + std = float(scale) / math.sqrt(fan_in + fan_out) + nn.init.trunc_normal_( + weight, + mean=0.0, + std=std, + a=-3.0 * std, + b=3.0 * std, + generator=get_generator(seed), + ) + + +@contextmanager +def nvtx_range(name: str) -> Generator[None, None, None]: + """ + Create an NVTX range when CUDA is available; otherwise, no-op. + + Parameters + ---------- + name + Range name shown in Nsight Systems/Compute. + """ + if torch.cuda.is_available(): + nvtx = torch.cuda.nvtx + if hasattr(nvtx, "range"): + with nvtx.range(name): + yield + return + yield + + +def safe_norm(x: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: + """ + Compute vector norm with smooth epsilon regularization. + + Uses float32 for computation when input is fp16/bf16. + + Parameters + ---------- + x : torch.Tensor + Input tensor with shape (N, 3), where N is the number of vectors. + eps : float + Lower bound for the norm. + + Returns + ------- + torch.Tensor + Norm with shape (N, 1). + """ + in_dtype = x.dtype + if in_dtype in (torch.float16, torch.bfloat16): + x = x.float() + eps_sq = x.new_tensor(float(eps) * float(eps)) + norm = torch.sqrt(torch.sum(x * x, dim=-1, keepdim=True) + eps_sq) + return norm.to(dtype=in_dtype) + + +def safe_numpy_to_tensor( + data: Any, *, device: torch.device, dtype: torch.dtype +) -> torch.Tensor: + if isinstance(data, torch.Tensor): + return data.to(device=device, dtype=dtype) + if isinstance(data, np.ndarray): + # Handle bfloat16: numpy uses ml_dtypes.bfloat16, which torch.as_tensor + # cannot convert. Convert to float32 first, then cast to target dtype. + if hasattr(data.dtype, "name") and "bfloat16" in data.dtype.name: + data = data.astype(np.float32) + return torch.as_tensor(data, device=device).to(dtype) + return torch.as_tensor(data, device=device, dtype=dtype) + + +def get_promoted_dtype(dtype: torch.dtype) -> torch.dtype: + """ + Get promoted dtype for numerical stability. + + For bf16/fp16, use float32 to ensure numerical stability + in computation and storage compatibility. + """ + if dtype in (torch.float16, torch.bfloat16): + return torch.float32 + return dtype + + +def np_safe( + tensor: torch.Tensor | None, +) -> np.ndarray | None: + """ + Convert tensor to numpy array, promoting low-precision types to fp32. + + For bf16/fp16, converts to fp32 first since NumPy/HDF5 do not natively + support these formats. fp32/fp64 are kept unchanged. + + Parameters + ---------- + tensor + PyTorch tensor to convert. Can be None. + + Returns + ------- + np.ndarray or None + numpy array with at least fp32 precision. + """ + if tensor is None: + return None + if tensor.dtype in (torch.float16, torch.bfloat16): + tensor = tensor.float() + return tensor.detach().cpu().numpy() diff --git a/deepmd/pt/model/descriptor/sezm_nn/wignerd.py b/deepmd/pt/model/descriptor/sezm_nn/wignerd.py new file mode 100644 index 0000000000..ca90d6978c --- /dev/null +++ b/deepmd/pt/model/descriptor/sezm_nn/wignerd.py @@ -0,0 +1,1516 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Quaternion-based Wigner-D and edge-frame utilities for SeZM. + +This module defines the quaternion helpers and Wigner-D evaluator used to +construct edge-aligned SO(3) rotation blocks in SeZM. +""" + +from __future__ import ( + annotations, +) + +import math +from itertools import ( + permutations, +) +from typing import ( + Any, + ClassVar, +) + +import torch +import torch.nn as nn + +from deepmd.pt.utils import ( + env, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .utils import ( + nvtx_range, +) + + +class CaseCoefficients(nn.Module): + """ + Polynomial tables for one magnitude-ordered branch of the quaternion Wigner path. + + The generic Wigner-D evaluation factors each matrix element into: + - a phase term carried by the arguments of ``Ra`` and ``Rb``; + - a real magnitude term evaluated by Horner recursion. + + The magnitude formula has two numerically stable branches, depending on whether + ``|Ra| >= |Rb|`` or the opposite. Each branch stores the branch-specific Horner + coefficients and the powers of ``|Ra|`` / ``|Rb|`` that sit outside the Horner + polynomial. + """ + + def __init__( + self, + *, + coeff: torch.Tensor, + horner: torch.Tensor, + poly_len: torch.Tensor, + ra_exp: torch.Tensor, + rb_exp: torch.Tensor, + sign: torch.Tensor, + ) -> None: + super().__init__() + self.register_buffer("coeff", coeff, persistent=True) + self.register_buffer("horner", horner, persistent=True) + self.register_buffer("poly_len", poly_len, persistent=True) + self.register_buffer("ra_exp", ra_exp, persistent=True) + self.register_buffer("rb_exp", rb_exp, persistent=True) + self.register_buffer("sign", sign, persistent=True) + + +class WignerPolynomialCoefficients(nn.Module): + """ + Precomputed coefficient tables for the generic quaternion Wigner evaluator. + + Only one half of each real block is stored explicitly. The remaining entries are + reconstructed from the exact symmetry + + ``D^l_{-m',-m} = (-1)^(m' - m) * conj(D^l_{m',m})``. + + This keeps the runtime path branch-free with respect to ``(l, m', m)`` while + preserving the exact packed ``(l, m)`` layout used everywhere else in SeZM. + """ + + def __init__( + self, + *, + lmin: int, + lmax: int, + size: int, + max_poly_len: int, + n_primary: int, + n_derived: int, + primary_row: torch.Tensor, + primary_col: torch.Tensor, + case1: CaseCoefficients, + case2: CaseCoefficients, + mp_plus_m: torch.Tensor, + m_minus_mp: torch.Tensor, + diagonal_mask: torch.Tensor, + anti_diagonal_mask: torch.Tensor, + special_2m: torch.Tensor, + anti_diag_sign: torch.Tensor, + derived_row: torch.Tensor, + derived_col: torch.Tensor, + derived_primary_idx: torch.Tensor, + derived_sign: torch.Tensor, + ) -> None: + super().__init__() + self.lmin = int(lmin) + self.lmax = int(lmax) + self.size = int(size) + self.max_poly_len = int(max_poly_len) + self.n_primary = int(n_primary) + self.n_derived = int(n_derived) + + self.register_buffer("primary_row", primary_row, persistent=True) + self.register_buffer("primary_col", primary_col, persistent=True) + self.case1 = case1 + self.case2 = case2 + self.register_buffer("mp_plus_m", mp_plus_m, persistent=True) + self.register_buffer("m_minus_mp", m_minus_mp, persistent=True) + self.register_buffer("diagonal_mask", diagonal_mask, persistent=True) + self.register_buffer("anti_diagonal_mask", anti_diagonal_mask, persistent=True) + self.register_buffer("special_2m", special_2m, persistent=True) + self.register_buffer("anti_diag_sign", anti_diag_sign, persistent=True) + self.register_buffer("derived_row", derived_row, persistent=True) + self.register_buffer("derived_col", derived_col, persistent=True) + self.register_buffer( + "derived_primary_idx", derived_primary_idx, persistent=True + ) + self.register_buffer("derived_sign", derived_sign, persistent=True) + + +class WignerSmallOrderCoefficients(nn.Module): + """ + Precomputed low-order quaternion polynomial kernels in the SeZM packed basis. + + The tensors in this container provide the specialized ``l=2`` and ``l=3,4`` + kernels used by the hybrid Wigner runtime: + - ``C_l2`` stores the degree-4 tensor-contraction coefficients; + - ``C_l3`` / ``C_l4`` store flattened monomial coefficient matrices; + - ``C_combined_l3l4`` lifts the ``l=3`` basis to degree 8 and stacks it with + ``l=4`` so both blocks can be produced by one matrix multiply; + - ``exp_l3`` / ``exp_l4`` store the monomial exponent tables used by the runtime + gather/prod path. + """ + + def __init__( + self, + *, + C_l2: torch.Tensor, + C_l3: torch.Tensor, + C_l4: torch.Tensor, + C_combined_l3l4: torch.Tensor, + exp_l3: torch.Tensor, + exp_l4: torch.Tensor, + ) -> None: + super().__init__() + self.register_buffer("C_l2", C_l2, persistent=True) + self.register_buffer("C_l3", C_l3, persistent=True) + self.register_buffer("C_l4", C_l4, persistent=True) + self.register_buffer("C_combined_l3l4", C_combined_l3l4, persistent=True) + self.register_buffer("exp_l3", exp_l3, persistent=True) + self.register_buffer("exp_l4", exp_l4, persistent=True) + + +def _safe_norm_nd(x: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: + """Compute an ``L2`` norm with smooth epsilon regularization.""" + in_dtype = x.dtype + if in_dtype in (torch.float16, torch.bfloat16): + x = x.float() + norm = torch.sqrt(torch.sum(x * x, dim=-1, keepdim=True) + eps * eps) + return norm.to(dtype=in_dtype) + + +def quaternion_normalize(q: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: + """Normalize quaternions with a differentiable epsilon floor.""" + return q / _safe_norm_nd(q, eps) + + +def quaternion_multiply(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: + """Hamilton product for batched quaternions in ``(w, x, y, z)`` order.""" + w1, x1, y1, z1 = q1.unbind(dim=-1) + w2, x2, y2, z2 = q2.unbind(dim=-1) + return torch.stack( + [ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, + ], + dim=-1, + ) + + +def quaternion_to_rotation_matrix(q: torch.Tensor) -> torch.Tensor: + """ + Convert unit quaternions to 3x3 rotation matrices. + + The returned matrix is the active rotation represented by ``q``. In SeZM this is + the global->local edge rotation, so multiplying the edge direction by this matrix + sends it to local ``+Z``. + """ + w, x, y, z = q.unbind(dim=-1) + x2 = x * x + y2 = y * y + z2 = z * z + xy = x * y + xz = x * z + yz = y * z + wx = w * x + wy = w * y + wz = w * z + return torch.stack( + [ + torch.stack( + [1.0 - 2.0 * (y2 + z2), 2.0 * (xy - wz), 2.0 * (xz + wy)], + dim=-1, + ), + torch.stack( + [2.0 * (xy + wz), 1.0 - 2.0 * (x2 + z2), 2.0 * (yz - wx)], + dim=-1, + ), + torch.stack( + [2.0 * (xz - wy), 2.0 * (yz + wx), 1.0 - 2.0 * (x2 + y2)], + dim=-1, + ), + ], + dim=-2, + ) + + +def quaternion_z_rotation(gamma: torch.Tensor) -> torch.Tensor: + """ + Create quaternions for a rotation about the local ``+Z`` axis. + + Parameters + ---------- + gamma + Roll angles in radians with shape ``(E,)``. + + Returns + ------- + torch.Tensor + Quaternions with shape ``(E, 4)`` in ``(w, x, y, z)`` order. + """ + half_gamma = 0.5 * gamma + w = torch.cos(half_gamma) + x = torch.zeros_like(gamma) + y = torch.zeros_like(gamma) + z = torch.sin(half_gamma) + return torch.stack([w, x, y, z], dim=-1) + + +def _smooth_step_cinf(x: torch.Tensor) -> torch.Tensor: + """ + Smooth ``C^inf`` step on ``[0, 1]``. + + This function equals exactly 0 and 1 at the endpoints, and transitions with all + derivatives vanishing there. It is used only to blend the two valid quaternion + charts; the geometric constraint itself is still enforced by the charts. + """ + x_clamped = x.clamp(0.0, 1.0) + eps = torch.finfo(x_clamped.dtype).eps + left = torch.exp(-1.0 / torch.clamp(x_clamped, min=eps)) + right = torch.exp(-1.0 / torch.clamp(1.0 - x_clamped, min=eps)) + interior = left / (left + right) + return torch.where( + x_clamped <= 0.0, + torch.zeros_like(x_clamped), + torch.where(x_clamped >= 1.0, torch.ones_like(x_clamped), interior), + ) + + +def quaternion_nlerp( + q0: torch.Tensor, + q1: torch.Tensor, + weight: torch.Tensor, + *, + eps: float = 1e-7, +) -> torch.Tensor: + """ + Normalized linear interpolation on the shortest quaternion arc. + + ``q`` and ``-q`` represent the same spatial rotation. Aligning signs before the + interpolation guarantees that the blended chart stays on the shorter great-circle + segment in ``S^3``. + """ + dot = torch.sum(q0 * q1, dim=-1, keepdim=True) + q1_aligned = torch.where(dot < 0.0, -q1, q1) + blended = (1.0 - weight.unsqueeze(-1)) * q0 + weight.unsqueeze(-1) * q1_aligned + return quaternion_normalize(blended, eps) + + +def _build_edge_quaternion_chart_pos_z( + edge_unit: torch.Tensor, + eps: float, +) -> torch.Tensor: + """Quaternion chart that is exact away from the ``-Z`` pole.""" + x = edge_unit[..., 0] + y = edge_unit[..., 1] + z = edge_unit[..., 2] + q = torch.stack([1.0 + z, y, -x, torch.zeros_like(x)], dim=-1) + return quaternion_normalize(q, eps) + + +def _build_edge_quaternion_chart_neg_z( + edge_unit: torch.Tensor, + eps: float, +) -> torch.Tensor: + """Quaternion chart that is exact away from the ``+Z`` pole.""" + x = edge_unit[..., 0] + y = edge_unit[..., 1] + z = edge_unit[..., 2] + q = torch.stack([-x, torch.zeros_like(x), 1.0 - z, y], dim=-1) + return quaternion_normalize(q, eps) + + +def build_edge_quaternion( + edge_vec: torch.Tensor, + *, + edge_len: torch.Tensor | None = None, + eps: float = 1e-7, +) -> torch.Tensor: + """ + Build stable edge quaternions for the SeZM local ``+Z`` convention. + + The returned quaternion represents the global->local edge rotation, so applying its + rotation matrix to the unit edge direction yields exactly ``(0, 0, 1)``. Two exact + quaternion charts are used: + + - a ``+Z`` chart that is regular everywhere except the antipodal ``-Z`` pole; + - a ``-Z`` chart that is regular everywhere except the antipodal ``+Z`` pole. + + Both charts encode the same edge-aligned local frame. A smooth ``C^inf`` blend in + the overlap region removes the hard pole switch while keeping the represented + rotation on the correct quaternion branch. + + Parameters + ---------- + edge_vec + Edge vectors with shape ``(E, 3)``. + edge_len + Optional edge lengths with shape ``(E, 1)``. When omitted, lengths are + recomputed from ``edge_vec``. + eps + Numerical floor used in vector and quaternion normalization. + + Returns + ------- + torch.Tensor + Unit quaternions with shape ``(E, 4)`` in ``(w, x, y, z)`` order. + """ + if edge_len is None: + edge_len = _safe_norm_nd(edge_vec, eps) + else: + edge_len = torch.sqrt(edge_len * edge_len + eps * eps) + edge_unit = edge_vec / edge_len + q_pos = _build_edge_quaternion_chart_pos_z(edge_unit, eps) + q_neg = _build_edge_quaternion_chart_neg_z(edge_unit, eps) + blend = _smooth_step_cinf(0.5 * (edge_unit[..., 2] + 1.0)) + return quaternion_nlerp(q_neg, q_pos, blend, eps=eps) + + +class WignerDCalculator(nn.Module): + """ + Quaternion-driven Wigner-D blocks for the SeZM packed real spherical basis. + + Input quaternions represent the global->local edge rotation that sends the edge + direction to local ``+Z``. The returned block-diagonal matrix keeps the packed + SeZM real spherical-harmonics layout, so downstream code continues to consume + ``D_full`` and ``Dt_full`` directly. + + Runtime structure: + - ``l=0``: scalar identity block; + - ``l=1``: direct quaternion -> Cartesian rotation -> real l=1 block; + - ``l=2``: dedicated degree-4 quaternion tensor contraction; + - ``l=3,4``: dedicated quaternion monomial kernels; + - ``l>=5``: generic quaternion polynomial path with precomputed coefficient tables. + """ + + _SMALL_ORDER_CACHE_CPU_FP64: ClassVar[dict[str, torch.Tensor] | None] = None + + def __init__( + self, + lmax: int, + *, + eps: float = 1e-7, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.lmax = int(lmax) + if self.lmax < 0: + raise ValueError("`lmax` must be non-negative") + self.dtype = dtype + self.device = env.DEVICE + self.eps = float(eps) + self.dim_full = (self.lmax + 1) ** 2 + self.poly_lmin = 5 + self.poly_offset = self.poly_lmin * self.poly_lmin + + self.register_buffer( + "l1_perm", + torch.tensor([1, 2, 0], dtype=torch.int64, device=self.device), + persistent=True, + ) + l1_sign = torch.tensor([-1.0, -1.0, 1.0], dtype=self.dtype, device=self.device) + self.register_buffer( + "l1_sign_outer", + torch.outer(l1_sign, l1_sign), + persistent=True, + ) + + if self.lmax >= 2: + self.small_order_kernels = self._build_small_order_kernels( + dtype=self.dtype, + device=self.device, + ) + + if self.lmax >= self.poly_lmin: + coeffs = self._precompute_wigner_coefficients( + self.lmax, + dtype=torch.float64, + device=torch.device("cpu"), + lmin=self.poly_lmin, + ) + self.poly_coeffs = coeffs.to(device=self.device) + blocks = self._precompute_real_basis_blocks( + lmin=self.poly_lmin, + lmax=self.lmax, + dtype=torch.float64, + device=torch.device("cpu"), + ) + U_re, U_im, U_re_t, U_im_t = self._assemble_block_diagonal_real_basis( + blocks + ) + self.register_buffer( + "poly_u_re", + U_re.to(device=self.device, dtype=self.dtype), + persistent=True, + ) + self.register_buffer( + "poly_u_im", + U_im.to(device=self.device, dtype=self.dtype), + persistent=True, + ) + self.register_buffer( + "poly_u_re_t", + U_re_t.to(device=self.device, dtype=self.dtype), + persistent=True, + ) + self.register_buffer( + "poly_u_im_t", + U_im_t.to(device=self.device, dtype=self.dtype), + persistent=True, + ) + + def forward( + self, edge_quaternion: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Build packed block-diagonal Wigner-D matrices from edge quaternions. + + Parameters + ---------- + edge_quaternion + Unit quaternions with shape ``(E, 4)`` representing the global->local + edge rotation. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + ``(D_full, Dt_full)`` with shape ``(E, (lmax+1)^2, (lmax+1)^2)``. + """ + edge_quaternion = quaternion_normalize( + edge_quaternion.to(dtype=self.dtype), + eps=self.eps, + ) + n_edge = edge_quaternion.shape[0] + D_full = torch.zeros( + n_edge, + self.dim_full, + self.dim_full, + dtype=edge_quaternion.dtype, + device=edge_quaternion.device, + ) + D_full[:, 0, 0] = 1.0 + + if self.lmax >= 1: + with nvtx_range("WignerD/l1"): + D_full[:, 1:4, 1:4] = self._compute_l1_block(edge_quaternion) + + if self.lmax >= 2: + with nvtx_range("WignerD/l2"): + D_full[:, 4:9, 4:9] = self._compute_l2_block(edge_quaternion) + + if self.lmax >= 3: + if self.lmax >= 4: + with nvtx_range("WignerD/l3l4"): + D_l3, D_l4 = self._compute_l3l4_blocks(edge_quaternion) + D_full[:, 9:16, 9:16] = D_l3 + D_full[:, 16:25, 16:25] = D_l4 + else: + with nvtx_range("WignerD/l3"): + D_full[:, 9:16, 9:16] = self._compute_l3_block(edge_quaternion) + + if self.lmax >= self.poly_lmin: + with nvtx_range("WignerD/polynomial"): + ra_re, ra_im, rb_re, rb_im = self._quaternion_to_ra_rb_real( + edge_quaternion + ) + D_re, D_im = self._wigner_d_matrix_realpair( + ra_re, + ra_im, + rb_re, + rb_im, + self.poly_coeffs, + dtype=self.dtype, + ) + D_poly = self._wigner_d_pair_to_real( + D_re, + D_im, + ( + self.poly_u_re, + self.poly_u_im, + self.poly_u_re_t, + self.poly_u_im_t, + ), + lmax=self.lmax, + lmin=self.poly_lmin, + ) + D_full[:, self.poly_offset :, self.poly_offset :] = D_poly + + Dt_full = D_full.transpose(-1, -2).contiguous() + return D_full, Dt_full + + @classmethod + def _get_small_order_cache_cpu_fp64(cls) -> dict[str, torch.Tensor]: + """Generate the low-order kernel coefficients once per process on CPU fp64.""" + if cls._SMALL_ORDER_CACHE_CPU_FP64 is None: + cls._SMALL_ORDER_CACHE_CPU_FP64 = cls._generate_small_order_cache_cpu_fp64() + return cls._SMALL_ORDER_CACHE_CPU_FP64 + + @classmethod + def _build_small_order_kernels( + cls, + *, + dtype: torch.dtype, + device: torch.device, + ) -> WignerSmallOrderCoefficients: + """Instantiate the specialized ``l=2,3,4`` kernels on the requested device/dtype.""" + cache = cls._get_small_order_cache_cpu_fp64() + return WignerSmallOrderCoefficients( + C_l2=cache["C_l2"].to(device=device, dtype=dtype), + C_l3=cache["C_l3"].to(device=device, dtype=dtype), + C_l4=cache["C_l4"].to(device=device, dtype=dtype), + C_combined_l3l4=cache["C_combined_l3l4"].to(device=device, dtype=dtype), + exp_l3=cache["exp_l3"].to(device=device), + exp_l4=cache["exp_l4"].to(device=device), + ) + + @classmethod + def _generate_small_order_cache_cpu_fp64(cls) -> dict[str, torch.Tensor]: + """ + Generate the low-order kernel coefficients from the generic SeZM reference path. + + The coefficients are exact module constants. They are solved once in fp64 on CPU, + validated against the generic quaternion polynomial evaluator, and then reused by + every `WignerDCalculator` instance. + """ + dtype = torch.float64 + device = torch.device("cpu") + generator = torch.Generator() + generator.manual_seed(20260404) + + q_fit = torch.randn(2048, 4, dtype=dtype, device=device, generator=generator) + q_fit = quaternion_normalize(q_fit, eps=torch.finfo(dtype).eps) + ref_blocks = cls._compute_generic_reference_blocks( + q_fit, lmax=4, dtype=dtype, device=device + ) + + monomials_l2 = cls._generate_monomials(4, 4) + monomials_l3 = cls._generate_monomials(4, 6) + monomials_l4 = cls._generate_monomials(4, 8) + exp_l2 = cls._monomials_to_exponent_tensor(monomials_l2, device=device) + exp_l3 = cls._monomials_to_exponent_tensor(monomials_l3, device=device) + exp_l4 = cls._monomials_to_exponent_tensor(monomials_l4, device=device) + + C_l2_flat = cls._solve_monomial_coefficients( + q_fit, + ref_blocks[2], + exp_l2, + ) + C_l3 = cls._solve_monomial_coefficients(q_fit, ref_blocks[3], exp_l3) + C_l4 = cls._solve_monomial_coefficients(q_fit, ref_blocks[4], exp_l4) + C_l2 = cls._build_l2_contraction_tensor(C_l2_flat, monomials_l2) + C_combined_l3l4 = cls._build_combined_l3l4( + C_l3, C_l4, monomials_l3, monomials_l4 + ) + + q_val = torch.randn(256, 4, dtype=dtype, device=device, generator=generator) + q_val = quaternion_normalize(q_val, eps=torch.finfo(dtype).eps) + ref_val = cls._compute_generic_reference_blocks( + q_val, lmax=4, dtype=dtype, device=device + ) + test_val = cls._evaluate_small_order_blocks( + q_val, + C_l2=C_l2, + C_l3=C_l3, + C_l4=C_l4, + exp_l3=exp_l3, + exp_l4=exp_l4, + ) + thresholds = {2: 1e-10, 3: 1e-10, 4: 1e-10} + for ell in (2, 3, 4): + err = (test_val[ell] - ref_val[ell]).abs().max().item() + if err > thresholds[ell]: + raise RuntimeError( + f"Failed to generate stable SeZM Wigner coefficients for l={ell}: max_err={err}" + ) + + return { + "C_l2": C_l2, + "C_l3": C_l3, + "C_l4": C_l4, + "C_combined_l3l4": C_combined_l3l4, + "exp_l3": exp_l3, + "exp_l4": exp_l4, + } + + @classmethod + def _compute_generic_reference_blocks( + cls, + edge_quaternion: torch.Tensor, + *, + lmax: int, + dtype: torch.dtype, + device: torch.device, + ) -> dict[int, torch.Tensor]: + """Evaluate the generic SeZM polynomial path and extract the ``l=2,3,4`` blocks.""" + coeffs = cls._precompute_wigner_coefficients( + lmax, + dtype=dtype, + device=device, + lmin=2, + ) + blocks = cls._precompute_real_basis_blocks( + lmin=2, + lmax=lmax, + dtype=dtype, + device=device, + ) + ra_re, ra_im, rb_re, rb_im = cls._quaternion_to_ra_rb_real(edge_quaternion) + D_re, D_im = cls._wigner_d_matrix_realpair( + ra_re, + ra_im, + rb_re, + rb_im, + coeffs, + dtype=dtype, + ) + D_ref = cls._wigner_d_pair_to_real( + D_re, + D_im, + blocks, + lmax=lmax, + lmin=2, + ) + return { + 2: D_ref[:, 0:5, 0:5], + 3: D_ref[:, 5:12, 5:12], + 4: D_ref[:, 12:21, 12:21], + } + + @classmethod + def _solve_monomial_coefficients( + cls, + edge_quaternion: torch.Tensor, + D_block: torch.Tensor, + monomial_exponents: torch.Tensor, + ) -> torch.Tensor: + """Solve the flattened monomial coefficient matrix for one low-order block.""" + max_power = int(monomial_exponents.sum(dim=1).max().item()) + powers = cls._precompute_powers(edge_quaternion, max_power) + M = cls._build_monomial_matrix(powers, monomial_exponents) + Y = D_block.reshape(edge_quaternion.shape[0], -1) + return torch.linalg.lstsq(M, Y).solution.transpose(0, 1).contiguous() + + @staticmethod + def _build_l2_contraction_tensor( + C_l2_flat: torch.Tensor, + monomials: list[tuple[int, int, int, int]], + ) -> torch.Tensor: + """Expand degree-4 monomial coefficients into the symmetric einsum tensor form.""" + C_l2 = torch.zeros( + 5, 5, 4, 4, 4, 4, dtype=C_l2_flat.dtype, device=C_l2_flat.device + ) + for flat_idx, coeff_row in enumerate(C_l2_flat): + i = flat_idx // 5 + j = flat_idx % 5 + for coeff, (a, b, c, d) in zip(coeff_row, monomials, strict=True): + if abs(float(coeff)) < 1e-15: + continue + pool = [0] * a + [1] * b + [2] * c + [3] * d + unique_permutations = set(permutations(pool, 4)) + share = coeff / len(unique_permutations) + for p0, p1, p2, p3 in unique_permutations: + C_l2[i, j, p0, p1, p2, p3] = share + return C_l2 + + @classmethod + def _evaluate_small_order_blocks( + cls, + edge_quaternion: torch.Tensor, + *, + C_l2: torch.Tensor, + C_l3: torch.Tensor, + C_l4: torch.Tensor, + exp_l3: torch.Tensor, + exp_l4: torch.Tensor, + ) -> dict[int, torch.Tensor]: + """Evaluate the specialized ``l=2,3,4`` kernels for validation and caching.""" + q2 = edge_quaternion.unsqueeze(-1) * edge_quaternion.unsqueeze(-2) + q4 = q2.unsqueeze(-1).unsqueeze(-1) * q2.unsqueeze(-3).unsqueeze(-3) + D_l2 = torch.einsum("nabcd,ijabcd->nij", q4, C_l2) + + powers6 = cls._precompute_powers(edge_quaternion, 6) + M3 = cls._build_monomial_matrix(powers6, exp_l3) + D_l3 = torch.matmul(M3, C_l3.transpose(0, 1)).view( + edge_quaternion.shape[0], 7, 7 + ) + + powers8 = cls._precompute_powers(edge_quaternion, 8) + M4 = cls._build_monomial_matrix(powers8, exp_l4) + D_l4 = torch.matmul(M4, C_l4.transpose(0, 1)).view( + edge_quaternion.shape[0], 9, 9 + ) + return { + 2: D_l2, + 3: D_l3, + 4: D_l4, + } + + @staticmethod + def _generate_monomials( + n_vars: int, + total_degree: int, + ) -> list[tuple[int, ...]]: + """Generate all monomials of fixed total degree in lexicographic order.""" + monomials: list[tuple[int, ...]] = [] + + def _recurse( + remaining_vars: int, + remaining_degree: int, + current: list[int], + ) -> None: + if remaining_vars == 1: + monomials.append((*current, remaining_degree)) + return + for i in range(remaining_degree + 1): + _recurse(remaining_vars - 1, remaining_degree - i, [*current, i]) + + _recurse(n_vars, total_degree, []) + return monomials + + @staticmethod + def _monomials_to_exponent_tensor( + monomials: list[tuple[int, ...]], + *, + device: torch.device, + ) -> torch.Tensor: + """Convert monomial tuples to an ``int64`` exponent table.""" + return torch.tensor(monomials, dtype=torch.int64, device=device) + + @staticmethod + def _build_combined_l3l4( + C_l3: torch.Tensor, + C_l4: torch.Tensor, + monomials_l3: list[tuple[int, int, int, int]], + monomials_l4: list[tuple[int, int, int, int]], + ) -> torch.Tensor: + """Lift the ``l=3`` basis to degree 8 and stack it with the ``l=4`` basis.""" + mono8_to_idx = {mono: idx for idx, mono in enumerate(monomials_l4)} + C_l3_lifted = torch.zeros( + C_l3.shape[0], + len(monomials_l4), + dtype=C_l3.dtype, + device=C_l3.device, + ) + for j, (a, b, c, d) in enumerate(monomials_l3): + for mono8 in ( + (a + 2, b, c, d), + (a, b + 2, c, d), + (a, b, c + 2, d), + (a, b, c, d + 2), + ): + C_l3_lifted[:, mono8_to_idx[mono8]] += C_l3[:, j] + return torch.cat([C_l3_lifted, C_l4], dim=0) + + @staticmethod + def _precompute_powers( + q: torch.Tensor, + max_power: int, + ) -> torch.Tensor: + """Precompute powers ``q_i^k`` as a dense table with shape ``(4, max_power+1, E)``.""" + components = q.transpose(0, 1) + if max_power == 0: + return torch.ones(4, 1, q.shape[0], dtype=q.dtype, device=q.device) + repeated = components.unsqueeze(1).expand(4, max_power, q.shape[0]) + positive_powers = torch.cumprod(repeated, dim=1) + return torch.cat( + [ + torch.ones(4, 1, q.shape[0], dtype=q.dtype, device=q.device), + positive_powers, + ], + dim=1, + ) + + @staticmethod + def _build_monomial_matrix( + powers: torch.Tensor, + monomial_exponents: torch.Tensor, + ) -> torch.Tensor: + """Assemble the monomial design matrix for one fixed degree by gather/prod.""" + gather_idx = ( + monomial_exponents.transpose(0, 1) + .unsqueeze(-1) + .expand( + 4, + monomial_exponents.shape[0], + powers.shape[-1], + ) + ) + selected = torch.gather(powers, 1, gather_idx) + return selected.prod(dim=0).transpose(0, 1).contiguous() + + def _compute_l1_block(self, edge_quaternion: torch.Tensor) -> torch.Tensor: + """Compute the vector block directly from the Cartesian rotation matrix.""" + rot_mat = quaternion_to_rotation_matrix(edge_quaternion) + rot_perm = rot_mat.index_select(-2, self.l1_perm).index_select(-1, self.l1_perm) + return rot_perm * self.l1_sign_outer + + def _compute_l2_block(self, edge_quaternion: torch.Tensor) -> torch.Tensor: + """Compute the ``l=2`` block from the degree-4 quaternion contraction.""" + q2 = edge_quaternion.unsqueeze(-1) * edge_quaternion.unsqueeze(-2) + q4 = q2.unsqueeze(-1).unsqueeze(-1) * q2.unsqueeze(-3).unsqueeze(-3) + return torch.einsum( + "nabcd,ijabcd->nij", + q4, + self.small_order_kernels.C_l2, + ) + + def _compute_l3_block(self, edge_quaternion: torch.Tensor) -> torch.Tensor: + """Compute the ``l=3`` block from the dedicated degree-6 monomial kernel.""" + powers = self._precompute_powers(edge_quaternion, 6) + monomials = self._build_monomial_matrix( + powers, + self.small_order_kernels.exp_l3, + ) + D_flat = torch.matmul( + monomials, + self.small_order_kernels.C_l3.transpose(0, 1), + ) + return D_flat.view(edge_quaternion.shape[0], 7, 7) + + def _compute_l3l4_blocks( + self, + edge_quaternion: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute the ``l=3`` and ``l=4`` blocks from one shared degree-8 kernel.""" + powers = self._precompute_powers(edge_quaternion, 8) + monomials = self._build_monomial_matrix( + powers, + self.small_order_kernels.exp_l4, + ) + D_flat = torch.matmul( + monomials, + self.small_order_kernels.C_combined_l3l4.transpose(0, 1), + ) + D_l3 = D_flat[:, :49].view(edge_quaternion.shape[0], 7, 7) + D_l4 = D_flat[:, 49:].view(edge_quaternion.shape[0], 9, 9) + return D_l3, D_l4 + + @staticmethod + def _factorial_table( + n: int, dtype: torch.dtype, device: torch.device + ) -> torch.Tensor: + """Return ``[0!, 1!, ..., n!]`` in the requested dtype/device.""" + table = torch.zeros(n + 1, dtype=dtype, device=device) + table[0] = 1.0 + for i in range(1, n + 1): + table[i] = table[i - 1] * i + return table + + @staticmethod + def _binomial(n: int, k: int, factorial: torch.Tensor) -> float: + """Evaluate ``C(n, k)`` from a precomputed factorial table.""" + if k < 0 or k > n: + return 0.0 + return float(factorial[n] / (factorial[k] * factorial[n - k])) + + @staticmethod + def _allocate_case_coeffs( + n_primary: int, + max_poly_len: int, + dtype: torch.dtype, + device: torch.device, + ) -> CaseCoefficients: + """Allocate one branch of Horner tables for the quaternion Wigner evaluator.""" + return CaseCoefficients( + coeff=torch.zeros(n_primary, dtype=dtype, device=device), + horner=torch.zeros(n_primary, max_poly_len, dtype=dtype, device=device), + poly_len=torch.zeros(n_primary, dtype=torch.int64, device=device), + ra_exp=torch.zeros(n_primary, dtype=dtype, device=device), + rb_exp=torch.zeros(n_primary, dtype=dtype, device=device), + sign=torch.zeros(n_primary, dtype=dtype, device=device), + ) + + @staticmethod + def _compute_case_coefficients( + case: CaseCoefficients, + idx: int, + ell: int, + mp: int, + m: int, + sqrt_factor: float, + factorial: torch.Tensor, + *, + is_case1: bool, + ) -> None: + """ + Fill one Horner branch for a fixed ``(ell, mp, m)`` entry. + + The closed-form quaternion Wigner formula is reorganized so that only the ratio + ``-(|Rb|/|Ra|)^2`` or ``-(|Ra|/|Rb|)^2`` enters the Horner chain. This avoids a + large family of per-entry runtime branches and keeps the generic path stable for + every ``ell``. + """ + if is_case1: + rho_min = max(0, mp - m) + rho_max = min(ell + mp, ell - m) + else: + rho_min = max(0, -(mp + m)) + rho_max = min(ell - m, ell - mp) + + if rho_min > rho_max: + return + + if is_case1: + binom1 = WignerDCalculator._binomial(ell + mp, rho_min, factorial) + binom2 = WignerDCalculator._binomial(ell - mp, ell - m - rho_min, factorial) + else: + binom1 = WignerDCalculator._binomial(ell + mp, ell - m - rho_min, factorial) + binom2 = WignerDCalculator._binomial(ell - mp, rho_min, factorial) + case.coeff[idx] = sqrt_factor * binom1 * binom2 + + poly_len = rho_max - rho_min + 1 + case.poly_len[idx] = poly_len + for i, rho in enumerate(range(rho_max, rho_min, -1)): + if is_case1: + n1 = ell + mp - rho + 1 + n2 = ell - m - rho + 1 + d1 = rho + d2 = m - mp + rho + else: + n1 = ell - m - rho + 1 + n2 = ell - mp - rho + 1 + d1 = rho + d2 = mp + m + rho + if d1 != 0 and d2 != 0: + case.horner[idx, i] = (n1 * n2) / (d1 * d2) + + if is_case1: + case.ra_exp[idx] = 2 * ell + mp - m - 2 * rho_min + case.rb_exp[idx] = m - mp + 2 * rho_min + case.sign[idx] = (-1) ** rho_min + else: + case.ra_exp[idx] = mp + m + 2 * rho_min + case.rb_exp[idx] = 2 * ell - mp - m - 2 * rho_min + case.sign[idx] = ((-1) ** (ell - m)) * ((-1) ** rho_min) + + @staticmethod + def _finalize_case_coefficients( + case: CaseCoefficients, + max_poly_len: int, + ) -> None: + """Attach runtime-ready masks and fused coefficients for one Horner branch.""" + step_count = torch.clamp(case.poly_len - 1, min=0) + if max_poly_len > 1: + horner_step_mask = torch.arange( + max_poly_len - 1, + dtype=case.poly_len.dtype, + device=case.poly_len.device, + ).unsqueeze(0) < step_count.unsqueeze(1) + else: + horner_step_mask = torch.zeros( + case.poly_len.shape[0], + 0, + dtype=torch.bool, + device=case.poly_len.device, + ) + case.register_buffer("valid_mask", case.poly_len > 0, persistent=True) + case.register_buffer("horner_step_mask", horner_step_mask, persistent=True) + case.register_buffer("signed_coeff", case.sign * case.coeff, persistent=True) + + @staticmethod + def _vectorized_horner( + ratio: torch.Tensor, + horner_coeffs: torch.Tensor, + horner_step_mask: torch.Tensor, + ) -> torch.Tensor: + """Evaluate many varying-length Horner chains in one batched loop.""" + n_batch = ratio.shape[0] + n_elements = horner_coeffs.shape[0] + result = torch.ones(n_batch, n_elements, dtype=ratio.dtype, device=ratio.device) + if horner_step_mask.shape[1] == 0: + return result + ratio = ratio.unsqueeze(1).expand(n_batch, n_elements) + for i in range(horner_step_mask.shape[1]): + new_result = 1.0 + result * (ratio * horner_coeffs[:, i].unsqueeze(0)) + result = torch.where( + horner_step_mask[:, i].unsqueeze(0), new_result, result + ) + return result + + @staticmethod + def _compute_case_magnitude( + log_ra: torch.Tensor, + log_rb: torch.Tensor, + ratio: torch.Tensor, + case: CaseCoefficients, + ) -> torch.Tensor: + """Compute the real magnitude factor for one stable Horner branch.""" + horner_sum = WignerDCalculator._vectorized_horner( + ratio, + case.horner, + case.horner_step_mask, + ) + ra_powers = torch.exp(torch.outer(log_ra, case.ra_exp)) + rb_powers = torch.exp(torch.outer(log_rb, case.rb_exp)) + magnitude = case.signed_coeff.unsqueeze(0) * ra_powers * rb_powers + return magnitude * horner_sum + + @staticmethod + def _scatter_primary_to_matrix( + result: torch.Tensor, + D: torch.Tensor, + coeffs: WignerPolynomialCoefficients, + ) -> None: + """Scatter the explicitly stored primary entries into the dense block matrix.""" + D[:, coeffs.primary_row, coeffs.primary_col] = result + + @staticmethod + def _build_complex_to_real_sh_block( + ell: int, + *, + dtype: torch.dtype = torch.complex128, + device: torch.device, + ) -> torch.Tensor: + """ + Build the complex-to-real basis transform for one ``ell`` block. + + The packed real basis follows the SeZM convention + ``m = -ell, ..., +ell`` inside each block. This unitary transform defines the + real tesseral basis used by the packed ``D_full`` layout. + """ + size = 2 * ell + 1 + inv_sqrt2 = 1.0 / math.sqrt(2.0) + U = torch.zeros(size, size, dtype=dtype, device=device) + for m in range(-ell, ell + 1): + row = m + ell + if m == 0: + U[row, ell] = 1.0 + elif m > 0: + U[row, m + ell] = inv_sqrt2 + U[row, -m + ell] = ((-1) ** m) * inv_sqrt2 + else: + U[row, -m + ell] = -1j * inv_sqrt2 + U[row, m + ell] = ((-1) ** m) * 1j * inv_sqrt2 + return U + + @staticmethod + def _precompute_real_basis_blocks( + *, + lmin: int, + lmax: int, + dtype: torch.dtype, + device: torch.device, + ) -> list[tuple[torch.Tensor, torch.Tensor]]: + """Precompute complex-to-real basis transforms for ``ell in [lmin, lmax]``.""" + if lmin > lmax: + return [] + complex_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128 + blocks: list[tuple[torch.Tensor, torch.Tensor]] = [] + for ell in range(lmin, lmax + 1): + U = WignerDCalculator._build_complex_to_real_sh_block( + ell, + dtype=complex_dtype, + device=device, + ) + blocks.append((U.real.to(dtype=dtype), U.imag.to(dtype=dtype))) + return blocks + + @staticmethod + def _assemble_block_diagonal_real_basis( + U_blocks: list[tuple[torch.Tensor, torch.Tensor]], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Assemble per-``ell`` real-basis blocks into one block-diagonal transform.""" + if not U_blocks: + empty = torch.zeros( + 0, + 0, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) + return empty, empty, empty, empty + + size = sum(U_re.shape[0] for U_re, _ in U_blocks) + dtype = U_blocks[0][0].dtype + device = U_blocks[0][0].device + U_re_full = torch.zeros(size, size, dtype=dtype, device=device) + U_im_full = torch.zeros(size, size, dtype=dtype, device=device) + offset = 0 + for U_re, U_im in U_blocks: + block_size = U_re.shape[0] + block_end = offset + block_size + U_re_full[offset:block_end, offset:block_end] = U_re + U_im_full[offset:block_end, offset:block_end] = U_im + offset = block_end + return ( + U_re_full, + U_im_full, + U_re_full.transpose(-1, -2).contiguous(), + U_im_full.transpose(-1, -2).contiguous(), + ) + + @staticmethod + def _quaternion_to_ra_rb_real( + q: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Decompose quaternion components into the Cayley-Klein pair used by the generic path. + + For ``q = (w, x, y, z)`` the SeZM real-basis convention is aligned by + + ``Ra = w - i z`` and ``Rb = y - i x``. + + This pairing matches the packed SeZM real spherical-harmonics ordering used by + the block-diagonal ``D_full`` layout. + """ + w = q[..., 0] + x = q[..., 1] + y = q[..., 2] + z = q[..., 3] + return w, -z, y, -x + + @staticmethod + def _precompute_wigner_coefficients( + lmax: int, + *, + dtype: torch.dtype, + device: torch.device, + lmin: int = 0, + ) -> WignerPolynomialCoefficients: + """ + Precompute the generic quaternion Wigner coefficient tables. + + The runtime path only performs batched Horner evaluation and symmetry scatter. + All factorial ratios, branch exponents, and packed matrix indices are resolved once + here, which keeps the forward path independent of ``ell`` and stable for arbitrary + ``lmax``. + """ + if lmin < 0: + raise ValueError("`lmin` must be non-negative") + if lmax < lmin: + raise ValueError("`lmax` must be >= `lmin`") + + factorial = WignerDCalculator._factorial_table(2 * lmax + 1, dtype, device) + n_total = sum((2 * ell + 1) ** 2 for ell in range(lmin, lmax + 1)) + n_primary = sum( + 1 + for ell in range(lmin, lmax + 1) + for mp in range(-ell, ell + 1) + for m in range(-ell, ell + 1) + if mp + m > 0 or (mp + m == 0 and mp >= 0) + ) + n_derived = n_total - n_primary + max_poly_len = lmax + 1 + size = (lmax + 1) ** 2 - lmin * lmin + + primary_row = torch.zeros(n_primary, dtype=torch.int64, device=device) + primary_col = torch.zeros(n_primary, dtype=torch.int64, device=device) + mp_plus_m = torch.zeros(n_primary, dtype=dtype, device=device) + m_minus_mp = torch.zeros(n_primary, dtype=dtype, device=device) + diagonal_mask = torch.zeros(n_primary, dtype=torch.bool, device=device) + anti_diagonal_mask = torch.zeros(n_primary, dtype=torch.bool, device=device) + special_2m = torch.zeros(n_primary, dtype=dtype, device=device) + anti_diag_sign = torch.zeros(n_primary, dtype=dtype, device=device) + case1 = WignerDCalculator._allocate_case_coeffs( + n_primary, + max_poly_len, + dtype, + device, + ) + case2 = WignerDCalculator._allocate_case_coeffs( + n_primary, + max_poly_len, + dtype, + device, + ) + derived_row = torch.zeros(n_derived, dtype=torch.int64, device=device) + derived_col = torch.zeros(n_derived, dtype=torch.int64, device=device) + derived_primary_idx = torch.zeros(n_derived, dtype=torch.int64, device=device) + derived_sign = torch.zeros(n_derived, dtype=dtype, device=device) + + primary_map: dict[tuple[int, int], int] = {} + primary_idx = 0 + block_start = 0 + for ell in range(lmin, lmax + 1): + block_size = 2 * ell + 1 + for mp_local in range(block_size): + mp = mp_local - ell + for m_local in range(block_size): + m = m_local - ell + row = block_start + mp_local + col = block_start + m_local + is_primary = (mp + m > 0) or (mp + m == 0 and mp >= 0) + if not is_primary: + continue + + primary_map[(row, col)] = primary_idx + primary_row[primary_idx] = row + primary_col[primary_idx] = col + mp_plus_m[primary_idx] = mp + m + m_minus_mp[primary_idx] = m - mp + diagonal_mask[primary_idx] = mp == m + anti_diagonal_mask[primary_idx] = mp == -m + special_2m[primary_idx] = 2 * m + anti_diag_sign[primary_idx] = (-1) ** (ell - m) + + sqrt_factor = math.sqrt( + float(factorial[ell + m] * factorial[ell - m]) + / float(factorial[ell + mp] * factorial[ell - mp]) + ) + WignerDCalculator._compute_case_coefficients( + case1, + primary_idx, + ell, + mp, + m, + sqrt_factor, + factorial, + is_case1=True, + ) + WignerDCalculator._compute_case_coefficients( + case2, + primary_idx, + ell, + mp, + m, + sqrt_factor, + factorial, + is_case1=False, + ) + primary_idx += 1 + block_start += block_size + + derived_idx = 0 + block_start = 0 + for ell in range(lmin, lmax + 1): + block_size = 2 * ell + 1 + for mp_local in range(block_size): + mp = mp_local - ell + for m_local in range(block_size): + m = m_local - ell + row = block_start + mp_local + col = block_start + m_local + is_primary = (mp + m > 0) or (mp + m == 0 and mp >= 0) + if is_primary: + continue + + derived_row[derived_idx] = row + derived_col[derived_idx] = col + derived_primary_idx[derived_idx] = primary_map[ + (block_start + (-mp + ell), block_start + (-m + ell)) + ] + derived_sign[derived_idx] = (-1) ** (mp - m) + derived_idx += 1 + block_start += block_size + + WignerDCalculator._finalize_case_coefficients(case1, max_poly_len) + WignerDCalculator._finalize_case_coefficients(case2, max_poly_len) + + return WignerPolynomialCoefficients( + lmin=lmin, + lmax=lmax, + size=size, + max_poly_len=max_poly_len, + n_primary=n_primary, + n_derived=n_derived, + primary_row=primary_row, + primary_col=primary_col, + case1=case1, + case2=case2, + mp_plus_m=mp_plus_m, + m_minus_mp=m_minus_mp, + diagonal_mask=diagonal_mask, + anti_diagonal_mask=anti_diagonal_mask, + special_2m=special_2m, + anti_diag_sign=anti_diag_sign, + derived_row=derived_row, + derived_col=derived_col, + derived_primary_idx=derived_primary_idx, + derived_sign=derived_sign, + ) + + @staticmethod + def _wigner_d_matrix_realpair( + ra_re: torch.Tensor, + ra_im: torch.Tensor, + rb_re: torch.Tensor, + rb_im: torch.Tensor, + coeffs: WignerPolynomialCoefficients, + *, + dtype: torch.dtype | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Evaluate the complex Wigner blocks in real/imaginary form. + + The runtime path uses only real arithmetic. The complex phase is represented by + two real tensors, while the polynomial and magnitude algebra is evaluated in + ``fp64`` before the result is cast back to the requested output dtype. + """ + n_batch = ra_re.shape[0] + output_dtype = ra_re.dtype if dtype is None else dtype + if coeffs.size == 0: + zeros = torch.zeros(n_batch, 0, 0, dtype=output_dtype, device=ra_re.device) + return zeros, zeros + + ra_re = ra_re.to(torch.float64) + ra_im = ra_im.to(torch.float64) + rb_re = rb_re.to(torch.float64) + rb_im = rb_im.to(torch.float64) + if ( + coeffs.case1.coeff.dtype != torch.float64 + or coeffs.primary_row.device != ra_re.device + ): + coeffs = coeffs.to(device=ra_re.device, dtype=torch.float64) + + dtype = torch.float64 + device = ra_re.device + + eps = torch.finfo(dtype).eps + eps_sq = eps * eps + ra_sq = ra_re * ra_re + ra_im * ra_im + rb_sq = rb_re * rb_re + rb_im * rb_im + ra_small = ra_sq <= eps_sq + rb_small = rb_sq <= eps_sq + ra = torch.sqrt(torch.clamp(ra_sq, min=eps_sq)) + rb = torch.sqrt(torch.clamp(rb_sq, min=eps_sq)) + general_mask = ~ra_small & ~rb_small + use_case1 = (ra >= rb) & general_mask + use_case2 = (ra < rb) & general_mask + + safe_ra_re = torch.where(ra_small, torch.ones_like(ra_re), ra_re) + safe_ra_im = torch.where(ra_small, torch.zeros_like(ra_im), ra_im) + safe_rb_re = torch.where(rb_small, torch.ones_like(rb_re), rb_re) + safe_rb_im = torch.where(rb_small, torch.zeros_like(rb_im), rb_im) + phia = torch.atan2(safe_ra_im, safe_ra_re) + phib = torch.atan2(safe_rb_im, safe_rb_re) + + phase = torch.outer(phia, coeffs.mp_plus_m) + torch.outer( + phib, coeffs.m_minus_mp + ) + exp_phase_re = torch.cos(phase) + exp_phase_im = torch.sin(phase) + + safe_ra = torch.clamp(ra, min=eps) + safe_rb = torch.clamp(rb, min=eps) + log_ra = torch.log(safe_ra) + log_rb = torch.log(safe_rb) + + result_re = torch.zeros(n_batch, coeffs.n_primary, dtype=dtype, device=device) + result_im = torch.zeros_like(result_re) + + anti_rows = ra_small + anti_log_rb = torch.where(anti_rows, log_rb, torch.zeros_like(log_rb)) + anti_phib = torch.where(anti_rows, phib, torch.zeros_like(phib)) + rb_power_mag = torch.exp(torch.outer(anti_log_rb, coeffs.special_2m)) + rb_power_phase = torch.outer(anti_phib, coeffs.special_2m) + anti_re = ( + coeffs.anti_diag_sign.unsqueeze(0) + * rb_power_mag + * torch.cos(rb_power_phase) + ) + anti_im = ( + coeffs.anti_diag_sign.unsqueeze(0) + * rb_power_mag + * torch.sin(rb_power_phase) + ) + anti_mask = ra_small.unsqueeze(1) & coeffs.anti_diagonal_mask.unsqueeze(0) + result_re = torch.where(anti_mask, anti_re, result_re) + result_im = torch.where(anti_mask, anti_im, result_im) + + diag_rows = rb_small & ~ra_small + diag_log_ra = torch.where(diag_rows, log_ra, torch.zeros_like(log_ra)) + diag_phia = torch.where(diag_rows, phia, torch.zeros_like(phia)) + ra_power_mag = torch.exp(torch.outer(diag_log_ra, coeffs.special_2m)) + ra_power_phase = torch.outer(diag_phia, coeffs.special_2m) + diag_re = ra_power_mag * torch.cos(ra_power_phase) + diag_im = ra_power_mag * torch.sin(ra_power_phase) + diag_mask = diag_rows.unsqueeze(1) & coeffs.diagonal_mask.unsqueeze(0) + result_re = torch.where(diag_mask, diag_re, result_re) + result_im = torch.where(diag_mask, diag_im, result_im) + + ratio1 = -(rb * rb) / (safe_ra * safe_ra) + case1_rows = use_case1 + magnitude1 = WignerDCalculator._compute_case_magnitude( + torch.where(case1_rows, log_ra, torch.zeros_like(log_ra)), + torch.where(case1_rows, log_rb, torch.zeros_like(log_rb)), + torch.where(case1_rows, ratio1, torch.zeros_like(ratio1)), + coeffs.case1, + ) + val1_re = magnitude1 * exp_phase_re + val1_im = magnitude1 * exp_phase_im + mask1 = case1_rows.unsqueeze(1) & coeffs.case1.valid_mask.unsqueeze(0) + result_re = torch.where(mask1, val1_re, result_re) + result_im = torch.where(mask1, val1_im, result_im) + + ratio2 = -(ra * ra) / (safe_rb * safe_rb) + case2_rows = use_case2 + magnitude2 = WignerDCalculator._compute_case_magnitude( + torch.where(case2_rows, log_ra, torch.zeros_like(log_ra)), + torch.where(case2_rows, log_rb, torch.zeros_like(log_rb)), + torch.where(case2_rows, ratio2, torch.zeros_like(ratio2)), + coeffs.case2, + ) + val2_re = magnitude2 * exp_phase_re + val2_im = magnitude2 * exp_phase_im + mask2 = case2_rows.unsqueeze(1) & coeffs.case2.valid_mask.unsqueeze(0) + result_re = torch.where(mask2, val2_re, result_re) + result_im = torch.where(mask2, val2_im, result_im) + + D_re = torch.zeros( + n_batch, coeffs.size, coeffs.size, dtype=dtype, device=device + ) + D_im = torch.zeros_like(D_re) + WignerDCalculator._scatter_primary_to_matrix(result_re, D_re, coeffs) + WignerDCalculator._scatter_primary_to_matrix(result_im, D_im, coeffs) + + if coeffs.n_derived > 0: + primary_re = result_re[:, coeffs.derived_primary_idx] + primary_im = result_im[:, coeffs.derived_primary_idx] + derived_sign = coeffs.derived_sign.unsqueeze(0) + derived_re = derived_sign * primary_re + derived_im = -derived_sign * primary_im + D_re[:, coeffs.derived_row, coeffs.derived_col] = derived_re + D_im[:, coeffs.derived_row, coeffs.derived_col] = derived_im + + return D_re.to(output_dtype), D_im.to(output_dtype) + + @staticmethod + def _wigner_d_pair_to_real( + D_re: torch.Tensor, + D_im: torch.Tensor, + U_blocks: list[tuple[torch.Tensor, torch.Tensor]] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + *, + lmax: int, + lmin: int, + ) -> torch.Tensor: + """ + Convert complex Wigner blocks to the current real packed basis. + + Each block applies the SeZM complex-to-real basis transform for its degree. + This preserves the packed ``(l, m)`` contract of ``D_full`` and ``Dt_full``. + """ + n_batch = D_re.shape[0] + if lmin > lmax: + return torch.zeros(n_batch, 0, 0, dtype=D_re.dtype, device=D_re.device) + + if isinstance(U_blocks, list): + U_re, U_im, U_re_t, U_im_t = ( + WignerDCalculator._assemble_block_diagonal_real_basis(U_blocks) + ) + else: + U_re, U_im, U_re_t, U_im_t = U_blocks + + if U_re.dtype != D_re.dtype or U_re.device != D_re.device: + U_re = U_re.to(dtype=D_re.dtype, device=D_re.device) + U_im = U_im.to(dtype=D_re.dtype, device=D_re.device) + U_re_t = U_re_t.to(dtype=D_re.dtype, device=D_re.device) + U_im_t = U_im_t.to(dtype=D_re.dtype, device=D_re.device) + + temp_re = torch.matmul(D_re, U_re_t) + torch.matmul(D_im, U_im_t) + temp_im = torch.matmul(D_im, U_re_t) - torch.matmul(D_re, U_im_t) + return torch.matmul(U_re, temp_re) - torch.matmul(U_im, temp_im) + + def serialize(self) -> dict[str, Any]: + """Serialize WignerDCalculator (lmax and dtype are stored by parent).""" + return { + "@class": "WignerDCalculator", + "@version": 1, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> WignerDCalculator: + """Deserialize WignerDCalculator - parent handles lmax/dtype reconstruction.""" + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "WignerDCalculator": + raise ValueError(f"Invalid class for WignerDCalculator: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + raise NotImplementedError( + "WignerDCalculator.deserialize should be called by parent with lmax/dtype" + ) diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 24075412db..3aa6b4613c 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -29,6 +29,9 @@ from deepmd.pt.model.task import ( BaseFitting, ) +from deepmd.pt.model.task.sezm_ener import ( + SeZMEnergyFittingNet, +) from deepmd.utils.spin import ( Spin, ) @@ -69,6 +72,12 @@ from .property_model import ( PropertyModel, ) +from .sezm_model import ( + SeZMModel, +) +from .sezm_spin_model import ( + SeZMSpinModel, +) from .spin_model import ( SpinEnergyModel, SpinModel, @@ -102,14 +111,19 @@ def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple: return descriptor, fitting, fitting_net["type"] -def get_spin_model(model_params: dict) -> SpinModel: - model_params = copy.deepcopy(model_params) +def _normalize_spin_use_spin(model_params: dict) -> None: + """Normalize spin.use_spin from type indices to per-type booleans.""" if not model_params["spin"]["use_spin"] or isinstance( model_params["spin"]["use_spin"][0], int ): use_spin = np.full(len(model_params["type_map"]), False, dtype=bool) use_spin[model_params["spin"]["use_spin"]] = True model_params["spin"]["use_spin"] = use_spin.tolist() + + +def get_spin_model(model_params: dict) -> SpinModel: + model_params = copy.deepcopy(model_params) + _normalize_spin_use_spin(model_params) # include virtual spin and placeholder types model_params["type_map"] += [item + "_spin" for item in model_params["type_map"]] spin = Spin( @@ -288,6 +302,152 @@ def get_standard_model(model_params: dict) -> BaseModel: return model +def get_sezm_model(model_params: dict) -> BaseModel: + model_params_old = model_params + model_params = copy.deepcopy(model_params) + model_params.setdefault("descriptor", {}) + model_params.setdefault("fitting_net", {}) + + ntypes = len(model_params["type_map"]) + model_params["descriptor"]["ntypes"] = ntypes + model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"]) + descriptor_exclude_types = [ + list(pair) for pair in (model_params["descriptor"].get("exclude_types") or []) + ] + if "pair_exclude_types" in model_params: + pair_exclude_types = [ + list(pair) for pair in (model_params["pair_exclude_types"] or []) + ] + if descriptor_exclude_types and descriptor_exclude_types != pair_exclude_types: + raise ValueError( + "SeZM `pair_exclude_types` and `descriptor.exclude_types` must match " + "when both are provided." + ) + else: + pair_exclude_types = descriptor_exclude_types + model_params["pair_exclude_types"] = pair_exclude_types + model_params["descriptor"]["exclude_types"] = copy.deepcopy(pair_exclude_types) + + # === Bridging parameters === + bridging_method = str(model_params.get("bridging_method", "none")).upper() + bridging_r_inner = float(model_params.get("bridging_r_inner", 0.8)) + bridging_r_outer = float(model_params.get("bridging_r_outer", 1.2)) + # Only inject bridging parameters when bridging is enabled. + if bridging_method != "NONE": + model_params["descriptor"]["inner_clamp_r_inner"] = bridging_r_inner + model_params["descriptor"]["inner_clamp_r_outer"] = bridging_r_outer + + descriptor = BaseDescriptor(**model_params["descriptor"]) + + fitting_net = copy.deepcopy(model_params["fitting_net"]) + fitting_net.pop("type", None) + fitting_net["ntypes"] = descriptor.get_ntypes() + fitting_net["type_map"] = copy.deepcopy(model_params["type_map"]) + fitting_net["mixed_types"] = descriptor.mixed_types() + fitting_net["dim_descrpt"] = descriptor.get_dim_out() + fitting = SeZMEnergyFittingNet(**fitting_net) + atom_exclude_types = model_params.get("atom_exclude_types", []) + preset_out_bias = model_params.get("preset_out_bias") + preset_out_bias = _convert_preset_out_bias_to_array( + preset_out_bias, model_params["type_map"] + ) + data_stat_protect = model_params.get("data_stat_protect", 1e-2) + use_compile = bool(model_params.get("use_compile", False)) + enable_tf32 = bool(model_params.get("enable_tf32", True)) + + model = SeZMModel( + descriptor=descriptor, + fitting=fitting, + type_map=model_params["type_map"], + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + preset_out_bias=preset_out_bias, + data_stat_protect=data_stat_protect, + use_compile=use_compile, + enable_tf32=enable_tf32, + bridging_method=bridging_method, + bridging_r_inner=bridging_r_inner, + bridging_r_outer=bridging_r_outer, + ) + model.model_def_script = json.dumps(model_params_old) + return model + + +def get_sezm_spin_model(model_params: dict) -> BaseModel: + model_params_old = model_params + model_params = copy.deepcopy(model_params) + model_params.setdefault("descriptor", {}) + model_params.setdefault("fitting_net", {}) + _normalize_spin_use_spin(model_params) + real_sel = model_params["descriptor"].get("sel", 120) + real_sel_list = [int(real_sel)] if isinstance(real_sel, int) else list(real_sel) + real_nsel = int(sum(real_sel_list)) + model_params["descriptor"]["sel"] = [2 * real_nsel + 1] + + spin = Spin( + use_spin=model_params["spin"]["use_spin"], + virtual_scale=model_params["spin"]["virtual_scale"], + ) + model_params["type_map"] += [item + "_spin" for item in model_params["type_map"]] + pair_exclude_types = spin.get_pair_exclude_types( + exclude_types=model_params.get("pair_exclude_types", None) + ) + model_params["pair_exclude_types"] = pair_exclude_types + model_params["descriptor"]["exclude_types"] = pair_exclude_types + atom_exclude_types = spin.get_atom_exclude_types( + exclude_types=model_params.get("atom_exclude_types", None) + ) + model_params["atom_exclude_types"] = atom_exclude_types + + ntypes = len(model_params["type_map"]) + model_params["descriptor"]["ntypes"] = ntypes + model_params["descriptor"]["type_map"] = copy.deepcopy(model_params["type_map"]) + + # === Bridging parameters === + bridging_method = str(model_params.get("bridging_method", "none")).upper() + bridging_r_inner = float(model_params.get("bridging_r_inner", 0.8)) + bridging_r_outer = float(model_params.get("bridging_r_outer", 1.2)) + if bridging_method != "NONE": + model_params["descriptor"]["inner_clamp_r_inner"] = bridging_r_inner + model_params["descriptor"]["inner_clamp_r_outer"] = bridging_r_outer + + descriptor = BaseDescriptor(**model_params["descriptor"]) + + fitting_net = copy.deepcopy(model_params["fitting_net"]) + fitting_net.pop("type", None) + fitting_net["ntypes"] = descriptor.get_ntypes() + fitting_net["type_map"] = copy.deepcopy(model_params["type_map"]) + fitting_net["mixed_types"] = descriptor.mixed_types() + fitting_net["dim_descrpt"] = descriptor.get_dim_out() + fitting = SeZMEnergyFittingNet(**fitting_net) + preset_out_bias = model_params.get("preset_out_bias") + preset_out_bias = _convert_preset_out_bias_to_array( + preset_out_bias, model_params["type_map"] + ) + data_stat_protect = model_params.get("data_stat_protect", 1e-2) + use_compile = bool(model_params.get("use_compile", False)) + enable_tf32 = bool(model_params.get("enable_tf32", True)) + + model = SeZMSpinModel( + descriptor=descriptor, + fitting=fitting, + type_map=model_params["type_map"], + atom_exclude_types=atom_exclude_types, + pair_exclude_types=pair_exclude_types, + preset_out_bias=preset_out_bias, + data_stat_protect=data_stat_protect, + use_compile=use_compile, + enable_tf32=enable_tf32, + bridging_method=bridging_method, + bridging_r_inner=bridging_r_inner, + bridging_r_outer=bridging_r_outer, + real_sel=real_sel_list, + spin=spin, + ) + model.model_def_script = json.dumps(model_params_old) + return model + + def get_model(model_params: dict) -> Any: model_type = model_params.get("type", "standard") if model_type == "standard": @@ -299,6 +459,10 @@ def get_model(model_params: dict) -> Any: return get_standard_model(model_params) elif model_type == "linear_ener": return get_linear_model(model_params) + elif model_type in ("SeZM", "sezm", "dpa4"): + if "spin" in model_params: + return get_sezm_spin_model(model_params) + return get_sezm_model(model_params) else: return BaseModel.get_class_by_type(model_type).get_model(model_params) @@ -313,9 +477,12 @@ def get_model(model_params: dict) -> Any: "FrozenModel", "LinearEnergyModel", "PolarModel", + "SeZMModel", + "SeZMSpinModel", "SpinEnergyModel", "SpinModel", "get_model", + "get_sezm_spin_model", "make_hessian_model", "make_model", ] diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py new file mode 100644 index 0000000000..534d95f084 --- /dev/null +++ b/deepmd/pt/model/model/sezm_model.py @@ -0,0 +1,2706 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""SeZM: Smooth equivariant Zone-bridging Model. + +This module hosts the full ``torch.compile`` + ``make_fx`` pipeline that +runs the SeZM energy (``ener``) path on the GPU. To the authors' +knowledge this is the first public implementation of a compiled, +dynamically shaped machine-learning potential whose *second-order* +derivatives -- required by force-loss training -- travel end-to-end +through Inductor without any eager fallback. The ``dens`` path below +uses a plain ``torch.compile`` wrapper and is not covered by the rest of +this docstring. + +Why force-loss training is hard to compile +========================================== + +An ML potential models atomic energy ``E(x, theta)`` from coordinates +``x`` and parameters ``theta``. Force-loss training minimizes + +:: + + L = alpha * ||E_pred - E_label||^2 + beta * ||f_pred - f_label||^2 + +with ``f_pred = -dE/dx``. The parameter update needs ``dL/dtheta``, +which contains ``d(f_pred)/dtheta = -d^2 E / (dx dtheta)`` -- a full +second-order derivative of the network with respect to one input and +one parameter axis. + +The standard ``torch.compile`` stack (AOT Autograd) captures forward and +first backward; it does *not* natively handle an +``autograd.grad(..., create_graph=True)`` call nested *inside* the +compiled region. So we compose two lower-level tools: + +1. ``make_fx`` traces the compute function *after* the inner + ``autograd.grad`` has been materialised, producing an FX graph whose + forward already contains the first-derivative graph as ordinary ops. +2. ``torch.compile(..., dynamic=True)`` lowers that traced FX graph to + Inductor. Because the graph no longer hides an autograd call, + Inductor's normal backward pipeline can differentiate the whole + thing a second time for the optimizer step. + +Everything else in this file exists to make that composition correct +under dynamic shapes, FSDP/DDP, and the list of PyTorch bugs that +surface along the way. Every non-obvious choice is pinned to a source +comment tagged ``NOTE:``; the numbered catalogue at the bottom of this +docstring explains each tag in depth. + +Pipeline for one training batch +=============================== + +:: + + forward(coord, atype, ...) + |-- input dtype cast + |-- neighbor list built in the extended region + '-- forward_common -- ener branch + |-- extended_coord.detach().requires_grad_(True) (NOTE 9) + |-- should_use_compile()? yes -> + | |-- trace_and_compile() on cache miss + | | |-- make_fx(compute_fn, + | | | tracing_mode="symbolic", + | | | _allow_non_fake_inputs=True, + | | | decomposition_table=) (NOTE 0) + | | | * trace inputs are nf=2 copies (NOTE 1) + | | | * silu_backward is decomposed (NOTE 2) + | | | * traced graph already contains the + | | | first autograd.grad over coords + | | |-- _strip_saved_tensor_detach (train only) (NOTE 3) + | | |-- _rebuild_graph_module (NOTE 4) + | | '-- torch.compile(backend="inductor", + | | dynamic=True, + | | options=) (NOTE 6) + | | stored in compiled_core_compute_cache[key] (NOTE 8) + | '-- compiled_core_compute_cache[key](...) + '-- communicate_extended_output + +Subsequent batches look up the cached callable at the same +``(training, do_atomic_virial, has_coord_corr)`` slot of +``compiled_core_compute_cache``. Each slot is retained independently, so +train <-> eval toggles around every ``disp_freq`` / full-validation checkpoint +reuse the other slot's compile product instead of evicting it (NOTE 7). + +Body of the traced compute +========================== + +``compute_fn`` (defined inside ``trace_and_compile``) wraps +``core_compute`` so that make_fx sees a pure tensor-in / tensor-out +function: + +* ``core_compute`` rebuilds a compact, GPU-friendly edge list from the + padded DeePMD neighbor list (``build_edge_list_from_nlist``), with a + single masked dummy edge appended so the edge tensor is never empty + (NOTE 10). Edge vectors come from ``index_select`` on the extended + coordinate tensor, which keeps the gradient path back to coordinates + explicit and safe under symbolic shapes (NOTE 11). +* The SeZM descriptor consumes the edge list and produces per-atom + features. +* The fitting network predicts per-atom energy; ``apply_out_stat`` adds + the per-type statistics and the atom mask zeroes out padding atoms. +* ``fit_output_to_model_output(..., create_graph=self.training)`` calls + ``autograd.grad`` internally to compute ``force = -dE/dx``. + ``create_graph`` is the single toggle that activates the + second-derivative branch for training and omits it at inference + (NOTE 12). + +Because ``make_fx`` traces *after* that inner ``autograd.grad`` has +executed, the resulting FX graph encodes both the forward and the first +derivative as ordinary ops. Any further ``.backward()`` on the compiled +output therefore just walks an FX-level backward that Inductor is +perfectly capable of lowering. + +The ``NOTE:`` catalogue +======================= + +NOTE 0 -- ``make_fx(tracing_mode="symbolic", _allow_non_fake_inputs=True)`` +-------------------------------------------------------------------------- + +``tracing_mode="symbolic"`` tells the proxy tensor that shapes are +sympy-backed symbols; it is what makes ``dynamic=True`` compile work +later. ``_allow_non_fake_inputs=True`` lets us feed *real* tensors +(not FakeTensors) to the trace. We need real data because the edge +compactor contains data-dependent operations (``torch.nonzero``, +``index_select``) that cannot be executed on FakeTensors; the shapes +become symbolic immediately after the first op, so only the control +flow is decided by concrete values. + +NOTE 1 -- Tracing with ``nf=2`` +------------------------------- + +``make_fx(tracing_mode="symbolic")`` replaces tensor shapes with sympy +symbols at trace time, but the moment a symbolic dim ends up equal to a +concrete dim elsewhere in the same tensor it collapses into a constant. +Concretely: + +* ``nf=1`` triggers PyTorch's 0/1 specialization and bakes ``nf`` into + the graph. +* ``nf=3`` collides with the spatial ``3`` in ``extended_coord`` whose + shape is ``(nf, nall, 3)``. +* ``nf=9`` would collide with the virial dim. + +Any of those collisions forces ``torch.compile(dynamic=True)`` to reject +later batches whose ``nf`` differs from the traced constant. ``nf=2`` +is the smallest batch size free of every known collision; we always +repeat the first frame twice to satisfy this invariant during tracing. + +NOTE 2 -- Decomposing ``silu_backward`` +--------------------------------------- + +PyTorch ships forward and first-order backward for SiLU but *no* +symbolic higher-order derivative. make_fx therefore emits +``aten.silu_backward.default`` opaquely inside the first-derivative +graph. When Inductor later has to differentiate that op again for the +optimizer step, it refuses because silu_backward is not differentiable +in its registered form. We pass an explicit decomposition +``silu_backward -> sigmoid + pointwise mul`` to ``make_fx``; every +pointwise piece then has a well-defined higher derivative of its own. + +NOTE 3 -- Stripping autograd-inserted detach chains +--------------------------------------------------- + +When ``autograd.grad(create_graph=True)`` runs under make_fx, the +autograd engine wraps every saved forward activation in a double-detach +chain, e.g.:: + + tanh -> detach_A -> detach_B -> tanh_backward + +In eager autograd those detaches are informational -- they mark saved +tensors as belonging to a different graph. After tracing, however, +they become ordinary ops inside the FX graph and sever the gradient +path from the force loss back to ``theta``; training then silently +produces zero parameter updates for the second-derivative term. + +``_strip_saved_tensor_detach`` removes them by pure graph topology -- +no op-name matching -- so that user-explicit ``.detach()`` calls +(e.g. cached SO2 weights, activation lookup matrices) survive: + +* *Chain inner*: input is another detach. +* *Dead node*: no downstream users. +* *Chain head*: every user is a detach. + +Any detach that matches none of the three is treated as user intent and +is kept verbatim. Stripping is guarded by ``self.training`` because +eval mode does not set ``create_graph=True``; the chain is never +inserted and removing it would be incorrect. + +NOTE 4 -- Rebuilding the FX graph from scratch +---------------------------------------------- + +``Graph.erase_node`` inside ``_strip_saved_tensor_detach`` unlinks nodes +from the doubly linked list that represents the graph. On several +PyTorch builds (observed on 2.11+cu130) it leaves the C-level +``prev/next`` pointers of *neighbouring* Node objects stale. Dynamo, +when it later re-traces the ``GraphModule`` and walks ``graph.nodes`` +inside ``output_graph.py:_create_proxy`` to read ``nd.meta``, +dereferences one of those stale pointers and segfaults. + +``_rebuild_graph_module`` does a single ``node_copy`` pass into a +freshly allocated ``torch.fx.Graph``. The result is an equivalent graph +whose linked list contains no erased entries, so dynamo can iterate it +safely. We always rebuild -- including in eval -- because a fresh +graph is cheap while a segfault is fatal. + +NOTE 5 -- Disabling ``DDPOptimizer`` +------------------------------------ + +``torch._dynamo.config.optimize_ddp = False`` is set unconditionally at +import time. DDPOptimizer is designed to split a DDP-wrapped model's +graph at bucket boundaries so that gradients can overlap with +all-reduce. But here the compile region is *inside* the DDP-wrapped +model -- it wraps only ``core_compute``. DDPOptimizer assumes it owns +the whole model, splits our inner graph at its internal bucket +heuristic, and the split produces subgraphs whose outputs include +symbolic integers. AOT Autograd then crashes with +``'int' object has no attribute 'meta'`` (pytorch/pytorch#134182). +Disabling the optimizer globally is safe because SeZM always owns its +own compile boundary and the surrounding DDP wrapper operates on the +full model call. + +NOTE 6 -- Inductor / Triton option lockdown +------------------------------------------- + +``torch.compile(backend="inductor", dynamic=True, options=...)`` is +configured with: + +* ``max_autotune=False`` + Autotune regresses on dynamic shapes because each recompile rolls + the search; deterministic kernels compiled once are consistently + faster on our edge-level reductions. +* ``shape_padding=True`` + Pads tensors to SIMD-friendly sizes when symbolic shapes + fluctuate batch-to-batch, eliminating tail-kernel generation cost. +* ``epilogue_fusion=False`` + Two independent reasons to keep it off. (a) Inductor only + enables epilogue fusion when ``max_autotune`` is on, and we + deliberately disable autotune above; leaving the flag on would + pay the scheduling cost without ever activating the fusion. + (b) Fused epilogues occasionally reorder saved tensors in ways + the second backward cannot recover; disabling the fusion keeps + the backward graph shape-stable under make_fx. +* ``triton.cudagraphs=False`` + cudagraphs capture autograd metadata only once. Higher-order + gradients need fresh metadata per call, so cudagraphs would feed + stale autograd state into the second backward. +* ``max_fusion_size`` -- mode-dependent + Caps kernel fusion complexity so Inductor's scheduler does not + time out on the large edge-level reductions inside the + descriptor when nsel is big. Training uses ``64`` (the long- + standing default, observed stable on every training run so far); + inference uses the tighter ``8`` to dodge the Triton lowering + failure described by the next bullet. +* ``triton.persistent_reductions=False`` -- inference only + Inductor's persistent-reduction scheduler fuses a ``sum`` with + *all* neighbouring pointwise ops (``tanh_backward``, ``pow``, + ``exp``, ``mul``, ``select``, ``slice``, ``view`` ...) into one + ``triton_per_fused_...`` kernel. On the graph emitted by + inference (``create_graph=False``, no double-detach stripping, + different fused topology than training) this kernel hits Triton + bug ``PassManager::run failed`` inside ``make_ttgir``. Training + never produces the same fused shape and does not benefit from + disabling the optimisation, so the flag is left on for training + to preserve kernel quality. +* ``triton.mix_order_reduction=False`` + Workaround for PyTorch <=2.11 bugs pytorch/pytorch#174379, + #178080, #179494. All three manifest only under data-dependent + symbolic shapes -- exactly our edge count. + +NOTE 7 -- Multi-slot compile cache key +-------------------------------------- + +The key is ``(training, do_atomic_virial, has_coord_corr)`` because all three +fields alter the traced graph topology: + +* ``self.training`` switches ``create_graph`` in + ``fit_output_to_model_output`` -- it toggles the entire + second-derivative branch on or off. +* ``do_atomic_virial`` adds or removes an extra per-atom virial tensor + in the compute output. +* ``has_coord_corr`` selects the spin-virial correction branch, changing the + compiled callable arity from six tensor inputs to seven. + +No single compiled graph can serve both variants, so the cache is a +``dict[tuple[bool, bool, bool], Callable]`` named +``compiled_core_compute_cache``. A single-slot +cache would have to evict on every flip, which turns the normal +training-loop pattern -- ``train -> eval at every disp_freq -> train`` +and an occasional full validation on top of that -- into +two-recompile-per-disp_freq thrashing (each recompile costs tens of +seconds to minutes on SeZM). With multi-slot caching the first +encounter of each mode pays the compile cost once, and every later +toggle is a dict lookup. + +Enabling compile for eval is an opt-in via ``DP_COMPILE_INFER=1`` +(``should_use_compile`` returns ``_env_use_compile_infer`` when +``self.training`` is ``False``). Once enabled, regular validation, +full validation and EMA full validation all reuse the eval slot. + +NOTE 8 -- Storing the compile cache outside the ``nn.Module`` tree +------------------------------------------------------------------ + +The cache dict is installed via ``object.__setattr__(self, ...)`` at +__init__ time rather than plain ``self.compiled_core_compute_cache = {}``, and +every later mutation writes into that same dict in place. +``nn.Module.__setattr__`` would register any module-looking value as a +submodule; the compiled wrappers held as *values* of this dict carry +duplicated flat views of the trainable parameters, and FSDP2 / DDP +would then shard or synchronise those duplicates and silently corrupt +training. A plain ``dict`` container escapes parameter discovery +entirely because ``nn.Module.__setattr__`` only recognises +``nn.Parameter`` / ``nn.Module`` values, and ``named_parameters`` / +``named_modules`` walk ``self._parameters`` / ``self._modules``, never +arbitrary attributes; ``object.__setattr__`` merely belt-and-braces +this invariant for readers of the constructor. + +NOTE 9 -- Graph restart via ``detach().requires_grad_(True)`` +------------------------------------------------------------- + +Before calling into the traced graph we rebind the extended coordinates +to a fresh leaf tensor: ``detach()`` breaks any upstream autograd graph +carried over from the data pipeline, and ``requires_grad_(True)`` +reinstates a grad-endpoint owned by this forward. The subsequent +``autograd.grad`` in ``fit_output_to_model_output`` therefore computes +``dE/dx`` against a graph of known shape and ownership -- the essential +precondition for make_fx symbolic tracing. + +In eval mode we merely detach; no ``create_graph`` is requested, so the +compiled kernel never has to build a backward graph. + +NOTE 10 -- Tail dummy edge +-------------------------- + +``build_edge_list_from_nlist`` appends exactly one masked edge at the +end of every batch. Real edge compaction happens via +``torch.nonzero(valid_mask)``, whose output length is data-dependent +and can be zero in sparse or single-type systems. make_fx cannot trace +an "if n_edges == 0: skip" branch symbolically; without the dummy it +would fall back to concrete shape specialization and break +``dynamic=True``. The dummy's ``edge_mask`` is ``False`` so it +contributes exactly zero to every downstream sum or gather. + +NOTE 11 -- ``index_select`` for coordinate gradients +---------------------------------------------------- + +Edge geometry is built with ``coord_flat.index_select(0, src)`` instead +of advanced indexing ``coord_flat[src]``. ``index_select`` registers +an explicit backward that routes gradient cleanly back to the original +extended coordinate tensor. Advanced indexing combined with make_fx +symbolic shapes has previously produced silent gradient truncation in +this project -- the second-derivative gradient over coordinates was +effectively zero, with no error raised. + +NOTE 12 -- ``create_graph=self.training`` +----------------------------------------- + +The single toggle that turns force-loss training on. When ``True``, +``autograd.grad`` keeps the graph over the first derivative alive so +the outer optimizer's ``.backward()`` can continue walking it into the +parameters. When ``False`` the double-backward graph is never built, +saving memory during inference. +""" + +from __future__ import ( + annotations, +) + +import logging +import os +import time +from contextlib import ( + contextmanager, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch +from einops import ( + rearrange, +) +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +if TYPE_CHECKING: + from collections.abc import Generator + + from jaxtyping import Float, Int + from torch import Tensor + +from deepmd.pt.model.atomic_model.sezm_atomic_model import ( + SeZMAtomicModel, +) +from deepmd.pt.model.descriptor.sezm_nn import ( + nvtx_range, +) +from deepmd.pt.model.model.dp_model import ( + DPModelCommon, +) +from deepmd.pt.model.model.make_model import ( + make_model, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) +from deepmd.pt.model.model.transform_output import ( + communicate_extended_output, + fit_output_to_model_output, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +log = logging.getLogger(__name__) + +SeZMModel_ = make_model(SeZMAtomicModel) + +# NOTE: Silence Inductor / Triton autotune dumps before any submodule is +# imported. ``torch.compile`` reads these environment variables exactly +# once at backend initialisation; setting them after the first compile +# would have no effect in the current run. ``setdefault`` preserves any +# explicit user-level override. +os.environ.setdefault("TORCHINDUCTOR_MAX_AUTOTUNE_REPORT_CHOICES_STATS", "0") +os.environ.setdefault("TRITON_PRINT_AUTOTUNING", "0") + +# NOTE: Disable DDPOptimizer graph splitting globally. +# ``compiled_core_compute_cache`` entries / ``compiled_dens_compute`` are inner +# ``torch.compile`` calls sitting *inside* a DDP-wrapped model; +# DDPOptimizer assumes it sees the *whole* model and splits the FX graph +# at DDP bucket boundaries. For an inner submodule that heuristic +# produces subgraphs whose outputs include symbolic integers, which then +# crash aot_autograd with ``'int' object has no attribute 'meta'``. +# See https://github.com/pytorch/pytorch/issues/134182. Turning the +# optimizer off globally is safe because SeZM always owns its own compile +# boundary and the surrounding DDP wrapper operates on the full model +# call. +import torch._dynamo.config as _dynamo_cfg + +_dynamo_cfg.optimize_ddp = False + + +def _parse_optional_env_bool(var_name: str) -> bool | None: + """ + Parse an optional boolean environment variable. + + Parameters + ---------- + var_name + Environment variable name. + + Returns + ------- + bool | None + Parsed boolean value, or ``None`` when the variable is unset. + + Raises + ------ + ValueError + If the environment variable value is not a supported boolean token. + """ + raw_value = os.environ.get(var_name) + if raw_value is None: + return None + normalized = raw_value.strip().lower() + if normalized in {"1", "true", "yes", "on"}: + return True + if normalized in {"0", "false", "no", "off"}: + return False + raise ValueError( + f"{var_name} must be one of 1/0/true/false/yes/no/on/off, got {raw_value!r}" + ) + + +def _strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None: + """Strip ``aten.detach`` nodes that ``make_fx`` inserts for saved tensors. + + When ``make_fx`` decomposes ``autograd.grad(..., create_graph=True)``, + the autograd engine wraps every saved forward activation in a double-detach + chain (e.g. ``tanh -> detach_A -> detach_B -> tanh_backward``). These + detach nodes block the second-order gradient path from the loss back to + model parameters, causing incorrect parameter updates during force-loss + training. + + User-explicit ``.detach()`` calls (e.g. inside ``attach_edge_vec_grad``) + are preserved. The two categories are distinguished by graph topology + alone — no hard-coded op names — using three rules: + + * *Chain inner*: input is another detach node. + * *Dead node*: no downstream users. + * *Chain head*: *all* users are detach nodes. + + Any detach that does **not** match these rules is treated as user-explicit + and left untouched. + """ + _DETACH = torch.ops.aten.detach.default + + def _is_detach(n: torch.fx.Node) -> bool: + return n.op == "call_function" and n.target == _DETACH + + # NOTE: Pass 1 -- classify every detach against the *original* graph. + # If we erased nodes eagerly, later classifications would walk a + # mutated neighbourhood and misjudge the chain-inner / chain-head / + # dead boundaries; the double-detach pattern in particular flips + # class within a single erase. Collecting first, mutating second + # keeps the topology rules well-defined. + to_remove: list[torch.fx.Node] = [] + for node in gm.graph.nodes: + if not _is_detach(node): + continue + input_node = node.args[0] + users = list(node.users.keys()) + is_chain_inner = _is_detach(input_node) + is_dead = len(users) == 0 + is_chain_head = len(users) > 0 and all(_is_detach(u) for u in users) + if is_chain_inner or is_dead or is_chain_head: + to_remove.append(node) + + # NOTE: Pass 2 -- rewire + erase atomically after the full + # classification. ``replace_all_uses_with`` forwards every consumer + # to the detach's input; ``erase_node`` then removes the now-dead + # detach. Doing both back-to-back means the graph never sits in a + # half-consistent state where one user sees the old detach and + # another the rewired source. + for node in to_remove: + node.replace_all_uses_with(node.args[0]) + gm.graph.erase_node(node) + + gm.graph.lint() + gm.recompile() + + +def _rebuild_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Return a fresh ``GraphModule`` whose node linked-list is newly allocated. + + After ``_strip_saved_tensor_detach`` erases nodes via + ``Graph.erase_node()``, the internal doubly-linked list may retain + stale pointers to erased nodes. When ``torch.compile`` later + triggers dynamo re-tracing and iterates ``graph.nodes`` to read + ``nd.meta`` (``output_graph.py:_create_proxy``), accessing these + stale entries causes a segfault. + + Copying every node into a brand-new ``Graph`` builds a clean linked + list from scratch, side-stepping the corruption entirely. + """ + old_graph = gm.graph + new_graph = torch.fx.Graph() + # node_copy needs a mapper from old nodes to their copies in new_graph. + val_map: dict[torch.fx.Node, torch.fx.Node] = {} + for node in old_graph.nodes: + val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) + new_graph.lint() + new_gm = torch.fx.GraphModule(gm, new_graph) + return new_gm + + +@BaseModel.register("SeZM") +@BaseModel.register("sezm") +@BaseModel.register("dpa4") +class SeZMModel(DPModelCommon, SeZMModel_): + """ + SeZM energy model with an optional compiled sparse-edge path. + + By default it uses the traditional DeePMD neighbor list path with ghost atoms + and padded neighbor matrix, compatible with LAMMPS and other MD engines. + When `use_compile=True`, it builds a compact sparse edge list from the + standard neighbor list and traces the local graph with ``make_fx`` for + higher-order force training. Evaluation/inference compile usage is + controlled by the `DP_COMPILE_INFER` environment variable read at model + initialization time. + """ + + model_type = "SeZM" + + def __init__( + self, + *args: Any, + use_compile: bool = False, + enable_tf32: bool = True, + bridging_method: str = "none", + bridging_r_inner: float = 0.8, + bridging_r_outer: float = 1.2, + lora: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + DPModelCommon.__init__(self) + SeZMModel_.__init__(self, *args, **kwargs) + self.redu_prec = env.GLOBAL_PT_ENER_FLOAT_PRECISION + self.use_compile = bool(use_compile) + self.enable_tf32 = bool(enable_tf32) + # LoRA injection happens in Trainer.__init__ after pre-trained state is loaded. + self.lora_config: dict[str, Any] | None = None if lora is None else dict(lora) + self._dens_compiled = False + self._core_compute_pending_compile_t0: float | None = None + self._core_compute_pending_compile_key: tuple[bool, bool, bool] | None = None + self._dens_pending_compile_t0: float | None = None + # Store compiled callables outside the nn.Module tree so that + # FSDP2 / DDP do not shard or sync its duplicated parameters. + # ``compiled_core_compute_cache`` is keyed on + # ``(training, do_atomic_virial, has_coord_corr)`` so every graph + # topology has its own slot; flipping between train and eval for + # validation -- regular, full, or EMA full -- therefore reuses cached + # compile products instead of evicting the other mode. + object.__setattr__(self, "compiled_core_compute_cache", {}) + object.__setattr__(self, "compiled_dens_compute", None) + # Training follows `use_compile`. Evaluation/inference reads + # `DP_COMPILE_INFER` at init time and falls back to eager when unset. + self._env_use_compile_infer: bool | None = _parse_optional_env_bool( + "DP_COMPILE_INFER" + ) + + # === Bridging (optional short-range zone bridging) === + self.bridging_method: str = str(bridging_method).upper() + self.bridging_r_inner = float(bridging_r_inner) + self.bridging_r_outer = float(bridging_r_outer) + self.inter_potential: InterPotential | None = ( + InterPotential(type_map=self.get_type_map(), mode=self.bridging_method) + if self.bridging_method != "NONE" + else None + ) + + # ========================================================================= + # Forward Methods + # ========================================================================= + + def forward( + self, + coord: Float[Tensor, "nf nloc 3"] | Float[Tensor, "nf nloc_x3"], + atype: Int[Tensor, "nf nloc"], + box: Float[Tensor, "nf 9"] | None = None, + fparam: Float[Tensor, "nf ndf"] | None = None, + aparam: Float[Tensor, "nf nloc nda"] | None = None, + do_atomic_virial: bool = False, + force_input: Float[Tensor, "nf nloc 3"] | None = None, + noise_mask: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """ + Forward pass using standard neighbor list. + + Parameters + ---------- + coord + Coordinates with shape (nf, nloc*3) or (nf, nloc, 3) in Å. + atype + Atom types with shape (nf, nloc). + box + Box tensor with shape (nf, 9) in Å, or None. + fparam + Frame parameters with shape (nf, ndf) or None. + aparam + Atomic parameters with shape (nf, nloc, nda) or None. + do_atomic_virial + Whether to compute atomic virial. + force_input + Optional atom-wise force input tensor with shape `(nf, nloc, 3)`. + It stays optional at the public model boundary because validation / + inference and clean `dens` batches may not provide force labels. + noise_mask + Optional corruption mask with shape `(nf, nloc)`. It stays optional + at the public model boundary because validation / inference and + clean `dens` batches may not provide corruption masks. + charge_spin + Frame-level charge and spin conditions with shape `(nf, 2)`. + + Returns + ------- + dict[str, torch.Tensor] + Model predictions including atom_energy, energy, force, virial, + atom_virial, and mask. + """ + model_ret = self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + force_input=force_input, + noise_mask=noise_mask, + charge_spin=charge_spin, + ) + if self.get_fitting_net() is not None: + model_predict: dict[str, torch.Tensor] = {} + + # === Step 1. Energy === + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + + # === Step 2. Force (independent branch) === + if self.do_grad_r("energy"): + model_predict["force"] = rearrange( + model_ret["energy_derv_r"], + "nf nloc 1 three -> nf nloc three", + three=3, + ) + else: + model_predict["force"] = model_ret["dforce"] + + if self.get_active_mode() == "dens": + if "energy_norm" in model_ret: + model_predict["energy_norm"] = model_ret["energy_norm"] + if "atom_energy_norm" in model_ret: + model_predict["atom_energy_norm"] = model_ret["atom_energy_norm"] + if "dforce_norm" in model_ret: + model_predict["force_norm"] = model_ret["dforce_norm"] + if "clean_dforce_norm" in model_ret: + model_predict["clean_force_norm"] = model_ret["clean_dforce_norm"] + if "denoising_dforce_norm" in model_ret: + model_predict["denoising_force_norm"] = model_ret[ + "denoising_dforce_norm" + ] + + # === Step 3. Virial === + if self.do_grad_c("energy"): + model_predict["virial"] = rearrange( + model_ret["energy_derv_c_redu"], "nf 1 nine -> nf nine", nine=9 + ) + if do_atomic_virial: + model_predict["atom_virial"] = rearrange( + model_ret["energy_derv_c"], + "nf nloc 1 nine -> nf nloc nine", + nine=9, + ) + + # === Step 4. Mask === + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + + def forward_common( + self, + coord: Float[Tensor, "nf nloc 3"] | Float[Tensor, "nf nloc_x3"], + atype: Int[Tensor, "nf nloc"], + box: Float[Tensor, "nf 9"] | None = None, + fparam: Float[Tensor, "nf ndf"] | None = None, + aparam: Float[Tensor, "nf nloc nda"] | None = None, + do_atomic_virial: bool = False, + force_input: Float[Tensor, "nf nloc 3"] | None = None, + noise_mask: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """ + Return model prediction using standard neighbor list. + + Parameters + ---------- + coord + Coordinates with shape (nf, nloc*3) or (nf, nloc, 3) in Å. + atype + Atom types with shape (nf, nloc). + box + Box tensor with shape (nf, 9) in Å, or None. + fparam + Frame parameters with shape (nf, ndf) or None. + aparam + Atomic parameters with shape (nf, nloc, nda) or None. + do_atomic_virial + Whether to compute atomic virial. + force_input + Optional atom-wise force input tensor with shape `(nf, nloc, 3)`. + It stays optional at the public model boundary because validation / + inference and clean `dens` batches may not provide force labels. + noise_mask + Optional corruption mask with shape `(nf, nloc)`. It stays optional + at the public model boundary because validation / inference and + clean `dens` batches may not provide corruption masks. + charge_spin + Frame-level charge and spin conditions with shape `(nf, 2)`. + + Returns + ------- + dict[str, torch.Tensor] + Model predictions including energy, forces, etc. + """ + with nvtx_range("SeZM/forward_common"): + # === Step 1. Cast inputs to correct dtype === + with nvtx_range("SeZM/input_type_cast"): + cc, bb, fp, ap, input_prec = self._input_type_cast( + coord, box=box, fparam=fparam, aparam=aparam + ) + del coord, box, fparam, aparam + nf, nloc = atype.shape[:2] + if cc.ndim == 2: + cc = cc.view(nf, nloc, 3) + + # === Step 2. Build neighbor list === + with nvtx_range("SeZM/build_neighbor_list"): + # extended_coord: (nf, nall, 3), extended_atype: (nf, nall) + # mapping: (nf, nall), nlist: (nf, nloc, nsel) + extended_coord, extended_atype, mapping, nlist = ( + self.build_neighbor_list(cc, atype, bb) + ) + + # === Step 3. Run the shared extended-input path === + return self.forward_common_after_nlist( + extended_coord, + extended_atype, + mapping, + nlist, + atype, + fp, + ap, + input_prec, + do_atomic_virial=do_atomic_virial, + force_input=force_input, + noise_mask=noise_mask, + charge_spin=charge_spin, + ) + + def forward_common_after_nlist( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + mapping: torch.Tensor, + nlist: torch.Tensor, + atype: torch.Tensor, + fp: torch.Tensor | None, + ap: torch.Tensor | None, + input_prec: torch.dtype, + *, + do_atomic_virial: bool = False, + force_input: torch.Tensor | None = None, + noise_mask: torch.Tensor | None = None, + extended_coord_corr: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """ + Run SeZM from already-built extended inputs. + + Parameters + ---------- + extended_coord + Coordinates in extended region with shape (nf, nall, 3). + extended_atype + Atom types in extended region with shape (nf, nall). + mapping + Extended-to-local mapping with shape (nf, nall). + nlist + Neighbor list with shape (nf, nloc, nsel). + atype + Local atom types with shape (nf, nloc). + fp + Cast frame parameters with shape (nf, ndf), or None. + ap + Cast atomic parameters with shape (nf, nloc, nda), or None. + input_prec + Original input precision used for output casting. + do_atomic_virial + Whether to compute per-atom virial. + force_input + Optional atom-wise force input for the ``dens`` path with shape + (nf, nloc, 3). + noise_mask + Optional atom-wise corruption mask for the ``dens`` path with + shape (nf, nloc). + extended_coord_corr + Coordinate correction for virial with shape (nf, nall, 3), or None. + charge_spin + Frame-level charge and spin conditions with shape `(nf, 2)`. + + Returns + ------- + dict[str, torch.Tensor] + Model predictions with the standard SeZM internal keys. + """ + nf, nloc = atype.shape[:2] + charge_spin = self.convert_charge_spin( + charge_spin, + nf=nf, + dtype=extended_coord.dtype, + device=extended_coord.device, + ) + active_mode = self.get_active_mode() + if active_mode == "dens": + # === Step 1. `dens` path (no coordinate gradients needed) === + extended_coord = extended_coord.detach() + force_input, noise_mask = self.canonicalize_dens_inputs( + force_input, + noise_mask, + nf=nf, + nloc=nloc, + dtype=extended_coord.dtype, + device=extended_coord.device, + ) + + if self.should_use_compile(): + fp, ap = self.convert_fp_ap( + fp, + ap, + nf=nf, + nloc=nloc, + dtype=extended_coord.dtype, + device=extended_coord.device, + ) + with self.tf32_precision_ctx(): + if self.compiled_dens_compute is None or not self._dens_compiled: + self.compile_dens() + with nvtx_range("SeZM/core_compute_dens"): + compute_ret = self.compiled_dens_compute( + extended_coord, + extended_atype, + nlist, + mapping, + force_input=force_input, + noise_mask=noise_mask, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + ) + if self._dens_pending_compile_t0 is not None: + if torch.cuda.is_available(): + torch.cuda.synchronize() + log.info( + "SeZM: finished compiling dens path in %.2fs", + time.perf_counter() - self._dens_pending_compile_t0, + ) + self._dens_pending_compile_t0 = None + else: + with nvtx_range("SeZM/core_compute_dens"): + compute_ret = self.core_compute_dens( + extended_coord, + extended_atype, + nlist, + mapping, + force_input=force_input, + noise_mask=noise_mask, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + ) + with nvtx_range("SeZM/post_process"): + model_predict = self.post_process_output_dens( + compute_ret, + atype, + noise_mask=noise_mask, + ) + else: + # === Step 1. `ener` path (edges built inside core_compute) === + # NOTE: Rebind the extended coordinates to a fresh leaf + # tensor before entering either ``core_compute`` or the + # compiled callable. ``detach()`` breaks any upstream + # autograd graph carried by the batch (data pipeline + # artefacts, neighbor-list ops) and + # ``requires_grad_(True)`` reinstates a grad-endpoint + # owned exclusively by this forward. The inner + # ``autograd.grad`` inside ``fit_output_to_model_output`` + # will then compute ``dE/dx`` against a graph of known + # shape and ownership -- the essential precondition for + # symbolic make_fx tracing. In eval without coordinate + # gradients a bare detach is enough. + if self.do_grad_r() or self.do_grad_c(): + extended_coord = extended_coord.detach().requires_grad_(True) + else: + extended_coord = extended_coord.detach() + + if self.should_use_compile(): + fp, ap = self.convert_fp_ap( + fp, + ap, + nf=nf, + nloc=nloc, + dtype=extended_coord.dtype, + device=extended_coord.device, + ) + with self.tf32_precision_ctx(): + has_coord_corr = extended_coord_corr is not None + cache_key = ( + bool(self.training), + bool(do_atomic_virial), + has_coord_corr, + ) + if cache_key not in self.compiled_core_compute_cache: + self.trace_and_compile( + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + charge_spin, + do_atomic_virial, + extended_coord_corr=extended_coord_corr, + ) + compiled_core_compute = self.compiled_core_compute_cache[cache_key] + with nvtx_range("SeZM/core_compute"): + if extended_coord_corr is None: + model_predict_lower = compiled_core_compute( + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + charge_spin, + ) + else: + model_predict_lower = compiled_core_compute( + extended_coord, + extended_atype, + nlist, + mapping, + fp, + ap, + charge_spin, + extended_coord_corr, + ) + if ( + self._core_compute_pending_compile_t0 is not None + and self._core_compute_pending_compile_key == cache_key + ): + if torch.cuda.is_available(): + torch.cuda.synchronize() + log.info( + "SeZM: finished compiling " + "(mode=%s, atomic_virial=%s, coord_corr=%s) " + "in %.2fs", + "train" if self.training else "eval", + do_atomic_virial, + has_coord_corr, + time.perf_counter() - self._core_compute_pending_compile_t0, + ) + self._core_compute_pending_compile_t0 = None + self._core_compute_pending_compile_key = None + else: + with nvtx_range("SeZM/core_compute"): + model_predict_lower = self.core_compute( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + extended_coord_corr=extended_coord_corr, + ) + + with nvtx_range("SeZM/communicate_output"): + model_predict = communicate_extended_output( + model_predict_lower, + self.model_output_def(), + mapping, + do_atomic_virial=do_atomic_virial, + ) + + # === Step 2. Type cast output === + with nvtx_range("SeZM/output_type_cast"): + model_predict = self._output_type_cast(model_predict, input_prec) + return model_predict + + def core_compute( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + do_atomic_virial: bool = False, + comm_dict: dict[str, torch.Tensor] | None = None, + extra_nlist_sort: bool = False, + extended_coord_corr: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """ + Compute SeZM lower outputs from extended inputs. + + Builds compact sparse edges, runs descriptor and fitting evaluation, + applies output masking and the optional analytical pair potential, + then calls ``fit_output_to_model_output`` for force / virial. + + Parameters + ---------- + extended_coord + Coordinates in extended region with shape (nf, nall, 3). + extended_atype + Atom types in extended region with shape (nf, nall). + nlist + Neighbor list with shape (nf, nloc, nsel). + mapping + Extended-to-local mapping with shape (nf, nall), or ``None``. + fparam + Frame parameters with shape (nf, ndf), or ``None``. + aparam + Atomic parameters with shape (nf, nloc, nda), or ``None``. + charge_spin + Frame-level charge and spin conditions with shape `(nf, 2)`. + do_atomic_virial + Whether to compute per-atom virial. + comm_dict + Communication data for parallel inference. Currently unused. + extra_nlist_sort + Whether to forcibly sort the nlist. + extended_coord_corr + Coordinates correction for virial with shape (nf, nall, 3) or ``None``. + + Returns + ------- + dict[str, torch.Tensor] + DeePMD lower-style outputs (energy, energy_redu, energy_derv_r, ...). + """ + del comm_dict + nlist = self.format_nlist( + extended_coord, extended_atype, nlist, extra_nlist_sort=extra_nlist_sort + ) + _, nloc, _ = nlist.shape + atype = extended_atype[:, :nloc] + descriptor_model = self.atomic_model.descriptor + + # === Step 1. Build compact sparse edges === + edge_index, edge_vec, edge_mask = self.build_edge_list_from_nlist( + extended_coord=extended_coord, + nlist=nlist, + mapping=mapping, + ) + + # === Step 2. Descriptor forward === + with nvtx_range("SeZM/descriptor"): + descriptor, _ = descriptor_model.forward_with_edges( + extended_coord=extended_coord[:, :nloc, :], + extended_atype=atype, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + charge_spin=charge_spin, + ) + if self.atomic_model.enable_eval_descriptor_hook: + self.atomic_model.eval_descriptor_list.append(descriptor.detach()) + + # === Step 3. Fitting net + output statistics === + with nvtx_range("SeZM/fitting_net"): + fit_ret = self.atomic_model.fitting_net( + descriptor, + atype, + fparam=fparam, + aparam=aparam, + ) + if self.atomic_model.enable_eval_fitting_last_layer_hook: + assert "middle_output" in fit_ret, ( + "eval_fitting_last_layer not supported for this fitting net!" + ) + self.atomic_model.eval_fitting_last_layer_list.append( + fit_ret.pop("middle_output").detach() + ) + with nvtx_range("SeZM/apply_out_stat"): + fit_ret = self.atomic_model.apply_out_stat(fit_ret, atype) + + # === Step 4. Apply atom mask === + ext_atom_mask = self.atomic_model.make_atom_mask(extended_atype) + atom_mask = ext_atom_mask[:, :nloc].to(torch.int32) + if self.atomic_model.atom_excl is not None: + atom_mask *= self.atomic_model.atom_excl(atype) + for key in fit_ret.keys(): + out_shape = fit_ret[key].shape + flat_dim = 1 + for axis_size in out_shape[2:]: + flat_dim *= axis_size + fit_ret[key] = ( + fit_ret[key].reshape([out_shape[0], out_shape[1], flat_dim]) + * atom_mask[:, :, None] + ).view(out_shape) + fit_ret["mask"] = atom_mask + + # === Step 5. Inject analytical pair potential === + if self.inter_potential is not None: + fit_ret["energy"] = fit_ret["energy"] + self.inter_potential( + extended_coord, + extended_atype, + nlist, + nloc, + real_type_count=self._get_inter_potential_real_type_count(), + ) + + # === Step 6. Force / virial via fit_output_to_model_output === + # NOTE: ``create_graph=self.training`` is the single toggle that + # activates force-loss training. Internally this calls + # ``torch.autograd.grad(energy, extended_coord, create_graph=...)`` + # to produce ``force = -dE/dx``. When ``True`` the autograd graph + # over the first derivative is kept alive, so the outer + # optimiser's ``.backward()`` can continue differentiating into + # parameters -- that chain is the full + # ``d^2 E / (dx dtheta)`` second derivative. When ``False`` the + # double-backward graph is never built, saving memory during + # inference. The entire reason this file exists -- make_fx, + # detach stripping, graph rebuild -- is to keep that + # second-derivative chain intact after ``torch.compile`` has + # captured the whole thing. + return fit_output_to_model_output( + fit_ret, + self.atomic_output_def(), + extended_coord, + do_atomic_virial=do_atomic_virial, + create_graph=self.training, + mask=fit_ret["mask"], + extended_coord_corr=extended_coord_corr, + ) + + def core_compute_dens( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + *, + force_input: torch.Tensor, + noise_mask: torch.Tensor, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Compute SeZM ``dens`` energy/direct-force tensors from extended inputs. + + Parameters + ---------- + extended_coord + Extended coordinates with shape (nf, nall, 3). + extended_atype + Extended atom types with shape (nf, nall). + nlist + Neighbor list with shape (nf, nloc, nsel). + mapping + Extended-to-local mapping with shape (nf, nall), or ``None``. + force_input + Atom-wise force input tensor with shape ``(nf, nloc, 3)``. + noise_mask + Atom-wise corruption mask with shape ``(nf, nloc)``. + fparam + Frame parameters with shape ``(nf, ndf)``, or ``None``. + aparam + Atomic parameters with shape ``(nf, nloc, nda)``, or ``None``. + charge_spin + Frame-level charge and spin conditions with shape `(nf, 2)`. + + Returns + ------- + torch.Tensor + Concatenated local tensor with shape ``(nf, nloc, 7)`` and layout + ``[atom_energy_norm | clean_dforce_norm | denoising_dforce_norm]``. + """ + if self.inter_potential is not None: + raise NotImplementedError( + "SeZM `dens` path does not support analytical bridging potentials." + ) + + nlist = self.format_nlist( + extended_coord, + extended_atype, + nlist, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + ) + _, nloc, _ = nlist.shape + atype = extended_atype[:, :nloc] + descriptor_model = self.atomic_model.descriptor + + # === Step 1. Build compact sparse edges === + edge_index, edge_vec, edge_mask = self.build_edge_list_from_nlist( + extended_coord=extended_coord, + nlist=nlist, + mapping=mapping, + ) + + # === Step 2. Force embedding === + dens_fitting = self.atomic_model.get_dens_fitting_net() + force_embedding = dens_fitting.build_force_embedding( + force_input, + noise_mask=noise_mask, + ) + + # === Step 3. Descriptor forward with force embedding === + with nvtx_range("SeZM/descriptor_dens"): + descriptor, latent = descriptor_model.forward_with_edges( + extended_coord=extended_coord[:, :nloc, :], + extended_atype=atype, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + force_embedding=force_embedding, + charge_spin=charge_spin, + ) + if self.atomic_model.enable_eval_descriptor_hook: + self.atomic_model.eval_descriptor_list.append(descriptor.detach()) + + # === Step 4. Dens fitting net === + with nvtx_range("SeZM/dens_fitting_net"): + fit_ret = dens_fitting( + descriptor, + latent, + atype, + noise_mask=noise_mask, + fparam=fparam, + aparam=aparam, + return_components=True, + ) + if self.atomic_model.enable_eval_fitting_last_layer_hook: + assert "middle_output" in fit_ret, ( + "eval_fitting_last_layer not supported for this fitting net!" + ) + self.atomic_model.eval_fitting_last_layer_list.append( + fit_ret.pop("middle_output").detach() + ) + return torch.cat( + [ + fit_ret["energy"], + fit_ret["clean_dforce"], + fit_ret["denoising_dforce"], + ], + dim=-1, + ) + + @torch.jit.export + def forward_lower( + self, + extended_coord: Float[Tensor, "nf nall_x3"] | Float[Tensor, "nf nall 3"], + extended_atype: Int[Tensor, "nf nall"], + nlist: Int[Tensor, "nf nloc nsel"], + mapping: Int[Tensor, "nf nall"] | None = None, + fparam: Float[Tensor, "nf ndf"] | None = None, + aparam: Float[Tensor, "nf nall nda"] | None = None, + do_atomic_virial: bool = False, + comm_dict: dict[str, torch.Tensor] | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """ + Lower-level public forward using the DeePMD lower-interface contract. + + Parameters + ---------- + extended_coord + Extended coordinates with shape (nf, nall*3) or (nf, nall, 3) in Å. + extended_atype + Extended atom types with shape (nf, nall). + nlist + Neighbor list with shape (nf, nloc, nsel). + mapping + Mapping indices with shape (nf, nall), or None. + fparam + Frame parameters with shape (nf, ndf) or None. + aparam + Atomic parameters with shape (nf, nall, nda) or None. + do_atomic_virial + Whether to compute atomic virial. + comm_dict + Communication dict forwarded to `forward_common_lower()`. + charge_spin + Frame-level charge and spin conditions with shape `(nf, 2)`. + + Returns + ------- + dict[str, torch.Tensor] + Lower-interface outputs. + When a fitting net is present, this always includes: + - `atom_energy`: atomic energy on local atoms with shape (nf, nloc, 1) + - `energy`: reduced energy with shape (nf, 1) + It additionally includes: + - `extended_force`: force on extended coordinates with shape (nf, nall, 3) + when `self.do_grad_r("energy")` is true + - `dforce`: fitting-net direct force output when energy is not coordinate differentiable + - `virial`: reduced virial with shape (nf, 9) when `self.do_grad_c("energy")` is true + - `extended_virial`: per-extended-atom virial with shape (nf, nall, 9) + only when both `self.do_grad_c("energy")` and `do_atomic_virial` are true + If no fitting net is present, the raw result of `forward_common_lower()` is returned. + """ + if self.get_active_mode() == "dens": + raise NotImplementedError( + "SeZM `forward_lower` only supports the conservative `ener` mode." + ) + cc_ext, _, fp, ap, input_prec = self._input_type_cast( + extended_coord, fparam=fparam, aparam=aparam + ) + model_ret = self.forward_common_lower( + cc_ext, + extended_atype, + nlist, + mapping, + fparam=fp, + aparam=ap, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + charge_spin=charge_spin, + ) + model_ret = self._output_type_cast(model_ret, input_prec) + if self.get_fitting_net() is not None: + model_predict: dict[str, torch.Tensor] = {} + + # === Step 1. Energy === + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + + # === Step 2. Force (independent branch) === + if self.do_grad_r("energy"): + model_predict["extended_force"] = rearrange( + model_ret["energy_derv_r"], + "nf nall 1 three -> nf nall three", + three=3, + ) + else: + assert model_ret["dforce"] is not None + model_predict["dforce"] = model_ret["dforce"] + + # === Step 3. Virial === + if self.do_grad_c("energy"): + model_predict["virial"] = rearrange( + model_ret["energy_derv_c_redu"], "nf 1 nine -> nf nine", nine=9 + ) + if do_atomic_virial: + model_predict["extended_virial"] = rearrange( + model_ret["energy_derv_c"], + "nf nall 1 nine -> nf nall nine", + nine=9, + ) + else: + model_predict = model_ret + return model_predict + + def forward_common_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + comm_dict: dict[str, torch.Tensor] | None = None, + extra_nlist_sort: bool = False, + extended_coord_corr: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Public lower interface with dtype casting around ``core_compute()``.""" + cc_ext, _, fp, ap, input_prec = self._input_type_cast( + extended_coord, fparam=fparam, aparam=aparam + ) + cc_ext = cc_ext.reshape(extended_atype.shape[0], -1, 3) + if extended_coord_corr is not None and extended_coord_corr.ndim == 2: + extended_coord_corr = extended_coord_corr.reshape( + extended_atype.shape[0], -1, 3 + ) + if self.do_grad_r() or self.do_grad_c(): + cc_ext = cc_ext.detach().requires_grad_(True) + nf = extended_atype.shape[0] + charge_spin = self.convert_charge_spin( + charge_spin, + nf=nf, + dtype=cc_ext.dtype, + device=cc_ext.device, + ) + model_predict = self.core_compute( + cc_ext, + extended_atype, + nlist, + mapping=mapping, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=extra_nlist_sort, + extended_coord_corr=extended_coord_corr, + ) + return self._output_type_cast(model_predict, input_prec) + + # ========================================================================= + # Compile Utilities + # ========================================================================= + + def trace_and_compile( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor, + fp: torch.Tensor, + ap: torch.Tensor, + charge_spin: torch.Tensor, + do_atomic_virial: bool, + extended_coord_corr: torch.Tensor | None = None, + ) -> None: + """Trace ``core_compute()`` with ``make_fx`` and cache the compiled callable. + + The full flow is: wrap ``core_compute`` in a tensor-only + ``compute_fn`` that also owns the coordinate grad-endpoint, trace + it with ``make_fx(tracing_mode="symbolic")`` so all shape axes + become sympy symbols, strip autograd-inserted detach chains in + training mode, rebuild the FX graph to flush stale linked-list + pointers, and finally hand the clean ``GraphModule`` to + ``torch.compile(backend="inductor", dynamic=True)``. The + compiled callable is stored outside the ``nn.Module`` tree so + FSDP/DDP cannot see or shard its duplicated parameters. + """ + from torch._decomp import ( + get_decompositions, + ) + + mode = "train" if self.training else "eval" + has_coord_corr = extended_coord_corr is not None + log.info( + "SeZM: start tracing and compiling " + "(mode=%s, atomic_virial=%s, coord_corr=%s)", + mode, + do_atomic_virial, + has_coord_corr, + ) + _compile_t0 = time.perf_counter() + + need_coord_grad = self.do_grad_r() or self.do_grad_c() + + def _prepare_coord_for_trace(coord: torch.Tensor) -> torch.Tensor: + """Restart the coordinate autograd graph for the traced compute. + + ``detach()`` severs any upstream graph carried by the trace + inputs and ``requires_grad_(True)`` reinstates a fresh + grad-endpoint owned by this compute. The inner + ``autograd.grad`` inside ``fit_output_to_model_output`` then + differentiates against a graph of known shape and ownership -- + the essential precondition for make_fx symbolic tracing to + capture dE/dx as ordinary FX nodes. In the eval-only branch + a bare detach keeps the traced graph free of backward sections. + """ + if need_coord_grad: + return coord.detach().requires_grad_(True) + else: + return coord.detach() + + if extended_coord_corr is None: + + def compute_fn( + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor, + fp: torch.Tensor, + ap: torch.Tensor, + charge_spin: torch.Tensor, + ) -> dict[str, torch.Tensor]: + return self.core_compute( + _prepare_coord_for_trace(extended_coord), + extended_atype, + nlist, + mapping=mapping, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + ) + else: + + def compute_fn( + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor, + fp: torch.Tensor, + ap: torch.Tensor, + charge_spin: torch.Tensor, + extended_coord_corr: torch.Tensor, + ) -> dict[str, torch.Tensor]: + # NOTE: Spin virial uses a coordinate correction derived from the + # virtual-atom displacement. Keeping it as a tensor input lets the + # compiled graph stay reusable across frames. + return self.core_compute( + _prepare_coord_for_trace(extended_coord), + extended_atype, + nlist, + mapping=mapping, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + extended_coord_corr=extended_coord_corr, + ) + + # NOTE: Always trace with a fixed batch size that is free of known + # symbolic-shape collisions. + # + # make_fx(tracing_mode="symbolic") replaces shapes with sympy + # symbols, but the moment a symbolic dim ends up equal to a + # *concrete* dim elsewhere in the same tensor it collapses into + # a constant and the graph specialises on that batch size. Known + # reserved dimensions include 1 (specialisation), 2 (charge/spin + # width), 3 (Cartesian coordinates), and 9 (virial tensor). Any + # of those collisions forces + # ``torch.compile(dynamic=True)`` to reject later batches whose + # nf differs from the traced constant. + # + # If a future code change introduces a new explicit dimension of + # this size and compile starts failing with a similar shape + # mismatch, change this constant accordingly. + trace_nf = 5 + coord_for_trace = extended_coord[:1].repeat(trace_nf, 1, 1) + atype_for_trace = extended_atype[:1].repeat(trace_nf, 1) + nlist_for_trace = nlist[:1].repeat(trace_nf, 1, 1) + mapping_for_trace = mapping[:1].repeat(trace_nf, 1) + fp_for_trace = fp[:1].repeat(trace_nf, 1) + ap_for_trace = ap[:1].repeat(trace_nf, 1, 1) + charge_spin_for_trace = charge_spin[:1].repeat(trace_nf, 1) + trace_args = [ + coord_for_trace, + atype_for_trace, + nlist_for_trace, + mapping_for_trace, + fp_for_trace, + ap_for_trace, + charge_spin_for_trace, + ] + if extended_coord_corr is not None: + trace_args.append(extended_coord_corr[:1].repeat(trace_nf, 1, 1)) + + # NOTE: Decompose ``silu_backward`` into primitive ops. + # PyTorch ships forward and first-order backward for SiLU but no + # symbolic higher-order derivative. Without this decomposition + # make_fx would emit ``aten.silu_backward.default`` opaquely + # inside the first-derivative graph; when Inductor later has to + # differentiate that op again for the optimiser step, it refuses + # because silu_backward is not differentiable in its registered + # form. Lowering to ``sigmoid + pointwise mul + ...`` gives + # every pointwise piece a well-defined higher derivative. + decomp_table = get_decompositions([torch.ops.aten.silu_backward.default]) + + # NOTE: ``tracing_mode="symbolic"`` makes every shape a sympy + # symbol so the compiled graph can later accept any + # (nframes, nall, n_edges, ...) at runtime. + # ``_allow_non_fake_inputs=True`` lets us feed real tensors to + # the trace -- the edge compactor contains data-dependent ops + # (``torch.nonzero``, ``index_select``) that cannot execute on + # FakeTensors, so we need concrete values to resolve their + # control flow exactly once; shapes become symbolic immediately + # afterwards. + traced = make_fx( + compute_fn, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + decomposition_table=decomp_table, + )(*trace_args) + + # NOTE: Only strip autograd-inserted detach chains in training + # mode. With ``create_graph=True`` make_fx wraps every saved + # forward activation in a + # ``fwd_op -> detach_A -> detach_B -> bwd_op`` chain. Those + # detaches are informational in eager autograd but become real + # ops after tracing and sever the gradient path from the force + # loss back to theta -- training would silently emit zero + # parameter updates for the second-derivative term. In eval + # mode ``create_graph=False`` so the chain is never inserted + # and stripping would be wrong. + if self.training: + _strip_saved_tensor_detach(traced) + + # NOTE: Rebuild the FX graph from scratch. + # ``Graph.erase_node`` inside ``_strip_saved_tensor_detach`` + # unlinks nodes from the doubly linked list but on some PyTorch + # builds (observed on 2.11+cu130) leaves stale C-level + # prev/next pointers on neighbouring Node objects. Dynamo later + # re-traces the ``GraphModule`` and walks ``graph.nodes`` inside + # ``output_graph.py:_create_proxy`` to read ``nd.meta``; + # dereferencing one of those stale pointers segfaults the + # process. A single ``node_copy`` pass into a freshly allocated + # ``torch.fx.Graph`` builds an equivalent graph with a clean + # linked list. We always rebuild -- even in eval -- because a + # fresh graph is cheap and a segfault is fatal. + traced = _rebuild_graph_module(traced) + + # NOTE: Inductor options are mode-dependent. Training has been + # running cleanly with ``max_fusion_size=64`` for a while, so we + # keep that path untouched to avoid destabilising it. Inference + # (``self.training is False``) has shown a Triton + # ``make_ttgir`` / ``PassManager::run failed`` on the fused + # per-reduction kernel + # ``triton_per_fused_clone_exp_mul_pow_select_slice_sum_tanh_...``; + # the kernel itself is fine, but the *fused* IR is too big / + # too complex for Triton's lowering pipeline on this version. + # So inference: + # * disables ``triton.persistent_reductions`` -- persistent + # reduction is what lets Inductor pull a ``sum`` together + # with all surrounding pointwise ops (including the + # activation-backward pointwise chain) into one + # ``per_fused_...`` kernel; turning it off forces the sum + # to emit its own kernel and stops the pathological fuse. + # * tightens ``max_fusion_size`` from 64 to 8, so even + # non-persistent fusions stay small enough for Triton IR + # generation to succeed. + # Training does not hit this path in practice (different graph + # topology under ``create_graph=True``), so we keep the looser + # options there to preserve kernel quality. + compile_options: dict[str, Any] = { + "max_autotune": False, + "shape_padding": True, + "epilogue_fusion": False, + "triton.cudagraphs": False, + # NOTE: ``mix_order_reduction`` hits multiple bugs under + # data-dependent symbolic shapes on PyTorch <=2.11 + # (pytorch/pytorch#174379, #178080, #179494) -- our edge + # count is exactly that kind of shape. + "triton.mix_order_reduction": False, + } + if self.training: + compile_options["max_fusion_size"] = 64 + else: + compile_options["max_fusion_size"] = 8 + compile_options["triton.persistent_reductions"] = False + try: + from torch._inductor import config as inductor_config + + valid_options = inductor_config.get_config_copy() + compile_options = { + key: value + for key, value in compile_options.items() + if key.replace("-", "_") in valid_options + } + except Exception: + # Older/future PyTorch builds may not expose the config registry. + # In that case keep the curated option set and let torch.compile + # surface any real backend error. + pass + + # NOTE: Store the compiled callable inside the plain-``dict`` + # cache ``compiled_core_compute_cache``. The dict itself was installed + # via ``object.__setattr__`` at __init__ time so that + # ``nn.Module.__setattr__`` never saw any of this; mutating the + # dict in place afterwards keeps the compile wrappers hidden + # from parameter discovery (FSDP2/DDP would otherwise shard or + # synchronise the wrapper's duplicated flat parameter views and + # silently corrupt training). The cache is keyed on + # ``(training, do_atomic_virial, has_coord_corr)`` so that distinct + # graph topologies coexist without evicting each other on every + # ``model.eval()`` / ``model.train()`` switch. + cache_key = (bool(self.training), bool(do_atomic_virial), has_coord_corr) + # NOTE: ``dynamic=True`` emits a single kernel per traced + # shape symbol, so changes in ``nframes``, ``nall`` or edge + # count do not trigger recompiles; and the option dict above + # disables every Inductor/Triton feature that has ever + # interacted badly with ``make_fx`` + double backward in + # this project. + self.compiled_core_compute_cache[cache_key] = torch.compile( + traced, + backend="inductor", + dynamic=True, + options=compile_options, + ) + # torch.compile is lazy; the "finished" log is emitted after the + # first call triggers Inductor lowering (see forward_common). + # ``pending_key`` pairs with ``pending_t0`` so the log is only + # printed once, by the forward that actually triggers lowering + # for *this* cache slot -- other slots may still be pending. + self._core_compute_pending_compile_t0 = _compile_t0 + self._core_compute_pending_compile_key = cache_key + + def compile_dens(self) -> None: + """Compile the direct-force `dens` path.""" + from torch._inductor import config as inductor_config + + log.info("SeZM: start compiling dens path") + _compile_t0 = time.perf_counter() + + inductor_config.max_autotune_report_choices_stats = False + inductor_config.autotune_num_choices_displayed = 0 + + object.__setattr__( + self, + "compiled_dens_compute", + torch.compile( + self.core_compute_dens, + backend="inductor", + dynamic=True, + options={ + "max_autotune": False, + "epilogue_fusion": False, + "triton.cudagraphs": False, + "shape_padding": True, + "max_fusion_size": 64, + }, + ), + ) + self._dens_compiled = True + # torch.compile is lazy; the "finished" log is emitted after the + # first call triggers Inductor lowering (see forward_common). + self._dens_pending_compile_t0 = _compile_t0 + + def should_use_compile(self) -> bool: + """Return whether the current forward should use the compile path.""" + if self.training: + return self.use_compile + return bool(self._env_use_compile_infer) + + # ========================================================================= + # Export Utilities + # ========================================================================= + + def _trace_lower_exportable( + self, + fn: Any, + *sample_inputs: torch.Tensor | None, + ) -> torch.nn.Module: + """Trace a lower-interface closure into an exportable FX graph.""" + from torch._decomp import ( + get_decompositions, + ) + + return make_fx( + fn, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + decomposition_table=get_decompositions( + [torch.ops.aten.silu_backward.default] + ), + )(*sample_inputs) + + def forward_common_lower_exportable( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + charge_spin: torch.Tensor | None = None, + ) -> torch.nn.Module: + """Trace ``forward_common_lower`` into an exportable FX ``GraphModule``. + + ``make_fx`` unfolds the inner ``autograd.grad`` that + ``fit_output_to_model_output`` performs for force and virial, so + the returned module can be handed to :func:`torch.export.export` + directly. ``silu_backward`` is decomposed to primitive ops so + Inductor never sees an opaque higher-order derivative — the same + decomposition the training compile path uses. + + Only the conservative ``ener`` mode is supported: ``dens`` + emits a direct-force tensor that has no ``DeepPotPTExpt`` consumer. + """ + if self.get_active_mode() == "dens": + raise NotImplementedError( + "SeZM export supports only the conservative `ener` path." + ) + + model = self + extra_sort = self.need_sorted_nlist_for_lower() + + def lower_fn( + ext_coord: torch.Tensor, + ext_atype: torch.Tensor, + nlist_: torch.Tensor, + mapping_: torch.Tensor | None, + fparam_: torch.Tensor | None, + aparam_: torch.Tensor | None, + charge_spin_: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + # detach + requires_grad_ must live INSIDE the traced closure: + # LAMMPS feeds a plain fp64 non-leaf tensor, and the exported + # graph needs its own grad endpoint for the inner autograd.grad + # that fit_output_to_model_output performs. + ext_coord = ext_coord.detach().requires_grad_(True) + return model.forward_common_lower( + ext_coord, + ext_atype, + nlist_, + mapping_, + fparam=fparam_, + aparam=aparam_, + do_atomic_virial=do_atomic_virial, + extra_nlist_sort=extra_sort, + charge_spin=charge_spin_, + ) + + def fn( + ext_coord: torch.Tensor, + ext_atype: torch.Tensor, + nlist_: torch.Tensor, + mapping_: torch.Tensor | None, + fparam_: torch.Tensor | None, + aparam_: torch.Tensor | None, + *maybe_charge_spin: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + charge_spin_ = maybe_charge_spin[0] if maybe_charge_spin else None + return lower_fn( + ext_coord, + ext_atype, + nlist_, + mapping_, + fparam_, + aparam_, + charge_spin_, + ) + + trace_inputs = (extended_coord, extended_atype, nlist, mapping, fparam, aparam) + if self.get_dim_chg_spin() > 0: + charge_spin = self.convert_charge_spin( + charge_spin, + nf=extended_atype.shape[0], + dtype=extended_coord.dtype, + device=extended_coord.device, + ) + trace_inputs = (*trace_inputs, charge_spin) + + return self._trace_lower_exportable( + fn, + *trace_inputs, + ) + + # ========================================================================= + # Neighbor List Construction + # ========================================================================= + + def build_neighbor_list( + self, + coord: Float[Tensor, "nf nloc 3"] | Float[Tensor, "nf nloc_x3"], + atype: Int[Tensor, "nf nloc"], + box: Float[Tensor, "nf 9"] | None, + ) -> tuple[ + Float[Tensor, "nf nall 3"], + Int[Tensor, "nf nall"], + Int[Tensor, "nf nall"], + Int[Tensor, "nf nloc nsel"], + ]: + """ + Build extended inputs and neighbor list (traditional path). + + Parameters + ---------- + coord + Coordinates with shape (nf, nloc, 3) in Å. + atype + Atom types with shape (nf, nloc). + box + Box tensor with shape (nf, 9) in Å, or None. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + Extended coordinates, extended atom types, neighbor list, and mapping. + """ + return extend_input_and_build_neighbor_list( + coord, + atype, + self.get_rcut(), + self.get_sel(), + mixed_types=True, + box=box, + ) + + def build_edge_list_from_nlist( + self, + *, + extended_coord: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Build a compact edge list from DeePMD padded neighbor list. + + Edge vectors are computed via ``index_select`` on ``extended_coord`` + so they remain differentiable w.r.t. the input coordinates. One + masked dummy edge is always appended to avoid data-dependent empty-edge + branches that ``make_fx`` cannot trace. + + Parameters + ---------- + extended_coord + Extended coordinates with shape (nf, nall, 3). + nlist + DeePMD padded neighbor list with shape (nf, nloc, nsel). + mapping + Extended-to-local mapping with shape (nf, nall), or ``None``. + + Returns + ------- + edge_index + Edge indices with shape (2, E+1) where E is valid edge count. + edge_vec + Edge vectors with shape (E+1, 3). + edge_mask + Boolean mask with shape (E+1,). The trailing element is ``False``. + """ + nf, nloc, nsel = nlist.shape + n_actual = nf * nloc + device = extended_coord.device + nall = extended_coord.shape[1] + descriptor_model = self.atomic_model.descriptor + coord_for_diff = extended_coord.to(dtype=descriptor_model.compute_dtype) + + # === Step 1. Build per-edge geometry via index_select (differentiable) === + # NOTE: Edge vectors come from ``coord_flat.index_select(0, ...)`` + # rather than advanced indexing ``coord_flat[...]``. + # ``index_select`` has an explicit, well-defined backward that + # routes gradient cleanly back to the original extended + # coordinate tensor. Advanced indexing combined with make_fx + # symbolic shapes has previously produced silent gradient + # truncation in this project -- the second-derivative gradient + # over coordinates was effectively zero, with no error raised. + # ``torch.where(valid_flat, neighbor_flat, 0)`` sanitises padded + # ``-1`` entries before indexing so we never hit an out-of-range + # gather; the corresponding edges are filtered out below anyway. + dst_actual = torch.arange( + n_actual, device=device, dtype=torch.long + ).repeat_interleave(nsel) + f_idx = dst_actual // nloc + dst_local = dst_actual % nloc + neighbor_flat = nlist.reshape(-1) + valid_flat = neighbor_flat >= 0 + neighbor_safe = torch.where( + valid_flat, neighbor_flat, torch.zeros_like(neighbor_flat) + ) + coord_flat = coord_for_diff.flatten(0, 1) + dst_ext = f_idx * nall + dst_local + src_ext = f_idx * nall + neighbor_safe.to(dtype=torch.long) + diff = coord_flat.index_select(0, src_ext) - coord_flat.index_select(0, dst_ext) + edge_len2 = torch.sum(diff * diff, dim=-1) + + # === Step 2. Build compact src/dst (local indices) === + if mapping is None: + src_local = neighbor_safe.to(dtype=torch.long) + else: + mapping_flat = mapping.reshape(-1) + src_local = mapping_flat.index_select(0, f_idx * nall + neighbor_safe) + src_actual = f_idx * nloc + src_local.to(dtype=torch.long) + + # Filter: valid nlist entry AND src in [0, nloc) AND non-zero distance. + src_local_valid = (src_local >= 0) & (src_local < nloc) + len_positive = edge_len2 > 1e-10 + edge_mask_actual = valid_flat & src_local_valid & len_positive + + valid_idx = torch.nonzero(edge_mask_actual, as_tuple=False).flatten() + + # === Step 3. Compact edges + append one masked dummy === + # NOTE: Always append exactly one masked dummy edge. + # ``torch.nonzero(edge_mask_actual)`` produces a data-dependent + # number of valid edges, which can be zero on sparse or + # single-type systems. make_fx cannot trace an + # ``if n_edges == 0: skip`` branch symbolically; without the + # dummy it would fall back to concrete shape specialisation and + # break ``torch.compile(dynamic=True)`` for later batches. The + # dummy edge copies entry 0 (any in-range index is fine) and + # carries ``edge_mask=False`` so every downstream sum, gather + # or scatter ignores it. + padded_idx = torch.cat( + [valid_idx, torch.zeros(1, dtype=torch.long, device=device)] + ) + src_sel = src_actual.index_select(0, padded_idx) + dst_sel = dst_actual.index_select(0, padded_idx) + edge_vec_sel = diff.index_select(0, padded_idx) + edge_index = torch.stack([src_sel, dst_sel], dim=0) + edge_mask = torch.cat( + [ + torch.ones(valid_idx.shape[0], dtype=torch.bool, device=device), + torch.zeros(1, dtype=torch.bool, device=device), + ] + ) + return edge_index, edge_vec_sel, edge_mask + + # ========================================================================= + # Input Canonicalization + # ========================================================================= + + def convert_fp_ap( + self, + fp: torch.Tensor | None, + ap: torch.Tensor | None, + nf: int, + nloc: int, + dtype: torch.dtype, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Convert optional fitting inputs to tensor-only compile inputs.""" + dim_fparam = self.get_dim_fparam() + dim_aparam = self.get_dim_aparam() + + # === Step 1. Canonicalize frame parameters === + if dim_fparam == 0: + fp = torch.empty((nf, 0), dtype=dtype, device=device) + elif fp is None: + default_fparam = self.get_default_fparam() + if default_fparam is None: + raise ValueError( + "fparam is required because fitting net dim_fparam > 0" + ) + fp = default_fparam.to(device=device, dtype=dtype).view(1, dim_fparam) + fp = fp.expand(nf, -1) + else: + if fp.numel() != nf * dim_fparam: + raise ValueError( + f"input fparam: cannot reshape {list(fp.shape)} " + f"into ({nf}, {dim_fparam})." + ) + fp = fp.to(device=device, dtype=dtype).view(nf, dim_fparam) + + # === Step 2. Canonicalize atomic parameters === + if dim_aparam == 0: + ap = torch.empty((nf, nloc, 0), dtype=dtype, device=device) + elif ap is None: + if dim_aparam > 0: + raise ValueError( + "aparam is required because fitting net dim_aparam > 0" + ) + else: + if ap.numel() != nf * nloc * dim_aparam: + raise ValueError( + f"input aparam: cannot reshape {list(ap.shape)} " + f"into ({nf}, {nloc}, {dim_aparam})." + ) + ap = ap.to(device=device, dtype=dtype).view(nf, nloc, dim_aparam) + + return fp, ap + + def convert_charge_spin( + self, + charge_spin: torch.Tensor | None, + nf: int, + dtype: torch.dtype, + device: torch.device, + ) -> torch.Tensor: + """ + Canonicalize optional charge/spin conditions for internal compute paths. + + Parameters + ---------- + charge_spin + Optional frame-level charge and spin conditions. + nf + Number of frames. + dtype + Target floating-point dtype. + device + Target device. + + Returns + ------- + torch.Tensor + Tensor with shape `(nf, 2)` when enabled, otherwise `(nf, 0)`. + """ + dim_chg_spin = self.atomic_model.get_dim_chg_spin() + if dim_chg_spin == 0: + return torch.empty((nf, 0), dtype=dtype, device=device) + + if charge_spin is None: + default_chg_spin = self.atomic_model.get_default_chg_spin() + if default_chg_spin is None: + raise ValueError("charge_spin is required for this SeZM model") + charge_spin = default_chg_spin.to(device=device, dtype=dtype).view(1, 2) + else: + charge_spin = charge_spin.to(device=device, dtype=dtype) + + if charge_spin.ndim == 1: + if charge_spin.numel() != dim_chg_spin: + raise ValueError("charge_spin must contain [charge, spin]") + charge_spin = charge_spin.view(1, dim_chg_spin) + elif charge_spin.ndim != 2 or charge_spin.shape[-1] != dim_chg_spin: + raise ValueError("charge_spin must have shape (nf, 2)") + + if charge_spin.shape[0] == 1 and nf != 1: + charge_spin = charge_spin.expand(nf, -1) + elif charge_spin.shape[0] != nf: + raise ValueError("charge_spin first dimension must match nframes") + return charge_spin + + def canonicalize_dens_inputs( + self, + force_input: torch.Tensor | None, + noise_mask: torch.Tensor | None, + nf: int, + nloc: int, + dtype: torch.dtype, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Canonicalize optional public `dens` inputs to concrete tensors. + + Parameters + ---------- + force_input + Optional atom-wise force input tensor. + noise_mask + Optional atom-wise corruption mask. + nf + Number of frames. + nloc + Number of local atoms per frame. + dtype + Target floating-point dtype. + device + Target device. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + Canonicalized force tensor with shape `(nf, nloc, 3)` and mask with + shape `(nf, nloc)`. + + Notes + ----- + `force_input` and `noise_mask` remain optional only at the outer model + API. Internal `dens` compute functions always receive concrete tensors. + """ + if force_input is None: + force_input = torch.zeros((nf, nloc, 3), dtype=dtype, device=device) + else: + if force_input.ndim == 2: + force_input = force_input.view(nf, nloc, 3) + elif force_input.ndim != 3: + raise ValueError( + "`force_input` must have shape (nf, nloc, 3) or (nf, nloc*3)." + ) + force_input = force_input.to(device=device, dtype=dtype) + + if noise_mask is None: + noise_mask = torch.zeros((nf, nloc), dtype=torch.bool, device=device) + else: + if noise_mask.ndim != 2: + raise ValueError("`noise_mask` must have shape (nf, nloc).") + noise_mask = noise_mask.to(device=device, dtype=torch.bool) + + return force_input, noise_mask + + # ========================================================================= + # Output Post-Processing + # ========================================================================= + + def post_process_output_dens( + self, + compute_ret: torch.Tensor, + atype: torch.Tensor, + *, + noise_mask: torch.Tensor, + ) -> dict[str, torch.Tensor]: + """ + Convert the concatenated `dens` output to DeePMD model outputs. + + Parameters + ---------- + compute_ret + Concatenated tensor with shape `(nf, nloc, 7)` or `(1, n_node, 7)`. + atype + Local atom types with shape `(nf, nloc)`. + noise_mask + Corruption mask with shape `(nf, nloc)`. + + Returns + ------- + dict[str, torch.Tensor] + Standard DeePMD model predictions for `dens` mode. + """ + nf, nloc = atype.shape[:2] + n_actual = nf * nloc + dens_ret = { + "energy": compute_ret[:, :n_actual, 0:1].view(nf, nloc, 1), + "clean_dforce": compute_ret[:, :n_actual, 1:4].view(nf, nloc, 3), + "denoising_dforce": compute_ret[:, :n_actual, 4:7].view(nf, nloc, 3), + } + return self.atomic_model.apply_out_stat_dens( + dens_ret, + atype, + noise_mask=noise_mask, + energy_redu_dtype=self.redu_prec, + ) + + # ========================================================================= + # Charge/Spin Condition Metadata + # ========================================================================= + + def has_chg_spin_ebd(self) -> bool: + """Return whether charge/spin condition embedding is enabled.""" + return self.atomic_model.has_chg_spin_ebd() + + def get_dim_chg_spin(self) -> int: + """Return charge/spin condition width.""" + return self.atomic_model.get_dim_chg_spin() + + def has_default_chg_spin(self) -> bool: + """Return whether default charge/spin conditions are configured.""" + return self.atomic_model.has_default_chg_spin() + + def get_default_chg_spin(self) -> torch.Tensor | None: + """Return default charge/spin conditions as a tensor.""" + return self.atomic_model.get_default_chg_spin() + + # ========================================================================= + # Mode Management + # ========================================================================= + + def get_active_mode(self) -> str: + """Return the current SeZM execution mode.""" + return self.atomic_model.get_active_mode() + + def set_active_mode(self, mode: str) -> None: + """ + Switch the active SeZM execution mode. + + Parameters + ---------- + mode + Target mode. Must be `ener` or `dens`. + """ + self.atomic_model.set_active_mode(mode) + + def set_active_mode_from_loss(self, loss_type: str) -> None: + """ + Select the active SeZM path from `loss.type`. + + Parameters + ---------- + loss_type + Loss type name. + """ + normalized = str(loss_type).lower() + if normalized in {"ener", "dens"}: + self.set_active_mode(normalized) + + def reset_head_for_mode(self, mode: str) -> None: + """ + Reinitialize one SeZM fitting head and reset mode-specific compile state. + + Parameters + ---------- + mode + Target mode to reset. + """ + self.atomic_model.reset_head_for_mode(mode) + if mode == "dens": + self._dens_compiled = False + self._dens_pending_compile_t0 = None + object.__setattr__(self, "compiled_dens_compute", None) + else: + self._core_compute_pending_compile_t0 = None + self._core_compute_pending_compile_key = None + # Drop every compile slot so the next forward retraces against the + # reinitialised fitting head. + self.compiled_core_compute_cache.clear() + + # ========================================================================= + # Bridging Helpers + # ========================================================================= + + def _get_inter_potential_real_type_count(self) -> int: + """Return the real-type count used to mask analytical pair potentials.""" + return len(self.get_type_map()) + + # ========================================================================= + # Type and Output Metadata + # ========================================================================= + + def translated_output_def(self) -> dict[str, Any]: + """ + Translate model output definition to a dictionary format. + + Returns + ------- + dict[str, Any] + Dictionary mapping output names to their corresponding output definitions. + """ + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": out_def_data["energy"], + "energy": out_def_data["energy_redu"], + } + if "dforce" in out_def_data: + output_def["force"] = out_def_data["dforce"] + elif self.do_grad_r("energy"): + output_def["force"] = out_def_data["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = out_def_data["energy_derv_c_redu"].squeeze(-2) + output_def["atom_virial"] = out_def_data["energy_derv_c"].squeeze(-2) + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + + return output_def + + def get_observed_type_list(self) -> list[str]: + """ + Get observed types (elements) of the model during data statistics. + + Returns + ------- + list[str] + A list of the observed types in this model. + """ + type_map = self.get_type_map() + out_bias = self.atomic_model.get_out_bias()[0] + + assert out_bias is not None, "No out_bias found in the model." + assert out_bias.dim() == 2, "The supported out_bias should be a 2D tensor." + assert out_bias.size(0) == len(type_map), ( + "The out_bias shape does not match the type_map length." + ) + bias_mask = ( + torch.gt(torch.abs(out_bias), 1e-6).any(dim=-1).detach().cpu() + ) # 1e-6 for stability + + # TorchScript does not support list comprehension with if clause + result: list[str] = [] + for t, m in zip(type_map, bias_mask.tolist()): + if m: + result.append(t) + return result + + # ========================================================================= + # Serialization + # ========================================================================= + + def serialize(self) -> dict[str, Any]: + """ + Serialize the SeZM model including model-level bridging state. + + Returns + ------- + dict[str, Any] + Serialized SeZM model data. + """ + return { + "@class": "Model", + "@version": 1, + "type": self.model_type, + "atomic_model": self.atomic_model.serialize(), + "bridging_method": self.bridging_method, + "bridging_r_inner": self.bridging_r_inner, + "bridging_r_outer": self.bridging_r_outer, + "lora": self.lora_config, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SeZMModel: + """ + Deserialize the SeZM model including model-level bridging state. + + Parameters + ---------- + data + Serialized SeZM model data. + + Returns + ------- + SeZMModel + Deserialized SeZM model. + """ + data = data.copy() + version = int(data.pop("@version", 1)) + check_version_compatibility(version, 1, 1) + data.pop("@class", None) + data.pop("type", None) + atomic_model = SeZMAtomicModel.deserialize(data.pop("atomic_model")) + return cls(atomic_model_=atomic_model, **data) + + # ========================================================================= + # Context Managers + # ========================================================================= + + @contextmanager + def tf32_precision_ctx(self) -> Generator[None, None, None]: + """Context manager to temporarily set TF32 matmul precision. + + TF32 is only enabled when the model is in training mode; during + inference we force ``highest`` precision because the reduced + mantissa of TF32 can introduce unacceptable errors in force + predictions and downstream MD trajectories. + """ + if not self.should_use_compile() or not torch.cuda.is_available(): + yield + return + prev_precision = torch.get_float32_matmul_precision() + try: + if self.enable_tf32 and self.training: + torch.set_float32_matmul_precision("high") + else: + torch.set_float32_matmul_precision("highest") + yield + finally: + torch.set_float32_matmul_precision(prev_precision) + + +# ============================================================================= +# InterPotential: analytical pair potentials for bridging +# ============================================================================= + +# fmt: off +ELEMENT_TO_Z: dict[str, int] = { + "H": 1, "He": 2, "Li": 3, "Be": 4, "B": 5, "C": 6, "N": 7, "O": 8, + "F": 9, "Ne": 10, "Na": 11, "Mg": 12, "Al": 13, "Si": 14, "P": 15, + "S": 16, "Cl": 17, "Ar": 18, "K": 19, "Ca": 20, "Sc": 21, "Ti": 22, + "V": 23, "Cr": 24, "Mn": 25, "Fe": 26, "Co": 27, "Ni": 28, "Cu": 29, + "Zn": 30, "Ga": 31, "Ge": 32, "As": 33, "Se": 34, "Br": 35, "Kr": 36, + "Rb": 37, "Sr": 38, "Y": 39, "Zr": 40, "Nb": 41, "Mo": 42, "Tc": 43, + "Ru": 44, "Rh": 45, "Pd": 46, "Ag": 47, "Cd": 48, "In": 49, "Sn": 50, + "Sb": 51, "Te": 52, "I": 53, "Xe": 54, "Cs": 55, "Ba": 56, "La": 57, + "Ce": 58, "Pr": 59, "Nd": 60, "Pm": 61, "Sm": 62, "Eu": 63, "Gd": 64, + "Tb": 65, "Dy": 66, "Ho": 67, "Er": 68, "Tm": 69, "Yb": 70, "Lu": 71, + "Hf": 72, "Ta": 73, "W": 74, "Re": 75, "Os": 76, "Ir": 77, "Pt": 78, + "Au": 79, "Hg": 80, "Tl": 81, "Pb": 82, "Bi": 83, "Po": 84, "At": 85, + "Rn": 86, "Fr": 87, "Ra": 88, "Ac": 89, "Th": 90, "Pa": 91, "U": 92, + "Np": 93, "Pu": 94, "Am": 95, "Cm": 96, "Bk": 97, "Cf": 98, "Es": 99, + "Fm": 100, "Md": 101, "No": 102, "Lr": 103, "Rf": 104, "Db": 105, + "Sg": 106, "Bh": 107, "Hs": 108, "Mt": 109, "Ds": 110, "Rg": 111, + "Cn": 112, "Nh": 113, "Fl": 114, "Mc": 115, "Lv": 116, "Ts": 117, + "Og": 118, +} +# fmt: on + +# ZBL screening function coefficients +_ZBL_A_COEFF = (0.18175, 0.50986, 0.28022, 0.028171) +_ZBL_B_COEFF = (3.1998, 0.94229, 0.4029, 0.20162) + +# Physical constants +_KE_EV_A = 14.3996 # Coulomb constant in eV·Å +_A_BOHR = 0.5291772109 # Bohr radius in Å + + +class InterPotential(torch.nn.Module): + """ + Analytical pair potential module for Zone bridging. + + Supports the Ziegler-Biersack-Littmark (ZBL) screened nuclear repulsion + potential. Designed to be extensible to other analytical forms (LJ, Morse, + etc.) through the ``mode`` parameter. + + Each pair (i, j) contributes ``V_ZBL(r_ij) / 2`` to both atom i and atom j, + avoiding double-counting from the symmetric neighbor list. + + Parameters + ---------- + type_map : list[str] + Element symbols (e.g. ``["O", "H"]``). Index in this list corresponds + to the ``atype`` integer values. + mode : str + Potential formula. Currently only ``"zbl"`` is supported. + + Raises + ------ + ValueError + If ``mode`` is not recognized, or if any element in ``type_map`` is + not found in the periodic table. + """ + + def __init__(self, type_map: list[str], mode: str = "zbl") -> None: + super().__init__() + mode = mode.upper() + if mode != "ZBL": + raise ValueError(f"Unknown InterPotential mode: {mode}") + self.mode = mode + + atomic_numbers = [] + for elem in type_map: + z = ELEMENT_TO_Z.get(elem) + if z is None: + raise ValueError(f"Unknown element symbol: {elem}") + atomic_numbers.append(z) + self.register_buffer( + "atomic_numbers", + torch.tensor(atomic_numbers, dtype=torch.float64, device=env.DEVICE), + ) + + def _zbl_pair_energy( + self, + r: torch.Tensor, + zi: torch.Tensor, + zj: torch.Tensor, + ) -> torch.Tensor: + """ + Compute ZBL pair energy for given distances and nuclear charges. + + Parameters + ---------- + r : torch.Tensor + Pair distances with shape (...) in Å. + zi : torch.Tensor + Nuclear charge of atom i with shape (...). + zj : torch.Tensor + Nuclear charge of atom j with shape (...). + + Returns + ------- + torch.Tensor + Pair energies with shape (...) in eV. + """ + a_screen = 0.88534 * _A_BOHR / (zi.pow(0.23) + zj.pow(0.23)) + x = r / a_screen + phi = sum(a * torch.exp(-b * x) for a, b in zip(_ZBL_A_COEFF, _ZBL_B_COEFF)) + return _KE_EV_A * zi * zj / r * phi + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + nloc: int, + real_type_count: int | None = None, + ) -> torch.Tensor: + """ + Compute per-atom pair energy from the standard neighbor list path. + + Parameters + ---------- + extended_coord + Coordinates in extended region with shape (nf, nall, 3) in Å. + extended_atype + Atom types in extended region with shape (nf, nall). + nlist + Neighbor list with shape (nf, nloc, nsel). + nloc : int + Number of local atoms. + real_type_count + Number of real atom types. Types with index greater than or equal to + this value are virtual spin types and are masked out of the + analytical potential. If omitted, all configured types are real. + + Returns + ------- + torch.Tensor + Per-atom pair energy with shape (nf, nloc, 1) in eV. + """ + if real_type_count is None: + real_type_count = int(self.atomic_numbers.numel()) + nf = extended_coord.shape[0] + coord64 = extended_coord.to(dtype=torch.float64) + atype_for_z = extended_atype.clamp(min=0) + atype_for_z = torch.where( + atype_for_z >= real_type_count, + atype_for_z - real_type_count, + atype_for_z, + ) + z_all = self.atomic_numbers[atype_for_z] # (nf, nall) + + # === Step 1. Gather neighbor coordinates and types === + nsel = nlist.shape[2] + nlist_clamp = nlist.clamp(min=0) # (nf, nloc, nsel) + nei_coord = torch.gather( + coord64, 1, nlist_clamp.unsqueeze(-1).expand(-1, -1, -1, 3).view(nf, -1, 3) + ).view(nf, nloc, nsel, 3) + atom_coord = coord64[:, :nloc].unsqueeze(2) # (nf, nloc, 1, 3) + diff = nei_coord - atom_coord # (nf, nloc, nsel, 3) + r = diff.norm(dim=-1).clamp(min=1e-10) # (nf, nloc, nsel) + + zi = z_all[:, :nloc].unsqueeze(2).expand_as(r) # (nf, nloc, nsel) + zj_idx = nlist_clamp + zj = torch.gather(z_all, 1, zj_idx.view(nf, -1)).view(nf, nloc, nsel) + + # === Step 2. Compute pair energies === + pair_e = self._zbl_pair_energy(r, zi, zj) # (nf, nloc, nsel) + + # Mask padding entries (nlist == -1) + valid = (nlist >= 0).to(dtype=pair_e.dtype) + center_is_real = (extended_atype[:, :nloc] < real_type_count).unsqueeze(2) + neighbor_atype = torch.gather(extended_atype, 1, nlist_clamp.view(nf, -1)).view( + nf, nloc, nsel + ) + neighbor_is_real = neighbor_atype < real_type_count + valid = valid * (center_is_real & neighbor_is_real).to(dtype=pair_e.dtype) + pair_e = pair_e * valid + + # Half contribution to avoid double-counting + atom_pair_energy = (pair_e * 0.5).sum(dim=-1, keepdim=True) # (nf, nloc, 1) + return atom_pair_energy.to(dtype=extended_coord.dtype) + + def forward_from_edges( + self, + edge_vec: torch.Tensor, + edge_index: torch.Tensor, + atype_flat: torch.Tensor, + edge_mask: torch.Tensor, + n_node: int, + ) -> torch.Tensor: + """ + Compute per-atom pair energy from the compile-path edge list. + + Parameters + ---------- + edge_vec + Edge vectors with shape (E, 3) in Å. + edge_index + Edge source/destination indices with shape (2, E). + atype_flat + Flat atom types with shape (N,). + edge_mask + Boolean mask with shape (E,). True means valid edge. + n_node : int + Number of flattened local nodes. + + Returns + ------- + torch.Tensor + Per-atom pair energy with shape (1, N, 1) in eV. + """ + src = edge_index[0].to(dtype=torch.long) + dst = edge_index[1].to(dtype=torch.long) + + r = edge_vec.to(dtype=torch.float64).norm(dim=-1).clamp(min=1e-10) # (E,) + z_all = self.atomic_numbers[atype_flat.clamp(min=0)] # (N,) + zi = z_all[src] # (E,) + zj = z_all[dst] # (E,) + + pair_e = self._zbl_pair_energy(r, zi, zj) # (E,) + pair_e = pair_e * edge_mask.to(dtype=pair_e.dtype) + + # Half contribution to each destination atom + atom_energy = torch.zeros(n_node, dtype=pair_e.dtype, device=pair_e.device) + atom_energy.index_add_(0, dst, pair_e * 0.5) + + return atom_energy.to(dtype=edge_vec.dtype).view(1, n_node, 1) diff --git a/deepmd/pt/model/model/sezm_spin_model.py b/deepmd/pt/model/model/sezm_spin_model.py new file mode 100644 index 0000000000..472a0581cf --- /dev/null +++ b/deepmd/pt/model/model/sezm_spin_model.py @@ -0,0 +1,695 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Spin-enabled SeZM energy model.""" + +import functools +from collections.abc import ( + Callable, +) +from copy import ( + deepcopy, +) +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel import ( + ModelOutputDef, +) +from deepmd.pt.model.atomic_model.sezm_atomic_model import ( + SeZMAtomicModel, +) +from deepmd.pt.model.descriptor.sezm_nn import ( + nvtx_range, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) +from deepmd.pt.model.model.sezm_model import ( + InterPotential, + SeZMModel, +) +from deepmd.pt.model.model.spin_model import ( + SpinModel, + _lookup_type_values, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) +from deepmd.pt.utils.utils import ( + to_torch_tensor, +) +from deepmd.utils.path import ( + DPPath, +) +from deepmd.utils.spin import ( + Spin, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +@BaseModel.register("sezm_spin") +class SeZMSpinModel(SeZMModel): + """SeZM energy model with virtual spin atoms. + + Parameters + ---------- + spin + Spin metadata describing magnetic real types and virtual displacement + scales. + *args + Positional arguments forwarded to :class:`SeZMModel`. + **kwargs + Keyword arguments forwarded to :class:`SeZMModel`. + """ + + model_type = "sezm_spin" + + def __init__( + self, + *args: Any, + spin: Spin, + real_sel: list[int], + **kwargs: Any, + ) -> None: + # Delay InterPotential construction until ntypes_real is available. + bridging_method = str(kwargs.pop("bridging_method", "none")).upper() + kwargs["bridging_method"] = "none" + + super().__init__(*args, **kwargs) + self.spin = spin + self.ntypes_real = self.spin.ntypes_real + self.real_sel = [int(sel) for sel in real_sel] + self.register_buffer( + "virtual_scale_mask", + to_torch_tensor(self.spin.get_virtual_scale_mask()), + persistent=False, + ) + self.register_buffer( + "spin_mask", + to_torch_tensor(self.spin.get_spin_mask()), + persistent=False, + ) + + self.bridging_method = bridging_method + self.inter_potential = ( + InterPotential(type_map=self.get_type_map(), mode=self.bridging_method) + if self.bridging_method != "NONE" + else None + ) + + # ========================================================================= + # Forward Methods + # ========================================================================= + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + spin: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Return spin-aware SeZM predictions with public output keys.""" + model_ret = self.forward_common( + coord, + atype, + spin, + box=box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + charge_spin=charge_spin, + ) + model_predict: dict[str, torch.Tensor] = { + "atom_energy": model_ret["energy"], + "energy": model_ret["energy_redu"], + "mask_mag": model_ret["mask_mag"], + } + if self.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) + return model_predict + + def forward_common( + self, + coord: torch.Tensor, + atype: torch.Tensor, + spin: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Return spin-aware SeZM predictions with internal output keys.""" + with nvtx_range("SeZMSpin/forward_common"): + cc, bb, fp, ap, input_prec = self._input_type_cast( + coord, box=box, fparam=fparam, aparam=aparam + ) + del coord, box, fparam, aparam + nf, nloc = atype.shape[:2] + if cc.ndim == 2: + cc = cc.view(nf, nloc, 3) + spin = spin.to(dtype=cc.dtype, device=cc.device).reshape(nf, nloc, 3) + + extended_coord, extended_atype, mapping, nlist = self.build_neighbor_list( + cc, atype, bb + ) + extended_spin = torch.gather( + spin, + 1, + mapping.unsqueeze(-1).expand(-1, -1, 3), + ) + ( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping_updated, + extended_coord_corr, + ) = self.process_spin_input_lower( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping=mapping, + ) + if ap is not None: + ap = self.expand_aparam(ap, nloc * 2) + model_ret = self.forward_common_after_nlist( + extended_coord_updated, + extended_atype_updated, + mapping_updated, + nlist_updated, + extended_atype_updated[:, : nloc * 2], + fp, + ap, + input_prec, + do_atomic_virial=do_atomic_virial, + extended_coord_corr=extended_coord_corr, + charge_spin=charge_spin, + ) + return self._split_spin_common_output(model_ret, atype, nloc) + + def forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_spin: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + comm_dict: dict[str, torch.Tensor] | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Return spin-aware SeZM lower-interface predictions.""" + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + charge_spin=charge_spin, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + ) + model_predict: dict[str, torch.Tensor] = { + "atom_energy": model_ret["energy"], + "energy": model_ret["energy_redu"], + "extended_mask_mag": model_ret["mask_mag"], + } + if self.do_grad_r("energy"): + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + model_predict["extended_force_mag"] = model_ret[ + "energy_derv_r_mag" + ].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( + -2 + ) + return model_predict + + def forward_common_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_spin: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + comm_dict: dict[str, torch.Tensor] | None = None, + extra_nlist_sort: bool = False, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Return spin-aware lower-interface predictions with internal keys.""" + _, nloc = nlist.shape[:2] + ( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping_updated, + extended_coord_corr, + ) = self.process_spin_input_lower( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping=mapping, + ) + if aparam is not None: + aparam = self.expand_aparam(aparam, nloc * 2) + model_ret = super().forward_common_lower( + extended_coord_updated, + extended_atype_updated, + nlist_updated, + mapping=mapping_updated, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=extra_nlist_sort, + extended_coord_corr=extended_coord_corr, + charge_spin=charge_spin, + ) + return self._split_spin_lower_output(model_ret, extended_atype, nloc) + + def forward_common_lower_exportable( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_spin: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + charge_spin: torch.Tensor | None = None, + ) -> torch.nn.Module: + """Trace the spin lower interface into an exportable FX graph.""" + extra_sort = self.need_sorted_nlist_for_lower() + + def lower_fn( + ext_coord: torch.Tensor, + ext_atype: torch.Tensor, + ext_spin: torch.Tensor, + nlist_: torch.Tensor, + mapping_: torch.Tensor | None, + fparam_: torch.Tensor | None, + aparam_: torch.Tensor | None, + charge_spin_: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + ext_coord = ext_coord.detach().requires_grad_(True) + return self.forward_common_lower( + ext_coord, + ext_atype, + ext_spin, + nlist_, + mapping_, + fparam=fparam_, + aparam=aparam_, + do_atomic_virial=do_atomic_virial, + extra_nlist_sort=extra_sort, + charge_spin=charge_spin_, + ) + + def fn( + ext_coord: torch.Tensor, + ext_atype: torch.Tensor, + ext_spin: torch.Tensor, + nlist_: torch.Tensor, + mapping_: torch.Tensor | None, + fparam_: torch.Tensor | None, + aparam_: torch.Tensor | None, + *maybe_charge_spin: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + charge_spin_ = maybe_charge_spin[0] if maybe_charge_spin else None + return lower_fn( + ext_coord, + ext_atype, + ext_spin, + nlist_, + mapping_, + fparam_, + aparam_, + charge_spin_, + ) + + trace_inputs = ( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping, + fparam, + aparam, + ) + if self.get_dim_chg_spin() > 0: + charge_spin = self.convert_charge_spin( + charge_spin, + nf=extended_atype.shape[0], + dtype=extended_coord.dtype, + device=extended_coord.device, + ) + trace_inputs = (*trace_inputs, charge_spin) + + return self._trace_lower_exportable( + fn, + *trace_inputs, + ) + + # ========================================================================= + # Statistics and Mode Methods + # ========================================================================= + + def compute_or_load_stat( + self, + sampled_func: Callable[[], list[dict[str, Any]]], + stat_file_path: DPPath | None = None, + preset_observed_type: list[str] | None = None, + ) -> None: + """Compute or load statistics with virtual spin atoms included.""" + super().compute_or_load_stat( + self._get_spin_sampled_func(sampled_func), + stat_file_path, + preset_observed_type=preset_observed_type, + ) + + def change_out_bias( + self, + merged: Callable[[], list[dict[str, Any]]] | list[dict[str, Any]], + bias_adjust_mode: str = "change-by-statistic", + ) -> None: + """Change output bias using spin-expanded sampled data.""" + spin_sampled_func = self._get_spin_sampled_func( + merged if callable(merged) else lambda: merged + ) + super().change_out_bias( + spin_sampled_func, + bias_adjust_mode=bias_adjust_mode, + ) + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat: Any = None + ) -> None: + """Change real type map and rebuild corresponding virtual spin types.""" + type_map_with_spin = type_map + [item + "_spin" for item in type_map] + super().change_type_map(type_map_with_spin, model_with_new_type_stat) + self.ntypes_real = len(type_map) + + def set_active_mode(self, mode: str) -> None: + """Switch mode, allowing only the conservative energy path.""" + normalized = str(mode).lower() + if normalized != "ener": + raise NotImplementedError("SeZM spin supports only the `ener` path.") + super().set_active_mode(normalized) + + def set_active_mode_from_loss(self, loss_type: str) -> None: + """Select execution mode from loss type.""" + normalized = str(loss_type).lower() + if normalized == "dens": + raise NotImplementedError("SeZM spin supports only the `ener` path.") + if normalized in {"ener", "ener_spin"}: + self.set_active_mode("ener") + + # ========================================================================= + # Output Definitions and Metadata + # ========================================================================= + + def has_spin(self) -> bool: + """Return whether this model consumes spin input.""" + return True + + def get_type_map(self) -> list[str]: + """Return the real atom type map.""" + return super().get_type_map()[: self.ntypes_real] + + def get_ntypes(self) -> int: + """Return the number of real atom types.""" + return len(self.get_type_map()) + + def get_sel(self) -> list[int]: + """Return the public real-atom neighbor selection.""" + return self.real_sel + + def get_nsel(self) -> int: + """Return the public real-atom total neighbor count.""" + return int(sum(self.real_sel)) + + def get_nnei(self) -> int: + """Return the public real-atom total neighbor count.""" + return int(sum(self.real_sel)) + + def get_observed_type_list(self) -> list[str]: + """Return observed real types according to the output bias.""" + type_map = self.get_type_map() + out_bias = self.atomic_model.get_out_bias()[0] + assert out_bias is not None, "No out_bias found in the model." + assert out_bias.dim() == 2, "The supported out_bias should be a 2D tensor." + assert out_bias.size(0) >= self.ntypes_real, ( + "The out_bias shape is smaller than the number of real types." + ) + bias_mask = ( + torch.gt(torch.abs(out_bias[: self.ntypes_real]), 1e-6).any(dim=-1).cpu() + ) + result: list[str] = [] + for t, m in zip(type_map, bias_mask.tolist()): + if m: + result.append(t) + return result + + def model_output_def(self) -> ModelOutputDef: + """Return the spin-aware model output definition.""" + var_name = self._get_output_var_name() + atomic_output_def = self.atomic_output_def() + atomic_output_def[var_name].magnetic = True + return ModelOutputDef(atomic_output_def) + + def translated_output_def(self) -> dict[str, Any]: + """Translate internal output definitions to public spin keys.""" + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": out_def_data["energy"], + "energy": out_def_data["energy_redu"], + "mask_mag": out_def_data["mask_mag"], + } + if self.do_grad_r("energy"): + output_def["force"] = deepcopy(out_def_data["energy_derv_r"]) + output_def["force"].squeeze(-2) + output_def["force_mag"] = deepcopy(out_def_data["energy_derv_r_mag"]) + output_def["force_mag"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) + output_def["atom_virial"].squeeze(-2) + return output_def + + # ========================================================================= + # Serialization + # ========================================================================= + + def serialize(self) -> dict[str, Any]: + """Serialize the SeZM spin model.""" + data = super().serialize() + data["type"] = self.model_type + data["spin"] = self.spin.serialize() + data["real_sel"] = self.real_sel + return data + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> "SeZMSpinModel": + """Deserialize a SeZM spin model.""" + data = data.copy() + version = int(data.pop("@version", 1)) + check_version_compatibility(version, 1, 1) + data.pop("@class", None) + data.pop("type", None) + spin = Spin.deserialize(data.pop("spin")) + real_sel = data.pop("real_sel") + atomic_model = SeZMAtomicModel.deserialize(data.pop("atomic_model")) + return cls(atomic_model_=atomic_model, spin=spin, real_sel=real_sel, **data) + + # ========================================================================= + # Small Utilities + # ========================================================================= + + def build_neighbor_list( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Build the real-atom neighbor list before spin expansion.""" + return extend_input_and_build_neighbor_list( + coord, + atype, + self.get_rcut(), + self.real_sel, + mixed_types=True, + box=box, + ) + + def format_nlist( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + extra_nlist_sort: bool = False, + ) -> torch.Tensor: + """Format spin-expanded nlist to the internal descriptor capacity.""" + del extended_atype + return self._format_nlist( + extended_coord, + nlist, + sum(self.atomic_model.get_sel()), + extra_nlist_sort=extra_nlist_sort, + ) + + def _get_inter_potential_real_type_count(self) -> int: + """Return the number of real types for real-only ZBL masking.""" + return self.ntypes_real + + def _get_output_var_name(self) -> str: + """Return the primary atomic output variable name.""" + return "energy" + + def _get_spin_sampled_func( + self, sampled_func: Callable[[], list[dict[str, Any]]] + ) -> Callable[[], list[dict[str, Any]]]: + """Wrap a data sampler so statistics see real and virtual atoms.""" + + @functools.lru_cache + def spin_sampled_func() -> list[dict[str, Any]]: + sampled = sampled_func() + spin_sampled = [] + for sys in sampled: + coord_updated, atype_updated, _ = self.process_spin_input( + sys["coord"], sys["atype"], sys["spin"] + ) + tmp_dict = { + "coord": coord_updated, + "atype": atype_updated, + } + if "aparam" in sys: + tmp_dict["aparam"] = self.expand_aparam( + sys["aparam"], atype_updated.shape[1] + ) + if "natoms" in sys: + natoms = sys["natoms"] + tmp_dict["natoms"] = torch.cat( + [2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], dim=-1 + ) + for item_key in sys.keys(): + if item_key not in [ + "coord", + "atype", + "spin", + "natoms", + "aparam", + ]: + tmp_dict[item_key] = sys[item_key] + spin_sampled.append(tmp_dict) + return spin_sampled + + return self.atomic_model._make_wrapped_sampler(spin_sampled_func) + + def _ensure_mask_mag( + self, + model_ret: dict[str, torch.Tensor], + atype: torch.Tensor, + ) -> None: + """Ensure the magnetic atom mask exists in ``model_ret``.""" + if "mask_mag" in model_ret: + return + nframes, nloc = atype.shape[:2] + atomic_mask = _lookup_type_values(self.virtual_scale_mask, atype).reshape( + [nframes, nloc, 1] + ) + model_ret["mask_mag"] = atomic_mask > 0.0 + + def _split_spin_common_output( + self, + model_ret: dict[str, torch.Tensor], + atype: torch.Tensor, + nloc: int, + ) -> dict[str, torch.Tensor]: + """Split full-interface SeZM outputs into real and magnetic parts.""" + var_name = self._get_output_var_name() + model_ret[var_name] = torch.split(model_ret[var_name], [nloc, nloc], dim=1)[0] + if self.do_grad_r(var_name) and model_ret.get(f"{var_name}_derv_r") is not None: + ( + model_ret[f"{var_name}_derv_r"], + model_ret[f"{var_name}_derv_r_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output(atype, model_ret[f"{var_name}_derv_r"]) + if self.do_grad_c(var_name) and model_ret.get(f"{var_name}_derv_c") is not None: + ( + model_ret[f"{var_name}_derv_c"], + model_ret[f"{var_name}_derv_c_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output( + atype, + model_ret[f"{var_name}_derv_c"], + add_mag=True, + virtual_scale=False, + ) + self._ensure_mask_mag(model_ret, atype) + return model_ret + + def _split_spin_lower_output( + self, + model_ret: dict[str, torch.Tensor], + extended_atype: torch.Tensor, + nloc: int, + ) -> dict[str, torch.Tensor]: + """Split lower-interface SeZM outputs into real and magnetic parts.""" + var_name = self._get_output_var_name() + model_ret[var_name] = torch.split(model_ret[var_name], [nloc, nloc], dim=1)[0] + if self.do_grad_r(var_name) and model_ret.get(f"{var_name}_derv_r") is not None: + ( + model_ret[f"{var_name}_derv_r"], + model_ret[f"{var_name}_derv_r_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output_lower( + extended_atype, model_ret[f"{var_name}_derv_r"], nloc + ) + if self.do_grad_c(var_name) and model_ret.get(f"{var_name}_derv_c") is not None: + ( + model_ret[f"{var_name}_derv_c"], + model_ret[f"{var_name}_derv_c_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output_lower( + extended_atype, + model_ret[f"{var_name}_derv_c"], + nloc, + add_mag=True, + virtual_scale=False, + ) + self._ensure_mask_mag(model_ret, extended_atype) + return model_ret + + process_spin_input = SpinModel.process_spin_input + process_spin_input_lower = SpinModel.process_spin_input_lower + process_spin_output = SpinModel.process_spin_output + process_spin_output_lower = SpinModel.process_spin_output_lower + extend_nlist = staticmethod(SpinModel.extend_nlist) + expand_aparam = staticmethod(SpinModel.expand_aparam) diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index 91c6e2ea71..fe5b9505db 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -36,6 +36,19 @@ ) +def _lookup_type_values(values: torch.Tensor, atype: torch.Tensor) -> torch.Tensor: + """ + Gather one scalar value per atom type. + + ``values[atype]`` is semantically equivalent, but AOTInductor may lower + that advanced-indexing form to a CUDA ``index.Tensor`` shim even for a CPU + ``.pt2`` package. ``index_select`` keeps the exported spin graph device + stable while preserving the same lookup semantics. + """ + flat_atype = atype.reshape(-1).to(dtype=torch.long) + return torch.index_select(values.to(atype.device), 0, flat_atype).view(atype.shape) + + class SpinModel(torch.nn.Module): """A spin model wrapper, with spin input preprocess and output split.""" @@ -70,9 +83,10 @@ def process_spin_input( spin = spin.reshape(nframes, nloc, 3) atype_spin = torch.concat([atype, atype + self.ntypes_real], dim=-1) # spin_dist = s_i * \mu_i - spin_dist = spin * (self.virtual_scale_mask.to(atype.device))[atype].reshape( - [nframes, nloc, 1] - ) + spin_dist = spin * _lookup_type_values( + self.virtual_scale_mask, + atype, + ).reshape([nframes, nloc, 1]) virtual_coord = coord + spin_dist coord_spin = torch.concat([coord, virtual_coord], dim=-2) # for spin virial corr @@ -115,9 +129,10 @@ def process_spin_input_lower( """ nframes, nall = extended_coord.shape[:2] nloc = nlist.shape[1] - extended_spin_dist = extended_spin * ( - self.virtual_scale_mask.to(extended_atype.device) - )[extended_atype].reshape([nframes, nall, 1]) + extended_spin_dist = extended_spin * _lookup_type_values( + self.virtual_scale_mask, + extended_atype, + ).reshape([nframes, nall, 1]) virtual_extended_coord = extended_coord + extended_spin_dist virtual_extended_atype = extended_atype + self.ntypes_real extended_coord_updated = concat_switch_virtual( @@ -165,7 +180,9 @@ def process_spin_output( virtual_scale_mask = self.virtual_scale_mask.to(atype.device) else: virtual_scale_mask = self.spin_mask.to(atype.device) - atomic_mask = virtual_scale_mask[atype].reshape([nframes, nloc, 1]) + atomic_mask = _lookup_type_values(virtual_scale_mask, atype).reshape( + [nframes, nloc, 1] + ) out_real, out_mag = torch.split(out_tensor, [nloc, nloc], dim=1) if add_mag: out_real = out_real + out_mag @@ -198,7 +215,10 @@ def process_spin_output_lower( virtual_scale_mask = self.virtual_scale_mask.to(extended_atype.device) else: virtual_scale_mask = self.spin_mask.to(extended_atype.device) - atomic_mask = virtual_scale_mask[extended_atype].reshape([nframes, nall, 1]) + atomic_mask = _lookup_type_values( + virtual_scale_mask, + extended_atype, + ).reshape([nframes, nall, 1]) extended_out_real = torch.cat( [ extended_out_tensor[:, :nloc], diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 02f7611429..13ea438f4f 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -294,6 +294,77 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: FittingNet = make_fitting_network(EmbeddingNet, MLP, MLPLayer) +class GLULayer(nn.Module): + """ + A GLU block for MLPs: Linear -> split -> value * act(gate). + + Parameters + ---------- + num_in + Input dimension. + num_out + Output dimension. + activation_function + Activation function applied to the gate branch. + precision + Numerical precision. + bias + Whether to use bias in the linear layer. + seed + Random seed for weight initialization. + trainable + Whether parameters are trainable. + """ + + def __init__( + self, + num_in: int, + num_out: int, + activation_function: str, + precision: str, + seed: int | list[int] | None, + trainable: bool, + bias: bool = True, + ) -> None: + super().__init__() + self.num_in = int(num_in) + self.num_out = int(num_out) + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + + self.linear = MLPLayer( + num_in=self.num_in, + num_out=2 * self.num_out, + bias=bias, + use_timestep=False, + activation_function=None, + resnet=False, + precision=self.precision, + seed=seed, + trainable=trainable, + ) + self.activation = ActivationFn(self.activation_function) + + def forward(self, xx: torch.Tensor) -> torch.Tensor: + """ + Apply GLU transformation. + + Parameters + ---------- + xx + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor. + """ + yy = self.linear(xx) + val, gate = yy.chunk(2, dim=-1) + return val * self.activation(gate) + + class NetworkCollection(DPNetworkCollection, nn.Module): """PyTorch implementation of NetworkCollection.""" diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py index 37ffec2725..a1b1173a0c 100644 --- a/deepmd/pt/model/task/__init__.py +++ b/deepmd/pt/model/task/__init__.py @@ -24,6 +24,9 @@ from .property import ( PropertyFittingNet, ) +from .sezm_ener import ( + SeZMEnergyFittingNet, +) from .type_predict import ( TypePredictNet, ) @@ -38,5 +41,6 @@ "Fitting", "PolarFittingNet", "PropertyFittingNet", + "SeZMEnergyFittingNet", "TypePredictNet", ] diff --git a/deepmd/pt/model/task/sezm_ener.py b/deepmd/pt/model/task/sezm_ener.py new file mode 100644 index 0000000000..c6af12fb5f --- /dev/null +++ b/deepmd/pt/model/task/sezm_ener.py @@ -0,0 +1,786 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""SeZM GLU energy fitting networks.""" + +from __future__ import ( + annotations, +) + +import math +from typing import ( + Any, + ClassVar, +) + +import torch + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.network.mlp import ( + GLULayer, + MLPLayer, +) +from deepmd.pt.model.task.fitting import ( + Fitting, + GeneralFitting, +) +from deepmd.pt.model.task.invar_fitting import ( + InvarFitting, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, + DEVICE, + PRECISION_DICT, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + + +class CaseFiLMConditioner(torch.nn.Module): + """ + Case-conditioned FiLM generator for SeZM fitting features. + + Parameters + ---------- + dim_case_embd + Case one-hot width. + dim_descrpt + Descriptor output width. + target_dims + Feature widths of all FiLM modulation targets. + activation_function + Activation used by the case MLP hidden layer. + precision + Numerical precision. + seed + Random seed. + trainable + Whether parameters are trainable. + """ + + def __init__( + self, + *, + dim_case_embd: int, + dim_descrpt: int, + target_dims: list[int], + activation_function: str, + precision: str, + seed: int | list[int] | None, + trainable: bool, + ) -> None: + super().__init__() + self.dim_case_embd = int(dim_case_embd) + self.dim_descrpt = int(dim_descrpt) + self.target_dims = [int(dim) for dim in target_dims] + self.activation_function = str(activation_function) + self.precision = str(precision) + self.prec = PRECISION_DICT[self.precision] + self.code_dim = 4 * self.dim_descrpt + hidden_dim = int(32 * math.ceil((4.0 * float(self.dim_case_embd)) / 32.0)) + + self.case_layer1 = MLPLayer( + self.dim_case_embd, + hidden_dim, + bias=False, + use_timestep=False, + activation_function=self.activation_function, + resnet=False, + precision=self.precision, + seed=child_seed(seed, 0), + trainable=trainable, + ) + self.case_layer2 = MLPLayer( + hidden_dim, + self.code_dim, + bias=False, + use_timestep=False, + activation_function=None, + resnet=False, + precision=self.precision, + seed=child_seed(seed, 1), + trainable=trainable, + ) + self.projectors = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.zeros( + self.code_dim, + 2 * target_dim, + dtype=self.prec, + device=DEVICE, + ) + ) + for target_dim in self.target_dims + ] + ) + strength_init = math.log(0.01) + self.adam_case_film_scale_strength_log = torch.nn.Parameter( + torch.full( + (len(self.target_dims),), + strength_init, + dtype=self.prec, + device=DEVICE, + ) + ) + self.adam_case_film_shift_strength_log = torch.nn.Parameter( + torch.full( + (len(self.target_dims),), + strength_init, + dtype=self.prec, + device=DEVICE, + ) + ) + + for param in self.parameters(): + param.requires_grad = trainable + + def encode(self, case_embd: torch.Tensor) -> torch.Tensor: + """ + Encode a compact case one-hot vector. + + Parameters + ---------- + case_embd + Case one-hot vector with shape (K,) or (1, K). + + Returns + ------- + torch.Tensor + Case code with shape (1, 4*dim_descrpt). + """ + code = case_embd.reshape(1, self.dim_case_embd) + return self.case_layer2(self.case_layer1(code)) + + def apply( + self, + xx: torch.Tensor, + case_code: torch.Tensor, + target_idx: int, + ) -> torch.Tensor: + """ + Apply one FiLM target to a feature tensor. + + Parameters + ---------- + xx + Feature tensor with shape (..., target_dim). + case_code + Encoded case tensor with shape (1, 4*dim_descrpt). + target_idx + Index of the target to modulate. + + Returns + ------- + torch.Tensor + Modulated feature tensor with the same shape as ``xx``. + """ + film = torch.matmul(case_code, self.projectors[target_idx]) + gamma, beta = film.chunk(2, dim=-1) + view_shape = [1 for _ in range(xx.ndim - 1)] + [xx.shape[-1]] + gamma = gamma.reshape(view_shape) + beta = beta.reshape(view_shape) + scale_strength = torch.exp(self.adam_case_film_scale_strength_log[target_idx]) + shift_strength = torch.exp(self.adam_case_film_shift_strength_log[target_idx]) + return xx * (1.0 + scale_strength * torch.tanh(gamma)) + ( + shift_strength * torch.tanh(beta) + ) + + +class GLUFittingNet(torch.nn.Module): + """ + GLU-based fitting network for SeZM. + + Parameters + ---------- + in_dim + Input dimension. + out_dim + Output dimension. + neuron + Hidden layer sizes. Empty list means direct linear projection. + activation_function + Activation function used for GLU gating. + resnet_dt + Reserved for compatibility; not used in GLU layers. + precision + Numerical precision. + bias_out + Whether the output layer uses bias. + seed + Random seed. + trainable + Whether parameters are trainable. + descriptor_dim + Descriptor feature width. Used by case FiLM to avoid modulating + frame/atomic parameters. + dim_case_embd + Case one-hot width. + case_film_embd + Whether to use case FiLM instead of input concatenation. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + neuron: list[int] | None = None, + activation_function: str = "silu", + resnet_dt: bool = False, + precision: str = DEFAULT_PRECISION, + bias_out: bool = False, + seed: int | list[int] | None = None, + trainable: bool | list[bool] = True, + descriptor_dim: int | None = None, + dim_case_embd: int = 0, + case_film_embd: bool = False, + ) -> None: + super().__init__() + if neuron is None: + neuron = [] + if isinstance(trainable, list): + trainable = all(trainable) + self.in_dim = int(in_dim) + self.out_dim = int(out_dim) + self.neuron = [int(nn_dim) for nn_dim in neuron] + self.activation_function = activation_function + self.resnet_dt = bool(resnet_dt) + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.bias_out = bool(bias_out) + self.descriptor_dim = ( + self.in_dim if descriptor_dim is None else int(descriptor_dim) + ) + self.dim_case_embd = int(dim_case_embd) + self.case_film_embd = bool(case_film_embd and self.dim_case_embd > 0) + + # === Step 1. Build GLU hidden layers === + hidden_layers = [] + dim_in = self.in_dim + for layer_idx, hidden_dim in enumerate(self.neuron): + hidden_layers.append( + GLULayer( + dim_in, + hidden_dim, + activation_function=self.activation_function, + precision=self.precision, + seed=child_seed(seed, layer_idx), + trainable=trainable, + ) + ) + dim_in = hidden_dim + self.hidden_layers = torch.nn.ModuleList(hidden_layers) + + # === Step 2. Build optional case FiLM conditioner === + if self.case_film_embd: + self.case_film = CaseFiLMConditioner( + dim_case_embd=self.dim_case_embd, + dim_descrpt=self.descriptor_dim, + target_dims=[self.descriptor_dim, *self.neuron], + activation_function=self.activation_function, + precision=self.precision, + seed=child_seed(seed, len(self.neuron)), + trainable=trainable, + ) + else: + self.case_film = None + + # === Step 3. Build output projection === + self.output_layer = MLPLayer( + num_in=dim_in, + num_out=self.out_dim, + bias=self.bias_out, + use_timestep=False, + activation_function=None, + resnet=False, + precision=self.precision, + seed=child_seed(seed, len(self.neuron) + int(self.case_film_embd)), + trainable=trainable, + ) + + for param in self.parameters(): + param.requires_grad = trainable + + def _apply_input_film( + self, + xx: torch.Tensor, + case_code: torch.Tensor, + ) -> torch.Tensor: + """Apply FiLM only to the descriptor slice of the fitting input.""" + descrpt = self.case_film.apply(xx[..., : self.descriptor_dim], case_code, 0) + if self.descriptor_dim == self.in_dim: + return descrpt + return torch.cat([descrpt, xx[..., self.descriptor_dim :]], dim=-1) + + def forward( + self, + xx: torch.Tensor, + case_embd: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Forward pass for the GLU fitting net. + + Parameters + ---------- + xx + Input tensor. + case_embd + Optional compact case one-hot vector with shape (K,). + + Returns + ------- + torch.Tensor + Output tensor. + """ + if self.case_film_embd: + case_code = self.case_film.encode(case_embd) + xx = self._apply_input_film(xx, case_code) + for layer_idx, layer in enumerate(self.hidden_layers): + xx = layer(xx) + xx = self.case_film.apply(xx, case_code, layer_idx + 1) + else: + for layer in self.hidden_layers: + xx = layer(xx) + return self.output_layer(xx) + + def call_until_last( + self, + xx: torch.Tensor, + case_embd: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Return activations before the output projection. + + Parameters + ---------- + xx + Input tensor. + case_embd + Optional compact case one-hot vector with shape (K,). + + Returns + ------- + torch.Tensor + Hidden activations, or input if no hidden layers exist. + """ + if self.case_film_embd: + case_code = self.case_film.encode(case_embd) + xx = self._apply_input_film(xx, case_code) + for layer_idx, layer in enumerate(self.hidden_layers): + xx = layer(xx) + xx = self.case_film.apply(xx, case_code, layer_idx + 1) + return xx + for layer in self.hidden_layers: + xx = layer(xx) + return xx + + def serialize(self) -> dict[str, Any]: + """Serialize the network to a dict.""" + state = self.state_dict() + return { + "@class": "GLUFittingNet", + "@version": 1, + "in_dim": self.in_dim, + "out_dim": self.out_dim, + "neuron": self.neuron.copy(), + "activation_function": self.activation_function, + "resnet_dt": self.resnet_dt, + "precision": self.precision, + "bias_out": self.bias_out, + "descriptor_dim": self.descriptor_dim, + "dim_case_embd": self.dim_case_embd, + "case_film_embd": self.case_film_embd, + "@variables": {key: to_numpy_array(value) for key, value in state.items()}, + } + + @classmethod + def deserialize(cls, data: dict) -> GLUFittingNet: + """Deserialize the network from a dict.""" + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + variables = data.pop("@variables", {}) + obj = cls(**data) + state = {key: to_torch_tensor(value) for key, value in variables.items()} + obj.load_state_dict(state) + return obj + + +class SeZMNetworkCollection(torch.nn.Module): + """ + Network collection for SeZM fitting networks. + + Parameters + ---------- + ndim + The number of type dimensions. + ntypes + Number of atom types. + network_type + The network type name. Only "sezm_fitting_network" is supported. + networks + The networks to initialize with. + """ + + NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = { + "sezm_fitting_network": GLUFittingNet, + } + + def __init__( + self, + ndim: int, + ntypes: int, + network_type: str = "sezm_fitting_network", + networks: list[GLUFittingNet | dict | None] | None = None, + ) -> None: + super().__init__() + self.ndim = int(ndim) + self.ntypes = int(ntypes) + if network_type not in self.NETWORK_TYPE_MAP: + raise ValueError(f"Unknown network_type: {network_type}") + self.network_type = self.NETWORK_TYPE_MAP[network_type] + if networks is None: + networks = [] + + total = self.ntypes**self.ndim + self._networks: list[GLUFittingNet | None] = [None for _ in range(total)] + for idx, network in enumerate(networks): + self[idx] = network + if any(net is None for net in self._networks): + raise RuntimeError("SeZMNetworkCollection is incomplete.") + self.networks = torch.nn.ModuleList(self._networks) + + def _convert_key(self, key: int | tuple | str) -> int: + if isinstance(key, int): + idx = key + else: + if isinstance(key, tuple): + pass + elif isinstance(key, str): + key = tuple([int(tt) for tt in key.split("_")[1:]]) + else: + raise TypeError(key) + assert isinstance(key, tuple) + assert len(key) == self.ndim + idx = sum([tt * self.ntypes**ii for ii, tt in enumerate(key)]) + return idx + + def __getitem__(self, key: int | tuple | str) -> GLUFittingNet: + idx = self._convert_key(key) + nn = self._networks[idx] + assert nn is not None + return nn + + def __setitem__(self, key: int | tuple | str, value: GLUFittingNet | dict) -> None: + if isinstance(value, self.network_type): + network = value + elif isinstance(value, dict): + network = self.network_type.deserialize(value) + else: + raise TypeError(value) + idx = self._convert_key(key) + self._networks[idx] = network + + def serialize(self) -> dict[str, Any]: + """Serialize the networks to a dict.""" + network_type_map_inv = {v: k for k, v in self.NETWORK_TYPE_MAP.items()} + return { + "@class": "NetworkCollection", + "@version": 1, + "ndim": self.ndim, + "ntypes": self.ntypes, + "network_type": network_type_map_inv[self.network_type], + "networks": [ + nn.serialize() if nn is not None else None for nn in self._networks + ], + } + + @classmethod + def deserialize(cls, data: dict) -> SeZMNetworkCollection: + """Deserialize the networks from a dict.""" + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + return cls(**data) + + +def _resolve_auto_neuron( + neuron: list[int] | None, + *, + dim_descrpt: int, + numb_fparam: int, + numb_aparam: int, + dim_case_embd: int, + case_film_embd: bool, + use_aparam_as_mask: bool, +) -> list[int]: + """Resolve SeZM fitting hidden widths, using 0 as the auto-width marker.""" + resolved_neuron = [0] if neuron is None else [int(width) for width in neuron] + if any(width < 0 for width in resolved_neuron): + raise ValueError("`fitting_net.neuron` entries must be >= 0") + if 0 not in resolved_neuron: + return resolved_neuron + case_dim = 0 if case_film_embd else int(dim_case_embd) + dim_in = ( + int(dim_descrpt) + + int(numb_fparam) + + (0 if use_aparam_as_mask else int(numb_aparam)) + + case_dim + ) + resolved_width = int(32 * math.ceil((8.0 * float(dim_in) / 3.0) / 32.0)) + return [resolved_width if width == 0 else width for width in resolved_neuron] + + +@Fitting.register("dpa4_ener") +@Fitting.register("sezm_ener") +class SeZMEnergyFittingNet(InvarFitting): + """ + SeZM energy fitting with GLU hidden layers. + + This uses the same configuration keys as the standard energy fitting + but replaces hidden MLP layers with GLU blocks. + """ + + def __init__( + self, + ntypes: int, + dim_descrpt: int, + neuron: list[int] | None = None, + bias_atom_e: torch.Tensor | None = None, + resnet_dt: bool = False, + numb_fparam: int = 0, + numb_aparam: int = 0, + dim_case_embd: int = 0, + case_film_embd: bool = False, + activation_function: str = "silu", + bias_out: bool = False, + precision: str = "float32", + mixed_types: bool = True, + seed: int | list[int] | None = None, + type_map: list[str] | None = None, + default_fparam: list | None = None, + **kwargs: Any, + ) -> None: + neuron = _resolve_auto_neuron( + neuron, + dim_descrpt=dim_descrpt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, + case_film_embd=case_film_embd, + use_aparam_as_mask=bool(kwargs.get("use_aparam_as_mask", False)), + ) + super().__init__( + "energy", + ntypes, + dim_descrpt, + 1, + neuron=neuron, + bias_atom_e=bias_atom_e, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, + activation_function=activation_function, + precision=precision, + mixed_types=mixed_types, + seed=seed, + type_map=type_map, + default_fparam=default_fparam, + **kwargs, + ) + self.bias_out = bool(bias_out) + self.case_film_embd = bool(case_film_embd and self.dim_case_embd > 0) + self._build_glu_fitting_layers() + + def _build_glu_fitting_layers(self) -> None: + # === Step 1. Derive input/output dimensions === + case_dim = 0 if self.case_film_embd else self.dim_case_embd + in_dim = ( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + + case_dim + ) + net_dim_out = self._net_out_dim() + n_networks = self.ntypes if not self.mixed_types else 1 + + # === Step 2. Build GLU fitting networks === + self.filter_layers = SeZMNetworkCollection( + 1 if not self.mixed_types else 0, + self.ntypes, + network_type="sezm_fitting_network", + networks=[ + GLUFittingNet( + in_dim, + net_dim_out, + self.neuron, + activation_function=self.activation_function, + resnet_dt=self.resnet_dt, + precision=self.precision, + bias_out=self.bias_out, + seed=child_seed(self.seed, idx), + trainable=self.trainable, + descriptor_dim=self.dim_descrpt, + dim_case_embd=self.dim_case_embd, + case_film_embd=self.case_film_embd, + ) + for idx in range(n_networks) + ], + ) + for param in self.parameters(): + param.requires_grad = self.trainable + + def _forward_common( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Run the SeZM fitting path with optional case FiLM.""" + if not self.case_film_embd: + return super()._forward_common( + descriptor, + atype, + gr, + g2, + h2, + fparam, + aparam, + ) + return self._forward_case_film(descriptor, atype, fparam, aparam) + + def _forward_case_film( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """ + Forward path for SeZM case FiLM. + + Parameters + ---------- + descriptor + Descriptor tensor with shape (nf, nloc, dim_descrpt). + atype + Atom types with shape (nf, nloc). + fparam + Frame parameters with shape (nf, numb_fparam). + aparam + Atomic parameters with shape (nf, nloc, numb_aparam). + + Returns + ------- + dict[str, torch.Tensor] + Per-atom fitting outputs. + """ + xx = descriptor.to(self.prec) + nf, nloc, nd = xx.shape + if self.numb_fparam > 0 and fparam is None: + assert self.default_fparam_tensor is not None + fparam = torch.tile(self.default_fparam_tensor.unsqueeze(0), [nf, 1]) + fparam = fparam.to(self.prec) if fparam is not None else None + aparam = aparam.to(self.prec) if aparam is not None else None + + if self.remove_vaccum_contribution is not None: + xx_zeros = torch.zeros_like(xx) + else: + xx_zeros = None + net_dim_out = self._net_out_dim() + + if nd != self.dim_descrpt: + raise ValueError( + f"get an input descriptor of dim {nd}," + f"which is not consistent with {self.dim_descrpt}." + ) + + if self.numb_fparam > 0: + assert fparam is not None, "fparam should not be None" + assert self.fparam_avg is not None + assert self.fparam_inv_std is not None + if fparam.numel() != nf * self.numb_fparam: + raise ValueError( + f"input fparam: cannot reshape {list(fparam.shape)} " + f"into ({nf}, {self.numb_fparam})." + ) + fparam = fparam.view([nf, self.numb_fparam]) + nb, _ = fparam.shape + t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb) + t_fparam_inv_std = self._extend_f_avg_std(self.fparam_inv_std, nb) + fparam = (fparam - t_fparam_avg) * t_fparam_inv_std + fparam = torch.tile(fparam.reshape([nf, 1, -1]), [1, nloc, 1]) + xx = torch.cat([xx, fparam], dim=-1) + if xx_zeros is not None: + xx_zeros = torch.cat([xx_zeros, fparam], dim=-1) + + if self.numb_aparam > 0 and not self.use_aparam_as_mask: + assert aparam is not None, "aparam should not be None" + assert self.aparam_avg is not None + assert self.aparam_inv_std is not None + if aparam.numel() % (nf * self.numb_aparam) != 0: + raise ValueError( + f"input aparam: cannot reshape {list(aparam.shape)} " + f"into ({nf}, nloc, {self.numb_aparam})." + ) + aparam = aparam.view([nf, -1, self.numb_aparam]) + nb, nloc, _ = aparam.shape + t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) + t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc) + aparam = (aparam - t_aparam_avg) * t_aparam_inv_std + xx = torch.cat([xx, aparam], dim=-1) + if xx_zeros is not None: + xx_zeros = torch.cat([xx_zeros, aparam], dim=-1) + + assert self.case_embd is not None + outs = torch.zeros( + (nf, nloc, net_dim_out), + dtype=self.prec, + device=descriptor.device, + ) + results = {} + + fitting = self.filter_layers.networks[0] + atom_property = fitting(xx, self.case_embd) + if self.eval_return_middle_output: + results["middle_output"] = fitting.call_until_last(xx, self.case_embd) + if xx_zeros is not None: + atom_property -= fitting(xx_zeros, self.case_embd) + outs = outs + atom_property + self.bias_atom_e[atype].to(self.prec) + + mask = self.emask(atype).to(torch.bool) + outs = torch.where(mask[:, :, None], outs, 0.0) + results.update({self.var_name: outs}) + return results + + @classmethod + def deserialize(cls, data: dict) -> GeneralFitting: + data = data.copy() + variables = data.pop("@variables") + nets = data.pop("nets") + check_version_compatibility(data.pop("@version", 1), 4, 1) + data.pop("var_name") + data.pop("dim_out") + obj = cls(**data) + for kk in variables.keys(): + obj[kk] = to_torch_tensor(variables[kk]) + obj.filter_layers = SeZMNetworkCollection.deserialize(nets) + return obj + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + return { + **super().serialize(), + "type": "sezm_ener", + "case_film_embd": self.case_film_embd, + } diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 012cdb3a65..a2f9320974 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -3,6 +3,7 @@ import functools import json import logging +import os import time from collections.abc import ( Callable, @@ -39,6 +40,7 @@ ) from deepmd.pt.loss import ( DenoiseLoss, + DeNSLoss, DOSLoss, EnergyHessianStdLoss, EnergySpinLoss, @@ -47,10 +49,18 @@ TaskLoss, TensorLoss, ) +from deepmd.pt.model.descriptor.sezm_nn import ( + apply_lora_to_sezm, + build_merged_state_dict, + strip_lora_from_extra_state, +) from deepmd.pt.model.model import ( get_model, get_zbl_model, ) +from deepmd.pt.model.model.sezm_model import ( + SeZMModel, +) from deepmd.pt.optimizer import ( AdaMuonOptimizer, HybridMuonOptimizer, @@ -63,6 +73,9 @@ get_ema_checkpoint_prefix, get_ema_validation_log_path, ) +from deepmd.pt.train.utils import ( + clip_grad_norm_with_stable_fallback, +) from deepmd.pt.train.validation import ( FullValidator, resolve_full_validation_start_step, @@ -167,6 +180,19 @@ def __init__( model_params = config["model"] training_params = config["training"] optimizer_params = config.get("optimizer", {}) + + # NOTE: Translate ``validating.compiled_infer`` (input.json opt-in) + # into the ``DP_COMPILE_INFER`` environment variable *before* any + # model is constructed below. SeZMModel samples this env var + # exactly once inside its __init__ (see ``_env_use_compile_infer`` + # in ``deepmd/pt/model/model/sezm_model.py``) and uses the cached + # value to decide whether eval / full-validation forwards take + # the compile path. Setting it later would be silently ignored + # for the rest of the run. ``setdefault`` preserves any explicit + # shell-level override so a user who manually exported + # ``DP_COMPILE_INFER`` (either direction) stays in control. + if bool((config.get("validating") or {}).get("compiled_infer", False)): + os.environ.setdefault("DP_COMPILE_INFER", "1") self.multi_task = "model_dict" in model_params self.finetune_links = finetune_links self.finetune_update_stat = False @@ -428,6 +454,8 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: resuming=resuming, _loss_params=loss_param_tmp, ) + # SeZM specific process for DeNS training + prepare_model_for_loss(self.model, loss_param_tmp) # Loss if not self.multi_task: @@ -452,9 +480,26 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: # add data requirement for labels data_requirement = self.loss.label_requirement data_requirement += get_additional_data_requirement(self.model) + min_pair_dist = float( + training_params.get("training_data", {}).get("min_pair_dist", 0.0) + ) + if min_pair_dist > 0.0: + data_requirement.append( + DataRequirementItem( + "min_pair_dist", + ndof=1, + atomic=False, + must=False, + high_prec=False, + default=min_pair_dist, + ) + ) training_data.add_data_requirement(data_requirement) if validation_data is not None: - validation_data.add_data_requirement(data_requirement) + validation_data.add_data_requirement( + self.loss.label_requirement + + get_additional_data_requirement(self.model) + ) # Preload and apply modifiers to all data before computing statistics training_data.preload_and_modify_all_data_torch() if validation_data is not None: @@ -510,9 +555,26 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: data_requirement += get_additional_data_requirement( self.model[model_key] ) + min_pair_dist = float( + training_params.get("training_data", {}).get("min_pair_dist", 0.0) + ) + if min_pair_dist > 0.0: + data_requirement.append( + DataRequirementItem( + "min_pair_dist", + ndof=1, + atomic=False, + must=False, + high_prec=False, + default=min_pair_dist, + ) + ) training_data[model_key].add_data_requirement(data_requirement) if validation_data[model_key] is not None: - validation_data[model_key].add_data_requirement(data_requirement) + validation_data[model_key].add_data_requirement( + self.loss[model_key].label_requirement + + get_additional_data_requirement(self.model[model_key]) + ) # Preload and apply modifiers to all data before computing statistics training_data[model_key].preload_and_modify_all_data_torch() if validation_data[model_key] is not None: @@ -661,6 +723,11 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0) self.lr_schedule = get_lr(config["learning_rate"]) + # Minimum pairwise distance for filtering unphysical frames during training + self.min_pair_dist = training_params.get("training_data", {}).get( + "min_pair_dist", 0.0 + ) + # JIT if JIT: self.model = torch.jit.script(self.model) @@ -809,6 +876,9 @@ def collect_single_finetune_params( "_extra_state" ] + # Always use current model_params so newly added fields + # (e.g. bridging_method) are persisted in checkpoints. + state_dict["_extra_state"] = self.wrapper.state_dict()["_extra_state"] self.wrapper.load_state_dict(state_dict) # change bias for fine-tuning @@ -877,6 +947,32 @@ def single_model_finetune( data_stat_protect=_data_stat_protect[0], ) + # LoRA injection (single-task only; argcheck rejects multi-task). + self._lora_enabled = False + if not self.multi_task: + _lora_cfg = model_params.get("lora") + if _lora_cfg is not None: + # "Default" is the fixed key ModelWrapper assigns to the sole + # single-task model (see wrapper.py); finetune `--model-branch` + # has already selected pretrained weights for this slot. + _branch_model = self.wrapper.model["Default"] + if not isinstance(_branch_model, SeZMModel): + log.warning( + "[LoRA] skipping: model is not SeZMModel; " + "LoRA fine-tuning is only supported for SeZM." + ) + else: + apply_lora_to_sezm( + _branch_model, + rank=int(_lora_cfg["rank"]), + alpha=_lora_cfg.get("alpha"), + ) + self._lora_enabled = True + log.info( + f"[LoRA] injected: rank={_lora_cfg['rank']}, " + f"alpha={_lora_cfg.get('alpha', _lora_cfg['rank'])}" + ) + if self.is_distributed: torch.cuda.set_device(LOCAL_RANK) if self.zero_stage >= 2: @@ -1284,6 +1380,10 @@ def step(_step_id: int, task_key: str = "Default") -> None: input_dict, label_dict, log_dict = self.get_data( is_train=True, task_key=task_key ) + # All frames filtered by min_pair_dist (single-GPU only; + # DDP path in get_data() always keeps at least one frame) + if not input_dict: + return if SAMPLER_RECORD: print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" fout1.write(print_str) @@ -1316,29 +1416,12 @@ def step(_step_id: int, task_key: str = "Default") -> None: for name, p in self.wrapper.named_parameters() if p.grad is not None ] - # FSDP2 sharded DTensor gradients don't support error_if_nonfinite; use manual isfinite check instead. - total_norm = torch.nn.utils.clip_grad_norm_( + total_norm = clip_grad_norm_with_stable_fallback( self.wrapper.parameters(), self.gradient_max_norm, + use_stable_fallback=self.zero_stage < 2, + named_parameters=self.wrapper.named_parameters, ) - if not torch.isfinite(total_norm): - bad_params = [] - for name, p in self.wrapper.named_parameters(): - if p.grad is not None: - grad_norm = p.grad.data.norm() - if not torch.isfinite(grad_norm): - bad_params.append( - f" {name}: grad_norm={grad_norm}, shape={list(p.shape)}" - ) - detail = ( - "\n".join(bad_params) - if bad_params - else " (all individual grads finite, overflow in norm reduction)" - ) - raise RuntimeError( - f"Non-finite gradient norm: {total_norm}\n" - f"Parameters with non-finite gradients:\n{detail}" - ) with torch.device(DEVICE): self.optimizer.step() self.scheduler.step() @@ -1427,6 +1510,8 @@ def fake_model() -> dict: self.train_loss_accu[item] = 0.0 for item in more_loss: if "l2_" not in item: + if item not in self.train_loss_accu: + self.train_loss_accu[item] = 0.0 self.train_loss_accu[item] += more_loss[item] else: # Accumulate loss for multi-task @@ -1652,14 +1737,22 @@ def log_loss_valid(_task_key: str = "Default") -> dict: step_id=_step_id, display_step=display_step_id, lr=cur_lr, - save_checkpoint=self.save_model, + save_checkpoint=( + self.save_model_merged + if self._lora_enabled + else self.save_model + ), ) if self.ema_full_validator is not None: self.ema_full_validator.run( step_id=_step_id, display_step=display_step_id, lr=cur_lr, - save_checkpoint=self.save_ema_model, + save_checkpoint=( + self.save_ema_model_merged + if self._lora_enabled + else self.save_ema_model + ), ) if ( @@ -2005,6 +2098,76 @@ def save_ema_model( include_optimizer=False, ) + def save_model_merged( + self, + save_path: str | Path, + lr: float = 0.0, + step: int = 0, + *, + ckpt_prefix: str | None = None, + max_ckpt_keep: int | None = None, + use_ema_weights: bool = False, + ) -> None: + """Save a plain SeZM checkpoint with LoRA adapters folded into base weights. + + Behaviour relative to :meth:`save_model`: + + - state_dict: every ``A_by_l`` / ``B_by_l`` / ``A_m0`` / ``B_m0`` / + ``A_m.*`` / ``B_m.*`` key is removed; the corresponding ``weight`` / + ``weight_m0`` / ``weight_m.*`` tensors absorb ``ΔW = BA·scaling``. + - ``_extra_state.model_params``: the ``lora`` entry is stripped (both + single-task and multi-task layouts) so the resulting checkpoint + loads as plain SeZM without re-triggering LoRA injection. + - optimizer state is **not** saved. Optimizer moments are keyed on + LoRA parameters that no longer exist in the merged layout, so + resuming training from a merged checkpoint is not supported. + - EMA state is **not** saved (this is a deployment snapshot). + - The live ``self.wrapper`` / ``optimizer`` / ``model_ema`` are + untouched; the fold happens on a detached copy of the state dict. + + Intended use: validator-driven best-topk checkpoint saves for LoRA + fine-tune runs. For plain (non-LoRA) runs the result is bit-level + identical to a regular :meth:`save_model` output minus optimizer + and EMA state. + """ + module = self._get_inner_module() + module.train_infos["lr"] = float(lr) + module.train_infos["step"] = step + model_state, _ = self._collect_checkpoint_states( + use_ema_weights=use_ema_weights, + include_optimizer=False, + ) + merged_state = build_merged_state_dict(module, state_dict=model_state) + if "_extra_state" in merged_state: + merged_state["_extra_state"] = strip_lora_from_extra_state( + merged_state["_extra_state"] + ) + self._write_checkpoint( + Path(save_path), + {"model": merged_state}, + ckpt_prefix=self.save_ckpt if ckpt_prefix is None else ckpt_prefix, + max_ckpt_keep=( + self.max_ckpt_keep if max_ckpt_keep is None else max_ckpt_keep + ), + ) + + def save_ema_model_merged( + self, save_path: str | Path, lr: float = 0.0, step: int = 0 + ) -> None: + """EMA-weight variant of :meth:`save_model_merged`.""" + if self.model_ema is None: + raise ValueError( + "EMA checkpoint saving requires `training.enable_ema=true`." + ) + self.save_model_merged( + save_path, + lr=lr, + step=step, + ckpt_prefix=self.ema_save_ckpt, + max_ckpt_keep=self.ema_ckpt_keep, + use_ema_weights=True, + ) + def get_data( self, is_train: bool = True, task_key: str = "Default" ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: @@ -2017,6 +2180,27 @@ def get_data( if iterator is None: return {}, {}, {} batch_data = next(iterator) + # === Filter frames with atoms too close (training only) === + if is_train and self.min_pair_dist > 0.0 and "min_pair_dist" in batch_data: + min_dists = batch_data["min_pair_dist"] + if isinstance(min_dists, torch.Tensor): + valid_mask = min_dists.squeeze(-1) >= self.min_pair_dist + n_total = valid_mask.shape[0] + n_valid = int(valid_mask.sum().item()) + if n_valid == 0: + # Under distributed training (DDP/FSDP), every rank must + # participate in backward() to avoid collective communication + # deadlock. Keep one frame as a fallback instead of + # skipping the entire batch. + if dist.is_available() and dist.is_initialized(): + valid_mask[0] = True + n_valid = 1 + else: + return {}, {}, {} + if n_valid < n_total: + for key, val in batch_data.items(): + if isinstance(val, torch.Tensor) and val.shape[0] == n_total: + batch_data[key] = val[valid_mask] for key in batch_data.keys(): if key == "sid" or key == "fid" or key == "box" or "find_" in key: continue @@ -2164,6 +2348,23 @@ def whether_hessian(loss_params: dict[str, Any]) -> bool: return loss_type == "ener" and loss_params.get("start_pref_h", 0.0) > 0.0 +def prepare_model_for_loss( + model: Any, + loss_params: dict[str, Any] | None, +) -> None: + """Align model execution mode with the configured training loss.""" + if loss_params is None: + return + if isinstance(model, dict): + for model_key, sub_model in model.items(): + sub_loss = loss_params.get(model_key) + if sub_loss is not None: + prepare_model_for_loss(sub_model, sub_loss) + return + if hasattr(model, "set_active_mode_from_loss"): + model.set_active_mode_from_loss(loss_params.get("type", "ener")) + + def get_loss( loss_params: dict[str, Any], start_lr: float, _ntypes: int, _model: Any ) -> TaskLoss: @@ -2174,6 +2375,9 @@ def get_loss( elif loss_type == "ener": loss_params["starter_learning_rate"] = start_lr return EnergyStdLoss(**loss_params) + elif loss_type == "dens": + loss_params["starter_learning_rate"] = start_lr + return DeNSLoss(**loss_params) elif loss_type == "dos": loss_params["starter_learning_rate"] = start_lr loss_params["numb_dos"] = _model.model_output_def()["dos"].output_size diff --git a/deepmd/pt/train/utils.py b/deepmd/pt/train/utils.py new file mode 100644 index 0000000000..2cbf536ac2 --- /dev/null +++ b/deepmd/pt/train/utils.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Training utility functions.""" + +from __future__ import ( + annotations, +) + +import math +from typing import ( + TYPE_CHECKING, +) + +import torch + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Iterable, + ) + + +def clip_grad_norm_with_stable_fallback( + parameters: Iterable[torch.nn.Parameter], + max_norm: float, + use_stable_fallback: bool = True, + named_parameters: Callable[[], Iterable[tuple[str, torch.nn.Parameter]]] + | None = None, +) -> torch.Tensor: + """Clip gradients, falling back to a scaled norm if the global norm overflows. + + The normal path returns PyTorch's native norm tensor. The stable fallback + returns a float64 scalar tensor on the first gradient device so very large + finite norms do not collapse back to inf when reported. + """ + params = [p for p in parameters if p.grad is not None] + if not params: + return torch.tensor(0.0, dtype=torch.float64, device="cpu") + + if not use_stable_fallback: + total_norm = torch.nn.utils.clip_grad_norm_( + params, + max_norm, + ) + if not torch.isfinite(total_norm): + raise_nonfinite_gradient_norm( + collect_named_grads(params, named_parameters), total_norm + ) + return total_norm + + try: + return torch.nn.utils.clip_grad_norm_( + params, + max_norm, + error_if_nonfinite=True, + ) + except RuntimeError as err: + message = str(err).lower() + if "non-finite" not in message and "nonfinite" not in message: + raise + return stable_clip_grad_norm( + collect_named_grads(params, named_parameters), max_norm + ) + + +def collect_named_grads( + parameters: list[torch.nn.Parameter], + named_parameters: Callable[[], Iterable[tuple[str, torch.nn.Parameter]]] | None, +) -> list[tuple[str, torch.nn.Parameter]]: + if named_parameters is None: + return [(f"param_{idx}", param) for idx, param in enumerate(parameters)] + return [ + (name, param) for name, param in named_parameters() if param.grad is not None + ] + + +def raise_nonfinite_gradient_norm( + named_parameters: list[tuple[str, torch.nn.Parameter]], + total_norm: torch.Tensor, +) -> None: + bad_params = [] + for name, param in named_parameters: + grad_norm = param.grad.data.norm() + if not torch.isfinite(grad_norm): + bad_params.append( + f" {name}: grad_norm={grad_norm}, shape={list(param.shape)}" + ) + detail = ( + "\n".join(bad_params) + if bad_params + else " (all individual grads finite, overflow in norm reduction)" + ) + raise RuntimeError( + f"Non-finite gradient norm: {total_norm}\n" + f"Parameters with non-finite gradients:\n{detail}" + ) + + +def stable_clip_grad_norm( + named_parameters: list[tuple[str, torch.nn.Parameter]], + max_norm: float, +) -> torch.Tensor: + """Clip finite gradients with a scaled L2 norm to avoid overflow.""" + bad_params = [] + scale = 0.0 + first_device = named_parameters[0][1].grad.device + + # === Step 1. Find the largest finite gradient magnitude === + for name, param in named_parameters: + grad = param.grad.detach() + values = grad.coalesce().values() if grad.is_sparse else grad + if not bool(torch.isfinite(values).all().item()): + grad_norm = grad.norm() + bad_params.append( + f" {name}: grad_norm={grad_norm}, shape={list(param.shape)}" + ) + continue + if values.numel() > 0: + scale = max(scale, float(values.abs().max().item())) + + if bad_params: + detail = "\n".join(bad_params) + raise RuntimeError( + "Non-finite gradient norm: non-finite\n" + f"Parameters with non-finite gradients:\n{detail}" + ) + if scale == 0.0: + return torch.zeros((), dtype=torch.float64, device=first_device) + + # === Step 2. Accumulate squared gradients after scaling by max magnitude === + scaled_ssq = 0.0 + for _, param in named_parameters: + grad = param.grad.detach() + values = grad.coalesce().values() if grad.is_sparse else grad + scaled = values.to(torch.float64) / scale + scaled_ssq += float(torch.sum(scaled * scaled).item()) + + total_norm = scale * math.sqrt(scaled_ssq) + if not math.isfinite(total_norm): + raise RuntimeError( + f"Non-finite gradient norm: {total_norm}\n" + "Parameters with non-finite gradients:\n" + " (all individual grads finite, overflow in stable norm reduction)" + ) + + clip_coef = float(max_norm) / (total_norm + 1e-6) + if clip_coef < 1.0: + for _, param in named_parameters: + param.grad.detach().mul_(clip_coef) + + return torch.tensor(total_norm, dtype=torch.float64, device=first_device) diff --git a/deepmd/pt/train/validation.py b/deepmd/pt/train/validation.py index ace3fb244d..f8874d8b19 100644 --- a/deepmd/pt/train/validation.py +++ b/deepmd/pt/train/validation.py @@ -473,7 +473,8 @@ def _evaluate_system( atom_types=test_data["type"], box=test_data["box"] if data_system.pbc else None, fparam=test_data["fparam"] - if bool(test_data.get("find_fparam", 0.0)) + if self.model.get_dim_fparam() > 0 + and bool(test_data.get("find_fparam", 0.0)) else None, aparam=test_data["aparam"] if self.model.get_dim_aparam() > 0 else None, include_virial=include_virial, diff --git a/deepmd/pt/utils/multi_task.py b/deepmd/pt/utils/multi_task.py index d30efd30ae..d99ac704c7 100644 --- a/deepmd/pt/utils/multi_task.py +++ b/deepmd/pt/utils/multi_task.py @@ -14,6 +14,26 @@ ) +def _cascade_top_level_defaults(model_config: dict[str, Any]) -> None: + """In-place: lower model-wide ``model.`` entries into each branch. + + Any key at the top of ``model`` other than ``model_dict`` / ``shared_dict`` + is ``setdefault``-copied into every ``model_dict`` entry (explicit branch + values win) and then removed from the top level so the multi-task + argcheck, which only accepts ``model_dict`` / ``shared_dict`` there, + does not reject it as an unknown field. + """ + _RESERVED_TOP_LEVEL = ("model_dict", "shared_dict") + top_level_defaults = { + k: deepcopy(v) for k, v in model_config.items() if k not in _RESERVED_TOP_LEVEL + } + for branch in model_config["model_dict"].values(): + for k, v in top_level_defaults.items(): + branch.setdefault(k, deepcopy(v)) + for k in top_level_defaults: + model_config.pop(k, None) + + def preprocess_shared_params( model_config: dict[str, Any], ) -> tuple[dict[str, Any], dict[str, Any]]: @@ -93,9 +113,14 @@ def preprocess_shared_params( ] } } - + Any key placed directly under ``model`` other than ``model_dict`` / + ``shared_dict`` is lowered into every branch via ``_cascade_top_level_defaults`` + (explicit branch values win), so model-wide switches can be written + once at the top level. """ assert "model_dict" in model_config, "only multi-task model can use this method!" + _cascade_top_level_defaults(model_config) + supported_types = ["type_map", "descriptor", "fitting_net"] shared_dict = model_config.get("shared_dict", {}) shared_links = {} diff --git a/deepmd/pt/utils/serialization.py b/deepmd/pt/utils/serialization.py index e54ec9c76d..82274796e8 100644 --- a/deepmd/pt/utils/serialization.py +++ b/deepmd/pt/utils/serialization.py @@ -79,6 +79,12 @@ def deserialize_to_file(model_file: str, data: dict) -> None: ) model = SpinEnergyModel.deserialize(model_data) + elif model_data.get("type") == "sezm_spin": + from deepmd.pt.model.model.sezm_spin_model import ( + SeZMSpinModel, + ) + + model = SeZMSpinModel.deserialize(model_data) else: model = BaseModel.deserialize(model_data) # JIT will happy in this way... diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index f2fe908297..ba3ada987b 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -406,13 +406,20 @@ def _load_pt(self, model_file: str, head: str | None = None) -> None: # `_collect_metadata` writes into metadata.json. self.metadata = { "type_map": model.get_type_map(), + "ntypes": model.get_descriptor().get_ntypes(), "rcut": model.get_rcut(), "sel": model.get_sel(), "dim_fparam": model.get_dim_fparam(), "dim_aparam": model.get_dim_aparam(), + "dim_chg_spin": model.get_dim_chg_spin() + if hasattr(model, "get_dim_chg_spin") + else 0, "mixed_types": model.mixed_types(), "has_default_fparam": model.has_default_fparam(), "default_fparam": model.get_default_fparam(), + "default_chg_spin": model.get_default_chg_spin() + if hasattr(model, "get_default_chg_spin") + else None, "is_spin": self._is_spin, } if self._is_spin: @@ -478,7 +485,9 @@ def get_rcut(self) -> float: def get_ntypes(self) -> int: """Get the number of atom types of this model.""" - return len(self._type_map) + if self._type_map: + return len(self._type_map) + return int(self.metadata.get("ntypes", 0)) def get_type_map(self) -> list[str]: """Get the type map (element name of the atom types) of this model.""" @@ -496,6 +505,34 @@ def get_dim_aparam(self) -> int: return self._dpmodel.get_dim_aparam() return int(self.metadata["dim_aparam"]) + def get_dim_chg_spin(self) -> int: + """Get the width of charge/spin condition inputs.""" + if self._dpmodel is not None and hasattr(self._dpmodel, "get_dim_chg_spin"): + return self._dpmodel.get_dim_chg_spin() + return int(self.metadata.get("dim_chg_spin", 0)) + + def _make_charge_spin_input(self, nframes: int) -> torch.Tensor | None: + """Build the fixed charge/spin tensor used by exported SeZM models.""" + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) + + dim_chg_spin = self.get_dim_chg_spin() + if dim_chg_spin == 0: + return None + default_chg_spin = self.metadata.get("default_chg_spin") + if default_chg_spin is None: + raise ValueError( + "charge_spin is required for this model, but exported " + "pt_expt inference does not expose a runtime charge_spin argument." + ) + return ( + torch.tensor(default_chg_spin, dtype=torch.float64, device=DEVICE) + .view(1, dim_chg_spin) + .expand(nframes, -1) + .contiguous() + ) + @property def model_type(self) -> type["DeepEvalWrapper"]: """The evaluator of the model type.""" @@ -1038,20 +1075,27 @@ def _eval_model( nframes, natoms, ) = self._prepare_inputs(coords, cells, atom_types, fparam, aparam) + charge_spin_t = self._make_charge_spin_input(nframes) # Call the model (forward_common_lower interface, internal keys) + model_inputs = ( + ext_coord_t, + ext_atype_t, + nlist_t, + mapping_t, + fparam_t, + aparam_t, + ) + if charge_spin_t is not None: + model_inputs = (*model_inputs, charge_spin_t) if self._is_pt2: # AOTInductor's __call__ unflattens output using stored out_spec, # returning a dict just like the .pte module. # It also filters non-tensor args automatically, matching the # export-time signature where None args were excluded. - model_ret = self._pt2_runner( - ext_coord_t, ext_atype_t, nlist_t, mapping_t, fparam_t, aparam_t - ) + model_ret = self._pt2_runner(*model_inputs) else: - model_ret = self.exported_module( - ext_coord_t, ext_atype_t, nlist_t, mapping_t, fparam_t, aparam_t - ) + model_ret = self.exported_module(*model_inputs) # Apply communicate_extended_output to map extended atoms → local atoms do_atomic_virial = any( @@ -1182,27 +1226,24 @@ def _eval_model_spin( else: aparam_t = None - # Call the model with spin (7 args) + charge_spin_t = self._make_charge_spin_input(nframes) + + # Call the model with spin. + model_inputs = ( + ext_coord_t, + ext_atype_t, + ext_spin_t, + nlist_t, + mapping_t, + fparam_t, + aparam_t, + ) + if charge_spin_t is not None: + model_inputs = (*model_inputs, charge_spin_t) if self._is_pt2: - model_ret = self._pt2_runner( - ext_coord_t, - ext_atype_t, - ext_spin_t, - nlist_t, - mapping_t, - fparam_t, - aparam_t, - ) + model_ret = self._pt2_runner(*model_inputs) else: - model_ret = self.exported_module( - ext_coord_t, - ext_atype_t, - ext_spin_t, - nlist_t, - mapping_t, - fparam_t, - aparam_t, - ) + model_ret = self.exported_module(*model_inputs) # Apply communicate_extended_output to map extended atoms → local atoms do_atomic_virial = any( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0364b24695..7800503ff3 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -52,7 +52,10 @@ doc_se_atten = "Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Attention mechanism will be used by this descriptor." doc_se_atten_v2 = "Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Attention mechanism with new modifications will be used by this descriptor." doc_se_a_mask = "Used by the smooth edition of Deep Potential. It can accept a variable number of atoms in a frame (Non-PBC system). *aparam* are required as an indicator matrix for the real/virtual sign of input atoms." -doc_hybrid = "Concatenates a list of descriptors into a new descriptor." +doc_hybrid = "Concatenate of a list of descriptors as a new descriptor." +doc_se_zm = ( + "DPA4 descriptor (SeZM implementation): Smooth equivariant Zone-bridging Model." +) # fitting doc_ener = "Fit an energy model (potential energy surface)." doc_dos = "Fit a density of states model. The total density of states / site-projected density of states labels should be provided by `dos.npy` or `atom_dos.npy` in each data system. The file has a number of frames (rows) and a number of energy-grid columns (multiplied by the number of atoms in `atom_dos.npy`). See `loss` parameter." @@ -340,6 +343,451 @@ def descrpt_se_a_args() -> list[Argument]: ] +@descrpt_args_plugin.register( + "dpa4", + alias=["SeZM", "sezm"], + doc=doc_only_pt_supported + doc_se_zm, +) +def descrpt_se_zm_args() -> list[Argument]: + # Follows exact order of docstring in sezm.py DescrptSeZM class + doc_sel = 'The maximum number of neighbors. It can be:\n\n\ + - `int`: the total maximum number of neighbors within `rcut` (all types combined)\n\n\ + - `list[int]`: sel[i] specifies the maximum number of type-i neighbors within `rcut`\n\n\ + - `str`: Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors with in the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally the number is wrapped up to 4 divisible. The option "auto" is equivalent to "auto:1.1".' + doc_rcut = "The cut-off radius." + doc_env_exp = ( + "C^2 cutoff envelope exponents `[rbf_env_exp, edge_env_exp]`. " + "`rbf_env_exp` controls radial basis function envelope decay; " + "`edge_env_exp` controls message passing edge weight envelope decay. " + "Larger values give weaker suppression." + ) + doc_channels = "Total channels per (l,m) coefficient." + doc_basis_type = "Radial basis type. Supported values are `bessel` and `gaussian`." + doc_n_radial = "Number of radial basis functions." + doc_radial_mlp = "Hidden layer sizes for radial networks. An output layer of size (l_schedule[0]+1)*channels will be automatically appended. Use 0 as a placeholder to be replaced by channels." + doc_use_env_seed = ( + "If True, apply environment matrix initial embedding as FiLM conditioning on " + "l=0 features using 4D [s, s*r_hat] representation. Internal dimensions are " + "derived from channels: embed_dim=min(channels, 128), " + "axis_dim=min(4 if embed_dim < 64 else 8, embed_dim-1), " + "type_dim=clamp(channels//4, 8, 32), " + "rbf_out_dim=max(32, embed_dim-2*type_dim), " + "hidden_dim=min(256, max(2*embed_dim, rbf_out_dim+2*type_dim))." + ) + doc_random_gamma = ( + "If True, apply a random roll about the edge-aligned local +Z axis before " + "building Wigner-D blocks. The roll is sampled independently per edge and " + "per forward call." + ) + doc_lmax = "Maximum degree, only used when `l_schedule` is None." + doc_l_schedule = "Pyramid schedule of lmax per block, e.g. [3, 3, 2]. Must be non-increasing. If set, lmax and n_blocks will be ignored." + doc_mmax = "Maximum SO(2) order (|m|), only used when `m_schedule` is None. If None, defaults to the per-block lmax." + doc_m_schedule = "Schedule of mmax per block. Must satisfy `m_schedule[i] <= l_schedule[i]`. If set, `mmax` will be ignored." + doc_n_blocks = "Number of blocks (only used when `l_schedule` is None)." + doc_block_attn_res = ( + "Descriptor-level block attention residual mode over block history " + "`[x0, b1, b2, ...]`, where each block summary is the sum of the SO(2) " + "unit output and all FFN unit outputs inside one interaction block. " + "`independent` uses learned query vectors, while `dependent` derives " + "queries from the current SeZM state before the SO(2) unit, before " + "each FFN unit, and before the final block aggregation. Must be one of " + "`none`, `independent`, or `dependent`. Cannot be enabled together " + "with `full_attn_res`." + ) + doc_so2_norm = ( + "If True, apply intermediate ReducedEquivariantRMSNorm between SO(2) mixing layers. " + "When False (default), no normalization is applied between layers." + ) + doc_so2_layers = "Number of SO(2) mixing layers per block." + doc_so2_attn_res = ( + "Depth-wise attention residual mode across the internal SO(2) layer " + "history inside each interaction block. Must be one of `none`, " + "`independent`, or `dependent`." + ) + attn_res_modes = {"none", "independent", "dependent"} + radial_so2_modes = {"none", "degree", "degree_channel"} + doc_radial_so2_mode = ( + "Dynamic radial degree mixer mode inside SO(2) convolution. " + "`none` applies elementwise radial modulation. " + "`degree` uses an edge-conditioned cross-degree kernel " + "`W[l_in,l_out,|m|](r)` shared by all channels. " + "`degree_channel` uses `W[l_in,l_out,|m|,c](r)`, optionally low-rank " + "when `radial_so2_rank > 0`." + ) + doc_radial_so2_rank = ( + "Low-rank channel factorization rank for `radial_so2_mode=degree_channel`. " + "`0` uses the full per-channel dynamic degree kernel." + ) + doc_n_focus = ( + "Number of parallel focus streams used only inside the SO(2) convolution." + ) + doc_focus_dim = "Hidden width per focus stream inside the SO(2) convolution. `0` means using `channels`." + doc_n_atten_head = ( + "Number of attention heads when aggregating messages in SO(2) " + "convolution. 0 applies a plain envelope-weighted scatter-sum. When >0, " + "the attention width must be divisible by `n_atten_head`, and envelope-gated " + "grouped softmax attention with output-side head gate is applied. Attention uses " + "`w**2 * exp(logit)` in the numerator and " + "`zeta + sum(w**2 * exp(logit))` in the denominator." + ) + doc_atten_f_mix = ( + "If True, merge all SO(2) focus streams into one attention stream after " + "rotate-back. Attention heads split `n_focus * focus_dim` (or " + "`n_focus * channels` when `focus_dim=0`) instead of each focus stream " + "independently. The default False preserves per-focus attention." + ) + doc_atten_v_proj = ( + "If True, apply an explicit degree-aware value projection inside SO(2) " + "attention. The default False keeps the raw rotated message as the " + "attention value." + ) + doc_atten_o_proj = ( + "If True, apply an explicit degree-aware output projection after the " + "SO(2) attention output gate. The default False keeps the legacy output " + "path without this projection." + ) + doc_ffn_neurons = ( + "Hidden width for block FFNs and the final scalar output FFN. " + "`>0` uses the same explicit width for both. " + "`0` lets each path resolve its own width from `channels`: " + "`4 * channels` without GLU, `(8 / 3) * channels` with GLU, " + "then round up to a multiple of 32." + ) + doc_ffn_blocks = "Number of FFN sublayers per interaction block." + doc_sandwich_norm = ( + "Pre/post-norm switches for residual branches. Use [so2_pre, so2_post, ffn_pre, ffn_post] to " + "enable pre-norm before and post-norm after SO(2) and FFN operations." + ) + doc_mlp_bias = ( + "Whether to use bias in equivariant layers. When False, removes bias from:\n" + "- SO3Linear: l=0 bias\n" + "- SO2Linear: l=0 bias\n" + "- GatedActivation: gate linear bias\n" + "- DepthAttnRes: input-dependent query projection\n" + "- EnvironmentInitialEmbedding MLPs: rbf_proj_layer1/2 and g_layer1/2\n" + "Attention projections in SO2Convolution " + "(attn_radial_bias_proj, attn_output_gate_proj) are always bias-free." + ) + doc_layer_scale = ( + "If True, apply learnable LayerScale (init 1e-3) on residual branches: " + "SO(2) branch uses per-focus-channel scales " + "(shape `(n_focus, focus_dim)`) on each SO(2) mixing layer, " + "and FFN branch uses per-channel scales (shape `(channels,)`) on each " + "FFN residual branch." + ) + doc_full_attn_res = ( + "Descriptor-level full attention residual mode over the unit history " + "`[x0, so2_0, ffn_0_0, ffn_0_1, ..., so2_1, ffn_1_0, ffn_1_1, ...]`. " + "`independent` uses learned query vectors, while `dependent` derives " + "the query from the current SeZM state before the SO(2) unit, before " + "each FFN unit, and before the final aggregation. Must be one of " + "`none`, `independent`, or `dependent`. Cannot be enabled together " + "with `block_attn_res`." + ) + doc_s2_activation = ( + "Two booleans `[so2_enabled, ffn_enabled]`. " + "`so2_enabled=true` makes the SO(2) gated activation path use " + '`activation_function="silu"`. ' + "`ffn_enabled=true` makes the block-internal FFN path use " + '`activation_function="silu"` and `glu_activation=true`. ' + "S2-grid resolutions are resolved automatically per block. The e3nn " + "SO(2) grid is `[2 * mmax + 4, ceil_even(3 * lmax + 2)]`, and the " + "e3nn FFN grid is lifted to `[max(R_phi, R_theta), max(R_phi, R_theta)]`. " + "Lebedev branches use the smallest packaged rule with precision at " + "least `3 * lmax`. " + "The final scalar output FFN is unchanged." + ) + doc_lebedev_quadrature = ( + "Either one boolean applied to both S2 branches, or two booleans " + "`[so2_enabled, ffn_enabled]` aligned with `s2_activation`. If a branch " + "is enabled here, its S2 projector uses packaged Lebedev quadrature " + "rules instead of the e3nn product grid. The default keeps the existing " + "e3nn behavior." + ) + doc_grid_ffn = ( + "If True, use the optional grid-MLP structure for the block-internal " + "equivariant FFN. This does not change the final `l=0` output head." + ) + doc_activation_function = ( + f"Base activation function for helper MLPs, the SO(2) gated activation " + f"path, and the final scalar output FFN. Supported activation functions " + f"are {list_to_doc(ACTIVATION_FN_DICT.keys())}. " + f'It is overridden to `"silu"` only on paths whose `s2_activation` ' + f"switch is enabled." + ) + doc_glu_activation = ( + "Base GLU switch for FFN (e.g., silu -> swiglu, gelu -> geglu). " + "The block-internal FFN overrides this to `true` when `s2_activation[1]=true`, " + "while the final scalar output FFN keeps the user-provided value." + ) + doc_use_amp = ( + "If True, use automatic mixed precision (AMP) with bfloat16 on CUDA. " + "This does not provide accelerations under fp32 precision but will decrease " + "the memory usage, while preserving model accuracy." + ) + doc_add_chg_spin_ebd = ( + "Whether to add frame-level charge and spin conditions to the descriptor " + "type embedding." + ) + doc_default_chg_spin = ( + "Default frame-level charge and spin conditions `[charge, spin]`. " + "If set, this value is used when charge_spin data are not provided." + ) + + doc_exclude_types = ( + "The excluded pairs of types which have no interaction with each other. " + "For example, `[[0, 1]]` means no interaction between type 0 and type 1. " + "When the SeZM descriptor is used inside a full SeZM model config, prefer " + "the model-level `pair_exclude_types`; if both fields are provided, they " + "must match." + ) + doc_precision = f"The precision of the descriptor parameters, supported options are {list_to_doc(PRECISION_DICT.keys())}." + doc_eps = "Small epsilon for numerical stability in division and normalization." + doc_trainable = "If the parameters in the descriptor are trainable." + doc_seed = "Random seed for parameter initialization." + return [ + Argument( + "sel", [int, list[int], str], optional=True, default="auto", doc=doc_sel + ), + Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut), + Argument( + "env_exp", + list[int], + optional=True, + default=[7, 5], + doc=doc_env_exp, + ), + Argument("channels", int, optional=True, default=64, doc=doc_channels), + Argument( + "basis_type", str, optional=True, default="bessel", doc=doc_basis_type + ), + Argument("n_radial", int, optional=True, default=16, doc=doc_n_radial), + Argument( + "radial_mlp", + list[int], + optional=True, + default=[0], + doc=doc_radial_mlp, + ), + Argument( + "use_env_seed", + bool, + optional=True, + default=True, + doc=doc_only_pt_supported + doc_use_env_seed, + ), + Argument( + "random_gamma", + bool, + optional=True, + default=True, + doc=doc_only_pt_supported + doc_random_gamma, + ), + Argument("lmax", int, optional=True, default=3, doc=doc_lmax), + Argument( + "l_schedule", list[int], optional=True, default=None, doc=doc_l_schedule + ), + Argument( + "mmax", + [int, None], + optional=True, + default=1, + doc=doc_mmax, + ), + Argument( + "m_schedule", list[int], optional=True, default=None, doc=doc_m_schedule + ), + Argument("n_blocks", int, optional=True, default=3, doc=doc_n_blocks), + Argument("so2_norm", bool, optional=True, default=False, doc=doc_so2_norm), + Argument("so2_layers", int, optional=True, default=4, doc=doc_so2_layers), + Argument( + "so2_attn_res", + str, + optional=True, + default="none", + extra_check=lambda x: x in attn_res_modes, + extra_check_errmsg="must be one of 'none', 'independent', or 'dependent'", + doc=doc_only_pt_supported + doc_so2_attn_res, + ), + Argument( + "radial_so2_mode", + str, + optional=True, + default="degree_channel", + extra_check=lambda x: x in radial_so2_modes, + extra_check_errmsg="must be one of 'none', 'degree', or 'degree_channel'", + doc=doc_only_pt_supported + doc_radial_so2_mode, + ), + Argument( + "radial_so2_rank", + int, + optional=True, + default=1, + extra_check=lambda x: x >= 0, + extra_check_errmsg="must be non-negative", + doc=doc_only_pt_supported + doc_radial_so2_rank, + ), + Argument("n_focus", int, optional=True, default=1, doc=doc_n_focus), + Argument( + "focus_dim", + int, + optional=True, + default=0, + extra_check=lambda x: x >= 0, + extra_check_errmsg="must be >= 0", + doc=doc_focus_dim, + ), + Argument("n_atten_head", int, optional=True, default=1, doc=doc_n_atten_head), + Argument( + "atten_f_mix", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_atten_f_mix, + ), + Argument( + "atten_v_proj", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_atten_v_proj, + ), + Argument( + "atten_o_proj", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_atten_o_proj, + ), + Argument( + "ffn_neurons", + int, + optional=True, + default=0, + extra_check=lambda x: x >= 0, + extra_check_errmsg="must be >= 0", + doc=doc_ffn_neurons, + ), + Argument( + "grid_mlp", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_grid_ffn, + ), + Argument( + "ffn_blocks", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_ffn_blocks, + ), + Argument( + "sandwich_norm", + list[bool], + optional=True, + default=[False, True, True, False], + doc=doc_only_pt_supported + doc_sandwich_norm, + ), + Argument( + "mlp_bias", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_mlp_bias, + ), + Argument( + "layer_scale", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_layer_scale, + ), + Argument( + "full_attn_res", + str, + optional=True, + default="none", + extra_check=lambda x: x in attn_res_modes, + extra_check_errmsg="must be one of 'none', 'independent', or 'dependent'", + doc=doc_only_pt_supported + doc_full_attn_res, + ), + Argument( + "block_attn_res", + str, + optional=True, + default="none", + extra_check=lambda x: x in attn_res_modes, + extra_check_errmsg="must be one of 'none', 'independent', or 'dependent'", + doc=doc_only_pt_supported + doc_block_attn_res, + ), + Argument( + "s2_activation", + list[bool], + optional=True, + default=[False, True], + extra_check=lambda x: len(x) == 2, + extra_check_errmsg="must be a list of two booleans: [so2_activation, ffn_activation]", + doc=doc_only_pt_supported + doc_s2_activation, + ), + Argument( + "lebedev_quadrature", + [bool, list[bool]], + optional=True, + default=True, + extra_check=lambda x: isinstance(x, bool) or len(x) == 2, + extra_check_errmsg="must be a boolean or a list of two booleans: [so2_quadrature, ffn_quadrature]", + doc=doc_only_pt_supported + doc_lebedev_quadrature, + ), + Argument( + "activation_function", + str, + optional=True, + default="silu", + doc=doc_activation_function, + ), + Argument( + "glu_activation", + bool, + optional=True, + default=True, + doc=doc_only_pt_supported + doc_glu_activation, + ), + Argument("use_amp", bool, optional=True, default=True, doc=doc_use_amp), + Argument( + "add_chg_spin_ebd", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_add_chg_spin_ebd, + ), + Argument( + "default_chg_spin", + list[float], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_default_chg_spin, + ), + Argument( + "exclude_types", + list[list[int]], + optional=True, + default=[], + doc=doc_exclude_types, + ), + Argument("precision", str, optional=True, default="float32", doc=doc_precision), + Argument( + "eps", + float, + optional=True, + default=1e-7, + doc=doc_only_pt_supported + doc_eps, + ), + Argument("trainable", bool, optional=True, default=True, doc=doc_trainable), + Argument("seed", [int, None], optional=True, default=None, doc=doc_seed), + ] + + @descrpt_args_plugin.register( "se_e3", alias=["se_at", "se_a_3be", "se_t"], doc=doc_se_e3 ) @@ -987,10 +1435,7 @@ def dpa2_repinit_args() -> list[Argument]: f"When `type_one_side` is False, the input is `input_t = concat([tebd_j, tebd_i])`. {doc_only_pt_supported} When `type_one_side` is True, the input is `input_t = tebd_j`. " "The output is `out_ij = embedding_t(input_t) * embedding_s(r_ij) + embedding_s(r_ij)` for the pair-wise representation of atom i with neighbor j." ) - doc_set_davg_zero = ( - "Set the normalization average to zero. " - "This option should be set when `atom_ener` in the energy fitting is used." - ) + doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used." doc_activation_function = f"The activation function in the embedding net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())}." doc_type_one_side = r"If true, the embedding network parameters vary by types of neighbor atoms only, so there will be $N_\text{types}$ sets of embedding network parameters. Otherwise, the embedding network parameters vary by types of centric atoms and types of neighbor atoms, so there will be $N_\text{types}^2$ sets of embedding network parameters." doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection.' @@ -1160,15 +1605,8 @@ def dpa2_repformer_args() -> list[Argument]: doc_update_residual = ( "When update using residual mode, the initial std of residual vector weights." ) - doc_update_residual_init = ( - "When update using residual mode, " - "the initialization mode of residual vector weights." - "Supported modes are: ['norm', 'const']." - ) - doc_set_davg_zero = ( - "Set the normalization average to zero. " - "This option should be set when `atom_ener` in the energy fitting is used." - ) + doc_update_residual_init = "When update using residual mode, the initialization mode of residual vector weights.Supported modes are: ['norm', 'const']." + doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used." doc_trainable_ln = ( "Whether to use trainable shift and scale weights in layer normalization." ) @@ -1869,6 +2307,105 @@ def fitting_ener() -> list[Argument]: ] +@fitting_args_plugin.register("dpa4_ener", alias=["sezm_ener"], doc=doc_ener) +def fitting_sezm_ener() -> list[Argument]: + doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." + doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." + doc_default_fparam = "The default frame parameter. If set, when `fparam.npy` files are not included in the data system, this value will be used as the default value for the frame parameter in the fitting net." + doc_dim_case_embd = "The dimension of the case embedding embedding. When training or fine-tuning a multitask model with case embedding embeddings, this number should be set to the number of model branches." + doc_neuron = "The number of neurons in each hidden layer of the fitting net. Use 0 as an auto-width placeholder resolved from the descriptor width." + doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' + doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." + doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' + doc_trainable = f"Whether the parameters in the fitting net are trainable. This option can be\n\n\ +- bool: True if all parameters of the fitting net are trainable, False otherwise.\n\n\ +- list of bool{doc_only_tf_supported}: Specifies if each layer is trainable. Since the fitting net is composed of hidden layers followed by an output layer, the length of this list should be equal to len(`neuron`)+1." + doc_rcond = "The condition number used to determine the initial energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details." + doc_seed = "Random seed for parameter initialization of the fitting net" + doc_atom_ener = "Specify the atomic energy in vacuum for each type" + doc_layer_name = ( + "The name of the each layer. The length of this list should be equal to n_neuron + 1. " + "If two layers, either in the same fitting or different fittings, " + "have the same name, they will share the same neural network parameters. " + "The shape of these layers should be the same. " + "If null is given for a layer, parameters will not be shared." + ) + doc_use_aparam_as_mask = ( + "Whether to use the aparam as a mask in input." + "If True, the aparam will not be used in fitting net for embedding." + "When descrpt is se_a_mask, the aparam will be used as a mask to indicate the input atom is real/virtual. And use_aparam_as_mask should be set to True." + ) + doc_case_film_embd = "Whether to use case FiLM conditioning for DPA4/SeZM shared fitting. When enabled, the case embedding is used to modulate fitting features instead of being concatenated to the fitting input." + return [ + Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam), + Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam), + Argument( + "default_fparam", + list[float], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_default_fparam, + ), + Argument( + "dim_case_embd", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_dim_case_embd, + ), + Argument( + "neuron", + list[int], + optional=True, + default=[0], + alias=["n_neuron"], + doc=doc_neuron, + ), + Argument( + "activation_function", + str, + optional=True, + default="silu", + doc=doc_activation_function, + ), + Argument("precision", str, optional=True, default="float32", doc=doc_precision), + Argument("resnet_dt", bool, optional=True, default=False, doc=doc_resnet_dt), + Argument( + "trainable", + [list[bool], bool], + optional=True, + default=True, + doc=doc_trainable, + ), + Argument( + "rcond", [float, type(None)], optional=True, default=None, doc=doc_rcond + ), + Argument("seed", [int, None], optional=True, default=None, doc=doc_seed), + Argument( + "atom_ener", + list[float | None], + optional=True, + default=[], + doc=doc_atom_ener, + ), + Argument("layer_name", list[str], optional=True, doc=doc_layer_name), + Argument( + "use_aparam_as_mask", + bool, + optional=True, + default=False, + doc=doc_use_aparam_as_mask, + ), + Argument( + "case_film_embd", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_case_film_embd, + ), + ] + + @fitting_args_plugin.register("dos", doc=doc_dos) def fitting_dos() -> list[Argument]: doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." @@ -2435,6 +2972,145 @@ def standard_model_args() -> Argument: return ca +@model_args_plugin.register( + "dpa4", + alias=["SeZM", "sezm"], +) +def sezm_model_args() -> Argument: + doc_descrpt = "The descriptor of atomic environment. User-provided (DPA4 / SeZM is recommended)." + doc_fitting = "The fitting of physical properties. The `type` field is ignored; DPA4 uses the dpa4_ener GLU energy fitting." + doc_model_branch_alias = ( + "List of aliases for this model branch. " + "Multiple aliases can be defined, and any alias can reference this branch throughout the model usage. " + "Used only in multitask models." + ) + doc_info = ( + "Dictionary of metadata for this model branch. " + "Store arbitrary key-value pairs with branch-specific information. " + "Used only in multitask models." + ) + doc_use_compile = ( + "Experimental feature. If True, use compact sparse edges together with " + "symbolic make_fx and torch.compile in the DPA4 / SeZM model. " + "Requires PyTorch >= 2.11. NVIDIA GPUs require CUDA >= 12.6. " + "Apple Silicon Macs are also supported. Tested with Python 3.13." + ) + doc_enable_tf32 = "If True, enable TF32 matmul precision when use_compile=True." + + ca = Argument( + "dpa4", + dict, + [ + Argument( + "descriptor", dict, [], [descrpt_variant_type_args()], doc=doc_descrpt + ), + Argument( + "fitting_net", + dict, + [], + [fitting_variant_type_args()], + doc=doc_fitting, + ), + Argument( + "use_compile", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_use_compile, + ), + Argument( + "enable_tf32", + bool, + optional=True, + default=True, + doc=doc_only_pt_supported + doc_enable_tf32, + ), + Argument( + "model_branch_alias", + list[str], + optional=True, + default=[], + doc=doc_only_pt_supported + doc_model_branch_alias, + ), + Argument( + "info", + dict, + optional=True, + default={}, + doc=doc_only_pt_supported + doc_info, + ), + Argument( + "bridging_method", + str, + optional=True, + default="None", + doc="Bridging method for short-range repulsion. Currently supports 'ZBL'. " + "Case-insensitive. Set to 'None' to disable.", + ), + Argument( + "bridging_r_inner", + float, + optional=True, + default=0.8, + doc="Inner clamping radius in Å. Distances below this are frozen for the ML model. " + "Only used when bridging_method is set. " + "When using ZBL bridging, set training_data.min_pair_dist to the same value " + "so that frames with atoms closer than r_inner are skipped during training.", + ), + Argument( + "bridging_r_outer", + float, + optional=True, + default=1.2, + doc="Outer clamping radius in Å. The transition zone [bridging_r_inner, bridging_r_outer] " + "uses a C3-continuous septic Hermite polynomial. Only used when bridging_method is set.", + ), + Argument( + "lora", + dict, + [ + Argument( + "rank", + int, + doc="LoRA rank; adapters are injected on every SO3Linear and SO2Linear.", + ), + Argument( + "alpha", + float, + optional=True, + default=None, + doc="LoRA scaling numerator; effective scaling is alpha / rank. " + "When omitted, alpha defaults to rank (scaling = 1.0).", + ), + ], + optional=True, + default=None, + doc=doc_only_pt_supported + + "Low-rank adaptation for fine-tuning. Single-task only; " + "setting this in a multi-task input (top-level or per-branch) " + "raises an error in `preprocess_shared_params` because " + "`share_params` links descriptor modules across branches to " + "the same object, which would collapse per-branch LoRA into " + "one shared adapter. " + "When set, backbone SO3Linear and " + "SO2Linear weights are frozen and low-rank A/B adapters are injected " + "alongside them (the adapters share the base shape family so HybridMuon's " + "slice route applies identically). fitting_net, env_seed_embedding, " + "radial_embedding, and small parameters (norm scales, LayerScale, FiLM " + "strength, attention projections, bias terms) stay fully trainable; type " + "embeddings, radial frequencies, and GatedActivation gate projections are " + "frozen. mid-train latest checkpoints include LoRA parameters for resume; " + "best checkpoints from full validation are saved with LoRA deltas folded " + "into base weights, producing plain DPA4 / SeZM checkpoints suitable for " + "deployment.", + ), + ], + alias=["SeZM", "sezm"], + doc="DPA4 model scaffold with fixed SeZM descriptor and fitting types.", + ) + return ca + + @hybrid_model_args_plugin.register("pairwise_dprc") def pairwise_dprc() -> Argument: qm_model_args = model_args(exclude_hybrid=True) @@ -2672,9 +3348,7 @@ def _check_wsd_args(data: dict[str, Any]) -> bool: "linear", ): raise ValueError( - "decay_type must be one of " - f"{('inverse_linear', 'cosine', 'linear')}. " - f"Got decay_type={decay_type}." + f"decay_type must be one of {('inverse_linear', 'cosine', 'linear')}. Got decay_type={decay_type}." ) return True @@ -2744,10 +3418,7 @@ def learning_rate_wsd() -> list[Argument]: "The remaining post-warmup steps are used as the stable phase. " "Default is 0.1." ) - doc_decay_type = ( - "The decay rule used in the decay phase. " - "Supported values are `inverse_linear` (default), `cosine`, and `linear`." - ) + doc_decay_type = "The decay rule used in the decay phase. Supported values are `inverse_linear` (default), `cosine`, and `linear`." return [ Argument( "decay_phase_ratio", @@ -2782,14 +3453,8 @@ def learning_rate_args(fold_subdoc: bool = False) -> Argument: doc_scale_by_worker = "When parallel training or batch size scaled, how to alter learning rate. Valid values are `linear`(default), `sqrt` or `none`." doc_lr = "The definition of learning rate" doc_start_lr = "The learning rate at the start of the training (after warmup)." - doc_stop_lr = ( - "The desired learning rate at the end of training. " - "Mutually exclusive with stop_lr_ratio." - ) - doc_stop_lr_ratio = ( - "The ratio of stop_lr to start_lr. stop_lr = start_lr * stop_lr_ratio. " - "Mutually exclusive with stop_lr." - ) + doc_stop_lr = "The desired learning rate at the end of training. Mutually exclusive with stop_lr_ratio." + doc_stop_lr_ratio = "The ratio of stop_lr to start_lr. stop_lr = start_lr * stop_lr_ratio. Mutually exclusive with stop_lr." doc_warmup_steps = ( "The number of steps for learning rate warmup. " "During warmup, the learning rate increases linearly from " @@ -3434,6 +4099,107 @@ def loss_ener() -> list[Argument]: ] +@loss_args_plugin.register("dens") +def loss_dens() -> list[Argument]: + doc_start_pref_e = start_pref("energy", abbr="e") + doc_limit_pref_e = limit_pref("energy") + doc_start_pref_f = start_pref("force", abbr="f") + doc_limit_pref_f = limit_pref("force") + doc_loss_func = ( + "Loss function type for energy and mixed direct-force / denoising supervision. " + "Options: 'mse' (Mean Squared Error, component-wise force loss) or " + "'mae' (Mean Absolute Error, default). In `dens` mode, `f_use_norm` is " + "not exposed: `mae` always uses per-atom force-vector L2 norms, while " + "`mse` always uses component-wise squared errors." + ) + doc_dens_prob = ( + "Probability of switching one batch to the denoising-enhanced training path. " + "When not selected, the `dens` head is still trained on clean direct forces." + ) + doc_dens_fixed_noise_std = ( + "Whether to use a fixed Gaussian noise standard deviation. " + "Only the fixed-noise path is supported in the initial SeZM `dens` integration." + ) + doc_dens_std = "Standard deviation of the Gaussian coordinate corruption used in the denoising path." + doc_dens_corrupt_ratio = ( + "Fraction of atoms corrupted within a denoising batch. " + "If omitted, all atoms in the batch are corrupted." + ) + doc_dens_denoising_pos_coefficient = "Loss multiplier applied to corrupted atoms whose target is the injected noise vector." + return [ + Argument( + "start_pref_e", + [float, int], + optional=True, + default=0.02, + doc=doc_start_pref_e, + ), + Argument( + "limit_pref_e", + [float, int], + optional=True, + default=1.00, + doc=doc_limit_pref_e, + ), + Argument( + "start_pref_f", + [float, int], + optional=True, + default=1000, + doc=doc_start_pref_f, + ), + Argument( + "limit_pref_f", + [float, int], + optional=True, + default=1.00, + doc=doc_limit_pref_f, + ), + Argument( + "loss_func", + str, + optional=True, + default="mae", + doc=doc_loss_func, + ), + Argument( + "dens_prob", + [float, int], + optional=True, + default=0.5, + doc=doc_dens_prob, + ), + Argument( + "dens_fixed_noise_std", + bool, + optional=True, + default=True, + doc=doc_dens_fixed_noise_std, + ), + Argument( + "dens_std", + [float, int], + optional=True, + default=0.025, + doc=doc_dens_std, + ), + Argument( + "dens_corrupt_ratio", + [float, int, None], + optional=True, + default=0.5, + doc=doc_dens_corrupt_ratio, + ), + Argument( + "dens_denoising_pos_coefficient", + [float, int], + optional=True, + default=10.0, + doc=doc_dens_denoising_pos_coefficient, + ), + ] + + @loss_args_plugin.register("ener_spin") def loss_ener_spin() -> list[Argument]: doc_start_pref_e = start_pref("energy") @@ -3708,7 +4474,7 @@ def loss_tensor() -> list[Argument]: def loss_variant_type_args() -> Variant: - doc_loss = "The type of the loss. When the fitting type is `ener`, the loss type should be set to `ener` or left unset. When the fitting type is `dipole` or `polar`, the loss type should be set to `tensor`." + doc_loss = "The type of the loss. When the fitting type is `ener`, the loss type should be set to `ener`, `dens` (Only DPA4 / SeZM supported), or left unset. When the fitting type is `dipole` or `polar`, the loss type should be set to `tensor`." return Variant( "type", @@ -3720,7 +4486,7 @@ def loss_variant_type_args() -> Variant: def loss_args() -> list[Argument]: - doc_loss = "The definition of loss function. The loss type should be set to `tensor`, `ener` or left unset." + doc_loss = "The definition of loss function. The loss type should be set to `tensor`, `ener`, `dens` or left unset." ca = Argument( "loss", dict, [], [loss_variant_type_args()], optional=True, doc=doc_loss ) @@ -3754,10 +4520,18 @@ def training_data_args() -> list[ - "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\ - "prob_sys_size" : the probability of a system is proportional to the number of batches in the system\n\n\ - "prob_sys_size;stt_idx:end_idx:weight;stt_idx:end_idx:weight;..." : the list of systems is divided into blocks. A block is specified by `stt_idx:end_idx:weight`, where `stt_idx` is the starting index of the system, `end_idx` is then ending (not including) index of the system, the probabilities of the systems in this block sums up to `weight`, and the relatively probabilities within this block is proportional to the number of batches in the system.' - doc_sys_probs = ( - "A list of float if specified. " - "Should be of the same length as `systems`, " - "specifying the probability of each system." + doc_sys_probs = "A list of float if specified. Should be of the same length as `systems`, specifying the probability of each system." + doc_min_pair_dist = ( + "Minimum pairwise atomic distance threshold in Å. " + "Frames containing any atom pair closer than this distance are excluded " + "from loss computation, as DFT labels for near-collision configurations " + "are often unreliable. Set to 0 to disable (default). " + "Under distributed training (DDP/FSDP), if ALL frames in a batch are " + "filtered out on a given rank, one frame is retained to ensure every " + "rank participates in collective communication (backward all-reduce). " + "Note: enabling this adds an O(N²) distance check per frame in the " + "DataLoader workers (CPU-side), which may slow down training for large " + "systems. To avoid the overhead, consider pre-cleaning the dataset instead." ) args = [ @@ -3796,6 +4570,13 @@ def training_data_args() -> list[ doc=doc_sys_probs, alias=["sys_weights"], ), + Argument( + "min_pair_dist", + float, + optional=True, + default=0.0, + doc=doc_only_pt_supported + doc_min_pair_dist, + ), ] doc_training_data = "Configurations of training data." @@ -3834,11 +4615,7 @@ def validation_data_args() -> list[ - "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\ - "prob_sys_size" : the probability of a system is proportional to the number of batches in the system\n\n\ - "prob_sys_size;stt_idx:end_idx:weight;stt_idx:end_idx:weight;..." : the list of systems is divided into blocks. A block is specified by `stt_idx:end_idx:weight`, where `stt_idx` is the starting index of the system, `end_idx` is then ending (not including) index of the system, the probabilities of the systems in this block sums up to `weight`, and the relatively probabilities within this block is proportional to the number of batches in the system.' - doc_sys_probs = ( - "A list of float if specified. " - "Should be of the same length as `systems`, " - "specifying the probability of each system." - ) + doc_sys_probs = "A list of float if specified. Should be of the same length as `systems`, specifying the probability of each system." doc_numb_btch = "An integer that specifies the number of batches to be sampled for each validation period." args = [ @@ -3889,10 +4666,7 @@ def validation_data_args() -> list[ ), ] - doc_validation_data = ( - "Configurations of validation data. Similar to that of training data, " - "except that a `numb_btch` argument may be configured" - ) + doc_validation_data = "Configurations of validation data. Similar to that of training data, except that a `numb_btch` argument may be configured" return Argument( "validation_data", dict, @@ -4324,10 +5098,7 @@ def validating_args() -> Argument: doc_validation_freq = ( "The frequency, in training steps, of running the full validation pass." ) - doc_save_best = ( - "Whether to save an extra checkpoint when the selected full validation " - "metric reaches a new best value." - ) + doc_save_best = "Whether to save an extra checkpoint when the selected full validation metric reaches a new best value." doc_ema_full_validation = ( "Whether to additionally run the same full validation flow on the " "EMA-smoothed model when `validating.full_validation=true`. This reuses " @@ -4348,10 +5119,7 @@ def validating_args() -> Argument: "`E` and `V` are per-atom metrics; `F` uses component-wise force errors, " "matching `dp test`. The corresponding loss prefactors must not both be 0." ) - doc_full_val_file = ( - "The file for writing full validation results only. This file is " - "independent from `training.disp_file`." - ) + doc_full_val_file = "The file for writing full validation results only. This file is independent from `training.disp_file`." doc_full_val_start = ( "The starting point of full validation. `0` means the feature is active " "from the beginning and will trigger at every `validation_freq` steps. " @@ -4359,6 +5127,16 @@ def validating_args() -> Argument: "`1` disables the feature. A value larger than `1` is interpreted as the " "starting step after integer conversion." ) + doc_compiled_infer = ( + "Whether to route eval-time forwards (including full validation) " + "through the DPA4 / SeZM `torch.compile` path instead of eager. When `true`, " + "this flag is translated into `DP_COMPILE_INFER=1` at trainer " + "startup before any model is constructed, which is the env var SeZM " + "samples inside `SeZMModel.__init__`. A manually exported " + "`DP_COMPILE_INFER` takes precedence over this option. Only " + "meaningful when `model.use_compile=true`; has no effect on models " + "that do not implement the SeZM-style eval compile path." + ) args = [ Argument( "full_validation", @@ -4427,6 +5205,13 @@ def validating_args() -> Argument: extra_check=lambda x: x >= 0, extra_check_errmsg="must be greater than or equal to 0", ), + Argument( + "compiled_infer", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_compiled_infer, + ), ] return Argument( "validating", @@ -4461,14 +5246,12 @@ def validate_full_validation_config( if not is_valid_full_validation_metric(metric): valid_metrics = ", ".join(item.upper() for item in FULL_VALIDATION_METRIC_PREFS) raise ValueError( - "validating.validation_metric must be one of " - f"{valid_metrics}, got {metric!r}." + f"validating.validation_metric must be one of {valid_metrics}, got {metric!r}." ) if multi_task: raise ValueError( - "validating.full_validation only supports single-task energy " - "training; multi-task training is not supported." + "validating.full_validation only supports single-task energy training; multi-task training is not supported." ) loss_params = data.get("loss", {}) @@ -4486,8 +5269,7 @@ def validate_full_validation_config( if not training_params.get("validation_data"): raise ValueError( - "full validation requires `training.validation_data`. It is only " - "supported for single-task energy training." + "full validation requires `training.validation_data`. It is only supported for single-task energy training." ) zero_stage = int(training_params.get("zero_stage", 0)) @@ -4624,6 +5406,27 @@ def gen_json_schema(multi_task: bool = False) -> str: return json.dumps(generate_json_schema(arg)) +def validate_no_multitask_lora(data: dict[str, Any], multi_task: bool = False) -> None: + """Reject ``lora`` in multi-task configs. + + In multi-task training `share_params` aliases descriptor modules across + branches to the same Python object, so a per-branch LoRA injection would + silently collapse into one global adapter. Catch this at config time + rather than letting a confusing shared-adapter model slip through. + """ + if not multi_task: + return + model_dict = (data.get("model") or {}).get("model_dict") or {} + for branch_key, branch in model_dict.items(): + if branch.get("lora") is not None: + raise ValueError( + f"`lora` is only supported in single-task training; found in " + f"branch '{branch_key}' (or inherited from the top-level " + f"`model` cascade). Remove the `lora` entry, or switch to a " + f"single-task input." + ) + + def normalize( data: dict[str, Any], multi_task: bool = False, *, check: bool = True ) -> dict[str, Any]: @@ -4633,6 +5436,7 @@ def normalize( if check: base.check_value(data, strict=True) validate_full_validation_config(data, multi_task=multi_task) + validate_no_multitask_lora(data, multi_task=multi_task) return data diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 3fbb9f636f..38db47809d 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -502,6 +502,26 @@ def get_single_frame(self, index: int, num_worker: int) -> dict: frame_data["fid"] = index + # === Compute min_pair_dist on-the-fly in DataLoader worker === + if "min_pair_dist" in self.data_dict: + from deepmd.dpmodel.utils.dist_check import ( + compute_min_pair_dist_single, + ) + + frame_data["find_min_pair_dist"] = np.float32(1.0) + min_pair_dist = float(self.data_dict["min_pair_dist"].get("default", 0.0)) + frame_data["min_pair_dist"] = np.array( + [ + compute_min_pair_dist_single( + frame_data["coord"], + frame_data.get("box"), + frame_data["type"], + stop_below=min_pair_dist, + ) + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + if self.modifier is not None: with ThreadPoolExecutor(max_workers=num_worker) as executor: # Apply modifier if it exists diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index 05e9ae60dc..9d13cb4699 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -690,6 +690,7 @@ def print_summary( nbatches: list[int], sys_probs: list[float], pbc: list[bool], + e_max: list[int] | None = None, ) -> None: """Print summary of systems. @@ -711,6 +712,8 @@ def print_summary( The probabilities pbc : list of bool The periodic boundary conditions + e_max : list of int, optional + The maximal number of valid edges per frame for each system. """ # width 65 sys_width = 42 @@ -718,25 +721,53 @@ def print_summary( f"---Summary of DataSystem: {name.capitalize():13s}-----------------------------------------------" ) log.info("Found %d System(s):", nsystems) - log.info( - "%s %6s %6s %6s %9s %3s", - _format_name_length("system", sys_width), - "natoms", - "bch_sz", - "n_bch", - "prob", - "pbc", - ) - for ii in range(nsystems): + use_e_max = e_max is not None and len(e_max) == nsystems + if use_e_max: + emax_width = max(5, len(str(max(e_max)))) + log.info( + "%s %-6s %-*s %-6s %-6s %-9s %-3s", + _format_name_length("system", sys_width), + "natoms", + emax_width, + "e_max", + "bch_sz", + "n_bch", + "prob", + "pbc", + ) + else: log.info( - "%s %6d %6d %6d %9.3e %3s", - _format_name_length(system_dirs[ii], sys_width), - natoms[ii], - batch_size[ii], - nbatches[ii], - sys_probs[ii], - "T" if pbc[ii] else "F", + "%s %-6s %-6s %-6s %-9s %-3s", + _format_name_length("system", sys_width), + "natoms", + "bch_sz", + "n_bch", + "prob", + "pbc", ) + for ii in range(nsystems): + if use_e_max: + log.info( + "%s %6d %*d %6d %6d %9.3e %3s", + _format_name_length(system_dirs[ii], sys_width), + natoms[ii], + emax_width, + e_max[ii], + batch_size[ii], + nbatches[ii], + sys_probs[ii], + "T" if pbc[ii] else "F", + ) + else: + log.info( + "%s %6d %6d %6d %9.3e %3s", + _format_name_length(system_dirs[ii], sys_width), + natoms[ii], + batch_size[ii], + nbatches[ii], + sys_probs[ii], + "T" if pbc[ii] else "F", + ) log.info( "--------------------------------------------------------------------------------------" ) diff --git a/doc/model/dpa4.md b/doc/model/dpa4.md new file mode 100644 index 0000000000..8ca5cc1123 --- /dev/null +++ b/doc/model/dpa4.md @@ -0,0 +1,418 @@ +# Descriptor DPA4 {{ pytorch_icon }} + +:::{note} +**Supported backends**: PyTorch {{ pytorch_icon }} +::: + +DPA4 is the DPA-series implementation of SeZM, the Smooth Equivariant +Zone-bridging Model. For new input files, set `model.type: "dpa4"` and +`descriptor.type: "dpa4"`. + +Training example: `examples/water/dpa4/input.json`. + +## Overview + +DPA4 is an SO(3)-equivariant message-passing model for conservative +interatomic potentials. It predicts atomic energies and obtains forces +and virials by differentiating the energy, following the same +conservative formulation used by standard DeePMD energy models: + +```math +\mathbf{F}_i = -\frac{\partial E}{\partial \mathbf{r}_i}. +``` + +The model retains vector and higher-order angular information during +descriptor construction. Only the final descriptor passed to the fitting +network is scalar. This separates the geometric representation from the +energy mapping: equivariant layers encode local geometry, and the fitting +network maps the resulting scalar features to atomic energies. + +## Descriptor construction + +For each frame, DPA4 first builds a local neighbor graph within cutoff +radius `rcut`. Each edge stores the displacement vector, a smooth cutoff +weight, radial basis features, and a rotation from the global coordinate +frame to an edge-aligned local frame. + +One DPA4 interaction block consists of the following operations: + +1. Gather source-atom equivariant features on each edge. +1. Rotate them into the edge-local frame. +1. Apply SO(2)-equivariant convolution on the retained angular orders. +1. Rotate messages back to the global frame. +1. Aggregate messages at destination atoms with smooth envelope weights or + attention weights. +1. Update atom features with an equivariant feed-forward block. + +After the last block, DPA4 keeps the `l = 0` scalar channels: + +```math +\mathcal{D}_i = \mathrm{Scalar}\left(\mathbf{h}_i^{(L)}\right), +``` + +where $\mathbf{h}_i^{(L)}$ is the final equivariant feature of atom `i`. + +## Angular representation + +DPA4 stores intermediate features as SO(3)-equivariant coefficients. A +feature block with maximum degree `lmax` contains all degrees +`l = 0, ..., lmax`, and each degree has `2l + 1` angular components. + +DPA4 avoids the most expensive part of a full SO(3) operation by working +in a local frame on each edge. In that frame, rotations around the edge +axis become SO(2) operations. The descriptor retains only orders +`|m| <= mmax` inside the SO(2) convolution, reducing angular cost while +preserving the required rotation behavior. + +Two schedules control the angular width: + +- `l_schedule` sets the SO(3) degree used by each block. A schedule such as + `[3, 3, 2]` uses higher degrees in early blocks and truncates them in + later blocks. +- `mmax` or `m_schedule` sets how many SO(2) orders are retained in the + edge-local convolution. + +The angular schedule is one of the primary accuracy-cost controls in +DPA4. Larger angular spaces can represent more complex local chemistry, +but the cost grows quickly with `lmax`. For many systems, a +non-increasing `l_schedule` provides a practical compromise. + +## Radial basis and smooth cutoff + +Every edge uses a radial basis multiplied by a smooth envelope. The +default basis is Bessel-like, and a Gaussian basis is also available. The +cutoff envelope is constructed so that its value and first three +derivatives vanish at `rcut`. This smoothness is important for molecular +dynamics because nonsmooth descriptor cutoffs would be inherited by force +derivatives. + +DPA4 uses two envelope exponents through `env_exp`: + +- the first exponent controls the radial basis envelope, +- the second controls message-passing edge weights. + +Increasing the exponent keeps the envelope closer to one for more of the +cutoff range before it drops near `rcut`. + +## Attention and focus streams + +DPA4 can aggregate edge messages either by envelope-weighted scatter or by +attention. When attention is enabled, the cutoff envelope also +participates in the softmax normalization. Edges near the cutoff therefore +fade out in both the numerator and the denominator, avoiding nonsmooth +contributions from the normalization term. + +The SO(2) convolution can also use multiple focus streams. These streams +process the same edge geometry in parallel and are then combined through +scalar weights. This design is not a sparse mixture of experts: all focus +streams are evaluated before soft reweighting. The additional capacity +helps the convolution distinguish different local patterns while +preserving equivariance. + +## Environment-seeded initial features + +When `use_env_seed` is enabled, DPA4 builds an initial scalar signal from a +DeePMD-style local environment matrix. The matrix uses radial information +and normalized directions, then produces FiLM-like scale and shift values +for the first scalar features. + +This provides a simple geometric prior before the equivariant +message-passing blocks. It can be especially useful when the number of +blocks is small. + +## Zone bridging and ZBL + +DPA4 includes an optional short-range bridge for analytical repulsion. The +typical use case is ZBL: + +```math +E_i = E_i^\mathrm{DPA4} + E_i^\mathrm{ZBL}. +``` + +The purpose of zone bridging is to combine the analytical short-range +repulsion with the learned model while preventing uncontrolled learned +forces in the same protected region. + +Zone bridging has two pieces: + +1. Distances below `bridging_r_inner` are clamped before they enter the + descriptor. Between `bridging_r_inner` and `bridging_r_outer`, a smooth + polynomial transitions back to the true distance. +1. A source gate suppresses message propagation from atoms involved in + frozen short-range pairs. This blocks multi-hop leakage, where a third + atom could otherwise carry information about the frozen pair back into + the learned energy. + +This gives a controlled decomposition in the protected region: + +```math +E_\mathrm{total}(r) = E_\mathrm{ZBL}(r) + E_\mathrm{model}(\tilde r), +``` + +where $r$ is the true distance and $\tilde r$ is the clamped distance seen +by the descriptor. + +Enable zone bridging with: + +```json +{ + "model": { + "bridging_method": "zbl", + "bridging_r_inner": 0.8, + "bridging_r_outer": 1.2 + } +} +``` + +## Fitting network + +DPA4 uses `dpa4_ener` as the energy fitting network name in input files. +It is a GLU-based fitting network that maps scalar descriptors to atomic +energies. + +The fitting network uses the same common keys as DeePMD's standard energy +fitting network: + +- `neuron` +- `activation_function` +- `precision` +- `seed` +- `numb_fparam` +- `numb_aparam` + +The hidden layers use GLU-style transformations. If `neuron` is `[0]`, +the fitting network uses a direct projection from descriptor channels to +atomic energy. This compact setting is useful for small examples and smoke +tests. + +For shared-fitting multitask training, DPA4 supports case embeddings. With +`case_film_embd: true`, the case vector modulates the fitting network +instead of being concatenated directly to the descriptor. This keeps the +descriptor case-independent while allowing the energy map to depend on the +task branch. + +## Configuration + +The minimal structure is: + +```json +{ + "model": { + "type": "dpa4", + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa4", + "sel": 120, + "rcut": 6.0, + "channels": 64, + "n_radial": 16, + "lmax": 3, + "mmax": 1, + "n_blocks": 3, + "precision": "float32" + }, + "fitting_net": { + "type": "dpa4_ener", + "neuron": [ + 0 + ], + "precision": "float32" + } + } +} +``` + +### Common descriptor parameters + +| Parameter | Default | Meaning | +| -------------- | ----------- | ----------------------------------------------------------------------------- | +| `sel` | Required | Maximum selected neighbors. It may be an integer, a per-type list, or `auto`. | +| `rcut` | `6.0` | Neighbor cutoff radius. | +| `env_exp` | `[7, 5]` | Envelope exponents for radial basis and message weights. | +| `channels` | `64` | Feature width per angular coefficient. | +| `basis_type` | `"bessel"` | Radial basis family. `"gaussian"` is also supported. | +| `n_radial` | `16` | Number of radial basis functions. | +| `radial_mlp` | `[0]` | Hidden sizes for the radial network. Use `0` as a placeholder for `channels`. | +| `lmax` | `3` | Maximum SO(3) degree when `l_schedule` is not set. | +| `l_schedule` | `None` | Per-block degree schedule. Non-increasing schedules reduce later-block cost. | +| `mmax` | `1` | Maximum SO(2) order when `m_schedule` is not set. | +| `m_schedule` | `None` | Per-block SO(2) order schedule. | +| `n_blocks` | `3` | Number of blocks when `l_schedule` is not set. | +| `n_focus` | `1` | Number of focus streams inside SO(2) convolution. | +| `n_atten_head` | `1` | Number of attention heads. Set to `0` for plain scatter aggregation. | +| `so2_layers` | `4` | Number of SO2Linear layers inside one SO(2) convolution. | +| `ffn_neurons` | `0` | Hidden width of the equivariant FFN. `0` enables automatic width selection. | +| `precision` | `"float32"` | Working precision of descriptor blocks. | + +### Common model parameters + +| Parameter | Default | Meaning | +| -------------------------- | -------- | ------------------------------------------------- | +| `model.type` | Required | Use `"dpa4"`. | +| `model.use_compile` | `false` | Enable the PyTorch `torch.compile` training path. | +| `model.enable_tf32` | `true` | Allow TF32 matmul when compile is used. | +| `model.bridging_method` | `"none"` | Use `"zbl"` to enable ZBL zone bridging. | +| `model.bridging_r_inner` | `0.8` | Inner radius of the bridging window. | +| `model.bridging_r_outer` | `1.2` | Outer radius of the bridging window. | +| `model.pair_exclude_types` | `[]` | Type pairs excluded from descriptor edges. | +| `model.lora` | `null` | Optional LoRA fine-tuning configuration. | + +## Training modes + +The recommended training objective is the standard conservative energy +loss: + +```json +{ + "loss": { + "type": "ener" + } +} +``` + +In this mode, the model predicts energies, and forces are computed by +autograd. + +DPA4 also has an experimental direct-force denoising mode selected by: + +```json +{ + "loss": { + "type": "dens" + } +} +``` + +Use `dens` only when the direct-force denoising head is required. It is +not the default training path. + +## Spin + +DPA4 supports the DeePMD-kit spin convention in the PyTorch backend. Keep +the DPA4 type string and add the standard `model.spin` block: + +```json +{ + "model": { + "type": "dpa4", + "type_map": [ + "Ni", + "O" + ], + "spin": { + "use_spin": [ + true, + false + ], + "virtual_scale": [ + 0.314 + ] + }, + "descriptor": { + "type": "dpa4", + "sel": 120, + "rcut": 6.0 + } + } +} +``` + +The spin path supports the conservative `ener_spin` loss. The direct-force +denoising mode is not used together with spin. + +## `torch.compile` + +DPA4 can train through an experimental `torch.compile` path: + +```json +{ + "model": { + "use_compile": true + } +} +``` + +This path is useful for force-loss training because the model first +differentiates energy to obtain forces and then differentiates the force +loss with respect to model parameters. The training graph therefore +contains second-order coordinate derivatives. DPA4 traces this graph before +passing it to Inductor. + +This is an experimental feature. It requires PyTorch >= 2.11. On NVIDIA +GPUs, CUDA must be >= 12.6. Apple Silicon Macs are also supported. It has +been tested with Python 3.13. + +For evaluation-time compile during validation, set: + +```json +{ + "validating": { + "compiled_infer": true + } +} +``` + +You can also set `DP_COMPILE_INFER=1` in the environment before training. + +## LoRA fine-tuning + +DPA4 supports LoRA adapters on its SO(3) and SO(2) linear layers. A typical +input block is: + +```json +{ + "model": { + "type": "dpa4", + "descriptor": { + "type": "dpa4" + }, + "lora": { + "rank": 16, + "alpha": 16.0 + } + } +} +``` + +Then fine-tune from a checkpoint: + +```bash +dp --pt train lora_ft.json --finetune pretrained.pt +``` + +See `examples/water/dpa4/lora_ft.json` for a complete example. + +## Export + +DPA4 checkpoints use the PyTorch `.pt2` export path. Run the standard +freeze command: + +```bash +dp --pt freeze -c model.ckpt -o frozen_model +``` + +The PyTorch backend detects DPA4 and writes `frozen_model.pt2`. Use this +file with LAMMPS: + +```lammps +pair_style deepmd frozen_model.pt2 +pair_coeff * * O H +``` + +A small LAMMPS example is in `examples/water/dpa4/lmp/`. + +## Data format + +DPA4 uses the [standard DeePMD-kit data format](../data/system.md). Keep +the `type_map` order consistent across the dataset, input file, and any +downstream `pair_coeff` mapping. + +## Limitations + +- DPA4 is currently implemented for the PyTorch backend. +- Model compression is not supported. +- Export uses `.pt2`; the ordinary TorchScript freeze path is not used for + DPA4 checkpoints. diff --git a/doc/model/index.rst b/doc/model/index.rst index a173732bbc..8bd5aada64 100644 --- a/doc/model/index.rst +++ b/doc/model/index.rst @@ -11,6 +11,7 @@ Model train-se-atten dpa2 dpa3 + dpa4 train-hybrid sel train-energy diff --git a/examples/water/dpa4/README.md b/examples/water/dpa4/README.md new file mode 100644 index 0000000000..9da48db452 --- /dev/null +++ b/examples/water/dpa4/README.md @@ -0,0 +1,13 @@ +# Input for DPA4 / SeZM: Smooth equivariant Zone-bridging Model (PyTorch) + +This directory stores a minimal configuration for training DPA4 on the water +example dataset. `model.type: dpa4` and `descriptor.type: dpa4` are the +preferred DPA-series names; `SeZM` and `sezm` are equivalent compatibility +aliases for the same PyTorch implementation. + +Run: + +```bash +cd examples/water/dpa4 +dp --pt train input.json +``` diff --git a/examples/water/dpa4/input-spin.json b/examples/water/dpa4/input-spin.json new file mode 100644 index 0000000000..db126be9d2 --- /dev/null +++ b/examples/water/dpa4/input-spin.json @@ -0,0 +1,160 @@ +{ + "model": { + "type": "dpa4", + "type_map": [ + "Ni", + "O" + ], + "spin": { + "use_spin": [ + true, + false + ], + "virtual_scale": [ + 0.314 + ] + }, + "descriptor": { + "type": "dpa4", + "sel": 120, + "rcut": 6.0, + "env_exp": [ + 7, + 5 + ], + "channels": 64, + "n_radial": 16, + "radial_mlp": [ + 0 + ], + "use_env_seed": true, + "random_gamma": true, + "lmax": 3, + "mmax": 1, + "n_blocks": 3, + "so2_layers": 4, + "so2_norm": false, + "so2_attn_res": "none", + "radial_so2_mode": "degree_channel", + "radial_so2_rank": 1, + "n_focus": 1, + "focus_dim": 0, + "n_atten_head": 1, + "atten_f_mix": false, + "atten_v_proj": false, + "atten_o_proj": false, + "ffn_neurons": 0, + "grid_mlp": false, + "ffn_blocks": 1, + "sandwich_norm": [ + false, + true, + true, + false + ], + "mlp_bias": false, + "layer_scale": false, + "full_attn_res": "none", + "block_attn_res": "none", + "s2_activation": [ + false, + true + ], + "lebedev_quadrature": true, + "activation_function": "silu", + "glu_activation": true, + "use_amp": true, + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "fitting_net": { + "neuron": [ + 0 + ], + "activation_function": "silu", + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "use_compile": false, + "enable_tf32": true, + "_comment": "that's all" + }, + "learning_rate": { + "type": "wsd", + "start_lr": 4.5e-4, + "stop_lr": 1e-6, + "warmup_steps": 5000, + "warmup_start_factor": 0.2, + "decay_phase_ratio": 0.65, + "decay_type": "cosine" + }, + "loss": { + "type": "ener_spin", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_fr": 1000, + "limit_pref_fr": 1, + "start_pref_fm": 1000, + "limit_pref_fm": 1, + "_comment": " that's all" + }, + "optimizer": { + "type": "HybridMuon", + "muon_mode": "slice", + "magma_muon": true, + "lr_adjust": 0.0, + "weight_decay": 0.001 + }, + "training": { + "stat_file": "./dpa4_spin.hdf5", + "training_data": { + "systems": [ + "../../spin/data_reformat/data_0", + "../../spin/data_reformat/data_1" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../../spin/data_reformat/data_2" + ], + "batch_size": 1, + "numb_btch": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "gradient_max_norm": 5.0, + "save_freq": 100, + "max_ckpt_keep": 3, + "enable_ema": true, + "ema_decay": 0.999, + "ema_ckpt_keep": 3, + "disp_file": "lcurve.out", + "disp_freq": 100, + "disp_avg": true, + "disp_training": true, + "time_training": true, + "tensorboard": false, + "enable_profiler": false, + "tensorboard_freq": 1000, + "tensorboard_log_dir": "tb_log", + "profiling": false, + "profiling_file": "timeline.json", + "zero_stage": 0, + "seed": 7, + "_comment": "that's all" + }, + "validating": { + "full_validation": false, + "ema_full_validation": false, + "validation_freq": 100, + "save_best": false, + "max_best_ckpt": 1, + "validation_metric": "E:MAE", + "compiled_infer": false, + "_comment": "full validation currently rejects spin-energy training" + } +} diff --git a/examples/water/dpa4/input-zbl.json b/examples/water/dpa4/input-zbl.json new file mode 100644 index 0000000000..ee07c81901 --- /dev/null +++ b/examples/water/dpa4/input-zbl.json @@ -0,0 +1,158 @@ +{ + "model": { + "type": "dpa4", + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa4", + "sel": 120, + "rcut": 6.0, + "env_exp": [ + 7, + 5 + ], + "channels": 64, + "n_radial": 16, + "radial_mlp": [ + 0 + ], + "use_env_seed": true, + "random_gamma": true, + "lmax": 3, + "mmax": 1, + "n_blocks": 3, + "so2_layers": 4, + "so2_norm": false, + "so2_attn_res": "none", + "radial_so2_mode": "degree_channel", + "radial_so2_rank": 1, + "n_focus": 1, + "focus_dim": 0, + "n_atten_head": 1, + "atten_f_mix": false, + "atten_v_proj": false, + "atten_o_proj": false, + "ffn_neurons": 0, + "grid_mlp": false, + "ffn_blocks": 1, + "sandwich_norm": [ + false, + true, + true, + false + ], + "mlp_bias": false, + "layer_scale": false, + "full_attn_res": "none", + "block_attn_res": "none", + "s2_activation": [ + false, + true + ], + "lebedev_quadrature": true, + "activation_function": "silu", + "glu_activation": true, + "use_amp": true, + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "fitting_net": { + "neuron": [ + 0 + ], + "activation_function": "silu", + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "use_compile": false, + "enable_tf32": true, + "bridging_method": "zbl", + "bridging_r_inner": 0.8, + "bridging_r_outer": 1.2, + "_comment": "that's all" + }, + "learning_rate": { + "type": "wsd", + "start_lr": 4.5e-4, + "stop_lr": 1e-6, + "warmup_steps": 5000, + "warmup_start_factor": 0.2, + "decay_phase_ratio": 0.65, + "decay_type": "cosine" + }, + "loss": { + "type": "ener", + "loss_func": "mae", + "f_use_norm": true, + "start_pref_e": 20, + "limit_pref_e": 20, + "start_pref_f": 20, + "limit_pref_f": 20, + "start_pref_v": 5, + "limit_pref_v": 5 + }, + "optimizer": { + "type": "HybridMuon", + "muon_mode": "slice", + "magma_muon": true, + "lr_adjust": 0.0, + "weight_decay": 0.001 + }, + "training": { + "stat_file": "./dpa4.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "min_pair_dist": 0.8, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "gradient_max_norm": 5.0, + "save_freq": 100, + "max_ckpt_keep": 3, + "enable_ema": true, + "ema_decay": 0.999, + "ema_ckpt_keep": 3, + "disp_file": "lcurve.out", + "disp_freq": 100, + "disp_avg": true, + "disp_training": true, + "time_training": true, + "tensorboard": false, + "enable_profiler": false, + "tensorboard_freq": 1000, + "tensorboard_log_dir": "tb_log", + "profiling": false, + "profiling_file": "timeline.json", + "zero_stage": 0, + "seed": 7, + "_comment": "that's all" + }, + "validating": { + "full_validation": false, + "ema_full_validation": false, + "validation_freq": 100, + "save_best": true, + "max_best_ckpt": 1, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.5, + "_comment": "that's all" + } +} diff --git a/examples/water/dpa4/input.json b/examples/water/dpa4/input.json new file mode 100644 index 0000000000..e6b48275ad --- /dev/null +++ b/examples/water/dpa4/input.json @@ -0,0 +1,156 @@ +{ + "model": { + "type": "dpa4", + "type_map": [ + "O", + "H" + ], + "pair_exclude_types": [], + "descriptor": { + "type": "dpa4", + "sel": 120, + "rcut": 6.0, + "env_exp": [ + 7, + 5 + ], + "channels": 64, + "n_radial": 16, + "radial_mlp": [ + 0 + ], + "use_env_seed": true, + "random_gamma": true, + "lmax": 3, + "mmax": 1, + "n_blocks": 3, + "so2_layers": 4, + "so2_norm": false, + "so2_attn_res": "none", + "radial_so2_mode": "degree_channel", + "radial_so2_rank": 1, + "n_focus": 1, + "focus_dim": 0, + "n_atten_head": 1, + "atten_f_mix": false, + "atten_v_proj": false, + "atten_o_proj": false, + "ffn_neurons": 0, + "grid_mlp": false, + "ffn_blocks": 1, + "sandwich_norm": [ + false, + true, + true, + false + ], + "mlp_bias": false, + "layer_scale": false, + "full_attn_res": "none", + "block_attn_res": "none", + "s2_activation": [ + false, + true + ], + "lebedev_quadrature": true, + "activation_function": "silu", + "glu_activation": true, + "use_amp": true, + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "fitting_net": { + "neuron": [ + 0 + ], + "activation_function": "silu", + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "use_compile": false, + "enable_tf32": true, + "_comment": "that's all" + }, + "learning_rate": { + "type": "wsd", + "start_lr": 4.5e-4, + "stop_lr": 1e-6, + "warmup_steps": 5000, + "warmup_start_factor": 0.2, + "decay_phase_ratio": 0.65, + "decay_type": "cosine" + }, + "loss": { + "type": "ener", + "loss_func": "mae", + "f_use_norm": true, + "start_pref_e": 20, + "limit_pref_e": 20, + "start_pref_f": 20, + "limit_pref_f": 20, + "start_pref_v": 5, + "limit_pref_v": 5 + }, + "optimizer": { + "type": "HybridMuon", + "muon_mode": "slice", + "magma_muon": true, + "lr_adjust": 0.0, + "weight_decay": 0.001 + }, + "training": { + "stat_file": "./dpa4.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "gradient_max_norm": 5.0, + "save_freq": 100, + "max_ckpt_keep": 3, + "enable_ema": true, + "ema_decay": 0.999, + "ema_ckpt_keep": 3, + "disp_file": "lcurve.out", + "disp_freq": 100, + "disp_avg": true, + "disp_training": true, + "time_training": true, + "tensorboard": false, + "enable_profiler": false, + "tensorboard_freq": 1000, + "tensorboard_log_dir": "tb_log", + "profiling": false, + "profiling_file": "timeline.json", + "zero_stage": 0, + "seed": 7, + "_comment": "that's all" + }, + "validating": { + "full_validation": true, + "ema_full_validation": true, + "validation_freq": 100, + "save_best": true, + "max_best_ckpt": 1, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 100, + "compiled_infer": false, + "_comment": "that's all" + } +} diff --git a/examples/water/dpa4/input_dens.json b/examples/water/dpa4/input_dens.json new file mode 100644 index 0000000000..0f46c203da --- /dev/null +++ b/examples/water/dpa4/input_dens.json @@ -0,0 +1,157 @@ +{ + "model": { + "type": "dpa4", + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa4", + "sel": 120, + "rcut": 6.0, + "env_exp": [ + 7, + 5 + ], + "channels": 64, + "n_radial": 16, + "radial_mlp": [ + 0 + ], + "use_env_seed": true, + "random_gamma": true, + "lmax": 3, + "mmax": 1, + "n_blocks": 3, + "so2_layers": 4, + "so2_norm": false, + "so2_attn_res": "none", + "radial_so2_mode": "degree_channel", + "radial_so2_rank": 1, + "n_focus": 1, + "focus_dim": 0, + "n_atten_head": 1, + "atten_f_mix": false, + "atten_v_proj": false, + "atten_o_proj": false, + "ffn_neurons": 0, + "grid_mlp": false, + "ffn_blocks": 1, + "sandwich_norm": [ + false, + true, + true, + false + ], + "mlp_bias": false, + "layer_scale": false, + "full_attn_res": "none", + "block_attn_res": "none", + "s2_activation": [ + false, + true + ], + "lebedev_quadrature": true, + "activation_function": "silu", + "glu_activation": true, + "use_amp": true, + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "fitting_net": { + "neuron": [ + 0 + ], + "activation_function": "silu", + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "use_compile": false, + "enable_tf32": true, + "_comment": "that's all" + }, + "learning_rate": { + "type": "wsd", + "start_lr": 5e-4, + "stop_lr": 1e-6, + "warmup_steps": 5000, + "warmup_start_factor": 0.2, + "decay_phase_ratio": 0.65, + "decay_type": "cosine" + }, + "loss": { + "type": "dens", + "loss_func": "mae", + "start_pref_e": 20, + "limit_pref_e": 20, + "start_pref_f": 20, + "limit_pref_f": 20, + "dens_prob": 0.5, + "dens_std": 0.025, + "dens_corrupt_ratio": 0.5, + "dens_denoising_pos_coefficient": 10.0 + }, + "optimizer": { + "type": "HybridMuon", + "muon_mode": "slice", + "magma_muon": true, + "lr_adjust": 0.0, + "weight_decay": 0.001 + }, + "training": { + "stat_file": "./dpa4.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "min_pair_dist": 1.0, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "gradient_max_norm": 5.0, + "save_freq": 100, + "max_ckpt_keep": 3, + "enable_ema": true, + "ema_decay": 0.999, + "ema_ckpt_keep": 3, + "disp_file": "lcurve.out", + "disp_freq": 100, + "disp_avg": true, + "disp_training": true, + "time_training": true, + "tensorboard": false, + "enable_profiler": false, + "tensorboard_freq": 1000, + "tensorboard_log_dir": "tb_log", + "profiling": false, + "profiling_file": "timeline.json", + "zero_stage": 0, + "seed": 7, + "_comment": "that's all" + }, + "validating": { + "full_validation": false, + "ema_full_validation": false, + "validation_freq": 100, + "save_best": true, + "max_best_ckpt": 1, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + "_comment": "that's all" + }, + "_comment": "that's all" +} diff --git a/examples/water/dpa4/input_multitask.json b/examples/water/dpa4/input_multitask.json new file mode 100644 index 0000000000..640ea93cca --- /dev/null +++ b/examples/water/dpa4/input_multitask.json @@ -0,0 +1,216 @@ +{ + "_comment": "DPA4 / SeZM multi-task example with a shared descriptor and per-task fitting nets.", + "model": { + "use_compile": false, + "enable_tf32": true, + "shared_dict": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa4", + "sel": 120, + "rcut": 6.0, + "env_exp": [ + 7, + 5 + ], + "channels": 64, + "n_radial": 16, + "radial_mlp": [ + 0 + ], + "use_env_seed": true, + "random_gamma": true, + "lmax": 3, + "mmax": 1, + "n_blocks": 3, + "so2_layers": 4, + "so2_norm": false, + "so2_attn_res": "none", + "radial_so2_mode": "degree_channel", + "radial_so2_rank": 1, + "n_focus": 1, + "focus_dim": 0, + "n_atten_head": 1, + "atten_f_mix": false, + "atten_v_proj": false, + "atten_o_proj": false, + "ffn_neurons": 0, + "grid_mlp": false, + "ffn_blocks": 1, + "sandwich_norm": [ + false, + true, + true, + false + ], + "mlp_bias": false, + "layer_scale": false, + "full_attn_res": "none", + "block_attn_res": "none", + "s2_activation": [ + false, + true + ], + "lebedev_quadrature": true, + "activation_function": "silu", + "glu_activation": true, + "use_amp": true, + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "_comment": "that's all" + }, + "model_dict": { + "water_1": { + "type": "dpa4", + "type_map": "type_map", + "descriptor": "descriptor", + "fitting_net": { + "type": "dpa4_ener", + "neuron": [ + 0 + ], + "activation_function": "silu", + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "model_branch_alias": [ + "Default", + "Water" + ], + "info": { + "description": "Water branch with shared DPA4 / SeZM descriptor and an independent fitting net" + }, + "_comment": "that's all" + }, + "water_2": { + "type": "dpa4", + "type_map": "type_map", + "descriptor": "descriptor", + "fitting_net": { + "type": "dpa4_ener", + "neuron": [ + 0 + ], + "activation_function": "silu", + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "model_branch_alias": [ + "Water2" + ], + "info": { + "description": "Second water branch with shared DPA4 / SeZM descriptor and an independent fitting net" + }, + "_comment": "that's all" + } + } + }, + "learning_rate": { + "type": "wsd", + "start_lr": 4.5e-4, + "stop_lr": 1e-6, + "warmup_steps": 5000, + "warmup_start_factor": 0.2, + "decay_phase_ratio": 0.65, + "decay_type": "cosine" + }, + "loss_dict": { + "water_1": { + "type": "ener", + "loss_func": "mae", + "f_use_norm": true, + "start_pref_e": 20, + "limit_pref_e": 20, + "start_pref_f": 20, + "limit_pref_f": 20, + "start_pref_v": 5, + "limit_pref_v": 5 + }, + "water_2": { + "type": "ener", + "loss_func": "mae", + "f_use_norm": true, + "start_pref_e": 20, + "limit_pref_e": 20, + "start_pref_f": 20, + "limit_pref_f": 20, + "start_pref_v": 5, + "limit_pref_v": 5 + } + }, + "optimizer": { + "type": "HybridMuon", + "muon_mode": "slice", + "magma_muon": true, + "lr_adjust": 0.0, + "weight_decay": 0.001 + }, + "training": { + "model_prob": { + "water_1": 0.5, + "water_2": 0.5 + }, + "data_dict": { + "water_1": { + "stat_file": "./dpa4_water_1.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 1, + "_comment": "that's all" + } + }, + "water_2": { + "stat_file": "./dpa4_water_2.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + } + } + }, + "numb_steps": 1000000, + "gradient_max_norm": 5.0, + "save_freq": 100, + "max_ckpt_keep": 3, + "enable_ema": true, + "ema_decay": 0.999, + "ema_ckpt_keep": 3, + "disp_file": "lcurve.out", + "disp_freq": 100, + "disp_avg": true, + "disp_training": true, + "time_training": true, + "tensorboard": false, + "enable_profiler": false, + "tensorboard_freq": 1000, + "tensorboard_log_dir": "tb_log", + "profiling": false, + "profiling_file": "timeline.json", + "zero_stage": 0, + "seed": 7, + "_comment": "that's all" + } +} diff --git a/examples/water/dpa4/input_multitask_sharefit-zbl.json b/examples/water/dpa4/input_multitask_sharefit-zbl.json new file mode 100644 index 0000000000..9893e8aaa7 --- /dev/null +++ b/examples/water/dpa4/input_multitask_sharefit-zbl.json @@ -0,0 +1,214 @@ +{ + "_comment": "DPA4 / SeZM multi-task example with shared descriptor, shared case-embedded fitting net, and ZBL bridging.", + "model": { + "use_compile": false, + "enable_tf32": true, + "bridging_method": "zbl", + "bridging_r_inner": 0.8, + "bridging_r_outer": 1.2, + "shared_dict": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa4", + "sel": 120, + "rcut": 6.0, + "env_exp": [ + 7, + 5 + ], + "channels": 64, + "n_radial": 16, + "radial_mlp": [ + 0 + ], + "use_env_seed": true, + "random_gamma": true, + "lmax": 3, + "mmax": 1, + "n_blocks": 3, + "so2_layers": 4, + "so2_norm": false, + "so2_attn_res": "none", + "radial_so2_mode": "degree_channel", + "radial_so2_rank": 1, + "n_focus": 1, + "focus_dim": 0, + "n_atten_head": 1, + "atten_f_mix": false, + "atten_v_proj": false, + "atten_o_proj": false, + "ffn_neurons": 0, + "grid_mlp": false, + "ffn_blocks": 1, + "sandwich_norm": [ + false, + true, + true, + false + ], + "mlp_bias": false, + "layer_scale": false, + "full_attn_res": "none", + "block_attn_res": "none", + "s2_activation": [ + false, + true + ], + "lebedev_quadrature": true, + "activation_function": "silu", + "glu_activation": true, + "use_amp": true, + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "shared_fit_with_id": { + "type": "dpa4_ener", + "neuron": [ + 0 + ], + "activation_function": "silu", + "precision": "float32", + "dim_case_embd": 2, + "seed": 42, + "_comment": "that's all" + }, + "_comment": "that's all" + }, + "model_dict": { + "water_1": { + "type": "dpa4", + "type_map": "type_map", + "descriptor": "descriptor", + "fitting_net": "shared_fit_with_id", + "model_branch_alias": [ + "Default", + "Water" + ], + "info": { + "description": "Water branch with shared DPA4 / SeZM descriptor, case-embedded shared fitting net, and ZBL bridging" + }, + "_comment": "that's all" + }, + "water_2": { + "type": "dpa4", + "type_map": "type_map", + "descriptor": "descriptor", + "fitting_net": "shared_fit_with_id", + "model_branch_alias": [ + "Water2" + ], + "info": { + "description": "Second water branch with shared DPA4 / SeZM descriptor, case-embedded shared fitting net, and ZBL bridging" + }, + "_comment": "that's all" + } + } + }, + "learning_rate": { + "type": "wsd", + "start_lr": 4.5e-4, + "stop_lr": 1e-6, + "warmup_steps": 5000, + "warmup_start_factor": 0.2, + "decay_phase_ratio": 0.65, + "decay_type": "cosine" + }, + "loss_dict": { + "water_1": { + "type": "ener", + "loss_func": "mae", + "f_use_norm": true, + "start_pref_e": 20, + "limit_pref_e": 20, + "start_pref_f": 20, + "limit_pref_f": 20, + "start_pref_v": 5, + "limit_pref_v": 5 + }, + "water_2": { + "type": "ener", + "loss_func": "mae", + "f_use_norm": true, + "start_pref_e": 20, + "limit_pref_e": 20, + "start_pref_f": 20, + "limit_pref_f": 20, + "start_pref_v": 5, + "limit_pref_v": 5 + } + }, + "optimizer": { + "type": "HybridMuon", + "muon_mode": "slice", + "magma_muon": true, + "lr_adjust": 0.0, + "weight_decay": 0.001 + }, + "training": { + "model_prob": { + "water_1": 0.5, + "water_2": 0.5 + }, + "data_dict": { + "water_1": { + "stat_file": "./dpa4_water_1.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "min_pair_dist": 0.8, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 1, + "_comment": "that's all" + } + }, + "water_2": { + "stat_file": "./dpa4_water_2.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "min_pair_dist": 0.8, + "_comment": "that's all" + } + } + }, + "numb_steps": 1000000, + "gradient_max_norm": 5.0, + "save_freq": 100, + "max_ckpt_keep": 3, + "enable_ema": true, + "ema_decay": 0.999, + "ema_ckpt_keep": 3, + "disp_file": "lcurve.out", + "disp_freq": 100, + "disp_avg": true, + "disp_training": true, + "time_training": true, + "tensorboard": false, + "enable_profiler": false, + "tensorboard_freq": 1000, + "tensorboard_log_dir": "tb_log", + "profiling": false, + "profiling_file": "timeline.json", + "zero_stage": 0, + "seed": 7, + "_comment": "that's all" + } +} diff --git a/examples/water/dpa4/input_multitask_sharefit.json b/examples/water/dpa4/input_multitask_sharefit.json new file mode 100644 index 0000000000..6a9b433e30 --- /dev/null +++ b/examples/water/dpa4/input_multitask_sharefit.json @@ -0,0 +1,210 @@ +{ + "_comment": "DPA4 / SeZM multi-task example with a shared descriptor AND shared fitting net (case-embedded).", + "model": { + "use_compile": false, + "enable_tf32": true, + "shared_dict": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa4", + "sel": 120, + "rcut": 6.0, + "env_exp": [ + 7, + 5 + ], + "channels": 64, + "n_radial": 16, + "radial_mlp": [ + 0 + ], + "use_env_seed": true, + "random_gamma": true, + "lmax": 3, + "mmax": 1, + "n_blocks": 3, + "so2_layers": 4, + "so2_norm": false, + "so2_attn_res": "none", + "radial_so2_mode": "degree_channel", + "radial_so2_rank": 1, + "n_focus": 1, + "focus_dim": 0, + "n_atten_head": 1, + "atten_f_mix": false, + "atten_v_proj": false, + "atten_o_proj": false, + "ffn_neurons": 0, + "grid_mlp": false, + "ffn_blocks": 1, + "sandwich_norm": [ + false, + true, + true, + false + ], + "mlp_bias": false, + "layer_scale": false, + "full_attn_res": "none", + "block_attn_res": "none", + "s2_activation": [ + false, + true + ], + "lebedev_quadrature": true, + "activation_function": "silu", + "glu_activation": true, + "use_amp": true, + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "shared_fit_with_id": { + "type": "dpa4_ener", + "neuron": [ + 0 + ], + "activation_function": "silu", + "precision": "float32", + "dim_case_embd": 2, + "case_film_embd": true, + "seed": 42, + "_comment": "that's all" + }, + "_comment": "that's all" + }, + "model_dict": { + "water_1": { + "type": "dpa4", + "type_map": "type_map", + "descriptor": "descriptor", + "fitting_net": "shared_fit_with_id", + "model_branch_alias": [ + "Default", + "Water" + ], + "info": { + "description": "Water branch with shared DPA4 / SeZM descriptor and case-embedded shared fitting net" + }, + "_comment": "that's all" + }, + "water_2": { + "type": "dpa4", + "type_map": "type_map", + "descriptor": "descriptor", + "fitting_net": "shared_fit_with_id", + "model_branch_alias": [ + "Water2" + ], + "info": { + "description": "Second water branch with shared DPA4 / SeZM descriptor and case-embedded shared fitting net" + }, + "_comment": "that's all" + } + } + }, + "learning_rate": { + "type": "wsd", + "start_lr": 4.5e-4, + "stop_lr": 1e-6, + "warmup_steps": 5000, + "warmup_start_factor": 0.2, + "decay_phase_ratio": 0.65, + "decay_type": "cosine" + }, + "loss_dict": { + "water_1": { + "type": "ener", + "loss_func": "mae", + "f_use_norm": true, + "start_pref_e": 20, + "limit_pref_e": 20, + "start_pref_f": 20, + "limit_pref_f": 20, + "start_pref_v": 5, + "limit_pref_v": 5 + }, + "water_2": { + "type": "ener", + "loss_func": "mae", + "f_use_norm": true, + "start_pref_e": 20, + "limit_pref_e": 20, + "start_pref_f": 20, + "limit_pref_f": 20, + "start_pref_v": 5, + "limit_pref_v": 5 + } + }, + "optimizer": { + "type": "HybridMuon", + "muon_mode": "slice", + "magma_muon": true, + "lr_adjust": 0.0, + "weight_decay": 0.001 + }, + "training": { + "model_prob": { + "water_1": 0.5, + "water_2": 0.5 + }, + "data_dict": { + "water_1": { + "stat_file": "./dpa4_water_1.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 1, + "_comment": "that's all" + } + }, + "water_2": { + "stat_file": "./dpa4_water_2.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + } + } + }, + "numb_steps": 1000000, + "gradient_max_norm": 5.0, + "save_freq": 100, + "max_ckpt_keep": 3, + "enable_ema": true, + "ema_decay": 0.999, + "ema_ckpt_keep": 3, + "disp_file": "lcurve.out", + "disp_freq": 100, + "disp_avg": true, + "disp_training": true, + "time_training": true, + "tensorboard": false, + "enable_profiler": false, + "tensorboard_freq": 1000, + "tensorboard_log_dir": "tb_log", + "profiling": false, + "profiling_file": "timeline.json", + "zero_stage": 0, + "seed": 7, + "_comment": "that's all" + } +} diff --git a/examples/water/dpa4/lmp/README.md b/examples/water/dpa4/lmp/README.md new file mode 100644 index 0000000000..1c28221e34 --- /dev/null +++ b/examples/water/dpa4/lmp/README.md @@ -0,0 +1,67 @@ +# LAMMPS example for DPA4 / SeZM + +This directory contains a minimal end-to-end pipeline for running a +DPA4 model in LAMMPS via `pair_style deepmd`. DPA4 and SeZM refer to the +same PyTorch implementation; DPA4 is the DPA-series user-facing name. + +## Files + +| File | Description | +| --------------- | -------------------------------------------------------------------------------------------------------------------------------- | +| `input.json` | Training configuration: tiny DPA4 / SeZM (`channels=16`, two blocks, fp32), 500 Adam steps on `examples/water/data/data_{0..3}`. | +| `pretrained.pt` | Shipped checkpoint for the LAMMPS smoke test. | +| `in.lammps` | 20-step NVT run at 330 K on 192 water molecules. | +| `water.lmp` | LAMMPS data file (192-atom liquid water cell). | + +The frozen `.pt2` archive is not included because AOTInductor packages +are target-specific: they depend on the host's CPU/GPU, GPU compute +capability, and libtorch version. Freeze locally before running. + +## Usage + +Optionally retrain: + +```bash +dp --pt train input.json --skip-neighbor-stat +``` + +Freeze the checkpoint (the pt backend detects DPA4 / SeZM and writes a +`.pt2` archive automatically): + +```bash +dp --pt freeze -c model.ckpt.pt -o frozen_model +``` + +To use the shipped smoke-test checkpoint instead of retraining, replace +`model.ckpt.pt` with `pretrained.pt`. + +Run the MD: + +```bash +lmp -in in.lammps +``` + +Expected LAMMPS output: + +``` +load model from: frozen_model.pt2 to gpu 0 + rcut in model: 6 + ntypes in model: 2 +Step PotEng KinEng TotEng Temp + 0 -29941.035 8.147 -29932.89 330.00 + 10 -29940.605 7.771 -29932.83 314.76 + 20 -29940.399 7.564 -29932.83 306.39 +``` + +## Notes + +- `pair_coeff * * O H` pins LAMMPS atom types 1 and 2 to `type_map` + entries `"O"` and `"H"` respectively. When the element names are + omitted, the mapping falls back to the `type_map` order stored in + the `.pt2` metadata. +- `atom_modify map yes` keeps the ghost / periodic-image to local-atom + mapping explicit for `.pt2` graph inference. Single-rank LAMMPS runs + can also synthesize this mapping from atom tags when no atom map exists. +- The 500-step `pretrained.pt` is intended as a smoke test, not a + physically accurate water potential. Retrain with a longer schedule + for production. diff --git a/examples/water/dpa4/lmp/in.lammps b/examples/water/dpa4/lmp/in.lammps new file mode 100644 index 0000000000..10ff43a689 --- /dev/null +++ b/examples/water/dpa4/lmp/in.lammps @@ -0,0 +1,38 @@ +# DPA4 / SeZM bulk water: short MD smoke run backed by an AOTInductor .pt2 archive. +# +# The .pt2 file is produced by +# dp --pt freeze -c -o frozen_model +# against a DPA4 / SeZM training checkpoint. The freeze CLI detects DPA4 / SeZM +# automatically, rewrites the output suffix to .pt2 and emits a package +# whose I/O is fp64 — matching the DeepPotPTExpt C++ contract. + +units metal +boundary p p p +atom_style atomic +atom_modify map yes + +neighbor 2.0 bin +neigh_modify every 10 delay 0 check no + +read_data water.lmp +mass 1 16 +mass 2 2 + +# AOTInductor models are dispatched by suffix: DeepPotPTExpt (the PyTorch +# "exportable" backend) owns every .pt2 file via source/api_cc/src/DeepPotPTExpt.cc. +# No plugin, no extra flags are required. +pair_style deepmd frozen_model.pt2 +# The training type_map ordering is (O, H); element names on pair_coeff +# map LAMMPS atom types 1/2 to those entries explicitly. +pair_coeff * * O H + +velocity all create 330.0 23456789 + +fix 1 all nvt temp 330.0 330.0 0.5 +timestep 0.0005 +thermo_style custom step pe ke etotal temp press vol +thermo 10 +dump 1 all custom 10 water.dump id type x y z + +# Short smoke run; scale up for production MD. +run 20 diff --git a/examples/water/dpa4/lmp/input.json b/examples/water/dpa4/lmp/input.json new file mode 100644 index 0000000000..9e611fb706 --- /dev/null +++ b/examples/water/dpa4/lmp/input.json @@ -0,0 +1,133 @@ +{ + "_comment": "Tiny DPA4 / SeZM water demo. Trained for ~500 steps; the resulting checkpoint is shipped with this example solely so newcomers have something to freeze -> run in LAMMPS out of the box. It is NOT a physically accurate force field.", + "model": { + "type": "dpa4", + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa4", + "sel": 120, + "rcut": 6.0, + "env_exp": [ + 7, + 5 + ], + "channels": 16, + "n_radial": 6, + "radial_mlp": [ + 0 + ], + "use_env_seed": true, + "random_gamma": true, + "l_schedule": [ + 1, + 1 + ], + "mmax": 1, + "so2_norm": false, + "so2_layers": 2, + "so2_attn_res": "none", + "n_focus": 1, + "focus_dim": 0, + "n_atten_head": 1, + "ffn_neurons": 0, + "grid_mlp": false, + "ffn_blocks": 1, + "sandwich_norm": [ + true, + false, + true, + false + ], + "mlp_bias": false, + "layer_scale": false, + "full_attn_res": "none", + "block_attn_res": "none", + "s2_activation": [ + false, + true + ], + "activation_function": "silu", + "glu_activation": true, + "use_amp": false, + "precision": "float32", + "seed": 42 + }, + "fitting_net": { + "neuron": [ + 0 + ], + "activation_function": "silu", + "precision": "float32", + "seed": 42 + }, + "use_compile": false, + "enable_tf32": false, + "bridging_method": "none", + "bridging_r_inner": 0.8, + "bridging_r_outer": 1.2 + }, + "learning_rate": { + "type": "wsd", + "start_lr": 5e-4, + "stop_lr": 1e-6, + "warmup_steps": 50, + "warmup_start_factor": 0.2, + "decay_phase_ratio": 0.65, + "decay_type": "cosine" + }, + "loss": { + "type": "ener", + "loss_func": "mae", + "f_use_norm": true, + "start_pref_e": 20, + "limit_pref_e": 20, + "start_pref_f": 20, + "limit_pref_f": 20, + "start_pref_v": 0, + "limit_pref_v": 0 + }, + "optimizer": { + "type": "HybridMuon", + "muon_mode": "slice", + "enable_gram": true, + "magma_muon": true, + "lr_adjust": 0.0, + "weight_decay": 0.001 + }, + "training": { + "training_data": { + "systems": [ + "../../data/data_0", + "../../data/data_1", + "../../data/data_2" + ], + "batch_size": 1 + }, + "validation_data": { + "systems": [ + "../../data/data_3" + ], + "batch_size": 1, + "numb_btch": 1 + }, + "numb_steps": 500, + "gradient_max_norm": 5.0, + "save_freq": 500, + "max_ckpt_keep": 1, + "enable_ema": false, + "disp_file": "lcurve.out", + "disp_freq": 100, + "disp_avg": true, + "disp_training": true, + "time_training": false, + "seed": 7 + }, + "validating": { + "full_validation": false, + "ema_full_validation": false, + "validation_freq": 100 + } +} diff --git a/examples/water/dpa4/lmp/pretrained.pt b/examples/water/dpa4/lmp/pretrained.pt new file mode 100644 index 0000000000..6b15c4f6c0 Binary files /dev/null and b/examples/water/dpa4/lmp/pretrained.pt differ diff --git a/examples/water/dpa4/lmp/water.lmp b/examples/water/dpa4/lmp/water.lmp new file mode 100644 index 0000000000..8bd1def35a --- /dev/null +++ b/examples/water/dpa4/lmp/water.lmp @@ -0,0 +1,202 @@ +# LAMMPS data +192 atoms +2 atom types +0.0 12.44470 xlo xhi +0.0 12.44470 ylo yhi +0.0 12.44470 zlo zhi +0.0 0.0 0.0 xy xz yz + +Atoms # metal + +1 1 11.83 2.56 2.18 +2 1 5.36 6 1.81 +3 1 12.17 7.34 2.71 +4 1 3.38 4.28 4.84 +5 1 4.54 11.73 11.9 +6 1 7.34 0.18 2.18 +7 1 5.73 9.9 10.31 +8 1 5.32 11.19 1.83 +9 1 9.39 11.14 1.1 +10 1 12.06 0.64 8.92 +11 1 5.01 7.72 11.74 +12 1 3.64 3.19 1.08 +13 1 11.88 10.9 4.7 +14 1 3.88 8.05 2.51 +15 1 9.82 9.16 10.18 +16 1 11 5.31 5.17 +17 1 0.97 4.12 0.84 +18 1 1.24 5.99 4.76 +19 1 10.1 9.09 3.32 +20 1 8.39 5.72 10.39 +21 1 8.06 4.72 1.3 +22 1 3.23 3.34 8.26 +23 1 5.02 6.21 7.29 +24 1 3.94 10.55 3.95 +25 1 6.25 4.56 4.21 +26 1 12.18 2.35 11.11 +27 1 6.76 8.57 4.07 +28 1 5.87 5.84 9.79 +29 1 2.18 6.91 6.99 +30 1 10.89 7.69 6.78 +31 1 8.03 4.28 6.57 +32 1 5.29 1.92 11.56 +33 1 1.53 12.16 3.45 +34 1 2.04 5.49 9.71 +35 1 6.66 9.05 1.21 +36 1 11.75 6.18 11.01 +37 1 6.11 3.07 7.96 +38 1 8.22 1.38 9.43 +39 1 1.31 2.02 7.09 +40 1 9.48 2.11 5.92 +41 1 9.33 3.34 10.97 +42 1 1.31 8.49 0.81 +43 1 11.25 3.48 7.51 +44 1 11.7 0.33 0.6 +45 1 7.63 2.04 0.28 +46 1 8.86 9.52 7.4 +47 1 2.01 11.54 6.89 +48 1 8.02 11.27 4.66 +49 1 5.82 9.18 7.71 +50 1 8.11 11.19 11.22 +51 1 5.94 0.84 5.91 +52 1 1.88 11.66 0.64 +53 1 4.66 11.28 6.73 +54 1 10.32 6.26 0.89 +55 1 2.01 11.07 10.33 +56 1 5.36 2.06 3.14 +57 1 8.56 7.59 12.17 +58 1 10.51 5.81 8.66 +59 1 2.37 6.23 12.1 +60 1 0.3 8.77 10.05 +61 1 9.47 12.13 7.57 +62 1 1.41 2.32 4.17 +63 1 9.69 3.69 3.56 +64 1 0.75 9.22 7.39 +65 2 11.09 2.87 2.74 +66 2 5.51 5.51 2.6 +67 2 0.24 6.8 3.29 +68 2 4.28 4.19 4.46 +69 2 3.63 11.56 11.76 +70 2 6.53 12.07 1.99 +71 2 5.37 9.14 10.85 +72 2 5.74 10.31 1.64 +73 2 8.71 11.07 0.41 +74 2 0.27 1.04 8.27 +75 2 5.46 7.08 11.15 +76 2 3.84 4.17 1.06 +77 2 11.55 10.19 4.15 +78 2 3 8.17 2.14 +79 2 9.53 10.04 10.56 +80 2 11.18 4.84 6.05 +81 2 0.68 3.76 12.44 +82 2 0.43 5.51 4.88 +83 2 9.51 9.48 3.94 +84 2 9.08 5.89 9.68 +85 2 7.92 3.83 0.8 +86 2 4.23 3.22 8.21 +87 2 5.27 5.9 8.19 +88 2 4.44 10.9 3.14 +89 2 6.93 4.56 4.89 +90 2 11.88 1.61 11.75 +91 2 6.79 8.54 3.09 +92 2 5.62 4.99 10.13 +93 2 3.09 6.66 6.98 +94 2 11.73 8.22 6.86 +95 2 8.4 3.45 6.18 +96 2 4.64 2.37 12.16 +97 2 0.74 11.69 3.86 +98 2 1.1 5.3 9.72 +99 2 7.54 9.11 0.83 +100 2 11.36 6.2 11.9 +101 2 5.85 2.21 7.53 +102 2 8.81 2.01 9.89 +103 2 2.17 2.43 7.46 +104 2 10.1 2.53 6.55 +105 2 9.07 4.29 10.91 +106 2 0.64 8.13 1.45 +107 2 11.12 4.2 8.14 +108 2 11.03 12.03 0.58 +109 2 6.83 1.84 12.19 +110 2 9.13 10.45 7.25 +111 2 1.76 12.44 6.8 +112 2 7.37 11.88 5.08 +113 2 5.59 8.29 7.44 +114 2 8.14 12.12 10.86 +115 2 5.48 0.06 6.32 +116 2 0.99 11.99 0.41 +117 2 4.54 10.95 5.79 +118 2 10.94 6.29 1.65 +119 2 2.52 10.55 9.76 +120 2 4.74 2.41 2.42 +121 2 8.84 8.27 11.56 +122 2 10.89 6.06 9.58 +123 2 2.15 5.99 11.18 +124 2 0.49 9.06 9.13 +125 2 8.69 12.34 8.15 +126 2 1.76 1.44 4.07 +127 2 10.05 4.44 4.1 +128 2 1.44 8.49 7.32 +129 2 12.25 3.32 1.68 +130 2 6.27 6.22 1.56 +131 2 11.5 7.85 3.14 +132 2 2.96 3.44 4.61 +133 2 5.04 11.17 11.3 +134 2 7.6 12.39 3.03 +135 2 5.94 9.68 9.37 +136 2 4.93 11.46 0.96 +137 2 8.92 11.68 1.69 +138 2 12.07 1.36 9.54 +139 2 4.17 7.28 11.8 +140 2 2.67 3.09 1.05 +141 2 11.89 10.59 5.56 +142 2 4.27 7.23 2.16 +143 2 10.8 9.22 10.18 +144 2 10.68 6.2 5.44 +145 2 1.51 4.9 0.53 +146 2 1.94 5.36 4.61 +147 2 10.02 9.64 2.47 +148 2 8.41 6.43 11.04 +149 2 8.58 4.29 2.01 +150 2 3.12 4.28 8.41 +151 2 5.12 5.54 6.6 +152 2 3.97 9.58 3.89 +153 2 6.17 3.63 3.9 +154 2 11.4 2.92 10.99 +155 2 5.96 8.02 4.26 +156 2 6.81 5.9 9.83 +157 2 1.83 6.59 6.16 +158 2 10.19 8.36 6.88 +159 2 8.74 4.67 7.12 +160 2 4.84 1.08 11.38 +161 2 2.38 11.71 3.7 +162 2 2.02 6.1 8.92 +163 2 6.05 8.48 0.61 +164 2 12.12 7.05 10.81 +165 2 6.81 3.46 7.31 +166 2 7.52 1.99 9.1 +167 2 1.25 2.26 6.18 +168 2 9.31 1.19 6.35 +169 2 8.84 2.9 11.69 +170 2 1.39 9.44 0.85 +171 2 12.13 3.22 7.5 +172 2 11.38 0.95 1.34 +173 2 7.43 1.82 1.24 +174 2 9.15 9.43 8.28 +175 2 2.97 11.66 6.88 +176 2 7.5 10.43 4.47 +177 2 6.69 9.4 7.47 +178 2 7.29 10.77 10.93 +179 2 5.45 1.07 5.12 +180 2 1.92 11.57 1.62 +181 2 5.05 10.55 7.26 +182 2 9.73 5.52 1.11 +183 2 1.73 11.85 9.82 +184 2 6.02 1.57 2.67 +185 2 9.34 7.2 0.21 +186 2 10.71 6.56 7.97 +187 2 1.86 7.03 12.37 +188 2 1.08 9.16 10.59 +189 2 10.19 12.39 8.15 +190 2 0.92 2.56 3.28 +191 2 9.34 3.01 4.17 +192 2 1.11 10.09 7.21 diff --git a/examples/water/dpa4/lora_ft.json b/examples/water/dpa4/lora_ft.json new file mode 100644 index 0000000000..dd30dcaf3f --- /dev/null +++ b/examples/water/dpa4/lora_ft.json @@ -0,0 +1,161 @@ +{ + "model": { + "type": "dpa4", + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa4", + "sel": 120, + "rcut": 6.0, + "env_exp": [ + 7, + 5 + ], + "channels": 64, + "n_radial": 16, + "radial_mlp": [ + 0 + ], + "use_env_seed": true, + "random_gamma": true, + "lmax": 3, + "mmax": 1, + "n_blocks": 3, + "so2_layers": 4, + "so2_norm": false, + "so2_attn_res": "none", + "radial_so2_mode": "degree_channel", + "radial_so2_rank": 1, + "n_focus": 1, + "focus_dim": 0, + "n_atten_head": 1, + "atten_f_mix": false, + "atten_v_proj": false, + "atten_o_proj": false, + "ffn_neurons": 0, + "grid_mlp": false, + "ffn_blocks": 1, + "sandwich_norm": [ + false, + true, + true, + false + ], + "mlp_bias": false, + "layer_scale": false, + "full_attn_res": "none", + "block_attn_res": "none", + "s2_activation": [ + false, + true + ], + "lebedev_quadrature": true, + "activation_function": "silu", + "glu_activation": true, + "use_amp": true, + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "fitting_net": { + "neuron": [ + 0 + ], + "activation_function": "silu", + "precision": "float32", + "seed": 42, + "_comment": "that's all" + }, + "use_compile": false, + "bridging_method": "none", + "bridging_r_inner": 0.8, + "bridging_r_outer": 1.2, + "lora": { + "rank": 16, + "alpha": 16.0 + }, + "enable_tf32": true, + "_comment": "that's all" + }, + "learning_rate": { + "type": "cosine", + "start_lr": 0.0005, + "stop_lr": 1e-06, + "warmup_steps": 100, + "warmup_start_factor": 0.2 + }, + "loss": { + "type": "ener", + "loss_func": "mae", + "f_use_norm": true, + "start_pref_e": 20, + "limit_pref_e": 20, + "start_pref_f": 20, + "limit_pref_f": 20, + "start_pref_v": 5, + "limit_pref_v": 5 + }, + "optimizer": { + "type": "HybridMuon", + "muon_mode": "slice", + "magma_muon": true, + "lr_adjust": 0.0, + "weight_decay": 0.001 + }, + "training": { + "stat_file": "./dpa4.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "min_pair_dist": 1.0, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "numb_btch": 1, + "_comment": "that's all" + }, + "numb_steps": 1000000, + "gradient_max_norm": 5.0, + "save_freq": 100, + "max_ckpt_keep": 3, + "enable_ema": true, + "ema_decay": 0.999, + "ema_ckpt_keep": 3, + "disp_file": "lcurve.out", + "disp_freq": 100, + "disp_avg": true, + "disp_training": true, + "time_training": true, + "tensorboard": false, + "enable_profiler": false, + "tensorboard_freq": 1000, + "tensorboard_log_dir": "tb_log", + "profiling": false, + "profiling_file": "timeline.json", + "zero_stage": 0, + "seed": 7, + "_comment": "that's all" + }, + "validating": { + "full_validation": true, + "ema_full_validation": false, + "validation_freq": 100, + "save_best": true, + "max_best_ckpt": 1, + "validation_metric": "E:MAE", + "full_val_file": "val.log", + "full_val_start": 0.0, + "_comment": "that's all" + }, + "_comment": "that's all" +} diff --git a/pyproject.toml b/pyproject.toml index 6c55e504b4..5788f53d08 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ 'h5py', "h5py>=3.6.0,!=3.11.0; platform_system=='Linux' and platform_machine=='aarch64'", 'wcmatch', + "einops", 'packaging', 'ml_dtypes', 'mendeleev', @@ -399,6 +400,7 @@ select = [ ignore = [ "ANN401", # Allow Any due to too many violations "E501", # line too long + "F722", # syntax error in type annotation for jaxtyping "F841", # local variable is assigned to but never used "RUF059", # unused-unpacked-variable "E741", # ambiguous variable name diff --git a/source/api_cc/include/DeepPotPTExpt.h b/source/api_cc/include/DeepPotPTExpt.h index 3559702f6a..bb4377919c 100644 --- a/source/api_cc/include/DeepPotPTExpt.h +++ b/source/api_cc/include/DeepPotPTExpt.h @@ -206,9 +206,11 @@ class DeepPotPTExpt : public DeepPotBackend { int ntypes; int dfparam; int daparam; + int dim_chg_spin; bool aparam_nall; bool has_default_fparam_; std::vector default_fparam_; + std::vector default_chg_spin_; double rcut; int gpu_id; bool gpu_enabled; diff --git a/source/api_cc/include/DeepSpinPTExpt.h b/source/api_cc/include/DeepSpinPTExpt.h index cc1304c69e..3294a8f6e7 100644 --- a/source/api_cc/include/DeepSpinPTExpt.h +++ b/source/api_cc/include/DeepSpinPTExpt.h @@ -179,9 +179,11 @@ class DeepSpinPTExpt : public DeepSpinBackend { int ntypes_spin; int dfparam; int daparam; + int dim_chg_spin; bool aparam_nall; bool has_default_fparam_; std::vector default_fparam_; + std::vector default_chg_spin_; std::vector use_spin_; double rcut; int gpu_id; diff --git a/source/api_cc/src/DeepPotPTExpt.cc b/source/api_cc/src/DeepPotPTExpt.cc index 910c2f6f7a..cfe9fb316e 100644 --- a/source/api_cc/src/DeepPotPTExpt.cc +++ b/source/api_cc/src/DeepPotPTExpt.cc @@ -20,6 +20,7 @@ #include "neighbor_list.h" using deepmd::ptexpt::parse_json; +using deepmd::ptexpt::read_default_chg_spin; using deepmd::ptexpt::read_zip_entry; using namespace deepmd; @@ -98,9 +99,14 @@ void DeepPotPTExpt::init(const std::string& model, auto metadata = parse_json(metadata_json); rcut = metadata["rcut"].as_double(); - ntypes = static_cast(metadata["type_map"].as_array().size()); + ntypes = metadata.obj_val.count("ntypes") + ? metadata["ntypes"].as_int() + : static_cast(metadata["type_map"].as_array().size()); dfparam = metadata["dim_fparam"].as_int(); daparam = metadata["dim_aparam"].as_int(); + dim_chg_spin = metadata.obj_val.count("dim_chg_spin") + ? metadata["dim_chg_spin"].as_int() + : 0; aparam_nall = false; // pt_expt models use nloc for aparam if (metadata.obj_val.count("has_default_fparam")) { has_default_fparam_ = metadata["has_default_fparam"].as_bool(); @@ -126,6 +132,7 @@ void DeepPotPTExpt::init(const std::string& model, << std::endl; } } + default_chg_spin_ = read_default_chg_spin(metadata, dim_chg_spin); if (metadata.obj_val.count("do_atomic_virial")) { do_atomic_virial = metadata["do_atomic_virial"].as_bool(); @@ -233,6 +240,13 @@ std::vector DeepPotPTExpt::run_model( if (daparam > 0) { inputs.push_back(aparam); } + if (dim_chg_spin > 0) { + auto charge_spin = torch::tensor(default_chg_spin_, coord.options()) + .view({1, dim_chg_spin}) + .expand({coord.size(0), dim_chg_spin}) + .contiguous(); + inputs.push_back(charge_spin); + } return loader->run(inputs); } @@ -267,6 +281,13 @@ std::vector DeepPotPTExpt::run_model_with_comm( if (daparam > 0) { inputs.push_back(aparam); } + if (dim_chg_spin > 0) { + auto charge_spin = torch::tensor(default_chg_spin_, coord.options()) + .view({1, dim_chg_spin}) + .expand({coord.size(0), dim_chg_spin}) + .contiguous(); + inputs.push_back(charge_spin); + } for (const auto& t : comm_tensors) { inputs.push_back(t); } diff --git a/source/api_cc/src/DeepSpinPTExpt.cc b/source/api_cc/src/DeepSpinPTExpt.cc index 2ac4369f5f..7ab7a60b61 100644 --- a/source/api_cc/src/DeepSpinPTExpt.cc +++ b/source/api_cc/src/DeepSpinPTExpt.cc @@ -20,6 +20,7 @@ #include "neighbor_list.h" using deepmd::ptexpt::parse_json; +using deepmd::ptexpt::read_default_chg_spin; using deepmd::ptexpt::read_zip_entry; using namespace deepmd; @@ -96,9 +97,14 @@ void DeepSpinPTExpt::init(const std::string& model, auto metadata = parse_json(metadata_json); rcut = metadata["rcut"].as_double(); - ntypes = static_cast(metadata["type_map"].as_array().size()); + ntypes = metadata.obj_val.count("ntypes") + ? metadata["ntypes"].as_int() + : static_cast(metadata["type_map"].as_array().size()); dfparam = metadata["dim_fparam"].as_int(); daparam = metadata["dim_aparam"].as_int(); + dim_chg_spin = metadata.obj_val.count("dim_chg_spin") + ? metadata["dim_chg_spin"].as_int() + : 0; aparam_nall = false; // Spin-specific metadata @@ -137,6 +143,7 @@ void DeepSpinPTExpt::init(const std::string& model, << std::endl; } } + default_chg_spin_ = read_default_chg_spin(metadata, dim_chg_spin); if (metadata.obj_val.count("do_atomic_virial")) { do_atomic_virial = metadata["do_atomic_virial"].as_bool(); @@ -238,6 +245,13 @@ std::vector DeepSpinPTExpt::run_model( if (daparam > 0) { inputs.push_back(aparam); } + if (dim_chg_spin > 0) { + auto charge_spin = torch::tensor(default_chg_spin_, coord.options()) + .view({1, dim_chg_spin}) + .expand({coord.size(0), dim_chg_spin}) + .contiguous(); + inputs.push_back(charge_spin); + } return loader->run(inputs); } @@ -271,6 +285,13 @@ std::vector DeepSpinPTExpt::run_model_with_comm( if (daparam > 0) { inputs.push_back(aparam); } + if (dim_chg_spin > 0) { + auto charge_spin = torch::tensor(default_chg_spin_, coord.options()) + .view({1, dim_chg_spin}) + .expand({coord.size(0), dim_chg_spin}) + .contiguous(); + inputs.push_back(charge_spin); + } for (const auto& t : comm_tensors) { inputs.push_back(t); } diff --git a/source/api_cc/src/commonPTExpt.h b/source/api_cc/src/commonPTExpt.h index 2d5d773b02..b2adbd63d1 100644 --- a/source/api_cc/src/commonPTExpt.h +++ b/source/api_cc/src/commonPTExpt.h @@ -257,6 +257,29 @@ inline JsonValue parse_json(const std::string& s) { return parser.parse(); } +inline std::vector read_default_chg_spin(const JsonValue& metadata, + const int dim_chg_spin) { + std::vector default_chg_spin; + if (dim_chg_spin <= 0) { + return default_chg_spin; + } + if (!metadata.obj_val.count("default_chg_spin")) { + throw deepmd::deepmd_exception( + "Model requires charge/spin conditions but default_chg_spin is " + "missing from metadata."); + } + for (const auto& v : metadata["default_chg_spin"].as_array()) { + default_chg_spin.push_back(v.as_double()); + } + if (static_cast(default_chg_spin.size()) != dim_chg_spin) { + throw deepmd::deepmd_exception("default_chg_spin length (" + + std::to_string(default_chg_spin.size()) + + ") does not match dim_chg_spin (" + + std::to_string(dim_chg_spin) + ")."); + } + return default_chg_spin; +} + // ============================================================================ // ZIP archive reader — reads a file from a ZIP archive. // ============================================================================ diff --git a/source/api_cc/tests/test_deeppot_ptexpt.cc b/source/api_cc/tests/test_deeppot_ptexpt.cc index 201724a725..4f90555839 100644 --- a/source/api_cc/tests/test_deeppot_ptexpt.cc +++ b/source/api_cc/tests/test_deeppot_ptexpt.cc @@ -11,6 +11,9 @@ #include "DeepPot.h" #include "DeepPotPTExpt.h" +#if defined(BUILD_PYTORCH) +#include "../src/commonPTExpt.h" +#endif #include "neighbor_list.h" #include "test_utils.h" @@ -94,6 +97,35 @@ deepmd::DeepPot TestInferDeepPotAPtExpt::dp; TYPED_TEST_SUITE(TestInferDeepPotAPtExpt, ValueTypes); +#if defined(BUILD_PYTORCH) +TEST(TestPtExptMetadata, default_chg_spin_is_optional_when_dim_is_zero) { + auto metadata = deepmd::ptexpt::parse_json("{}"); + auto value = deepmd::ptexpt::read_default_chg_spin(metadata, 0); + EXPECT_TRUE(value.empty()); +} + +TEST(TestPtExptMetadata, default_chg_spin_is_read_when_required) { + auto metadata = + deepmd::ptexpt::parse_json(R"({"default_chg_spin": [0.0, 1.0]})"); + auto value = deepmd::ptexpt::read_default_chg_spin(metadata, 2); + ASSERT_EQ(value.size(), 2); + EXPECT_DOUBLE_EQ(value[0], 0.0); + EXPECT_DOUBLE_EQ(value[1], 1.0); +} + +TEST(TestPtExptMetadata, default_chg_spin_missing_throws) { + auto metadata = deepmd::ptexpt::parse_json("{}"); + EXPECT_THROW(deepmd::ptexpt::read_default_chg_spin(metadata, 2), + deepmd::deepmd_exception); +} + +TEST(TestPtExptMetadata, default_chg_spin_length_mismatch_throws) { + auto metadata = deepmd::ptexpt::parse_json(R"({"default_chg_spin": [0.0]})"); + EXPECT_THROW(deepmd::ptexpt::read_default_chg_spin(metadata, 2), + deepmd::deepmd_exception); +} +#endif + TYPED_TEST(TestInferDeepPotAPtExpt, cpu_build_nlist) { using VALUETYPE = TypeParam; std::vector& coord = this->coord; diff --git a/source/lmp/pair_deepmd.cpp b/source/lmp/pair_deepmd.cpp index 9ea2021570..70c1bc4cae 100644 --- a/source/lmp/pair_deepmd.cpp +++ b/source/lmp/pair_deepmd.cpp @@ -190,9 +190,25 @@ void PairDeepMD::compute(int eflag, int vflag) { // mapping (for DPA-2 JAX) std::vector mapping_vec(nall, -1); - if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) { - for (size_t ii = 0; ii < nall; ++ii) { - mapping_vec[ii] = atom->map(atom->tag[ii]); + bool has_mapping = false; + if (comm->nprocs == 1) { + if (atom->map_style != Atom::MAP_NONE) { + for (size_t ii = 0; ii < nall; ++ii) { + mapping_vec[ii] = atom->map(atom->tag[ii]); + } + has_mapping = true; + } else if (atom->tag_enable) { + std::map local_index_by_tag; + for (int ii = 0; ii < nlocal; ++ii) { + local_index_by_tag[atom->tag[ii]] = ii; + } + for (int ii = 0; ii < nall; ++ii) { + auto it = local_index_by_tag.find(atom->tag[ii]); + if (it != local_index_by_tag.end()) { + mapping_vec[ii] = it->second; + } + } + has_mapping = true; } } @@ -239,7 +255,7 @@ void PairDeepMD::compute(int eflag, int vflag) { commdata_->firstrecv, commdata_->sendlist, commdata_->sendproc, commdata_->recvproc, &world); lmp_list.set_mask(NEIGHMASK); - if (comm->nprocs == 1 && atom->map_style != Atom::MAP_NONE) { + if (has_mapping) { lmp_list.set_mapping(mapping_vec.data()); } deepmd_compat::InputNlist extend_lmp_list; diff --git a/source/lmp/pair_deepspin.cpp b/source/lmp/pair_deepspin.cpp index eddcb2eef4..5c3bd6b4ad 100644 --- a/source/lmp/pair_deepspin.cpp +++ b/source/lmp/pair_deepspin.cpp @@ -201,6 +201,29 @@ void PairDeepSpin::compute(int eflag, int vflag) { } } + std::vector mapping_vec(nall, -1); + bool has_mapping = false; + if (comm->nprocs == 1) { + if (atom->map_style != Atom::MAP_NONE) { + for (size_t ii = 0; ii < nall; ++ii) { + mapping_vec[ii] = atom->map(atom->tag[ii]); + } + has_mapping = true; + } else if (atom->tag_enable) { + std::map local_index_by_tag; + for (int ii = 0; ii < nlocal; ++ii) { + local_index_by_tag[atom->tag[ii]] = ii; + } + for (int ii = 0; ii < nall; ++ii) { + auto it = local_index_by_tag.find(atom->tag[ii]); + if (it != local_index_by_tag.end()) { + mapping_vec[ii] = it->second; + } + } + has_mapping = true; + } + } + if (do_compute_aparam) { make_aparam_from_compute(daparam); } else if (aparam.size() > 0) { @@ -244,6 +267,9 @@ void PairDeepSpin::compute(int eflag, int vflag) { commdata_->firstrecv, commdata_->sendlist, commdata_->sendproc, commdata_->recvproc, &world); lmp_list.set_mask(NEIGHMASK); + if (has_mapping) { + lmp_list.set_mapping(mapping_vec.data()); + } if (single_model || multi_models_no_mod_devi) { // cvflag_atom is the right flag for the cvatom matrix if (!(eflag_atom || cvflag_atom)) { diff --git a/source/tests/common/dpmodel/test_dist_check.py b/source/tests/common/dpmodel/test_dist_check.py new file mode 100644 index 0000000000..f64034e59d --- /dev/null +++ b/source/tests/common/dpmodel/test_dist_check.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Unit tests for min_pair_dist frame filtering.""" + +import unittest + +import numpy as np + +from deepmd.dpmodel.utils.dist_check import ( + compute_min_pair_dist_single, +) + + +class TestComputeMinPairDistSingle(unittest.TestCase): + """Test minimum pairwise distance computation.""" + + def test_three_atoms_no_pbc(self) -> None: + """Three atoms, closest pair is 0.3 Å.""" + coord = np.array( + [ + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 1.3, + 0.0, + 0.0, + ] + ) + atype = np.array([0, 0, 1]) + dist = compute_min_pair_dist_single(coord, box=None, atype=atype) + np.testing.assert_almost_equal(dist, 0.3) + + def test_pbc_minimum_image(self) -> None: + """Two atoms near opposite edges of a 10 Å cubic box. + + Real-space distance is 9.0 Å, but minimum image distance is 1.0 Å. + """ + coord = np.array([0.5, 5.0, 5.0, 9.5, 5.0, 5.0]) + box = np.array([10.0, 0, 0, 0, 10.0, 0, 0, 0, 10.0]) + atype = np.array([0, 0]) + dist = compute_min_pair_dist_single(coord, box=box, atype=atype) + np.testing.assert_almost_equal(dist, 1.0) + + def test_pbc_triclinic(self) -> None: + """Triclinic box with atoms near boundary.""" + # Triclinic box: a=(10,0,0), b=(2,10,0), c=(0,0,10) + box = np.array([10.0, 0, 0, 2.0, 10.0, 0, 0, 0, 10.0]) + coord = np.array([0.2, 0.0, 0.0, 9.8, 0.0, 0.0]) + atype = np.array([0, 0]) + dist = compute_min_pair_dist_single(coord, box=box, atype=atype) + np.testing.assert_almost_equal(dist, 0.4, decimal=5) + + def test_virtual_atoms_excluded(self) -> None: + """Virtual atoms (type < 0) should be excluded.""" + coord = np.array( + [ + 0.0, + 0.0, + 0.0, + 0.1, + 0.0, + 0.0, + 2.0, + 0.0, + 0.0, + ] + ) + atype = np.array([0, -1, 1]) + dist = compute_min_pair_dist_single(coord, box=None, atype=atype) + np.testing.assert_almost_equal(dist, 2.0) + + def test_single_real_atom(self) -> None: + """Only one real atom returns inf.""" + coord = np.array([0.0, 0.0, 0.0, 1.0, 0.0, 0.0]) + atype = np.array([0, -1]) + dist = compute_min_pair_dist_single(coord, box=None, atype=atype) + self.assertEqual(dist, float("inf")) + + def test_all_virtual(self) -> None: + """All virtual atoms return inf.""" + coord = np.array([0.0, 0.0, 0.0, 1.0, 0.0, 0.0]) + atype = np.array([-1, -1]) + dist = compute_min_pair_dist_single(coord, box=None, atype=atype) + self.assertEqual(dist, float("inf")) + + def test_coord_shape_2d(self) -> None: + """Accept (natoms, 3) shaped coord.""" + coord = np.array([[0.0, 0.0, 0.0], [0.8, 0.0, 0.0]]) + atype = np.array([0, 1]) + dist = compute_min_pair_dist_single(coord, box=None, atype=atype) + np.testing.assert_almost_equal(dist, 0.8) + + def test_stop_below_triggers_early_exit(self) -> None: + """A pair below stop_below should still return the correct minimum.""" + coord = np.array([0.0, 0.0, 0.0, 0.05, 0.0, 0.0, 10.0, 0.0, 0.0]) + atype = np.array([0, 0, 0]) + dist = compute_min_pair_dist_single( + coord, box=None, atype=atype, stop_below=0.1 + ) + np.testing.assert_almost_equal(dist, 0.05) + + def test_stop_below_not_triggered(self) -> None: + """If all pairs are above stop_below, the true minimum is returned.""" + coord = np.array([0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0]) + atype = np.array([0, 0, 0]) + dist = compute_min_pair_dist_single( + coord, box=None, atype=atype, stop_below=0.5 + ) + np.testing.assert_almost_equal(dist, 1.0) + + def test_multi_block_iteration(self) -> None: + """>512 atoms exercises multiple row blocks.""" + rng = np.random.default_rng(42) + nloc = 600 + coord = rng.uniform(0.0, 100.0, (nloc, 3)) + atype = np.zeros(nloc, dtype=np.int64) + diff = coord[:, np.newaxis, :] - coord[np.newaxis, :, :] + dist = np.sqrt(np.sum(diff * diff, axis=-1)) + np.fill_diagonal(dist, np.inf) + ref = dist.min() + + actual = compute_min_pair_dist_single(coord, box=None, atype=atype) + np.testing.assert_almost_equal(actual, ref, decimal=10) + + def test_coincident_atoms_zero(self) -> None: + """Coincident real atoms should return exactly zero.""" + coord = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]) + atype = np.array([0, 0, 0]) + dist = compute_min_pair_dist_single(coord, box=None, atype=atype) + self.assertEqual(dist, 0.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/common/dpmodel/test_lmdb_data.py b/source/tests/common/dpmodel/test_lmdb_data.py index 656c7cff0b..0838cca0ab 100644 --- a/source/tests/common/dpmodel/test_lmdb_data.py +++ b/source/tests/common/dpmodel/test_lmdb_data.py @@ -21,6 +21,9 @@ is_lmdb, make_neighbor_stat_data, ) +from deepmd.utils.data import ( + DataRequirementItem, +) # ============================================================ # LMDB creation helpers @@ -324,6 +327,26 @@ def test_lmdb_test_data(self): self.assertEqual(result["find_energy"], 1.0) self.assertEqual(result["find_force"], 1.0) + def test_min_pair_dist_requirement_computed(self): + path = _create_grid_lmdb(f"{self._tmpdir.name}/grid_min_pair.lmdb", nframes=1) + reader = LmdbDataReader(path, ["TYPE"], batch_size=1) + reader.add_data_requirement( + [ + DataRequirementItem( + "min_pair_dist", + ndof=1, + atomic=False, + must=False, + high_prec=False, + ) + ] + ) + + frame = reader[0] + + self.assertEqual(frame["find_min_pair_dist"], np.float32(1.0)) + np.testing.assert_allclose(frame["min_pair_dist"], np.array([1.0])) + # ============================================================ # Mixed nloc tests diff --git a/source/tests/common/dpmodel/test_nlist.py b/source/tests/common/dpmodel/test_nlist.py index baed11a961..7f1a28e080 100644 --- a/source/tests/common/dpmodel/test_nlist.py +++ b/source/tests/common/dpmodel/test_nlist.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest +from importlib.util import ( + find_spec, +) import numpy as np @@ -300,3 +303,21 @@ def test_extend_coord(self) -> None: rtol=self.prec, atol=self.prec, ) + + @unittest.skipIf(find_spec("jax") is None, "JAX is not installed") + def test_extend_coord_jax_matches_numpy(self) -> None: + import jax.numpy as jnp + + ecoord_np, eatype_np, mapping_np = extend_coord_with_ghosts( + self.coord, self.atype, self.cell, self.rcut + ) + ecoord_jax, eatype_jax, mapping_jax = extend_coord_with_ghosts( + jnp.asarray(self.coord), + jnp.asarray(self.atype), + jnp.asarray(self.cell), + self.rcut, + ) + + np.testing.assert_allclose(np.asarray(ecoord_jax), ecoord_np, atol=1e-6) + np.testing.assert_array_equal(np.asarray(eatype_jax), eatype_np) + np.testing.assert_array_equal(np.asarray(mapping_jax), mapping_np) diff --git a/source/tests/common/test_examples.py b/source/tests/common/test_examples.py index 5f7cd323ee..1647494403 100644 --- a/source/tests/common/test_examples.py +++ b/source/tests/common/test_examples.py @@ -60,6 +60,10 @@ p_examples / "water" / "dpa2" / "input_torch_compressible.json", p_examples / "water" / "dpa3" / "input_torch.json", p_examples / "water" / "dpa3" / "input_torch_dynamic.json", + p_examples / "water" / "dpa4" / "input.json", + p_examples / "water" / "dpa4" / "input-zbl.json", + p_examples / "water" / "dpa4" / "input-spin.json", + p_examples / "water" / "dpa4" / "lmp" / "input.json", p_examples / "property" / "train" / "input_torch.json", p_examples / "water" / "se_e3_tebd" / "input_torch.json", p_examples / "hessian" / "single_task" / "input.json", @@ -72,6 +76,9 @@ p_examples / "water_multi_task" / "pytorch_example" / "input_torch_sharefit.json", p_examples / "water_multi_task" / "pytorch_example" / "input_torch_with_alias.json", p_examples / "hessian" / "multi_task" / "input.json", + p_examples / "water" / "dpa4" / "input_multitask.json", + p_examples / "water" / "dpa4" / "input_multitask_sharefit.json", + p_examples / "water" / "dpa4" / "input_multitask_sharefit-zbl.json", p_examples / "water_multi_task" / "pytorch_example" diff --git a/source/tests/pt/model/test_descriptor_sezm.py b/source/tests/pt/model/test_descriptor_sezm.py new file mode 100644 index 0000000000..fc77011094 --- /dev/null +++ b/source/tests/pt/model/test_descriptor_sezm.py @@ -0,0 +1,1740 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import math +import unittest + +import torch + +from deepmd.pt.model.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt.model.descriptor.sezm import ( + DescrptSeZM, +) +from deepmd.pt.model.descriptor.sezm_nn import ( + DynamicRadialDegreeMixer, + ForceEmbedding, + InnerClamp, + SeZMDirectForceHead, + SO2Linear, + WignerDCalculator, + build_edge_quaternion, + build_m_major_l_index, + quaternion_multiply, + quaternion_to_rotation_matrix, +) +from deepmd.pt.model.model import ( + get_sezm_model, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + PRECISION_DICT, +) + + +def _random_quaternion( + n_batch: int, + *, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """Sample normalized quaternions in ``(w, x, y, z)`` order.""" + sample_dtype = torch.float32 if dtype in (torch.float16, torch.bfloat16) else dtype + q = torch.randn(n_batch, 4, device=device, dtype=sample_dtype) + q = q / torch.sqrt( + torch.sum(q * q, dim=-1, keepdim=True).clamp_min(torch.finfo(sample_dtype).eps) + ) + return q.to(dtype=dtype) + + +def _tiny_two_atom_system( + device: torch.device, + dtype: torch.dtype, + *, + padded: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Create a minimal two-atom system for descriptor tests.""" + coord = torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], + dtype=dtype, + device=device, + ).view(1, -1, 3) + atype = torch.tensor([[0, 1]], dtype=torch.int32, device=device) + nlist_values = [[[1, -1], [0, -1]]] if padded else [[[1, 1], [0, 0]]] + nlist = torch.tensor(nlist_values, dtype=torch.int64, device=device) + return coord, atype, nlist + + +def _descriptor_kwargs(**overrides) -> dict: + """Build a compact SeZM descriptor config for tests.""" + kwargs = { + "rcut": 3.0, + "sel": [1, 1], + "ntypes": 2, + "l_schedule": [1, 0], + "channels": 4, + "n_radial": 3, + "radial_mlp": [6], + "ffn_neurons": 8, + "ffn_blocks": 1, + "random_gamma": False, + "n_atten_head": 0, + "mlp_bias": True, + "precision": "float32", + "trainable": True, + } + kwargs.update(overrides) + return kwargs + + +def _attention_descriptor_kwargs( + *, + precision: str = "float32", + seed: int | None = None, + **overrides, +) -> dict: + """Build a richer attention-enabled SeZM descriptor config for tests.""" + kwargs = _descriptor_kwargs( + l_schedule=[1, 1, 0], + channels=8, + n_focus=2, + focus_dim=0, + n_radial=4, + radial_mlp=[8], + so2_layers=2, + full_attn_res="dependent", + so2_attn_res="dependent", + ffn_neurons=16, + ffn_blocks=2, + layer_scale=False, + precision=precision, + seed=seed, + ) + kwargs.update(overrides) + return kwargs + + +def _forward_tols(dtype: torch.dtype) -> tuple[float, float]: + """Return output comparison tolerances for one dtype.""" + if dtype == torch.float64: + return 1e-10, 1e-10 + if dtype == torch.float32: + return 5e-5, 5e-5 + return 5e-3, 5e-3 + + +def _parameter_tols(dtype: torch.dtype) -> tuple[float, float]: + """Return parameter comparison tolerances for one dtype.""" + if dtype == torch.float64: + return 1e-10, 1e-10 + if dtype == torch.float32: + return 1e-6, 1e-6 + return 1e-3, 1e-3 + + +class _SeZMTestCase(unittest.TestCase): + """Base test case with the shared device setup.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + +class TestDescrptSeZM(_SeZMTestCase): + """Test the SeZM descriptor.""" + + def _assert_forward_backward_smoke(self, **model_kwargs) -> DescrptSeZM: + """Run a compact forward/backward smoke test and return the model.""" + coord, atype, nlist = _tiny_two_atom_system(self.device, dtype=torch.float32) + extended_coord = coord.reshape(1, -1).detach().requires_grad_(True) + model = DescrptSeZM(**model_kwargs) + desc, *_ = model(extended_coord, atype, nlist, mapping=None, comm_dict=None) + self.assertEqual(desc.shape, (1, 2, model_kwargs["channels"])) + self.assertEqual(desc.dtype, env.GLOBAL_PT_FLOAT_PRECISION) + desc.sum().backward() + self.assertIsNotNone(extended_coord.grad) + self.assertTrue(torch.all(torch.isfinite(extended_coord.grad))) + return model + + def test_dpa4_alias_constructs_descriptor(self) -> None: + """DPA4 should be the primary user-facing alias for the SeZM descriptor.""" + model = BaseDescriptor(type="dpa4", **_descriptor_kwargs()) + + self.assertIsInstance(model, DescrptSeZM) + + def test_dpa4_alias_deserializes_descriptor(self) -> None: + """Serialized descriptor payloads should accept the DPA4 type string.""" + data = DescrptSeZM(**_descriptor_kwargs(seed=123)).serialize() + data["type"] = "dpa4" + + restored = BaseDescriptor.deserialize(data) + + self.assertIsInstance(restored, DescrptSeZM) + + def test_forward_with_descriptor_variants(self) -> None: + """Test forward/backward smoke paths for compact descriptor variants.""" + cases = { + "focus_dim_zero": _descriptor_kwargs( + channels=4, + n_focus=2, + focus_dim=0, + so2_layers=2, + ), + "focus_dim_explicit": _descriptor_kwargs( + channels=4, + n_focus=2, + focus_dim=3, + so2_layers=2, + ), + "focus_dim_zero_s2": _descriptor_kwargs( + channels=4, + n_focus=2, + focus_dim=0, + so2_layers=2, + s2_activation=[False, True], + ), + "focus_dim_zero_s2_lebedev": _descriptor_kwargs( + channels=4, + n_focus=2, + focus_dim=0, + so2_layers=2, + s2_activation=[False, True], + lebedev_quadrature=[False, True], + ), + "gaussian_basis": _descriptor_kwargs( + channels=4, + basis_type="gaussian", + ), + "radial_so2_degree": _descriptor_kwargs( + channels=4, + n_focus=2, + focus_dim=0, + so2_layers=2, + radial_so2_mode="degree", + ), + "radial_so2_degree_channel_rank2": _descriptor_kwargs( + channels=4, + n_focus=2, + focus_dim=0, + so2_layers=2, + radial_so2_mode="degree_channel", + radial_so2_rank=2, + ), + } + for name, model_kwargs in cases.items(): + with self.subTest(mode=name): + self._assert_forward_backward_smoke(**model_kwargs) + + def test_forward_with_attention_variants(self) -> None: + """Test forward/backward smoke paths for attention-based variants.""" + cases = { + "full_attention": _attention_descriptor_kwargs( + precision="float32", + seed=123, + ), + "block_attention": _attention_descriptor_kwargs( + precision="float32", + seed=123, + full_attn_res="none", + block_attn_res="dependent", + ), + "full_attention_s2": _attention_descriptor_kwargs( + precision="float32", + seed=123, + s2_activation=[False, True], + ), + "mixed_so2_attention": _attention_descriptor_kwargs( + precision="float32", + seed=123, + n_atten_head=4, + atten_f_mix=True, + ), + "standard_single_head_attention": _attention_descriptor_kwargs( + precision="float32", + seed=123, + n_atten_head=1, + atten_v_proj=True, + atten_o_proj=True, + ), + } + for name, model_kwargs in cases.items(): + with self.subTest(mode=name): + self._assert_forward_backward_smoke(**model_kwargs) + + def test_forward_backward_second_order_fixed_edges(self) -> None: + """Test fixed-shape edge path matches nlist for fwd/bwd/2nd order.""" + dtype = torch.float32 + coord = torch.tensor( + [[0.1, 0.2, 0.3], [1.1, 0.7, 0.2]], + dtype=dtype, + device=self.device, + ).view(1, -1, 3) + atype = torch.tensor([[0, 1]], dtype=torch.int32, device=self.device) + nlist = torch.tensor([[[1, 1], [0, 0]]], dtype=torch.int64, device=self.device) + extended_coord = coord.reshape(1, -1).detach().requires_grad_(True) + + model = DescrptSeZM( + **_attention_descriptor_kwargs( + precision="float32", + channels=4, + n_focus=1, + n_radial=3, + radial_mlp=[6], + ffn_neurons=8, + ) + ) + + desc_nlist, _, _, _, sw_nlist = model( + extended_coord, atype, nlist, mapping=None, comm_dict=None + ) + + # Fixed-shape edge list for n_node=2, nsel=2 + edge_index = torch.tensor( + [[1, 0, 0, 0], [0, 0, 1, 1]], + dtype=torch.long, + device=self.device, + ) + coord_view = extended_coord.view(1, 2, 3) + valid_nlist = nlist >= 0 + gather_index = torch.where(valid_nlist, nlist, torch.zeros_like(nlist)) + index = gather_index.view(1, 4, 1).expand(-1, -1, 3) + nei_pos = torch.gather(coord_view, 1, index).view(1, 2, 2, 3) + atom_pos = coord_view[:, :2].unsqueeze(2) + diff = nei_pos - atom_pos + edge_vec = diff.reshape(4, 3) + edge_mask = torch.tensor([1, 1, 1, 1], dtype=torch.bool, device=self.device) + + desc_edge, _, _, _, sw_edge = model( + extended_coord, + atype, + nlist, + mapping=None, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + comm_dict=None, + ) + + torch.testing.assert_close(desc_nlist, desc_edge, atol=1e-6, rtol=1e-6) + + loss_nlist = desc_nlist.sum() + loss_edge = desc_edge.sum() + + (grad_nlist,) = torch.autograd.grad( + loss_nlist, extended_coord, create_graph=True + ) + (grad_edge,) = torch.autograd.grad(loss_edge, extended_coord, create_graph=True) + torch.testing.assert_close(grad_nlist, grad_edge, atol=1e-6, rtol=1e-6) + + (grad2_nlist,) = torch.autograd.grad( + grad_nlist.sum(), extended_coord, create_graph=False + ) + (grad2_edge,) = torch.autograd.grad( + grad_edge.sum(), extended_coord, create_graph=False + ) + torch.testing.assert_close(grad2_nlist, grad2_edge, atol=1e-6, rtol=1e-6) + + def test_force_embedding_and_vector_head_roundtrip_l1_basis(self) -> None: + """Force encoding basis should match the vector-head Cartesian decode.""" + force_input = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [0.0, 0.0, -3.0], + ], + dtype=torch.float64, + device=self.device, + ) + encoder = ForceEmbedding( + lmax=1, + channels=1, + precision="float64", + mlp_bias=False, + trainable=True, + seed=123, + ).to(self.device) + decoder = SeZMDirectForceHead( + lmax=1, + channels=1, + precision="float64", + mlp_bias=False, + trainable=True, + seed=456, + ).to(self.device) + + with torch.no_grad(): + encoder.proj.weight.zero_() + encoder.proj.weight[1, 0, 0] = 1.0 + if encoder.proj.bias is not None: + encoder.proj.bias.zero_() + decoder.proj.weight.zero_() + decoder.proj.weight[1, 0, 0] = math.sqrt(3.0) + if decoder.proj.bias is not None: + decoder.proj.bias.zero_() + + latent = encoder(force_input) + decoded = decoder(latent) + torch.testing.assert_close(decoded, force_input, atol=1e-10, rtol=1e-10) + + def test_backward_gradient(self) -> None: + """Test backward gradient through coordinates.""" + for prec in ["float64", "float32", "bfloat16"]: + dtype = PRECISION_DICT[prec] + coord, atype, nlist = _tiny_two_atom_system(self.device, dtype=dtype) + extended_coord = coord.reshape(1, -1).detach().requires_grad_(True) + model = DescrptSeZM( + **_descriptor_kwargs( + ffn_blocks=2, + layer_scale=False, + precision=prec, + ) + ) + desc, *_ = model(extended_coord, atype, nlist, mapping=None, comm_dict=None) + loss = desc.sum() + loss.backward() + self.assertIsNotNone(extended_coord.grad) + self.assertTrue(torch.all(torch.isfinite(extended_coord.grad))) + + def test_serialization_deserialization(self) -> None: + """Test serialization and deserialization preserves model state.""" + cases = { + "focus_dim_zero": _attention_descriptor_kwargs( + precision="float32", + focus_dim=0, + ), + "focus_dim_explicit": _descriptor_kwargs( + precision="float32", + channels=4, + n_focus=2, + focus_dim=3, + so2_layers=2, + n_radial=3, + radial_mlp=[6], + ffn_neurons=8, + ), + "focus_dim_zero_s2": _descriptor_kwargs( + precision="float32", + channels=4, + n_focus=2, + focus_dim=0, + so2_layers=2, + n_radial=3, + radial_mlp=[6], + ffn_neurons=8, + s2_activation=[False, True], + ), + "focus_dim_zero_s2_lebedev": _descriptor_kwargs( + precision="float32", + channels=4, + n_focus=2, + focus_dim=0, + so2_layers=2, + n_radial=3, + radial_mlp=[6], + ffn_neurons=8, + s2_activation=[False, True], + lebedev_quadrature=[False, True], + ), + "radial_so2_degree": _descriptor_kwargs( + precision="float32", + channels=4, + n_focus=2, + focus_dim=0, + so2_layers=2, + n_radial=3, + radial_mlp=[6], + ffn_neurons=8, + radial_so2_mode="degree", + ), + "radial_so2_degree_channel_rank2": _descriptor_kwargs( + precision="float32", + channels=4, + n_focus=2, + focus_dim=0, + so2_layers=2, + n_radial=3, + radial_mlp=[6], + ffn_neurons=8, + radial_so2_mode="degree_channel", + radial_so2_rank=2, + ), + } + dtype = PRECISION_DICT["float32"] + for case_name, model_kwargs in cases.items(): + coord, atype, nlist = _tiny_two_atom_system(self.device, dtype=dtype) + extended_coord = coord.reshape(1, -1) + with self.subTest(mode=case_name): + model = DescrptSeZM(**model_kwargs) + + desc1, _, _, _, sw1 = model(extended_coord, atype, nlist) + data = model.serialize() + model_restored = DescrptSeZM.deserialize(data) + desc2, _, _, _, sw2 = model_restored(extended_coord, atype, nlist) + atol, rtol = _forward_tols(dtype) + + torch.testing.assert_close( + desc1, + desc2, + atol=atol, + rtol=rtol, + msg="Descriptor mismatch after deserialization", + ) + torch.testing.assert_close( + sw1, + sw2, + atol=atol, + rtol=rtol, + msg="Smooth weight mismatch after deserialization", + ) + + def test_charge_spin_sparse_edge_conditioning(self) -> None: + """Charge/spin conditions should affect the sparse-edge descriptor path.""" + coord, atype, nlist = _tiny_two_atom_system(self.device, dtype=torch.float32) + extended_coord = coord.reshape(1, -1) + edge_index = torch.tensor( + [[1, 0], [0, 1]], dtype=torch.long, device=self.device + ) + edge_vec = torch.tensor( + [[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]], + dtype=torch.float32, + device=self.device, + ) + edge_mask = torch.ones(2, dtype=torch.bool, device=self.device) + model = DescrptSeZM( + **_descriptor_kwargs( + add_chg_spin_ebd=True, + default_chg_spin=[0.0, 1.0], + seed=123, + ) + ) + + desc_default, *_ = model(extended_coord, atype, nlist) + desc_explicit, *_ = model( + extended_coord, + atype, + nlist, + charge_spin=torch.tensor([[0.0, 1.0]], device=self.device), + ) + desc_ref, _ = model.forward_with_edges( + extended_coord=coord, + extended_atype=atype, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + charge_spin=torch.tensor([[0.0, 1.0]], device=self.device), + ) + desc_shifted, _ = model.forward_with_edges( + extended_coord=coord, + extended_atype=atype, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + charge_spin=torch.tensor([[1.0, 1.0]], device=self.device), + ) + restored = DescrptSeZM.deserialize(model.serialize()) + desc_restored, _ = restored.forward_with_edges( + extended_coord=coord, + extended_atype=atype, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + charge_spin=torch.tensor([[0.0, 1.0]], device=self.device), + ) + + self.assertTrue(restored.add_chg_spin_ebd) + self.assertEqual(restored.get_default_chg_spin(), [0.0, 1.0]) + torch.testing.assert_close(desc_default, desc_explicit, atol=1e-6, rtol=1e-6) + self.assertFalse(torch.allclose(desc_ref, desc_shifted)) + torch.testing.assert_close(desc_ref, desc_restored, atol=1e-6, rtol=1e-6) + + def test_plain_descriptor_deserializes_without_condition_config(self) -> None: + """Plain descriptors should not depend on charge/spin condition fields.""" + coord, atype, nlist = _tiny_two_atom_system(self.device, dtype=torch.float32) + extended_coord = coord.reshape(1, -1) + model = DescrptSeZM(**_descriptor_kwargs(seed=123)) + self.assertTrue( + all("charge_spin_embedding" not in key for key in model.state_dict()) + ) + data = model.serialize() + data["config"].pop("add_chg_spin_ebd", None) + data["config"].pop("default_chg_spin", None) + + restored = DescrptSeZM.deserialize(data) + desc_ref, *_ = model(extended_coord, atype, nlist) + desc_new, *_ = restored(extended_coord, atype, nlist) + + self.assertFalse(restored.add_chg_spin_ebd) + self.assertTrue( + all("charge_spin_embedding" not in key for key in restored.state_dict()) + ) + torch.testing.assert_close(desc_ref, desc_new, atol=1e-6, rtol=1e-6) + + def test_seed_reproducibility(self) -> None: + """Test that fixed seed produces identical model initialization.""" + for prec in ["float64", "float32", "bfloat16"]: + dtype = PRECISION_DICT[prec] + seed = 12345 + + model_kwargs = _attention_descriptor_kwargs(precision=prec, seed=seed) + model1 = DescrptSeZM(**model_kwargs) + model2 = DescrptSeZM(**model_kwargs) + param_atol, param_rtol = _parameter_tols(dtype) + + for (n1, p1), (n2, p2) in zip( + model1.named_parameters(), model2.named_parameters(), strict=False + ): + self.assertEqual(n1, n2, msg="Parameter name mismatch") + torch.testing.assert_close( + p1, + p2, + atol=param_atol, + rtol=param_rtol, + msg=f"Parameter {n1} differs between models with same seed", + ) + + coord, atype, nlist = _tiny_two_atom_system(self.device, dtype=dtype) + extended_coord = coord.reshape(1, -1) + + desc1, _, _, _, sw1 = model1(extended_coord, atype, nlist) + desc2, _, _, _, sw2 = model2(extended_coord, atype, nlist) + forward_atol, forward_rtol = _forward_tols(dtype) + + torch.testing.assert_close( + desc1, + desc2, + atol=forward_atol, + rtol=forward_rtol, + msg="Forward output differs for models with same seed", + ) + torch.testing.assert_close( + sw1, + sw2, + atol=forward_atol, + rtol=forward_rtol, + msg="Smooth weight differs for models with same seed", + ) + + +class TestBuildEdgeQuaternion(_SeZMTestCase): + """Test the stable edge-quaternion chart used by SeZM.""" + + def setUp(self) -> None: + super().setUp() + torch.manual_seed(0) + + def _get_tols(self, dtype: torch.dtype) -> tuple[float, float]: + if dtype == torch.float64: + return 1e-10, 1e-10 + if dtype == torch.float32: + return 1e-4, 1e-4 + return 5e-3, 5e-3 + + def _safe_norm(self, x: torch.Tensor) -> torch.Tensor: + eps = torch.finfo(x.dtype).eps + return torch.sqrt(torch.sum(x * x, dim=-1, keepdim=True).clamp(min=eps)) + + def _assert_quaternion_invariants( + self, edge_quat: torch.Tensor, edge_vec: torch.Tensor + ) -> None: + atol, rtol = self._get_tols(edge_vec.dtype) + rot_mat = quaternion_to_rotation_matrix(edge_quat) + n_edge = int(edge_vec.shape[0]) + eye = torch.eye(3, device=self.device, dtype=edge_vec.dtype).expand( + n_edge, 3, 3 + ) + torch.testing.assert_close( + rot_mat @ rot_mat.transpose(-1, -2), + eye, + atol=atol, + rtol=rtol, + ) + + edge_unit = edge_vec / self._safe_norm(edge_vec) + ez = torch.tensor( + [0.0, 0.0, 1.0], device=self.device, dtype=edge_vec.dtype + ).expand(n_edge, 3) + rotated = (rot_mat @ edge_unit.unsqueeze(-1)).squeeze(-1) + torch.testing.assert_close(rotated, ez, atol=atol, rtol=rtol) + + det_mat = rot_mat.float() if rot_mat.dtype == torch.bfloat16 else rot_mat + det = torch.linalg.det(det_mat) + torch.testing.assert_close( + det, + torch.ones_like(det), + atol=atol, + rtol=rtol, + ) + + def test_invariants_random_edges(self) -> None: + for dtype in [torch.float64, torch.float32]: + edge_vec = torch.randn(512, 3, device=self.device, dtype=dtype) + edge_quat = build_edge_quaternion(edge_vec) + self._assert_quaternion_invariants(edge_quat, edge_vec) + + def test_invariants_near_poles(self) -> None: + for dtype in [torch.float64, torch.float32]: + delta = torch.tensor( + [-1.0e-3, -1.0e-4, 0.0, 1.0e-4, 1.0e-3], + device=self.device, + dtype=dtype, + ) + for sign in [1.0, -1.0]: + edge_vec = torch.stack( + [delta, torch.zeros_like(delta), torch.full_like(delta, sign)], + dim=-1, + ) + edge_quat = build_edge_quaternion(edge_vec) + self._assert_quaternion_invariants(edge_quat, edge_vec) + + +class TestWignerDCalculator(_SeZMTestCase): + """Test the quaternion-driven Wigner-D calculator.""" + + def setUp(self) -> None: + super().setUp() + self.batch = 8 + torch.manual_seed(0) + + def _get_tols(self, dtype: torch.dtype) -> tuple[float, float]: + if dtype == torch.float64: + return 1e-10, 1e-10 + if dtype == torch.float32: + return 5e-5, 5e-5 + return 5e-3, 5e-3 + + def _extract_l_block( + self, + D_full: torch.Tensor, + l: int, + ) -> torch.Tensor: + """Extract the l-block from D_full.""" + s, e = l * l, (l + 1) * (l + 1) + return D_full[:, s:e, s:e] + + def test_orthogonality(self) -> None: + """Test D @ D^T = I for random quaternions.""" + for dtype, lmax in itertools.product([torch.float64, torch.float32], [1, 3, 6]): + atol, rtol = self._get_tols(dtype) + wigner = WignerDCalculator(lmax=lmax, dtype=dtype) + edge_quat = _random_quaternion(self.batch, device=self.device, dtype=dtype) + D_full, Dt_full = wigner(edge_quat) + + for l in range(lmax + 1): + dim = 2 * l + 1 + eye = torch.eye(dim, device=self.device, dtype=dtype).expand( + self.batch, dim, dim + ) + D_l = self._extract_l_block(D_full, l) + Dt_l = self._extract_l_block(Dt_full, l) + torch.testing.assert_close( + D_l @ Dt_l, + eye, + atol=atol, + rtol=rtol, + msg=( + f"Orthogonality failed for WignerDCalculator, dtype={dtype}, lmax={lmax}, l={l}" + ), + ) + + def test_group_property(self) -> None: + """Test group property in quaternion composition order.""" + for dtype, lmax in itertools.product([torch.float64, torch.float32], [1, 3, 6]): + atol = 1e-10 if dtype == torch.float64 else 5e-4 + rtol = 1e-10 if dtype == torch.float64 else 5e-4 + wigner = WignerDCalculator(lmax=lmax, dtype=dtype) + + q1 = _random_quaternion(self.batch, device=self.device, dtype=dtype) + q2 = _random_quaternion(self.batch, device=self.device, dtype=dtype) + q12 = quaternion_multiply(q1, q2) + + D1_full, _ = wigner(q1) + D2_full, _ = wigner(q2) + D12_full, _ = wigner(q12) + + for l in range(lmax + 1): + D1_l = self._extract_l_block(D1_full, l) + D2_l = self._extract_l_block(D2_full, l) + D12_l = self._extract_l_block(D12_full, l) + torch.testing.assert_close( + D12_l, + D1_l @ D2_l, + atol=atol, + rtol=rtol, + msg=( + f"Group property failed for WignerDCalculator, dtype={dtype}, lmax={lmax}, l={l}" + ), + ) + + def test_l1_matches_vector_representation(self) -> None: + """Test that the l=1 block matches the Cartesian vector representation.""" + for dtype in [torch.float64, torch.float32]: + atol, rtol = self._get_tols(dtype) + S = torch.tensor( + [[0.0, -1.0, 0.0], [0.0, 0.0, -1.0], [1.0, 0.0, 0.0]], + device=self.device, + dtype=dtype, + ) + S_batch = S.unsqueeze(0).expand(self.batch, 3, 3) + + edge_quat = _random_quaternion(self.batch, device=self.device, dtype=dtype) + rot = quaternion_to_rotation_matrix(edge_quat) + wigner = WignerDCalculator(lmax=1, dtype=dtype) + D_full, Dt_full = wigner(edge_quat) + D1 = self._extract_l_block(D_full, 1) + Dt1 = self._extract_l_block(Dt_full, 1) + + expected = S_batch @ rot @ S_batch.transpose(-1, -2) + torch.testing.assert_close( + D1, + expected, + atol=atol, + rtol=rtol, + msg=f"l=1 block mismatch for WignerDCalculator, dtype={dtype}", + ) + torch.testing.assert_close( + Dt1, + expected.transpose(-1, -2), + atol=atol, + rtol=rtol, + msg=f"l=1 transpose block mismatch for WignerDCalculator, dtype={dtype}", + ) + + def test_pole_path_gradient_matches_finite_difference(self) -> None: + """Check one pole-crossing Wigner probe against finite differences.""" + for dtype in [torch.float64, torch.float32]: + wigner = WignerDCalculator(lmax=6, dtype=dtype) + atol = 5.0e-8 if dtype == torch.float64 else 2.0e-6 + rtol = 1.0e-6 if dtype == torch.float64 else 2.0e-4 + for sign in [1.0, -1.0]: + delta = torch.linspace( + -0.02, + 0.02, + 257, + device=self.device, + dtype=dtype, + requires_grad=True, + ) + edge_vec = torch.stack( + [delta, torch.zeros_like(delta), torch.full_like(delta, sign)], + dim=-1, + ) + edge_quat = build_edge_quaternion(edge_vec) + D_full, _ = wigner(edge_quat) + probe = D_full[:, 5, 7] + D_full[:, 17, 19] + grad = torch.autograd.grad(probe.sum(), delta)[0] + delta_detached = delta.detach() + probe_detached = probe.detach() + numerical_grad = (probe_detached[2:] - probe_detached[:-2]) / ( + 2.0 * (delta_detached[1] - delta_detached[0]) + ) + torch.testing.assert_close( + grad[1:-1].detach(), + numerical_grad, + atol=atol, + rtol=rtol, + msg=( + f"Pole-path Wigner gradient mismatch for dtype={dtype}, sign={sign}" + ), + ) + + def test_y_crossing_overlap_has_no_large_wigner_jump(self) -> None: + """Check chart-overlap continuity for a path that crosses y=0.""" + for dtype in [torch.float64, torch.float32]: + wigner = WignerDCalculator(lmax=4, dtype=dtype) + max_allowed = 1.0e-2 if dtype == torch.float64 else 1.5e-2 + y_vals = torch.tensor( + [-1.0e-3, -5.0e-4, -1.0e-4, 0.0, 1.0e-4, 5.0e-4, 1.0e-3], + device=self.device, + dtype=dtype, + ) + edge_vec = torch.stack( + [ + torch.full_like(y_vals, 0.35), + y_vals, + torch.full_like(y_vals, 0.25), + ], + dim=-1, + ) + edge_quat = build_edge_quaternion(edge_vec) + D_full, _ = wigner(edge_quat) + step = (D_full[1:] - D_full[:-1]).abs().amax(dim=(1, 2)) + self.assertLess( + step.max().item(), + max_allowed, + msg=f"Large Wigner jump across y=0 for dtype={dtype}", + ) + + +class TestSO2LinearEquivariance(_SeZMTestCase): + """Test SO2Linear z-rotation equivariance: SO2Linear(Z @ x) = Z @ SO2Linear(x).""" + + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + + def _get_tols(self, dtype: torch.dtype) -> tuple[float, float]: + if dtype == torch.float64: + return 1e-10, 1e-10 + if dtype == torch.float32: + return 1e-5, 1e-5 + # bf16 has only 7-bit mantissa; use looser tolerance for equivariance tests. + return 2e-2, 2e-2 + + def _build_m_major_z_rotation( + self, angles: torch.Tensor, lmax: int, mmax: int + ) -> torch.Tensor: + """ + Build block z-rotation matrix for the m-major truncated layout. + + Parameters + ---------- + angles + Rotation angles with shape (batch,). + lmax + Maximum degree. + mmax + Maximum order (|m|). Must satisfy mmax <= lmax. + + Returns + ------- + torch.Tensor + Z matrix with shape (batch, dim_red, dim_red). + """ + batch = angles.shape[0] + m0_size = lmax + 1 + dim_red = m0_size + for m in range(1, mmax + 1): + num_l = lmax - m + 1 + dim_red += 2 * num_l + + Z = angles.new_zeros(batch, dim_red, dim_red) + eye0 = torch.eye(m0_size, device=self.device, dtype=angles.dtype).expand( + batch, m0_size, m0_size + ) + Z[:, :m0_size, :m0_size] = eye0 + + offset = m0_size + for m in range(1, mmax + 1): + num_l = lmax - m + 1 + eye = torch.eye(num_l, device=self.device, dtype=angles.dtype).expand( + batch, num_l, num_l + ) + cos_m = torch.cos(m * angles).view(batch, 1, 1) + sin_m = torch.sin(m * angles).view(batch, 1, 1) + + # In m-major layout, each m group is stored as [neg(l), pos(l)] with two halves. + # Rotation is [[cos I, -sin I], [sin I, cos I]] for the (neg, pos) pair. + Z[:, offset : offset + num_l, offset : offset + num_l] = cos_m * eye + Z[ + :, + offset : offset + num_l, + offset + num_l : offset + 2 * num_l, + ] = -sin_m * eye + Z[ + :, + offset + num_l : offset + 2 * num_l, + offset : offset + num_l, + ] = sin_m * eye + Z[ + :, + offset + num_l : offset + 2 * num_l, + offset + num_l : offset + 2 * num_l, + ] = cos_m * eye + offset += 2 * num_l + + return Z + + def test_equivariance_random_angles(self) -> None: + """Test SO2Linear(Z @ x) = Z @ SO2Linear(x) for random z-rotations.""" + for dtype, lmax, mmax in itertools.product( + [torch.float64, torch.float32, torch.bfloat16], + [1, 2, 3], + [1, 2, 3], + ): + if mmax > lmax: + continue + atol, rtol = self._get_tols(dtype) + batch = 16 + channels_in = 8 + channels_out = 12 + + so2_linear = SO2Linear( + lmax=lmax, + mmax=mmax, + in_channels=channels_in, + out_channels=channels_out, + dtype=dtype, + seed=123, + trainable=True, + ) + + dim_red = so2_linear.reduced_dim + x = torch.randn( + batch, 1, dim_red, channels_in, device=self.device, dtype=dtype + ) + + angles = torch.rand(batch, device=self.device, dtype=dtype) * 2 * 3.14159 + Z = self._build_m_major_z_rotation(angles, lmax, mmax) + + x_rotated = torch.einsum("bij,bfjc->bfic", Z, x) + lhs = so2_linear(x_rotated) + rhs = torch.einsum("bij,bfjc->bfic", Z, so2_linear(x)) + + torch.testing.assert_close( + lhs, + rhs, + atol=atol, + rtol=rtol, + msg=f"SO2Linear equivariance failed for dtype={dtype}, lmax={lmax}, mmax={mmax}", + ) + + def test_dynamic_radial_degree_mixer_equivariance(self) -> None: + """Dynamic radial degree mixer should commute with local z-rotations.""" + for mode, rank in ( + ("degree", 0), + ("degree_channel", 0), + ("degree_channel", 2), + ): + with self.subTest(mode=mode, rank=rank): + dtype = torch.float32 + lmax = 2 + mmax = 2 + batch = 6 + channels = 5 + mixer = DynamicRadialDegreeMixer( + lmax=lmax, + mmax=mmax, + channels=channels, + mode=mode, + rank=rank, + dtype=dtype, + seed=123, + trainable=True, + ) + degree_index_m = build_m_major_l_index(lmax, mmax, device=self.device) + dim_red = int(degree_index_m.numel()) + x = torch.randn( + batch, dim_red, channels, device=self.device, dtype=dtype + ) + radial_base = torch.randn( + batch, lmax + 1, channels, device=self.device, dtype=dtype + ) + radial_feat = radial_base[:, degree_index_m, :] + angles = ( + torch.rand(batch, device=self.device, dtype=dtype) * 2 * math.pi + ) + Z = self._build_m_major_z_rotation(angles, lmax, mmax) + + x_rotated = torch.einsum("bij,bjc->bic", Z, x) + lhs = mixer(x_rotated, radial_feat) + rhs = torch.einsum("bij,bjc->bic", Z, mixer(x, radial_feat)) + + torch.testing.assert_close(lhs, rhs, atol=1e-5, rtol=1e-5) + + +class TestInnerClamp(_SeZMTestCase): + """Test InnerClamp C3-continuous septic Hermite clamping.""" + + def setUp(self) -> None: + super().setUp() + self.r_inner = 1.0 + self.r_outer = 1.5 + self.clamp = InnerClamp(self.r_inner, self.r_outer) + + def test_monotonicity(self) -> None: + """Test that r̃ is monotonically non-decreasing.""" + r = torch.linspace(0.0, 3.0, 1000, dtype=torch.float64, device=self.device) + out = self.clamp(r) + diff = out[1:] - out[:-1] + self.assertTrue((diff >= -1e-14).all(), "InnerClamp is not monotonic") + + def test_frozen_zone_zero_gradient(self) -> None: + """Test that dr̃/dr = 0 for r < r_inner (frozen zone).""" + r = torch.tensor( + [0.3, 0.5, 0.8, 0.99], + dtype=torch.float64, + device=self.device, + requires_grad=True, + ) + out = self.clamp(r) + grads = torch.autograd.grad(out.sum(), r)[0] + torch.testing.assert_close( + grads, + torch.zeros_like(grads), + atol=1e-12, + rtol=0, + msg="Gradient should be zero in the frozen zone", + ) + + def test_identity_zone_unit_gradient(self) -> None: + """Test that dr̃/dr = 1 for r > r_outer (identity zone).""" + r = torch.tensor( + [1.6, 2.0, 3.0, 5.0], + dtype=torch.float64, + device=self.device, + requires_grad=True, + ) + out = self.clamp(r) + grads = torch.autograd.grad(out.sum(), r)[0] + torch.testing.assert_close( + grads, + torch.ones_like(grads), + atol=1e-12, + rtol=0, + msg="Gradient should be 1 in the identity zone", + ) + + def test_c3_continuity_at_boundaries(self) -> None: + """Test C3 continuity at r_inner and r_outer via autograd derivatives.""" + eps = 1e-6 + for boundary in [self.r_inner, self.r_outer]: + r = torch.tensor( + [boundary - eps, boundary, boundary + eps], + dtype=torch.float64, + device=self.device, + requires_grad=True, + ) + out = self.clamp(r) + + # First derivative via autograd + grads = torch.autograd.grad(out.sum(), r, create_graph=True)[0] + # dr̃/dr should be continuous (left ≈ center ≈ right) + self.assertAlmostEqual( + grads[0].item(), + grads[1].item(), + places=4, + msg=f"First derivative discontinuous at {boundary}", + ) + self.assertAlmostEqual( + grads[1].item(), + grads[2].item(), + places=4, + msg=f"First derivative discontinuous at {boundary}", + ) + + # Second derivative via autograd + grads2 = torch.autograd.grad(grads.sum(), r, create_graph=True)[0] + self.assertAlmostEqual( + grads2[0].item(), + grads2[1].item(), + places=3, + msg=f"Second derivative discontinuous at {boundary}", + ) + self.assertAlmostEqual( + grads2[1].item(), + grads2[2].item(), + places=3, + msg=f"Second derivative discontinuous at {boundary}", + ) + + # Third derivative via autograd + grads3 = torch.autograd.grad(grads2.sum(), r)[0] + self.assertAlmostEqual( + grads3[0].item(), + grads3[1].item(), + places=2, + msg=f"Third derivative discontinuous at {boundary}", + ) + self.assertAlmostEqual( + grads3[1].item(), + grads3[2].item(), + places=2, + msg=f"Third derivative discontinuous at {boundary}", + ) + + def test_invalid_params(self) -> None: + """Test that invalid parameters raise ValueError.""" + with self.assertRaises(ValueError): + InnerClamp(1.5, 1.0) + with self.assertRaises(ValueError): + InnerClamp(-1.0, 1.0) + with self.assertRaises(ValueError): + InnerClamp(1.0, 1.0) + + +class TestDescriptorEnergyCurveSmoothness(_SeZMTestCase): + """Test PES smoothness from scaled symmetric eight-atom probes.""" + + RANDOM_WEIGHT_BASE_SEED = 184 + RANDOM_WEIGHT_STD = 0.1 + N_DISPLACEMENT_POINTS = 201 + MAX_DISPLACEMENT = 0.1 + RCUT_NEAR_DISTANCE = 4.95 + BRIDGING_R_INNER = 0.8 + BRIDGING_R_OUTER = 1.2 + ENERGY_SPAN_MARGIN = 1.0e-7 + FIRST_DERIVATIVE_MARGIN = 5.0e-7 + SECOND_DERIVATIVE_MARGIN = 5.0e-4 + EXTREMUM_MARGIN = 1.0e-9 + + def setUp(self) -> None: + super().setUp() + self.dtype = torch.float64 + self.symmetry_frac_coord = torch.tensor( + [ + [0.0, 0.0, 0.0], + [0.0, 0.5, 0.5], + [0.5, 0.0, 0.5], + [0.5, 0.5, 0.0], + [0.5, 0.5, 0.5], + [0.5, 0.0, 0.0], + [0.0, 0.5, 0.0], + [0.0, 0.0, 0.5], + ], + dtype=self.dtype, + device=self.device, + ).view(1, 8, 3) + self.symmetry_atype = torch.tensor( + [[0, 0, 0, 0, 1, 1, 1, 1]], + dtype=torch.int32, + device=self.device, + ) + + def _build_model_params( + self, + n_atten_head: int, + *, + use_amp: bool, + n_focus: int = 1, + bridging_method: str = "none", + bridging_r_inner: float = 0.8, + bridging_r_outer: float = 1.2, + ) -> dict: + """Build the SeZM probe model used for one PES scan.""" + params = { + "type": "SeZM", + "type_map": ["Na", "Cl"], + "descriptor": { + "type": "SeZM", + "sel": [16, 16], + "rcut": 5.0, + "channels": 16, + "n_focus": n_focus, + "focus_dim": 0, + "focus_compete": True, + "n_radial": 6, + "radial_mlp": [16], + "use_env_seed": True, + "l_schedule": [1, 0], + "mmax": 1, + "so2_norm": False, + "so2_layers": 1, + "n_atten_head": n_atten_head, + "sandwich_norm": [True, False, True, False], + "ffn_neurons": 16, + "ffn_blocks": 1, + "mlp_bias": True, + "layer_scale": False, + "use_amp": use_amp, + "activation_function": "silu", + "glu_activation": True, + "precision": "float64", + "seed": 7, + }, + "fitting_net": { + "neuron": [], + "activation_function": "silu", + "precision": "float64", + "seed": 7, + }, + } + if bridging_method.lower() != "none": + params["bridging_method"] = bridging_method + params["bridging_r_inner"] = bridging_r_inner + params["bridging_r_outer"] = bridging_r_outer + return params + + def _build_scaled_symmetric_structure( + self, + nearest_distance: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Scale one symmetric eight-atom template to a target nearest-neighbor distance.""" + lattice = 2.0 * nearest_distance + coord = self.symmetry_frac_coord * lattice + box = torch.tensor( + [[lattice, 0.0, 0.0, 0.0, lattice, 0.0, 0.0, 0.0, lattice]], + dtype=self.dtype, + device=self.device, + ) + return coord, box + + def _stable_parameter_seed(self, name: str) -> int: + """Return an order-independent seed derived from one parameter name.""" + seed_value = self.RANDOM_WEIGHT_BASE_SEED + for idx, char in enumerate(name): + seed_value += (idx + 1) * ord(char) + return seed_value + + def _build_random_weight_model( + self, + n_atten_head: int, + *, + use_amp: bool, + n_focus: int = 1, + bridging_method: str = "none", + bridging_r_inner: float = 0.8, + bridging_r_outer: float = 1.2, + ) -> torch.nn.Module: + """Build a probe model and randomize all floating-point parameters.""" + model = get_sezm_model( + self._build_model_params( + n_atten_head, + use_amp=use_amp, + n_focus=n_focus, + bridging_method=bridging_method, + bridging_r_inner=bridging_r_inner, + bridging_r_outer=bridging_r_outer, + ) + ).to( + device=self.device, + dtype=self.dtype, + ) + randomized = 0 + + # Assign deterministic random weights by parameter name. + with torch.no_grad(): + for name, param in model.named_parameters(): + if not param.is_floating_point(): + continue + generator = torch.Generator(device=self.device) + generator.manual_seed(self._stable_parameter_seed(name)) + param.normal_( + mean=0.0, + std=self.RANDOM_WEIGHT_STD, + generator=generator, + ) + randomized += 1 + + self.assertGreater( + randomized, 0, "No floating-point parameters were randomized" + ) + model.eval() + return model + + def _scan_total_energy_curve( + self, + model: torch.nn.Module, + *, + nearest_distance: float, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Scan total energies while displacing atom 0 along x.""" + displacements = torch.linspace( + -self.MAX_DISPLACEMENT, + self.MAX_DISPLACEMENT, + self.N_DISPLACEMENT_POINTS, + dtype=self.dtype, + device=self.device, + ) + coord0, box = self._build_scaled_symmetric_structure(nearest_distance) + coord = coord0.repeat(displacements.shape[0], 1, 1) + coord[:, 0, 0] += displacements + result = model( + coord, + self.symmetry_atype.expand(displacements.shape[0], -1), + box=box.expand(displacements.shape[0], -1), + ) + return displacements, result["energy"][:, 0].detach() + + def _collect_curve_statistics( + self, + energies: torch.Tensor, + displacements: torch.Tensor, + ) -> dict[str, float | str]: + """Compute derivative and extremum statistics for one energy curve.""" + step = displacements[1] - displacements[0] + first = (energies[2:] - energies[:-2]) / (2.0 * step) + second = (energies[2:] - 2.0 * energies[1:-1] + energies[:-2]) / (step * step) + center_idx = energies.shape[0] // 2 + deriv_center_idx = first.shape[0] // 2 + left_first = first[:deriv_center_idx] + right_first = first[deriv_center_idx + 1 :] + center_energy = energies[center_idx] + other_energies = torch.cat((energies[:center_idx], energies[center_idx + 1 :])) + + bowl_up = bool( + torch.all(left_first < -self.FIRST_DERIVATIVE_MARGIN) + and torch.all(right_first > self.FIRST_DERIVATIVE_MARGIN) + and torch.all(second > self.SECOND_DERIVATIVE_MARGIN) + and torch.all(center_energy <= other_energies + self.EXTREMUM_MARGIN) + ) + bowl_down = bool( + torch.all(left_first > self.FIRST_DERIVATIVE_MARGIN) + and torch.all(right_first < -self.FIRST_DERIVATIVE_MARGIN) + and torch.all(second < -self.SECOND_DERIVATIVE_MARGIN) + and torch.all(center_energy >= other_energies - self.EXTREMUM_MARGIN) + ) + + if bowl_up: + curve_kind = "minimum" + elif bowl_down: + curve_kind = "maximum" + else: + curve_kind = "invalid" + + return { + "curve_kind": curve_kind, + "energy_span": float((energies.max() - energies.min()).abs()), + "left_abs_min": float(left_first.abs().min()), + "right_abs_min": float(right_first.abs().min()), + "curvature_abs_min": float(second.abs().min()), + } + + def _assert_curve_has_usable_signal( + self, + stats: dict[str, float | str], + *, + label: str, + n_atten_head: int, + ) -> None: + """Check that the scanned curve has enough signal above numerical noise.""" + self.assertGreater( + stats["energy_span"], + self.ENERGY_SPAN_MARGIN, + f"{label} energy curve became nearly flat for n_atten_head={n_atten_head}: {stats}", + ) + self.assertGreater( + stats["left_abs_min"], + self.FIRST_DERIVATIVE_MARGIN, + f"{label} left-branch slope is too small for n_atten_head={n_atten_head}: {stats}", + ) + self.assertGreater( + stats["right_abs_min"], + self.FIRST_DERIVATIVE_MARGIN, + f"{label} right-branch slope is too small for n_atten_head={n_atten_head}: {stats}", + ) + self.assertGreater( + stats["curvature_abs_min"], + self.SECOND_DERIVATIVE_MARGIN, + f"{label} curvature is too small for n_atten_head={n_atten_head}: {stats}", + ) + + def _assert_cutoff_near_energy_curve_is_smooth( + self, + n_atten_head: int, + *, + use_amp: bool, + n_focus: int, + ) -> None: + """Check that the non-bridged near-cutoff probe keeps one smooth extremum.""" + model = self._build_random_weight_model( + n_atten_head, + use_amp=use_amp, + n_focus=n_focus, + ) + displacements, energies = self._scan_total_energy_curve( + model, + nearest_distance=self.RCUT_NEAR_DISTANCE, + ) + self.assertTrue(torch.isfinite(energies).all().item()) + + stats = self._collect_curve_statistics(energies, displacements) + self._assert_curve_has_usable_signal( + stats, + label=f"Near-cutoff (use_amp={use_amp}, n_focus={n_focus})", + n_atten_head=n_atten_head, + ) + self.assertIn( + stats["curve_kind"], + {"minimum", "maximum"}, + ( + "Near-cutoff energy curve is not a single smooth bowl " + f"for n_atten_head={n_atten_head}, use_amp={use_amp}, n_focus={n_focus}: {stats}" + ), + ) + + def _assert_bridged_boundary_energy_curve_is_smooth( + self, + n_atten_head: int, + *, + use_amp: bool, + n_focus: int, + nearest_distance: float, + boundary_label: str, + ) -> None: + """Check that one bridged boundary probe keeps one smooth minimum.""" + model = self._build_random_weight_model( + n_atten_head, + use_amp=use_amp, + n_focus=n_focus, + bridging_method="ZBL", + bridging_r_inner=self.BRIDGING_R_INNER, + bridging_r_outer=self.BRIDGING_R_OUTER, + ) + displacements, energies = self._scan_total_energy_curve( + model, + nearest_distance=nearest_distance, + ) + self.assertTrue(torch.isfinite(energies).all().item()) + + stats = self._collect_curve_statistics(energies, displacements) + self._assert_curve_has_usable_signal( + stats, + label=f"Bridged {boundary_label} (use_amp={use_amp}, n_focus={n_focus})", + n_atten_head=n_atten_head, + ) + self.assertEqual( + stats["curve_kind"], + "minimum", + ( + f"Bridged {boundary_label} probe should form one symmetric repulsive bowl " + f"for n_atten_head={n_atten_head}, use_amp={use_amp}, n_focus={n_focus}: {stats}" + ), + ) + + def test_scaled_cutoff_near_energy_curve_is_smooth_across_attention_modes( + self, + ) -> None: + """Check the non-bridged near-cutoff PES shape across attention and AMP modes.""" + for use_amp in (False, True): + for n_atten_head in (0, 1, 2): + for n_focus in (1, 2): + with self.subTest( + n_atten_head=n_atten_head, use_amp=use_amp, n_focus=n_focus + ): + self._assert_cutoff_near_energy_curve_is_smooth( + n_atten_head, + use_amp=use_amp, + n_focus=n_focus, + ) + + def test_scaled_bridging_inner_energy_curve_is_smooth_across_attention_modes( + self, + ) -> None: + """Check the bridged near-r_inner PES shape across attention and AMP modes.""" + for use_amp in (False, True): + for n_atten_head in (0, 1, 2): + for n_focus in (1, 2): + with self.subTest( + n_atten_head=n_atten_head, use_amp=use_amp, n_focus=n_focus + ): + self._assert_bridged_boundary_energy_curve_is_smooth( + n_atten_head, + use_amp=use_amp, + n_focus=n_focus, + nearest_distance=self.BRIDGING_R_INNER, + boundary_label="r_inner", + ) + + def test_scaled_bridging_outer_energy_curve_is_smooth_across_attention_modes( + self, + ) -> None: + """Check the bridged near-r_outer PES shape across attention and AMP modes.""" + for use_amp in (False, True): + for n_atten_head in (0, 1, 2): + for n_focus in (1, 2): + with self.subTest( + n_atten_head=n_atten_head, use_amp=use_amp, n_focus=n_focus + ): + self._assert_bridged_boundary_energy_curve_is_smooth( + n_atten_head, + use_amp=use_amp, + n_focus=n_focus, + nearest_distance=self.BRIDGING_R_OUTER, + boundary_label="r_outer", + ) + + +class TestSourceFreezePropagationGate(TestDescriptorEnergyCurveSmoothness): + """ + Sharper correctness probes for the Source Freeze Propagation Gate. + + ``TestDescriptorEnergyCurveSmoothness`` already validates that the + bridged PES is a single smooth repulsive bowl near ``r_inner`` and + ``r_outer``; that family exercises the combined ``InnerClamp`` + + ``BridgingSwitch`` machinery on a realistic 8-atom layout and is + sensitive to discontinuities. + + The tests below add a tighter, analytical guardrail that + ``InnerClamp`` alone cannot satisfy once the model grows past one + interaction block. Setup: three atoms ``(A, B, C)`` with ``A`` at + the origin, ``B`` sliding rigidly on a sphere of radius + ``r_AB < r_inner`` around ``A`` and ``C`` anchored in the normal + zone. Under this motion + + * ``r_AB`` is constant, so ``V_ZBL(r_AB)`` is constant and the + analytical half-energy assigned to ``A`` does not move; + * the direction ``r_hat_AB`` varies, and so do ``r_BC`` and + ``r_hat_BC`` — these are exactly the channels through which the + direction and multi-hop leaks manifest in the user's analysis. + + The guarantee SFPG delivers is therefore very sharp: **the atomic + energy of every node that is not ``B`` itself must stay strictly + constant**. ``E_B`` is expected to vary because ``B`` still has a + genuine chemical bond with ``C``. + """ + + NEAR_DISTANCE = 0.6 # < BRIDGING_R_INNER, stays frozen for all samples + N_SPHERE_POINTS = 16 + INVARIANCE_TOLERANCE = 1.0e-10 + + def _build_three_atom_box( + self, + *, + near_distance: float, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Build a 3-atom open-box probe ``(A, B, C)`` with frozen pair ``(A, B)``. + + ``A`` is anchored at the origin, ``B`` is placed inside the frozen + sphere at distance ``near_distance`` along the initial direction, + and ``C`` sits well outside the bridging window so ``(A, C)`` and + ``(B, C)`` are ordinary GNN edges. + """ + coord = torch.tensor( + [ + [0.0, 0.0, 0.0], + [near_distance, 0.0, 0.0], + [2.4, 0.0, 0.0], + ], + dtype=self.dtype, + device=self.device, + ).unsqueeze(0) + atype = torch.tensor([[0, 1, 0]], dtype=torch.int32, device=self.device) + # Large enough cubic cell to keep the system aperiodic under rcut=5. + box = torch.tensor( + [[10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0]], + dtype=self.dtype, + device=self.device, + ) + return coord, atype, box + + def _sphere_positions( + self, + n_points: int, + radius: float, + ) -> torch.Tensor: + """Deterministic Fibonacci-like sphere sampling around the origin.""" + indices = torch.arange(n_points, dtype=self.dtype, device=self.device) + phi = math.pi * (3.0 - math.sqrt(5.0)) # golden-angle step + z = 1.0 - 2.0 * (indices + 0.5) / float(n_points) + rho = torch.sqrt(torch.clamp(1.0 - z * z, min=0.0)) + theta = phi * indices + positions = torch.stack( + [ + radius * rho * torch.cos(theta), + radius * rho * torch.sin(theta), + radius * z, + ], + dim=-1, + ) + return positions # (n_points, 3) + + def _evaluate_frozen_sphere_atom_energies( + self, + model: torch.nn.Module, + ) -> torch.Tensor: + """Evaluate per-atom energies while ``B`` rigidly slides on the frozen sphere. + + The per-edge local-``+Z`` gauge roll is turned off so every edge uses + the deterministic canonical edge quaternion. This removes the only + non-determinism that would otherwise mask the ``< 1e-10`` invariance + SFPG is supposed to deliver in fp64; the gauge roll is pure + training-time augmentation and does not change the physical content. + """ + descriptor = model.atomic_model.descriptor + previous_random_gamma = descriptor.random_gamma + descriptor.random_gamma = False + try: + coord0, atype, box = self._build_three_atom_box( + near_distance=self.NEAR_DISTANCE + ) + directions = self._sphere_positions( + self.N_SPHERE_POINTS, radius=self.NEAR_DISTANCE + ) + coord = coord0.repeat(directions.shape[0], 1, 1) + coord[:, 1, :] = directions # rotate B around A, radius fixed + result = model( + coord, + atype.expand(directions.shape[0], -1), + box=box.expand(directions.shape[0], -1), + ) + finally: + descriptor.random_gamma = previous_random_gamma + return result["atom_energy"].detach() # (n_samples, 3, 1) + + def _atomic_energy_span( + self, + atom_energies: torch.Tensor, + atom_index: int, + ) -> float: + """Range of one atom's energy across the sphere trajectory.""" + values = atom_energies[:, atom_index, 0] + return float((values.max() - values.min()).abs()) + + def test_sfpg_preserves_atomic_energy_of_frozen_partner(self) -> None: + """ + With SFPG active, the atomic energy of the frozen-partner atom + ``A`` must be strictly constant along the sphere trajectory. + + ``A``'s atomic energy decomposes as ``E_A_GNN + V_ZBL(r_AB)/2 + + V_ZBL(r_AC)/2``. Both ZBL half-terms are constant because + ``r_AB`` is fixed on the frozen sphere and ``r_AC`` does not + change either, so any residual variation must come from the + GNN branch. ``E_C`` is deliberately not probed here because + ``V_ZBL(r_BC)/2`` that sits on the ``C`` side is of order + several eV at ``r_BC ≈ r_outer`` and moves with ``B`` even + though the GNN piece of ``E_C`` remains strictly invariant. + Writing the test on ``E_A`` keeps the assertion clean and + purely diagnostic of SFPG rather than of ZBL geometry. + """ + model = self._build_random_weight_model( + n_atten_head=2, + use_amp=False, + n_focus=2, + bridging_method="ZBL", + bridging_r_inner=self.BRIDGING_R_INNER, + bridging_r_outer=self.BRIDGING_R_OUTER, + ) + atom_energies = self._evaluate_frozen_sphere_atom_energies(model) + self.assertTrue(torch.isfinite(atom_energies).all().item()) + + span_a = self._atomic_energy_span(atom_energies, atom_index=0) + self.assertLess( + span_a, + self.INVARIANCE_TOLERANCE, + ( + "SFPG must freeze the atomic energy of ``A`` under rigid " + f"motion of its frozen partner; observed span {span_a:.3e} " + f"> {self.INVARIANCE_TOLERANCE:.3e}." + ), + ) + + def test_sfpg_leaks_reopen_when_gate_is_disabled(self) -> None: + """ + Ablation: clearing ``bridging_switch`` must re-expose the leak. + + This test is the contrapositive of the previous one: if SFPG is + the only mechanism that closes the direction / multi-hop leak, + disabling it must produce a measurably non-constant ``E_A`` on + the same frozen-sphere trajectory. Keeping this ablation in the + suite pins down which component owns the invariance, so any + future regression that silently disables SFPG is caught + immediately. + """ + model = self._build_random_weight_model( + n_atten_head=2, + use_amp=False, + n_focus=2, + bridging_method="ZBL", + bridging_r_inner=self.BRIDGING_R_INNER, + bridging_r_outer=self.BRIDGING_R_OUTER, + ) + # Drop only the per-source gate: InnerClamp and ZBL stay active. + model.atomic_model.descriptor.bridging_switch = None + atom_energies = self._evaluate_frozen_sphere_atom_energies(model) + + span_a = self._atomic_energy_span(atom_energies, atom_index=0) + self.assertGreater( + span_a, + self.INVARIANCE_TOLERANCE, + ( + "Clearing SFPG should re-expose the direction / multi-hop " + f"leak on ``E_A``; observed span {span_a:.3e} is not above " + f"the invariance tolerance {self.INVARIANCE_TOLERANCE:.3e}." + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py b/source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py new file mode 100644 index 0000000000..580898a81b --- /dev/null +++ b/source/tests/pt/model/test_descriptor_sezm_s2_equivariance.py @@ -0,0 +1,385 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import torch + +from deepmd.pt.model.descriptor.sezm_nn import ( + S2GridProjector, + SwiGLUS2Activation, + WignerDCalculator, + build_m_major_index, + quaternion_z_rotation, + resolve_s2_grid_resolution, +) +from deepmd.pt.utils import ( + env, +) + + +def _random_quaternion( + n_batch: int, + *, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + """Sample normalized quaternions in ``(w, x, y, z)`` order.""" + q = torch.randn(n_batch, 4, device=device, dtype=dtype) + return q / torch.sqrt(torch.sum(q * q, dim=-1, keepdim=True)) + + +def _rotate_ndfc(x: torch.Tensor, d_matrix: torch.Tensor) -> torch.Tensor: + """Rotate coefficient-layout tensors with shape ``(N, D, F, C)``.""" + return torch.einsum("nij,njfc->nifc", d_matrix, x) + + +def _rotate_nfdc(x: torch.Tensor, d_matrix: torch.Tensor) -> torch.Tensor: + """Rotate coefficient-layout tensors with shape ``(N, F, D, C)``.""" + return torch.einsum("nij,nfjc->nfic", d_matrix, x) + + +def _max_abs_equivariance_error(lhs: torch.Tensor, rhs: torch.Tensor) -> float: + """Compute the maximum absolute equivariance error.""" + return float(torch.max(torch.abs(lhs - rhs)).item()) + + +class TestS2GridProjector(unittest.TestCase): + """Test S2 projection invariants.""" + + def setUp(self) -> None: + self.device = env.DEVICE + torch.manual_seed(0) + + def test_lebedev_roundtrip_preserves_bandlimited_coefficients(self) -> None: + """Lebedev quadrature should reconstruct coefficients up to lmax.""" + projector = S2GridProjector( + lmax=3, + dtype=torch.float64, + grid_resolution_list=None, + coefficient_layout="packed", + grid_method="lebedev", + ) + x = torch.randn( + 5, projector.coeff_dim, 3, device=self.device, dtype=torch.float64 + ) + y = projector.from_grid(projector.to_grid(x)) + torch.testing.assert_close(y, x, atol=1e-12, rtol=1e-12) + + +class TestSwiGLUS2Equivariance(unittest.TestCase): + """Test default-grid equivariance of full-m and truncated SwiGLU-S2 activations.""" + + def setUp(self) -> None: + self.device = env.DEVICE + torch.manual_seed(0) + + def test_default_full_m_grid_counts_keep_s2_activation_equivariant(self) -> None: + """Default full-m S2 activation grids should keep SO(3) equivariance.""" + # Each case is (lmax, full_m_grid, fp64_tol, fp32_tol). + # e3nn full_m_grid is [R_phi, R_theta] after the square-grid lift. + # Lebedev full_m_grid is [precision, n_points]. + cases_by_method = { + "e3nn": [ + (2, [8, 8], 4.20e-7, 5.00e-6), # local: fp64=3.62e-7, fp32=4.77e-7 + (3, [12, 12], 8.10e-7, 5.00e-6), # local: fp64=7.04e-7, fp32=6.86e-7 + (4, [14, 14], 9.20e-7, 5.00e-6), # local: fp64=7.97e-7, fp32=1.55e-6 + (5, [18, 18], 1.70e-6, 5.00e-6), # local: fp64=1.48e-6, fp32=1.49e-6 + (6, [20, 20], 4.80e-6, 5.00e-6), # local: fp64=4.14e-6, fp32=2.27e-6 + (7, [24, 24], 3.70e-6, 5.00e-6), # local: fp64=3.19e-6, fp32=2.03e-6 + ], + "lebedev": [ + (2, [7, 26], 1.00e-12, 5.00e-6), # local: fp64=2.31e-14, fp32=2.38e-7 + (3, [9, 38], 1.00e-12, 5.00e-6), # local: fp64=3.58e-14, fp32=3.58e-7 + (4, [13, 74], 1.00e-12, 5.00e-6), # local: fp64=5.82e-14, fp32=6.56e-7 + (5, [15, 86], 1.00e-12, 5.00e-6), # local: fp64=3.22e-14, fp32=6.56e-7 + (6, [19, 146], 1.00e-12, 5.00e-6), # local: fp64=7.99e-14, fp32=8.35e-7 + (7, [21, 170], 1.00e-12, 5.00e-6), # local: fp64=6.86e-14, fp32=8.79e-7 + ], + } + dtype_cases = [ + (torch.float64, 0), + (torch.float32, 1), + ] + n_batch = 3 + n_focus = 1 + channels = 2 + + for dtype, tolerance_index in dtype_cases: + for method, cases in cases_by_method.items(): + for lmax, expected_full_m_grid, *tolerances in cases: + with self.subTest( + method=method, + dtype=dtype, + lmax=lmax, + grid=expected_full_m_grid, + ): + self._assert_default_full_m_s2_activation_equivariance( + grid_method=method, + lmax=lmax, + expected_full_m_grid=expected_full_m_grid, + n_batch=n_batch, + n_focus=n_focus, + channels=channels, + dtype=dtype, + tolerance=tolerances[tolerance_index], + ) + + def _assert_default_full_m_s2_activation_equivariance( + self, + *, + grid_method: str, + lmax: int, + expected_full_m_grid: list[int], + n_batch: int, + n_focus: int, + channels: int, + dtype: torch.dtype, + tolerance: float, + ) -> None: + """Assert full-m S2 activation equivariance for one method/dtype/lmax case.""" + torch.manual_seed(1234 + lmax) + default_grid = resolve_s2_grid_resolution( + lmax, + lmax, + method=grid_method, + ) + full_m_grid = ( + [max(default_grid), max(default_grid)] + if grid_method == "e3nn" + else default_grid + ) + self.assertEqual(full_m_grid, expected_full_m_grid) + + activation = SwiGLUS2Activation( + lmax=lmax, + channels=channels, + dtype=dtype, + n_focus=n_focus, + layout="ndfc", + grid_resolution_list=full_m_grid, + coefficient_layout="packed", + grid_method=grid_method, + mlp_bias=False, + trainable=False, + seed=17 + lmax, + ) + self.assertEqual(activation.grid_resolution_list, expected_full_m_grid) + + x = torch.randn( + n_batch, + (lmax + 1) ** 2, + n_focus, + 2 * channels, + device=self.device, + dtype=dtype, + ) + quat = _random_quaternion(n_batch, device=self.device, dtype=dtype) + d_matrix, _ = WignerDCalculator(lmax=lmax, dtype=dtype)(quat) + + y_rotated_input = activation(_rotate_ndfc(x, d_matrix)) + y_then_rotated = _rotate_ndfc(activation(x), d_matrix) + max_error = _max_abs_equivariance_error( + y_rotated_input, + y_then_rotated, + ) + + self.assertLessEqual(max_error, tolerance) + + def test_default_mmax_truncated_grid_counts_keep_s2_activation_z_equivariant( + self, + ) -> None: + """Default mmax-truncated S2 activation grids should keep z-equivariance.""" + # Each case is (lmax, mmax, truncated_grid, fp64_tol, fp32_tol). + # e3nn truncated_grid is [R_phi, R_theta] used by the m-major path. + # Lebedev truncated_grid is [precision, n_points]. + cases_by_method = { + "e3nn": { + 1: [ + (2, [6, 8], 2.80e-7, 5.00e-6), # local: fp64=2.36e-7, fp32=3.58e-7 + (3, [6, 12], 1.50e-7, 5.00e-6), # local: fp64=1.22e-7, fp32=5.96e-7 + (4, [6, 14], 1.33e-6, 5.00e-6), # local: fp64=1.12e-6, fp32=9.54e-7 + (5, [6, 18], 1.30e-7, 5.00e-6), # local: fp64=1.10e-7, fp32=1.43e-6 + (6, [6, 20], 9.00e-7, 5.00e-6), # local: fp64=7.64e-7, fp32=1.91e-6 + (7, [6, 24], 2.60e-7, 5.00e-6), # local: fp64=2.17e-7, fp32=1.91e-6 + ], + 2: [ + (2, [8, 8], 4.70e-7, 5.00e-6), # local: fp64=4.01e-7, fp32=8.34e-7 + (3, [8, 12], 7.00e-7, 5.00e-6), # local: fp64=5.99e-7, fp32=8.34e-7 + (4, [8, 14], 7.00e-7, 5.00e-6), # local: fp64=6.02e-7, fp32=1.67e-6 + (5, [8, 18], 1.40e-6, 5.00e-6), # local: fp64=1.19e-6, fp32=1.55e-6 + (6, [8, 20], 1.55e-6, 5.00e-6), # local: fp64=1.33e-6, fp32=2.15e-6 + (7, [8, 24], 1.65e-6, 5.00e-6), # local: fp64=1.41e-6, fp32=2.62e-6 + ], + }, + "lebedev": { + 1: [ + ( + 2, + [7, 26], + 1.00e-12, + 5.00e-6, + ), # local: fp64=2.31e-14, fp32=2.38e-7 + ( + 3, + [9, 38], + 1.00e-12, + 5.00e-6, + ), # local: fp64=3.55e-14, fp32=2.98e-7 + ( + 4, + [13, 74], + 1.00e-12, + 5.00e-6, + ), # local: fp64=1.04e-13, fp32=9.54e-7 + ( + 5, + [15, 86], + 1.00e-12, + 5.00e-6, + ), # local: fp64=9.34e-14, fp32=7.15e-7 + ( + 6, + [19, 146], + 1.00e-12, + 5.00e-6, + ), # local: fp64=8.56e-14, fp32=2.15e-6 + ( + 7, + [21, 170], + 1.00e-12, + 5.00e-6, + ), # local: fp64=2.08e-13, fp32=3.34e-6 + ], + 2: [ + ( + 2, + [7, 26], + 1.00e-12, + 5.00e-6, + ), # local: fp64=1.50e-14, fp32=2.38e-7 + ( + 3, + [9, 38], + 1.00e-12, + 5.00e-6, + ), # local: fp64=5.71e-14, fp32=3.58e-7 + ( + 4, + [13, 74], + 1.00e-12, + 5.00e-6, + ), # local: fp64=9.15e-14, fp32=5.96e-7 + ( + 5, + [15, 86], + 1.00e-12, + 5.00e-6, + ), # local: fp64=7.83e-14, fp32=4.77e-7 + ( + 6, + [19, 146], + 1.00e-12, + 5.00e-6, + ), # local: fp64=1.29e-13, fp32=9.54e-7 + ( + 7, + [21, 170], + 1.00e-12, + 5.00e-6, + ), # local: fp64=1.56e-13, fp32=1.43e-6 + ], + }, + } + dtype_cases = [ + (torch.float64, 0), + (torch.float32, 1), + ] + n_batch = 3 + n_focus = 2 + channels = 2 + + for dtype, tolerance_index in dtype_cases: + for method, cases_by_mmax in cases_by_method.items(): + for mmax, cases in cases_by_mmax.items(): + for lmax, expected_truncated_grid, *tolerances in cases: + with self.subTest( + method=method, + dtype=dtype, + lmax=lmax, + mmax=mmax, + grid=expected_truncated_grid, + ): + self._assert_default_mmax_truncated_grid_z_equivariance( + grid_method=method, + lmax=lmax, + mmax=mmax, + expected_truncated_grid=expected_truncated_grid, + n_batch=n_batch, + n_focus=n_focus, + channels=channels, + dtype=dtype, + tolerance=tolerances[tolerance_index], + ) + + def _assert_default_mmax_truncated_grid_z_equivariance( + self, + *, + grid_method: str, + lmax: int, + mmax: int, + expected_truncated_grid: list[int], + n_batch: int, + n_focus: int, + channels: int, + dtype: torch.dtype, + tolerance: float, + ) -> None: + """Assert mmax-truncated S2 activation z-equivariance for one case.""" + torch.manual_seed(2234 + lmax + 100 * mmax) + truncated_grid = resolve_s2_grid_resolution( + lmax, + mmax, + method=grid_method, + ) + self.assertEqual(truncated_grid, expected_truncated_grid) + + activation = SwiGLUS2Activation( + lmax=lmax, + mmax=mmax, + channels=channels, + dtype=dtype, + n_focus=n_focus, + layout="nfdc", + grid_resolution_list=truncated_grid, + coefficient_layout="m_major", + grid_method=grid_method, + mlp_bias=False, + trainable=False, + seed=27 + lmax + 100 * mmax, + ) + self.assertEqual(activation.grid_resolution_list, expected_truncated_grid) + + coeff_index = build_m_major_index(lmax, mmax, device=self.device) + x = torch.randn( + n_batch, + n_focus, + int(coeff_index.numel()), + 2 * channels, + device=self.device, + dtype=dtype, + ) + gamma = torch.randn(n_batch, device=self.device, dtype=dtype) + quaternion = quaternion_z_rotation(gamma) + d_matrix, _ = WignerDCalculator(lmax=lmax, dtype=dtype)(quaternion) + d_matrix_reduced = d_matrix.index_select(1, coeff_index).index_select( + 2, + coeff_index, + ) + + y_rotated_input = activation(_rotate_nfdc(x, d_matrix_reduced)) + y_then_rotated = _rotate_nfdc(activation(x), d_matrix_reduced) + max_error = _max_abs_equivariance_error( + y_rotated_input, + y_then_rotated, + ) + + self.assertLessEqual(max_error, tolerance) diff --git a/source/tests/pt/model/test_descriptor_sezm_triton.py b/source/tests/pt/model/test_descriptor_sezm_triton.py new file mode 100644 index 0000000000..a0e3d44483 --- /dev/null +++ b/source/tests/pt/model/test_descriptor_sezm_triton.py @@ -0,0 +1,960 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import torch + +from deepmd.pt.model.descriptor.sezm_nn import ( + C3CutoffEnvelope, + InnerClamp, + RadialBasis, + build_m_major_index, + project_D_to_m, + project_Dt_from_m, +) +from deepmd.pt.model.descriptor.sezm_nn.triton import ( + SEZM_TRITON_AVAILABLE, + TritonRotationMode, + edge_geometry_rbf_triton, + resolve_triton_rotation_mode, + rotate_back_triton, + rotate_to_local_triton, +) + +TRITON_CUDA_AVAILABLE = SEZM_TRITON_AVAILABLE and torch.cuda.is_available() + + +class TestSeZMTritonDispatch(unittest.TestCase): + """Validate the SeZM Triton dispatch policy.""" + + def test_resolve_rotation_mode_covers_small_generic_and_fallback(self) -> None: + """Dispatch policy should cover small kernels, generic kernels, and fallback.""" + self.assertEqual( + resolve_triton_rotation_mode(dim_full=1, reduced_dim=1), + TritonRotationMode.SMALL_LE1, + ) + self.assertEqual( + resolve_triton_rotation_mode(dim_full=4, reduced_dim=4), + TritonRotationMode.SMALL_LE1, + ) + self.assertEqual( + resolve_triton_rotation_mode(dim_full=9, reduced_dim=7), + TritonRotationMode.SMALL_L2, + ) + self.assertEqual( + resolve_triton_rotation_mode(dim_full=16, reduced_dim=10), + TritonRotationMode.SMALL_L3, + ) + self.assertEqual( + resolve_triton_rotation_mode(dim_full=25, reduced_dim=15), + TritonRotationMode.EAGER_REFERENCE, + ) + self.assertEqual( + resolve_triton_rotation_mode(dim_full=25, reduced_dim=16), + TritonRotationMode.GENERIC_TILED, + ) + + +@unittest.skipUnless( + TRITON_CUDA_AVAILABLE, + "SeZM Triton rotation tests require CUDA and Triton.", +) +class TestSeZMTritonEdgeGeometryRBF(unittest.TestCase): + """Validate the Triton edge geometry/RBF chain against eager reference.""" + + def _eager_reference( + self, + *, + coord_flat: torch.Tensor, + center_idx: torch.Tensor, + neighbor_idx: torch.Tensor, + edge_envelope: C3CutoffEnvelope, + radial_basis: RadialBasis, + inner_clamp: InnerClamp | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the eager reference geometry/RBF chain.""" + center_pos = coord_flat.index_select(0, center_idx) + neighbor_pos = coord_flat.index_select(0, neighbor_idx) + edge_vec = neighbor_pos - center_pos + edge_len = torch.sqrt( + torch.sum(edge_vec * edge_vec, dim=-1, keepdim=True) + 1.0e-14 + ) + if inner_clamp is not None: + clamped = inner_clamp(edge_len) + edge_vec = edge_vec * (clamped / edge_len) + edge_len = clamped + edge_env = edge_envelope(edge_len) + edge_rbf = radial_basis(edge_len) + return edge_vec, edge_len, edge_env, edge_rbf + + def test_edge_geometry_rbf_matches_reference_forward_backward(self) -> None: + """Compare fused geometry/RBF chain with eager gather/clamp/envelope/rbf.""" + device = torch.device("cuda") + dtype = torch.float32 + coord_ref = torch.randn( + 12, + 3, + device=device, + dtype=dtype, + requires_grad=True, + ) + coord_triton = coord_ref.detach().clone().requires_grad_(True) + center_idx = torch.randint(0, 12, (9,), device=device, dtype=torch.long) + neighbor_idx = torch.randint(0, 12, (9,), device=device, dtype=torch.long) + edge_envelope = C3CutoffEnvelope(rcut=6.0, exponent=5).to(device) + radial_ref = RadialBasis(rcut=6.0, n_radial=6, dtype=dtype, exponent=7).to( + device + ) + radial_triton = RadialBasis(rcut=6.0, n_radial=6, dtype=dtype, exponent=7).to( + device + ) + radial_triton.load_state_dict(radial_ref.state_dict()) + + out_ref = self._eager_reference( + coord_flat=coord_ref, + center_idx=center_idx, + neighbor_idx=neighbor_idx, + edge_envelope=edge_envelope, + radial_basis=radial_ref, + inner_clamp=None, + ) + out_triton = edge_geometry_rbf_triton( + coord_flat=coord_triton, + center_coord_index=center_idx, + neighbor_coord_index=neighbor_idx, + edge_envelope=edge_envelope, + radial_basis=radial_triton, + eps=1.0e-7, + inner_clamp=None, + ) + for ref, tri in zip(out_ref, out_triton, strict=True): + torch.testing.assert_close(tri, ref, atol=1.0e-5, rtol=1.0e-5) + + grad_out = tuple(torch.randn_like(ref) for ref in out_ref) + grad_coord_ref, grad_freq_ref = torch.autograd.grad( + out_ref, + (coord_ref, radial_ref.adam_freqs), + grad_outputs=grad_out, + ) + grad_coord_triton, grad_freq_triton = torch.autograd.grad( + out_triton, + (coord_triton, radial_triton.adam_freqs), + grad_outputs=grad_out, + ) + torch.testing.assert_close( + grad_coord_triton, + grad_coord_ref, + atol=2.0e-5, + rtol=2.0e-5, + ) + torch.testing.assert_close( + grad_freq_triton, + grad_freq_ref, + atol=2.0e-5, + rtol=2.0e-5, + ) + + def test_edge_geometry_rbf_matches_reference_with_inner_clamp(self) -> None: + """Compare the clamped Triton path with eager reference.""" + device = torch.device("cuda") + dtype = torch.float32 + coord_ref = torch.randn( + 10, + 3, + device=device, + dtype=dtype, + requires_grad=True, + ) + coord_triton = coord_ref.detach().clone().requires_grad_(True) + center_idx = torch.randint(0, 10, (7,), device=device, dtype=torch.long) + neighbor_idx = torch.randint(0, 10, (7,), device=device, dtype=torch.long) + edge_envelope = C3CutoffEnvelope(rcut=6.0, exponent=5).to(device) + radial_ref = RadialBasis(rcut=6.0, n_radial=4, dtype=dtype, exponent=7).to( + device + ) + radial_triton = RadialBasis(rcut=6.0, n_radial=4, dtype=dtype, exponent=7).to( + device + ) + radial_triton.load_state_dict(radial_ref.state_dict()) + inner_clamp = InnerClamp(0.9, 1.3).to(device) + + out_ref = self._eager_reference( + coord_flat=coord_ref, + center_idx=center_idx, + neighbor_idx=neighbor_idx, + edge_envelope=edge_envelope, + radial_basis=radial_ref, + inner_clamp=inner_clamp, + ) + out_triton = edge_geometry_rbf_triton( + coord_flat=coord_triton, + center_coord_index=center_idx, + neighbor_coord_index=neighbor_idx, + edge_envelope=edge_envelope, + radial_basis=radial_triton, + eps=1.0e-7, + inner_clamp=inner_clamp, + ) + for ref, tri in zip(out_ref, out_triton, strict=True): + torch.testing.assert_close(tri, ref, atol=2.0e-5, rtol=2.0e-5) + + loss_ref = sum(x.square().sum() for x in out_ref) + loss_triton = sum(x.square().sum() for x in out_triton) + grad_coord_ref, grad_freq_ref = torch.autograd.grad( + loss_ref, + (coord_ref, radial_ref.adam_freqs), + ) + grad_coord_triton, grad_freq_triton = torch.autograd.grad( + loss_triton, + (coord_triton, radial_triton.adam_freqs), + ) + torch.testing.assert_close( + grad_coord_triton, + grad_coord_ref, + atol=3.0e-5, + rtol=3.0e-5, + ) + torch.testing.assert_close( + grad_freq_triton, + grad_freq_ref, + atol=3.0e-5, + rtol=3.0e-5, + ) + + +@unittest.skipUnless( + TRITON_CUDA_AVAILABLE, + "SeZM Triton rotation tests require CUDA and Triton.", +) +class TestSeZMTritonSO2(unittest.TestCase): + """Validate Triton SO(2) rotation kernels against the eager reference path.""" + + def _require_cuda_bfloat16(self) -> None: + """Skip the mixed-precision Triton tests when CUDA bf16 is unavailable.""" + if not torch.cuda.is_bf16_supported(): + self.skipTest("CUDA bfloat16 is required for mixed-precision Triton tests.") + + def test_rotate_to_local_matches_reference_forward_backward(self) -> None: + """Compare fused Triton rotate-to-local with projected eager matmul.""" + device = torch.device("cuda") + dtype = torch.float32 + n_node = 7 + n_edge = 11 + channels = 8 + for lmax, mmax in ((2, 1), (3, 1)): + dim_full = (lmax + 1) ** 2 + coeff_index = build_m_major_index(lmax, mmax, device=device) + src = torch.randint(0, n_node, (n_edge,), device=device, dtype=torch.long) + x_ref = torch.randn( + n_node, + dim_full, + channels, + device=device, + dtype=dtype, + requires_grad=True, + ) + wigner_ref = torch.randn( + n_edge, + dim_full, + dim_full, + device=device, + dtype=dtype, + requires_grad=True, + ) + x_triton = x_ref.detach().clone().requires_grad_(True) + wigner_triton = wigner_ref.detach().clone().requires_grad_(True) + + out_ref = torch.bmm( + project_D_to_m( + D_full=wigner_ref, + coeff_index_m=coeff_index, + ebed_dim_full=dim_full, + cache=None, + key_lmax=lmax, + key_mmax=mmax, + ), + x_ref.index_select(0, src), + ) + out_triton = rotate_to_local_triton( + x=x_triton, + src=src, + wigner=wigner_triton, + coeff_index=coeff_index, + dim_full=dim_full, + ) + torch.testing.assert_close(out_triton, out_ref, atol=1.0e-5, rtol=1.0e-5) + + grad_out = torch.randn_like(out_ref) + grad_x_ref, grad_wigner_ref = torch.autograd.grad( + out_ref, + (x_ref, wigner_ref), + grad_outputs=grad_out, + ) + grad_x_triton, grad_wigner_triton = torch.autograd.grad( + out_triton, + (x_triton, wigner_triton), + grad_outputs=grad_out, + ) + torch.testing.assert_close( + grad_x_triton, + grad_x_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) + torch.testing.assert_close( + grad_wigner_triton, + grad_wigner_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) + + def test_rotate_back_matches_reference_forward_backward(self) -> None: + """Compare fused Triton rotate-back with projected eager matmul.""" + device = torch.device("cuda") + dtype = torch.float32 + n_edge = 11 + channels = 8 + for lmax, mmax in ((2, 1), (3, 1)): + dim_full = (lmax + 1) ** 2 + coeff_index = build_m_major_index(lmax, mmax, device=device) + reduced_dim = int(coeff_index.numel()) + x_local_ref = torch.randn( + n_edge, + reduced_dim, + channels, + device=device, + dtype=dtype, + requires_grad=True, + ) + wigner_ref = torch.randn( + n_edge, + dim_full, + dim_full, + device=device, + dtype=dtype, + requires_grad=True, + ) + x_local_triton = x_local_ref.detach().clone().requires_grad_(True) + wigner_triton = wigner_ref.detach().clone().requires_grad_(True) + + out_ref = torch.bmm( + project_Dt_from_m( + Dt_full=wigner_ref, + coeff_index_m=coeff_index, + ebed_dim_full=dim_full, + cache=None, + key_lmax=lmax, + key_mmax=mmax, + ), + x_local_ref, + ) + out_triton = rotate_back_triton( + x_local=x_local_triton, + wigner=wigner_triton, + coeff_index=coeff_index, + dim_full=dim_full, + ) + torch.testing.assert_close(out_triton, out_ref, atol=1.0e-5, rtol=1.0e-5) + + grad_out = torch.randn_like(out_ref) + grad_x_ref, grad_wigner_ref = torch.autograd.grad( + out_ref, + (x_local_ref, wigner_ref), + grad_outputs=grad_out, + ) + grad_x_triton, grad_wigner_triton = torch.autograd.grad( + out_triton, + (x_local_triton, wigner_triton), + grad_outputs=grad_out, + ) + torch.testing.assert_close( + grad_x_triton, + grad_x_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) + torch.testing.assert_close( + grad_wigner_triton, + grad_wigner_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) + + def test_rotate_to_local_matches_mixed_precision_reference(self) -> None: + """Compare Triton rotate-to-local with bf16 activations and fp32 Wigner.""" + self._require_cuda_bfloat16() + device = torch.device("cuda") + x_dtype = torch.bfloat16 + wigner_dtype = torch.float32 + n_node = 7 + n_edge = 11 + channels = 8 + for lmax, mmax in ((2, 1), (3, 1)): + dim_full = (lmax + 1) ** 2 + coeff_index = build_m_major_index(lmax, mmax, device=device) + src = torch.randint(0, n_node, (n_edge,), device=device, dtype=torch.long) + x_ref = torch.randn( + n_node, + dim_full, + channels, + device=device, + dtype=x_dtype, + requires_grad=True, + ) + wigner_ref = torch.randn( + n_edge, + dim_full, + dim_full, + device=device, + dtype=wigner_dtype, + requires_grad=True, + ) + x_triton = x_ref.detach().clone().requires_grad_(True) + wigner_triton = wigner_ref.detach().clone().requires_grad_(True) + + out_ref = torch.bmm( + project_D_to_m( + D_full=wigner_ref, + coeff_index_m=coeff_index, + ebed_dim_full=dim_full, + cache=None, + key_lmax=lmax, + key_mmax=mmax, + ).to(dtype=x_dtype), + x_ref.index_select(0, src), + ) + out_triton = rotate_to_local_triton( + x=x_triton, + src=src, + wigner=wigner_triton, + coeff_index=coeff_index, + dim_full=dim_full, + ) + torch.testing.assert_close(out_triton, out_ref, atol=3.0e-2, rtol=3.0e-2) + + grad_out = torch.randn_like(out_ref) + grad_x_ref, grad_wigner_ref = torch.autograd.grad( + out_ref, + (x_ref, wigner_ref), + grad_outputs=grad_out, + ) + grad_x_triton, grad_wigner_triton = torch.autograd.grad( + out_triton, + (x_triton, wigner_triton), + grad_outputs=grad_out, + ) + torch.testing.assert_close( + grad_x_triton, + grad_x_ref, + atol=3.0e-2, + rtol=3.0e-2, + ) + torch.testing.assert_close( + grad_wigner_triton, + grad_wigner_ref, + atol=3.0e-2, + rtol=3.0e-2, + ) + + def test_rotate_back_matches_mixed_precision_reference(self) -> None: + """Compare Triton rotate-back with bf16 activations and fp32 Wigner.""" + self._require_cuda_bfloat16() + device = torch.device("cuda") + x_dtype = torch.bfloat16 + wigner_dtype = torch.float32 + n_edge = 11 + channels = 8 + for lmax, mmax in ((2, 1), (3, 1)): + dim_full = (lmax + 1) ** 2 + coeff_index = build_m_major_index(lmax, mmax, device=device) + reduced_dim = int(coeff_index.numel()) + x_local_ref = torch.randn( + n_edge, + reduced_dim, + channels, + device=device, + dtype=x_dtype, + requires_grad=True, + ) + wigner_ref = torch.randn( + n_edge, + dim_full, + dim_full, + device=device, + dtype=wigner_dtype, + requires_grad=True, + ) + x_local_triton = x_local_ref.detach().clone().requires_grad_(True) + wigner_triton = wigner_ref.detach().clone().requires_grad_(True) + + out_ref = torch.bmm( + project_Dt_from_m( + Dt_full=wigner_ref, + coeff_index_m=coeff_index, + ebed_dim_full=dim_full, + cache=None, + key_lmax=lmax, + key_mmax=mmax, + ).to(dtype=x_dtype), + x_local_ref, + ) + out_triton = rotate_back_triton( + x_local=x_local_triton, + wigner=wigner_triton, + coeff_index=coeff_index, + dim_full=dim_full, + ) + torch.testing.assert_close(out_triton, out_ref, atol=3.0e-2, rtol=3.0e-2) + + grad_out = torch.randn_like(out_ref) + grad_x_ref, grad_wigner_ref = torch.autograd.grad( + out_ref, + (x_local_ref, wigner_ref), + grad_outputs=grad_out, + ) + grad_x_triton, grad_wigner_triton = torch.autograd.grad( + out_triton, + (x_local_triton, wigner_triton), + grad_outputs=grad_out, + ) + torch.testing.assert_close( + grad_x_triton, + grad_x_ref, + atol=3.0e-2, + rtol=3.0e-2, + ) + torch.testing.assert_close( + grad_wigner_triton, + grad_wigner_ref, + atol=3.0e-2, + rtol=3.0e-2, + ) + + def test_rotate_to_local_matches_bfloat16_autocast_semantics(self) -> None: + """Use the activation dtype selected by AMP for Triton rotate-to-local.""" + self._require_cuda_bfloat16() + device = torch.device("cuda") + act_dtype = torch.bfloat16 + wigner_dtype = torch.float32 + n_node = 7 + n_edge = 11 + dim_full = 16 + channels = 8 + coeff_index = build_m_major_index(3, 1, device=device) + src = torch.randint(0, n_node, (n_edge,), device=device, dtype=torch.long) + x_ref = torch.randn( + n_node, + dim_full, + channels, + device=device, + dtype=act_dtype, + requires_grad=True, + ) + wigner_ref = torch.randn( + n_edge, + dim_full, + dim_full, + device=device, + dtype=wigner_dtype, + requires_grad=True, + ) + x_triton = x_ref.detach().clone().requires_grad_(True) + wigner_triton = wigner_ref.detach().clone().requires_grad_(True) + + D_m_prime = project_D_to_m( + D_full=wigner_ref, + coeff_index_m=coeff_index, + ebed_dim_full=dim_full, + cache=None, + key_lmax=3, + key_mmax=1, + ).to(dtype=act_dtype) + out_ref = torch.bmm(D_m_prime, x_ref.index_select(0, src)) + out_triton = rotate_to_local_triton( + x=x_triton, + src=src, + wigner=wigner_triton, + coeff_index=coeff_index, + dim_full=dim_full, + ) + torch.testing.assert_close(out_triton, out_ref, atol=5.0e-2, rtol=5.0e-2) + + grad_out = torch.randn_like(out_ref) + grad_x_ref, grad_wigner_ref = torch.autograd.grad( + out_ref, + (x_ref, wigner_ref), + grad_outputs=grad_out, + ) + grad_x_triton, grad_wigner_triton = torch.autograd.grad( + out_triton, + (x_triton, wigner_triton), + grad_outputs=grad_out, + ) + torch.testing.assert_close( + grad_x_triton, + grad_x_ref, + atol=5.0e-2, + rtol=5.0e-2, + ) + torch.testing.assert_close( + grad_wigner_triton, + grad_wigner_ref, + atol=5.0e-2, + rtol=5.0e-2, + ) + + def test_rotate_back_matches_bfloat16_autocast_semantics(self) -> None: + """Use the activation dtype selected by AMP for Triton rotate-back.""" + self._require_cuda_bfloat16() + device = torch.device("cuda") + act_dtype = torch.bfloat16 + wigner_dtype = torch.float32 + n_edge = 11 + dim_full = 16 + channels = 8 + coeff_index = build_m_major_index(3, 1, device=device) + reduced_dim = int(coeff_index.numel()) + x_local_ref = torch.randn( + n_edge, + reduced_dim, + channels, + device=device, + dtype=act_dtype, + requires_grad=True, + ) + wigner_ref = torch.randn( + n_edge, + dim_full, + dim_full, + device=device, + dtype=wigner_dtype, + requires_grad=True, + ) + x_local_triton = x_local_ref.detach().clone().requires_grad_(True) + wigner_triton = wigner_ref.detach().clone().requires_grad_(True) + + Dt_from_m = project_Dt_from_m( + Dt_full=wigner_ref, + coeff_index_m=coeff_index, + ebed_dim_full=dim_full, + cache=None, + key_lmax=3, + key_mmax=1, + ).to(dtype=act_dtype) + out_ref = torch.bmm(Dt_from_m, x_local_ref) + out_triton = rotate_back_triton( + x_local=x_local_triton, + wigner=wigner_triton, + coeff_index=coeff_index, + dim_full=dim_full, + ) + torch.testing.assert_close(out_triton, out_ref, atol=5.0e-2, rtol=5.0e-2) + + grad_out = torch.randn_like(out_ref) + grad_x_ref, grad_wigner_ref = torch.autograd.grad( + out_ref, + (x_local_ref, wigner_ref), + grad_outputs=grad_out, + ) + grad_x_triton, grad_wigner_triton = torch.autograd.grad( + out_triton, + (x_local_triton, wigner_triton), + grad_outputs=grad_out, + ) + torch.testing.assert_close( + grad_x_triton, + grad_x_ref, + atol=5.0e-2, + rtol=5.0e-2, + ) + torch.testing.assert_close( + grad_wigner_triton, + grad_wigner_ref, + atol=5.0e-2, + rtol=5.0e-2, + ) + + def test_generic_small_k_falls_back_to_reference_forward_backward(self) -> None: + """Fallback to eager bmm when generic Triton tiles would have K < 16.""" + device = torch.device("cuda") + dtype = torch.float32 + lmax, mmax = 4, 0 + dim_full = (lmax + 1) ** 2 + n_node = 7 + n_edge = 11 + channels = 8 + coeff_index = build_m_major_index(lmax, mmax, device=device) + self.assertLess(int(coeff_index.numel()), 16) + + src = torch.randint(0, n_node, (n_edge,), device=device, dtype=torch.long) + x_ref = torch.randn( + n_node, + dim_full, + channels, + device=device, + dtype=dtype, + requires_grad=True, + ) + wigner_ref = torch.randn( + n_edge, + dim_full, + dim_full, + device=device, + dtype=dtype, + requires_grad=True, + ) + x_triton = x_ref.detach().clone().requires_grad_(True) + wigner_triton = wigner_ref.detach().clone().requires_grad_(True) + + out_ref = torch.bmm( + project_D_to_m( + D_full=wigner_ref, + coeff_index_m=coeff_index, + ebed_dim_full=dim_full, + cache=None, + key_lmax=lmax, + key_mmax=mmax, + ), + x_ref.index_select(0, src), + ) + out_triton = rotate_to_local_triton( + x=x_triton, + src=src, + wigner=wigner_triton, + coeff_index=coeff_index, + dim_full=dim_full, + ) + torch.testing.assert_close(out_triton, out_ref, atol=1.0e-5, rtol=1.0e-5) + + grad_out = torch.randn_like(out_ref) + grad_x_ref, grad_wigner_ref = torch.autograd.grad( + out_ref, + (x_ref, wigner_ref), + grad_outputs=grad_out, + ) + grad_x_triton, grad_wigner_triton = torch.autograd.grad( + out_triton, + (x_triton, wigner_triton), + grad_outputs=grad_out, + ) + torch.testing.assert_close( + grad_x_triton, + grad_x_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) + torch.testing.assert_close( + grad_wigner_triton, + grad_wigner_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) + + x_local_ref = torch.randn( + n_edge, + int(coeff_index.numel()), + channels, + device=device, + dtype=dtype, + requires_grad=True, + ) + wigner_back_ref = torch.randn( + n_edge, + dim_full, + dim_full, + device=device, + dtype=dtype, + requires_grad=True, + ) + x_local_triton = x_local_ref.detach().clone().requires_grad_(True) + wigner_back_triton = wigner_back_ref.detach().clone().requires_grad_(True) + + out_back_ref = torch.bmm( + project_Dt_from_m( + Dt_full=wigner_back_ref, + coeff_index_m=coeff_index, + ebed_dim_full=dim_full, + cache=None, + key_lmax=lmax, + key_mmax=mmax, + ), + x_local_ref, + ) + out_back_triton = rotate_back_triton( + x_local=x_local_triton, + wigner=wigner_back_triton, + coeff_index=coeff_index, + dim_full=dim_full, + ) + torch.testing.assert_close( + out_back_triton, + out_back_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) + + grad_back = torch.randn_like(out_back_ref) + grad_x_local_ref, grad_wigner_back_ref = torch.autograd.grad( + out_back_ref, + (x_local_ref, wigner_back_ref), + grad_outputs=grad_back, + ) + grad_x_local_triton, grad_wigner_back_triton = torch.autograd.grad( + out_back_triton, + (x_local_triton, wigner_back_triton), + grad_outputs=grad_back, + ) + torch.testing.assert_close( + grad_x_local_triton, + grad_x_local_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) + torch.testing.assert_close( + grad_wigner_back_triton, + grad_wigner_back_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) + + def test_generic_large_k_matches_reference_forward_backward(self) -> None: + """Exercise the true generic Triton path when reduced_dim >= 16.""" + device = torch.device("cuda") + dtype = torch.float32 + n_node = 7 + n_edge = 11 + channels = 8 + for lmax, mmax in ((4, 2), (4, 4), (5, 2)): + dim_full = (lmax + 1) ** 2 + coeff_index = build_m_major_index(lmax, mmax, device=device) + self.assertGreaterEqual(int(coeff_index.numel()), 16) + + src = torch.randint(0, n_node, (n_edge,), device=device, dtype=torch.long) + x_ref = torch.randn( + n_node, + dim_full, + channels, + device=device, + dtype=dtype, + requires_grad=True, + ) + wigner_ref = torch.randn( + n_edge, + dim_full, + dim_full, + device=device, + dtype=dtype, + requires_grad=True, + ) + x_triton = x_ref.detach().clone().requires_grad_(True) + wigner_triton = wigner_ref.detach().clone().requires_grad_(True) + + out_ref = torch.bmm( + project_D_to_m( + D_full=wigner_ref, + coeff_index_m=coeff_index, + ebed_dim_full=dim_full, + cache=None, + key_lmax=lmax, + key_mmax=mmax, + ), + x_ref.index_select(0, src), + ) + out_triton = rotate_to_local_triton( + x=x_triton, + src=src, + wigner=wigner_triton, + coeff_index=coeff_index, + dim_full=dim_full, + ) + torch.testing.assert_close(out_triton, out_ref, atol=1.0e-5, rtol=1.0e-5) + + grad_out = torch.randn_like(out_ref) + grad_x_ref, grad_wigner_ref = torch.autograd.grad( + out_ref, + (x_ref, wigner_ref), + grad_outputs=grad_out, + ) + grad_x_triton, grad_wigner_triton = torch.autograd.grad( + out_triton, + (x_triton, wigner_triton), + grad_outputs=grad_out, + ) + torch.testing.assert_close( + grad_x_triton, + grad_x_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) + torch.testing.assert_close( + grad_wigner_triton, + grad_wigner_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) + + x_local_ref = torch.randn( + n_edge, + int(coeff_index.numel()), + channels, + device=device, + dtype=dtype, + requires_grad=True, + ) + wigner_back_ref = torch.randn( + n_edge, + dim_full, + dim_full, + device=device, + dtype=dtype, + requires_grad=True, + ) + x_local_triton = x_local_ref.detach().clone().requires_grad_(True) + wigner_back_triton = wigner_back_ref.detach().clone().requires_grad_(True) + + out_back_ref = torch.bmm( + project_Dt_from_m( + Dt_full=wigner_back_ref, + coeff_index_m=coeff_index, + ebed_dim_full=dim_full, + cache=None, + key_lmax=lmax, + key_mmax=mmax, + ), + x_local_ref, + ) + out_back_triton = rotate_back_triton( + x_local=x_local_triton, + wigner=wigner_back_triton, + coeff_index=coeff_index, + dim_full=dim_full, + ) + torch.testing.assert_close( + out_back_triton, + out_back_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) + + grad_back = torch.randn_like(out_back_ref) + grad_x_local_ref, grad_wigner_back_ref = torch.autograd.grad( + out_back_ref, + (x_local_ref, wigner_back_ref), + grad_outputs=grad_back, + ) + grad_x_local_triton, grad_wigner_back_triton = torch.autograd.grad( + out_back_triton, + (x_local_triton, wigner_back_triton), + grad_outputs=grad_back, + ) + torch.testing.assert_close( + grad_x_local_triton, + grad_x_local_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) + torch.testing.assert_close( + grad_wigner_back_triton, + grad_wigner_back_ref, + atol=1.0e-5, + rtol=1.0e-5, + ) diff --git a/source/tests/pt/model/test_sezm_export.py b/source/tests/pt/model/test_sezm_export.py new file mode 100644 index 0000000000..d896cda6fe --- /dev/null +++ b/source/tests/pt/model/test_sezm_export.py @@ -0,0 +1,734 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Unit tests for SeZM's AOTInductor ``.pt2`` freeze pipeline. + +Layout mirrors ``source/tests/pt_expt/model/test_export_pipeline.py``: +a tiny fp64 SeZM model is built on the fly, so the tests are fully +self-contained and have no external-artefact dependency. +""" + +from __future__ import ( + annotations, +) + +import contextlib +import copy +import json +import tempfile +import unittest +import zipfile +from pathlib import ( + Path, +) +from typing import ( + TYPE_CHECKING, +) +from unittest import ( + mock, +) + +import numpy as np +import torch + +from deepmd.pt.entrypoints.freeze_pt2 import ( + _build_dynamic_shapes, + _collect_metadata, + _make_sample_inputs, + _resolve_nframes, + freeze_sezm_to_pt2, + is_sezm_checkpoint, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.train.wrapper import ( + ModelWrapper, +) + +if TYPE_CHECKING: + from collections.abc import ( + Iterator, + ) + +# Tracing and numerical parity always run on CPU — see module docstring +# of deepmd/pt/entrypoints/freeze_pt2.py for why. +_CPU = torch.device("cpu") + +_REQUIRED_OUTPUT_KEYS = { + "energy", + "energy_redu", + "energy_derv_r", + "energy_derv_c", + "energy_derv_c_redu", +} + + +def _tiny_sezm_model_params() -> dict: + """Minimal fp64 SeZM config for self-contained export tests. + + ``precision="float64"`` is what unlocks the ``rtol=1e-10, atol=1e-10`` + parity pt_expt enforces; fp32 accumulation alone drifts in the 1e-6 + range. All other knobs are tuned to keep ``make_fx`` tracing time + in the low-single-digit seconds. + """ + return { + "type": "SeZM", + "type_map": ["A", "B"], + "descriptor": { + "type": "SeZM", + "sel": [2, 2], + "rcut": 3.0, + "channels": 4, + "n_focus": 1, + "n_radial": 3, + "radial_mlp": [6], + "use_env_seed": True, + "l_schedule": [1, 0], + "mmax": 1, + "so2_norm": False, + "so2_layers": 1, + "n_atten_head": 1, + "sandwich_norm": [True, False, True, False], + "ffn_neurons": 8, + "ffn_blocks": 1, + "s2_activation": [False, True], + "mlp_bias": False, + "layer_scale": False, + "use_amp": False, + "activation_function": "silu", + "glu_activation": True, + "precision": "float64", + "seed": 7, + }, + "fitting_net": { + "neuron": [8], + "activation_function": "silu", + "precision": "float64", + "seed": 7, + }, + "use_compile": False, + } + + +def _tiny_sezm_spin_model_params() -> dict: + """Minimal fp64 SeZM spin config for freeze routing tests.""" + params = copy.deepcopy(_tiny_sezm_model_params()) + params["type_map"] = ["O", "H"] + params["spin"] = { + "use_spin": [True, False], + "virtual_scale": 0.2, + } + return params + + +def _build_tiny_sezm_model() -> torch.nn.Module: + """Fresh tiny SeZM model on CPU, in eval mode.""" + model = get_model(_tiny_sezm_model_params()) + model.eval() + model.to(_CPU) + return model + + +def _write_tiny_sezm_checkpoint(tmp_path: Path, params: dict) -> Path: + """Serialise a tiny SeZM model to a ``.pt`` in the trainer's layout. + + ``ModelWrapper`` populates ``state_dict["_extra_state"]`` from its + ``get_extra_state`` hook, which is exactly the shape + :func:`freeze_sezm_to_pt2` expects. + """ + model = get_model(params) + model.eval() + model.to(_CPU) + wrapper = ModelWrapper(model, model_params=copy.deepcopy(params)) + ckpt_path = tmp_path / "tiny_sezm.pt" + torch.save({"model": wrapper.state_dict()}, ckpt_path) + return ckpt_path + + +def _make_sample(model: torch.nn.Module, *, nloc: int, start: int) -> tuple: + """Build a forward_common_lower sample on CPU via the freeze helper.""" + _, sample = _resolve_nframes(model, nloc=nloc, device=_CPU, start=start) + return sample + + +@contextlib.contextmanager +def _clear_default_device() -> Iterator[None]: + """Clear the pt-test ``cuda:9999999`` sentinel default device. + + ``source/tests/pt/__init__.py`` sets the default device to an + invalid ``"cuda:9999999"`` so that tests relying on implicit + placement fail loudly. The AOTI / export pipeline in PyTorch 2.11 + allocates unnamed tensors (e.g. inside ``PhiloxStateTracker``) + without an explicit device and would trip the guard. Matches the + pattern used by ``pt_expt/test_change_bias.py``. + """ + saved = torch.get_default_device() + torch.set_default_device(None) + try: + yield + finally: + torch.set_default_device(saved) + + +def _eager_forward( + model: torch.nn.Module, + sample_inputs: tuple, +) -> dict[str, torch.Tensor]: + """Mirror the trace closure: fresh leaf coord + ``requires_grad=True``.""" + ext_coord, ext_atype, nlist, mapping, fparam, aparam = sample_inputs + eager_coord = ext_coord.detach().clone().requires_grad_(True) + return model.forward_common_lower( + eager_coord, + ext_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + extra_nlist_sort=model.need_sorted_nlist_for_lower(), + ) + + +class TestSeZMExportPipeline(unittest.TestCase): + """Bitwise trace / export / ``.pte`` round-trip parity (``rtol=1e-10``). + + The ExportedProgram is a pure FX graph (no Inductor codegen), so + it must reproduce the eager result exactly. Drift here implies a + bug in ``forward_common_lower_exportable`` or the dynamic-shape + spec, not in AOTI. The pipeline is built once per class because + ``make_fx`` and ``.pte`` round-trip dominate wall time. + """ + + @classmethod + def setUpClass(cls) -> None: + with _clear_default_device(): + cls.model = _build_tiny_sezm_model() + cls.sample_inputs = _make_sample(cls.model, nloc=7, start=2) + cls.traced, cls.loaded, cls._pte_tmp = cls._build_pipeline( + cls.model, cls.sample_inputs + ) + + @classmethod + def tearDownClass(cls) -> None: + cls._pte_tmp.close() + + def setUp(self) -> None: + self._device_ctx = _clear_default_device() + self._device_ctx.__enter__() + + def tearDown(self) -> None: + self._device_ctx.__exit__(None, None, None) + + @staticmethod + def _build_pipeline( + model: torch.nn.Module, + sample_inputs: tuple, + ) -> tuple[ + torch.fx.GraphModule, + torch.nn.Module, + tempfile._TemporaryFileWrapper, + ]: + traced = model.forward_common_lower_exportable( + *sample_inputs, + do_atomic_virial=True, + ) + exported = torch.export.export( + traced, + sample_inputs, + dynamic_shapes=_build_dynamic_shapes(sample_inputs), + strict=False, + prefer_deferred_runtime_asserts_over_guards=True, + ) + # Keep the tempfile alive for the class lifetime so the loaded + # module can lazily reference its backing bytes. + pte_tmp = tempfile.NamedTemporaryFile(suffix=".pte", delete=True) + torch.export.save(exported, pte_tmp.name) + loaded = torch.export.load(pte_tmp.name).module() + return traced, loaded, pte_tmp + + def _assert_dict_allclose( + self, + ref: dict[str, torch.Tensor], + test_dict: dict[str, torch.Tensor] | object, + *, + context: str, + ) -> None: + test_pairs = ( + list(test_dict.items()) + if hasattr(test_dict, "items") + else list(zip(ref.keys(), test_dict, strict=True)) + ) + for key, test_val in test_pairs: + self.assertIn(key, ref, msg=f"{context}: unexpected output key {key!r}") + ref_val = ref[key] + self.assertEqual( + tuple(ref_val.shape), + tuple(test_val.shape), + msg=( + f"{context} ({key}): shape mismatch " + f"ref={tuple(ref_val.shape)} vs test={tuple(test_val.shape)}" + ), + ) + np.testing.assert_allclose( + ref_val.detach().cpu().numpy(), + test_val.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"{context}: {key}", + ) + + def test_traced_matches_eager_same_shape(self) -> None: + eager_out = _eager_forward(self.model, self.sample_inputs) + traced_out = self.traced(*self.sample_inputs) + self._assert_dict_allclose( + eager_out, traced_out, context="traced vs eager (trace shape)" + ) + + def test_loaded_pte_matches_eager_same_shape(self) -> None: + eager_out = _eager_forward(self.model, self.sample_inputs) + loaded_out = self.loaded(*self.sample_inputs) + self._assert_dict_allclose( + eager_out, loaded_out, context="loaded (.pte) vs eager (trace shape)" + ) + + def test_loaded_pte_matches_eager_different_shape(self) -> None: + # start=3 retargets the nframes symbol away from the trace + # value of 2; nloc=11 exercises the nloc symbol. + infer_inputs = _make_sample(self.model, nloc=11, start=3) + eager_out = _eager_forward(self.model, infer_inputs) + loaded_out = self.loaded(*infer_inputs) + self._assert_dict_allclose( + eager_out, loaded_out, context="loaded (.pte) vs eager (infer shape)" + ) + + +class _FrozenPt2Fixture: + """Shared setUp/tearDown: freeze a tiny SeZM checkpoint to ``.pt2`` once. + + AOTInductor compilation costs a few seconds; classes that share this + fixture avoid paying that cost twice. ``cls.ckpt_path`` / ``cls.out_path`` + / ``cls.params`` are populated and live for the lifetime of the class. + """ + + params: dict + ckpt_path: Path + out_path: Path + + @classmethod + def setUpClass(cls) -> None: + cls._tmpdir = tempfile.TemporaryDirectory() + tmp_root = Path(cls._tmpdir.name) + cls.params = _tiny_sezm_model_params() + with _clear_default_device(): + cls.ckpt_path = _write_tiny_sezm_checkpoint(tmp_root, cls.params) + cls.out_path = tmp_root / "frozen_sezm.pt2" + freeze_sezm_to_pt2(str(cls.ckpt_path), str(cls.out_path), device=_CPU) + + @classmethod + def tearDownClass(cls) -> None: + cls._tmpdir.cleanup() + + def setUp(self) -> None: + self._device_ctx = _clear_default_device() + self._device_ctx.__enter__() + + def tearDown(self) -> None: + self._device_ctx.__exit__(None, None, None) + + +class TestSeZMExportArchive(_FrozenPt2Fixture, unittest.TestCase): + """AOTI ``.pt2`` archive structure + load-and-run smoke. + + Numerical parity of the compiled ``.pt2`` is covered by the + pipeline class through the ``.pte`` round-trip; here we only + verify the archive layout and the C++ consumer contract. + """ + + def test_detector_recognises_sezm(self) -> None: + self.assertTrue(is_sezm_checkpoint(str(self.ckpt_path))) + + def test_archive_metadata(self) -> None: + """ZIP layout + metadata fields match the DeepPotPTExpt contract.""" + self.assertTrue(zipfile.is_zipfile(str(self.out_path))) + with zipfile.ZipFile(str(self.out_path), "r") as zf: + names = zf.namelist() + self.assertIn("model/extra/metadata.json", names) + self.assertIn("model/extra/model_def_script.json", names) + metadata = json.loads(zf.read("model/extra/metadata.json").decode("utf-8")) + mds = json.loads( + zf.read("model/extra/model_def_script.json").decode("utf-8") + ) + + for key in ( + "type_map", + "ntypes", + "rcut", + "sel", + "dim_fparam", + "dim_aparam", + "dim_chg_spin", + "mixed_types", + "has_default_fparam", + "default_chg_spin", + "output_keys", + "fitting_output_defs", + "sel_type", + "is_spin", + ): + self.assertIn(key, metadata) + + self.assertEqual(metadata["type_map"], self.params["type_map"]) + self.assertEqual(metadata["ntypes"], len(self.params["type_map"])) + self.assertEqual(metadata["rcut"], self.params["descriptor"]["rcut"]) + self.assertEqual(list(metadata["sel"]), list(self.params["descriptor"]["sel"])) + self.assertTrue(metadata["mixed_types"]) + self.assertFalse(metadata["is_spin"]) + self.assertEqual(metadata["dim_fparam"], 0) + self.assertEqual(metadata["dim_aparam"], 0) + self.assertEqual(metadata["dim_chg_spin"], 0) + self.assertIsNone(metadata["default_chg_spin"]) + # sel_type must agree with the eager SeZM model — this is the + # field DeepEval._init_from_metadata reads when no model.json is + # present. DPA4 / SeZM's dpa4_ener fitting head enumerates every type, + # so the list is non-empty in general. + probe = _build_tiny_sezm_model() + self.assertEqual(list(metadata["sel_type"]), list(probe.get_sel_type())) + self.assertTrue(_REQUIRED_OUTPUT_KEYS.issubset(set(metadata["output_keys"]))) + + # model_def_script preserves the training params verbatim. + self.assertEqual(str(mds.get("type", "")).lower(), "sezm") + self.assertEqual(mds.get("use_compile"), self.params["use_compile"]) + + def test_aoti_load_and_run_returns_finite_outputs(self) -> None: + from torch._inductor import ( + aoti_load_package, + ) + + loader = aoti_load_package(str(self.out_path)) + probe = _build_tiny_sezm_model() + sample_inputs = _make_sample(probe, nloc=5, start=2) + outs = loader(*sample_inputs) + + # AOTICompiledModel returns an immutable_dict on PyTorch ≥2.11 + # and a flat tuple on older versions; normalise both. + with zipfile.ZipFile(str(self.out_path), "r") as zf: + output_keys = json.loads( + zf.read("model/extra/metadata.json").decode("utf-8") + )["output_keys"] + if hasattr(outs, "items"): + out_map = dict(outs.items()) + self.assertEqual(list(out_map.keys()), output_keys) + else: + self.assertEqual(len(outs), len(output_keys)) + out_map = dict(zip(output_keys, outs, strict=True)) + + for key in ("energy_redu", "energy_derv_r", "energy_derv_c_redu"): + self.assertIn(key, out_map) + self.assertTrue(torch.isfinite(out_map[key]).all().item()) + + +class TestSeZMViaDeepPot(_FrozenPt2Fixture, unittest.TestCase): + """Integration through the standard :class:`deepmd.infer.DeepPot` entry. + + Locks in the contract that makes ``dp test -m frozen.pt2`` and the + deepmd ASE calculator work on a SeZM-produced archive. Everything + here goes through the public backend-agnostic API — + :class:`DeepPot` dispatches ``.pt2`` to + :class:`deepmd.pt_expt.infer.deep_eval.DeepEval`, which since the + metadata-only patch no longer needs ``extra/model.json``. + + Numerical tolerance is looser than the ``.pte`` pipeline tests + because AOTInductor fuses pointwise / reduction kernels and the + fused accumulation order differs from eager; the intent here is + contract parity, not bitwise parity. + """ + + RTOL = 1e-5 + ATOL = 1e-7 + + @classmethod + def setUpClass(cls) -> None: + # The ``.pt2`` archive is compiled on CPU by the fixture; AOTI + # packages are device-locked, so ``pt_expt.DeepEval``'s input + # preparation must also place tensors on CPU — otherwise + # ``_pt2_runner(...)`` segfaults on dtype/device mismatch. + # ``_prepare_inputs`` does a function-local + # ``from deepmd.pt_expt.utils.env import DEVICE``, so patching + # the module attribute is enough (no rebinding required). + import deepmd.pt_expt.utils.env as _pt_expt_env + + cls._orig_pt_expt_device = _pt_expt_env.DEVICE + _pt_expt_env.DEVICE = _CPU + + super().setUpClass() + + # Late import: building the deepmd Backend registry is cheap, but + # doing it at collection time conflicts with the conftest + # default-device sentinel used elsewhere in this package. + from deepmd.infer import ( + DeepPot, + ) + + cls.dp = DeepPot(str(cls.out_path)) + + # A deterministic bulk sample; coord is centred in a cubic box + # well inside the periodic image, and the atype distribution + # exercises both type-0 and type-1 slots of sel=[2, 2]. + rng = np.random.default_rng(2026) + cls.natoms = 5 + cls.atype = np.array([0, 1, 0, 1, 0], dtype=np.int32) + box_edge = cls.params["descriptor"]["rcut"] * 3.0 + cls.coord = ( + rng.random((1, cls.natoms, 3), dtype=np.float64) * box_edge * 0.4 + + box_edge * 0.3 + ) + cls.cell = (np.eye(3, dtype=np.float64) * box_edge).reshape(1, 9) + + @classmethod + def tearDownClass(cls) -> None: + import deepmd.pt_expt.utils.env as _pt_expt_env + + _pt_expt_env.DEVICE = cls._orig_pt_expt_device + super().tearDownClass() + + def _eager_energy_force_virial(self) -> tuple[np.ndarray, ...]: + """Run the eager SeZMModel forward and return arrays shaped like DeepPot.""" + model = _build_tiny_sezm_model() + wrapper = ModelWrapper(model, model_params=copy.deepcopy(self.params)) + raw = torch.load(self.ckpt_path, map_location=_CPU, weights_only=False) + wrapper.load_state_dict(raw["model"]) + model.eval() + + coord_t = torch.tensor(self.coord, dtype=torch.float64).requires_grad_(True) + atype_t = torch.tensor(self.atype, dtype=torch.int64).unsqueeze(0) + box_t = torch.tensor(self.cell, dtype=torch.float64) + out = model.forward(coord_t, atype_t, box_t, do_atomic_virial=True) + return ( + out["energy"].detach().cpu().numpy(), + out["force"].detach().cpu().numpy(), + out["virial"].detach().cpu().numpy(), + out["atom_energy"].detach().cpu().numpy(), + ) + + # ---- metadata accessors ---------------------------------------- + + def test_deeppot_metadata_accessors(self) -> None: + dp = self.dp + self.assertEqual(list(dp.deep_eval.get_type_map()), self.params["type_map"]) + self.assertEqual(dp.deep_eval.get_ntypes(), len(self.params["type_map"])) + self.assertAlmostEqual( + dp.deep_eval.get_rcut(), self.params["descriptor"]["rcut"] + ) + self.assertEqual(dp.deep_eval.get_dim_fparam(), 0) + self.assertEqual(dp.deep_eval.get_dim_aparam(), 0) + # get_sel_type() must agree with the eager model; SeZM's + # ``dpa4_ener`` fitting head selects every type by enumerating them, + # so the concrete value is ``list(range(ntypes))`` rather than ``[]`` + # — both are valid DeepPot conventions for "all types selected". + eager = _build_tiny_sezm_model() + self.assertEqual(list(dp.deep_eval.get_sel_type()), list(eager.get_sel_type())) + self.assertFalse(dp.deep_eval.get_has_spin()) + + def test_deeppot_is_metadata_only(self) -> None: + """SeZM's .pt2 omits model.json, so the loader must take the fallback.""" + self.assertIsNone(self.dp.deep_eval._dpmodel) + + # ---- numeric parity against eager ------------------------------- + + def test_deeppot_eval_matches_eager(self) -> None: + e_ref, f_ref, v_ref, atom_e_ref = self._eager_energy_force_virial() + e, f, v = self.dp.eval(self.coord, self.cell, self.atype, atomic=False)[:3] + np.testing.assert_allclose( + e, + e_ref.reshape(e.shape), + rtol=self.RTOL, + atol=self.ATOL, + err_msg="energy mismatch (DeepPot vs eager)", + ) + np.testing.assert_allclose( + f, + f_ref.reshape(f.shape), + rtol=self.RTOL, + atol=self.ATOL, + err_msg="force mismatch (DeepPot vs eager)", + ) + np.testing.assert_allclose( + v, + v_ref.reshape(v.shape), + rtol=self.RTOL, + atol=self.ATOL, + err_msg="virial mismatch (DeepPot vs eager)", + ) + + def test_deeppot_eval_atomic_matches_eager(self) -> None: + """``atomic=True`` additionally returns atom_energy and atom_virial.""" + e_ref, _, _, atom_e_ref = self._eager_energy_force_virial() + out = self.dp.eval(self.coord, self.cell, self.atype, atomic=True) + e, _, _, atom_e, _ = out + np.testing.assert_allclose( + e, + e_ref.reshape(e.shape), + rtol=self.RTOL, + atol=self.ATOL, + err_msg="energy mismatch (atomic path)", + ) + np.testing.assert_allclose( + atom_e, + atom_e_ref.reshape(atom_e.shape), + rtol=self.RTOL, + atol=self.ATOL, + err_msg="atom_energy mismatch (atomic path)", + ) + + +class TestSeZMFreezeGuards(unittest.TestCase): + """Error paths: detector rejections and CLI-level ``NotImplementedError``s.""" + + def test_metadata_records_ntypes_when_type_map_is_empty(self) -> None: + """Metadata-only loaders need ntypes even when no type names are exported.""" + model = _build_tiny_sezm_model() + with mock.patch.object(model, "get_type_map", return_value=[]): + metadata = _collect_metadata(model, ["energy"]) + + self.assertEqual(metadata["type_map"], []) + self.assertEqual(metadata["ntypes"], model.get_descriptor().get_ntypes()) + + def test_charge_spin_export_sample_has_runtime_input_slot(self) -> None: + """Charge/spin-conditioned exports should not bake defaults into the graph.""" + params = _tiny_sezm_model_params() + params["descriptor"]["add_chg_spin_ebd"] = True + params["descriptor"]["default_chg_spin"] = [0.0, 1.0] + model = get_model(params).to(_CPU).eval() + + sample_inputs = _make_sample_inputs(model, nframes=5, nloc=7, device=_CPU) + metadata = _collect_metadata(model, ["energy"]) + dynamic_shapes = _build_dynamic_shapes(sample_inputs) + + self.assertEqual(len(sample_inputs), 7) + self.assertEqual(sample_inputs[-1].shape, (5, 2)) + self.assertEqual(len(dynamic_shapes), len(sample_inputs)) + self.assertEqual(metadata["dim_chg_spin"], 2) + self.assertEqual(metadata["default_chg_spin"], [0.0, 1.0]) + + def test_is_sezm_checkpoint_rejects_non_sezm(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + ckpt_path = Path(tmp) / "ener.pt" + torch.save( + {"model": {"_extra_state": {"model_params": {"type": "ener"}}}}, + ckpt_path, + ) + self.assertFalse(is_sezm_checkpoint(str(ckpt_path))) + + def test_is_sezm_checkpoint_accepts_dpa4_alias(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + ckpt_path = Path(tmp) / "dpa4.pt" + torch.save( + {"model": {"_extra_state": {"model_params": {"type": "dpa4"}}}}, + ckpt_path, + ) + self.assertTrue(is_sezm_checkpoint(str(ckpt_path))) + + def test_freeze_rejects_head_selection(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + ckpt_path = Path(tmp) / "fake.pt" + torch.save( + {"model": {"_extra_state": {"model_params": {"type": "SeZM"}}}}, + ckpt_path, + ) + out = Path(tmp) / "out.pt2" + with self.assertRaises(NotImplementedError): + freeze_sezm_to_pt2(str(ckpt_path), str(out), head="branch") + + def test_freeze_requires_head_for_multi_task(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + ckpt_path = Path(tmp) / "multi.pt" + torch.save( + { + "model": { + "_extra_state": { + "model_params": { + "type": "SeZM", + "model_dict": {"branch": {}}, + } + } + } + }, + ckpt_path, + ) + out = Path(tmp) / "out.pt2" + with self.assertRaises(ValueError): + freeze_sezm_to_pt2(str(ckpt_path), str(out)) + + def test_freeze_accepts_multi_task_dpa4_head(self) -> None: + """Multitask DPA4 checkpoints should export the selected branch.""" + + def fake_compile(_exported: torch.export.ExportedProgram, package_path: str): + with zipfile.ZipFile(package_path, "w") as zf: + zf.writestr("model/data.pkl", b"") + + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + branch_params = _tiny_sezm_model_params() + branch_params["type"] = "dpa4" + branch_params["descriptor"]["type"] = "dpa4" + params = {"model_dict": {"Domains_Alloy": branch_params}} + model = { + "Domains_Alloy": get_model(copy.deepcopy(branch_params)).to(_CPU).eval() + } + wrapper = ModelWrapper(model, model_params=copy.deepcopy(params)) + ckpt_path = tmp_path / "multi_dpa4.pt" + torch.save({"model": wrapper.state_dict()}, ckpt_path) + out = tmp_path / "multi_dpa4.pt2" + + self.assertTrue(is_sezm_checkpoint(str(ckpt_path))) + with mock.patch( + "torch._inductor.aoti_compile_and_package", + side_effect=fake_compile, + ): + freeze_sezm_to_pt2( + str(ckpt_path), str(out), device=_CPU, head="Domains_Alloy" + ) + + with zipfile.ZipFile(str(out), "r") as zf: + model_def = json.loads( + zf.read("model/extra/model_def_script.json").decode("utf-8") + ) + + self.assertEqual(model_def["type"], "dpa4") + + def test_freeze_accepts_spin_checkpoint_metadata(self) -> None: + """SeZM spin checkpoints should export a spin-compatible pt2 contract.""" + + def fake_compile(_exported: torch.export.ExportedProgram, package_path: str): + with zipfile.ZipFile(package_path, "w") as zf: + zf.writestr("model/data.pkl", b"") + + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + params = _tiny_sezm_spin_model_params() + ckpt_path = _write_tiny_sezm_checkpoint(tmp_path, params) + out = tmp_path / "spin.pt2" + + with mock.patch( + "torch._inductor.aoti_compile_and_package", + side_effect=fake_compile, + ): + freeze_sezm_to_pt2(str(ckpt_path), str(out), device=_CPU) + + with zipfile.ZipFile(str(out), "r") as zf: + metadata = json.loads( + zf.read("model/extra/metadata.json").decode("utf-8") + ) + + self.assertTrue(metadata["is_spin"]) + self.assertEqual(metadata["type_map"], params["type_map"]) + self.assertEqual(metadata["ntypes"], len(params["type_map"])) + self.assertEqual(metadata["dim_chg_spin"], 0) + self.assertIsNone(metadata["default_chg_spin"]) + self.assertEqual(metadata["use_spin"], params["spin"]["use_spin"]) + self.assertEqual(metadata["ntypes_spin"], 1) + self.assertIn("energy_derv_r_mag", metadata["output_keys"]) + self.assertIn("energy_derv_c_redu", metadata["output_keys"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_sezm_model.py b/source/tests/pt/model/test_sezm_model.py new file mode 100644 index 0000000000..9d4095c50e --- /dev/null +++ b/source/tests/pt/model/test_sezm_model.py @@ -0,0 +1,1987 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import math +import os +import tempfile +import unittest +import warnings +from pathlib import ( + Path, +) +from unittest import ( + mock, +) + +import h5py +import numpy as np +import torch + +from deepmd.pt.loss import ( + DeNSLoss, + EnergyStdLoss, +) +from deepmd.pt.model.descriptor.sezm_nn import ( + GatedActivation, + LoRASO2, + LoRASO3, + SO2Linear, + SO3Linear, + apply_lora_to_sezm, + build_edge_cache, + build_edge_cache_from_edges, + build_merged_state_dict, +) +from deepmd.pt.model.model import ( + get_model, + get_sezm_model, +) +from deepmd.pt.model.model.sezm_model import ( + InterPotential, + SeZMModel, +) +from deepmd.pt.train.training import ( + prepare_model_for_loss, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.utils.path import ( + DPPath, +) + +warnings.filterwarnings( + # Keep the compile-test warning summary focused on strict-tolerance drift. + # PyTorch's AOTAutograd cache emits an internal Python 3.14 deprecation + # warning that is unrelated to SeZM numerical correctness. + "ignore", + category=DeprecationWarning, + module=r"torch\._functorch\._aot_autograd\.autograd_cache", +) + + +def _assert_close_with_strict_warning( + actual: torch.Tensor, + expected: torch.Tensor, + *, + strict_atol: float = 1.0e-6, + strict_rtol: float = 1.0e-6, + atol: float, + rtol: float, + msg: str, +) -> None: + """Warn on strict compile drift, fail only outside relaxed tolerance.""" + try: + torch.testing.assert_close( + actual, + expected, + atol=strict_atol, + rtol=strict_rtol, + msg=msg, + ) + except AssertionError as err: + warnings.warn( + f"{msg} exceeds strict tolerance " + f"(atol={strict_atol:g}, rtol={strict_rtol:g}) but is checked " + f"against relaxed tolerance (atol={atol:g}, rtol={rtol:g}): {err}", + RuntimeWarning, + stacklevel=2, + ) + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol, msg=msg) + + +def _build_m_major_z_rotation( + angles: torch.Tensor, lmax: int, mmax: int, device: torch.device +) -> torch.Tensor: + """Build the m-major block z-rotation matrix used by SO(2) equivariance tests. + + Given per-sample rotation ``angles`` and the ``(lmax, mmax)`` truncation, + return a tensor with shape ``(batch, dim_red, dim_red)`` where ``dim_red`` + is the truncated coefficient dimension of the m-major layout. + """ + batch = angles.shape[0] + m0_size = lmax + 1 + dim_red = m0_size + for m in range(1, mmax + 1): + num_l = lmax - m + 1 + dim_red += 2 * num_l + + Z = angles.new_zeros(batch, dim_red, dim_red) + eye0 = torch.eye(m0_size, device=device, dtype=angles.dtype).expand( + batch, m0_size, m0_size + ) + Z[:, :m0_size, :m0_size] = eye0 + + offset = m0_size + for m in range(1, mmax + 1): + num_l = lmax - m + 1 + eye = torch.eye(num_l, device=device, dtype=angles.dtype).expand( + batch, num_l, num_l + ) + cos_m = torch.cos(m * angles).view(batch, 1, 1) + sin_m = torch.sin(m * angles).view(batch, 1, 1) + + # Each m group stores the coefficients as [neg(l), pos(l)]; the rotation + # is [[cos I, -sin I], [sin I, cos I]] for the (neg, pos) pair. + Z[:, offset : offset + num_l, offset : offset + num_l] = cos_m * eye + Z[ + :, + offset : offset + num_l, + offset + num_l : offset + 2 * num_l, + ] = -sin_m * eye + Z[ + :, + offset + num_l : offset + 2 * num_l, + offset : offset + num_l, + ] = sin_m * eye + Z[ + :, + offset + num_l : offset + 2 * num_l, + offset + num_l : offset + 2 * num_l, + ] = cos_m * eye + offset += 2 * num_l + + return Z + + +def _build_lora_sezm_model_params(**overrides) -> dict: + """Minimal SeZMModel config suitable for LoRA injection tests. + + Uses ``s2_activation=[False, False]`` so the model keeps a + ``GatedActivation`` (the override-freeze policy is easier to exercise) and + sets ``use_compile=False`` by default; set ``use_compile=True`` via + ``overrides`` to exercise the compile path. + """ + params = { + "type": "SeZM", + "type_map": ["A", "B"], + "descriptor": { + "type": "SeZM", + "sel": [2, 2], + "rcut": 3.0, + "channels": 4, + "n_focus": 1, + "n_radial": 3, + "radial_mlp": [6], + "use_env_seed": True, + "l_schedule": [1, 0], + "mmax": 1, + "so2_norm": False, + "so2_layers": 1, + "n_atten_head": 1, + "sandwich_norm": [True, False, True, False], + "ffn_neurons": 8, + "ffn_blocks": 1, + "s2_activation": [False, False], + "mlp_bias": True, + "layer_scale": True, + "use_amp": False, + "activation_function": "silu", + "glu_activation": True, + "precision": "float32", + "seed": 7, + }, + "fitting_net": { + "neuron": [8], + "activation_function": "silu", + "precision": "float32", + "seed": 7, + }, + "use_compile": False, + } + params.update(overrides) + return params + + +class TestSeZMDPA4Alias(unittest.TestCase): + """Test the DPA4 user-facing aliases for the SeZM model scaffold.""" + + def test_get_model_accepts_dpa4_alias(self) -> None: + """DPA4 model and descriptor type strings should build SeZMModel.""" + params = _build_lora_sezm_model_params(type="dpa4") + params["descriptor"]["type"] = "dpa4" + + model = get_model(params) + + self.assertIsInstance(model, SeZMModel) + self.assertEqual( + model.serialize()["atomic_model"]["fitting"]["type"], + "sezm_ener", + ) + + +class TestSeZMModelCompile(unittest.TestCase): + """Test SeZM model compile path consistency.""" + + def setUp(self) -> None: + self.device = env.DEVICE + torch.manual_seed(2024) + + @staticmethod + def _randomize_params(model: torch.nn.Module, seed: int = 1234) -> None: + """Fill all parameters with small random values. + + Zero-initialized parameters mask second-order gradient bugs because + many multiplicative paths collapse to zero. + """ + torch.manual_seed(seed) + with torch.no_grad(): + for p in model.parameters(): + p.copy_(torch.randn_like(p) * 0.1) + + def _build_model_params(self, *, use_compile: bool) -> dict: + return { + "type": "SeZM", + "type_map": ["A", "B"], + "descriptor": { + "type": "SeZM", + "sel": [2, 2], + "rcut": 3.0, + "channels": 4, + "n_focus": 1, + "n_radial": 3, + "radial_mlp": [6], + "use_env_seed": True, + "l_schedule": [1, 0], + "mmax": 1, + "so2_norm": False, + "so2_layers": 1, + "n_atten_head": 1, + "sandwich_norm": [True, False, True, False], + "ffn_neurons": 8, + "ffn_blocks": 1, + "s2_activation": [False, True], + "mlp_bias": False, + "layer_scale": False, + "use_amp": False, + "activation_function": "silu", + "glu_activation": True, + "precision": "float32", + "seed": 7, + }, + "fitting_net": { + "neuron": [8], + "activation_function": "silu", + "precision": "float32", + "seed": 7, + }, + "use_compile": use_compile, + } + + def _make_tiny_frame( + self, + nframe: int = 1, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """Build deterministic tiny frames with force and virial labels. + + Parameters + ---------- + nframe + Number of frames to build. + + Returns + ------- + coord : torch.Tensor + Coordinates with shape (nframe, nloc, 3). + atype : torch.Tensor + Atom types with shape (nframe, nloc). + box : torch.Tensor + Box tensor with shape (nframe, 9). + energy : torch.Tensor + Energy with shape (nframe, 1). + force : torch.Tensor + Forces with shape (nframe, nloc, 3). + virial : torch.Tensor + Virial tensor with shape (nframe, 9). + """ + if nframe <= 0: + raise ValueError("nframe must be positive") + + frame_shift = torch.arange( + nframe, device=self.device, dtype=torch.float32 + ).view(nframe, 1, 1) + coord = ( + torch.tensor( + [ + [ + [0.0, 0.0, 0.0], + [1.1, 0.3, 0.0], + [0.2, 1.5, 0.4], + [1.7, 1.2, 0.2], + [2.3, 0.1, 1.0], + [0.8, 2.2, 1.1], + [2.6, 1.8, 1.5], + ], + ], + device=self.device, + dtype=torch.float32, + ) + + 0.05 * frame_shift + ) + atype = torch.tensor( + [[0, 1, 0, 1, 0, 1, 0]], device=self.device, dtype=torch.int32 + ).repeat(nframe, 1) + box = torch.tensor( + [[8.0, 0.0, 0.0, 0.0, 8.0, 0.0, 0.0, 0.0, 8.0]], + device=self.device, + dtype=torch.float32, + ).repeat(nframe, 1) + energy = torch.tensor( + [[0.25]], device=self.device, dtype=torch.float32 + ) + 0.01 * frame_shift.view(nframe, 1) + force = ( + torch.tensor( + [ + [ + [0.2, -0.1, 0.0], + [-0.3, 0.4, 0.1], + [0.1, -0.3, -0.1], + [0.0, 0.2, -0.2], + [-0.2, -0.1, 0.3], + [0.3, 0.0, -0.1], + [-0.1, -0.1, 0.0], + ], + ], + device=self.device, + dtype=torch.float32, + ) + + 0.02 * frame_shift + ) + virial = torch.tensor( + [[0.3, 0.01, -0.02, 0.01, -0.2, 0.04, -0.02, 0.04, 0.1]], + device=self.device, + dtype=torch.float32, + ) + 0.03 * frame_shift.view(nframe, 1) + return coord, atype, box, energy, force, virial + + def _train_steps( + self, + model: torch.nn.Module, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor, + energy: torch.Tensor, + force: torch.Tensor, + virial: torch.Tensor | None = None, + steps: int = 3, + ) -> dict[str, torch.Tensor]: + optimizer = torch.optim.SGD(model.parameters(), lr=1.0e-7) + for _ in range(steps): + optimizer.zero_grad(set_to_none=True) + out = model(coord, atype, box=box) + loss_energy = torch.mean( + (out["energy"] - energy.to(out["energy"].dtype)) ** 2 + ) + loss_force = torch.mean((out["force"] - force.to(out["force"].dtype)) ** 2) + loss = loss_energy + loss_force + if virial is not None and "virial" in out: + loss_virial = torch.mean( + (out["virial"] - virial.to(out["virial"].dtype)) ** 2 + ) + loss = loss + loss_virial + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + return { + name: param.detach().clone() for name, param in model.named_parameters() + } + + def test_compile_cache_slots_and_eval_shape_change(self) -> None: + """Compile cache slots should coexist while eval handles batch-size growth.""" + coord_1, atype_1, box_1, _, _, _ = self._make_tiny_frame() + coord_2, atype_2, box_2, _, _, _ = self._make_tiny_frame(nframe=2) + + # === Step 1. Build paired models with shared random weights === + model_dyn = get_sezm_model(self._build_model_params(use_compile=False)) + self._randomize_params(model_dyn) + with mock.patch.dict(os.environ, {"DP_COMPILE_INFER": "1"}, clear=False): + model_cmp = get_sezm_model(self._build_model_params(use_compile=True)) + model_cmp.load_state_dict(model_dyn.state_dict()) + + train_key = (True, False, False) + eval_key = (False, False, False) + + # === Step 2. Train-mode forward fills the training slot. === + model_cmp.train() + model_cmp(coord_1, atype_1, box=box_1) + self.assertIn(train_key, model_cmp.compiled_core_compute_cache) + self.assertNotIn(eval_key, model_cmp.compiled_core_compute_cache) + callable_train_first = model_cmp.compiled_core_compute_cache[train_key] + + # === Step 3. First eval call adds the eval slot without evicting train. === + model_dyn.eval() + model_cmp.eval() + out_dyn_1 = model_dyn(coord_1, atype_1, box=box_1) + out_cmp_1 = model_cmp(coord_1, atype_1, box=box_1) + self.assertIn(train_key, model_cmp.compiled_core_compute_cache) + self.assertIn(eval_key, model_cmp.compiled_core_compute_cache) + self.assertIs( + model_cmp.compiled_core_compute_cache[train_key], callable_train_first + ) + callable_eval_first = model_cmp.compiled_core_compute_cache[eval_key] + self.assertIsNot(callable_train_first, callable_eval_first) + _assert_close_with_strict_warning( + out_dyn_1["energy"], + out_cmp_1["energy"], + atol=1.0e-6, + rtol=1.0e-6, + msg="eval energy mismatch on first compiled call", + ) + _assert_close_with_strict_warning( + out_dyn_1["force"], + out_cmp_1["force"], + atol=1.0e-6, + rtol=1.0e-6, + msg="eval force mismatch on first compiled call", + ) + _assert_close_with_strict_warning( + out_dyn_1["virial"], + out_cmp_1["virial"], + atol=1.0e-5, + rtol=1.0e-5, + msg="eval virial mismatch on first compiled call", + ) + + # === Step 4. Reuse the traced eval graph on a larger batch. === + out_dyn_2 = model_dyn(coord_2, atype_2, box=box_2) + out_cmp_2 = model_cmp(coord_2, atype_2, box=box_2) + self.assertEqual(out_dyn_2["energy"].shape, (2, 1)) + self.assertEqual(out_cmp_2["energy"].shape, (2, 1)) + self.assertIs( + model_cmp.compiled_core_compute_cache[eval_key], callable_eval_first + ) + _assert_close_with_strict_warning( + out_dyn_2["energy"], + out_cmp_2["energy"], + atol=1.0e-6, + rtol=1.0e-6, + msg="eval energy mismatch after batch-size growth", + ) + _assert_close_with_strict_warning( + out_dyn_2["force"], + out_cmp_2["force"], + atol=1.0e-6, + rtol=1.0e-6, + msg="eval force mismatch after batch-size growth", + ) + _assert_close_with_strict_warning( + out_dyn_2["virial"], + out_cmp_2["virial"], + atol=1.0e-5, + rtol=1.0e-5, + msg="eval virial mismatch after batch-size growth", + ) + + # === Step 5. Flip back to train and reuse the existing training slot. === + model_cmp.train() + model_cmp(coord_1, atype_1, box=box_1) + self.assertIs( + model_cmp.compiled_core_compute_cache[train_key], callable_train_first + ) + self.assertIs( + model_cmp.compiled_core_compute_cache[eval_key], callable_eval_first + ) + + def test_charge_spin_condition_matches_compile(self) -> None: + """Charge/spin conditions should work through the compiled energy path.""" + coord, atype, box, _, _, _ = self._make_tiny_frame() + params = self._build_model_params(use_compile=False) + params["descriptor"]["add_chg_spin_ebd"] = True + params["descriptor"]["default_chg_spin"] = [0.0, 1.0] + + model_dyn = get_sezm_model(params) + self._randomize_params(model_dyn) + params_cmp = self._build_model_params(use_compile=True) + params_cmp["descriptor"]["add_chg_spin_ebd"] = True + params_cmp["descriptor"]["default_chg_spin"] = [0.0, 1.0] + with mock.patch.dict(os.environ, {"DP_COMPILE_INFER": "1"}, clear=False): + model_cmp = get_sezm_model(params_cmp) + model_cmp.load_state_dict(model_dyn.state_dict()) + model_dyn.eval() + model_cmp.eval() + + charge_spin = torch.tensor( + [[0.0, 1.0]], dtype=torch.float32, device=self.device + ) + out_dyn = model_dyn(coord, atype, box=box, charge_spin=charge_spin) + out_cmp = model_cmp(coord, atype, box=box, charge_spin=charge_spin) + out_default = model_cmp(coord, atype, box=box) + out_shifted = model_cmp( + coord, + atype, + box=box, + charge_spin=torch.tensor( + [[1.0, 1.0]], dtype=torch.float32, device=self.device + ), + ) + + _assert_close_with_strict_warning( + out_dyn["energy"], + out_cmp["energy"], + atol=1.0e-6, + rtol=1.0e-6, + msg="charge/spin energy mismatch", + ) + _assert_close_with_strict_warning( + out_dyn["force"], + out_cmp["force"], + atol=1.0e-6, + rtol=1.0e-6, + msg="charge/spin force mismatch", + ) + _assert_close_with_strict_warning( + out_dyn["virial"], + out_cmp["virial"], + atol=1.0e-5, + rtol=1.0e-5, + msg="charge/spin virial mismatch", + ) + _assert_close_with_strict_warning( + out_default["energy"], + out_cmp["energy"], + atol=1.0e-6, + rtol=1.0e-6, + msg="default charge/spin energy mismatch", + ) + self.assertFalse( + torch.allclose(out_shifted["atom_energy"], out_cmp["atom_energy"]) + ) + + def test_fixed_edge_geometry_matches_standard_cache(self) -> None: + """Sparse edge geometry should match the standard descriptor cache.""" + coord, atype, box, _, _, _ = self._make_tiny_frame() + model = get_sezm_model(self._build_model_params(use_compile=False)) + model.train() + descriptor = model.atomic_model.descriptor + + cc, bb, fp, ap, _ = model._input_type_cast( + coord, box=box, fparam=None, aparam=None + ) + del fp, ap + if cc.ndim == 2: + cc = cc.view(coord.shape[0], atype.shape[1], 3) + extended_coord, extended_atype, mapping, nlist = model.build_neighbor_list( + cc, atype, bb + ) + atype_loc = extended_atype[:, : nlist.shape[1]] + type_ebed = descriptor.type_embedding(atype_loc).reshape( + -1, descriptor.channels + ) + pair_keep_mask = torch.ones_like( + nlist, dtype=torch.bool, device=extended_coord.device + ) + + cache_std = build_edge_cache( + type_ebed=type_ebed, + extended_coord=extended_coord.to(descriptor.compute_dtype), + nlist=nlist, + mapping=mapping, + pair_keep_mask=pair_keep_mask, + eps=descriptor.eps, + edge_envelope=descriptor.edge_envelope, + radial_basis=descriptor.radial_basis, + n_radial=descriptor.radial_basis.n_radial, + random_gamma=False, + wigner_calc=descriptor.wigner_calc, + use_geometry_rbf_triton=False, + ) + + edge_index, edge_vec, edge_mask = model.build_edge_list_from_nlist( + extended_coord=extended_coord, + nlist=nlist, + mapping=mapping, + ) + cache_sparse = build_edge_cache_from_edges( + type_ebed=type_ebed, + atype_flat=atype_loc.reshape(-1), + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + compute_dtype=descriptor.compute_dtype, + eps=descriptor.eps, + inner_clamp=descriptor.inner_clamp, + bridging_switch=descriptor.bridging_switch, + edge_envelope=descriptor.edge_envelope, + radial_basis=descriptor.radial_basis, + has_exclude_types=False, + edge_type_keep_mask=descriptor._edge_type_keep_mask, + random_gamma=False, + wigner_calc=descriptor.wigner_calc, + ) + + # build_edge_list_from_nlist appends one masked dummy edge; + # compare only the real edges (all except the trailing dummy). + n_real = cache_std.src.shape[0] + self.assertTrue(torch.equal(cache_std.src, cache_sparse.src[:n_real])) + self.assertTrue(torch.equal(cache_std.dst, cache_sparse.dst[:n_real])) + torch.testing.assert_close(cache_std.edge_vec, cache_sparse.edge_vec[:n_real]) + torch.testing.assert_close(cache_std.edge_rbf, cache_sparse.edge_rbf[:n_real]) + torch.testing.assert_close(cache_std.edge_env, cache_sparse.edge_env[:n_real]) + torch.testing.assert_close(cache_std.D_full, cache_sparse.D_full[:n_real]) + torch.testing.assert_close(cache_std.Dt_full, cache_sparse.Dt_full[:n_real]) + + def test_eval_compile_policy(self) -> None: + """Eval should stay eager by default and compile only with env override.""" + model = get_sezm_model(self._build_model_params(use_compile=True)) + self.assertTrue(model.use_compile) + + model.train() + self.assertTrue(model.should_use_compile()) + + model.eval() + self.assertFalse(model.should_use_compile()) + + with mock.patch.dict(os.environ, {"DP_COMPILE_INFER": "1"}, clear=False): + model_eval = get_sezm_model(self._build_model_params(use_compile=True)) + model_eval.eval() + self.assertTrue(model_eval.should_use_compile()) + + def test_forward_backward_double_backward_matches_compile(self) -> None: + """ + Check forward, backward, double backward, and short training consistency. + + Forward: energy/force outputs should match. + Backward: d(energy)/d(params) should match. + Double backward: d(force_loss)/d(params) should match. + Training: three SGD steps and a larger follow-up batch should still match. + """ + coord, atype, box, energy, force, virial = self._make_tiny_frame() + coord_2, atype_2, box_2, _, _, _ = self._make_tiny_frame(nframe=2) + + # === Step 1. Build paired models with shared random weights === + model_dyn = get_sezm_model(self._build_model_params(use_compile=False)) + self._randomize_params(model_dyn) + model_cmp = get_sezm_model(self._build_model_params(use_compile=True)) + model_cmp.load_state_dict(model_dyn.state_dict()) + model_dyn.train() + model_cmp.train() + + # === Step 2. Forward output consistency === + out_dyn = model_dyn(coord, atype, box=box) + out_cmp = model_cmp(coord, atype, box=box) + _assert_close_with_strict_warning( + out_dyn["energy"], + out_cmp["energy"], + atol=1.0e-6, + rtol=1.0e-6, + msg="train energy mismatch on first compiled call", + ) + _assert_close_with_strict_warning( + out_dyn["force"], + out_cmp["force"], + atol=1.0e-6, + rtol=1.0e-6, + msg="train force mismatch on first compiled call", + ) + + # === Step 3. Backward on energy === + model_dyn.zero_grad(set_to_none=True) + model_cmp.zero_grad(set_to_none=True) + loss_dyn = out_dyn["energy"].sum() + loss_cmp = out_cmp["energy"].sum() + loss_dyn.backward() + loss_cmp.backward() + grads_dyn = { + name: ( + torch.zeros_like(param) if param.grad is None else param.grad.detach() + ) + for name, param in model_dyn.named_parameters() + } + grads_cmp = { + name: ( + torch.zeros_like(param) if param.grad is None else param.grad.detach() + ) + for name, param in model_cmp.named_parameters() + } + # Inductor Triton kernels use different reduction order vs eager, + # so float32 gradients can differ by ~1e-3 on GPU. + grad_atol = 1.0e-5 if self.device == torch.device("cpu") else 2.0e-3 + grad_rtol = 1.0e-5 if self.device == torch.device("cpu") else 1.0e-4 + self.assertEqual(set(grads_dyn.keys()), set(grads_cmp.keys())) + for name in grads_dyn.keys(): + _assert_close_with_strict_warning( + grads_dyn[name], + grads_cmp[name], + atol=grad_atol, + rtol=grad_rtol, + msg=f"energy-grad mismatch at {name}", + ) + + # === Step 5. Reuse the compiled training graph for three optimizer steps === + params_dyn = self._train_steps( + model_dyn, coord, atype, box, energy, force, virial + ) + params_cmp = self._train_steps( + model_cmp, coord, atype, box, energy, force, virial + ) + self.assertEqual(set(params_dyn.keys()), set(params_cmp.keys())) + for name in params_dyn.keys(): + _assert_close_with_strict_warning( + params_dyn[name], + params_cmp[name], + strict_atol=1.0e-7, + strict_rtol=1.0e-7, + atol=1.0e-7, + rtol=1.0e-7, + msg=f"trained parameter mismatch at {name}", + ) + + # === Step 6. The traced training graph should also handle a larger batch === + out_dyn = model_dyn(coord_2, atype_2, box=box_2) + out_cmp = model_cmp(coord_2, atype_2, box=box_2) + self.assertEqual(out_dyn["energy"].shape, (2, 1)) + self.assertEqual(out_cmp["energy"].shape, (2, 1)) + _assert_close_with_strict_warning( + out_dyn["energy"], + out_cmp["energy"], + atol=1.0e-6, + rtol=1.0e-6, + msg="train energy mismatch after batch-size growth", + ) + + # === Step 4. Double backward via force loss === + model_dyn.zero_grad(set_to_none=True) + model_cmp.zero_grad(set_to_none=True) + out_dyn = model_dyn(coord, atype, box=box) + out_cmp = model_cmp(coord, atype, box=box) + loss_dyn = torch.sum(out_dyn["force"] * out_dyn["force"]) + loss_cmp = torch.sum(out_cmp["force"] * out_cmp["force"]) + loss_dyn.backward() + loss_cmp.backward() + grads_dyn = { + name: ( + torch.zeros_like(param) if param.grad is None else param.grad.detach() + ) + for name, param in model_dyn.named_parameters() + } + grads_cmp = { + name: ( + torch.zeros_like(param) if param.grad is None else param.grad.detach() + ) + for name, param in model_cmp.named_parameters() + } + self.assertEqual(set(grads_dyn.keys()), set(grads_cmp.keys())) + for name in grads_dyn.keys(): + _assert_close_with_strict_warning( + grads_dyn[name], + grads_cmp[name], + atol=grad_atol, + rtol=grad_rtol, + msg=f"force-grad mismatch at {name}", + ) + + def _assert_multitask_compile_matches_eager( + self, + *, + case_film_embd: bool, + ) -> None: + """ + Multi-task + compile: two SeZM branches sharing descriptor and + fitting (with per-task case embedding) should each compile correctly + and produce outputs matching their eager counterparts. + """ + from deepmd.pt.train.training import ( + get_model_for_wrapper, + prepare_model_for_loss, + ) + from deepmd.pt.train.wrapper import ( + ModelWrapper, + ) + from deepmd.pt.utils.multi_task import ( + preprocess_shared_params, + ) + + # === Step 1. Build a multi-task model config with shared descriptor + # + shared fitting (case_embd=2) seeded from the compile fixture. === + def _make_mt_cfg(use_compile: bool) -> dict: + single = self._build_model_params(use_compile=use_compile) + fitting_shared = dict(single["fitting_net"]) + fitting_shared["type"] = "dpa4_ener" + fitting_shared["dim_case_embd"] = 2 + fitting_shared["case_film_embd"] = case_film_embd + return { + "use_compile": use_compile, + "shared_dict": { + "type_map": single["type_map"], + "descriptor": single["descriptor"], + "shared_fit": fitting_shared, + }, + "model_dict": { + "water_1": { + "type": "SeZM", + "type_map": "type_map", + "descriptor": "descriptor", + "fitting_net": "shared_fit", + }, + "water_2": { + "type": "SeZM", + "type_map": "type_map", + "descriptor": "descriptor", + "fitting_net": "shared_fit", + }, + }, + } + + def _build_wrapper(use_compile: bool) -> ModelWrapper: + mt_cfg = _make_mt_cfg(use_compile) + # ``preprocess_shared_params`` cascades the top-level + # ``use_compile`` into every branch before unrolling the + # shared_dict, mirroring the real training flow. + mt_cfg, shared_links = preprocess_shared_params(mt_cfg) + loss_params = { + "water_1": {"type": "ener"}, + "water_2": {"type": "ener"}, + } + models = get_model_for_wrapper(mt_cfg, _loss_params=loss_params) + prepare_model_for_loss(models, loss_params) + wrapper = ModelWrapper(models) + wrapper.share_params(shared_links, {"water_1": 0.5, "water_2": 0.5}) + return wrapper + + wrapper_eager = _build_wrapper(use_compile=False) + self._randomize_params(wrapper_eager) + wrapper_cmp = _build_wrapper(use_compile=True) + # Mirror eager weights so the only remaining difference between the + # two wrappers is the compile path. + wrapper_cmp.load_state_dict(wrapper_eager.state_dict()) + + # Sanity: descriptor and fitting parameters are shared across branches + # inside each wrapper. + for w in (wrapper_eager, wrapper_cmp): + d1 = w.model["water_1"].get_descriptor() + d2 = w.model["water_2"].get_descriptor() + self.assertEqual( + next(d1.parameters()).data_ptr(), + next(d2.parameters()).data_ptr(), + ) + f1 = w.model["water_1"].atomic_model.fitting_net + f2 = w.model["water_2"].atomic_model.fitting_net + self.assertEqual( + next(f1.filter_layers.parameters()).data_ptr(), + next(f2.filter_layers.parameters()).data_ptr(), + ) + # Per-task case embeddings remain distinct. + self.assertFalse(torch.equal(f1.case_embd, f2.case_embd)) + expected_in_dim = f1.dim_descrpt + (0 if case_film_embd else 2) + self.assertEqual(f1.filter_layers.networks[0].in_dim, expected_in_dim) + self.assertEqual(f1.case_film_embd, case_film_embd) + + # === Step 2. Run compile + eager forward on each branch. === + coord, atype, box, _, _, _ = self._make_tiny_frame() + for branch in ("water_1", "water_2"): + m_eager = wrapper_eager.model[branch] + m_cmp = wrapper_cmp.model[branch] + m_eager.train() + m_cmp.train() + out_e = m_eager(coord, atype, box=box) + out_c = m_cmp(coord, atype, box=box) + _assert_close_with_strict_warning( + out_e["energy"], + out_c["energy"], + atol=1.0e-6, + rtol=1.0e-6, + msg=f"multitask energy mismatch at {branch}", + ) + _assert_close_with_strict_warning( + out_e["force"], + out_c["force"], + atol=1.0e-6, + rtol=1.0e-6, + msg=f"multitask force mismatch at {branch}", + ) + + # === Step 3. Each compiled branch owns its own compile cache; the + # shared descriptor weights must not collapse them into one. + # Step 2 ran every branch in training mode with the default + # ``do_atomic_virial=False`` and no coordinate correction, so each + # per-branch cache dict + # should hold exactly that one slot, and the compiled callables + # at that slot must be distinct across branches. === + cache1 = wrapper_cmp.model["water_1"].compiled_core_compute_cache + cache2 = wrapper_cmp.model["water_2"].compiled_core_compute_cache + self.assertIsNot(cache1, cache2) + train_key = (True, False, False) + self.assertIn(train_key, cache1) + self.assertIn(train_key, cache2) + c1 = cache1[train_key] + c2 = cache2[train_key] + self.assertIsNotNone(c1) + self.assertIsNotNone(c2) + self.assertIsNot(c1, c2) + + # === Step 4. Per-task case embedding must differentiate outputs. === + out_e1 = wrapper_eager.model["water_1"](coord, atype, box=box) + out_e2 = wrapper_eager.model["water_2"](coord, atype, box=box) + self.assertFalse( + torch.allclose(out_e1["energy"], out_e2["energy"], atol=1.0e-8) + ) + + def test_multitask_compile_matches_eager(self) -> None: + """Legacy case embedding concatenation should match through compile.""" + self._assert_multitask_compile_matches_eager(case_film_embd=False) + + +class TestInterPotential(unittest.TestCase): + """Test InterPotential ZBL analytical pair potential.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_zbl_known_value_OO(self) -> None: + """Test ZBL energy for O-O pair at known distance against reference.""" + pot = InterPotential(type_map=["O", "H"], mode="ZBL").to(self.device) + + import math + + z_o = 8.0 + a_bohr = 0.5291772109 + ke = 14.3996 + a_screen = 0.88534 * a_bohr / (z_o**0.23 + z_o**0.23) + r = 1.0 + x = r / a_screen + phi = ( + 0.18175 * math.exp(-3.1998 * x) + + 0.50986 * math.exp(-0.94229 * x) + + 0.28022 * math.exp(-0.4029 * x) + + 0.028171 * math.exp(-0.20162 * x) + ) + expected = ke * z_o * z_o / r * phi + + extended_coord = torch.tensor( + [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], + dtype=torch.float64, + device=self.device, + ) + extended_atype = torch.tensor([[0, 0]], dtype=torch.int64, device=self.device) + nlist = torch.tensor([[[1], [0]]], dtype=torch.int64, device=self.device) + + pair_e = pot(extended_coord, extended_atype, nlist, nloc=2) + total_e = pair_e.sum().item() + self.assertAlmostEqual(total_e, expected, places=5) + + def test_zbl_known_value_OH(self) -> None: + """Test ZBL energy for O-H pair at known distance.""" + pot = InterPotential(type_map=["O", "H"], mode="ZBL").to(self.device) + import math + + z_o, z_h = 8.0, 1.0 + a_bohr = 0.5291772109 + ke = 14.3996 + a_screen = 0.88534 * a_bohr / (z_o**0.23 + z_h**0.23) + r = 0.8 + x = r / a_screen + phi = ( + 0.18175 * math.exp(-3.1998 * x) + + 0.50986 * math.exp(-0.94229 * x) + + 0.28022 * math.exp(-0.4029 * x) + + 0.028171 * math.exp(-0.20162 * x) + ) + expected = ke * z_o * z_h / r * phi + + extended_coord = torch.tensor( + [[[0.0, 0.0, 0.0], [0.8, 0.0, 0.0]]], + dtype=torch.float64, + device=self.device, + ) + extended_atype = torch.tensor([[0, 1]], dtype=torch.int64, device=self.device) + nlist = torch.tensor([[[1], [0]]], dtype=torch.int64, device=self.device) + + pair_e = pot(extended_coord, extended_atype, nlist, nloc=2) + total_e = pair_e.sum().item() + self.assertAlmostEqual(total_e, expected, places=5) + + def test_zbl_gradient_exists(self) -> None: + """Test that ZBL potential produces valid gradients for force computation.""" + pot = InterPotential(type_map=["O", "H"], mode="ZBL").to(self.device) + + extended_coord = torch.tensor( + [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], + dtype=torch.float64, + device=self.device, + requires_grad=True, + ) + extended_atype = torch.tensor([[0, 1]], dtype=torch.int64, device=self.device) + nlist = torch.tensor([[[1], [0]]], dtype=torch.int64, device=self.device) + + pair_e = pot(extended_coord, extended_atype, nlist, nloc=2) + pair_e.sum().backward() + self.assertIsNotNone(extended_coord.grad) + self.assertTrue(torch.isfinite(extended_coord.grad).all()) + + def test_unknown_element_raises(self) -> None: + """Test that unknown element raises ValueError.""" + with self.assertRaises(ValueError): + InterPotential(type_map=["O", "Xx"]) + + def test_forward_from_edges(self) -> None: + """Test the compile-path edge-based ZBL computation.""" + pot = InterPotential(type_map=["O", "H"], mode="ZBL").to(self.device) + + edge_vec = torch.tensor( + [[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]], + dtype=torch.float64, + device=self.device, + ) + edge_index = torch.tensor( + [[1, 0], [0, 1]], dtype=torch.long, device=self.device + ) + atype_flat = torch.tensor([0, 1], dtype=torch.long, device=self.device) + edge_mask = torch.tensor([True, True], device=self.device) + + result = pot.forward_from_edges(edge_vec, edge_index, atype_flat, edge_mask, 2) + self.assertEqual(result.shape, (1, 2, 1)) + self.assertTrue(torch.isfinite(result).all()) + + extended_coord = torch.tensor( + [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]], + dtype=torch.float64, + device=self.device, + ) + extended_atype = torch.tensor([[0, 1]], dtype=torch.int64, device=self.device) + nlist = torch.tensor([[[1], [0]]], dtype=torch.int64, device=self.device) + pair_e_nlist = pot(extended_coord, extended_atype, nlist, nloc=2) + torch.testing.assert_close( + result.sum(), pair_e_nlist.sum().to(result.dtype), atol=1e-8, rtol=1e-8 + ) + + +class TestSeZMModelBridging(unittest.TestCase): + """Test SeZM model with ZBL bridging enabled.""" + + def setUp(self) -> None: + self.device = env.DEVICE + torch.manual_seed(2024) + + def _build_model_params(self, *, bridging_method: str = "none") -> dict: + return { + "type": "SeZM", + "type_map": ["O", "H"], + "descriptor": { + "type": "SeZM", + "sel": [2, 2], + "rcut": 3.0, + "channels": 4, + "n_focus": 1, + "focus_compete": False, + "n_radial": 3, + "radial_mlp": [6], + "use_env_seed": False, + "l_schedule": [1, 0], + "mmax": 1, + "so2_norm": False, + "so2_layers": 1, + "n_atten_head": 0, + "sandwich_norm": [True, False, True, False], + "ffn_neurons": 8, + "ffn_blocks": 1, + "mlp_bias": True, + "layer_scale": True, + "use_amp": False, + "activation_function": "silu", + "glu_activation": True, + "precision": "float32", + "seed": 7, + }, + "fitting_net": { + "neuron": [8], + "activation_function": "silu", + "precision": "float32", + "seed": 7, + }, + "use_compile": False, + "bridging_method": bridging_method, + "bridging_r_inner": 0.8, + "bridging_r_outer": 1.2, + } + + def test_bridging_none_unchanged(self) -> None: + """Test that bridging_method='none' produces no inter_potential.""" + model = get_sezm_model(self._build_model_params(bridging_method="none")) + self.assertIsNone(model.inter_potential) + self.assertEqual(model.bridging_method, "NONE") + + def test_bridging_zbl_creates_potential(self) -> None: + """Test that bridging_method='ZBL' creates InterPotential and InnerClamp.""" + model = get_sezm_model(self._build_model_params(bridging_method="ZBL")) + self.assertIsNotNone(model.inter_potential) + self.assertEqual(model.bridging_method, "ZBL") + self.assertIsNotNone(model.atomic_model.descriptor.inner_clamp) + + def test_zbl_adds_energy(self) -> None: + """Test that ZBL bridging adds energy to the model output.""" + model_plain = get_sezm_model(self._build_model_params(bridging_method="none")) + model_zbl = get_sezm_model(self._build_model_params(bridging_method="ZBL")) + + sd = model_plain.state_dict() + model_zbl.load_state_dict(sd, strict=False) + + coord = torch.tensor( + [[[0.0, 0.0, 0.0], [0.8, 0.0, 0.0], [0.0, 2.0, 0.0]]], + dtype=torch.float32, + device=self.device, + ) + atype = torch.tensor([[0, 1, 0]], dtype=torch.int32, device=self.device) + box = torch.tensor( + [[10.0, 0, 0, 0, 10.0, 0, 0, 0, 10.0]], + dtype=torch.float32, + device=self.device, + ) + + model_plain.eval() + model_zbl.eval() + + out_plain = model_plain(coord, atype, box=box) + out_zbl = model_zbl(coord, atype, box=box) + + energy_diff = (out_zbl["energy"] - out_plain["energy"]).item() + self.assertGreater( + energy_diff, + 0.0, + "ZBL bridging should add positive (repulsive) energy", + ) + + +class TestSeZMModelModes(unittest.TestCase): + """Targeted regression tests for SeZM `ener` / `dens` mode routing.""" + + def setUp(self) -> None: + self.device = env.DEVICE + torch.manual_seed(2024) + + def _build_model_params( + self, + *, + use_compile: bool = False, + bridging_method: str = "none", + ) -> dict: + return { + "type": "SeZM", + "type_map": ["O", "H"], + "descriptor": { + "type": "SeZM", + "sel": [2, 2], + "rcut": 3.0, + "channels": 4, + "n_focus": 1, + "focus_compete": False, + "n_radial": 3, + "radial_mlp": [6], + "use_env_seed": False, + "l_schedule": [1, 1], + "mmax": 1, + "so2_norm": False, + "so2_layers": 1, + "n_atten_head": 0, + "sandwich_norm": [True, False, True, False], + "ffn_neurons": 8, + "ffn_blocks": 1, + "mlp_bias": True, + "layer_scale": False, + "use_amp": False, + "activation_function": "silu", + "glu_activation": True, + "precision": "float32", + "seed": 7, + }, + "fitting_net": { + "neuron": [8], + "activation_function": "silu", + "precision": "float32", + "seed": 7, + }, + "use_compile": use_compile, + "bridging_method": bridging_method, + } + + def _tiny_system( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + coord = torch.tensor( + [ + [ + [0.0, 0.0, 0.0], + [1.1, 0.2, 0.0], + [0.2, 1.0, 0.3], + ] + ], + device=self.device, + dtype=torch.float32, + ) + atype = torch.tensor([[0, 1, 0]], device=self.device, dtype=torch.int32) + box = torch.tensor( + [[6.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 6.0]], + device=self.device, + dtype=torch.float32, + ) + force = torch.tensor( + [ + [ + [0.2, -0.1, 0.0], + [-0.3, 0.4, 0.1], + [0.1, 0.2, -0.2], + ] + ], + device=self.device, + dtype=torch.float32, + ) + noise_mask = torch.tensor( + [[True, False, True]], + device=self.device, + dtype=torch.bool, + ) + return coord, atype, box, force, noise_mask + + def _dens_stat_samples(self) -> list[dict[str, torch.Tensor | np.float32]]: + """Build a tiny SeZM `dens` statistics set with force labels.""" + return [ + { + "atype": torch.tensor( + [[0, 1]], + device=self.device, + dtype=torch.int32, + ), + "natoms": torch.tensor( + [[2, 2, 1, 1]], + device=self.device, + dtype=torch.int32, + ), + "energy": torch.tensor( + [[10.0]], + device=self.device, + dtype=torch.float32, + ), + "force": torch.tensor( + [[[1.0, 2.0, 3.0], [2.0, 4.0, 6.0]]], + device=self.device, + dtype=torch.float32, + ), + "find_energy": np.float32(1.0), + "find_force": np.float32(1.0), + }, + { + "atype": torch.tensor( + [[0, 0]], + device=self.device, + dtype=torch.int32, + ), + "natoms": torch.tensor( + [[2, 2, 2, 0]], + device=self.device, + dtype=torch.int32, + ), + "energy": torch.tensor( + [[8.0]], + device=self.device, + dtype=torch.float32, + ), + "force": torch.tensor( + [[[5.0, 6.0, 7.0], [5.0, 6.0, 7.0]]], + device=self.device, + dtype=torch.float32, + ), + "find_energy": np.float32(1.0), + "find_force": np.float32(1.0), + }, + { + "atype": torch.tensor( + [[1, 1]], + device=self.device, + dtype=torch.int32, + ), + "natoms": torch.tensor( + [[2, 2, 0, 2]], + device=self.device, + dtype=torch.int32, + ), + "energy": torch.tensor( + [[12.0]], + device=self.device, + dtype=torch.float32, + ), + "force": torch.tensor( + [[[8.0, 10.0, 12.0], [8.0, 10.0, 12.0]]], + device=self.device, + dtype=torch.float32, + ), + "find_energy": np.float32(1.0), + "find_force": np.float32(1.0), + }, + ] + + def _expected_dens_force_rmsd( + self, + sampled: list[dict[str, torch.Tensor | np.float32]], + ) -> float: + """Compute the expected global direct-force RMSD.""" + force_square_sum = 0.0 + force_atom_count = 0 + for sample in sampled: + force = sample["force"].detach().cpu().numpy() + force_square_sum += float(np.square(force).sum()) + force_atom_count += int(force.shape[0] * force.shape[1]) + return float(np.sqrt(force_square_sum / force_atom_count)) + + def test_training_setup_routes_mode_without_rebuilding_energy_head(self) -> None: + """Training setup should route SeZM mode without rebuilding the energy head.""" + model = get_sezm_model(self._build_model_params(use_compile=False)) + energy_param_before = ( + next(model.atomic_model.fitting_net.parameters()).detach().clone() + ) + prepare_model_for_loss(model, {"type": "dens"}) + self.assertEqual(model.get_active_mode(), "dens") + self.assertIsNotNone(model.atomic_model.dens_fitting_net) + prepare_model_for_loss(model, {"type": "ener"}) + coord, atype, box, _, _ = self._tiny_system() + loss_module = EnergyStdLoss( + starter_learning_rate=1.0e-3, + start_pref_e=1.0, + limit_pref_e=1.0, + ) + _, loss, _ = loss_module( + { + "coord": coord, + "atype": atype, + "box": box, + }, + model, + { + "energy": torch.zeros((1, 1), device=self.device, dtype=torch.float32), + "find_energy": 1.0, + }, + natoms=atype.shape[1], + learning_rate=1.0e-3, + ) + energy_param_after = next(model.atomic_model.fitting_net.parameters()).detach() + torch.testing.assert_close(energy_param_after, energy_param_before) + self.assertEqual(model.get_active_mode(), "ener") + self.assertTrue(torch.isfinite(loss)) + + def test_checkpoint_loading_handles_optional_dens_head(self) -> None: + """Checkpoint loading should respect whether `dens` weights exist.""" + params = self._build_model_params(use_compile=False) + model = get_sezm_model(params) + state_without_dens = { + key: value + for key, value in model.state_dict().items() + if "dens_fitting_net" not in key + } + fresh_model = get_sezm_model(params) + self.assertIsNone(fresh_model.atomic_model.dens_fitting_net) + fresh_model.load_state_dict(state_without_dens, strict=True) + self.assertIsNone(fresh_model.atomic_model.dens_fitting_net) + self.assertEqual(fresh_model.get_active_mode(), "ener") + coord, atype, box, _, _ = self._tiny_system() + out = fresh_model(coord, atype, box=box) + self.assertIn("energy", out) + self.assertIn("force", out) + model = get_sezm_model(self._build_model_params(use_compile=False)) + model.set_active_mode("dens") + dens_state = model.state_dict() + fresh_model = get_sezm_model(self._build_model_params(use_compile=False)) + self.assertIsNone(fresh_model.atomic_model.dens_fitting_net) + fresh_model.load_state_dict(dens_state, strict=True) + self.assertIsNotNone(fresh_model.atomic_model.dens_fitting_net) + self.assertEqual(fresh_model.get_active_mode(), "dens") + + def test_dens_forward_returns_direct_force_outputs(self) -> None: + """`dens` mode should expose direct-force outputs without virial branches.""" + model = get_sezm_model(self._build_model_params(use_compile=False)) + model.set_active_mode("dens") + coord, atype, box, force, noise_mask = self._tiny_system() + out = model( + coord, + atype, + box=box, + force_input=force, + noise_mask=noise_mask, + ) + self.assertIn("energy", out) + self.assertIn("atom_energy", out) + self.assertIn("force", out) + self.assertNotIn("virial", out) + self.assertEqual(out["force"].shape, force.shape) + + def test_dens_loss_forward_smoke(self) -> None: + """`DeNSLoss` should build noisy inputs and return a finite training loss.""" + model = get_sezm_model(self._build_model_params(use_compile=False)) + prepare_model_for_loss(model, {"type": "dens"}) + loss_module = DeNSLoss( + starter_learning_rate=1.0e-3, + start_pref_e=1.0, + limit_pref_e=1.0, + start_pref_f=1.0, + limit_pref_f=1.0, + dens_prob=1.0, + dens_std=0.025, + dens_corrupt_ratio=0.5, + dens_denoising_pos_coefficient=10.0, + loss_func="mae", + ) + coord, atype, box, force, _ = self._tiny_system() + label = { + "energy": torch.zeros((1, 1), device=self.device, dtype=torch.float32), + "force": force, + "find_energy": 1.0, + "find_force": 1.0, + } + model_pred, loss, more_loss = loss_module( + { + "coord": coord, + "atype": atype, + "box": box, + }, + model, + label, + natoms=atype.shape[1], + learning_rate=1.0e-3, + ) + self.assertEqual(model.get_active_mode(), "dens") + self.assertIn("force", model_pred) + self.assertTrue(torch.isfinite(loss)) + + def test_dens_stat_roundtrip(self) -> None: + """`dens` statistics should roundtrip the global direct-force RMSD.""" + sampled = self._dens_stat_samples() + expected_force_rmsd = self._expected_dens_force_rmsd(sampled) + + model = get_sezm_model(self._build_model_params(use_compile=False)) + prepare_model_for_loss(model, {"type": "dens"}) + + with tempfile.TemporaryDirectory() as tmpdir: + h5file = Path(tmpdir) / "sezm_stat.hdf5" + with h5py.File(h5file, "w"): + pass + + stat_path = DPPath(str(h5file), "a") + try: + model.atomic_model.compute_or_load_stat( + lambda: sampled, + stat_file_path=stat_path, + ) + self.assertAlmostEqual( + model.atomic_model.dens_force_rmsd.item(), + expected_force_rmsd, + places=7, + ) + self.assertEqual(model.get_active_mode(), "dens") + + stored_force_rmsd = (stat_path / "O H" / "rmsd_dforce").load_numpy() + self.assertAlmostEqual( + float(np.asarray(stored_force_rmsd).reshape(-1)[0]), + expected_force_rmsd, + places=7, + ) + + fresh_model = get_sezm_model( + self._build_model_params(use_compile=False) + ) + prepare_model_for_loss(fresh_model, {"type": "dens"}) + + def raise_error() -> None: + raise RuntimeError("statistics should be restored from file") + + fresh_model.atomic_model.compute_or_load_stat( + raise_error, + stat_file_path=stat_path, + ) + self.assertAlmostEqual( + fresh_model.atomic_model.dens_force_rmsd.item(), + expected_force_rmsd, + places=7, + ) + self.assertEqual(fresh_model.get_active_mode(), "dens") + finally: + stat_path.root.close() + + +# ============================================================================= +# LoRA fine-tune tests +# ============================================================================= + + +class _LoRATestCase(unittest.TestCase): + """Shared device / seeding base for LoRA tests.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + +class TestLoRASO3Adapter(_LoRATestCase): + """Unit tests for :class:`LoRASO3`.""" + + def setUp(self) -> None: + super().setUp() + torch.manual_seed(17) + + def _build_base_and_lora( + self, + *, + rank: int = 4, + lmax: int = 2, + in_channels: int = 4, + out_channels: int = 5, + n_focus: int = 1, + mlp_bias: bool = False, + dtype: torch.dtype = torch.float32, + ) -> tuple[SO3Linear, LoRASO3]: + base = SO3Linear( + lmax=lmax, + in_channels=in_channels, + out_channels=out_channels, + n_focus=n_focus, + dtype=dtype, + mlp_bias=mlp_bias, + trainable=True, + seed=101, + ) + lora = LoRASO3(base, rank=rank, alpha=float(rank)) + return base, lora + + def _random_input(self, lora: LoRASO3) -> torch.Tensor: + n_dim = (lora.lmax + 1) ** 2 + return torch.randn( + 3, + n_dim, + lora.n_focus, + lora.in_channels, + device=self.device, + dtype=lora.dtype, + ) + + def test_merge_into_base_matches_forward(self) -> None: + """Numerical parity between LoRASO3 forward and its merged base.""" + _, lora = self._build_base_and_lora() + torch.nn.init.normal_(lora.B_by_l, std=0.05) + x = self._random_input(lora) + out_lora = lora(x) + merged = lora.merge_into_base() + out_merged = merged(x) + torch.testing.assert_close(out_lora, out_merged, atol=1e-6, rtol=1e-5) + self.assertIs(type(merged), SO3Linear) + + +class TestLoRASO2Adapter(_LoRATestCase): + """Unit tests for :class:`LoRASO2`.""" + + def setUp(self) -> None: + super().setUp() + torch.manual_seed(23) + + def _build_base_and_lora( + self, + *, + rank: int = 4, + lmax: int = 2, + mmax: int = 2, + in_channels: int = 4, + out_channels: int = 5, + n_focus: int = 1, + mlp_bias: bool = False, + dtype: torch.dtype = torch.float32, + ) -> tuple[SO2Linear, LoRASO2]: + base = SO2Linear( + lmax=lmax, + mmax=mmax, + in_channels=in_channels, + out_channels=out_channels, + n_focus=n_focus, + dtype=dtype, + mlp_bias=mlp_bias, + seed=202, + trainable=True, + ) + lora = LoRASO2(base, rank=rank, alpha=float(rank)) + return base, lora + + def _random_input(self, lora: LoRASO2) -> torch.Tensor: + return torch.randn( + 3, + lora.n_focus, + lora.reduced_dim, + lora.in_channels, + device=self.device, + dtype=lora.dtype, + ) + + def _randomize_lora_B(self, lora: LoRASO2) -> None: + torch.nn.init.normal_(lora.B_m0, std=0.05) + for b in lora.B_m: + torch.nn.init.normal_(b, std=0.05) + + def test_merge_into_base_matches_forward(self) -> None: + """Numerical parity between LoRASO2 forward and its merged base.""" + _, lora = self._build_base_and_lora() + self._randomize_lora_B(lora) + x = self._random_input(lora) + out_lora = lora(x) + merged = lora.merge_into_base() + out_merged = merged(x) + torch.testing.assert_close(out_lora, out_merged, atol=1e-6, rtol=1e-5) + self.assertIs(type(merged), SO2Linear) + + def test_z_rotation_equivariance(self) -> None: + """Rotating x by the m-major z-block rotation commutes with LoRASO2 forward.""" + lmax, mmax = 2, 1 + _, lora = self._build_base_and_lora( + rank=3, lmax=lmax, mmax=mmax, in_channels=6, out_channels=4, n_focus=1 + ) + self._randomize_lora_B(lora) + batch = 8 + dtype = lora.dtype + x = torch.randn( + batch, + lora.n_focus, + lora.reduced_dim, + lora.in_channels, + device=self.device, + dtype=dtype, + ) + angles = torch.rand(batch, device=self.device, dtype=dtype) * 2 * math.pi + z_mat = _build_m_major_z_rotation(angles, lmax, mmax, self.device) + x_rot = torch.einsum("bij,bfjc->bfic", z_mat, x) + lhs = lora(x_rot) + rhs = torch.einsum("bij,bfjc->bfic", z_mat, lora(x)) + torch.testing.assert_close(lhs, rhs, atol=1e-5, rtol=1e-5) + + +class TestApplyLoRAToSeZM(_LoRATestCase): + """Tests for the full SeZM LoRA injection policy.""" + + def setUp(self) -> None: + super().setUp() + torch.manual_seed(31) + self.model = get_sezm_model(_build_lora_sezm_model_params()) + apply_lora_to_sezm(self.model, rank=4, alpha=4.0) + + def test_so3_and_so2_are_subclassed(self) -> None: + """Every SO3Linear / SO2Linear submodule is now a LoRA subclass.""" + n_lora_so3 = 0 + n_lora_so2 = 0 + for mod in self.model.modules(): + if type(mod) is SO3Linear: + self.fail("Found a bare SO3Linear; apply_lora_to_sezm missed it.") + if type(mod) is SO2Linear: + self.fail("Found a bare SO2Linear; apply_lora_to_sezm missed it.") + if isinstance(mod, LoRASO3): + n_lora_so3 += 1 + elif isinstance(mod, LoRASO2): + n_lora_so2 += 1 + self.assertGreater(n_lora_so3, 0) + self.assertGreater(n_lora_so2, 0) + + def test_lora_base_weights_are_frozen(self) -> None: + """Base weight matrices inside every LoRA wrapper stay frozen. + + Bias-like parameters (``bias`` / ``bias0``) remain trainable by the + leaf-name rule "any leaf containing 'bias' is unfrozen"; this test + asserts the large weight matrices specifically. + """ + for mod in self.model.modules(): + if isinstance(mod, LoRASO3): + self.assertFalse(mod.weight.requires_grad) + elif isinstance(mod, LoRASO2): + self.assertFalse(mod.weight_m0.requires_grad) + for w in mod.weight_m: + self.assertFalse(w.requires_grad) + + def test_lora_adapter_params_are_trainable(self) -> None: + """LoRA A/B parameters are trainable everywhere.""" + for mod in self.model.modules(): + if isinstance(mod, LoRASO3): + self.assertTrue(mod.A_by_l.requires_grad) + self.assertTrue(mod.B_by_l.requires_grad) + elif isinstance(mod, LoRASO2): + self.assertTrue(mod.A_m0.requires_grad) + self.assertTrue(mod.B_m0.requires_grad) + for a, b in zip(mod.A_m, mod.B_m, strict=True): + self.assertTrue(a.requires_grad) + self.assertTrue(b.requires_grad) + + def test_full_unfreezes(self) -> None: + """fitting_net / radial_embedding / env_seed_embedding are fully trainable.""" + fitting = self.model.atomic_model.fitting_net + self.assertIsNotNone(fitting) + for p in fitting.parameters(): + self.assertTrue(p.requires_grad) + radial = self.model.atomic_model.descriptor.radial_embedding + for p in radial.parameters(): + self.assertTrue(p.requires_grad) + env_seed = self.model.atomic_model.descriptor.env_seed_embedding + self.assertIsNotNone(env_seed) + # All non-type-embed params inside env_seed must be trainable. + for name, p in env_seed.named_parameters(): + if name.endswith("adam_type_embedding"): + continue + self.assertTrue( + p.requires_grad, msg=f"env_seed param {name} should be trainable" + ) + + def test_override_freezes_type_embed_and_radial_freqs(self) -> None: + """``adam_type_embedding`` and ``adam_freqs`` stay frozen.""" + frozen_leaves = {"adam_type_embedding", "adam_freqs"} + hit = dict.fromkeys(frozen_leaves, 0) + for name, p in self.model.named_parameters(): + leaf = name.rsplit(".", 1)[-1] + if leaf in frozen_leaves: + self.assertFalse( + p.requires_grad, + msg=f"{name} should stay frozen after LoRA injection", + ) + hit[leaf] += 1 + for leaf, count in hit.items(): + self.assertGreater( + count, + 0, + msg=f"No parameter with leaf {leaf} found; test coverage gap.", + ) + + def test_override_freezes_gated_activation(self) -> None: + """Every parameter inside a GatedActivation is frozen.""" + found = False + for mod in self.model.modules(): + if isinstance(mod, GatedActivation): + for p in mod.parameters(): + self.assertFalse(p.requires_grad) + found = True + self.assertTrue(found, msg="Expected at least one GatedActivation in SeZM.") + + +class TestBuildMergedStateDict(_LoRATestCase): + """Tests for the non-destructive merged-state-dict helper.""" + + def setUp(self) -> None: + super().setUp() + torch.manual_seed(41) + self.model = get_sezm_model(_build_lora_sezm_model_params()) + apply_lora_to_sezm(self.model, rank=4, alpha=4.0) + # Randomize every B so LoRA delta is non-trivial. + for mod in self.model.modules(): + if isinstance(mod, LoRASO3): + torch.nn.init.normal_(mod.B_by_l, std=0.05) + elif isinstance(mod, LoRASO2): + torch.nn.init.normal_(mod.B_m0, std=0.05) + for b in mod.B_m: + torch.nn.init.normal_(b, std=0.05) + + def test_keys_match_plain_sezm(self) -> None: + """Merged state dict has the same keys as a never-LoRA'ed sibling model.""" + plain_model = get_sezm_model(_build_lora_sezm_model_params()) + plain_keys = set(plain_model.state_dict().keys()) + merged = build_merged_state_dict(self.model) + merged_keys = set(merged.keys()) + self.assertEqual(merged_keys, plain_keys) + # Explicitly assert that no LoRA-only key survived. + for key in merged_keys: + leaf = key.rsplit(".", 1)[-1] + self.assertNotIn( + leaf, + {"A_by_l", "B_by_l", "A_m0", "B_m0"}, + msg=f"LoRA-only leaf {leaf} should not appear in merged state", + ) + parts = key.split(".") + i = len(parts) - 1 + while i > 0 and parts[i].isdigit(): + i -= 1 + self.assertNotIn(parts[i], {"A_m", "B_m"}) + + def test_weight_values_include_delta(self) -> None: + """Every LoRA weight key in the merged state equals ``W + ΔW``.""" + merged = build_merged_state_dict(self.model) + # Keys live under `atomic_model.descriptor....` inside SeZMModel; helper + # walks self.model.named_modules() so prefix is "" at the top. + for name, mod in self.model.named_modules(): + prefix = name + "." if name else "" + if isinstance(mod, LoRASO3): + expected = ( + mod.weight.detach() + + torch.einsum("lor,lri->lio", mod.B_by_l, mod.A_by_l).detach() + * mod.scaling + ) + torch.testing.assert_close( + merged[prefix + "weight"], expected, atol=1e-6, rtol=1e-5 + ) + elif isinstance(mod, LoRASO2): + expected_m0 = ( + mod.weight_m0.detach() + + torch.einsum("ri,or->io", mod.A_m0, mod.B_m0).detach() + * mod.scaling + ) + torch.testing.assert_close( + merged[prefix + "weight_m0"], expected_m0, atol=1e-6, rtol=1e-5 + ) + for m_idx, w in enumerate(mod.weight_m): + expected_m = ( + w.detach() + + torch.einsum( + "ri,or->io", mod.A_m[m_idx], mod.B_m[m_idx] + ).detach() + * mod.scaling + ) + torch.testing.assert_close( + merged[prefix + f"weight_m.{m_idx}"], + expected_m, + atol=1e-6, + rtol=1e-5, + ) + + +class TestSeZMModelLoRACompile(unittest.TestCase): + """LoRA + ``torch.compile`` end-to-end consistency test. + + Runs the SeZM ``ener`` path with ``use_compile=True`` against the eager + reference on the same LoRA-injected model (randomized ``B`` so the LoRA + delta is non-trivial) and checks forward / first-order / second-order + consistency, mirroring the style of + :meth:`TestSeZMModelCompile.test_forward_backward_double_backward_matches_compile`. + """ + + def setUp(self) -> None: + self.device = env.DEVICE + torch.manual_seed(2024) + + def _tiny_system(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build a compact two-frame, three-atom system for LoRA compile tests.""" + coord = torch.tensor( + [ + [[0.0, 0.0, 0.0], [1.1, 0.3, 0.0], [0.2, 1.5, 0.4]], + [[0.1, 0.2, 0.3], [0.9, 1.0, 0.1], [2.0, 0.5, 1.2]], + ], + dtype=torch.float32, + device=self.device, + ) + atype = torch.tensor( + [[0, 1, 0], [1, 0, 1]], dtype=torch.int32, device=self.device + ) + box = torch.tensor( + [ + [10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0], + [10.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 10.0], + ], + dtype=torch.float32, + device=self.device, + ) + return coord, atype, box + + @staticmethod + def _build_matched_lora_models() -> tuple[SeZMModel, SeZMModel]: + """Build eager + compile SeZM twins that share LoRA-augmented weights.""" + params_eager = _build_lora_sezm_model_params(use_compile=False) + model_eager = get_sezm_model(params_eager) + apply_lora_to_sezm(model_eager, rank=2, alpha=4.0) + # Randomize every LoRA B so the LoRA delta is non-trivial across both + # branches; randomize A similarly so the low-rank term has full rank. + for mod in model_eager.modules(): + if isinstance(mod, LoRASO3): + torch.nn.init.normal_(mod.A_by_l, std=0.05) + torch.nn.init.normal_(mod.B_by_l, std=0.05) + elif isinstance(mod, LoRASO2): + torch.nn.init.normal_(mod.A_m0, std=0.05) + torch.nn.init.normal_(mod.B_m0, std=0.05) + for a, b in zip(mod.A_m, mod.B_m, strict=True): + torch.nn.init.normal_(a, std=0.05) + torch.nn.init.normal_(b, std=0.05) + + params_compile = _build_lora_sezm_model_params(use_compile=True) + model_compile = get_sezm_model(params_compile) + apply_lora_to_sezm(model_compile, rank=2, alpha=4.0) + # After injection both models share the same named-parameter layout; + # copying the eager state_dict also copies the randomized LoRA A/B. + model_compile.load_state_dict(model_eager.state_dict()) + return model_eager, model_compile + + def test_forward_and_backward_match_eager(self) -> None: + """Forward / first-order / second-order outputs agree with eager.""" + coord, atype, box = self._tiny_system() + model_eager, model_compile = self._build_matched_lora_models() + model_eager.train() + model_compile.train() + + # === Forward === + out_eager = model_eager(coord, atype, box=box) + out_compile = model_compile(coord, atype, box=box) + _assert_close_with_strict_warning( + out_eager["energy"], + out_compile["energy"], + atol=1.0e-6, + rtol=1.0e-6, + msg="LoRA energy mismatch", + ) + _assert_close_with_strict_warning( + out_eager["force"], + out_compile["force"], + atol=2.0e-4, + rtol=1.0e-5, + msg="LoRA force mismatch", + ) + + # === First-order backward (d energy / d params) === + model_eager.zero_grad(set_to_none=True) + model_compile.zero_grad(set_to_none=True) + out_eager["energy"].sum().backward() + out_compile["energy"].sum().backward() + grad_atol = 1.0e-5 if self.device == torch.device("cpu") else 2.0e-3 + grad_rtol = 1.0e-5 if self.device == torch.device("cpu") else 1.0e-4 + force_grad_atol = 1.0e-2 + force_grad_rtol = 1.0e-4 + grads_eager = { + name: ( + torch.zeros_like(param) + if param.grad is None + else param.grad.detach().clone() + ) + for name, param in model_eager.named_parameters() + } + grads_compile = { + name: ( + torch.zeros_like(param) + if param.grad is None + else param.grad.detach().clone() + ) + for name, param in model_compile.named_parameters() + } + self.assertEqual(set(grads_eager.keys()), set(grads_compile.keys())) + for name in grads_eager.keys(): + _assert_close_with_strict_warning( + grads_eager[name], + grads_compile[name], + atol=grad_atol, + rtol=grad_rtol, + msg=f"energy-grad mismatch at {name}", + ) + + # === Second-order backward via force loss (d force^2 / d params) === + model_eager.zero_grad(set_to_none=True) + model_compile.zero_grad(set_to_none=True) + out_eager = model_eager(coord, atype, box=box) + out_compile = model_compile(coord, atype, box=box) + torch.sum(out_eager["force"] * out_eager["force"]).backward() + torch.sum(out_compile["force"] * out_compile["force"]).backward() + grads_eager_2 = { + name: ( + torch.zeros_like(param) + if param.grad is None + else param.grad.detach().clone() + ) + for name, param in model_eager.named_parameters() + } + grads_compile_2 = { + name: ( + torch.zeros_like(param) + if param.grad is None + else param.grad.detach().clone() + ) + for name, param in model_compile.named_parameters() + } + for name in grads_eager_2.keys(): + _assert_close_with_strict_warning( + grads_eager_2[name], + grads_compile_2[name], + atol=force_grad_atol, + rtol=force_grad_rtol, + msg=f"force-grad-sq mismatch at {name}", + ) diff --git a/source/tests/pt/model/test_sezm_spin_model.py b/source/tests/pt/model/test_sezm_spin_model.py new file mode 100644 index 0000000000..f293d65bef --- /dev/null +++ b/source/tests/pt/model/test_sezm_spin_model.py @@ -0,0 +1,412 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import json +import os +import tempfile +import unittest +import warnings +from unittest import ( + mock, +) + +import torch + +from deepmd.pt.loss import ( + EnergySpinLoss, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.model.model.sezm_spin_model import ( + SeZMSpinModel, +) +from deepmd.pt.train.training import ( + prepare_model_for_loss, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) +from deepmd.pt.utils.serialization import ( + deserialize_to_file, +) + +warnings.filterwarnings( + # Keep the compile-test warning summary focused on strict-tolerance drift. + # PyTorch's AOTAutograd cache emits an internal Python 3.14 deprecation + # warning that is unrelated to SeZM numerical correctness. + "ignore", + category=DeprecationWarning, + module=r"torch\._functorch\._aot_autograd\.autograd_cache", +) + + +def _assert_close_with_strict_warning( + actual: torch.Tensor, + expected: torch.Tensor, + *, + strict_atol: float = 1.0e-6, + strict_rtol: float = 1.0e-6, + atol: float, + rtol: float, + msg: str, +) -> None: + """Warn on strict compile drift, fail only outside relaxed tolerance.""" + try: + torch.testing.assert_close( + actual, + expected, + atol=strict_atol, + rtol=strict_rtol, + msg=msg, + ) + except AssertionError as err: + warnings.warn( + f"{msg} exceeds strict tolerance " + f"(atol={strict_atol:g}, rtol={strict_rtol:g}) but is checked " + f"against relaxed tolerance (atol={atol:g}, rtol={rtol:g}): {err}", + RuntimeWarning, + stacklevel=2, + ) + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol, msg=msg) + + +def reduce_tensor( + extended_tensor: torch.Tensor, + mapping: torch.Tensor, + nloc: int, +) -> torch.Tensor: + """Reduce an extended tensor back to local atoms.""" + nframes = extended_tensor.shape[0] + ext_dims = extended_tensor.shape[2:] + reduced_tensor = torch.zeros( + [nframes, nloc, *ext_dims], + dtype=extended_tensor.dtype, + device=extended_tensor.device, + ) + mldims = list(mapping.shape) + mapping = mapping.view(mldims + [1] * len(ext_dims)).expand( + [-1] * len(mldims) + list(ext_dims) + ) + return torch.scatter_reduce( + reduced_tensor, + 1, + index=mapping, + src=extended_tensor, + reduce="sum", + ) + + +class TestSeZMSpinModel(unittest.TestCase): + """Test spin support for the SeZM PyTorch model.""" + + def setUp(self) -> None: + self.device = env.DEVICE + torch.manual_seed(2024) + self.coord = torch.tensor( + [ + [ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + ] + ], + dtype=torch.float64, + device=self.device, + ) + self.atype = torch.tensor([[0, 1, 0]], dtype=torch.long, device=self.device) + self.spin = torch.tensor( + [ + [ + [0.20, 0.10, 0.00], + [0.30, 0.00, 0.10], + [0.10, 0.20, 0.10], + ] + ], + dtype=torch.float64, + device=self.device, + ) + self.box = torch.eye(3, dtype=torch.float64, device=self.device).reshape(1, 9) + self.box = self.box * 6.0 + + def _build_model_params( + self, + *, + use_compile: bool = False, + bridging_method: str = "none", + ) -> dict: + """Build a minimal deterministic SeZM spin model config.""" + return { + "type": "SeZM", + "type_map": ["O", "H"], + "spin": { + "use_spin": [True, False], + "virtual_scale": 0.2, + }, + "descriptor": { + "type": "SeZM", + "sel": [2, 2], + "rcut": 3.0, + "channels": 4, + "n_focus": 1, + "n_radial": 3, + "radial_mlp": [6], + "use_env_seed": True, + "random_gamma": False, + "l_schedule": [1, 0], + "mmax": 1, + "so2_norm": False, + "so2_layers": 1, + "n_atten_head": 1, + "sandwich_norm": [True, False, True, False], + "ffn_neurons": 8, + "ffn_blocks": 1, + "s2_activation": [False, True], + "mlp_bias": False, + "layer_scale": False, + "use_amp": False, + "activation_function": "silu", + "glu_activation": True, + "precision": "float32", + "seed": 7, + }, + "fitting_net": { + "neuron": [8], + "activation_function": "silu", + "precision": "float32", + "seed": 7, + }, + "bridging_method": bridging_method, + "bridging_r_inner": 0.8, + "bridging_r_outer": 1.2, + "use_compile": use_compile, + } + + def test_factory_shapes_and_masks(self) -> None: + """Factory should build SeZMSpinModel with public real-type metadata.""" + model = get_model(self._build_model_params()).to(self.device) + + self.assertIsInstance(model, SeZMSpinModel) + self.assertTrue(model.has_spin()) + self.assertEqual(model.get_type_map(), ["O", "H"]) + self.assertEqual(model.get_sel(), [2, 2]) + + out = model(self.coord, self.atype, spin=self.spin, box=self.box) + + self.assertEqual(out["energy"].shape, (1, 1)) + self.assertEqual(out["atom_energy"].shape, (1, 3, 1)) + self.assertEqual(out["force"].shape, (1, 3, 3)) + self.assertEqual(out["force_mag"].shape, (1, 3, 3)) + torch.testing.assert_close( + out["mask_mag"], + torch.tensor( + [[[True], [False], [True]]], + dtype=torch.bool, + device=self.device, + ), + ) + + def test_forward_lower_matches_forward(self) -> None: + """Lower spin interface should match the standard spin forward path.""" + model = get_model(self._build_model_params()).to(self.device) + out = model(self.coord, self.atype, spin=self.spin, box=self.box) + extended_coord, extended_atype, mapping, nlist = ( + extend_input_and_build_neighbor_list( + self.coord, + self.atype, + model.get_rcut(), + model.get_sel(), + mixed_types=model.mixed_types(), + box=self.box, + ) + ) + extended_spin = torch.gather( + self.spin, + 1, + mapping.unsqueeze(-1).expand(-1, -1, 3), + ) + + out_lower = model.forward_lower( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping=mapping, + ) + + torch.testing.assert_close(out_lower["energy"], out["energy"]) + torch.testing.assert_close(out_lower["atom_energy"], out["atom_energy"]) + reduced_force = reduce_tensor(out_lower["extended_force"], mapping, nloc=3) + reduced_force_mag = reduce_tensor( + out_lower["extended_force_mag"], mapping, nloc=3 + ) + torch.testing.assert_close(reduced_force, out["force"]) + torch.testing.assert_close(reduced_force_mag, out["force_mag"]) + + def test_serialize_deserialize_consistency(self) -> None: + """Serialized SeZMSpinModel should restore the same predictions.""" + model = get_model(self._build_model_params()).to(self.device) + restored = SeZMSpinModel.deserialize(model.serialize()).to(self.device) + + out = model(self.coord, self.atype, spin=self.spin, box=self.box) + restored_out = restored(self.coord, self.atype, spin=self.spin, box=self.box) + + self.assertEqual(restored.get_type_map(), ["O", "H"]) + self.assertEqual(restored.get_sel(), [2, 2]) + for key, value in out.items(): + torch.testing.assert_close(restored_out[key], value) + + def test_deserialize_to_file_uses_spin_model(self) -> None: + """File deserialization should route sezm_spin through SeZMSpinModel.""" + model = get_model(self._build_model_params()).to(self.device) + data = { + "model": model.serialize(), + "model_def_script": self._build_model_params(), + "@variables": {}, + } + + with ( + tempfile.TemporaryDirectory() as tmpdir, + mock.patch( + "deepmd.pt.utils.serialization.torch.jit.script", + side_effect=lambda model: model, + ), + mock.patch("deepmd.pt.utils.serialization.torch.jit.save") as save_mock, + ): + deserialize_to_file(f"{tmpdir}/model.pth", data) + + saved_model = save_mock.call_args.args[0] + self.assertIsInstance(saved_model, SeZMSpinModel) + self.assertEqual( + saved_model.model_def_script, + json.dumps(data["model_def_script"]), + ) + + def test_energy_spin_loss_consumes_force_mag(self) -> None: + """EnergySpinLoss should consume force and magnetic-force predictions.""" + model = get_model(self._build_model_params()).to(self.device) + loss = EnergySpinLoss( + start_pref_e=1.0, + limit_pref_e=1.0, + start_pref_fr=1.0, + limit_pref_fr=1.0, + start_pref_fm=1.0, + limit_pref_fm=1.0, + ) + input_dict = { + "coord": self.coord, + "atype": self.atype, + "spin": self.spin, + "box": self.box, + } + label = { + "energy": torch.zeros((1, 1), dtype=torch.float64, device=self.device), + "force": torch.zeros((1, 3, 3), dtype=torch.float64, device=self.device), + "force_mag": torch.zeros( + (1, 3, 3), dtype=torch.float64, device=self.device + ), + "find_energy": torch.tensor(1.0, device=self.device), + "find_force": torch.tensor(1.0, device=self.device), + "find_force_mag": torch.tensor(1.0, device=self.device), + } + + model_pred, loss_value, more_loss = loss( + input_dict, + model, + label, + natoms=3, + learning_rate=1.0, + ) + + self.assertIn("force_mag", model_pred) + self.assertIn("rmse_fm", more_loss) + self.assertTrue(torch.isfinite(loss_value)) + + def test_dens_mode_is_rejected(self) -> None: + """SeZM spin permanently rejects the dens path.""" + model = get_model(self._build_model_params()).to(self.device) + + with self.assertRaises(NotImplementedError): + prepare_model_for_loss(model, {"type": "dens"}) + + def test_bridging_masks_virtual_pairs(self) -> None: + """ZBL bridging should ignore virtual spin types without indexing them.""" + model = get_model(self._build_model_params(bridging_method="ZBL")).to( + self.device + ) + self.assertIsNotNone(model.inter_potential) + + coord = torch.tensor( + [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.5, 0.0, 0.0]]], + dtype=torch.float64, + device=self.device, + ) + atype_with_virtual = torch.tensor( + [[0, 1, 2]], dtype=torch.long, device=self.device + ) + nlist_real_and_virtual = torch.tensor( + [[[1, 2], [0, 2], [0, 1]]], dtype=torch.long, device=self.device + ) + nlist_real_only = torch.tensor( + [[[1, -1], [0, -1], [-1, -1]]], dtype=torch.long, device=self.device + ) + + energy_with_virtual = model.inter_potential( + coord, + atype_with_virtual, + nlist_real_and_virtual, + nloc=3, + real_type_count=2, + ) + energy_real_only = model.inter_potential( + coord, + atype_with_virtual, + nlist_real_only, + nloc=3, + real_type_count=2, + ) + + torch.testing.assert_close(energy_with_virtual, energy_real_only) + + def test_compile_matches_eager(self) -> None: + """Compiled SeZM spin path should match eager predictions.""" + eager = get_model(self._build_model_params(use_compile=False)).to(self.device) + with mock.patch.dict(os.environ, {"DP_COMPILE_INFER": "1"}, clear=False): + compiled = get_model(self._build_model_params(use_compile=True)).to( + self.device + ) + compiled.load_state_dict(copy.deepcopy(eager.state_dict())) + eager.eval() + compiled.eval() + + out_eager = eager(self.coord, self.atype, spin=self.spin, box=self.box) + out_compiled = compiled(self.coord, self.atype, spin=self.spin, box=self.box) + + self.assertIn((False, False, True), compiled.compiled_core_compute_cache) + _assert_close_with_strict_warning( + out_compiled["energy"], + out_eager["energy"], + atol=1.0e-6, + rtol=1.0e-6, + msg="spin compile energy mismatch", + ) + _assert_close_with_strict_warning( + out_compiled["force"], + out_eager["force"], + atol=1.0e-6, + rtol=1.0e-6, + msg="spin compile force mismatch", + ) + _assert_close_with_strict_warning( + out_compiled["force_mag"], + out_eager["force_mag"], + atol=1.0e-6, + rtol=1.0e-6, + msg="spin compile magnetic force mismatch", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/requirements.txt b/source/tests/pt/requirements.txt index 74abad719e..c662f1ed52 100644 --- a/source/tests/pt/requirements.txt +++ b/source/tests/pt/requirements.txt @@ -4,3 +4,4 @@ dpdata ase coverage pytest +e3nn diff --git a/source/tests/pt/test_train_utils.py b/source/tests/pt/test_train_utils.py new file mode 100644 index 0000000000..f279726de5 --- /dev/null +++ b/source/tests/pt/test_train_utils.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from unittest.mock import ( + patch, +) + +import torch + +from deepmd.pt.train.utils import ( + clip_grad_norm_with_stable_fallback, +) + + +class TestStableGradClip(unittest.TestCase): + def test_fsdp_path_finite_grads(self) -> None: + p = torch.nn.Parameter(torch.zeros(1, device="cpu")) + p.grad = torch.tensor([2.0], device="cpu") + norm = clip_grad_norm_with_stable_fallback( + [p], + max_norm=1.0, + use_stable_fallback=False, + named_parameters=lambda: [("p", p)], + ) + + self.assertTrue(torch.isfinite(norm)) + self.assertAlmostEqual(p.grad.item(), 1.0, places=5) + + def test_fsdp_path_nonfinite_raises(self) -> None: + p = torch.nn.Parameter(torch.zeros(1, device="cpu")) + p.grad = torch.tensor([float("nan")], device="cpu") + + with self.assertRaisesRegex(RuntimeError, "p:"): + clip_grad_norm_with_stable_fallback( + [p], + max_norm=1.0, + use_stable_fallback=False, + named_parameters=lambda: [("p", p)], + ) + + def test_stable_fallback_nan_individual_grad_raises(self) -> None: + p0 = torch.nn.Parameter(torch.zeros(1, device="cpu")) + p1 = torch.nn.Parameter(torch.zeros(1, device="cpu")) + p0.grad = torch.tensor([float("nan")], device="cpu") + p1.grad = torch.tensor([1.0], device="cpu") + + with self.assertRaisesRegex(RuntimeError, "p0:"): + clip_grad_norm_with_stable_fallback( + [p0, p1], + max_norm=1.0, + named_parameters=lambda: [("p0", p0), ("p1", p1)], + ) + + def test_healthy_path_no_overflow(self) -> None: + p = torch.nn.Parameter(torch.zeros(1, device="cpu")) + p.grad = torch.tensor([0.5], device="cpu") + norm = clip_grad_norm_with_stable_fallback( + [p], + max_norm=1.0, + named_parameters=lambda: [("p", p)], + ) + + self.assertTrue(torch.isfinite(norm)) + self.assertAlmostEqual(p.grad.item(), 0.5, places=5) + + def test_empty_parameters(self) -> None: + norm = clip_grad_norm_with_stable_fallback([], max_norm=1.0) + + self.assertEqual(norm.item(), 0.0) + + def test_fallback_clips_large_finite_gradients(self) -> None: + p0, p1 = self._make_large_grad_parameters() + + with patch( + "torch.nn.utils.clip_grad_norm_", + side_effect=RuntimeError("non-finite total norm"), + ): + total_norm = clip_grad_norm_with_stable_fallback( + [p0, p1], + max_norm=3.0, + named_parameters=lambda: [("p0", p0), ("p1", p1)], + ) + + self._check_clipped_norm(total_norm, p0, p1) + + def test_real_overflow_path_uses_stable_fallback(self) -> None: + p0, p1 = self._make_large_grad_parameters() + + total_norm = clip_grad_norm_with_stable_fallback( + [p0, p1], + max_norm=3.0, + named_parameters=lambda: [("p0", p0), ("p1", p1)], + ) + + self._check_clipped_norm(total_norm, p0, p1) + + def _make_large_grad_parameters( + self, + ) -> tuple[torch.nn.Parameter, torch.nn.Parameter]: + p0 = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32, device="cpu")) + p1 = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32, device="cpu")) + p0.grad = torch.tensor([torch.finfo(torch.float32).max / 2], device="cpu") + p1.grad = torch.tensor([torch.finfo(torch.float32).max / 2], device="cpu") + return p0, p1 + + def _check_clipped_norm( + self, + total_norm: torch.Tensor, + p0: torch.nn.Parameter, + p1: torch.nn.Parameter, + ) -> None: + clipped_norm = torch.linalg.vector_norm( + torch.stack([p0.grad.double().norm(), p1.grad.double().norm()]) + ) + self.assertTrue(torch.isfinite(total_norm)) + self.assertEqual(total_norm.dtype, torch.float64) + self.assertAlmostEqual(clipped_norm.item(), 3.0, places=5) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index c4e58c0368..5a10b559ce 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -6,6 +6,7 @@ import signal import tempfile import unittest +import warnings from collections.abc import ( Callable, ) @@ -38,6 +39,7 @@ get_finetune_rules, ) from deepmd.pt.utils.multi_task import ( + _cascade_top_level_defaults, preprocess_shared_params, ) from deepmd.utils.argcheck import ( @@ -1009,6 +1011,81 @@ def test_full_validation_rejects_multitask(self) -> None: normalize(config, multi_task=True) +class TestMultiTaskUtils(unittest.TestCase): + def test_cascade_top_level_defaults(self) -> None: + cfg = {"foo": 1, "model_dict": {"a": {}, "b": {"foo": 2}}} + _cascade_top_level_defaults(cfg) + + self.assertEqual(cfg["model_dict"]["a"]["foo"], 1) + self.assertEqual(cfg["model_dict"]["b"]["foo"], 2) + self.assertNotIn("foo", cfg) + + def test_cascade_keeps_reserved_top_level_keys(self) -> None: + cfg = {"shared_dict": {"x": 1}, "model_dict": {"a": {}}} + _cascade_top_level_defaults(cfg) + + self.assertIn("shared_dict", cfg) + self.assertNotIn("shared_dict", cfg["model_dict"]["a"]) + + def test_cascade_deepcopy_independence(self) -> None: + cfg = {"foo": [1, 2], "model_dict": {"a": {}, "b": {}}} + _cascade_top_level_defaults(cfg) + cfg["model_dict"]["a"]["foo"].append(99) + + self.assertEqual(cfg["model_dict"]["b"]["foo"], [1, 2]) + + +class TestSkippedTrainingBatch(unittest.TestCase): + def setUp(self) -> None: + self._cwd = os.getcwd() + self._tmpdir = tempfile.TemporaryDirectory() + os.chdir(self._tmpdir.name) + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config = convert_optimizer_v31_to_v32(self.config, warning=False) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_se_e2_a) + self.config["training"]["numb_steps"] = 2 + self.config["training"]["save_freq"] = 2 + self.config["training"]["disp_training"] = False + self.config["validating"] = { + "full_validation": False, + "ema_full_validation": False, + } + + def tearDown(self) -> None: + os.chdir(self._cwd) + self._tmpdir.cleanup() + + def test_skipped_batch_does_not_advance_scheduler(self) -> None: + trainer = get_trainer(deepcopy(self.config)) + original_get_data = trainer.get_data + skipped = {"done": False} + + def get_data( + is_train: bool = True, task_key: str = "Default" + ) -> tuple[dict, dict, dict]: + if is_train and not skipped["done"]: + skipped["done"] = True + return {}, {}, {} + return original_get_data(is_train=is_train, task_key=task_key) + + trainer.get_data = get_data + with warnings.catch_warnings(): + warnings.filterwarnings( + "error", + message=r"Detected call of `lr_scheduler\.step\(\)` before `optimizer\.step\(\)`.*", + category=UserWarning, + ) + trainer.run() + + self.assertTrue(skipped["done"]) + self.assertEqual(trainer.scheduler.last_epoch, 1) + + class TestEMATraining(unittest.TestCase): def setUp(self) -> None: import deepmd.pt.train.training as training_module