[JAX] Fix L0_jax_unittest docs example test to enforce single-GPU#3059
Conversation
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Greptile SummaryThis PR fixes the
Confidence Score: 5/5Safe to merge — the change is a single-line, targeted fix that correctly scopes GPU visibility to GPU 0 for one pytest invocation without affecting any other commands in the script. The inline CUDA_VISIBLE_DEVICES=0 prefix is standard shell idiom that applies only to the immediately following command, leaving all other test invocations unaffected. The fix is minimal, correct, and directly addresses the reported issue of distributed tests running on multi-GPU CI runners. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[L0_jax_unittest/test.sh] --> B[Run JAX unit tests not distributed]
A --> C[Run fused attn determinism tests]
A --> D[Run mnist example tests]
A --> E[Run single-GPU encoder tests]
A --> F["CUDA_VISIBLE_DEVICES=0 python3 -m pytest docs/examples/jax/"]
F --> G{GPU count visible to process}
G -->|Always 1 GPU| H[Single-GPU docs tests pass]
G -.->|Before fix: N GPUs on multi-GPU runner| I[Distributed tests attempt to run]
H --> J{Any failures?}
I --> J
J -->|Yes| K[test_fail + exit 1]
J -->|No| L[All tests passed + exit 0]
Reviews (2): Last reviewed commit: "Merge branch 'main' into jberchtold/fix-..." | Re-trigger Greptile |
| # Exercise the docs/examples/jax tutorials. The multi-GPU tests are | ||
| # skipped at runtime when fewer than 4 devices are visible, so this is safe on | ||
| # single-GPU runners. | ||
| python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax/ || test_fail "docs/examples/jax" | ||
| CUDA_VISIBLE_DEVICES=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax/ || test_fail "docs/examples/jax" |
There was a problem hiding this comment.
The comment above still reflects the old rationale — that multi-GPU tests would self-skip when fewer than 4 devices are visible. Now that
CUDA_VISIBLE_DEVICES=0 actively enforces single-GPU visibility, that explanation is stale and slightly misleading. Updating it prevents future readers from removing the env var under the mistaken belief that runtime skipping is a sufficient guard.
| # Exercise the docs/examples/jax tutorials. The multi-GPU tests are | |
| # skipped at runtime when fewer than 4 devices are visible, so this is safe on | |
| # single-GPU runners. | |
| python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax/ || test_fail "docs/examples/jax" | |
| CUDA_VISIBLE_DEVICES=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax/ || test_fail "docs/examples/jax" | |
| # Exercise the docs/examples/jax tutorials. CUDA_VISIBLE_DEVICES=0 restricts | |
| # JAX to a single GPU so that distributed tests are not attempted on | |
| # multi-GPU runners. | |
| CUDA_VISIBLE_DEVICES=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax/ || test_fail "docs/examples/jax" |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
/te-ci |
…IDIA#3059) Update test.sh Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Co-authored-by: Teddy Do <tdophung@nvidia.com> Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
…IDIA#3059) Update test.sh Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Co-authored-by: Teddy Do <tdophung@nvidia.com> Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
Description
The L0_jax_unittest invocation of our docs test_dense.py examples did not set
CUDA_VISIBLE_DEVICESso on a runner with >1 GPU it would still try to execute the distributed tests.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: