[JAX] Expert Parallelism: JAX primitives + VJPs#3036
Conversation
Greptile SummaryThis PR lands the JAX Expert Parallelism bindings: XLA FFI handlers over the
Confidence Score: 3/5Two bugs affecting core bootstrap and routing correctness need fixes before this is safe to merge into a production branch. The
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User code
participant EP as ep.py (custom_vjp)
participant Prim as cpp_extensions/ep.py (Primitives)
participant FFI as csrc/extensions/ep.cpp (XLA FFI)
participant BE as ep_backend.cpp (EPBackend)
participant NCCL as ncclEpDispatch/ncclEpCombine
Note over User,NCCL: Bootstrap (once per process)
User->>EP: ep_bootstrap(world_size, rank, ep_size, ...)
EP->>FFI: set_ep_bootstrap_params(uid_bytes, ...)
FFI->>NCCL: ncclCommInitRank (collective barrier)
FFI->>BE: nvte_ep_initialize(comm, group_cfg)
Note over User,NCCL: Forward pass
User->>EP: ep_dispatch(topk_idx, tokens, topk_weights, ...)
EP->>Prim: ep_prepare(topk_idx) to (token_counts, EpHandle)
Prim->>FFI: EpPrepareFFI to nvte_ep_prepare
FFI->>NCCL: ncclEpDispatch routing metadata
EP->>Prim: ep_dispatch_fwd(handle, tokens, ...) to recv_tokens
Prim->>FFI: EpDispatchFFI to nvte_ep_dispatch
FFI->>NCCL: ncclEpDispatch token scatter
User->>EP: ep_combine(handle, expert_out, ...) to result
EP->>Prim: ep_combine_fwd(handle, weighted_expert_out)
Prim->>FFI: EpCombineFFI to nvte_ep_combine
FFI->>NCCL: ncclEpCombine token gather
Note over User,NCCL: Backward pass (custom_vjp)
EP->>Prim: ep_dispatch_bwd(handle, g_recv_tokens, ...)
Prim->>FFI: EpDispatchBwdFFI to nvte_ep_dispatch_bwd
FFI->>NCCL: ncclEpCombine (reverse scatter)
EP->>Prim: ep_combine_bwd(handle, g_result, ...)
Prim->>FFI: EpCombineBwdFFI to nvte_ep_combine_bwd
FFI->>NCCL: ncclEpDispatch (reverse gather)
|
| } | ||
|
|
||
| private: | ||
| EpCommManager() = default; |
There was a problem hiding this comment.
If we use stateful FFI calls we could tie to EP communicator to the lifetime of the jax computation rather than the process.
There was a problem hiding this comment.
Cool to learn! I will update it.
| Error_Type EpPrepareFFI(cudaStream_t stream, Buffer_Type topk_idx, Result_Type token_counts, | ||
| Result_Type handle_mem, Result_Type workspace, EpPrepareConfig config) { | ||
| auto topk_dims = topk_idx.dimensions(); | ||
| NVTE_CHECK(topk_dims.size() >= 2, |
There was a problem hiding this comment.
nit: can we return FFI InvalidArgument instead of a NVTE_CHECK for these inputs?
There was a problem hiding this comment.
This is probably a good idea. I suggest we make another follow-up MR to do so for all the FFIs.
|
I would appreciate your help to review this PR @tdophung @jberchtold-nvidia! |
| kernels = kernels.reshape(ep_size, NLE, *kernels.shape[1:]) | ||
|
|
||
| @jax.jit | ||
| def step(idx, toks, w, lk): |
There was a problem hiding this comment.
What does lk stand for?
| leading = _ep_leading_dims(is_outer) | ||
| recv_tokens_aval = jax.core.ShapedArray(leading + (recv_pr, hidden_dim), tok_dtype) | ||
| recv_topk_weights_aval = jax.core.ShapedArray(leading + (recv_pr,), jnp.float32) | ||
| workspace_aval = jax.core.ShapedArray(topk_idx_aval.shape, jnp.int64) |
There was a problem hiding this comment.
Same comment as above about int64
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…s, MoE example) Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
| assert ret == 0, f"ncclGetUniqueId failed with code {ret}" | ||
| uid_bytes = bytes(uid_arr) |
There was a problem hiding this comment.
assert disabled by -O in ctypes UID path
assert ret == 0 is silently elided when Python runs under the -O optimisation flag (common in production or Numba/Conda environments). If ncclGetUniqueId fails, uid_bytes would be all zeros; the all-gather propagates those zeros to every rank in the EP group, causing ncclCommInitRank to either produce mismatched communicators or hang indefinitely with no diagnostic message.
| assert ret == 0, f"ncclGetUniqueId failed with code {ret}" | |
| uid_bytes = bytes(uid_arr) | |
| ret = libnccl.ncclGetUniqueId(ctypes.cast(uid_arr, ctypes.c_void_p)) | |
| if ret != 0: | |
| raise RuntimeError(f"ncclGetUniqueId failed with code {ret}") |
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Summary
Third PR in the TE Expert Parallelism (EP) series, built on top of #3034. Lands the JAX bindings: an XLA FFI layer over the
nvte_ep_*C API, a Python wrapper withcustom_vjpfor autograd, mesh-aware sharding rules, a multi-process test suite, and an end-to-end MoE example. NCCLncclEpDispatch/ncclEpCombineare exposed as XLA primitives and work with CUDA-graph capture.Implementation
Public Python API (
transformer_engine/jax/ep.py)ep_dispatch/ep_combinearejax.custom_vjpfunctions: forward is the FFI primitive, backward calls the matchingnvte_ep_*_bwdFFI primitive directly (noep_preparein the bwd — routing state is already cached inhandle.mem). Note thatep_dispatchalso callsep_preparein the forward path, which all-gathers and preprocesses routing maps.XLA FFI bindings (
transformer_engine/jax/csrc/extensions/ep.cpp)Five
XLA_FFI_DEFINE_HANDLER_SYMBOLentries —EpPrepareHandler,EpDispatchHandler,EpCombineHandler,EpDispatchBwdHandler,EpCombineBwdHandler— each calling the correspondingnvte_ep_*C entry point. All markedFFI_CudaGraph_Traitsso they capture cleanly.handle_idis a static FFI attribute baked at jit trace time.Primitives + Python layer (
transformer_engine/jax/cpp_extensions/ep.py, +951 lines)Standard TE primitive plumbing:
abstract_eval(shape/dtype inference),lowering,impl,outer_primitiveregistration, and partitioning rules so the EP collective is treated as a single sharded op by XLA (no spurious resharding around it).Sharding (
transformer_engine/jax/sharding.py, +12 lines)Adds the EP mesh axis to the global mesh resource set so downstream sharding rules can reference it.
Build wiring (
build_tools/jax.py, +41 lines)Threads NCCL EP linkage through the JAX
transformer_engine_jaxextension. No new top-level build flags; rides on the parent PR'sNVTE_BUILD_WITH_NCCL_EP.Tests & example
tests/jax/test_multi_process_ep.py(+690 lines): 13 tests covering bootstrap,ep_prepareshape/handle contracts, primitive-level dispatch/combine identity (uniform + skewed routing),custom_vjpfwd+bwd correctness, and HLO inspection (must not insert XLA collectives outside the EP FFI).tests/jax/multi_process_launch_ep.sh: 4-rank launcher; setsXLA_FLAGSto keep XLA command-buffer capture off for the EP FFI sequence (NCCL EP graph-destroy interaction).examples/jax/ep/ep_moe.py(+394 lines) +run_test_ep.sh: end-to-end MoE with EP, dp=ep=2 mesh, includes a ref-comparison--checkthat verifies fwd+bwd vs a single-process reference.Type of change
Checklist: