Skip to content

transformers: route fused ScaledMaskedSoftmax through an accurate vectorized exp#2318

Open
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feat/arm64-neon-fp16-activations
Open

transformers: route fused ScaledMaskedSoftmax through an accurate vectorized exp#2318
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feat/arm64-neon-fp16-activations

Conversation

@czoli1976

@czoli1976 czoli1976 commented May 29, 2026

Copy link
Copy Markdown
Contributor

What

ScaledMaskedSoftmax::eval hard-coded SoftmaxExp::Libc, so the fused attention softmax always ran scalar libm expf and never reached the linalg SIMD softmax kernels — a real perf gap. This PR closes that gap without sacrificing accuracy, by adding an accurate vectorizable exp and routing the fused softmax through it.

History: this PR originally just flipped Libc → FastCompact. CI (and a local full-suite run) showed that was wrong on two counts, so the approach was reworked — see below.

Why not FastCompact

Switching the fused softmax to the existing SoftmaxExp::FastCompact (Schraudolph approximation) fails the scaled_masked_softmax / sdpa proptests two independent ways:

  1. Precision. FastCompact's exp is ~0.5% off true softmax — outside the suite's Approximate tolerance (f32 rtol 5e-4), producing 30%+ outliers. (The existing softmax_l2 frame test only compares FastCompact against itself, so it never caught this.)
  2. Fully-masked rows → NaN mismatch. On an all--inf row the FastCompact kernel pads the SIMD tail with f32::MIN and computes exp(f32::MIN - f32::MIN) ≈ 1, so the row sums to a nonzero value and yields a finite 0 where the scalar libc path and the numpy reference both yield NaN (0 * 1/0).

This mirrors what ggml/llama.cpp concluded (ggml-org/llama.cpp#7154): keep an accurate vectorized expf for softmax, reserve fast-approx exp for error-tolerant ops.

What changed

  • linalg: accurate_exp_f32 — a Cephes-style range-reduced exp (Cody-Waite ln2 split + degree-6 polynomial + 2^n by exponent construction). Measured max rel error ~1.9e-6 vs libc over the softmax domain [0, -60]. exp(0)==1 and exp(-inf)==0 exactly; deep underflow flushes to 0; NaN propagates.
  • linalg: SSoftMaxL2Accurate / HSoftMaxL2Accurate map-reduce kernels, exposed as softmax2_accurate_{f32,f16}. They pad the SIMD tail with -inf (not f32::MIN), so masked/padding lanes contribute exactly 0 and a fully-masked row sums to 0NaN, matching libc and the reference.
  • core: new SoftmaxExp::Accurate variant + dispatch (f32/f16).
  • nnef: exp = "accurate" de/serialization round-trip.
  • transformers: ScaledMaskedSoftmax::eval uses SoftmaxExp::Accurate.

Libc remains the default everywhere; FastCompact is untouched. This adds a third, accurate-but-vectorized option and points fused attention at it.

Tests

  • New linalg tests validate accurate_exp_f32 against libc (not against itself) and cover the fully-masked degenerate row (sum == 0).
  • scaled_masked_softmax + sdpa proptests (f16/f32 × raw/decluttered/optimized) pass on native and wasm32-wasip1 (the job that originally failed).
  • Full tract-linalg suite green; tract-core / nnef / transformers green; cargo fmt clean; no new clippy warnings on touched files.

🤖 Generated with Claude Code

@czoli1976 czoli1976 marked this pull request as draft May 29, 2026 17:14
@czoli1976

Copy link
Copy Markdown
Contributor Author

taking it back, correctness failure to investigate

@czoli1976 czoli1976 force-pushed the feat/arm64-neon-fp16-activations branch from eb71f47 to 457c825 Compare May 29, 2026 21:41
@czoli1976 czoli1976 changed the title transformers: ScaledMaskedSoftmax eval uses FastCompact exp to unlock SIMD softmax transformers: route fused ScaledMaskedSoftmax through an accurate vectorized exp May 29, 2026
@czoli1976 czoli1976 force-pushed the feat/arm64-neon-fp16-activations branch from 457c825 to 9a5fa12 Compare May 30, 2026 06:36
@czoli1976 czoli1976 marked this pull request as ready for review May 30, 2026 07:07
@kali kali force-pushed the feat/arm64-neon-fp16-activations branch from 9a5fa12 to 476ef48 Compare June 8, 2026 11:56
@kali

kali commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

rebased!

ScaledMaskedSoftmax::eval hard-coded SoftmaxExp::Libc, so the fused
attention softmax always ran scalar libm expf and never reached the
linalg SIMD softmax kernels — a real perf gap. The naive fix (switch to
SoftmaxExp::FastCompact) trades correctness for speed and fails the
proptests two ways:

  1. FastCompact's Schraudolph exp is ~0.5% off true softmax — outside
     the suite's Approximate tolerance (f32 rtol 5e-4), 30%+ outliers.
  2. On a fully-masked row (all -inf) the FastCompact kernel pads the
     SIMD tail with f32::MIN and computes exp(f32::MIN - f32::MIN) ≈ 1,
     so the row sums to a nonzero value and yields a finite 0 where the
     scalar libc path (and the numpy reference) yield NaN (0 * 1/0).

Instead, add an accurate vectorizable exp and route the fused softmax
through it (mirrors ggml/llama.cpp, which kept an accurate vectorized
expf for softmax rather than a coarse approximation):

  * linalg: `accurate_exp_f32`, a Cephes-style range-reduced exp
    (Cody-Waite ln2 split + degree-6 poly + 2^n by exponent
    construction). Measured max rel error ~1.9e-6 vs libc over the
    softmax domain [0, -60]. exp(0)==1 and exp(-inf)==0 exactly; deep
    underflow flushes to 0; NaN propagates.
  * linalg: `SSoftMaxL2Accurate` / `HSoftMaxL2Accurate` map-reduce
    kernels, exposed as `softmax2_accurate_{f32,f16}`. They pad the
    SIMD tail with -inf (not f32::MIN), so masked/padding lanes
    contribute exactly 0 and a fully-masked row sums to 0 -> NaN,
    matching libc and the reference.
  * core: new `SoftmaxExp::Accurate` variant + dispatch.
  * nnef: `exp = "accurate"` de/serialization round-trip.
  * transformers: ScaledMaskedSoftmax::eval uses SoftmaxExp::Accurate.

New linalg tests validate the accurate exp against libc (not against
itself, unlike the existing FastCompact frame test) and cover the
fully-masked degenerate row. scaled_masked_softmax + sdpa proptests
(f16/f32, raw/decluttered/optimized) pass on native and wasm32-wasip1.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@kali

kali commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

⚠️⚠️⚠️ Just rebased! ⚠️⚠️⚠️

@kali kali force-pushed the feat/arm64-neon-fp16-activations branch from 476ef48 to a18d911 Compare June 18, 2026 14:01
@github-actions

Copy link
Copy Markdown

🔴 Bench vs main — 5 speed regression(s)

Reference: main nightly, latest 2026-06-18 (0d old) · PR 47aadbb21 · ran on apple-m1-max, i9-11900kb_rtx-4060, jetson-orin-nx · 1089 metrics compared

Speed — evaltime · prefill · decode

Δ metric device main → PR
🔴 -10.2% openelm_270M_q40ef16_541
prefill · cpu
i9-11900kb_rtx-4060 47 tok/s → 42.19 tok/s
🔴 -8.8% openelm_270M_q40ef16_516
prefill · cpu
i9-11900kb_rtx-4060 46.2 tok/s → 42.15 tok/s
🔴 -7.4% llama_3_2_1B_q40ef32_516
prefill · cpu
i9-11900kb_rtx-4060 14.4 tok/s → 13.34 tok/s
🔴 -5.5% llama_3_2_1B_instruct_q40ef16_541
prefill · cpu
i9-11900kb_rtx-4060 12.2 tok/s → 11.53 tok/s
🔴 +5.0% parakeet_tdt_600m_v3_f32f32_preprocessor_1s
evaltime · cuda
i9-11900kb_rtx-4060 3.26 ms → 3.43 ms

lower is better except prefill/decode (tok/s) · adaptive thresholds (max(floor, k×noise) vs the series' own history) · single-shot vs nightly reference · full table → run summary

@kali

kali commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

@czoli1976 surprising bench results ?

@czoli1976

Copy link
Copy Markdown
Contributor Author

how is MLX ?

@kali

kali commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

Mmm... I'm not sure I follow ?

@czoli1976

czoli1976 commented Jun 19, 2026 via email

Copy link
Copy Markdown
Contributor Author

@kali

kali commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

so the link at the bottom leads to the full result reports, including the mac/metal ones (table is huge and flat, i'm gonna work on it soon). Search for "tok/s", the mac is the first block :)

@czoli1976

czoli1976 commented Jun 19, 2026 via email

Copy link
Copy Markdown
Contributor Author

@kali

kali commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

No rush, I'm barely keeping up anyway.

@kali

kali commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

I happen to have canibalized this function for a different context, and struggle to get it to vectorize. My imaginary friend came with this:

 accurate_exp_f32: branchless underflow flush (for autovectorization)

  Change. Replace the trailing early-out branch with a bitmask flush:

  // #2318 as written:
      let e = y * pow2n;
      if x < LO { 0.0 } else { e }

  // branchless:
      let e = y * pow2n;
      // 0 when x < LO, else e
      let flush = -((x < LO) as i32) as u32;   // all-ones if x < LO, else 0
      f32::from_bits(e.to_bits() & !flush)

  Why. The if x < LO { 0.0 } early-out is the one thing that stops the function from
  autovectorizing when it's inlined into a reduction loop (e.g. iter().map(|x|
  accurate_exp_f32(x - max)).sum()). LLVM treats the underflow path as profitable to
  skip and sinks the polynomial behind a per-element conditional branch (ucomiss/ja) —
  so even compiled with AVX2/AVX-512 target features, the loop stays scalar. The
  bitmask form keeps the dataflow straight-line, and the comparison lowers to a vector
  mask (vcmpltps → vandps), letting the whole body pack.

@czoli1976

czoli1976 commented Jun 19, 2026 via email

Copy link
Copy Markdown
Contributor Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants