Skip to content

[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443

Open
denera wants to merge 81 commits into
NVIDIA:mainfrom
denera:common/tp-overlap-cublasmp
Open

[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443
denera wants to merge 81 commits into
NVIDIA:mainfrom
denera:common/tp-overlap-cublasmp

Conversation

@denera
Copy link
Copy Markdown
Collaborator

@denera denera commented Dec 2, 2025

Description

This PR adds support for the NVTE cuBlasMp bindings in the Comm+GEMM overlap API.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera self-assigned this Dec 2, 2025
@denera denera force-pushed the common/tp-overlap-cublasmp branch 2 times, most recently from 908bbc2 to 69cf235 Compare December 2, 2025 20:12
Comment thread transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
@@ -17,6 +18,12 @@

#define NVTE_COMM_OVERLAP_MAX_STREAMS 3

/* \brief Check if TE is built with cuBlasMp.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: cuBLASMp

@@ -526,6 +514,11 @@ class CommOverlapHelper : torch::CustomClassHolder {
ExtComm comm);

void ub_barrier(ExtComm comm);

int64_t get_nccl_comm_ptr(std::string comm_name) {
NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL.");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error message could be more descriptive - e.g. something like "chosen backend for the communication-computation overlap (cuBLASMp) requires NCCL communicator, but the passed ProcessGroup uses a different backend."

@denera denera force-pushed the common/tp-overlap-cublasmp branch from 4596411 to b4ad546 Compare December 16, 2025 19:04
@denera denera marked this pull request as ready for review December 16, 2025 22:58
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Dec 16, 2025

Greptile Summary

This PR adds cuBLASMp as a second backend for the Comm+GEMM overlap API across PyTorch and JAX, alongside the existing Userbuffers backend. It introduces new CommOverlapHelper NCCL communicator initialization, cuBLASMp-flavoured constructors for CommOverlap/CommOverlapP2P, a with_cublasmp parameter in initialize_ub(), and framework-level adjustments for bulk-overlap fallback, output-buffer routing, wgrad re-gathering, and input rank restoration.

  • C++ core: CommOverlapCore gains a second constructor accepting a pre-built ncclComm_t; CommOverlapHelper now creates dedicated NCCL world/intra communicators when NVTE_WITH_CUBLASMP is set; a cublasmp_capture_warmup helper pre-registers cuBLASMp workspaces outside any CUDA-graph capture window.
  • PyTorch modules: initialize_ub() clears bulk/external overlap methods under cuBLASMp, Linear/LayerNormLinear/LayerNormMLP override bulk flags at construction time and re-gather grad_output in the backward pass since cuBLASMp AG+GEMM does not preserve the gathered tensor.
  • JAX: collective_gemm_bootstrap gains a use_cublasmp flag; CgemmConfig stores and hashes the flag so plan-cache lookups are backend-aware.

Confidence Score: 2/5

Multiple unresolved issues from prior review rounds remain open, including a CUDA device-memory leak in the warmup path, missing NCCL dependency guards in the public header, an unguarded cublasmp.h include that breaks non-cuBLASMp CI, uninitialized all_outputs2 in a specific test path, and a removed pgid that breaks NCCL-ID-file isolation on shared machines.

The architectural direction is sound and several bugs from earlier rounds have been fixed (ncclCommInitRank replacing ncclCommInitAll, plan-id cache now hashing use_cublasmp, cuBLASMp bulk-overlap guard now raising ValueError). However, a cluster of concrete defects reported in previous reviews remain present: the warmup cudaMalloc leak on exceptions, unconditional nccl.h in the public header, unconditional cublasmp.h in the C++ test, missing all_outputs2 assignment in the non-FP8 cuBLASMp test path, and the NCCL-ID-file race after pgid removal.

transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp (warmup leak), transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h (unconditional nccl.h), tests/cpp_distributed/test_comm_gemm.cu (unconditional cublasmp.h + vacuous BF16 tolerance), tests/pytorch/distributed/run_gemm_with_overlap.py (unset all_outputs2), transformer_engine/jax/csrc/extensions/cgemm_helper.cpp (NCCL ID file isolation)

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp New NCCL communicator init via broadcast+ncclCommInitRank, cuBLASMp constructors, and warmup function. The warmup still uses raw cudaMalloc with no RAII, leaking device memory if any intervening call throws.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Adds use_cublasmp and comm_type to CommOverlap/CommOverlapP2P bindings. num_splits default changed 3→4 and two new params inserted before it, silently breaking positional callers.
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h Adds cuBLASMp constructor overloads. nccl.h and comm_gemm.h included unconditionally, forcing NCCL/cuBLASMp dependency on all includers even without NVTE_WITH_CUBLASMP.
transformer_engine/pytorch/module/base.py Adds _ub_initialized/_ub_with_cublasmp globals and with_cublasmp to initialize_ub(). Bulk/external methods cleared for cuBLASMp. Logic correct overall.
transformer_engine/pytorch/module/layernorm_linear.py Adds cuBLASMp output routing, ln_out aliasing guard, and grad_output re-gather for wgrad on the cuBLASMp backward path.
transformer_engine/pytorch/module/layernorm_mlp.py Adds cuBLASMp output routing for fc2/fc1_dgrad, ln_out aliasing guard, fc2_wgrad re-gather. any([...]) correctly wrapped in a list.
transformer_engine/pytorch/module/linear.py Adds cuBLASMp output routing, wgrad re-gather, and out.view() to restore logical tensor rank after cuBLASMp 2-D output.
tests/cpp_distributed/test_comm_gemm.cu Adds MPI-based AG/RS/AR reference. Unconditional <cublasmp.h> breaks non-cuBLASMp builds. BF16 GemmAr tolerance raised 600x making the check nearly vacuous.
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp Adds use_cublasmp to CgemmConfig and plan_id hash (fixing cache collision). pgid removed from NCCL ID file path, breaking isolation on shared machines.
tests/pytorch/distributed/run_gemm_with_overlap.py Adds --use-cublasmp, refactors reference computation to per-rank GEMMs. all_outputs2 remains unassigned for non-FP8 cuBLASMp paths.

Sequence Diagram

sequenceDiagram
    participant P as Python initialize_ub
    participant H as CommOverlapHelper
    participant NCCL as NCCL
    participant CO as CommOverlap
    participant CMP as cuBLASMp
    P->>H: CommOverlapHelper(world_pg, intra_pg)
    H->>NCCL: ncclGetUniqueId rank 0
    H->>NCCL: ncclCommInitRank nccl_world
    H->>NCCL: ncclCommInitRank nccl_intra
    H-->>P: helper with shared NCCL comms
    P->>CO: CommOverlap cuBLASMp ctor
    CO->>H: get_nccl_comm intra
    CO->>CMP: nvte_comm_gemm_ctx_create
    CO->>CMP: cublasmp_capture_warmup
    CMP-->>CO: workspaces registered
    P->>CO: cublasmp_ag_gemm forward
    CMP-->>P: local 2D output
    P->>P: out.view restore rank
    P->>CO: cublasmp_gemm_rs backward
    CMP-->>P: dX reduce-scattered
    P->>P: gather dY for wgrad
Loading

Reviews (38): Last reviewed commit: "Merge remote-tracking branch 'upstream/m..." | Re-trigger Greptile

greptile-apps[bot]

This comment was marked as outdated.

@denera denera force-pushed the common/tp-overlap-cublasmp branch from 147036f to c5471f8 Compare December 17, 2025 02:15
denera and others added 6 commits December 17, 2025 02:16
…rk extensions

Signed-off-by: Alp Dener <adener@nvidia.com>
…entirely

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from c5471f8 to d79bf21 Compare December 17, 2025 02:16
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from 364b416 to ee517d3 Compare December 17, 2025 02:50
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from 5cb8204 to 51b64fb Compare December 17, 2025 03:36
Comment thread transformer_engine/common/CMakeLists.txt Outdated
Comment thread transformer_engine/jax/cpp_extensions/gemm.py Outdated
denera and others added 12 commits May 20, 2026 21:54
…d in the warmup, and fixed unguarded bulk overlap cublasmp backend check

Signed-off-by: Alp Dener <adener@nvidia.com>
…ap tests now passing

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented May 22, 2026

/te-ci pytorch

@denera denera requested a review from timmoon10 as a code owner May 27, 2026 05:14
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented May 27, 2026

/te-ci

@denera denera requested a review from ptrendx May 27, 2026 15:06
denera and others added 3 commits May 27, 2026 20:07
…d non-deterministic failures, removed XLA_FLAGS modifications for TE/JAX tests

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera
Copy link
Copy Markdown
Collaborator Author

denera commented May 28, 2026

/te-ci L1

Comment on lines +131 to +142
void cublasmp_ag_gemm(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
bool grad, bool accumulate, cudaStream_t stream_main);

void cublasmp_gemm_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
bool grad, bool accumulate, cudaStream_t stream_main);

void cublasmp_gemm_ar(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out,
bool grad, bool accumulate, cudaStream_t stream_main);

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need those functions if we already have the nvte_* calls? In general this is supposed to be C API so we should not add and rely on even more C++ things to it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These exist only because the Userbuffers backend is all C++ and it's just keeping everything together.

I agree with you about the C API issue, but I think that's going to need to be a separate refactor PR entirely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants