metal/ggml: keep q4_0 decode on the mat-vec kernel up to 8 rows#2369
Open
czoli1976 wants to merge 2 commits into
Open
metal/ggml: keep q4_0 decode on the mat-vec kernel up to 8 rows#2369czoli1976 wants to merge 2 commits into
czoli1976 wants to merge 2 commits into
Conversation
The GGML matmul kernels hardcoded f32 output, and the q4_0 / f16-weight GEMV+GEMM paths required f32 activations. So a q40ef16 model (Q4_0 weights, f16 activations — the common on-device LLM layout) bounced every matmul through f32: the transform inserted a f16->f32 cast on the activation and a f32->f16 cast on the output. Make the output dtype follow the activation dtype and let the kernels consume f16 activations directly: - ggml_mm_mv.metal: the mul_mv output pointer is now the activation type T1 (f16 activations -> f16 output); the q4_0 GEMV is templated on the activation/output type (new kernel_mul_mv_q4_0_f16, accumulating in f32); the GEMM (kernel_mul_mm) is templated on the activation/output type, converting f16 activations to f32 in threadgroup memory and writing f16 output through the f32 simdgroup scratch (simdgroup_store only targets float). New kernel_mul_mm_f16_f16 / kernel_mul_mm_q4_0_f16 instantiations. - ggml_gemm/mod.rs: output_dt returns the activation dtype; the GEMV/GEMM dispatch and dtype guards accept f16 activations and pick the f16 kernels. - transform.rs: drop the forced f16->f32 activation upcast; output_dt now makes the post-matmul f32->f16 cast a no-op too. Correctness: all 53 tract-metal GPU tests pass, including a new mmm_ggml_prop_q4_f16 prop test (q4_0 weights x f16 activations vs f32 CPU reference). End-to-end on Qwen3-1.7B q40ef16 (Metal), greedy output is identical before/after. Benchmark (Qwen3-1.7B q40ef16, Metal decode, examples/causal_llm complete_bench, mean of 3 x 96 tokens): baseline (f32 round-trip): ~41.6 tok/s (24.0 ms/token) this change (f16 direct) : ~45.6 tok/s (21.9 ms/token) ~10% faster No clash with sonos#2320 (it only flips `mod mfa` -> `pub mod mfa`; this touches the ggml_gemm kernels, output_dt and the matmul lowering). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
The q4_0 matrix-vector kernel is bandwidth-bound on the weight read and stays cheaper than the tiled GEMM up to ~8 activation rows, but the dispatcher switched to GEMM at m>4, making 5-8-row q4 decode (batched or speculative) needlessly slow. Raise the q4 mat-vec row cap to 8; f16/f32 stay at 4. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This was referenced Jun 14, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Depends on #2366 (
perf/metal-ggml-f16-roundtrip) and is stacked on it — please merge #2366 first. Until then this PR's diff includes #2366's commit; it collapses to the single threshold change once #2366 lands.The q4_0 matrix-vector kernel is bandwidth-bound on the weight read and stays cheaper than the tiled GEMM up to ~8 activation rows, but the dispatcher switched to GEMM at
m > 4, so 5–8-row q4 decode (batched, or speculative / lookahead) paid the full GEMM cost for no gain. This raises the q4 mat-vec row cap to 8; f16/f32 stay at 4.Perf
Forward-pass latency, Qwen3-1.7B q40ef16, Metal (Apple M-series), 256-token past, median ms/pass:
The 5–8-row band now lands on the mat-vec path (m=6: −26% vs #2366). Single-token decode (m=1) and prefill (m≥12) are unchanged. Downstream this turns k=4 speculative decoding on Qwen3-1.7B from a slowdown (~0.81×) into a ~1.19× speedup, and benefits any small-batch q4 decode.
The crossover (8) is measured on Apple GPUs and would ideally be device-tuned.
🤖 Generated with Claude Code