Skip to content
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion deepmd/pt_expt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,13 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None:
decomposition_table=decomp_table,
)(ext_coord, ext_atype, nlist, mapping, fparam, aparam)

# make_fx has captured the graph; input tensors are no longer needed.
del ext_coord, ext_atype, nlist, mapping
if fparam is not None:
del fparam
if aparam is not None:
del aparam

# make_fx inserts aten.detach.default for saved tensors used in the
# decomposed autograd.grad backward ops. These detach nodes break
# second-order gradient flow (d(force)/d(params) for force training).
Expand Down Expand Up @@ -316,12 +323,19 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None:
if compile_opts:
inductor_options.update(compile_opts)

return torch.compile(
compiled = torch.compile(
traced_lower,
backend="inductor",
dynamic=True,
options=inductor_options,
)
del traced_lower
model_uses_cuda = any(param.is_cuda for param in model.parameters()) or any(
buffer.is_cuda for buffer in model.buffers()
)
if model_uses_cuda and torch.cuda.is_available() and torch.cuda.is_initialized():
torch.cuda.empty_cache()
return compiled


class _CompiledModel(torch.nn.Module):
Expand Down Expand Up @@ -994,6 +1008,21 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None:
)

wrapper_mod.model[task_key] = _CompiledModel(model, compiled_lower)

# Release all intermediate tensors built for this task so they don't
# accumulate across tasks in multi-task scenarios.
del ext_coord, ext_atype, mapping, nlist_t
del coord, atype, coord_3d, coord_norm
if box is not None:
del box, box_flat
if fparam is not None:
del fparam
if aparam is not None:
del aparam
del inp
Comment thread
anyangml marked this conversation as resolved.
Outdated
if DEVICE.type == "cuda" and torch.cuda.is_initialized():
torch.cuda.empty_cache()

log.info(
"Model compiled (task=%s, tracing_mode=symbolic, "
"dynamic=True, backend=inductor).",
Expand Down
Loading