[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443
[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443denera wants to merge 81 commits into
Conversation
908bbc2 to
69cf235
Compare
| @@ -17,6 +18,12 @@ | |||
|
|
|||
| #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 | |||
|
|
|||
| /* \brief Check if TE is built with cuBlasMp. | |||
| @@ -526,6 +514,11 @@ class CommOverlapHelper : torch::CustomClassHolder { | |||
| ExtComm comm); | |||
|
|
|||
| void ub_barrier(ExtComm comm); | |||
|
|
|||
| int64_t get_nccl_comm_ptr(std::string comm_name) { | |||
| NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL."); | |||
There was a problem hiding this comment.
This error message could be more descriptive - e.g. something like "chosen backend for the communication-computation overlap (cuBLASMp) requires NCCL communicator, but the passed ProcessGroup uses a different backend."
4596411 to
b4ad546
Compare
Greptile SummaryThis PR adds cuBLASMp as a second backend for the Comm+GEMM overlap API across PyTorch and JAX, alongside the existing Userbuffers backend. It introduces new
Confidence Score: 2/5Multiple unresolved issues from prior review rounds remain open, including a CUDA device-memory leak in the warmup path, missing NCCL dependency guards in the public header, an unguarded cublasmp.h include that breaks non-cuBLASMp CI, uninitialized all_outputs2 in a specific test path, and a removed pgid that breaks NCCL-ID-file isolation on shared machines. The architectural direction is sound and several bugs from earlier rounds have been fixed (ncclCommInitRank replacing ncclCommInitAll, plan-id cache now hashing use_cublasmp, cuBLASMp bulk-overlap guard now raising ValueError). However, a cluster of concrete defects reported in previous reviews remain present: the warmup cudaMalloc leak on exceptions, unconditional nccl.h in the public header, unconditional cublasmp.h in the C++ test, missing all_outputs2 assignment in the non-FP8 cuBLASMp test path, and the NCCL-ID-file race after pgid removal. transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp (warmup leak), transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h (unconditional nccl.h), tests/cpp_distributed/test_comm_gemm.cu (unconditional cublasmp.h + vacuous BF16 tolerance), tests/pytorch/distributed/run_gemm_with_overlap.py (unset all_outputs2), transformer_engine/jax/csrc/extensions/cgemm_helper.cpp (NCCL ID file isolation) Important Files Changed
Sequence DiagramsequenceDiagram
participant P as Python initialize_ub
participant H as CommOverlapHelper
participant NCCL as NCCL
participant CO as CommOverlap
participant CMP as cuBLASMp
P->>H: CommOverlapHelper(world_pg, intra_pg)
H->>NCCL: ncclGetUniqueId rank 0
H->>NCCL: ncclCommInitRank nccl_world
H->>NCCL: ncclCommInitRank nccl_intra
H-->>P: helper with shared NCCL comms
P->>CO: CommOverlap cuBLASMp ctor
CO->>H: get_nccl_comm intra
CO->>CMP: nvte_comm_gemm_ctx_create
CO->>CMP: cublasmp_capture_warmup
CMP-->>CO: workspaces registered
P->>CO: cublasmp_ag_gemm forward
CMP-->>P: local 2D output
P->>P: out.view restore rank
P->>CO: cublasmp_gemm_rs backward
CMP-->>P: dX reduce-scattered
P->>P: gather dY for wgrad
Reviews (38): Last reviewed commit: "Merge remote-tracking branch 'upstream/m..." | Re-trigger Greptile |
147036f to
c5471f8
Compare
…rk extensions Signed-off-by: Alp Dener <adener@nvidia.com>
…entirely Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
c5471f8 to
d79bf21
Compare
364b416 to
ee517d3
Compare
Signed-off-by: Alp Dener <adener@nvidia.com>
5cb8204 to
51b64fb
Compare
for more information, see https://pre-commit.ci
…d in the warmup, and fixed unguarded bulk overlap cublasmp backend check Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
…ap tests now passing Signed-off-by: Alp Dener <adener@nvidia.com>
…ormerEngine into common/tp-overlap-cublasmp
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
…ormerEngine into common/tp-overlap-cublasmp
Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
…ormerEngine into common/tp-overlap-cublasmp
Signed-off-by: Alp Dener <adener@nvidia.com>
|
/te-ci pytorch |
Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
…ormerEngine into common/tp-overlap-cublasmp
for more information, see https://pre-commit.ci
|
/te-ci |
…d non-deterministic failures, removed XLA_FLAGS modifications for TE/JAX tests Signed-off-by: Alp Dener <adener@nvidia.com>
…ormerEngine into common/tp-overlap-cublasmp
for more information, see https://pre-commit.ci
|
/te-ci L1 |
| void cublasmp_ag_gemm(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, | ||
| TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, | ||
| bool grad, bool accumulate, cudaStream_t stream_main); | ||
|
|
||
| void cublasmp_gemm_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, | ||
| TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, | ||
| bool grad, bool accumulate, cudaStream_t stream_main); | ||
|
|
||
| void cublasmp_gemm_ar(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, | ||
| TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, | ||
| bool grad, bool accumulate, cudaStream_t stream_main); | ||
|
|
There was a problem hiding this comment.
Why do we need those functions if we already have the nvte_* calls? In general this is supposed to be C API so we should not add and rely on even more C++ things to it.
There was a problem hiding this comment.
These exist only because the Userbuffers backend is all C++ and it's just keeping everything together.
I agree with you about the C API issue, but I think that's going to need to be a separate refactor PR entirely.
…tstrapping cuBLASMp Signed-off-by: Alp Dener <adener@nvidia.com>
…ormerEngine into common/tp-overlap-cublasmp
Description
This PR adds support for the NVTE cuBlasMp bindings in the Comm+GEMM overlap API.
Type of change
Checklist: