metal: pack element-wise kernels into full-width threadgroups#2365
Open
czoli1976 wants to merge 1 commit into
Open
metal: pack element-wise kernels into full-width threadgroups#2365czoli1976 wants to merge 1 commit into
czoli1976 wants to merge 1 commit into
Conversation
The flat (one-thread-per-element) Metal kernels — element_wise, cast, copy_unicast, silu, gelu_approx, leaky_relu, the bin_op 1-row path and the iff/select generic kernel — dispatched `n` threadgroups of a *single* thread (`dispatch_thread_groups(grid = n, group = 1)`). On Apple GPUs each threadgroup owns its own SIMD-group, so a 1-thread group leaves 31 of 32 lanes idle. Add `utils::dispatch_threads_1d`, which uses `dispatch_threads` (non-uniform threadgroups, already used by apply_rope) to pack the same `n` threads into threadgroups of up to the pipeline maximum. The kernels index by `thread_position_in_grid`, which is unchanged, so this is a pure dispatch-side change with no kernel edits. The structural `threadgroup_position_in_grid` kernels (gather, broadcast bin_op) are left as-is — they encode one threadgroup per element and need separate reworks. All 63 tract-metal GPU tests pass. Benchmark (M-series GPU, us/call, silu f32 / cast f32->f16), via the added `threadgroup_bench` example, baseline = 1-thread groups: n silu base -> fix cast base -> fix 16384 6.08 -> 3.32 (1.8x) 6.85 -> 2.96 (2.3x) 65536 24.91 -> 5.00 (5.0x) 13.71 -> 6.15 (2.2x) 262144 31.64 -> 13.1 (2.4x) 54.60 -> 10.7 (5.1x) 1048576 143.1 -> 68.7 (2.1x) 69.1 -> 48.6 (1.4x) 4194304 385 -> 391 (1.0x) 279 -> 280 (1.0x) 16777216 1541 -> 1556 (1.0x) 1149 -> 1160 (1.0x) So: 1.4-5x on small/medium element-wise dispatches (the per-token decode sizes), neutral on large tensors (already memory-bandwidth-bound, where 1-thread groups provide enough parallelism to saturate). No regression observed. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
The flat, one-thread-per-element Metal kernels dispatched
nthreadgroups of a single thread (dispatch_thread_groups(grid = n, group = 1)). On Apple GPUs each threadgroup owns its own SIMD-group, so a 1-thread threadgroup leaves 31 of every 32 lanes idle.This adds
utils::dispatch_threads_1d, which usesdispatch_threads(non-uniform threadgroups — already used byapply_rope) to pack the samenthreads into threadgroups of up to the pipeline maximum. The kernels index bythread_position_in_grid, which is unchanged, so this is a dispatch-side-only change — no kernel edits.Converted (all index by
thread_position_in_grid):element_wise,cast,copy_unicast,silu,gelu_approx,leaky_relu, thebin_op1-row path, and theiff/selectiff_generickernel.Left as-is: the
threadgroup_position_in_gridkernels (gather,diag_gather, broadcastbin_op) — they structurally encode one threadgroup per element and need separate reworks; the broadcastbin_opalready uses a real threadgroup viabuild_metal_grid_and_groups_for_el_wise_op.Honest scope
The "31/32 SIMD lanes idle" framing oversells it: these ops are memory-bandwidth-bound, and on large tensors the GPU already has enough 1-thread threadgroups in flight to saturate bandwidth, so the change is neutral there. The win shows up on small/medium dispatches — which is exactly the per-token LLM-decode regime (element-wise/activation/cast over
hidden_size × 1 token).Benchmark
Added
metal/examples/threadgroup_bench.rs(M-series GPU, µs/call, 200 iters; baseline = 1-thread threadgroups):1.4–5× on small/medium element-wise dispatches; neutral on large memory-bound tensors. No regression.
Testing
All 63
tract-metalGPU tests pass (run on Apple Silicon). The change is +31/−33 in the library (the helper replaces 8 duplicated grid/group blocks); the example adds the benchmark.Files
metal/src/kernels/utils.rs—dispatch_threads_1dhelpermetal/src/kernels/{element_wise,bin_ops}.rs,kernels/array/{cast,copy}.rs,kernels/nn/{silu,gelu_approximate,leaky_relu}.rs— use itmetal/examples/threadgroup_bench.rs— benchmark🤖 Generated with Claude Code