Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,8 @@ def test_sanity_grouped_linear(
# Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527.
bs = bs * 16
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)

if os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM") == "0" and single_param:
Comment thread
vthumbe1503 marked this conversation as resolved.
Outdated
pytest.skip("single parameter grouped linear requires NVTE_GROUPED_LINEAR_SINGLE_PARAM=1")
skip_unsupported_backward_override("grouped_linear", fp8_recipe, backward_override)
if fp8_recipe is not None:
fp8_recipe = copy.deepcopy(fp8_recipe)
Expand Down
Loading