transformers: in-place KV cache for decode via a fused InPlaceKvSdpa op#2321
transformers: in-place KV cache for decode via a fused InPlaceKvSdpa op#2321czoli1976 wants to merge 1 commit into
Conversation
|
@kali look at me first |
For models trained with sliding-window attention (Mistral, Gemma-style local/global): a fixed-capacity ring buffer that overwrites the oldest slot on append, so decode runs at CONSTANT memory + per-step cost regardless of context length, losslessly (the model is trained to attend only within the window). Cheap because decode attention is ORDER-INVARIANT over keys (O = Σ softmax_j·V_j is unchanged under a (K,V) permutation), so the ring buffer never needs un-rotation — the consumer attends over the W physical slots as-is. Validated: holds the last-W as a set (incl. prefill chunk > window); windowed attention == ordered last-W attention (close, float summation order); memory bounded at W. Companion to the in-place cache (#2321) = 'in-place cache with a cap + wraparound'. 3 tests, fmt+clippy clean. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…rite + NNEF + resume
tract's DynKeyValueCache grows by TypedConcat([past, new]) each step, copying the
whole t-token past into a fresh buffer -> O(T^2) total copy over a T-token decode.
Apple Core ML "stateful in-place KV" lever. Pieces:
1. InPlaceKvCache: geometric-growth in-place cache. Buffer with spare capacity along
`axis`, write each new chunk at the cursor (Tensor::assign_slice, strided-safe for
any axis), double only when capacity is exceeded -> O(T) amortized copy.
valid_view() exposes the live [0..len] region as a ZERO-COPY ndarray view (the path
that realizes the win). For the seq axis of [B,H,S,D] a per-head slice of the
capacity buffer is a contiguous prefix, so a consumer reads it at concat cost.
2. InPlaceKvSdpa: stateful fused op owning the K/V in-place caches, running the CPU
SDPA (FlashSdpaOp::flash_attention_gqa) over the zero-copy views. tract Tensors
cannot be zero-copy views ACROSS an op boundary (Tensor::slice copies), so keeping
the buffers inside the consuming op is what makes the saving real. Drop-in for
{kv_cache(K), kv_cache(V), Sdpa}; does GQA internally.
3. InPlaceKvSdpaTransform: rewrite pass that strips the GQA broadcast chain
(fuse_kv_cache_broadcast_rule) then fuses {cache(K), cache(V), Sdpa} -> InPlaceKvSdpa
so existing decode models adopt the in-place cache transparently.
4. NNEF ser/de: round-trips via tract_transformers_inplace_kv_sdpa (registered).
5. Resume: save_to/load_from checkpoint the cache as [K,V] tensors; freeze/unfreeze
snapshot the running state in-process. Both bit-exact resume; snapshot is O(len).
Validated (11 tests): in-place bit-exact vs concat-grow; fused op matches concat-cache
+ FlashSdpaOp baseline (prefill+decode, GQA, causal/non-causal); runs end-to-end via a
persistent SimpleState; the rewrite fires + the rewritten model matches baseline; NNEF
round-trip; freeze/unfreeze and save/load resume bit-exact; growth amortized. fmt +
clippy clean; transformers lib 23/0 no-regression.
Benched (release, B=1 H=8 D=128):
- cache-update only: 21x (T=256) -> 709x (T=4096), O(T^2) -> O(T)
- end-to-end via the op: 1.10x (256) -> 1.63x (2048), 39% faster decode @2k
- resume checkpoint: O(len), 0.10ms (256) -> 1.76ms (4096), one-time
Follow-up: GPU coupling (sonos#2320 MFA kernel reading capacity buffer + length).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
eb63080 to
8837149
Compare
…e bits) Training-free affine quantize<->dequantize for the KV cache: keep every token but at fewer bits (configurable, 1..16). Keys per-CHANNEL (outlier channels get their own scale), Values per-TOKEN (KIVI, Liu et al. 2024). Gentler than evicting; works for any model. (CommVQ's RoPE-commutative codebook is a fancier follow-on.) Validated: round-trip error <= scale/2 and shrinks with bits; per-channel >> per-token on outlier channels; 8-bit near-lossless for attention output. Real GPT-2 (harness/ kv_quant_real.py): int8 ~0.5% attention deviation (near-lossless, 2x mem), graceful to int2; int4 per-channel-K beats per-token-K 1.75-1.9x on early layers. Memory = bits/16 of the f16 cache (int8 2x, int4 4x, int2 8x). 3 tests, fmt+clippy clean. Follow-on: packed-int storage + a quantized KV-cache op (dequant-on-attend), composing with the in-place (sonos#2321) / sliding-window (sonos#2327) caches; CommVQ codebook variant. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…e bits) Training-free affine quantize<->dequantize for the KV cache: keep every token but at fewer bits (configurable, 1..16). Keys per-CHANNEL (outlier channels get their own scale), Values per-TOKEN (KIVI, Liu et al. 2024). Gentler than evicting; works for any model. (CommVQ's RoPE-commutative codebook is a fancier follow-on.) Validated: round-trip error <= scale/2 and shrinks with bits; per-channel >> per-token on outlier channels; 8-bit near-lossless for attention output. Real GPT-2 (harness/ kv_quant_real.py): int8 ~0.5% attention deviation (near-lossless, 2x mem), graceful to int2; int4 per-channel-K beats per-token-K 1.75-1.9x on early layers. Memory = bits/16 of the f16 cache (int8 2x, int4 4x, int2 8x). 3 tests, fmt+clippy clean. Follow-on: packed-int storage + a quantized KV-cache op (dequant-on-attend), composing with the in-place (sonos#2321) / sliding-window (sonos#2327) caches; CommVQ codebook variant. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…e bits) Training-free affine quantize<->dequantize for the KV cache: keep every token but at fewer bits (configurable, 1..16). Keys per-CHANNEL (outlier channels get their own scale), Values per-TOKEN (KIVI, Liu et al. 2024). Gentler than evicting; works for any model. (CommVQ's RoPE-commutative codebook is a fancier follow-on.) Validated: round-trip error <= scale/2 and shrinks with bits; per-channel >> per-token on outlier channels; 8-bit near-lossless for attention output. Real GPT-2 (harness/ kv_quant_real.py): int8 ~0.5% attention deviation (near-lossless, 2x mem), graceful to int2; int4 per-channel-K beats per-token-K 1.75-1.9x on early layers. Memory = bits/16 of the f16 cache (int8 2x, int4 4x, int2 8x). 3 tests, fmt+clippy clean. Follow-on: packed-int storage + a quantized KV-cache op (dequant-on-attend), composing with the in-place (sonos#2321) / sliding-window (sonos#2327) caches; CommVQ codebook variant. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…e bits) Training-free affine quantize<->dequantize for the KV cache: keep every token but at fewer bits (configurable, 1..16). Keys per-CHANNEL (outlier channels get their own scale), Values per-TOKEN (KIVI, Liu et al. 2024). Gentler than evicting; works for any model. (CommVQ's RoPE-commutative codebook is a fancier follow-on.) Validated: round-trip error <= scale/2 and shrinks with bits; per-channel >> per-token on outlier channels; 8-bit near-lossless for attention output. Real GPT-2 (harness/ kv_quant_real.py): int8 ~0.5% attention deviation (near-lossless, 2x mem), graceful to int2; int4 per-channel-K beats per-token-K 1.75-1.9x on early layers. Memory = bits/16 of the f16 cache (int8 2x, int4 4x, int2 8x). 3 tests, fmt+clippy clean. Follow-on: packed-int storage + a quantized KV-cache op (dequant-on-attend), composing with the in-place (sonos#2321) / sliding-window (sonos#2327) caches; CommVQ codebook variant. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…e bits) Training-free affine quantize<->dequantize for the KV cache: keep every token but at fewer bits (configurable, 1..16). Keys per-CHANNEL (outlier channels get their own scale), Values per-TOKEN (KIVI, Liu et al. 2024). Gentler than evicting; works for any model. (CommVQ's RoPE-commutative codebook is a fancier follow-on.) Validated: round-trip error <= scale/2 and shrinks with bits; per-channel >> per-token on outlier channels; 8-bit near-lossless for attention output. Real GPT-2 (harness/ kv_quant_real.py): int8 ~0.5% attention deviation (near-lossless, 2x mem), graceful to int2; int4 per-channel-K beats per-token-K 1.75-1.9x on early layers. Memory = bits/16 of the f16 cache (int8 2x, int4 4x, int2 8x). 3 tests, fmt+clippy clean. Follow-on: packed-int storage + a quantized KV-cache op (dequant-on-attend), composing with the in-place (#2321) / sliding-window (#2327) caches; CommVQ codebook variant. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
So... we introduced the DynKVCache exactly for solving the concat thing, we did the groundwork, but not the actual optimisations. My (rough) initial idea was to introduce a new TensorStorage that would allow an indirection alongside one axis, the sequencing one. So this storage would be essentially a Vec<Arc>, sharing the Tensor with the DynKVCache state. Then fix the SDPA implementations so they can actually consume this input. For the non-monoblock SDPA, the consumer is a GEMM. We can either fix/extend the Packers on CPU and/or have a fallback that reify the actual tensor (so not benefiting from the optimisation at all). I don't know if this approach is better or not that fusing as you propose. But happy to hear your thoughts. Also, I'm not sure we want these new operators (whether the ones you're introducing or my proposed ReifyTensor) in the decluttered set (=> so no nnef serializing, keeping the SDPA form in NNEF because it's more semantic). So, WDYT ? |
@kali — opening this as much as an RFC as a PR: it adds an opt-in in-place KV-cache path for decode, and since it sits right next to your
DynKeyValueCache+freeze_intowork I'd like your read on the shape before it's merge-bound. It's additive and opt-in — no default behavior changes.Independent of #2319 / #2320 — not stacked
This branches off
mainand can be reviewed and merged on its own, in any order:flash_sdpa.rs; this only addsinplace_kv_cache.rsplus two wiring lines (ops/mod.rs,lib.rs). No conflict. The fused op reuses the existingFlashSdpaOp::flash_attention_gqaconsumer that transformers: CPU FlashSdpa — contiguous P·V GEMM + head-parallel exec + seq-len lowering heuristic #2319 happens to speed up, so the two are synergistic but not coupled — this works on the currentmainconsumer regardless of transformers: CPU FlashSdpa — contiguous P·V GEMM + head-parallel exec + seq-len lowering heuristic #2319.The problem
DynKeyValueCachegrows the cache withTypedConcat([past, new])every step, so steptcopies the wholet-token past into a fresh buffer — O(T²) total copy over aT-token decode (plus an allocation per step). The attention compute is already O(T²); this is pure cache-management overhead on top.The design choice (the RFC bit)
The catch: in-place growth only pays off if the consumer reads the cache without re-copying. A naive "preallocate + write-at-cursor, then slice
[0..len]for the consumer" is a wash —Tensor::slicecopies, so the per-step slice-to-valid reintroduces the O(T²). And tractTensors can't be zero-copy views across an op boundary.So I kept the K/V buffers inside the op that consumes them: a stateful fused op (
InPlaceKvSdpa) that owns two in-place caches and runs the existing CPU SDPA (FlashSdpaOp::flash_attention_gqa) over zero-copy[0..len]views. It's a drop-in for the{kv_cache(K), kv_cache(V), Sdpa}subgraph — same output by construction (same kernel, same K/V) — and because it does GQA internally, fusing also removes the unsqueeze/broadcast/reshape chain.The question for you: is fusing cache+attention the direction you want, or would you rather (a) make
DynKeyValueCacheitself in-place + a length-awareSdpareadingbuffer + length, or (b) a core-level zero-copy sub-tensor (Tensor=Arcbuffer + offset/len) so the cache can output a view — more general, bigger change? I built (the fused op) because it needs no core changes and is provably equivalent; happy to redirect.What's in the PR
InPlaceKvCache— geometric-growth (Vec-style doubling) in-place cache; appends viaTensor::assign_slice(strided-safe for any axis);valid_view()is a zero-copy ndarray view of[0..len]. O(T) amortized copy. (For the seq axis of[B,H,S,D]a per-head slice of the capacity buffer is a contiguous prefix, so the consumer reads it at concat cost.)InPlaceKvSdpa— the stateful fused op (Op/EvalOp/TypedOp+OpState+OpStateFreeze).InPlaceKvSdpaTransform— an opt-inModelTransform(likeKeyValueCacheTransform) that strips the GQA broadcast chain (reusing yourfuse_kv_cache_broadcast_rule) then fusescache → SdpaintoInPlaceKvSdpa. Apply it to get in-place decode; don't, and nothing changes.save_to/load_fromcheckpoint the cache as[K,V];freeze/unfreezesnapshot the running state — extends yourfreeze_into).Numbers (Apple M-series, f32, B=1 H=8 D=128)
save_to)These are op-level microbenches on synthetic attention — I can add a real decode-model wall-clock A/B (transform on/off, M1/M4) if you'd want that for the merge bar.
Correctness
InPlaceKvCachebit-exact vs concat-grow (multi-axis + decode); the fused op matches theconcat-cache + FlashSdpaOpbaseline over prefill+decode, GQA, causal/non-causal; runs end-to-end through a persistentSimpleState; the rewrite fires + the rewritten model matches the baseline; NNEF round-trip; freeze/unfreeze and save/load resume bit-exact; growth amortized (≤12 reallocs / 1024 pushes).cargo build --workspaceclean;tract-transformers+ blast-radius suite green; fmt + clippy clean.Apple research & prior art
MLState/StateType, in-place mutable cache/recurrent tensors).kv_cache.rs; llama.cppllama_memory_i/ past-present share-buffer; ONNX Runtimepast_present_share_buffer.Related
Cfor both the key-loop bound and the K address stride, and K is[H,D,C]so the valid prefix is strided — I validated this on-GPU). In-place K on Metal would need an owned.metalport — noted on metal: fused Sdpa via the vendored MetalFlashAttention kernel (~2×) #2320. This PR is the CPU/cross-backend path available today.🤖 Generated with Claude Code