Skip to content
Open
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
177b2ec
cuBlasMp backend logic added to TE/common with connections to framewo…
denera Dec 2, 2025
7d46b0b
added use_cublasmp flags to CollectiveGemm bootstrapping to avoid UB …
denera Dec 2, 2025
6d4a141
added cuBLASMp backend option to JAX unit tests for CollectiveGEMM
denera Dec 16, 2025
35d0f19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2025
dd8eaf3
added pytorch unit tests for comm+GEMM overlap with cuBLASMp backend
denera Dec 16, 2025
d79bf21
greptile fixes
denera Dec 17, 2025
ee517d3
linting
denera Dec 17, 2025
51b64fb
function argument call order fixes
denera Dec 17, 2025
9be771c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
4cec043
JAX collective GEMM modified to inherit cublasmp usage from global bo…
denera Jan 16, 2026
898cf30
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Jan 16, 2026
422a654
typos and style fixes
pre-commit-ci[bot] Jan 16, 2026
6e42235
documentation and build fixes
denera Jan 27, 2026
626dd1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2026
d44cfc4
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Mar 13, 2026
e341a8b
fixed default SM margin option and JAX cgemm test runner cleanup
denera Mar 13, 2026
6942d20
cublasmp running with TE/PyTorch
denera Mar 16, 2026
bef5c7e
cublasmp working with TE/JAX
denera Mar 16, 2026
81d6383
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Mar 16, 2026
6c6cc4d
cublasmp working with TE/JAX (JAX container is missing cuBLASMp insta…
denera Mar 16, 2026
9ed2adf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 16, 2026
ca913b9
added arch suffixes for CUBLASMP lib lookup in CMAKE
denera Mar 16, 2026
c55626d
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera Mar 16, 2026
f863ba8
fixed TE/JAX collective gemm test runner
denera Mar 16, 2026
5a8c7ae
TE/JAX CGEMM test runner script fix
denera Mar 17, 2026
5b9df92
fixed the cublasmp option in the pytest runners
denera Mar 17, 2026
775df95
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Mar 20, 2026
441472a
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Apr 7, 2026
3df11fc
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Apr 17, 2026
58f1e68
cuBLASMp passing tests with TE/PyTorch
denera Apr 21, 2026
f05f849
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Apr 21, 2026
f95f229
updated cuBLASMp C++ tests to also test local chunks instead of globa…
denera Apr 21, 2026
f84e8f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2026
c67c183
cuBLASmp C++ tests switched to NCCL comms for reference results, now …
denera Apr 22, 2026
e9c79a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2026
1b8fb1e
[JAX] Fix cuBLASMp collective GEMM tests and document XLA command buf…
denera Apr 24, 2026
caa741e
changed cuBLASMp call sizing to use flat first/last dims
denera May 1, 2026
9cca8a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 1, 2026
ff4187c
cuBLASMp backend passing tests with both PyT and JAX, CUDA graph comp…
denera May 11, 2026
218257f
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera May 12, 2026
c2af15b
fixed JAX cublasmp bootstrapping TP rank argument, fixed PyTorch Comm…
denera May 12, 2026
f75d98e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 12, 2026
c208d83
C++ tests restored to working order, TE/PyTorch layer failures diagno…
denera May 15, 2026
5bd8ff9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2026
509c12e
fixed linting issues, corrected Hopper/Blackwell FP8 GEMM layout hand…
denera May 15, 2026
a51bd3b
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 15, 2026
b0bbe6d
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera May 15, 2026
04c52ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2026
4ea7334
updated TE/JAX CollectiveGemm tests to use normal distributions with …
denera May 18, 2026
6d6c7b2
added cuda stream sync to CollectiveGemm XLA custom op prepare stage …
denera May 18, 2026
f4740ea
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 18, 2026
ee80f69
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 18, 2026
85292f3
fixed TE/PyTorch cublasmp backend flag, warmup workspace now cleaned …
denera May 19, 2026
0cdbd6a
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 19, 2026
80b0a71
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera May 19, 2026
cc25997
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2026
0b4ecba
handling ncclComm_t via shared pointers to make sure they don't get p…
denera May 19, 2026
f753353
dummy warmup cuBLASMp GEMM buffers are locally allocated and destroye…
denera May 20, 2026
cf54c14
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2026
deb0890
fixed bulk-overlap fallback for cuBLASMP backend, all comm+GEMM overl…
denera May 22, 2026
f959f34
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 22, 2026
89f5d8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2026
67521f7
test skip condition when TE is NOT built with cuBLASMp
denera May 22, 2026
6d5ca20
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 22, 2026
8bcdaff
enforcing initialize_ub() call before module construction
denera May 22, 2026
e90498d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2026
a77c914
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera May 22, 2026
8d254af
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 22, 2026
a5c9117
fixed UB initializer flag
denera May 22, 2026
cd3ad03
cublasmp backend support in comm+GEMM overlap extended to fusible ops
denera May 27, 2026
5b413cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2026
4e54318
added non-multicast algo fallback for cuBLASMp
denera May 27, 2026
d29895b
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 27, 2026
15c44e4
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera May 27, 2026
cf07453
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2026
32402f7
disabling fused attention in TE/PyTorch comm+GEMM layers test to avoi…
denera May 27, 2026
0421e2b
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 27, 2026
24169fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2026
0206400
removed requirement for PyTorch PG to be on the NCCL backend when boo…
denera May 29, 2026
1f2710b
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 29, 2026
94e90b9
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera May 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os
from pathlib import Path
from importlib import metadata

import setuptools

Expand Down Expand Up @@ -88,6 +89,15 @@ def setup_pytorch_extension(
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")

if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))):
# Creating a cuBlasMp context requires direct access to the underlying NCCL
# communicator in a tensor-parallel process group. The header for ProcessGroupNCCL
# needs this CPP directive to be included properly.
cxx_flags.append("-DNVTE_WITH_CUBLASMP")
torch_lib_path = metadata.distribution("torch").locate_file("torch/lib")
library_dirs.append(torch_lib_path)
libraries.append("torch_cuda")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this NCCL communicator literally stored by cuBLASMp or is it copied somehow? I'm worried about the case where the process group gets destroyed and the NCCL communicator that was used in it also gets destroyed from underneath us.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we don't use NCCL comm from PyTorch anymore. We create our own NCCL comm that spans the same devices as the PyTorch PG with the same TP ranks, so there's no risk of the NCCL comm disappearing if the PyT PG gets destroyed.

The comment you quoted here is leftover from an earlier iteration when we did try to extract the NCCL comm from PyT PG. I will remove it to avoid confusion, and I believe we might also no longer need to directly link to the libtorch_cuda.so anymore either because of this. I will double-check that and remove it as well.


# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
Expand Down
42 changes: 40 additions & 2 deletions examples/jax/collective_gemm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
"""Shared functions for the collective GEMM tests"""

import argparse
import glob
import os

import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental import mesh_utils
from jax.experimental.multihost_utils import sync_global_devices
Comment thread
ptrendx marked this conversation as resolved.

from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap

Expand Down Expand Up @@ -56,9 +59,9 @@ def assert_allclose(actual, desired, rtol=None, atol=None, dtype=None, **kwargs)
tols["atol"] = atol

if not isinstance(actual, float):
actual = actual.astype(jnp.float32)
actual = np.asarray(actual, dtype=np.float32)
if not isinstance(desired, float):
desired = desired.astype(jnp.float32)
desired = np.asarray(desired, dtype=np.float32)

np.testing.assert_allclose(actual, desired, **tols, **kwargs)

Expand Down Expand Up @@ -96,6 +99,20 @@ def _initialize_distributed(args):

assert args.num_devices_per_process == 1, "Only single process single GPU is supported!"

# cuBLASMp issues NCCL collectives on its own communication stream
# inside the GEMM custom call. Add COLLECTIVES so XLA captures those
# ops alongside the custom call instead of invalidating the capture.
# Lower the min-graph-size to 1 so single-matmul modules also get
# captured -- otherwise small test cases skip the captured path.
# Userbuffers does not need either flag.
if args.use_cublasmp:
xla_flags = os.environ.get("XLA_FLAGS", "")
os.environ["XLA_FLAGS"] = (
xla_flags
+ " --xla_gpu_enable_command_buffer=+COLLECTIVES"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we need to make sure that CUSTOM CALL is captured by default.
Would be better to add check for that?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUSTOM_CALL is already part of the defaults for command buffers, so I don't believe we need to add it ourselves here.

The issue is that cuBLASMp has NCCL calls in it for device-initiated communication, and XLA exempts NCCL collectives from graph capture, so when NCCL collectives appear inside a CUSTOM_CALL chunk, we end up with a CUDA_ERROR_STREAM_CAPTURE_INVALIDATED error.

Adding COLLECTIVES to the command buffers is the only way I'm aware that gets around this issue. I don't know of any XLA API or flag that selectively enables command buffers for NCCL collectives only inside a CUSTOM_CALL.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then don't you need this mark this custom op with "collective"?
Something like this tensorflow/tensorflow@104721c

Copy link
Copy Markdown
Collaborator Author

@denera denera May 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commit message for this says:

Currently all FFIs are treated as opaque kernels which are further treated as compute kernels in LHS.

...

It introduces a annotation to specific compute_on="gpu_stream:collective", then it can be scheduled in the same way as native collectives and stream assignment will give it the collective stream which is of high priority and wont overlap with other collectives.

This doesn't sound like what we want for CollectiveGemm. This is not a custom collective. It's a compute op that has some internal communication (strictly P2P in TE/JAX, not collective) with a specific order-of-operations dependencies with the compute. We want XLA to treat it opaquely, and we want it to be invoked with the low priority compute stream that's used with the GEMM chunks while the internal high priority collective streams in Userbuffers or cuBLASMp are used for the overlapped communication (XLA does not know about these internal streams and doesn't need to).

+ " --xla_gpu_graph_min_graph_size=1"
)

print(
f"Initializing JAX distributed with coordinator={args.coordinator_address}, "
f"num_processes={args.num_processes}, process_id={args.process_id}"
Expand All @@ -118,6 +135,20 @@ def _initialize_distributed(args):
devices_per_process = 1
num_total_devices = args.num_processes

# Remove stale NCCL unique ID files from previous (possibly crashed) runs.
# These files are used for one-time coordination during bootstrap; stale files
# cause non-leader processes to read an old unique ID, breaking NCCL init.
# Only process 0 performs the cleanup; a global barrier ensures all processes
# wait for the cleanup to complete before any TP leader writes a fresh file.
nccl_base_path = os.environ.get("NVTE_JAX_NCCL_FILE_PATH", "/tmp")
if args.process_id == 0:
for f in glob.glob(os.path.join(nccl_base_path, "nccl_*_unique_id_*.bin")):
try:
os.remove(f)
except OSError:
pass
sync_global_devices("nccl_id_cleanup")

print(
f"Initializing CGEMM communicator with num_total_devices={num_total_devices},"
f" devices_per_process={devices_per_process}, process_id={args.process_id}"
Expand All @@ -128,6 +159,7 @@ def _initialize_distributed(args):
num_devices_per_process=devices_per_process,
process_id=args.process_id,
tensor_parallel_size=args.tensor_parallel_size,
use_cublasmp=args.use_cublasmp,
)


Expand Down Expand Up @@ -224,5 +256,11 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para
parser.add_argument(
"--enable-result-check", action="store_true", default=True, help="Enable result checking"
)
parser.add_argument(
"--use-cublasmp",
action="store_true",
default=False,
help="Use the cuBLASMp backend for overlapping collective operations with GEMM computation",
)

return parser
11 changes: 11 additions & 0 deletions examples/jax/collective_gemm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,29 @@
"""config for collective_gemm tests"""
import pytest

import transformer_engine.jax # noqa: F401 - must load libtransformer_engine.so before transformer_engine_jax
from transformer_engine_jax import nvte_built_with_cublasmp


def pytest_addoption(parser):
"""Pytest hook for collective_gemm tests"""
parser.addoption("--coordinator-address", action="store", default="localhost:12345")
parser.addoption("--num-processes", action="store", default=1)
parser.addoption("--process-id", action="store", default=0)
parser.addoption("--local-device-ids", action="store", default=None)
parser.addoption("--use-cublasmp", action="store_true", default=False)


@pytest.fixture(autouse=True)
def distributed_args(request):
"""Fixture for querying distributed initialization arguments"""
if request.cls:
use_cublasmp = request.config.getoption("--use-cublasmp")
if use_cublasmp and not nvte_built_with_cublasmp():
pytest.skip(
"Collective GEMM cuBLASMp backend tests require Transformer Engine to be built "
"with NVTE_WITH_CUBLASMP=1."
)
Comment on lines +25 to +30
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to build CUBLASMP in our CI images by default?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to but only starting with 26.05.

We need NCCL 2.30+ for cuBLASMp to be graph-safe. I'm told JAX containers are going to satisfy that requirement starting with 26.05 so we can safely install cuBLASMp without breaking anything else.

request.cls.coordinator_address = request.config.getoption("--coordinator-address")
request.cls.num_processes = int(request.config.getoption("--num-processes"))
request.cls.process_id = int(request.config.getoption("--process-id"))
Expand All @@ -27,3 +37,4 @@ def distributed_args(request):
if request.cls.local_device_ids is None
else len(request.cls.local_device_ids.split(","))
)
request.cls.use_cublasmp = use_cublasmp
118 changes: 77 additions & 41 deletions examples/jax/collective_gemm/run_test_cgemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,30 @@ else
echo "NVLINK support detected"
fi

echo "*** Checking cuBLASMp support in TE build ***"
CUBLASMP_SUPPORT=$(python3 - <<'PY'
try:
import transformer_engine.jax
from transformer_engine_jax import nvte_built_with_cublasmp
except Exception as exc:
print(f"error:{exc}")
raise SystemExit(0)

print("1" if nvte_built_with_cublasmp() else "0")
PY
)

if [[ "$CUBLASMP_SUPPORT" == "1" ]]; then
echo "cuBLASMp backend support detected"
BACKENDS=("cublasmp" "userbuffers")
elif [[ "$CUBLASMP_SUPPORT" == "0" ]]; then
echo "cuBLASMp backend support not detected; skipping cuBLASMp backend tests"
BACKENDS=("userbuffers")
else
echo "Failed to query cuBLASMp support from transformer_engine_jax: $CUBLASMP_SUPPORT"
exit 1
fi

# Define individual test cases to run (file::class::method)
# DelayedScalingFP8 and CurrentScalingFP8 use the same GEMM so we don't need to test both cases all
# the time.
Expand Down Expand Up @@ -93,50 +117,62 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
# Clear PIDs array for this test case
PIDS=()

for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_NAME}_gpu_${i}.log"

if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ==="
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_NAME}.xml \
"$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \
--num-processes=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
for BACKEND in "${BACKENDS[@]}"; do
echo "Setting backend to $BACKEND for test $TEST_NAME"

for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_NAME}_gpu_${i}_${BACKEND}.log"

test_args=(
"--num-processes=$NUM_GPUS"
"--process-id=$i"
)
if [ "$BACKEND" == "cublasmp" ]; then
test_args+=("--use-cublasmp")
fi

if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ==="
pytest -s -c "${TE_PATH}/tests/jax/pytest.ini" -vs \
"--junitxml=${XML_LOG_DIR}/${TEST_NAME}_gpu_${i}_${BACKEND}.xml" \
"${TE_PATH}/examples/jax/collective_gemm/${TEST_CASE}" \
"${test_args[@]}" 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
# For other processes: redirect to log files only
pytest -s -c "${TE_PATH}/tests/jax/pytest.ini" -vs \
"${TE_PATH}/examples/jax/collective_gemm/${TEST_CASE}" \
"${test_args[@]}" > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
fi
done

# Wait for all processes to finish
wait

# Check and print the log content from process 0
if grep -q "SKIPPED" "${TEST_NAME}_gpu_0_${BACKEND}.log"; then
echo "... $TEST_CASE SKIPPED"
elif grep -q "FAILED" "${TEST_NAME}_gpu_0_${BACKEND}.log"; then
echo "... $TEST_CASE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_NAME}_gpu_0_${BACKEND}.log"; then
echo "... $TEST_CASE PASSED"
else
# For other processes: redirect to log files only
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \
--num-processes=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
echo "... $TEST_CASE INVALID"
HAS_FAILURE=1
fi
done

# Wait for all processes to finish
wait

# Check and print the log content from process 0
if grep -q "SKIPPED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED"
elif grep -q "FAILED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE PASSED"
else
echo "... $TEST_CASE INVALID"
HAS_FAILURE=1
fi

# Remove the log files after processing them
wait
rm ${TEST_NAME}_gpu_*.log

# Remove the log files after processing them
wait
rm ${TEST_NAME}_gpu_*_${BACKEND}.log

done
done

wait
Expand Down
1 change: 1 addition & 0 deletions examples/jax/collective_gemm/test_dense_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def setUp(self):
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.use_cublasmp = self.use_cublasmp
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
Expand Down
37 changes: 23 additions & 14 deletions examples/jax/collective_gemm/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os
from functools import partial

import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec, NamedSharding
Expand Down Expand Up @@ -151,20 +152,27 @@ def run_gemm_tests(args, mesh=None):
jax.block_until_ready(gathered_output)

if args.enable_result_check and args.process_id == 0:
# CGEMM + RS + BF16 uses TE's reduce_bf16 kernel (sequential left-to-right in FP32).
# With catastrophic cancellation the output is near zero while the absolute diff can
# reach 1 ULP of the partial GEMM magnitude (~0.0625 for typical transformer
# activations at O(8) scale), which exceeds the previous atol=1e-5. The 2x
# margin (0.125) covers this worst-case 1-ULP absolute difference.
is_cgemm_rs_bf16 = collective_op == CollectiveOp.REDUCE_SCATTER and not use_quantization
rtol = 1e-2 if is_cgemm_rs_bf16 else None
atol = 0.125 if is_cgemm_rs_bf16 else None
assert_allclose(
gathered_ref_output,
gathered_output,
dtype=get_tolerance_dtype(quantizer_set),
rtol=rtol,
atol=atol,
if use_quantization:
# FP8 quantization noise on near-zero outputs can exceed the rtol
# gate; allow a small absolute tolerance.
rtol, atol = 0.125, 0.625
else:
rtol, atol = 0.02, 0.002
# Use NumPy (not JAX) for the result check to avoid triggering new XLA compilations
# on process 0 only, which would deadlock in multi-process JAX because XLA compilation
# of distributed arrays requires collective synchronization across all processes.
actual = np.asarray(gathered_output, dtype=np.float32)
desired = np.asarray(gathered_ref_output, dtype=np.float32)
diff = np.abs(actual - desired)
abs_desired = np.abs(desired)
failures = (diff > atol) & (diff > rtol * abs_desired)
num_failures = int(np.sum(failures))
assert num_failures == 0, (
f"NUMERICAL CHECK FAILED: {num_failures}/{diff.size} elements "
f"({100 * num_failures / diff.size:.4f}%) exceed tolerances "
f"(rtol={rtol}, atol={atol}). "
f"Max abs error: {float(np.max(diff)):.6f}, "
f"max rel error: {float(np.max(diff / np.maximum(abs_desired, 1e-5))):.6f}"
)


Expand All @@ -180,6 +188,7 @@ def setUp(self):
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.use_cublasmp = self.use_cublasmp
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
Expand Down
1 change: 1 addition & 0 deletions examples/jax/collective_gemm/test_layernorm_mlp_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def setUp(self):
self.args.process_id = self.process_id
self.args.local_device_ids = self.local_device_ids
self.args.num_devices_per_process = self.num_devices_per_process
self.args.use_cublasmp = self.use_cublasmp
self.args.enable_data_parallel = True
self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1]
_initialize_distributed(self.args)
Expand Down
Loading
Loading