[PyTorch] Expert Parallelism: PyTorch wrapper + autograd ops with symm-mem zero-copy#3035
[PyTorch] Expert Parallelism: PyTorch wrapper + autograd ops with symm-mem zero-copy#3035phu0ngng wants to merge 2 commits into
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ributed tests/example Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Greptile SummaryThis PR adds the PyTorch-level binding for Expert Parallelism (EP): a public Python API (
Confidence Score: 3/5Two correctness bugs in the Python autograd layer and C++ bootstrap path need fixes before this lands in a production training run. The _EpDispatch.backward fallback allocates a zero-gradient tensor shaped (max_tokens_per_rank, H) when the correct shape is (recv_capacity_per_rank, H); any training path where the upstream gradient of recv_tokens is None will hit either a runtime NVTE_CHECK or silent wrong-sized communication. Separately, ep_initialize in the C++ extension creates an NCCL communicator and then calls nvte_ep_initialize; if the latter throws, the communicator is never stored and can never be destroyed. transformer_engine/pytorch/ep.py (_EpDispatch.backward zero-grad shape) and transformer_engine/pytorch/csrc/extensions/ep.cpp (ep_initialize NCCL comm lifetime) require the most attention before merge. Important Files Changed
Reviews (1): Last reviewed commit: "ep: PyTorch wrapper, autograd ops, symm-..." | Re-trigger Greptile |
| if g_recv_tokens is None: | ||
| g_recv_tokens = torch.zeros_like(grad_tokens) |
There was a problem hiding this comment.
Wrong fallback shape for
g_recv_tokens zero gradient
grad_tokens has shape (max_tokens_per_rank, H), but g_recv_tokens should have shape (recv_capacity_per_rank, H). These two dimensions are intentionally different (recv capacity is typically much larger). When g_recv_tokens is None — which happens whenever recv_tokens is detached or the loss graph does not flow back through it — the zero-gradient tensor passed to ep_dispatch_bwd has the wrong first dimension. The C++ layer then computes recv_pr = grad.numel() / H = max_tokens_per_rank, causing the NVTE_CHECK g_recv_topk_weights.numel() == recv_pr to fail (since recv_topk_weights has size recv_capacity_per_rank), or silently producing wrong results if the two values happen to coincide.
The correct shape for the zero-gradient fallback can be derived from ctx.recv_topk_weights (already stashed on ctx), whose length equals recv_capacity_per_rank: use torch.zeros(recv_w_tmpl.shape[0], grad_tokens.shape[-1], dtype=grad_tokens.dtype, device=grad_tokens.device) instead of torch.zeros_like(grad_tokens).
| /*max_recv_tokens_per_rank=*/static_cast<int>(max_recv_tokens_per_rank), | ||
| /*hidden_dim=*/static_cast<int>(hidden_dim), | ||
| /*max_num_sms=*/static_cast<int>(max_num_sms), | ||
| /*allow_handle_mem_reloc=*/allow_handle_mem_reloc ? 1 : 0, | ||
| }; | ||
|
|
||
| // Copy bytes into a typed ncclUniqueId so the ABI is unambiguous when | ||
| // passing it by value to ncclCommInitRank. | ||
| ncclUniqueId uid{}; | ||
| std::memcpy(uid.internal, unique_id_bytes.data(), kEpUniqueIdSize); | ||
| ncclComm_t ep_comm = nullptr; |
There was a problem hiding this comment.
NCCL communicator leaked on
nvte_ep_initialize failure
ep_comm is created by ncclCommInitRank and then passed to nvte_ep_initialize. If nvte_ep_initialize throws (e.g., the NCCL version or compute-capability check inside EPBackend::initialize fails), control jumps out before g_ep_nccl_comm = ep_comm is reached. The communicator is never stored and can never be destroyed, leaking the NCCL resource. Additionally, validate_config is invoked inside nvte_ep_initialize rather than before ncclCommInitRank, so validation failures surface only after the collective initialization has already succeeded on all participating ranks.
The simplest fix is to assign g_ep_nccl_comm = ep_comm immediately after the two NVTE_CHECKs and before calling nvte_ep_initialize, then rely on ep_finalize to destroy it on any subsequent error. Alternatively, wrap ep_comm in a RAII guard that calls ncclCommDestroy on scope exit unless it is explicitly released.
| @contextlib.contextmanager | ||
| def _zero_copy_scope(enabled: bool): | ||
| """Toggles whether per-step ops apply the symm-mem NCCL window annotation.""" | ||
| if enabled: | ||
| yield | ||
| return | ||
| tex.ep_set_zero_copy(False) | ||
| try: | ||
| yield | ||
| finally: | ||
| tex.ep_set_zero_copy(True) |
There was a problem hiding this comment.
_zero_copy_scope does not save/restore the previous flag value
When enabled=False, the manager unconditionally sets g_zero_copy_enabled=False on entry and g_zero_copy_enabled=True on exit. If two callers both use zero_copy=False concurrently (e.g., pipeline-parallel microbatches dispatched from separate Python threads) or if the context is nested, the inner scope's finally block prematurely re-enables zero-copy while the outer scope is still active. The outer scope's finally then sets True again, but between the inner finally and the outer finally the C++ layer sees True unexpectedly.
The fix is to capture the previous value before writing and restore it unconditionally: save old = tex.ep_get_zero_copy() (adding a corresponding getter), then tex.ep_set_zero_copy(old) in the finally block. At minimum, document the single-caller-at-a-time assumption prominently so pipeline-parallel users know to serialize.
| return ncclFloat32; | ||
| case kNVTEFloat16: | ||
| return ncclFloat16; | ||
| case kNVTEBFloat16: | ||
| return ncclBfloat16; | ||
| case kNVTEInt32: | ||
| return ncclInt32; | ||
| case kNVTEInt64: | ||
| return ncclInt64; | ||
| case kNVTEByte: | ||
| return ncclUint8; | ||
| case kNVTEFloat8E4M3: | ||
| return ncclFloat8e4m3; | ||
| case kNVTEFloat8E5M2: | ||
| return ncclFloat8e5m2; | ||
| default: | ||
| NVTE_ERROR("Unsupported NVTEDType for NCCL EP conversion: ", static_cast<int>(dtype)); | ||
| } | ||
| return ncclFloat32; // unreachable | ||
| } | ||
|
|
There was a problem hiding this comment.
Missing validation for
max_recv_tokens_per_rank > 0
validate_config checks max_tokens_per_rank, hidden_dim, num_experts, and ep_size, but omits max_recv_tokens_per_rank. The inline comment on the cfg.max_recv_tokens_per_rank assignment even notes "Must be > 0; NCCL EP errors out on 0", yet there is no NVTE_CHECK to surface a clear message. Passing 0 from Python would silently reach ncclEpCreateGroup and produce an opaque NCCL error. Add NVTE_CHECK(config.max_recv_tokens_per_rank > 0, ...) alongside the other positive-value checks.
Summary
Second PR in the TE Expert Parallelism (EP) series. Adds the PyTorch binding on top of the common C API (#3034): exposes EP dispatch/combine as
torch.librarycustom ops with autograd, and plumbs NCCL symmetric-memory windows through for the zero-copy path.Payload tensors allocated via
te.pytorch.symm_mem_alloctake the one-sided zero-copy path; anything else silently falls back to staged-copy, so the API is drop-in compatible with any allocator.Implementation
Public Python API (
transformer_engine/pytorch/ep.py)ep_bootstrap/ep_finalize— one-time per-process init + teardown (also auto-registered viaatexit). Rank 0 mints anncclUniqueId, broadcasts it onep_group, backend opens its ownncclComm_t. Requires
ep_group.size() >= 4.symm_mem_alloc(shape, dtype, ep_group)— allocate a per-rank tensor backed by NCCL symmetric memory, already rendezvoused onep_group.EpHandle— per-layer routing state; reuse across steps.ep_prepare/ep_dispatch/ep_combine— per-step ops; both dispatch and combine are autograd-aware and registered astorch.library.custom_op, so they compose withtorch.compilefullgraph capture andCUDA graphs.
C++ bindings (
transformer_engine/pytorch/csrc/extensions/ep.cpp)py::bytesforncclUniqueId, primitives for config) — no c10d ABI on the boundary.maybe_make_window()looks up each payload tensor'sNCCLSymmetricMemorywindow and returns anNVTECommWindowto the backend; non-symm-mem tensors get{nullptr, 0}and the backend picks staged-copy automatically.
tokens,recv_tokens,expert_out,grad) aren't symm-mem-backed. Routing-weight tensors stay silent (nice-to-have, not required). Suppress withNVTE_EP_SILENCE _NONSYMM_WARN=1.Build
build_tools/pytorch.pypropagates-DNVTE_WITH_NCCL_EPto the PyTorch extension. When NCCL EP is off, the extension still loads —nvte_ep_*come from the common stub and throw on first call.Testing
tests/pytorch/distributed/run_ep.py— 17-testunittestsuite:preparecorrectness, dispatch/combine identity (uniform + non-uniform), 3D input, VJPs,top_k=1all-to-one, alignment edge cases, CUDA graph capture (eager + zero-copy),
torch.compilefullgraph, bf16 autocast (eager + autograd), zero-copy autograd combine, symm-mem fallback, gradient checkpointing.tests/pytorch/distributed/run_test_ep.sh. Verified on 8×H200:Ran 17 tests in 19.8s … OKon every rank.examples/pytorch/ep/ep_moe.py— minimal end-to-end MoE forward+backward driver.Type of change
Checklist: