Add shape-aware recommended alignment for SM90 small-M grouped GEMM#350
Open
qescccczmr wants to merge 2 commits into
Open
Add shape-aware recommended alignment for SM90 small-M grouped GEMM#350qescccczmr wants to merge 2 commits into
qescccczmr wants to merge 2 commits into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds a shape-aware recommended M/K alignment helper for contiguous grouped layouts.
The existing
get_theoretical_mk_alignment_for_contiguous_layout()keeps SM90 on the legacy 128 alignment. For SM90 m-grouped contiguous GEMM, this can introduce extrapadded rows for some small-M non-psum shapes.
This PR adds an opt-in recommendation helper that returns 64 only for the empirically validated shape family:
num_groups == 4expected_m == 128expected_k <= 256All other shapes fall back to the existing theoretical alignment.
This does not change the default alignment.
Motivation
Grouped contiguous layout requires each expert segment to be aligned by
get_mk_alignment_for_contiguous_layout(). In the target small-M case, using 128 alignment padsthe generated test shape from
M=640toM=768.Reducing the recommended alignment to 64 for the validated shape reduces padded rows while preserving correctness.
Broader scans showed regressions for other shapes such as
expected_m=64,expected_m=192,num_groups=8/16, andk>=512, so this helper intentionally keeps the 64recommendation restricted.
Changes
get_recommended_mk_alignment_for_contiguous_layout(use_psum_layout, expected_m, expected_k, expected_num_groups).expected_m=64falls backk=512falls backnum_groups=8falls backBenchmark
Hardware:
Command:
python tools/bench_sm90_small_m_m_grouped.py \ --compare-recommended \ --num-groups-list 4,8,16 \ --expected-m-list 128,192 \ --n-list 3072,4096,6144,7168 \ --k-list 128,256,512,1024 \ --psum-list false,true \ --repeat 20 Target shape results: num_groups=4, expected_m=128, psum=false n,k,baseline_alignment,recommended_alignment,baseline_m,recommended_m,speedup 3072,128,128,64,768,640,1.0361x 3072,256,128,64,768,640,1.0020x 4096,128,128,64,768,640,1.0944x 4096,256,128,64,768,640,1.0492x 6144,128,128,64,768,640,1.0975x 6144,256,128,64,768,640,1.0610x 7168,128,128,64,768,640,1.1163x 7168,256,128,64,768,640,1.0603x Non-target shapes keep recommended_alignment=128, including: - psum=true - expected_m=192 - k=512/1024 This avoids the regressions observed when applying 64 alignment more broadly. ## Testing python -m py_compile \ tools/bench_sm90_small_m_m_grouped.py \ tests/test_sm90_small_m_alignment.py On H200: python tests/test_sm90_small_m_alignment.py