Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
177b2ec
cuBlasMp backend logic added to TE/common with connections to framewo…
denera Dec 2, 2025
7d46b0b
added use_cublasmp flags to CollectiveGemm bootstrapping to avoid UB …
denera Dec 2, 2025
6d4a141
added cuBLASMp backend option to JAX unit tests for CollectiveGEMM
denera Dec 16, 2025
35d0f19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2025
dd8eaf3
added pytorch unit tests for comm+GEMM overlap with cuBLASMp backend
denera Dec 16, 2025
d79bf21
greptile fixes
denera Dec 17, 2025
ee517d3
linting
denera Dec 17, 2025
51b64fb
function argument call order fixes
denera Dec 17, 2025
9be771c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
4cec043
JAX collective GEMM modified to inherit cublasmp usage from global bo…
denera Jan 16, 2026
898cf30
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Jan 16, 2026
422a654
typos and style fixes
pre-commit-ci[bot] Jan 16, 2026
6e42235
documentation and build fixes
denera Jan 27, 2026
626dd1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 27, 2026
d44cfc4
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Mar 13, 2026
e341a8b
fixed default SM margin option and JAX cgemm test runner cleanup
denera Mar 13, 2026
6942d20
cublasmp running with TE/PyTorch
denera Mar 16, 2026
bef5c7e
cublasmp working with TE/JAX
denera Mar 16, 2026
81d6383
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Mar 16, 2026
6c6cc4d
cublasmp working with TE/JAX (JAX container is missing cuBLASMp insta…
denera Mar 16, 2026
9ed2adf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 16, 2026
ca913b9
added arch suffixes for CUBLASMP lib lookup in CMAKE
denera Mar 16, 2026
c55626d
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera Mar 16, 2026
f863ba8
fixed TE/JAX collective gemm test runner
denera Mar 16, 2026
5a8c7ae
TE/JAX CGEMM test runner script fix
denera Mar 17, 2026
5b9df92
fixed the cublasmp option in the pytest runners
denera Mar 17, 2026
775df95
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Mar 20, 2026
441472a
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Apr 7, 2026
3df11fc
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Apr 17, 2026
58f1e68
cuBLASMp passing tests with TE/PyTorch
denera Apr 21, 2026
f05f849
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera Apr 21, 2026
f95f229
updated cuBLASMp C++ tests to also test local chunks instead of globa…
denera Apr 21, 2026
f84e8f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 21, 2026
c67c183
cuBLASmp C++ tests switched to NCCL comms for reference results, now …
denera Apr 22, 2026
e9c79a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2026
1b8fb1e
[JAX] Fix cuBLASMp collective GEMM tests and document XLA command buf…
denera Apr 24, 2026
caa741e
changed cuBLASMp call sizing to use flat first/last dims
denera May 1, 2026
9cca8a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 1, 2026
ff4187c
cuBLASMp backend passing tests with both PyT and JAX, CUDA graph comp…
denera May 11, 2026
218257f
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera May 12, 2026
c2af15b
fixed JAX cublasmp bootstrapping TP rank argument, fixed PyTorch Comm…
denera May 12, 2026
f75d98e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 12, 2026
c208d83
C++ tests restored to working order, TE/PyTorch layer failures diagno…
denera May 15, 2026
5bd8ff9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2026
509c12e
fixed linting issues, corrected Hopper/Blackwell FP8 GEMM layout hand…
denera May 15, 2026
a51bd3b
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 15, 2026
b0bbe6d
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera May 15, 2026
04c52ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2026
4ea7334
updated TE/JAX CollectiveGemm tests to use normal distributions with …
denera May 18, 2026
6d6c7b2
added cuda stream sync to CollectiveGemm XLA custom op prepare stage …
denera May 18, 2026
f4740ea
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 18, 2026
ee80f69
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 18, 2026
85292f3
fixed TE/PyTorch cublasmp backend flag, warmup workspace now cleaned …
denera May 19, 2026
0cdbd6a
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 19, 2026
80b0a71
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera May 19, 2026
cc25997
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2026
0b4ecba
handling ncclComm_t via shared pointers to make sure they don't get p…
denera May 19, 2026
f753353
dummy warmup cuBLASMp GEMM buffers are locally allocated and destroye…
denera May 20, 2026
cf54c14
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2026
deb0890
fixed bulk-overlap fallback for cuBLASMP backend, all comm+GEMM overl…
denera May 22, 2026
f959f34
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 22, 2026
89f5d8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2026
67521f7
test skip condition when TE is NOT built with cuBLASMp
denera May 22, 2026
6d5ca20
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 22, 2026
8bcdaff
enforcing initialize_ub() call before module construction
denera May 22, 2026
e90498d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2026
a77c914
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera May 22, 2026
8d254af
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 22, 2026
a5c9117
fixed UB initializer flag
denera May 22, 2026
cd3ad03
cublasmp backend support in comm+GEMM overlap extended to fusible ops
denera May 27, 2026
5b413cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2026
4e54318
added non-multicast algo fallback for cuBLASMp
denera May 27, 2026
d29895b
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 27, 2026
15c44e4
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera May 27, 2026
cf07453
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2026
32402f7
disabling fused attention in TE/PyTorch comm+GEMM layers test to avoi…
denera May 27, 2026
0421e2b
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 27, 2026
24169fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2026
0206400
removed requirement for PyTorch PG to be on the NCCL backend when boo…
denera May 29, 2026
1f2710b
Merge branch 'common/tp-overlap-cublasmp' of github.com:denera/Transf…
denera May 29, 2026
94e90b9
Merge remote-tracking branch 'upstream/main' into common/tp-overlap-c…
denera May 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""PyTorch related extensions."""
import os
from pathlib import Path
from importlib import metadata

import setuptools

Expand Down Expand Up @@ -87,6 +88,15 @@ def setup_pytorch_extension(
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")

if bool(int(os.getenv("NVTE_WITH_CUBLASMP", 0))):
# Creating a cuBlasMp context requires direct access to the underlying NCCL
# communicator in a tensor-parallel process group. The header for ProcessGroupNCCL
# needs this CPP directive to be included properly.
cxx_flags.append("-DNVTE_WITH_CUBLASMP")
torch_lib_path = metadata.distribution("torch").locate_file("torch/lib")
library_dirs.append(torch_lib_path)
libraries.append("torch_cuda")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this NCCL communicator literally stored by cuBLASMp or is it copied somehow? I'm worried about the case where the process group gets destroyed and the NCCL communicator that was used in it also gets destroyed from underneath us.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we don't use NCCL comm from PyTorch anymore. We create our own NCCL comm that spans the same devices as the PyTorch PG with the same TP ranks, so there's no risk of the NCCL comm disappearing if the PyT PG gets destroyed.

The comment you quoted here is leftover from an earlier iteration when we did try to extract the NCCL comm from PyT PG. I will remove it to avoid confusion, and I believe we might also no longer need to directly link to the libtorch_cuda.so anymore either because of this. I will double-check that and remove it as well.


# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
Expand Down
7 changes: 7 additions & 0 deletions examples/jax/collective_gemm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def _initialize_distributed(args):
num_devices_per_process=devices_per_process,
process_id=args.process_id,
tensor_parallel_size=args.tensor_parallel_size,
use_cublasmp=args.use_cublasmp,
)


Expand Down Expand Up @@ -224,5 +225,11 @@ def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor para
parser.add_argument(
"--enable-result-check", action="store_true", default=True, help="Enable result checking"
)
parser.add_argument(
"--use-cublasmp",
action="store_true",
default=False,
help="Use the cuBLASMp backend for overlapping collective operations with GEMM computation",
)

return parser
101 changes: 60 additions & 41 deletions examples/jax/collective_gemm/run_test_cgemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -93,50 +93,69 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
# Clear PIDs array for this test case
PIDS=()

for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_NAME}_gpu_${i}.log"

if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ==="
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_NAME}.xml \
"$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \
--num-processes=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
BACKENDS=("userbuffers", "cublasmp")
Comment thread
denera marked this conversation as resolved.
Outdated
for BACKEND in "${BACKENDS[@]}"; do
echo "Setting backend to $BACKEND for test $TEST_NAME"

for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_NAME}_gpu_${i}_${BACKEND}.log"

test_case_args=(

"--num-processes=$NUM_GPUS"
"--process-id=$i"
)
if [ "$BACKEND" == "cublasmp" ]; then
pytest_args+=("--use-cublasmp")
fi

pytest_args=(
"-s"
"-c $TE_PATH/tests/jax/pytest.ini"
"-vs"
)
Comment thread
denera marked this conversation as resolved.
Outdated
if [ $i -eq 0 ]; then
# For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ==="
pytest_args+=("--junitxml=${XML_LOG_DIR}/${TEST_NAME}_gpu_${i}_${BACKEND}.xml")
pytest "${pytest_args[@]}" \
"$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \
"${test_case_args[@]}" 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
# For other processes: redirect to log files only
pytest "${pytest_args[@]}" \
"$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \
"${test_case_args[@]}" > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
fi
done

# Wait for all processes to finish
wait

# Check and print the log content from process 0
if grep -q "SKIPPED" "${TEST_NAME}_gpu_0_${BACKEND}.log"; then
echo "... $TEST_CASE SKIPPED"
elif grep -q "FAILED" "${TEST_NAME}_gpu_0_${BACKEND}.log"; then
echo "... $TEST_CASE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_NAME}_gpu_0_${BACKEND}.log"; then
echo "... $TEST_CASE PASSED"
else
# For other processes: redirect to log files only
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_CASE" \
--num-processes=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
echo "... $TEST_CASE INVALID"
HAS_FAILURE=1
fi
done

# Wait for all processes to finish
wait

# Check and print the log content from process 0
if grep -q "SKIPPED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED"
elif grep -q "FAILED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_NAME}_gpu_0.log"; then
echo "... $TEST_CASE PASSED"
else
echo "... $TEST_CASE INVALID"
HAS_FAILURE=1
fi

# Remove the log files after processing them
wait
rm ${TEST_NAME}_gpu_*.log

# Remove the log files after processing them
wait
rm ${TEST_NAME}_gpu_*_${BACKEND}.log

done
done

wait
Expand Down
124 changes: 84 additions & 40 deletions tests/pytorch/distributed/run_gemm_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ def _parse_args(argv=None, namespace=None):
parser.add_argument(
"--use-cuda-graphs", action="store_true", default=False, help="Use CUDA graphs."
)
parser.add_argument(
"--use-cublasmp", action="store_true", default=False, help="Use cuBLASMp backend."
)
parser.add_argument(
"-v", "--verbose", action="store_true", default=False, help="Verbose info messages."
)
Expand Down Expand Up @@ -203,6 +206,7 @@ def _main(opts):
capture_output=True,
text=True,
shell=True,
check=False,
)

if result.stdout == "0": # Extra checks for non-MNNVL platforms
Expand Down Expand Up @@ -306,7 +310,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
helper = (
tex.CommOverlapHelper()
if tex.ubuf_built_with_mpi()
else tex.CommOverlapHelper(bootstrap_pg)
else tex.CommOverlapHelper(bootstrap_pg, tp_group)
)

# Initialize userbuffers with (M, N) buffer
Expand All @@ -323,47 +327,75 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
):
buffer_dtype = torch.uint8
ub_obj = (
tex.CommOverlapP2P(
(outer_size, hidden_size),
buffer_dtype,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
opts.comm_type,
set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic,
atomic_gemm=opts.atomic,
aggregate=opts.aggregate,
use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))),
(
tex.CommOverlapP2P(
(outer_size, hidden_size),
buffer_dtype,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
opts.comm_type,
set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic,
atomic_gemm=opts.atomic,
aggregate=opts.aggregate,
use_ce=not (opts.atomic and bool(int(os.getenv("NVTE_AG_P2P_MULTI_ATOMIC", "0")))),
)
if not opts.use_cublasmp
else tex.CommOverlapP2P(
helper,
tp_rank,
tp_size,
num_comm_sm=3,
atomic_gemm=opts.atomic,
)
)
if opts.p2p
else tex.CommOverlap(
(outer_size, hidden_size),
buffer_dtype,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
atomic_gemm=opts.atomic,
else (
tex.CommOverlap(
(outer_size, hidden_size),
buffer_dtype,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
atomic_gemm=opts.atomic,
)
if not opts.use_cublasmp
else tex.CommOverlap(
helper,
tp_rank,
tp_size,
num_comm_sm=16,
atomic_gemm=opts.atomic,
)
)
)

# Numerical check on AG + atomic GEMM requires testing an AG+RS pair
ub_obj2 = None
if opts.atomic and opts.comm_type == tex.CommOverlapType.AG and opts.check_numerics:
ub_obj2 = (
tex.CommOverlapP2P(
(outer_size, hidden_size),
torch.uint8 if opts.fp8_output else torch.bfloat16,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
tex.CommOverlapType.RS,
set_sm_margin=True,
atomic_gemm=True,
(
tex.CommOverlapP2P(
(outer_size, hidden_size),
torch.uint8 if opts.fp8_output else torch.bfloat16,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
tex.CommOverlapType.RS,
set_sm_margin=True,
atomic_gemm=True,
)
if not opts.use_cublasmp
else tex.CommOverlapP2P(helper, tp_rank, tp_size, num_comm_sm=16, atomic_gemm=True)
)
if opts.atomic_rs_p2p
else tex.CommOverlap(
(outer_size, hidden_size),
torch.uint8 if opts.fp8_output else torch.bfloat16,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
atomic_gemm=True,
else (
tex.CommOverlap(
(outer_size, hidden_size),
torch.uint8 if opts.fp8_output else torch.bfloat16,
helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
atomic_gemm=True,
)
if not opts.use_cublasmp
else tex.CommOverlap(helper, tp_rank, tp_size, num_comm_sm=3, atomic_gemm=True)
)
)

Expand Down Expand Up @@ -408,7 +440,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
mean=0.0,
std=opts.std,
)
if ub_obj2 is not None:
if opts.comm_type == tex.CommOverlapType.AG and ub_obj2 is not None:
kernel2_t = torch.nn.init.normal_(
torch.empty(local_kernel2_t_shape, dtype=torch.bfloat16, device="cuda"),
mean=0.0,
Expand All @@ -429,22 +461,22 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
# AG Kernel: (K/P, N) -> gather -> (K, N) -> T -> (N, K)
ker_g = torch.transpose(
te.distributed.gather_along_first_dim(kernel_t, tp_group)[0], 0, 1
).to(dtype=torch.float32)
)
# AG Input: (M/P, N) -> gather -> (M, N)
inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0].to(dtype=torch.float32)
inp_g = te.distributed.gather_along_first_dim(inp, tp_group)[0]
if ub_obj2 is not None:
ker2_g = te.distributed.gather_along_first_dim(
torch.transpose(kernel2_t, 0, 1), tp_group
)[0].to(dtype=torch.float32)
)[0]
else:
# RS Kernel: (N, K/P) -> T -> (K/P, N) -> gather -> (K, N)
ker_g = te.distributed.gather_along_first_dim(
torch.transpose(kernel_t, 0, 1), tp_group
)[0].to(dtype=torch.float32)
)[0]
# RS Input: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K)
inp_g = torch.transpose(
te.distributed.gather_along_first_dim(torch.transpose(inp, 0, 1), tp_group)[0], 0, 1
).to(dtype=torch.float32)
)

if opts.bulk_overlap:
if opts.comm_type == tex.CommOverlapType.AG:
Expand All @@ -456,10 +488,20 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
# Sum the list together for final global result
ref_g = torch.stack(bulk_inp_list).sum(dim=0)
else:
ref_g = torch.matmul(inp_g, ker_g)
ref_g, *_ = tex.general_gemm(
torch.transpose(ker_g, 0, 1),
inp_g,
out_dtype=torch.bfloat16,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
)
if ub_obj2 is not None:
inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable
ref2_g = torch.matmul(inp2_g, ker2_g)
ref2_g = tex.general_gemm(
torch.transpose(ker2_g),
inp2_g,
out_dtype=torch.bfloat16,
use_split_accumulator=te.module.base._2X_ACC_FPROP,
)
Comment thread
denera marked this conversation as resolved.
Outdated

# Initialize quantizers
with_quantized_compute = opts.quantization != "none"
Expand Down Expand Up @@ -580,14 +622,16 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None
tp_group,
)
gemm_inp = inp
else:
elif not opts.use_cublasmp:
ag_out, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
inp_fp8 if with_quantized_compute else inp,
inp_quantizer,
tp_group,
)
gemm_inp = ag_out
else:
gemm_inp = inp_fp8 if with_quantized_compute else inp
if ub_obj2 is not None:
rs_out2 = torch.empty(
(outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda"
Expand Down
7 changes: 7 additions & 0 deletions tests/pytorch/distributed/run_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ def _parse_args(argv=None, namespace=None):
default=0,
help="Number of layers at the end to run in bf16.",
)
parser.add_argument(
"--use-cublasmp",
action="store_true",
default=False,
help="Use cuBLASMp backend.",
)
args = parser.parse_args(argv, namespace)

if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
Expand Down Expand Up @@ -436,6 +442,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg,
with_cublasmp=opts.use_cublasmp,
)

with te.quantized_model_init(enabled=opts.fp8_init):
Expand Down
Loading
Loading