Skip to content

metal: pack element-wise kernels into full-width threadgroups#2365

Open
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:perf/metal-threadgroup-occupancy
Open

metal: pack element-wise kernels into full-width threadgroups#2365
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:perf/metal-threadgroup-occupancy

Conversation

@czoli1976

Copy link
Copy Markdown
Contributor

What

The flat, one-thread-per-element Metal kernels dispatched n threadgroups of a single thread (dispatch_thread_groups(grid = n, group = 1)). On Apple GPUs each threadgroup owns its own SIMD-group, so a 1-thread threadgroup leaves 31 of every 32 lanes idle.

This adds utils::dispatch_threads_1d, which uses dispatch_threads (non-uniform threadgroups — already used by apply_rope) to pack the same n threads into threadgroups of up to the pipeline maximum. The kernels index by thread_position_in_grid, which is unchanged, so this is a dispatch-side-only change — no kernel edits.

Converted (all index by thread_position_in_grid): element_wise, cast, copy_unicast, silu, gelu_approx, leaky_relu, the bin_op 1-row path, and the iff/select iff_generic kernel.

Left as-is: the threadgroup_position_in_grid kernels (gather, diag_gather, broadcast bin_op) — they structurally encode one threadgroup per element and need separate reworks; the broadcast bin_op already uses a real threadgroup via build_metal_grid_and_groups_for_el_wise_op.

Honest scope

The "31/32 SIMD lanes idle" framing oversells it: these ops are memory-bandwidth-bound, and on large tensors the GPU already has enough 1-thread threadgroups in flight to saturate bandwidth, so the change is neutral there. The win shows up on small/medium dispatches — which is exactly the per-token LLM-decode regime (element-wise/activation/cast over hidden_size × 1 token).

Benchmark

Added metal/examples/threadgroup_bench.rs (M-series GPU, µs/call, 200 iters; baseline = 1-thread threadgroups):

n_elements silu base→fix speedup cast (f32→f16) base→fix speedup
16 384 6.08 → 3.32 1.8× 6.85 → 2.96 2.3×
65 536 24.91 → 5.00 5.0× 13.71 → 6.15 2.2×
262 144 31.64 → 13.12 2.4× 54.60 → 10.67 5.1×
1 048 576 143.1 → 68.7 2.1× 69.1 → 48.6 1.4×
4 194 304 385 → 391 1.0× 279 → 280 1.0×
16 777 216 1541 → 1556 1.0× 1149 → 1160 1.0×

1.4–5× on small/medium element-wise dispatches; neutral on large memory-bound tensors. No regression.

Testing

All 63 tract-metal GPU tests pass (run on Apple Silicon). The change is +31/−33 in the library (the helper replaces 8 duplicated grid/group blocks); the example adds the benchmark.

Files

  • metal/src/kernels/utils.rsdispatch_threads_1d helper
  • metal/src/kernels/{element_wise,bin_ops}.rs, kernels/array/{cast,copy}.rs, kernels/nn/{silu,gelu_approximate,leaky_relu}.rs — use it
  • metal/examples/threadgroup_bench.rs — benchmark

🤖 Generated with Claude Code

The flat (one-thread-per-element) Metal kernels — element_wise, cast,
copy_unicast, silu, gelu_approx, leaky_relu, the bin_op 1-row path and the
iff/select generic kernel — dispatched `n` threadgroups of a *single* thread
(`dispatch_thread_groups(grid = n, group = 1)`). On Apple GPUs each threadgroup
owns its own SIMD-group, so a 1-thread group leaves 31 of 32 lanes idle.

Add `utils::dispatch_threads_1d`, which uses `dispatch_threads` (non-uniform
threadgroups, already used by apply_rope) to pack the same `n` threads into
threadgroups of up to the pipeline maximum. The kernels index by
`thread_position_in_grid`, which is unchanged, so this is a pure dispatch-side
change with no kernel edits. The structural `threadgroup_position_in_grid`
kernels (gather, broadcast bin_op) are left as-is — they encode one threadgroup
per element and need separate reworks.

All 63 tract-metal GPU tests pass.

Benchmark (M-series GPU, us/call, silu f32 / cast f32->f16), via the added
`threadgroup_bench` example, baseline = 1-thread groups:

  n         silu base -> fix     cast base -> fix
  16384       6.08 -> 3.32 (1.8x)  6.85 -> 2.96 (2.3x)
  65536      24.91 -> 5.00 (5.0x) 13.71 -> 6.15 (2.2x)
  262144     31.64 -> 13.1 (2.4x) 54.60 -> 10.7 (5.1x)
  1048576   143.1  -> 68.7 (2.1x) 69.1  -> 48.6 (1.4x)
  4194304   385    -> 391  (1.0x) 279   -> 280  (1.0x)
  16777216 1541    -> 1556 (1.0x) 1149  -> 1160 (1.0x)

So: 1.4-5x on small/medium element-wise dispatches (the per-token decode sizes),
neutral on large tensors (already memory-bandwidth-bound, where 1-thread groups
provide enough parallelism to saturate). No regression observed.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant