Skip to content

Expert Parallelism: common C API + NCCL EP backend#3034

Open
phu0ngng wants to merge 1 commit into
NVIDIA:mainfrom
phu0ngng:phuong/ep-2-commwindow
Open

Expert Parallelism: common C API + NCCL EP backend#3034
phu0ngng wants to merge 1 commit into
NVIDIA:mainfrom
phu0ngng:phuong/ep-2-commwindow

Conversation

@phu0ngng
Copy link
Copy Markdown
Collaborator

Summary

First PR in the TE Expert Parallelism (EP) series. Lands the common C API and NCCL EP backend that later framework PRs (PyTorch, JAX) build on. No Python bindings yet — common-lib foundation plus build wiring only. Build/load works on any arch; SM and NCCL version gates fire at runtime.

Every network-bound payload tensor takes an optional NVTECommWindow. When the window is provided, the backend uses NCCL EP's symmetric-memory zero-copy path, which skips the D2D Memcpy from the user buffers to the Symmetric Staging Buffers.

Implementation

Public C API (transformer_engine/common/include/transformer_engine/{ep.h,comm_window.h})

Types: NVTEEpGroupConfig, NVTEEpLayerConfig, NVTEEpHandle, NVTECommWindow (side-band {ncclWindow_t window, size_t offset}; NCCL peer handles are not carried on NVTETensor).

Lifecycle (host-only, eager):

void     nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config);
void     nvte_ep_shutdown(void);

uint64_t nvte_ep_register_layer(NVTEEpLayerConfig layer_config, size_t* handle_mem_size);
  • nvte_ep_initialize — borrow an external ncclComm_t for the EP sub-group and init the singleton backend.

  • nvte_ep_shutdown — tear down the backend; idempotent; does not destroy ep_comm.

  • nvte_ep_register_layer — reserve a handle_id for a layer config and report the handle_mem buffer size the caller must allocate. The pair {id, mem} becomes the per-step NVTEEpHandle.

Per-step (allocation-free, CUDA-graph capturable)

void nvte_ep_prepare(NVTEEpHandle handle, NVTETensor topk_idx, NVTETensor token_counts,
                     size_t dispatch_output_per_expert_alignment, cudaStream_t stream);

void nvte_ep_dispatch(NVTEEpHandle handle, NVTETensor topk_idx,
                      NVTETensor tokens, NVTECommWindow tokens_win,
                      NVTETensor topk_weights, NVTECommWindow topk_weights_win,
                      NVTETensor recv_tokens, NVTECommWindow recv_tokens_win,
                      NVTETensor recv_topk_weights,  NVTECommWindow recv_topk_weights_win,
                      cudaStream_t stream);

void nvte_ep_combine(NVTEEpHandle handle, NVTETensor expert_out, NVTECommWindow expert_out_win,
                     NVTETensor result, cudaStream_t stream);

void nvte_ep_dispatch_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win,
                          NVTETensor g_recv_topk_weights, NVTECommWindow g_recv_topk_weights_win,
                          NVTETensor grad_tokens, NVTETensor grad_topk_weights, cudaStream_t stream);

void nvte_ep_combine_bwd(NVTEEpHandle handle, NVTETensor grad, NVTECommWindow grad_win,
                         NVTETensor grad_expert_out, NVTECommWindow grad_expert_out_win,
                         cudaStream_t stream);
  • nvte_ep_prepare — all-gather the routing map and write routing maps to handle.mem.
  • nvte_ep_dispatch — scatter tokens and routing weights from source ranks to expert ranks. tokens, topk_weights, recv_tokens, recv_topk_weights each accept an optional symm-mem window for zero-copy.
  • nvte_ep_combine — scatter-sum expert outputs back to source ranks (unweighted; caller pre-multiplies by recv_topk_weights). expert_out accepts a window.
  • nvte_ep_dispatch_bwd — backward of dispatch; routes token and weight grads back to source. grad and g_recv_topk_weights accept windows; the gathered outputs (grad_tokens, grad_topk_weights).
  • nvte_ep_combine_bwd — backward of combine; grad and grad_expert_out accept windows. Padded slots in grad_expert_out are zeroed.

Backend + build

  • NCCL EP backend (transformer_engine/common/ep/): EPBackend singleton, HT-mode dispatch/combine over NCCL EP (libnccl_ep.so), group/layer registration. Internal helper make_payload_tensor() builds the per-call ncclEpTensor_t: when the caller's NVTECommWindow.window != nullptr it sets win_hdl + win_offset (zero-copy); otherwise it sets data from nvte_tensor_data(t) (HBM fallback).
  • Runtime gates (in EPBackend::initialize): SM>=90 (via cudaDeviceGetAttribute), NCCL>=2.30.4 (via ncclGetVersion), CUDA multicast/NVLS support.
  • Stub path: when NVTE_WITH_NCCL_EP=OFF, ep/ep_api_stub.cpp provides throwing nvte_ep_* stubs so framework bindings link unconditionally; failure surfaces at first nvte_ep_initialize.
  • Build wiring
    • setup.py builds libnccl_ep.so from 3rdparty/nccl by default; auto-disables NCCL EP when no requested CUDA arch >= 90. Explicit NVTE_BUILD_WITH_NCCL_EP=1 with all archs < 90 is treated as user error NVTE_BUILD_WITH_NCCL_EP=0 to opt out.
    • NCCL_HOME resolved dynamically: explicit env → /opt/nvidia/nccl, /usr/local/nccl, /usrldconfig -p fallback.

Testing

  • C++ distributed tests under tests/cpp_distributed/.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng phu0ngng requested a review from ptrendx as a code owner May 22, 2026 02:42
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR introduces the Expert Parallelism (EP) foundation for TransformerEngine: a public C API (ep.h, comm_window.h), an NCCL EP backend singleton (EPBackend), throwing stubs for non-EP builds, build wiring in setup.py and CMakeLists.txt, and a distributed C++ test suite.

  • Backend (ep_backend.cpp): register_layer initializes ncclEpHandleConfig_t with {} plus a manual size field while open_handle uses NCCL_EP_HANDLE_CONFIG_INIT; if the macro also sets a version field the two calls will disagree on the required buffer size. validate_config also omits the max_recv_tokens_per_rank > 0 check that the code's own comment flags as required.
  • Build (setup.py): The arch-detection loop skips non-digit tokens such as \"native\", silently disabling NCCL EP on a Hopper machine when NVTE_CUDA_ARCHS=native. _discover_nccl_home falls through to system probes without warning when NCCL_HOME points to a path without nccl.h.
  • API and stubs: Clean delegation in ep_api.cpp; stubs correctly no-op nvte_ep_shutdown and throw descriptive errors everywhere else.

Confidence Score: 3/5

The new EP layer is not yet wired to any Python framework, so no production training path is affected today, but the handle-memory sizing inconsistency means the first framework integration could allocate an undersized buffer and corrupt memory inside ncclEpInitHandle.

Two issues in ep_backend.cpp warrant attention before the next framework PR lands: register_layer uses a manually zero-initialised ncclEpHandleConfig_t instead of the NCCL_EP_HANDLE_CONFIG_INIT macro used everywhere else, almost certainly omitting the version field and making ncclEpHandleMemSize return a buffer size that does not match what ncclEpInitHandle will actually write; and validate_config skips the max_recv_tokens_per_rank > 0 check that an inline comment already flags as mandatory.

transformer_engine/common/ep/ep_backend.cpp (handle config init and missing validation) and setup.py (arch detection and NCCL_HOME handling)

Important Files Changed

Filename Overview
transformer_engine/common/ep/ep_backend.cpp Core NCCL EP backend. ncclEpHandleConfig_t in register_layer uses {} + manual size instead of NCCL_EP_HANDLE_CONFIG_INIT, likely omitting version and risking buffer size mismatch. validate_config missing max_recv_tokens_per_rank > 0 check.
setup.py Build wiring for NCCL EP submodule. _discover_nccl_home silently ignores invalid NCCL_HOME; arch parsing incorrectly disables NCCL EP when NVTE_CUDA_ARCHS=native on Hopper hardware.
transformer_engine/common/ep/ep_api.cpp Thin C API delegation layer with null guards on required pointer parameters. No issues found.
transformer_engine/common/ep/ep_api_stub.cpp Stub implementation that throws descriptive errors when NCCL EP is not built. nvte_ep_shutdown is a correct no-op. No issues found.
transformer_engine/common/include/transformer_engine/ep.h New public C API header with well-documented struct types and per-step operations. No issues found.
transformer_engine/common/include/transformer_engine/comm_window.h Minimal header defining NVTECommWindow. Clean and correct. No issues found.
transformer_engine/common/CMakeLists.txt CMake wiring for NCCL EP: finds headers, libs, sets rpath, conditionally compiles backend vs stubs. No issues found.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant C_API as nvte_ep_* (ep_api.cpp)
    participant Backend as EPBackend singleton
    participant NCCL_EP as libnccl_ep.so

    Caller->>C_API: nvte_ep_initialize(ep_comm, group_config)
    C_API->>Backend: EPBackend::initialize(ncclComm_t, config)
    Backend->>Backend: validate_config
    Backend->>NCCL_EP: ncclEpCreateGroup

    Caller->>C_API: nvte_ep_register_layer(layer_config, handle_mem_size)
    C_API->>Backend: register_layer
    Backend->>NCCL_EP: ncclEpHandleMemSize
    Backend-->>Caller: handle_id

    loop Per training step
        Caller->>C_API: nvte_ep_prepare(handle, topk_idx, stream)
        C_API->>Backend: prepare
        Backend->>NCCL_EP: ncclEpInitHandle + ncclEpUpdateHandle

        Caller->>C_API: nvte_ep_dispatch(handle, tokens, stream)
        C_API->>Backend: dispatch
        Backend->>NCCL_EP: ncclEpInitHandle + ncclEpDispatch

        Caller->>C_API: nvte_ep_combine(handle, expert_out, stream)
        C_API->>Backend: combine
        Backend->>NCCL_EP: ncclEpInitHandle + ncclEpCombine
    end

    Caller->>C_API: nvte_ep_shutdown()
    C_API->>Backend: EPBackend::shutdown()
    Backend->>NCCL_EP: ncclEpGroupDestroy
Loading

Reviews (1): Last reviewed commit: "Expert Parallelism: common C API + NCCL ..." | Re-trigger Greptile

Comment on lines +284 to +289
ncclEpHandleConfig_t hcfg{};
hcfg.size = static_cast<unsigned int>(sizeof(hcfg));
hcfg.dispatch_output_per_expert_alignment = layer_config.dispatch_output_per_expert_alignment;
size_t hm_size = 0;
NVTE_CHECK_NCCL(ncclEpHandleMemSize(ep_group_, NCCL_EP_LAYOUT_EXPERT_MAJOR, &hcfg, &hm_size,
layer_config.top_k));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Inconsistent ncclEpHandleConfig_t initialization may produce wrong buffer size

register_layer initializes the config with {} and manually sets only hcfg.size, while open_handle uses NCCL_EP_HANDLE_CONFIG_INIT (which also sets version and likely other fields to their expected defaults). ncclEpHandleMemSize and ncclEpInitHandle can disagree on the required buffer size when the version field is 0 instead of NCCL_EP_API_VERSION, causing the caller to allocate an undersized handle_mem buffer and leading to an out-of-bounds write in ncclEpInitHandle.

Comment on lines +114 to +115
NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ",
config.max_num_sms);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 max_recv_tokens_per_rank is not validated in validate_config

The comment at line 243 explicitly notes "Must be > 0; NCCL EP errors out on 0", but validate_config never enforces this. A zero value would cause ncclEpCreateGroup to fail with a cryptic NCCL internal error instead of the clear TE diagnostic that all other config fields get.

Suggested change
NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ",
config.max_num_sms);
NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ",
config.max_num_sms);
NVTE_CHECK(config.max_recv_tokens_per_rank > 0,
"max_recv_tokens_per_rank must be positive, got ",
config.max_recv_tokens_per_rank);

Comment thread setup.py
Comment on lines +162 to +164
env_home = os.environ.get("NCCL_HOME")
if env_home and (Path(env_home) / "include" / "nccl.h").exists():
return env_home
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 NCCL_HOME set to a wrong path is silently ignored

If a user sets NCCL_HOME to an incorrect prefix that doesn't contain include/nccl.h, the function falls through to the system probe list without any warning. The function should warn when NCCL_HOME is set but doesn't resolve to a valid NCCL install.

Suggested change
env_home = os.environ.get("NCCL_HOME")
if env_home and (Path(env_home) / "include" / "nccl.h").exists():
return env_home
env_home = os.environ.get("NCCL_HOME")
if env_home:
if (Path(env_home) / "include" / "nccl.h").exists():
return env_home
print(
f"[NCCL EP] WARNING: NCCL_HOME='{env_home}' is set but "
f"'{env_home}/include/nccl.h' was not found; falling back to system probes."
)

Comment thread setup.py
Comment on lines +93 to +97
has_hopper_or_newer = any(
int(a.strip().rstrip("af")) >= 90
for a in str(archs or "").split(";")
if a.strip().rstrip("af").isdigit()
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 NVTE_CUDA_ARCHS=native silently disables NCCL EP on valid Hopper hardware

The arch parsing rejects any token that is not isdigit() after stripping a/f suffixes. The CMake keyword "native" is silently skipped, so has_hopper_or_newer stays False and NCCL EP is auto-disabled even on a Hopper machine.

Suggested change
has_hopper_or_newer = any(
int(a.strip().rstrip("af")) >= 90
for a in str(archs or "").split(";")
if a.strip().rstrip("af").isdigit()
)
arch_tokens = [a.strip() for a in str(archs or "").split(";") if a.strip()]
has_native = any(t.lower() == "native" for t in arch_tokens)
has_hopper_or_newer = has_native or any(
int(t.rstrip("af")) >= 90
for t in arch_tokens
if t.rstrip("af").isdigit()
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant