Skip to content

[fix] Fix CUTLASS grouped GEMM segfault for empty groups#3037

Closed
Baibaifan wants to merge 4 commits into
NVIDIA:mainfrom
Baibaifan:zero_groupgemm
Closed

[fix] Fix CUTLASS grouped GEMM segfault for empty groups#3037
Baibaifan wants to merge 4 commits into
NVIDIA:mainfrom
Baibaifan:zero_groupgemm

Conversation

@Baibaifan
Copy link
Copy Markdown

Description

Handle grouped GEMM calls where all groups are empty.

MoE routing can legally produce a microbatch where no local expert receives
tokens. The PyTorch grouped GEMM wrapper filters those zero-token GEMMs, but
the CUTLASS grouped GEMM path could still be reached with num_gemms == 0 and
then dereference A[0]/B[0]/D[0], causing a native segfault.

Return early after filtering all GEMMs in te_general_grouped_gemm, and add a
defensive num_gemms <= 0 guard in nvte_multi_tensor_gemm.

Add a Hopper/CUTLASS regression test covering all-empty grouped GEMM inputs for
TN, NN, and NT layouts.

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Changes

Please list the changes introduced in this PR:

  • Return early from te_general_grouped_gemm when all GEMMs were filtered.
  • Add a defensive num_gemms <= 0 guard in nvte_multi_tensor_gemm.
  • Add a Hopper-only regression test for all-empty grouped GEMM inputs under
    NVTE_USE_CUTLASS_GROUPED_GEMM=1.
  • Cover TN, NN, and NT layouts.

Testing

pytest -q tests/pytorch/test_numerics.py::test_grouped_gemm_cutlass_empty_groups -s

Checklist:

  • I have read and followed the [contributing guidelines]
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@Baibaifan Baibaifan requested a review from ksivaman as a code owner May 22, 2026 04:06
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 22, 2026
Signed-off-by: yangfan.bai <yangfan.bai@shopee.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR fixes a native segfault that occurred when the CUTLASS grouped-GEMM path was reached with all groups empty: te_general_grouped_gemm filtered every GEMM but still called nvte_multi_tensor_gemm with empty vectors, which then dereferenced A[0]/B[0]/D[0].

  • gemm.cpp: Adds if (te_A_wrappers.empty()) return bias; immediately after the per-group filtering loop, short-circuiting before any downstream array access. The return value matches the normal function exit.
  • cublaslt_gemm.cu: Adds a defensive if (num_gemms <= 0) return; guard at the top of nvte_multi_tensor_gemm to protect any direct callers of the C API.
  • test_grouped_linear.py: Adds a Hopper-only parametrized test for TN, NN, and NT layouts with m_splits=[0]; the NT case meaningfully verifies that the pre-allocated wgrad buffer is zeroed on early return, while TN/NN serve as crash-only guards.

Confidence Score: 5/5

The change is safe to merge: both guards are minimal, idempotent, and sit on paths that were already broken (they only fire when there is nothing to do).

Both fixes are tightly scoped: the te_A_wrappers.empty() guard in gemm.cpp returns the same value as the normal function exit, and the num_gemms <= 0 guard in cublaslt_gemm.cu is purely defensive with no observable side effects on non-empty inputs. The NT regression test provides a real postcondition check; TN/NN act as crash guards. No existing behaviour is altered for the non-empty case.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/gemm.cpp Adds an early return after the per-GEMM filtering loop when all wrappers are empty, preventing downstream dereference of empty vectors; return value bias matches the normal function exit.
transformer_engine/common/gemm/cublaslt_gemm.cu Adds a defensive num_gemms <= 0 early-exit guard at the top of nvte_multi_tensor_gemm to block entry into the CUTLASS path with an empty array.
tests/pytorch/test_grouped_linear.py Adds a Hopper-only regression test for all-empty grouped GEMM inputs; NT layout produces a meaningful zero-check on the pre-allocated output buffer, while TN/NN comparisons are trivially satisfied on zero-numel tensors (crash/segfault guards only).

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[general_grouped_gemm Python] --> B[te_general_grouped_gemm C++]
    B --> C{Loop over A/B groups}
    C -->|te_A.numel==0 or te_B.numel==0| D[zero out_tensor / bias / pre_gelu_out\ncontinue]
    D --> C
    C -->|all groups processed| E{te_A_wrappers.empty?\nNEW GUARD}
    E -->|yes - all groups filtered| F[return bias early\navoid segfault]
    E -->|no - some non-empty groups| G[swizzle scales\nbuild NVTETensor vectors]
    G --> H[nvte_multi_tensor_gemm]
    H --> I{num_gemms <= 0?\nNEW GUARD}
    I -->|yes| J[return early]
    I -->|no - Hopper + CUTLASS| K[CUTLASS grouped GEMM]
    I -->|no - other path| L[cuBLAS multi-stream GEMM]
Loading

Reviews (3): Last reviewed commit: "[fix] add empty_groups unit test." | Re-trigger Greptile

Comment thread tests/pytorch/test_numerics.py Outdated
Comment on lines 2850 to 2855
for tensor in out:
torch.testing.assert_close(tensor, torch.zeros_like(tensor), rtol=0, atol=0)


def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None:
data = grouped_tensor.rowwise_data
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Zero-assertion is trivially true for TN and NN layouts

For TN and NN, out is constructed as a list containing a single 0-element tensor (torch.empty(0, n/k, ...)). torch.testing.assert_close on two empty tensors passes unconditionally regardless of any computation, so those two sub-cases only serve as crash/segfault guards. The meaningful assertion only fires for NT, where out[0] is a full (n, k) buffer that the C++ code zeros in-place. Consider either documenting this in a comment or, for TN/NN, adding a small non-empty output tensor and asserting it is zero to provide the same level of postcondition coverage as NT.

ptrendx
ptrendx previously approved these changes May 29, 2026
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented May 29, 2026

Hi @Baibaifan, could you resolve the conflicts?

@Baibaifan
Copy link
Copy Markdown
Author

Hi @Baibaifan, could you resolve the conflicts?

hi, @ptrendx

Due to a code merging issue, I have resubmitted a pull request (PR). The new PR can be found here: #3067.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants