transformers: route fused ScaledMaskedSoftmax through an accurate vectorized exp#2318
transformers: route fused ScaledMaskedSoftmax through an accurate vectorized exp#2318czoli1976 wants to merge 1 commit into
Conversation
|
taking it back, correctness failure to investigate |
eb71f47 to
457c825
Compare
457c825 to
9a5fa12
Compare
9a5fa12 to
476ef48
Compare
|
rebased! |
ScaledMaskedSoftmax::eval hard-coded SoftmaxExp::Libc, so the fused
attention softmax always ran scalar libm expf and never reached the
linalg SIMD softmax kernels — a real perf gap. The naive fix (switch to
SoftmaxExp::FastCompact) trades correctness for speed and fails the
proptests two ways:
1. FastCompact's Schraudolph exp is ~0.5% off true softmax — outside
the suite's Approximate tolerance (f32 rtol 5e-4), 30%+ outliers.
2. On a fully-masked row (all -inf) the FastCompact kernel pads the
SIMD tail with f32::MIN and computes exp(f32::MIN - f32::MIN) ≈ 1,
so the row sums to a nonzero value and yields a finite 0 where the
scalar libc path (and the numpy reference) yield NaN (0 * 1/0).
Instead, add an accurate vectorizable exp and route the fused softmax
through it (mirrors ggml/llama.cpp, which kept an accurate vectorized
expf for softmax rather than a coarse approximation):
* linalg: `accurate_exp_f32`, a Cephes-style range-reduced exp
(Cody-Waite ln2 split + degree-6 poly + 2^n by exponent
construction). Measured max rel error ~1.9e-6 vs libc over the
softmax domain [0, -60]. exp(0)==1 and exp(-inf)==0 exactly; deep
underflow flushes to 0; NaN propagates.
* linalg: `SSoftMaxL2Accurate` / `HSoftMaxL2Accurate` map-reduce
kernels, exposed as `softmax2_accurate_{f32,f16}`. They pad the
SIMD tail with -inf (not f32::MIN), so masked/padding lanes
contribute exactly 0 and a fully-masked row sums to 0 -> NaN,
matching libc and the reference.
* core: new `SoftmaxExp::Accurate` variant + dispatch.
* nnef: `exp = "accurate"` de/serialization round-trip.
* transformers: ScaledMaskedSoftmax::eval uses SoftmaxExp::Accurate.
New linalg tests validate the accurate exp against libc (not against
itself, unlike the existing FastCompact frame test) and cover the
fully-masked degenerate row. scaled_masked_softmax + sdpa proptests
(f16/f32, raw/decluttered/optimized) pass on native and wasm32-wasip1.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
|
476ef48 to
a18d911
Compare
|
🔴 Bench vs main — 5 speed regression(s) Reference: main nightly, latest 2026-06-18 (0d old) · PR Speed — evaltime · prefill · decode
lower is better except prefill/decode (tok/s) · adaptive thresholds (max(floor, k×noise) vs the series' own history) · single-shot vs nightly reference · full table → run summary |
|
@czoli1976 surprising bench results ? |
|
how is MLX ? |
|
Mmm... I'm not sure I follow ? |
|
If you run these on macOS via Metal, do you see any difference? The Results
above seem to indicate a windows Only PC as an Host
Il giorno ven 19 giu 2026 alle ore 07:29 Mathieu Poumeyrol <
***@***.***> ha scritto:
… *kali* left a comment (sonos/tract#2318)
<#2318 (comment)>
Mmm... I'm not sure I follow ?
—
Reply to this email directly, view it on GitHub
<#2318?email_source=notifications&email_token=APL2Z6WEKA3626YLCMW5PWT5ATMWFA5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINZUHEYDOOBQGU2KM4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJLDGN5XXIZLSL5RWY2LDNM#issuecomment-4749078054>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/APL2Z6XPS7CRZRET2Z4C6CD5ATMWFAVCNFSNUABEKJSXA33TNF2G64TZHM4TSNJWGEZDCMZ3JFZXG5LFHM2DKNBZGE4TKNRVGOQXMAQ>
.
Triage notifications, keep track of coding agent tasks and review pull
requests on the go with GitHub Mobile for iOS
<https://github.com/notifications/mobile/ios/APL2Z6W77ZZF7LKTDH65WMT5ATMWFA5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINZUHEYDOOBQGU2KM4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJKTGN5XXIZLSL5UW64Y>
and Android
<https://github.com/notifications/mobile/android/APL2Z6XZ3M7LTGXQIEOPPBD5ATMWFA5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINZUHEYDOOBQGU2KM4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJLTGN5XXIZLSL5QW4ZDSN5UWI>.
Download it today!
You are receiving this because you were mentioned.Message ID:
***@***.***>
--
Best Regards
Ckristian Zoli
Email: ***@***.***
Mobile: +39.3474721699
|
|
so the link at the bottom leads to the full result reports, including the mac/metal ones (table is huge and flat, i'm gonna work on it soon). Search for "tok/s", the mac is the first block :) |
|
ok, will, for now let's take in what is obviously a winner.
will try to figure it out next week, life a bit in the way right now
Il giorno ven 19 giu 2026 alle ore 07:58 Mathieu Poumeyrol <
***@***.***> ha scritto:
… *kali* left a comment (sonos/tract#2318)
<#2318 (comment)>
so the link at the bottom leads to the full result reports, including the
mac/metal ones (table is huge and flat, i'm gonna work on it soon). Search
for "tok/s", the mac is the first block :)
—
Reply to this email directly, view it on GitHub
<#2318?email_source=notifications&email_token=APL2Z6TX7B2ROHCF5SPMN3L5ATP75A5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINZUHEZDGNRWHE42M4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJLDGN5XXIZLSL5RWY2LDNM#issuecomment-4749236699>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/APL2Z6QOKUTNNRG7BYR7G4T5ATP75AVCNFSNUABEKJSXA33TNF2G64TZHM4TSNJWGEZDCMZ3JFZXG5LFHM2DKNBZGE4TKNRVGOQXMAQ>
.
Triage notifications, keep track of coding agent tasks and review pull
requests on the go with GitHub Mobile for iOS
<https://github.com/notifications/mobile/ios/APL2Z6V5YJEKEMWMYQBYKHD5ATP75A5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINZUHEZDGNRWHE42M4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJKTGN5XXIZLSL5UW64Y>
and Android
<https://github.com/notifications/mobile/android/APL2Z6QYLOZZKZ2M7AEP6XT5ATP75A5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINZUHEZDGNRWHE42M4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJLTGN5XXIZLSL5QW4ZDSN5UWI>.
Download it today!
You are receiving this because you were mentioned.Message ID:
***@***.***>
--
Best Regards
Ckristian Zoli
Email: ***@***.***
Mobile: +39.3474721699
|
|
No rush, I'm barely keeping up anyway. |
|
I happen to have canibalized this function for a different context, and struggle to get it to vectorize. My imaginary friend came with this: |
|
Let me ask mine
Best Regards
Ckristian Zoli
Email: ***@***.***
…On Fri, 19 Jun 2026 at 11:51 Mathieu Poumeyrol ***@***.***> wrote:
*kali* left a comment (sonos/tract#2318)
<#2318 (comment)>
I happen to have canibalized this function for a different context, and
struggle to get it to vectorize. My imaginary friend came with this:
accurate_exp_f32: branchless underflow flush (for autovectorization)
Change. Replace the trailing early-out branch with a bitmask flush:
// #2318 as written:
let e = y * pow2n;
if x < LO { 0.0 } else { e }
// branchless:
let e = y * pow2n;
// 0 when x < LO, else e
let flush = -((x < LO) as i32) as u32; // all-ones if x < LO, else 0
f32::from_bits(e.to_bits() & !flush)
Why. The if x < LO { 0.0 } early-out is the one thing that stops the function from
autovectorizing when it's inlined into a reduction loop (e.g. iter().map(|x|
accurate_exp_f32(x - max)).sum()). LLVM treats the underflow path as profitable to
skip and sinks the polynomial behind a per-element conditional branch (ucomiss/ja) —
so even compiled with AVX2/AVX-512 target features, the loop stays scalar. The
bitmask form keeps the dataflow straight-line, and the comparison lowers to a vector
mask (vcmpltps → vandps), letting the whole body pack.
—
Reply to this email directly, view it on GitHub
<#2318?email_source=notifications&email_token=APL2Z6UDUHSPRMOTTP46XQ35AULJNA5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINZVGA4DIOJWGU32M4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJLDGN5XXIZLSL5RWY2LDNM#issuecomment-4750849657>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/APL2Z6SVCFNZK6GMSGJPSLL5AULJNAVCNFSNUABEKJSXA33TNF2G64TZHM4TSNJWGEZDCMZ3JFZXG5LFHM2DKNBZGE4TKNRVGOQXMAQ>
.
Triage notifications, keep track of coding agent tasks and review pull
requests on the go with GitHub Mobile for iOS
<https://github.com/notifications/mobile/ios/APL2Z6VRDEGMRTAMKIDOJ3T5AULJNA5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINZVGA4DIOJWGU32M4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJKTGN5XXIZLSL5UW64Y>
and Android
<https://github.com/notifications/mobile/android/APL2Z6USLC5EVZN7KHQIBFD5AULJNA5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINZVGA4DIOJWGU32M4TFMFZW63VHNVSW45DJN5XKKZLWMVXHJLTGN5XXIZLSL5QW4ZDSN5UWI>.
Download it today!
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
What
ScaledMaskedSoftmax::evalhard-codedSoftmaxExp::Libc, so the fused attention softmax always ran scalar libmexpfand never reached the linalg SIMD softmax kernels — a real perf gap. This PR closes that gap without sacrificing accuracy, by adding an accurate vectorizableexpand routing the fused softmax through it.Why not FastCompact
Switching the fused softmax to the existing
SoftmaxExp::FastCompact(Schraudolph approximation) fails thescaled_masked_softmax/sdpaproptests two independent ways:expis ~0.5% off true softmax — outside the suite'sApproximatetolerance (f32 rtol5e-4), producing 30%+ outliers. (The existingsoftmax_l2frame test only compares FastCompact against itself, so it never caught this.)-infrow the FastCompact kernel pads the SIMD tail withf32::MINand computesexp(f32::MIN - f32::MIN) ≈ 1, so the row sums to a nonzero value and yields a finite0where the scalar libc path and the numpy reference both yieldNaN(0 * 1/0).This mirrors what ggml/llama.cpp concluded (ggml-org/llama.cpp#7154): keep an accurate vectorized
expffor softmax, reserve fast-approxexpfor error-tolerant ops.What changed
accurate_exp_f32— a Cephes-style range-reducedexp(Cody-Waiteln2split + degree-6 polynomial +2^nby exponent construction). Measured max rel error ~1.9e-6 vs libc over the softmax domain[0, -60].exp(0)==1andexp(-inf)==0exactly; deep underflow flushes to0;NaNpropagates.SSoftMaxL2Accurate/HSoftMaxL2Accuratemap-reduce kernels, exposed assoftmax2_accurate_{f32,f16}. They pad the SIMD tail with-inf(notf32::MIN), so masked/padding lanes contribute exactly0and a fully-masked row sums to0→NaN, matching libc and the reference.SoftmaxExp::Accuratevariant + dispatch (f32/f16).exp = "accurate"de/serialization round-trip.ScaledMaskedSoftmax::evalusesSoftmaxExp::Accurate.Libcremains the default everywhere;FastCompactis untouched. This adds a third, accurate-but-vectorized option and points fused attention at it.Tests
accurate_exp_f32against libc (not against itself) and cover the fully-masked degenerate row (sum == 0).scaled_masked_softmax+sdpaproptests (f16/f32 × raw/decluttered/optimized) pass on native and wasm32-wasip1 (the job that originally failed).tract-linalgsuite green;tract-core/nnef/transformersgreen;cargo fmtclean; no new clippy warnings on touched files.🤖 Generated with Claude Code