-
Notifications
You must be signed in to change notification settings - Fork 726
[Common] NVTETensor peer-handle annotation + nccl_comm backend #3017
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| #include "transformer_engine/comm_handle.h" | ||
|
|
||
| #include "common.h" | ||
| #include "transformer_engine/nccl_comm.h" | ||
| #include "util/logging.h" | ||
|
|
||
| using transformer_engine::convertNVTETensor; | ||
|
|
||
| NVTEPeerHandleKind nvte_tensor_peer_handle_kind(const NVTETensor t) { | ||
| const auto* tensor = convertNVTETensor(t); | ||
| return tensor != nullptr ? tensor->peer_handle_kind : NVTE_PEER_HANDLE_NONE; | ||
| } | ||
|
|
||
| void nvte_tensor_detach_peer_handle(NVTETensor t) { | ||
| auto* tensor = convertNVTETensor(t); | ||
| if (tensor == nullptr) return; | ||
| tensor->peer_handle_kind = NVTE_PEER_HANDLE_NONE; | ||
| tensor->peer_handle_data = nullptr; | ||
| tensor->peer_handle_offset = 0; | ||
| } | ||
|
|
||
| void nvte_tensor_attach_nccl_window(NVTETensor t, void* window, uint64_t offset) { | ||
| auto* tensor = convertNVTETensor(t); | ||
| NVTE_CHECK(tensor != nullptr, "nvte_tensor_attach_nccl_window: invalid NVTETensor handle"); | ||
| if (window == nullptr) { | ||
| tensor->peer_handle_kind = NVTE_PEER_HANDLE_NONE; | ||
| tensor->peer_handle_data = nullptr; | ||
| tensor->peer_handle_offset = 0; | ||
| return; | ||
| } | ||
| tensor->peer_handle_kind = NVTE_PEER_HANDLE_NCCL_WINDOW; | ||
| tensor->peer_handle_data = window; | ||
| tensor->peer_handle_offset = offset; | ||
| } | ||
|
|
||
| void nvte_tensor_nccl_window(const NVTETensor t, void** window, uint64_t* offset) { | ||
| const auto* tensor = convertNVTETensor(t); | ||
| const bool has_nccl = | ||
| tensor != nullptr && tensor->peer_handle_kind == NVTE_PEER_HANDLE_NCCL_WINDOW; | ||
| if (window != nullptr) *window = has_nccl ? tensor->peer_handle_data : nullptr; | ||
| if (offset != nullptr) *offset = has_nccl ? tensor->peer_handle_offset : 0; | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| /*! \file comm_handle.h | ||
| * \brief Generic peer-handle annotation on NVTETensor for one-sided RMA. | ||
| * | ||
| * The annotation is borrowed; the tensor never owns the underlying resource. | ||
| * Per-backend setters/getters live in dedicated headers (e.g. ``nccl_comm.h``). | ||
| */ | ||
|
|
||
| #ifndef TRANSFORMER_ENGINE_COMM_HANDLE_H_ | ||
| #define TRANSFORMER_ENGINE_COMM_HANDLE_H_ | ||
|
|
||
| #include "transformer_engine.h" | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| /*! \brief Comm backend that owns a tensor's peer handle. */ | ||
| typedef enum { | ||
| NVTE_PEER_HANDLE_NONE = 0, | ||
| NVTE_PEER_HANDLE_NCCL_WINDOW = 1, | ||
| } NVTEPeerHandleKind; | ||
|
|
||
| /*! \brief Peer-handle kind attached to ``t``. */ | ||
| NVTEPeerHandleKind nvte_tensor_peer_handle_kind(const NVTETensor t); | ||
|
|
||
| /*! \brief Clear any peer handle attached to ``t``. */ | ||
| void nvte_tensor_detach_peer_handle(NVTETensor t); | ||
|
|
||
| #ifdef __cplusplus | ||
| } | ||
| #endif | ||
|
|
||
| #endif // TRANSFORMER_ENGINE_COMM_HANDLE_H_ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| /************************************************************************* | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| /*! \file nccl_comm.h | ||
| * \brief Attach a registered NCCL symmetric-memory window to an NVTETensor. | ||
| * | ||
| * The window is caller-owned and must outlive the tensor; ``attach`` does | ||
| * not register or rendezvous it. | ||
| */ | ||
|
|
||
| #ifndef TRANSFORMER_ENGINE_NCCL_COMM_H_ | ||
| #define TRANSFORMER_ENGINE_NCCL_COMM_H_ | ||
|
|
||
| #include "comm_handle.h" | ||
| #include "transformer_engine.h" | ||
|
|
||
| #ifdef __cplusplus | ||
| extern "C" { | ||
| #endif | ||
|
|
||
| /*! \brief Attach an NCCL window + byte offset to ``t``. Pass ``window=NULL`` to detach. | ||
| * | ||
| * \param[in,out] t Tensor to annotate. | ||
| * \param[in] window Opaque ncclWindow_t (caller-owned), or NULL to clear. | ||
| * \param[in] offset Byte offset into the window where this tensor starts. | ||
| */ | ||
| void nvte_tensor_attach_nccl_window(NVTETensor t, void* window, uint64_t offset); | ||
|
|
||
| /*! \brief Read the NCCL window + offset attached to ``t``; yields (NULL, 0) when unset. | ||
| * Either out-pointer may be NULL to skip that field. | ||
| */ | ||
| void nvte_tensor_nccl_window(const NVTETensor t, void** window, uint64_t* offset); | ||
|
|
||
| #ifdef __cplusplus | ||
| } | ||
| #endif | ||
|
|
||
| #endif // TRANSFORMER_ENGINE_NCCL_COMM_H_ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,12 @@ | |
| * See LICENSE for license information. | ||
| ************************************************************************/ | ||
|
|
||
| // Conditionally compiled: the common Newton-Schulz/cuSOLVERMp impl is gated | ||
| // behind NVTE_WITH_CUSOLVERMP in the common CMakeLists. Without the gate, the | ||
| // pytorch ext glob would pick this file up and produce undefined symbols | ||
| // (nvte_cusolvermp_ctx_*). Keep this gate aligned with common/. | ||
| #ifdef NVTE_WITH_CUSOLVERMP | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The |
||
|
|
||
| #include "transformer_engine/newton_schulz.h" | ||
|
|
||
| #include "../extensions.h" | ||
|
|
@@ -38,3 +44,5 @@ void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t | |
| } | ||
|
|
||
| } // namespace transformer_engine::pytorch | ||
|
|
||
| #endif // NVTE_WITH_CUSOLVERMP | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
? What is the error you are seeing without that? At least from the cursory look at those sources I don't see how the pytorch files would be affected here. Is the comments implying that the PyTorch compilation is somehow taking the common files and compiling them again?