-
Notifications
You must be signed in to change notification settings - Fork 733
[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API) #2443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 78 commits
177b2ec
7d46b0b
6d4a141
35d0f19
dd8eaf3
d79bf21
ee517d3
51b64fb
9be771c
4cec043
898cf30
422a654
6e42235
626dd1d
d44cfc4
e341a8b
6942d20
bef5c7e
81d6383
6c6cc4d
9ed2adf
ca913b9
c55626d
f863ba8
5a8c7ae
5b9df92
775df95
441472a
3df11fc
58f1e68
f05f849
f95f229
f84e8f9
c67c183
e9c79a3
1b8fb1e
caa741e
9cca8a9
ff4187c
218257f
c2af15b
f75d98e
c208d83
5bd8ff9
509c12e
a51bd3b
b0bbe6d
04c52ca
4ea7334
6d6c7b2
f4740ea
ee80f69
85292f3
0cdbd6a
80b0a71
cc25997
0b4ecba
f753353
cf54c14
deb0890
f959f34
89f5d8d
67521f7
6d5ca20
8bcdaff
e90498d
a77c914
8d254af
a5c9117
cd3ad03
5b413cd
4e54318
d29895b
15c44e4
cf07453
32402f7
0421e2b
24169fb
0206400
1f2710b
94e90b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,19 +5,31 @@ | |
| """config for collective_gemm tests""" | ||
| import pytest | ||
|
|
||
| import transformer_engine.jax # noqa: F401 - must load libtransformer_engine.so before transformer_engine_jax | ||
| from transformer_engine_jax import nvte_built_with_cublasmp | ||
|
|
||
|
|
||
| def pytest_addoption(parser): | ||
| """Pytest hook for collective_gemm tests""" | ||
| parser.addoption("--coordinator-address", action="store", default="localhost:12345") | ||
| parser.addoption("--num-processes", action="store", default=1) | ||
| parser.addoption("--process-id", action="store", default=0) | ||
| parser.addoption("--local-device-ids", action="store", default=None) | ||
| parser.addoption("--use-cublasmp", action="store_true", default=False) | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def distributed_args(request): | ||
| """Fixture for querying distributed initialization arguments""" | ||
| if request.cls: | ||
| use_cublasmp = request.config.getoption("--use-cublasmp") | ||
| if use_cublasmp and not nvte_built_with_cublasmp(): | ||
| pytest.skip( | ||
| "Collective GEMM cuBLASMp backend tests require Transformer Engine to be built " | ||
| "with NVTE_WITH_CUBLASMP=1." | ||
| ) | ||
|
Comment on lines
+25
to
+30
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we going to build CUBLASMP in our CI images by default?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to but only starting with 26.05. We need NCCL 2.30+ for cuBLASMp to be graph-safe. I'm told JAX containers are going to satisfy that requirement starting with 26.05 so we can safely install cuBLASMp without breaking anything else. |
||
| if use_cublasmp and "mxfp8" in request.node.name.lower(): | ||
| pytest.skip("MXFP8 is not supported by the cuBLASMp backend wrappers in TE/common.") | ||
| request.cls.coordinator_address = request.config.getoption("--coordinator-address") | ||
| request.cls.num_processes = int(request.config.getoption("--num-processes")) | ||
| request.cls.process_id = int(request.config.getoption("--process-id")) | ||
|
|
@@ -27,3 +39,4 @@ def distributed_args(request): | |
| if request.cls.local_device_ids is None | ||
| else len(request.cls.local_device_ids.split(",")) | ||
| ) | ||
| request.cls.use_cublasmp = use_cublasmp | ||
Uh oh!
There was an error while loading. Please reload this page.