[Common] NVTETensor peer-handle annotation + nccl_comm backend#3017
[Common] NVTETensor peer-handle annotation + nccl_comm backend#3017phu0ngng wants to merge 3 commits into
Conversation
…ewton_schulz on NVTE_WITH_CUSOLVERMP Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Greptile SummaryThis PR introduces a generic peer-handle annotation on
Confidence Score: 4/5Mostly safe to merge, but The new comm-handle machinery is well-structured and null-safe, and the Newton-Schulz build gate is correctly applied across all three affected files. The one concrete defect is that transformer_engine/common/common.h — the Important Files Changed
|
| // behind NVTE_WITH_CUSOLVERMP in the common CMakeLists. Without the gate, the | ||
| // pytorch ext glob would pick this file up and produce undefined symbols | ||
| // (nvte_cusolvermp_ctx_*). Keep this gate aligned with common/. | ||
| #ifdef NVTE_WITH_CUSOLVERMP |
There was a problem hiding this comment.
Incomplete guard — pybind.cpp and extensions.h still reference ungated symbols
The #ifdef NVTE_WITH_CUSOLVERMP here makes newton_schulz.cpp compile to an empty TU when the flag is unset, but pybind.cpp (lines 592–601) and extensions.h (lines 634–639) still reference transformer_engine::pytorch::cusolvermp_ctx_create, cusolvermp_ctx_destroy, and newton_schulz unconditionally. Those symbols will be undefined at link time, producing the same linker failure the PR aims to fix. The guard needs to be mirrored in both pybind.cpp and extensions.h.
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…SOLVERMP Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
a573c9a to
4bd2f5b
Compare
| # Mirror the cuSOLVERMp gate. newton_schulz.cpp is conditionally compiled | ||
| # in the common lib; the pytorch ext glob pulls the same source so it must | ||
| # see the same define, otherwise the pybind layer refers to undefined | ||
| # cusolvermp_ctx_* symbols. | ||
| if bool(int(os.getenv("NVTE_WITH_CUSOLVERMP", "0"))): | ||
| cxx_flags.append("-DNVTE_WITH_CUSOLVERMP") | ||
|
|
There was a problem hiding this comment.
? What is the error you are seeing without that? At least from the cursory look at those sources I don't see how the pytorch files would be affected here. Is the comments implying that the PyTorch compilation is somehow taking the common files and compiling them again?
|
Ok, a general question - this looks like a thing that is basically not compatible with everything else in nvtetensor, so why do we even want to use NVTETensor at all? If you just need a C struct that would hold the pointer and offset then we should make that separate. It is hard to review this without seeing what the concrete implementation would actually use this mechanism for. |
|
Close this PR in favor of having a separate data structure for the NCCL Window instead of attaching it to the NVTETensor. |
Description
NVTETensorso consumers can issue one-sided RMA against a tensor's storage without owning the underlying resource.nccl_comm.h).newton_schulzonNVTE_WITH_CUSOLVERMPto unblock builds without cuSOLVERMp.Type of change
Changes
common/include/transformer_engine/comm_handle.h: payload-agnostic C API — kind tag (NVTEPeerHandleKind),nvte_tensor_peer_handle_kind,nvte_tensor_detach_peer_handle. Reserves slots for future backends (NVSHMEM, CUDA-IPC, UCX).common/include/transformer_engine/nccl_comm.h: NCCL-specific setter/getter —nvte_tensor_attach_nccl_window(t, window, offset)/nvte_tensor_nccl_window(t, &window, &offset). Window is borrowed; tensor never owns it.common/common.h: extendTensorwithpeer_handle_kind / peer_handle_data / peer_handle_offsetfields (ignored by paths that don't issue cross-rank ops).common/comm_handle.cpp: implementations.common/CMakeLists.txt: build the newcomm_handle.cpp.pytorch/csrc/extensions/newton_schulz.cpp+build_tools/pytorch.py: compile/include newton_schulz only whenNVTE_WITH_CUSOLVERMPis set.3rdparty/cudnn-frontendand minor doc tweak underexamples/pytorch/comm_gemm_overlap/README.md.Testing
phuong/ep-2-common-cpp), which exercise the NCCL-window attach/detach path end-to-end.NVTE_CUDA_ARCHS="90;100" NVTE_FRAMEWORK=jax pip install --no-build-isolation -e .).Checklist: