Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 102 additions & 7 deletions deepmd/pt_expt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,13 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None:
decomposition_table=decomp_table,
)(ext_coord, ext_atype, nlist, mapping, fparam, aparam)

# make_fx has captured the graph; input tensors are no longer needed.
del ext_coord, ext_atype, nlist, mapping
if fparam is not None:
del fparam
if aparam is not None:
del aparam

# make_fx inserts aten.detach.default for saved tensors used in the
# decomposed autograd.grad backward ops. These detach nodes break
# second-order gradient flow (d(force)/d(params) for force training).
Expand Down Expand Up @@ -316,12 +323,30 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None:
if compile_opts:
inductor_options.update(compile_opts)

return torch.compile(
compiled = torch.compile(
traced_lower,
backend="inductor",
dynamic=True,
options=inductor_options,
)
# Keep the traced FX graph alive as long as the compiled callable.
# _remove_detach_nodes makes saved activations alias the graph's symbolic
# tensors; if the FX graph is GC'd, its SymInt shape objects lose their
# Python references while C++ view metadata still holds raw pointers to
# them — causing apply_view_meta_sequence to read garbage (crash at
# random training steps, earlier under higher GC pressure from many tasks).
# Use object.__setattr__ to bypass nn.Module.__setattr__: traced_lower is
# an nn.Module, and normal assignment would register it as a submodule of
# compiled (also an nn.Module), creating a cycle in the module tree that
# causes RecursionError in trainer.wrapper.train().
object.__setattr__(compiled, "_traced_lower_ref", traced_lower)
del traced_lower
model_uses_cuda = any(param.is_cuda for param in model.parameters()) or any(
buffer.is_cuda for buffer in model.buffers()
)
if model_uses_cuda and torch.cuda.is_available() and torch.cuda.is_initialized():
torch.cuda.empty_cache()
return compiled


class _CompiledModel(torch.nn.Module):
Expand Down Expand Up @@ -993,13 +1018,52 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None:
compile_opts,
)

# torch.compile is lazy: inductor only compiles on the first
# call. In DDP multi-task training, different ranks may first
# hit a task at different training steps, so one rank can block
# inside inductor for minutes while others spin in AllReduce —
# causing an NCCL timeout. Warmup here, while sample inputs
# still exist, forces eager compilation before training starts.
#
# Match _CompiledModel.forward which sets requires_grad_(True) on
# ext_coord: Dynamo's guard includes requires_grad, so a mismatch
# causes every task's first training call to miss the warmup cache.
ext_coord = ext_coord.detach().requires_grad_(True)
_warmup_out = compiled_lower(
ext_coord, ext_atype, nlist_t, mapping, fparam, aparam
)
del _warmup_out
if DEVICE.type == "cuda" and torch.cuda.is_initialized():
torch.cuda.synchronize()

wrapper_mod.model[task_key] = _CompiledModel(model, compiled_lower)

# Release all intermediate tensors built for this task so they don't
# accumulate across tasks in multi-task scenarios.
del ext_coord, ext_atype, mapping, nlist_t
del coord, atype, coord_3d, coord_norm
if box is not None:
del box, box_flat
if fparam is not None:
del fparam
if aparam is not None:
del aparam
del inp, _
if DEVICE.type == "cuda" and torch.cuda.is_initialized():
torch.cuda.empty_cache()

log.info(
"Model compiled (task=%s, tracing_mode=symbolic, "
"dynamic=True, backend=inductor).",
task_key,
)

# All tasks compiled on this rank — wait for all ranks before
# training starts so no rank enters the training loop while another
# is still blocked in inductor compilation.
if self.is_distributed:
dist.barrier()

# ------------------------------------------------------------------
# Data helpers
# ------------------------------------------------------------------
Expand Down Expand Up @@ -1176,7 +1240,9 @@ def run(self) -> None:
if self.rank == 0:
if not self.multi_task:
train_results = {
k: v for k, v in more_loss.items() if "l2_" not in k
k: (v.item() if isinstance(v, torch.Tensor) else v)
for k, v in more_loss.items()
if "l2_" not in k
}

# validation
Expand All @@ -1197,7 +1263,13 @@ def run(self) -> None:
for k, v in _vmore.items():
if "l2_" not in k:
valid_results[k] = (
valid_results.get(k, 0.0) + v * natoms
valid_results.get(k, 0.0)
+ (
v.item()
if isinstance(v, torch.Tensor)
else v
)
* natoms
)
if sum_natoms > 0:
valid_results = {
Expand All @@ -1210,23 +1282,38 @@ def run(self) -> None:

# current task already has loss
train_results[task_key] = {
k: v for k, v in more_loss.items() if "l2_" not in k
k: (v.item() if isinstance(v, torch.Tensor) else v)
for k, v in more_loss.items()
if "l2_" not in k
}

# compute loss for other tasks
for _key in self.model_keys:
if _key != task_key:
self.optimizer.zero_grad()
self.optimizer.zero_grad(set_to_none=True)
_inp, _lab = self.get_data(is_train=True, task_key=_key)
_, _loss, _more = self._unwrapped(
**_inp,
cur_lr=cur_lr_sched,
label=_lab,
task_key=_key,
)
# Use .item() so the backward graph (and its
# saved activations) can be freed immediately.
# Display passes never call loss.backward(), so
# without this the computation graphs for all
# tasks accumulate simultaneously in GPU memory.
train_results[_key] = {
k: v for k, v in _more.items() if "l2_" not in k
k: (v.item() if isinstance(v, torch.Tensor) else v)
for k, v in _more.items()
if "l2_" not in k
}
del _loss, _more, _inp, _lab
if (
torch.cuda.is_available()
and torch.cuda.is_initialized()
):
torch.cuda.empty_cache()

# validation for each task
_vdata = self.validation_data[_key]
Expand All @@ -1249,7 +1336,15 @@ def run(self) -> None:
_sum_natoms += natoms
for k, v in _vmore.items():
if "l2_" not in k:
_vres[k] = _vres.get(k, 0.0) + v * natoms
_vres[k] = (
_vres.get(k, 0.0)
+ (
v.item()
if isinstance(v, torch.Tensor)
else v
)
* natoms
)
if _sum_natoms > 0:
_vres = {
k: v / _sum_natoms for k, v in _vres.items()
Expand Down
Loading