Skip to content

[PyTorch] Make modules.GroupedLinear graph-safe#3038

Merged
timmoon10 merged 5 commits into
NVIDIA:mainfrom
yaox12:xiny/enable-grouped-quantize-cublaslt
May 29, 2026
Merged

[PyTorch] Make modules.GroupedLinear graph-safe#3038
timmoon10 merged 5 commits into
NVIDIA:mainfrom
yaox12:xiny/enable-grouped-quantize-cublaslt

Conversation

@yaox12
Copy link
Copy Markdown
Member

@yaox12 yaox12 commented May 22, 2026

Description

  • Enable grouped quantization and cuBLASLt grouped gemm for modules.GroupedLinear to benefit cases where cuteDSL fused grouped gemm is not available.

    1. Reduce CPU overhead by reducing number of kernels.
    2. Be CUDA-Graph-safe.
    3. Improve kernel performance.
  • Move grouped gemm and grouped linear related tests to a standalone file.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 22, 2026

Greptile Summary

This PR introduces a new cuBLASLt grouped GEMM path for GroupedLinear (gated by NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM=1) that enables CUDA-graph-safe execution on SM100+ hardware by operating on GroupedTensor metadata rather than splitting Python lists at runtime. It also extracts grouped-linear tests into a standalone file.

  • Adds _is_grouped_tensor_path_supported, _forward_grouped_tensor, and _backward_grouped_tensor static methods to _GroupedLinear; the new forward path uses group_quantize / bgrad_group_quantize and general_grouped_gemm_for_grouped_tensor to reduce per-kernel CPU overhead.
  • Refactors GroupedLinear.forward to normalize m_splits from list or tensor and dispatch to the new path when supported; moves test cases from test_numerics.py into a new test_grouped_linear.py and registers it in qa/L0_pytorch_unittest/test.sh.

Confidence Score: 4/5

Merge with caution: the new grouped-tensor path silently loses its CUDA-graph-safety guarantee when callers pass m_splits as a list rather than a CUDA tensor.

The new _forward_grouped_tensor path moves m_splits from CPU to the compute device (m_splits.to(device=device)) if the caller passes a list. This CPU→GPU transfer is not capturable by CUDA Graph, which is the primary goal of this PR. The transfer happens silently with no warning or validation, so any caller using the backward-compatible list API (including all current accuracy tests) will not get graph-safe behavior even with the env var enabled.

transformer_engine/pytorch/module/grouped_linear.py — specifically the GroupedLinear.forward normalization block (lines 1688-1697) and _forward_grouped_tensor (line 241).

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Core implementation: adds grouped-tensor forward/backward paths with new static helpers; overall logic is sound but the CUDA-graph guarantee is broken when m_splits is passed as a Python list rather than a pre-placed CUDA tensor.
tests/pytorch/test_grouped_linear.py New dedicated test file covering accuracy, padding, grouped GEMM layouts, and CUTLASS path; all GroupedLinear accuracy tests pass m_splits as a list, so the grouped-tensor (graph-safe) code path is never exercised under the default env-var setting.
benchmarks/linear/benchmark_grouped_linear.py Adds env-var-gated conversion of m_splits to a CUDA tensor before the benchmark loop; straightforward and correct.
qa/L0_pytorch_unittest/test.sh Adds test_grouped_linear.py to the CI pipeline with appropriate JIT/compile flags; no issues.
tests/pytorch/test_numerics.py Grouped-linear tests moved out; only minor removals, no new logic introduced.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["GroupedLinear.forward(inp, m_splits)"] --> B["Normalize m_splits\n(list → CPU tensor)"]
    B --> C{_is_grouped_tensor_path_supported?}
    C -- "Yes\n(SM100+, MXFP8/BF16/FP16,\nenv var set)" --> D["_forward_grouped_tensor\n(GroupedTensor path)"]
    C -- No --> E["Legacy path\ntex.split_quantize / general_grouped_gemm"]
    D --> D1["tex.group_quantize\n(or _make_grouped_tensor)"]
    D1 --> D2["general_grouped_gemm_for_grouped_tensor\n(TN layout — fprop)"]
    D2 --> D3["Return out\n(CUDA-graph-safe iff m_splits is CUDA tensor)"]
    E --> E1["tex.split_quantize\n(or torch.split)"]
    E1 --> E2["general_grouped_gemm\n(per-group GEMM)"]

    subgraph Backward grouped tensor path
        B1["_backward_grouped_tensor"] --> B2["dgrad: general_grouped_gemm_for_grouped_tensor NN"]
        B1 --> B3["bias grad: compute_grouped_dbias\nor bgrad_group_quantize"]
        B1 --> B4{delay_wgrad_compute?}
        B4 -- Yes --> B5["wgrad_store.put\ngrouped_x grouped_dy wgrad_list"]
        B5 --> B6["backward_dw called later\n grouped_gemm_wgrad NT"]
        B4 -- No --> B7["grouped_gemm_wgrad NT\ngeneral_grouped_gemm_for_grouped_tensor"]
    end
Loading

Reviews (6): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread tests/pytorch/test_grouped_linear.py
Comment thread tests/pytorch/test_grouped_linear.py
Comment thread transformer_engine/pytorch/module/grouped_linear.py
Comment thread benchmarks/linear/benchmark_grouped_linear.py
Comment thread tests/pytorch/test_grouped_linear.py Outdated
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/enable-grouped-quantize-cublaslt branch from d176247 to 698383e Compare May 25, 2026 03:56
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 25, 2026

/te-ci pytorch

Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 26, 2026

/te-ci pytorch

Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, we should consider moving the grouped MLP tests from test_fusible_ops.py into this file.

Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
timmoon10 and others added 3 commits May 29, 2026 02:24
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Handle tensor splits in both legacy and graph-safe impls. Create weight grad tensors as subviews of a larger buffer.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Comment on lines +1665 to +1666
m_splits : torch.Tensor
Split sizes for the input tensor.
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 May 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going into the future, I think we should prefer passing in the splits as a torch.Tensor, even if they are on CPU. This makes the API more consistent. We need to support lists of ints for backward compatibility, but it should be considered deprecated.

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI.

As a followup, we should consolidate the implementation so that we can reuse the same code in te.ops.GroupedLinear and te.GroupedLinear.

@timmoon10
Copy link
Copy Markdown
Member

/te-ci pytorch

Comment on lines +1688 to 1697
if not isinstance(m_splits, torch.Tensor):
# Convert list of ints to tensor for backward compatibility
m_splits = torch.tensor(m_splits, dtype=torch.int64, device="cpu")
elif m_splits.dtype != torch.int64:
m_splits = m_splits.to(dtype=torch.int64)
if m_splits.size() != (num_gemms,):
raise ValueError(
f"Number of splits ({len(m_splits)}) should match number of"
f" GEMMs ({self.num_gemms})."
f"Shape of splits tensor ({tuple(m_splits.size())}) "
f"does not match number of GEMMs ({num_gemms})."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 CPU→GPU sync silently breaks CUDA-graph capture on the new path

When NVTE_GROUPED_LINEAR_USE_FUSED_GROUPED_GEMM=1 is active and m_splits is supplied as a Python list (or as a CPU tensor), the normalization at line 1690 creates a CPU tensor. Inside _forward_grouped_tensor that CPU tensor is moved to the compute device via split_sizes = m_splits.to(device=device) (line 241). This CPU→GPU transfer cannot be captured by CUDA Graph, silently defeating the PR's primary graph-safety guarantee for any caller that still passes a list.

Since the legacy path (which correctly handles lists) is only a m_splits.tolist() call away, a simple guard here would prevent the silent mis-use: raise (or warn) when _is_grouped_tensor_path_supported() would return True but m_splits is not already on the compute device.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the splits are passed in as a list or on CPU, there's no hope of CUDA Graph capture anyways. A blocking H2D memcpy is the best we can do. Warning is also excessive, since CPU splits are perfectly valid (if suboptimal) when running without CUDA Graphs.

@timmoon10 timmoon10 merged commit f8bda5d into NVIDIA:main May 29, 2026
23 of 25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants