Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "3rdparty/nccl"]
path = 3rdparty/nccl
url = https://github.com/NVIDIA/nccl.git
branch = v2.30u1
1 change: 1 addition & 0 deletions 3rdparty/nccl
Submodule nccl added at 6a9bc9
41 changes: 39 additions & 2 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,50 @@ def setup_jax_extension(

setup_mpi_flags(include_dirs, cxx_flags)

# NCCL EP is on by default. Set NVTE_BUILD_WITH_NCCL_EP=0 to skip it.
build_with_nccl_ep = bool(int(os.getenv("NVTE_BUILD_WITH_NCCL_EP", "1")))
libraries = []
submod_lib_dir = None
submod_nccl_inc = None
if build_with_nccl_ep:
cxx_flags.append("-DNVTE_WITH_NCCL_EP")
# Headers + libs come from the in-tree 3rdparty/nccl submodule build
# (auto-produced by setup.py).
libraries = ["nccl", "nccl_ep"]
# NCCL EP requires SM>=90 (Hopper+).
archs_env = os.getenv("NVTE_CUDA_ARCHS", "")
for a in archs_env.split(";"):
a_num = "".join(c for c in a if c.isdigit())
if a_num and int(a_num) < 90:
raise RuntimeError(
f"NCCL EP requires CUDA arch >= 90 (Hopper or newer); got '{a}' in"
" NVTE_CUDA_ARCHS."
)
Comment thread
phu0ngng marked this conversation as resolved.
submod_root = (common_header_files / ".." / "3rdparty" / "nccl").resolve()
submod_ep_inc = submod_root / "contrib" / "nccl_ep" / "include"
if not (submod_ep_inc / "nccl_ep.h").exists():
raise RuntimeError(
f"NCCL EP header not found at {submod_ep_inc}/nccl_ep.h. "
"Run `git submodule update --init --recursive` to checkout 3rdparty/nccl."
)
Comment thread
phu0ngng marked this conversation as resolved.
include_dirs.append(submod_ep_inc)
submod_lib_dir = submod_root / "build" / "lib"
submod_nccl_inc = submod_root / "build" / "include"

# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension

return Pybind11Extension(
ext = Pybind11Extension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args=cxx_flags,
libraries=["nccl"],
libraries=libraries,
)
if submod_lib_dir is not None:
ext.library_dirs.append(str(submod_lib_dir))
ext.runtime_library_dirs.append(str(submod_lib_dir))
# Prefer submodule's nccl.h when present (matches the C++ side).
if (submod_nccl_inc / "nccl.h").exists():
ext.include_dirs.insert(0, str(submod_nccl_inc))
return ext
Loading
Loading