Skip to content

[JAX] Fix L0_jax_unittest docs example test to enforce single-GPU#3059

Merged
jberchtold-nvidia merged 2 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/fix-l0-docs-test
May 29, 2026
Merged

[JAX] Fix L0_jax_unittest docs example test to enforce single-GPU#3059
jberchtold-nvidia merged 2 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/fix-l0-docs-test

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

Description

The L0_jax_unittest invocation of our docs test_dense.py examples did not set CUDA_VISIBLE_DEVICES so on a runner with >1 GPU it would still try to execute the distributed tests.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Set CUDA_VISIBLE_DEVICES=0 to enforce single-GPU tests for docs examples in L0_jax_unittest

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 28, 2026

Greptile Summary

This PR fixes the L0_jax_unittest CI script so that the docs/examples/jax pytest invocation always runs with only a single GPU visible, preventing distributed tests from executing on multi-GPU runners.

  • Adds CUDA_VISIBLE_DEVICES=0 as an inline environment variable prefix to the python3 -m pytest … docs/examples/jax/ command, scoping GPU visibility to GPU 0 for that invocation only without affecting any surrounding commands.

Confidence Score: 5/5

Safe to merge — the change is a single-line, targeted fix that correctly scopes GPU visibility to GPU 0 for one pytest invocation without affecting any other commands in the script.

The inline CUDA_VISIBLE_DEVICES=0 prefix is standard shell idiom that applies only to the immediately following command, leaving all other test invocations unaffected. The fix is minimal, correct, and directly addresses the reported issue of distributed tests running on multi-GPU CI runners.

No files require special attention.

Important Files Changed

Filename Overview
qa/L0_jax_unittest/test.sh One-line fix: adds CUDA_VISIBLE_DEVICES=0 inline prefix to the docs/examples/jax pytest invocation so distributed tests cannot run on multi-GPU CI runners.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[L0_jax_unittest/test.sh] --> B[Run JAX unit tests not distributed]
    A --> C[Run fused attn determinism tests]
    A --> D[Run mnist example tests]
    A --> E[Run single-GPU encoder tests]
    A --> F["CUDA_VISIBLE_DEVICES=0 python3 -m pytest docs/examples/jax/"]
    F --> G{GPU count visible to process}
    G -->|Always 1 GPU| H[Single-GPU docs tests pass]
    G -.->|Before fix: N GPUs on multi-GPU runner| I[Distributed tests attempt to run]
    H --> J{Any failures?}
    I --> J
    J -->|Yes| K[test_fail + exit 1]
    J -->|No| L[All tests passed + exit 0]
Loading

Reviews (2): Last reviewed commit: "Merge branch 'main' into jberchtold/fix-..." | Re-trigger Greptile

Comment on lines 45 to +48
# Exercise the docs/examples/jax tutorials. The multi-GPU tests are
# skipped at runtime when fewer than 4 devices are visible, so this is safe on
# single-GPU runners.
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax/ || test_fail "docs/examples/jax"
CUDA_VISIBLE_DEVICES=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax/ || test_fail "docs/examples/jax"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The comment above still reflects the old rationale — that multi-GPU tests would self-skip when fewer than 4 devices are visible. Now that CUDA_VISIBLE_DEVICES=0 actively enforces single-GPU visibility, that explanation is stale and slightly misleading. Updating it prevents future readers from removing the env var under the mistaken belief that runtime skipping is a sufficient guard.

Suggested change
# Exercise the docs/examples/jax tutorials. The multi-GPU tests are
# skipped at runtime when fewer than 4 devices are visible, so this is safe on
# single-GPU runners.
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax/ || test_fail "docs/examples/jax"
CUDA_VISIBLE_DEVICES=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax/ || test_fail "docs/examples/jax"
# Exercise the docs/examples/jax tutorials. CUDA_VISIBLE_DEVICES=0 restricts
# JAX to a single GPU so that distributed tests are not attempted on
# multi-GPU runners.
CUDA_VISIBLE_DEVICES=0 python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_docs_examples_jax.xml $TE_PATH/docs/examples/jax/ || test_fail "docs/examples/jax"

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

@jberchtold-nvidia jberchtold-nvidia merged commit af5d1e0 into NVIDIA:main May 29, 2026
12 of 13 checks passed
@jberchtold-nvidia jberchtold-nvidia deleted the jberchtold/fix-l0-docs-test branch May 29, 2026 17:55
KshitijLakhani pushed a commit that referenced this pull request May 29, 2026
)

Update test.sh

Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: Teddy Do <tdophung@nvidia.com>
Baibaifan pushed a commit to Baibaifan/TransformerEngine that referenced this pull request Jun 1, 2026
…IDIA#3059)

Update test.sh

Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: Teddy Do <tdophung@nvidia.com>
Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
Baibaifan pushed a commit to Baibaifan/TransformerEngine that referenced this pull request Jun 1, 2026
…IDIA#3059)

Update test.sh

Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: Teddy Do <tdophung@nvidia.com>
Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants