Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def setup_pytorch_extension(

setup_mpi_flags(include_dirs, cxx_flags)

# Mirror the cuSOLVERMp gate. newton_schulz.cpp is conditionally compiled
# in the common lib; the pytorch ext glob pulls the same source so it must
# see the same define, otherwise the pybind layer refers to undefined
# cusolvermp_ctx_* symbols.
if bool(int(os.getenv("NVTE_WITH_CUSOLVERMP", "0"))):
cxx_flags.append("-DNVTE_WITH_CUSOLVERMP")

Comment on lines +79 to +85
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? What is the error you are seeing without that? At least from the cursory look at those sources I don't see how the pytorch files would be affected here. Is the comments implying that the PyTorch compilation is somehow taking the common files and compiling them again?

library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ set(transformer_engine_cuda_sources)
set(transformer_engine_cuda_arch_specific_sources)

list(APPEND transformer_engine_cpp_sources
comm_handle.cpp
cudnn_utils.cpp
transformer_engine.cpp
fused_attn/fused_attn.cpp
Expand Down
48 changes: 48 additions & 0 deletions transformer_engine/common/comm_handle.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "transformer_engine/comm_handle.h"

#include "common.h"
#include "transformer_engine/nccl_comm.h"
#include "util/logging.h"

using transformer_engine::convertNVTETensor;

NVTEPeerHandleKind nvte_tensor_peer_handle_kind(const NVTETensor t) {
const auto* tensor = convertNVTETensor(t);
return tensor != nullptr ? tensor->peer_handle_kind : NVTE_PEER_HANDLE_NONE;
}

void nvte_tensor_detach_peer_handle(NVTETensor t) {
auto* tensor = convertNVTETensor(t);
if (tensor == nullptr) return;
tensor->peer_handle_kind = NVTE_PEER_HANDLE_NONE;
tensor->peer_handle_data = nullptr;
tensor->peer_handle_offset = 0;
}

void nvte_tensor_attach_nccl_window(NVTETensor t, void* window, uint64_t offset) {
auto* tensor = convertNVTETensor(t);
NVTE_CHECK(tensor != nullptr, "nvte_tensor_attach_nccl_window: invalid NVTETensor handle");
if (window == nullptr) {
tensor->peer_handle_kind = NVTE_PEER_HANDLE_NONE;
tensor->peer_handle_data = nullptr;
tensor->peer_handle_offset = 0;
return;
}
tensor->peer_handle_kind = NVTE_PEER_HANDLE_NCCL_WINDOW;
tensor->peer_handle_data = window;
tensor->peer_handle_offset = offset;
}

void nvte_tensor_nccl_window(const NVTETensor t, void** window, uint64_t* offset) {
const auto* tensor = convertNVTETensor(t);
const bool has_nccl =
tensor != nullptr && tensor->peer_handle_kind == NVTE_PEER_HANDLE_NCCL_WINDOW;
if (window != nullptr) *window = has_nccl ? tensor->peer_handle_data : nullptr;
if (offset != nullptr) *offset = has_nccl ? tensor->peer_handle_offset : 0;
}
8 changes: 8 additions & 0 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ static_assert(NVTE_BUILD_NUM_PHILOX_ROUNDS > 0,
#endif

#include <cuda_runtime_api.h>
#include <transformer_engine/comm_handle.h>
#include <transformer_engine/transformer_engine.h>

#include <cstdint>
Expand Down Expand Up @@ -179,6 +180,13 @@ struct Tensor {
*/
bool row_scaled_nvfp4 = false;

/*! \brief Optional borrowed peer handle for one-sided RMA against this tensor.
* ``peer_handle_kind`` selects the backend owning ``peer_handle_data``;
* the caller keeps the resource valid for the tensor's lifetime. */
NVTEPeerHandleKind peer_handle_kind = NVTE_PEER_HANDLE_NONE;
void *peer_handle_data = nullptr;
uint64_t peer_handle_offset = 0;

/*! Map from NVTETensorParam to parameter sizes */
static constexpr size_t attr_sizes[] = {
sizeof(NVTEBasicTensor), // kNVTERowwiseData
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

/*! \file comm_handle.h
* \brief Generic peer-handle annotation on NVTETensor for one-sided RMA.
*
* The annotation is borrowed; the tensor never owns the underlying resource.
* Per-backend setters/getters live in dedicated headers (e.g. ``nccl_comm.h``).
*/

#ifndef TRANSFORMER_ENGINE_COMM_HANDLE_H_
#define TRANSFORMER_ENGINE_COMM_HANDLE_H_

#include "transformer_engine.h"

#ifdef __cplusplus
extern "C" {
#endif

/*! \brief Comm backend that owns a tensor's peer handle. */
typedef enum {
NVTE_PEER_HANDLE_NONE = 0,
NVTE_PEER_HANDLE_NCCL_WINDOW = 1,
} NVTEPeerHandleKind;

/*! \brief Peer-handle kind attached to ``t``. */
NVTEPeerHandleKind nvte_tensor_peer_handle_kind(const NVTETensor t);

/*! \brief Clear any peer handle attached to ``t``. */
void nvte_tensor_detach_peer_handle(NVTETensor t);

#ifdef __cplusplus
}
#endif

#endif // TRANSFORMER_ENGINE_COMM_HANDLE_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

/*! \file nccl_comm.h
* \brief Attach a registered NCCL symmetric-memory window to an NVTETensor.
*
* The window is caller-owned and must outlive the tensor; ``attach`` does
* not register or rendezvous it.
*/

#ifndef TRANSFORMER_ENGINE_NCCL_COMM_H_
#define TRANSFORMER_ENGINE_NCCL_COMM_H_

#include "comm_handle.h"
#include "transformer_engine.h"

#ifdef __cplusplus
extern "C" {
#endif

/*! \brief Attach an NCCL window + byte offset to ``t``. Pass ``window=NULL`` to detach.
*
* \param[in,out] t Tensor to annotate.
* \param[in] window Opaque ncclWindow_t (caller-owned), or NULL to clear.
* \param[in] offset Byte offset into the window where this tensor starts.
*/
void nvte_tensor_attach_nccl_window(NVTETensor t, void* window, uint64_t offset);

/*! \brief Read the NCCL window + offset attached to ``t``; yields (NULL, 0) when unset.
* Either out-pointer may be NULL to skip that field.
*/
void nvte_tensor_nccl_window(const NVTETensor t, void** window, uint64_t* offset);

#ifdef __cplusplus
}
#endif

#endif // TRANSFORMER_ENGINE_NCCL_COMM_H_
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -631,12 +631,14 @@ void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at:
* Newton-Schulz (cuSolverMp)
**************************************************************************************************/

#ifdef NVTE_WITH_CUSOLVERMP
int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank);

void cusolvermp_ctx_destroy(int64_t ctx_ptr);

void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t num_iterations,
std::vector<float> coefficients);
#endif // NVTE_WITH_CUSOLVERMP

} // namespace transformer_engine::pytorch

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
* See LICENSE for license information.
************************************************************************/

// Conditionally compiled: the common Newton-Schulz/cuSOLVERMp impl is gated
// behind NVTE_WITH_CUSOLVERMP in the common CMakeLists. Without the gate, the
// pytorch ext glob would pick this file up and produce undefined symbols
// (nvte_cusolvermp_ctx_*). Keep this gate aligned with common/.
#ifdef NVTE_WITH_CUSOLVERMP
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Incomplete guard — pybind.cpp and extensions.h still reference ungated symbols

The #ifdef NVTE_WITH_CUSOLVERMP here makes newton_schulz.cpp compile to an empty TU when the flag is unset, but pybind.cpp (lines 592–601) and extensions.h (lines 634–639) still reference transformer_engine::pytorch::cusolvermp_ctx_create, cusolvermp_ctx_destroy, and newton_schulz unconditionally. Those symbols will be undefined at link time, producing the same linker failure the PR aims to fix. The guard needs to be mirrored in both pybind.cpp and extensions.h.


#include "transformer_engine/newton_schulz.h"

#include "../extensions.h"
Expand Down Expand Up @@ -38,3 +44,5 @@ void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t
}

} // namespace transformer_engine::pytorch

#endif // NVTE_WITH_CUSOLVERMP
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&transformer_engine::pytorch::multi_tensor_compute_scale_inv_e8m0_cuda,
"Fused compute E8M0 scale_inv from amax", py::call_guard<py::gil_scoped_release>());

#ifdef NVTE_WITH_CUSOLVERMP
// Newton-Schulz (cuSolverMp)
m.def("cusolvermp_ctx_create", &transformer_engine::pytorch::cusolvermp_ctx_create,
"Create cuSolverMp context for Newton-Schulz", py::arg("nccl_comm_ptr"), py::arg("nranks"),
Expand All @@ -599,6 +600,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Newton-Schulz matrix orthogonalization", py::arg("ctx_ptr"), py::arg("m"), py::arg("n"),
py::arg("x"), py::arg("num_iterations"), py::arg("coefficients"),
py::call_guard<py::gil_scoped_release>());
#endif // NVTE_WITH_CUSOLVERMP

// Comm+GEMM Overlap
m.def("bulk_overlap_ag_with_external_gemm",
Expand Down
Loading