Skip to content

[JAX] Expert Parallelism: JAX primitives + VJPs#3036

Open
phu0ngng wants to merge 3 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax
Open

[JAX] Expert Parallelism: JAX primitives + VJPs#3036
phu0ngng wants to merge 3 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-jax

Conversation

@phu0ngng
Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng commented May 22, 2026

Summary

Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the nvte_ep_* C API, a Python wrapper with custom_vjp for autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCL ncclEpDispatch/ncclEpCombine are exposed as XLA primitives and work with CUDA-graph capture.

Implementation

Public Python API (transformer_engine/jax/ep.py)

from transformer_engine.jax.ep import (
    EpHandle,        # opaque (id, handle_mem) pair from ep_prepare
    ep_bootstrap,    # one-shot per-process: init NCCL comm + nvte_ep_initialize
    ep_dispatch,     # custom_vjp-wrapped dispatch 
    ep_combine,      # custom_vjp-wrapped combine

ep_dispatch / ep_combine are jax.custom_vjp functions: forward is the FFI primitive, backward calls the matching nvte_ep_*_bwd FFI primitive directly (no ep_prepare in the bwd — routing state is already cached in handle.mem). Note that ep_dispatch also calls ep_prepare in the forward path, which all-gathers and preprocesses routing maps.

XLA FFI bindings (transformer_engine/jax/csrc/extensions/ep.cpp)

Five XLA_FFI_DEFINE_HANDLER_SYMBOL entries — EpPrepareHandler, EpDispatchHandler, EpCombineHandler, EpDispatchBwdHandler, EpCombineBwdHandler — each calling the corresponding nvte_ep_* C entry point. All marked FFI_CudaGraph_Traits so they capture cleanly. handle_id is a static FFI attribute baked at jit trace time.

Primitives + Python layer (transformer_engine/jax/cpp_extensions/ep.py, +951 lines)

Standard TE primitive plumbing: abstract_eval (shape/dtype inference), lowering, impl, outer_primitive registration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).

Sharding (transformer_engine/jax/sharding.py, +12 lines)

Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.

Build wiring (build_tools/jax.py, +41 lines)

Threads NCCL EP linkage through the JAX transformer_engine_jax extension. No new top-level build flags; rides on the parent PR's NVTE_BUILD_WITH_NCCL_EP.

Tests & example

  • tests/jax/test_multi_process_ep.py (+690 lines): 13 tests covering bootstrap, ep_prepare shape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing), custom_vjp fwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).
  • tests/jax/multi_process_launch_ep.sh: 4-rank launcher; sets XLA_FLAGS to keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).
  • examples/jax/ep/ep_moe.py (+394 lines) + run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison --check that verifies fwd+bwd vs a single-process reference.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR lands the JAX Expert Parallelism bindings: XLA FFI handlers over the nvte_ep_* C API, jax.custom_vjp-wrapped ep_dispatch/ep_combine with mesh-aware sharding rules, NCCL EP submodule build wiring, a 690-line multi-process test suite, and an end-to-end MoE example.

  • transformer_engine/jax/csrc/extensions/ep.cpp: Five XLA_FFI_DEFINE_HANDLER_SYMBOL entries (EpPrepare, EpDispatch, EpCombine, and their *Bwd counterparts); all marked FFI_CudaGraph_Traits. EpDispatchFFI always wraps topk_weights as DType::kFloat32 without validating the actual element type — non-float32 inputs would be silently reinterpreted, corrupting routing.
  • transformer_engine/jax/ep.py: Bootstrap and custom_vjp wrappers. The ctypes fallback for NCCL UID generation hardcodes libnccl.so.2; environments that ship only libnccl.so would hit an unhandled OSError on root ranks, blocking non-root ranks in the subsequent process_allgather.
  • transformer_engine/jax/cpp_extensions/ep.py: Standard TE primitive plumbing with sharding partition rules and Shardy rules for all five primitives.

Confidence Score: 3/5

Two bugs affecting core bootstrap and routing correctness need fixes before this is safe to merge into a production branch.

The EpDispatchFFI handler silently reinterprets any non-float32 topk_weights buffer as float32, corrupting expert routing with no runtime error in mixed-precision training scenarios. Separately, the ctypes fallback in ep_bootstrap hardcodes libnccl.so.2, which raises an unhandled OSError on root ranks when that specific SONAME is absent, leaving non-root ranks hung in the all-gather barrier. Both issues affect paths exercised in real deployments but not caught by the existing test suite.

transformer_engine/jax/csrc/extensions/ep.cpp (missing topk_weights dtype guard in EpDispatchFFI) and transformer_engine/jax/ep.py (hardcoded libnccl.so.2 in bootstrap ctypes fallback).

Important Files Changed

Filename Overview
transformer_engine/jax/ep.py Public EP bootstrap + custom_vjp wrappers for dispatch/combine. Hardcoded libnccl.so.2 in ctypes fallback can raise unhandled OSError on root ranks, hanging non-root ranks in the process_allgather barrier.
transformer_engine/jax/csrc/extensions/ep.cpp XLA FFI handlers for EP primitives. EpDispatchFFI unconditionally wraps topk_weights as float32 without validating the input element type, enabling silent data corruption for non-float32 routing weights.
transformer_engine/jax/cpp_extensions/ep.py JAX primitive plumbing (abstract_eval, lowering, partition, shardy rules) for all five EP ops. The ep_prepare callsite-based handle_id cache (sys._getframe) is documented in previous review threads; primitive logic itself is correct.
transformer_engine/common/ep/ep_backend.cpp EPBackend singleton wrapping ncclEpDispatch/ncclEpCombine. Lifecycle uses a mutex-protected singleton with RAII cleanup; NCCL version and SM capability guards are present.
build_tools/jax.py Threads NCCL EP linkage into the JAX extension. Arch guard only raises when an explicit sub-90 arch is listed but not when NVTE_CUDA_ARCHS is unset, inconsistent with setup.py; flagged in prior review threads.
setup.py Adds NCCL EP submodule build logic with arch detection, auto-disable for pre-Hopper targets, and symlink mirroring of system NCCL headers into the submodule build tree.
transformer_engine/jax/sharding.py Adds ep_resource field to MeshResource and ep_axis_size() helper; minimal, correct additions.
transformer_engine/common/ep/ep_api.cpp Thin C API delegation layer to EPBackend; straightforward and consistent with the header contract.
tests/jax/test_multi_process_ep.py 690-line multi-process test covering bootstrap, shape contracts, identity routing, custom_vjp correctness, and HLO inspection. Tests use float32 weights exclusively, so the missing dtype-validation bug would not be caught here.

Sequence Diagram

sequenceDiagram
    participant User as User code
    participant EP as ep.py (custom_vjp)
    participant Prim as cpp_extensions/ep.py (Primitives)
    participant FFI as csrc/extensions/ep.cpp (XLA FFI)
    participant BE as ep_backend.cpp (EPBackend)
    participant NCCL as ncclEpDispatch/ncclEpCombine

    Note over User,NCCL: Bootstrap (once per process)
    User->>EP: ep_bootstrap(world_size, rank, ep_size, ...)
    EP->>FFI: set_ep_bootstrap_params(uid_bytes, ...)
    FFI->>NCCL: ncclCommInitRank (collective barrier)
    FFI->>BE: nvte_ep_initialize(comm, group_cfg)

    Note over User,NCCL: Forward pass
    User->>EP: ep_dispatch(topk_idx, tokens, topk_weights, ...)
    EP->>Prim: ep_prepare(topk_idx) to (token_counts, EpHandle)
    Prim->>FFI: EpPrepareFFI to nvte_ep_prepare
    FFI->>NCCL: ncclEpDispatch routing metadata
    EP->>Prim: ep_dispatch_fwd(handle, tokens, ...) to recv_tokens
    Prim->>FFI: EpDispatchFFI to nvte_ep_dispatch
    FFI->>NCCL: ncclEpDispatch token scatter

    User->>EP: ep_combine(handle, expert_out, ...) to result
    EP->>Prim: ep_combine_fwd(handle, weighted_expert_out)
    Prim->>FFI: EpCombineFFI to nvte_ep_combine
    FFI->>NCCL: ncclEpCombine token gather

    Note over User,NCCL: Backward pass (custom_vjp)
    EP->>Prim: ep_dispatch_bwd(handle, g_recv_tokens, ...)
    Prim->>FFI: EpDispatchBwdFFI to nvte_ep_dispatch_bwd
    FFI->>NCCL: ncclEpCombine (reverse scatter)

    EP->>Prim: ep_combine_bwd(handle, g_result, ...)
    Prim->>FFI: EpCombineBwdFFI to nvte_ep_combine_bwd
    FFI->>NCCL: ncclEpDispatch (reverse gather)
Loading

Comments Outside Diff (1)

  1. transformer_engine/jax/csrc/extensions/ep.cpp, line 1330-1331 (link)

    P1 topk_weights dtype not validated — silent data corruption for non-float32 inputs

    EpDispatchFFI unconditionally wraps topk_weights as DType::kFloat32 without checking the actual element type. If a caller passes bfloat16 or float16 weights (common in mixed-precision MoE training), the C++ layer reinterprets those bytes as float32, producing completely wrong routing weight values. There is no C++ NVTE_CHECK on the dtype and no Python-side guard either: the abstract eval explicitly deletes topk_weights_aval before any shape/dtype inspection can happen. The routing would be silently corrupted without any error or warning.

    Add a dtype check before constructing the wrapper, and mirror it in EpDispatchPrimitive.abstract with assert topk_weights_aval.dtype == jnp.float32.

Reviews (4): Last reviewed commit: "JAX EP: tie NCCL comm lifetime to JAX ex..." | Re-trigger Greptile

Comment thread build_tools/jax.py
Comment thread build_tools/jax.py
Comment thread transformer_engine/jax/cpp_extensions/ep.py
}

private:
EpCommManager() = default;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use stateful FFI calls we could tie to EP communicator to the lifetime of the jax computation rather than the process.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool to learn! I will update it.

Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts,
Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) {
auto topk_dims = topk_idx.dimensions();
NVTE_CHECK(topk_dims.size() >= 2,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.

@phu0ngng phu0ngng requested a review from tdophung May 22, 2026 15:51
@phu0ngng
Copy link
Copy Markdown
Collaborator Author

I would appreciate your help to review this PR @tdophung @jberchtold-nvidia!
Please focus on the changes in the JAX side, as the TE/Common ones will be discussed in #3034

Comment thread examples/jax/ep/ep_moe.py Outdated
kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:])

@jax.jit
def step(idx, toks, w, lk):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does lk stand for?

Comment thread tests/jax/multi_process_launch_ep.sh Outdated
Comment thread transformer_engine/jax/cpp_extensions/ep.py Outdated
leading = _ep_leading_dims(is_outer)
recv_tokens_aval = jax.core.ShapedArray(leading + (recv_pr, hidden_dim), tok_dtype)
recv_topk_weights_aval = jax.core.ShapedArray(leading + (recv_pr,), jnp.float32)
workspace_aval = jax.core.ShapedArray(topk_idx_aval.shape, jnp.int64)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above about int64

Comment thread examples/jax/ep/ep_moe.py
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Comment thread transformer_engine/jax/ep.py Outdated
Comment thread transformer_engine/jax/ep.py Outdated
…s, MoE example)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Comment on lines +81 to +82
assert ret == 0, f"ncclGetUniqueId failed with code {ret}"
uid_bytes = bytes(uid_arr)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 assert disabled by -O in ctypes UID path

assert ret == 0 is silently elided when Python runs under the -O optimisation flag (common in production or Numba/Conda environments). If ncclGetUniqueId fails, uid_bytes would be all zeros; the all-gather propagates those zeros to every rank in the EP group, causing ncclCommInitRank to either produce mismatched communicators or hang indefinitely with no diagnostic message.

Suggested change
assert ret == 0, f"ncclGetUniqueId failed with code {ret}"
uid_bytes = bytes(uid_arr)
ret = libnccl.ncclGetUniqueId(ctypes.cast(uid_arr, ctypes.c_void_p))
if ret != 0:
raise RuntimeError(f"ncclGetUniqueId failed with code {ret}")

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants