Skip to content

Add shape-aware recommended alignment for SM90 small-M grouped GEMM#350

Open
qescccczmr wants to merge 2 commits into
deepseek-ai:mainfrom
qescccczmr:sm90-small-m-alignment
Open

Add shape-aware recommended alignment for SM90 small-M grouped GEMM#350
qescccczmr wants to merge 2 commits into
deepseek-ai:mainfrom
qescccczmr:sm90-small-m-alignment

Conversation

@qescccczmr

Copy link
Copy Markdown

Summary

This PR adds a shape-aware recommended M/K alignment helper for contiguous grouped layouts.

The existing get_theoretical_mk_alignment_for_contiguous_layout() keeps SM90 on the legacy 128 alignment. For SM90 m-grouped contiguous GEMM, this can introduce extra
padded rows for some small-M non-psum shapes.

This PR adds an opt-in recommendation helper that returns 64 only for the empirically validated shape family:

  • SM90
  • non-psum m-grouped contiguous layout
  • num_groups == 4
  • expected_m == 128
  • expected_k <= 256

All other shapes fall back to the existing theoretical alignment.

This does not change the default alignment.

Motivation

Grouped contiguous layout requires each expert segment to be aligned by get_mk_alignment_for_contiguous_layout(). In the target small-M case, using 128 alignment pads
the generated test shape from M=640 to M=768.

Reducing the recommended alignment to 64 for the validated shape reduces padded rows while preserving correctness.

Broader scans showed regressions for other shapes such as expected_m=64, expected_m=192, num_groups=8/16, and k>=512, so this helper intentionally keeps the 64
recommendation restricted.

Changes

  • Add get_recommended_mk_alignment_for_contiguous_layout(use_psum_layout, expected_m, expected_k, expected_num_groups).
  • Expose the helper through the Python layout API.
  • Document the helper in README.
  • Rename the benchmark/test wording from small-K to small-M.
  • Add SM90 correctness and guardrail coverage for:
    • target shape returns 64
    • psum layout falls back
    • expected_m=64 falls back
    • k=512 falls back
    • num_groups=8 falls back

Benchmark

Hardware:

  • NVIDIA H200
  • SM90
  • CUDA 12.8
  • NVCC 12.8

Command:

python tools/bench_sm90_small_m_m_grouped.py \
  --compare-recommended \
  --num-groups-list 4,8,16 \
  --expected-m-list 128,192 \
  --n-list 3072,4096,6144,7168 \
  --k-list 128,256,512,1024 \
  --psum-list false,true \
  --repeat 20

Target shape results:

num_groups=4, expected_m=128, psum=false

n,k,baseline_alignment,recommended_alignment,baseline_m,recommended_m,speedup
3072,128,128,64,768,640,1.0361x
3072,256,128,64,768,640,1.0020x
4096,128,128,64,768,640,1.0944x
4096,256,128,64,768,640,1.0492x
6144,128,128,64,768,640,1.0975x
6144,256,128,64,768,640,1.0610x
7168,128,128,64,768,640,1.1163x
7168,256,128,64,768,640,1.0603x

Non-target shapes keep recommended_alignment=128, including:

- psum=true
- expected_m=192
- k=512/1024

This avoids the regressions observed when applying 64 alignment more broadly.

## Testing

python -m py_compile \
  tools/bench_sm90_small_m_m_grouped.py \
  tests/test_sm90_small_m_alignment.py

On H200:

python tests/test_sm90_small_m_alignment.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant