From 0e8cd3cf09a595c2e4ecf356605763367b917c25 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Sat, 30 May 2026 09:42:01 +0000 Subject: [PATCH 01/21] linalg/build.rs: add Intel AMX int8 assembler probe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirrors the SME probe pattern: a tiny dummy_amx.S file containing the mnemonics the upcoming kernel needs (ldtilecfg, tilezero, tdpbusd, tilerelease) is compiled by the build script. On toolchains predating AMX support — notably Debian stretch's gas 2.28 — the probe fails and the `tract_amx_int8` cfg is not emitted, so the (forthcoming) kernel file is excluded from compilation and the Rust side never references the absent symbol. Dispatch then falls back to VNNI or AVX2 silently. Sets up infrastructure for the next commit which adds the actual kernel. No behaviour change yet: amx_int8_files is empty until the kernel lands. https://claude.ai/code/session_01MtQweWhf8E1W7pEJMF6ywf --- linalg/build.rs | 39 +++++++++++++++++++++++++++++++++ linalg/x86_64/avx512amx/dummy.S | 29 ++++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 linalg/x86_64/avx512amx/dummy.S diff --git a/linalg/build.rs b/linalg/build.rs index 4b298cb0eb..c8d614fb91 100644 --- a/linalg/build.rs +++ b/linalg/build.rs @@ -68,6 +68,24 @@ fn assembler_supports_avx512vnni() -> bool { .is_ok() } +// Probe whether the target assembler can actually assemble Intel AMX int8 +// instructions (`ldtilecfg`, `tilezero`, `tdpbusd`, `tilerelease`). Older +// binutils (e.g. Debian stretch's gas 2.28) predate AMX and reject these +// mnemonics outright, which would break the x86_64 build for users on those +// toolchains. When the probe fails we skip the AMX kernel entirely; the +// matching `tract_amx_int8` cfg keeps the Rust side from referencing the +// (absent) kernel symbol, and `qmmm_i32` dispatch falls back to VNNI (or +// AVX2 when VNNI is itself unavailable). +fn assembler_supports_amx_int8() -> bool { + cc::Build::new() + .file("x86_64/avx512amx/dummy.S") + .cargo_metadata(false) + .cargo_warnings(false) + .warnings(false) + .try_compile("tract_amx_int8_probe") + .is_ok() +} + fn include_sve() -> bool { // SVE/SVE2 lives on ARMv9 server/mobile cores (Neoverse V1+/N2+, Cortex-X2+, // Graviton 3/4) — Linux aarch64. No Apple silicon has SVE. @@ -165,6 +183,9 @@ fn main() { println!("cargo:rustc-check-cfg=cfg(tract_arm64_dotprod)"); // Set below only when the x86_64 assembler probe for vpdpbusd ymm passes. println!("cargo:rustc-check-cfg=cfg(tract_avx512vnni)"); + // Set below only when the x86_64 assembler accepts AMX int8 mnemonics + // (avoids breaking the build on toolchains predating AMX). + println!("cargo:rustc-check-cfg=cfg(tract_amx_int8)"); match arch.as_ref() { "x86_64" => { @@ -238,6 +259,24 @@ fn main() { cc::Build::new().file(&out).flag("-mfma").compile("x86_64_avx512vnni"); println!("cargo:rustc-cfg=tract_avx512vnni"); } + + // AMX int8 kernel lives in its own subdirectory so it can be + // gated behind a build-time assembler probe. Unix only for now; + // the kernel uses the GAS intel-syntax path. The `tract_amx_int8` + // cfg gates the Rust-side symbol reference: when the probe fails + // on old toolchains (e.g. Debian stretch's binutils 2.28), the + // kernel is omitted and `qmmm_i32` dispatch falls back to VNNI + // or AVX2 with no build error. + if os != "windows" && assembler_supports_amx_int8() { + let amx_files = + preprocess_files("x86_64/avx512amx", &[], &suffix, false); + if !amx_files.is_empty() { + cc::Build::new() + .files(&amx_files) + .compile("x86_64_avx512amx"); + println!("cargo:rustc-cfg=tract_amx_int8"); + } + } } "arm" | "armv7" => { let files = preprocess_files("arm32/armvfpv2", &[], &suffix, false); diff --git a/linalg/x86_64/avx512amx/dummy.S b/linalg/x86_64/avx512amx/dummy.S new file mode 100644 index 0000000000..544e9c749f --- /dev/null +++ b/linalg/x86_64/avx512amx/dummy.S @@ -0,0 +1,29 @@ +// Build-time capability probe for the assembler, used by build.rs +// (assembler_supports_amx_int8). Older binutils — notably the Debian stretch +// x86_64 cross-toolchain in CI — predate AMX and cannot assemble these +// mnemonics. If this file fails to assemble, build.rs skips the AMX kernels +// and the `tract_amx_int8` cfg, and the runtime falls back to VNNI (or AVX2) +// for `qmmm_i32`. Not linked into anything. +.intel_syntax noprefix +.text +.globl tract_amx_int8_probe +tract_amx_int8_probe: + push rbp + mov rbp, rsp + sub rsp, 64 // room for the tilecfg block + mov qword ptr [rsp], 0 + mov qword ptr [rsp+8], 0 + mov qword ptr [rsp+16], 0 + mov qword ptr [rsp+24], 0 + mov qword ptr [rsp+32], 0 + mov qword ptr [rsp+40], 0 + mov qword ptr [rsp+48], 0 + mov qword ptr [rsp+56], 0 + mov byte ptr [rsp], 1 // palette = 1 + ldtilecfg [rsp] + tilezero tmm0 + tdpbusd tmm0, tmm1, tmm2 + tilerelease + mov rsp, rbp + pop rbp + ret From ddf6bb6b08460c64f88ada8b16affd27c484f5c8 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Sat, 30 May 2026 10:16:34 +0000 Subject: [PATCH 02/21] linalg/x86_64: add Intel AMX int8 GEMM kernel (avx512amx_mmm_i32_8x8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Route qmmm_i32 through Intel AMX TDPBSSD when CPUID reports amx-int8/amx-tile AND the OS grants tile-data XSAVE permission (Linux: arch_prctl ARCH_REQ_XCOMP_PERM). The kernel exposes the same 8x8 ymm-accumulator tile as avx512vnni_mmm_i32_8x8 and reuses its entire post-matmul dispatcher epilogue (per_rows / per_cols / scalars / q_scale / q_shr / add_unicast / store) unchanged — only the inner-K matmul phase changes. Tile geometry inside the kernel: tmm0 (C): 8 rows x 32 colsb -> 8 M x 8 N i32 accumulator (the 8x8 tile) tmm1 (A): 8 rows x 64 colsb -> 8 M x 64 K-bytes per inner iter tmm2 (B): 16 rows x 32 colsb -> 16 K-pair-rows x (8 N-cols * 4 K-bytes) Per TDPBSSD: 8 * 8 * 64 = 4096 i32 mul-acc ops (128x a single vpdpbusd ymm). After the matmul phase, tmm0 is tilestored to a 256-byte stack scratch and loaded back as 8 row-major ymm registers, then a 24-instruction 8x8 i32 transpose (vpunpckl/h + vpunpcklqdq/h + vperm2i128) brings the accumulators into the column-major ymm0..ymm7 layout the existing epilogue expects. Packing: - B reuses the existing K=4-inner PackedI8K4 layout unchanged (the same byte layout that VNNI feeds vpdpbusd; tileloadd with stride=32 and cfg.colsb=32 reads it as one K-pair-row per tile row). - A uses a NEW M-major-within-panel layout (PackedAmxA): per 8-M-row panel, bytes are laid out row-major as panel[m*K_padded + k] = A[m, k], with K_padded = ceil(K / 64) * 64. tileloadd with stride=K_padded reads 8 contiguous M-rows of 64 K-bytes per inner iter. TDPBSSD is s8 x s8 -> i32 (Sapphire Rapids+, AMX-INT8 baseline), so no +128 bias trick is needed (unlike VNNI's vpdpbusd). The i32 accumulators are bit-identical to the AVX2 / VNNI paths. Build-time gating: a `tract_amx_int8` cfg is emitted only when the assembler accepts the AMX mnemonics (ldtilecfg, tilezero, tdpbssd, tilerelease, tileloadd, tilestored), checked by the assembler_supports_amx_int8 probe introduced in the previous commit. Old toolchains (Debian stretch binutils 2.28) fall back to VNNI silently. Runtime gating: has_amx_int8() does both CPUID (leaf 7 sub-leaf 0 EDX bits 24/25, since `is_x86_feature_detected!("amx-int8")` is gated on the nightly x86_amx_intrinsics feature) and a one-shot Linux arch_prctl ARCH_REQ_XCOMP_PERM call for XFEATURE_XTILEDATA (=18) via raw syscall. Result is OnceLock-memoised. Non-Linux returns false. Validation: - `cargo test --release -p tract-linalg`: 2885+9 tests pass, 0 failed. - The avx512amx_mmm_i32_8x8 kernel passes the full MMM property-test suite (i8i8 frame::prop, i32i32 frame::prop, store_i32/i8 row/col/arbitrary, return_q_scale across all rounding policies + pot/nonpot scales, etc.) — bit-identical to AVX2 and VNNI on the same inputs. https://claude.ai/code/session_01MtQweWhf8E1W7pEJMF6ywf --- linalg/build.rs | 51 +- linalg/src/x86_64_fma.rs | 3 + linalg/src/x86_64_fma/amx.rs | 200 +++++ linalg/src/x86_64_fma/mmm.rs | 32 + linalg/x86_64/fma/avx512amx_mmm_i32_8x8.S.j2 | 748 +++++++++++++++++++ 5 files changed, 1018 insertions(+), 16 deletions(-) create mode 100644 linalg/src/x86_64_fma/amx.rs create mode 100644 linalg/x86_64/fma/avx512amx_mmm_i32_8x8.S.j2 diff --git a/linalg/build.rs b/linalg/build.rs index c8d614fb91..b7c10fb2c7 100644 --- a/linalg/build.rs +++ b/linalg/build.rs @@ -197,6 +197,26 @@ fn main() { }); files.extend(preprocess_files("x86_64/avx512", &[], &suffix, false)); + // Pull the AMX kernel template out of the generic fma bulk-compile + // so it can be gated behind the assembler probe below. Its + // mnemonics (`ldtilecfg`, `tdpbssd`, `tilezero`, `tilerelease`) + // require gas >= 2.34; old toolchains (Debian stretch's binutils + // 2.28) would otherwise fail the whole build. The kernel template + // lives next to its Jinja partials (`dispatcher.j2`, the i32 + // epilogue includes); only the compile of the rendered .S is + // moved. + let amx_int8_files: Vec = files + .iter() + .filter(|f| { + f.file_name() + .and_then(|n| n.to_str()) + .map(|n| n.starts_with("avx512amx_")) + .unwrap_or(false) + }) + .cloned() + .collect(); + files.retain(|f| !amx_int8_files.contains(f)); + if os == "windows" { if use_masm() { let mut lib_exe = cc::windows_registry::find(&target, "lib.exe") @@ -260,22 +280,21 @@ fn main() { println!("cargo:rustc-cfg=tract_avx512vnni"); } - // AMX int8 kernel lives in its own subdirectory so it can be - // gated behind a build-time assembler probe. Unix only for now; - // the kernel uses the GAS intel-syntax path. The `tract_amx_int8` - // cfg gates the Rust-side symbol reference: when the probe fails - // on old toolchains (e.g. Debian stretch's binutils 2.28), the - // kernel is omitted and `qmmm_i32` dispatch falls back to VNNI - // or AVX2 with no build error. - if os != "windows" && assembler_supports_amx_int8() { - let amx_files = - preprocess_files("x86_64/avx512amx", &[], &suffix, false); - if !amx_files.is_empty() { - cc::Build::new() - .files(&amx_files) - .compile("x86_64_avx512amx"); - println!("cargo:rustc-cfg=tract_amx_int8"); - } + // AMX int8 kernel: compile only when the assembler accepts the + // mnemonics, and the kernel template was actually pulled aside + // above. Unix only for now (the .S uses the GAS intel-syntax + // path). The `tract_amx_int8` cfg gates the Rust-side symbol + // reference: when the probe fails on old toolchains (e.g. Debian + // stretch's binutils 2.28), the kernel is omitted and `qmmm_i32` + // dispatch falls back to VNNI or AVX2 with no build error. + if os != "windows" + && !amx_int8_files.is_empty() + && assembler_supports_amx_int8() + { + cc::Build::new() + .files(&amx_int8_files) + .compile("x86_64_avx512amx"); + println!("cargo:rustc-cfg=tract_amx_int8"); } } "arm" | "armv7" => { diff --git a/linalg/src/x86_64_fma.rs b/linalg/src/x86_64_fma.rs index e61baa2efe..60099ee74a 100644 --- a/linalg/src/x86_64_fma.rs +++ b/linalg/src/x86_64_fma.rs @@ -9,6 +9,9 @@ pub mod mmm; pub mod act; pub mod act_f16; pub mod act_f16_fp16; + +#[cfg(tract_amx_int8)] +pub mod amx; pub mod by_scalar; pub mod erf; mod intel; diff --git a/linalg/src/x86_64_fma/amx.rs b/linalg/src/x86_64_fma/amx.rs new file mode 100644 index 0000000000..d7371e7c4b --- /dev/null +++ b/linalg/src/x86_64_fma/amx.rs @@ -0,0 +1,200 @@ +// Intel AMX int8 support: A packing format and runtime gate. +// +// The kernel `avx512amx_mmm_i32_8x8` uses TDPBSSD (signed-signed). Per +// iteration of its inner loop it consumes one 8x64-byte A tile and one +// 16x32-byte B tile and updates an 8x8 i32 C tile. The B-side packing +// matches the existing K=4-inner `PackedI8K4` layout, so it is reused +// unchanged. The A-side packing is novel: AMX's tile-A semantics require +// M-major-within-panel row-major bytes, which is incompatible with the +// K-major-outer `PackedI8K4`. `PackedAmxA` below produces that layout. +// +// Runtime gate: CPUID `amx-int8` is necessary but not sufficient on Linux — +// the kernel must also call `arch_prctl(ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)` +// to receive AMX tile-data XSAVE permission from the kernel before any tile +// instruction can run. `has_amx_int8()` performs both checks once and caches +// the result; it returns false on non-Linux even if CPUID reports AMX. + +use std::sync::OnceLock; + +use tract_data::internal::*; + +use crate::WeightType; +use crate::frame::mmm::{ + EagerPackedInput, MMMInputFormat, MMMInputValue, PackedExoticFact, PackedMatrixStorage, +}; + +/// Detect AMX-INT8 + AMX-TILE via CPUID leaf 7 sub-leaf 0 (EDX bits 24-25). +/// Stable-Rust friendly: `is_x86_feature_detected!("amx-int8")` is gated on +/// the nightly `x86_amx_intrinsics` feature, so we read CPUID by hand. +fn cpu_has_amx_int8() -> bool { + if !std::is_x86_feature_detected!("avx512f") { + return false; + } + let r = std::arch::x86_64::__cpuid_count(7, 0); + // bit 24 = AMX-TILE, bit 25 = AMX-INT8 in EDX. + const AMX_TILE: u32 = 1 << 24; + const AMX_INT8: u32 = 1 << 25; + (r.edx & AMX_TILE) != 0 && (r.edx & AMX_INT8) != 0 +} + +/// Linux only: ask the kernel for permission to use the AMX tile-data XSAVE +/// state via `arch_prctl(ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)`. Returns +/// true if the kernel grants permission (or if the process already has it). +#[cfg(target_os = "linux")] +unsafe fn request_amx_xcomp_perm() -> bool { + // x86_64 syscall: rax=158 (arch_prctl), rdi=0x1023 (REQ_XCOMP_PERM), + // rsi=18 (XFEATURE_XTILEDATA). Returns 0 on success. + let rc: i64; + unsafe { + std::arch::asm!( + "syscall", + in("rax") 158i64, + in("rdi") 0x1023i64, + in("rsi") 18i64, + lateout("rax") rc, + out("rcx") _, + out("r11") _, + options(nostack), + ); + } + rc == 0 +} + +/// Returns true iff Intel AMX int8 is available AND the OS has granted this +/// process permission to use the AMX tile-data XSAVE state. Result is +/// memoised — the arch_prctl call has process-wide effect and only needs to +/// run once. +pub fn has_amx_int8() -> bool { + static GATE: OnceLock = OnceLock::new(); + *GATE.get_or_init(|| { + if !cpu_has_amx_int8() { + return false; + } + #[cfg(target_os = "linux")] + { + unsafe { request_amx_xcomp_perm() } + } + #[cfg(not(target_os = "linux"))] + { + false + } + }) +} + +/// AMX-friendly A packing: per `r`-row panel, M-rows are laid out row-major +/// across `K_padded = ceil(K / 64) * 64` contiguous bytes per row. AMX's +/// `tileloadd` with stride = K_padded reads exactly 8 contiguous M-rows of +/// 64 K-bytes each per call. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct PackedAmxA { + pub r: usize, + pub align: usize, +} + +impl PackedAmxA { + pub fn new(r: usize) -> Self { + PackedAmxA { r, align: 64 } + } + fn k_padded(&self, k: usize) -> usize { + k.div_ceil(64) * 64 + } + fn panel(&self, k: usize) -> usize { + self.k_padded(k) * self.r + } + pub fn single_panel_len(&self, k: usize) -> usize { + self.panel(k) + } + pub fn len(&self, k: usize, mn: usize) -> usize { + mn.div_ceil(self.r) * self.panel(k) + } + pub fn alignment(&self) -> usize { + self.align + } + + pub fn pack_view( + &self, + t: &TensorView, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + let k = t.shape()[k_axis]; + let mn = t.shape()[mn_axis]; + let kp = self.k_padded(k); + let pl = kp * self.r; + let panels = mn.div_ceil(self.r); + let st = t.strides(); + let (ks, ms) = (st[k_axis], st[mn_axis]); + let mut blob = unsafe { Blob::new_for_size_and_align(panels * pl, self.align) }; + blob.as_bytes_mut().fill(0); + unsafe { + let src = t.as_ptr_unchecked::(); + let dst = blob.as_mut_ptr() as *mut i8; + for p in 0..panels { + let pw = self.r.min(mn - p * self.r); + let panel = dst.add(p * pl); + let mn0 = (p * self.r) as isize; + for lm in 0..pw { + let drow = panel.add(lm * kp); + let srow_base = src.offset((mn0 + lm as isize) * ms); + for kk in 0..k { + *drow.add(kk) = *srow_base.offset(kk as isize * ks); + } + } + } + } + Ok(Box::new(EagerPackedInput { + fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k }, + packed: blob.into(), + panel_bytes: pl, + mn, + })) + } +} + +impl std::fmt::Display for PackedAmxA { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "AmxA[{}]", self.r) + } +} + +impl MMMInputFormat for PackedAmxA { + fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult { + Ok(PackedMatrixStorage::new(self.prepare_one(t, k_axis, mn_axis)?) + .into_tensor(t.datum_type())) + } + fn prepare_one( + &self, + t: &Tensor, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + self.pack_view(&t.view(), k_axis, mn_axis) + } + fn precursor(&self) -> WeightType { + WeightType::Plain(i8::datum_type()) + } + fn r(&self) -> usize { + self.r + } + fn k_alignment(&self) -> usize { + // AMX consumes K=64 bytes per tdpbssd inner iteration; the packer + // already pads internally, but expose the alignment so upstream + // schedulers can reason about K-blocking. + 64 + } + fn merge_with<'o, 'a: 'o, 'b: 'o>( + &'a self, + o: &'b dyn MMMInputFormat, + ) -> Option<&'o dyn MMMInputFormat> { + o.downcast_ref::().filter(|x| x.r == self.r).map(|_| self as _) + } + fn mem_size(&self, k: TDim, mn: TDim) -> TDim { + mn.divceil(self.r) * self.panel(k.to_usize().unwrap_or(0)) + } + fn extract_at_mn_f16(&self, _: &EagerPackedInput, _: usize, _: &mut [f16]) -> TractResult<()> { + bail!("no f16 extract") + } + fn extract_at_mn_f32(&self, _: &EagerPackedInput, _: usize, _: &mut [f32]) -> TractResult<()> { + bail!("no f32 extract") + } +} diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index eaaf40b9d0..21d8a892d2 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -4,8 +4,13 @@ use crate::mmm::ImplementationQuality::ManuallyOptimized; use crate::mmm::MatMatMul; use crate::pack::{PackedFormat, PackedI8K4}; +#[cfg(tract_amx_int8)] +use super::amx::{PackedAmxA, has_amx_int8}; use super::*; +#[cfg(tract_amx_int8)] +const AVX512AMX: fn() -> bool = has_amx_int8; + /// One candidate kernel in a dispatcher's pool, with its tile geometry /// and a relative-throughput scale (1.0 = baseline, used to break /// near-ties between kernels with similar tile waste). @@ -107,6 +112,20 @@ MMMExternKernel! { avx512vnni_mmm_i32_8x8(8,8)@(256,4) where(AVX512VNNI) store(i8) } +// Same epilogue as avx512vnni_mmm_i32_8x8 (8x8 ymm accumulators), but the i8i8 +// matmul inner loop uses TDPBSSD (16-M x 16-N x 64-K mul-acc per instruction) +// over AMX tiles. A's packing is novel (PackedAmxA, M-major-within-panel, +// K-padded to multiples of 64); B reuses VNNI's K=4-inner PackedI8K4 layout +// unchanged. TDPBSSD is s8 x s8 so no +128 bias trick — accumulators are +// bit-identical to AVX2/VNNI. Gated by `where(AVX512AMX)` (= CPUID amx-int8 +// AND Linux XSAVE permission via arch_prctl). +#[cfg(tract_amx_int8)] +MMMExternKernel! { avx512amx_mmm_i32_8x8(8,8)@(64,4) where(AVX512AMX) + packing[1] = i8i8 => |k| k.with_packing(PackedAmxA::new(8), PackedI8K4::new(8)); + quality(ManuallyOptimized) + store(i8) +} + pub fn plug(ops: &mut Ops) { if is_x86_feature_detected!("avx2") { plug_avx2(ops); @@ -117,6 +136,12 @@ pub fn plug(ops: &mut Ops) { #[cfg(tract_avx512vnni)] if is_x86_feature_detected!("avx512vnni") { plug_avx512vnni(ops); + // AMX int8 preferred over VNNI when both available AND the OS + // has granted XSAVE tile-data permission (see `has_amx_int8`). + #[cfg(tract_amx_int8)] + if has_amx_int8() { + plug_avx512amx_int8(ops); + } } } } @@ -130,6 +155,13 @@ pub fn plug_avx512vnni(ops: &mut Ops) { log::info!("qmmm_i32: x86_64/avx512vnni activated"); } +#[cfg(tract_amx_int8)] +pub fn plug_avx512amx_int8(ops: &mut Ops) { + ops.mmm_impls.push(avx512amx_mmm_i32_8x8.mmm()); + ops.qmmm_i32 = Box::new(|_, _, _| avx512amx_mmm_i32_8x8.mmm()); + log::info!("qmmm_i32: x86_64/avx512amx_int8 activated"); +} + pub fn plug_avx2(ops: &mut Ops) { ops.mmm_impls.push(mmm::avx2_mmm_i32_8x8.mmm()); ops.qmmm_i32 = Box::new(|_, _, _| mmm::avx2_mmm_i32_8x8.mmm()); diff --git a/linalg/x86_64/fma/avx512amx_mmm_i32_8x8.S.j2 b/linalg/x86_64/fma/avx512amx_mmm_i32_8x8.S.j2 new file mode 100644 index 0000000000..6ed351d10a --- /dev/null +++ b/linalg/x86_64/fma/avx512amx_mmm_i32_8x8.S.j2 @@ -0,0 +1,748 @@ +{# +// vim: set syntax=asm : + +/* mmm 8x8: + + ymm0 ymm1 ymm2 ymm3 ymm4 ymm5 ymm6 ymm7 + +System V ABI: + args: rdi, rsi, rdx, rcx, r8, r9 + preserve: rbx, rsp, rbp, r12, r13, r14, r15 + scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11 + return: rax (+rdx) + +Windows ABI: + args: RCX, RDX, R8, R9 + preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15 + scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of YMM0-15 and ZMM0-15 + return: rax (+rdx) +*/ +#} + +{% if msvc %} + +_text segment +avx512amx_mmm_i32_8x8_{{suffix}} proc + +{% else %} + +.intel_syntax noprefix +.text +.p2align 5 +.globl {{G}}avx512amx_mmm_i32_8x8_{{suffix}} +{{G}}avx512amx_mmm_i32_8x8_{{suffix}}: +.cfi_startproc + +{% endif %} + + push rbp + mov rbp, rsp + +{% if family == "windows" %} +// https://www.agner.org/optimize/calling_conventions.pdf xmm6-15 are not scratch +// https://stackoverflow.com/questions/43358429/save-value-of-xmm-registers + and rsp,-16 + lea rsp,[rsp-160] + vmovaps [rsp], xmm6 + vmovaps [rsp+16*1],xmm7 + vmovaps [rsp+16*2],xmm8 + vmovaps [rsp+16*3],xmm9 + vmovaps [rsp+16*4],xmm10 + vmovaps [rsp+16*5],xmm11 + vmovaps [rsp+16*6],xmm12 + vmovaps [rsp+16*7],xmm13 + vmovaps [rsp+16*8],xmm14 + vmovaps [rsp+16*9],xmm15 + + push rdi + push rsi + + mov rdi, rcx + +{% endif %} + + push rbx + push r12 + push r13 + push r14 + push r15 + + sub rsp, 8 + +{% if family == "unix" %} +.cfi_def_cfa_offset 64 +{% endif %} + + stmxcsr [rsp + 4] +{% if msvc %} + mov rax, 1FC0h +{% else %} + mov rax, 0x1FC0 +{% endif %} + mov [rsp], eax + ldmxcsr [rsp] + + // Reserve 64 bytes of stack for the AMX tile-config block, zero it, + // populate palette + tile dimensions, then ldtilecfg. The tile config + // stays live for the whole function; tilerelease is emitted at return. + // + // tmm0 = C accumulator: 8 rows x 32 colsb (8 M-rows x 8 N-cols of i32) + // tmm1 = A tile: 8 rows x 64 colsb (8 M-rows x 64 K-bytes per inner iter) + // tmm2 = B tile: 16 rows x 32 colsb (16 K-pair-rows x 8 N-cols * 4 K-bytes) + sub rsp, 64 + vpxor xmm15, xmm15, xmm15 + vmovdqu [rsp ], xmm15 + vmovdqu [rsp + 16], xmm15 + vmovdqu [rsp + 32], xmm15 + vmovdqu [rsp + 48], xmm15 + mov byte ptr [rsp + 0 ], 1 // palette = 1 + mov word ptr [rsp + 16], 32 // colsb[0] = 32 (tmm0) + mov word ptr [rsp + 18], 64 // colsb[1] = 64 (tmm1) + mov word ptr [rsp + 20], 32 // colsb[2] = 32 (tmm2) + mov byte ptr [rsp + 48], 8 // rows[0] = 8 (tmm0) + mov byte ptr [rsp + 49], 8 // rows[1] = 8 (tmm1) + mov byte ptr [rsp + 50], 16 // rows[2] = 16 (tmm2) + ldtilecfg [rsp] + +{% include "dispatcher.j2" %} + +{{L}}clear: + vzeroall + jmp {{L}}non_linear_loop + +{{L}}add_mat_mul: + mov r12, [rdi + 32] // packing + mov rbx, [rdi + 24] // B + mov rax, [rdi + 16] // A + + mov rcx, [rdi + 8] // k + test rcx, rcx + jz {{L}}non_linear_loop + + cmp r12, 1 + je {{L}}main_loop_packed_packed_i8i8 + +{{L}}main_loop_packed_packed: + vmovaps ymm12, [rax] + + {% for i in range(0, 8) %} + vbroadcastss ymm14, dword ptr [rbx + {{i}} * 4] + vpmulld ymm13, ymm12, ymm14 + vpaddd ymm{{i}}, ymm{{i}}, ymm13 + {% endfor %} + + add rax, 32 + add rbx, 32 + dec rcx + jnz {{L}}main_loop_packed_packed + + jmp {{L}}non_linear_loop + +{{L}}main_loop_packed_packed_i8i8: + // AMX i8 layout: A panel is 8 M-rows x K_padded K-bytes ROW-major within + // each 8-row panel (PackedAmxA8); B panel reuses the existing VNNI K=4- + // inner format (8 N-cols x 4 K-bytes per K-block, 16 such blocks per + // K=64 AMX tile). K is padded to a multiple of 64 by the packer. + // + // tdpbssd is s8 x s8 -> i32 (Sapphire Rapids+), so no +128 trick is needed: + // the i32 accumulators are bit-identical to the AVX2 / VNNI paths. + // + // Per tdpbssd: tmm0[m, n] += sum_{k=0..63} A[m, k] * B[k, n] + // (16 M-rows x 16 N-i32-lanes x 64 K = 16384 mul-acc per instruction) + + // r8 <- K_padded = ceil(k/64) * 64 = byte-stride between A's M-rows. + mov r8, rcx + add r8, 63 + and r8, -64 + + // rcx <- ceil(k/64) = number of K=64 AMX inner iterations. + add rcx, 63 + shr rcx, 6 + + // r9 <- 32 = byte-stride between B's K-pair-rows. + mov r9, 32 + + tilezero tmm0 + +{{L}}loop_64k_amx_i8i8: + tileloadd tmm1, [rax + r8 * 1] // A tile: stride r8 = K_padded + tileloadd tmm2, [rbx + r9 * 1] // B tile: stride r9 = 32 + tdpbssd tmm0, tmm1, tmm2 + add rax, 64 // +64 K-bytes in A row 0 + add rbx, 512 // +16 K-pairs * 32 = 512 B bytes + dec rcx + jnz {{L}}loop_64k_amx_i8i8 + + // tmm0 -> ymm0..ymm7 via 256-byte stack scratch (8 rows x 32 bytes). + // After tilestored, the layout is row-major i32: byte (m*32 + n*4) = C[m, n]. + // We need ymm{n} = column n of C with 8 i32 lanes (rows m=0..7) — the + // dispatcher epilogue convention. So we (a) load 8 ymms = 8 rows of C, + // then (b) transpose 8x8 i32 in place. + sub rsp, 256 + mov r10, rsp + mov r11, 32 + tilestored [r10 + r11 * 1], tmm0 + + {% for r in range(0, 8) %} + vmovdqu ymm{{r}}, [r10 + {{ r * 32 }}] + {% endfor %} + + add rsp, 256 + + // 8x8 i32 transpose: ymm0..ymm7 row-major -> column-major in place. + // Stage 1: interleave 32-bit dwords pairwise (ymm0..ymm7 -> ymm8..ymm15). + vpunpckldq ymm8, ymm0, ymm1 // [r0[0], r1[0], r0[1], r1[1], r0[4], r1[4], r0[5], r1[5]] + vpunpckhdq ymm9, ymm0, ymm1 + vpunpckldq ymm10, ymm2, ymm3 + vpunpckhdq ymm11, ymm2, ymm3 + vpunpckldq ymm12, ymm4, ymm5 + vpunpckhdq ymm13, ymm4, ymm5 + vpunpckldq ymm14, ymm6, ymm7 + vpunpckhdq ymm15, ymm6, ymm7 + + // Stage 2: interleave 64-bit quads (ymm8..ymm15 -> ymm0..ymm7). + vpunpcklqdq ymm0, ymm8, ymm10 // [r0[0], r1[0], r2[0], r3[0], r0[4], r1[4], r2[4], r3[4]] + vpunpckhqdq ymm1, ymm8, ymm10 + vpunpcklqdq ymm2, ymm9, ymm11 + vpunpckhqdq ymm3, ymm9, ymm11 + vpunpcklqdq ymm4, ymm12, ymm14 + vpunpckhqdq ymm5, ymm12, ymm14 + vpunpcklqdq ymm6, ymm13, ymm15 + vpunpckhqdq ymm7, ymm13, ymm15 + + // Stage 3: cross-lane permute (128-bit halves). Two phases so we can + // overwrite the inputs incrementally without clobbering needed data. + vperm2i128 ymm8, ymm0, ymm4, 0x20 // col 0: low(y0) | low(y4) + vperm2i128 ymm9, ymm1, ymm5, 0x20 // col 1 + vperm2i128 ymm10, ymm2, ymm6, 0x20 // col 2 + vperm2i128 ymm11, ymm3, ymm7, 0x20 // col 3 + vperm2i128 ymm12, ymm0, ymm4, 0x31 // col 4: high(y0) | high(y4) + vperm2i128 ymm13, ymm1, ymm5, 0x31 // col 5 + vperm2i128 ymm14, ymm2, ymm6, 0x31 // col 6 + vperm2i128 ymm15, ymm3, ymm7, 0x31 // col 7 + + vmovdqa ymm0, ymm8 + vmovdqa ymm1, ymm9 + vmovdqa ymm2, ymm10 + vmovdqa ymm3, ymm11 + vmovdqa ymm4, ymm12 + vmovdqa ymm5, ymm13 + vmovdqa ymm6, ymm14 + vmovdqa ymm7, ymm15 + + jmp {{L}}non_linear_loop + +{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_scalars.j2" %} +{% set mr = 8 %}{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_per_rows.j2" %} +{% set mr = 8 %}{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_per_cols.j2" %} +{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_load_tile.j2" %} + +{{L}}add_unicast: + + mov r10, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rbx, [rdi + 24] // col stride + mov r8, [rdi + 32] // item size + + cmp r8, 4 + je {{L}}non_linear_addc_i32 + +{# +// This is not great as vgatherdps reads 32-bits values and goes beyond our buffer. Probably harmless though. +// Commented and replaced with the "mov al" loop beyond to pacify valgrind. +// ymm14 and ymm15 are the same as in the non_linear_addc_i32 case (compute them before the test right above here. +// {% for i in range(0, 8) %} +// vpcmpeqd ymm15, ymm15, ymm15 +// vgatherdps ymm12, [ r10 + ymm14 ], ymm15 // 0xxx 1xxx 2xxx 3xxx 4xxx 5xxx 6xxx 7xxx +// +// // we need to go through vpmovsxbd, shuffling naively erases signs +// vpshufb ymm12, ymm12, ymm10 // 0123 0123 0123 0123 4567 4567 4567 4567 +// +// vpermd ymm12, ymm11, ymm12 // 0123 4567 +// vpmovsxbd ymm12, xmm12 // sign extend +// +// vpaddd ymm{{i}}, ymm{{i}}, ymm12 +// add r10, rbx +// {% endfor %} +#} + + {% for col in range(0, 8) %} + mov r8, r10 + {% for half in range(0, 2) %} + {% for lane in range(0, 4) %} + mov al, [ r8 ] + add r8, rsi + movsx eax, al + pinsrd xmm10, eax, {{lane}} + {% endfor %} + vperm2f128 ymm10, ymm10, ymm10, 1 + {% endfor %} + vpaddd ymm{{col}}, ymm{{col}}, ymm10 + add r10, rbx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}non_linear_addc_i32: + + mov eax, 0 +{% for i in range(0, 4) %} + pinsrd xmm14, eax, {{i}} + add eax, esi +{% endfor %} + vpermq ymm14, ymm14, 78 // 0b01001110 +{% for i in range(0, 4) %} + pinsrd xmm14, eax, {{i}} + add eax, esi +{% endfor %} + vpermq ymm14, ymm14, 78 // 0b01001110 + + +{% if msvc %} + vpbroadcastd ymm10, dword ptr [ offset byte_shuffle ] + vmovups ymm11, dword ptr [ offset i128_shuffle ] +{% else %} + vpbroadcastd ymm10, [ rip + {{L}}byte_shuffle ] + vmovups ymm11, [ rip + {{L}}i128_shuffle ] +{% endif %} + +{% for i in range(0, 8) %} + vpcmpeqd ymm15, ymm15, ymm15 + vgatherdps ymm12, [ r10 + ymm14 ], ymm15 + vpaddd ymm{{i}}, ymm{{i}}, ymm12 + add r10, rbx +{% endfor %} + + jmp {{L}}non_linear_loop + +{% if msvc %} +.data +byte_shuffle dd 201851904 // 0x0c080400 +i128_shuffle dd 0, 4 +.code +{% else %} +{{L}}byte_shuffle: .int 201851904 // 0x0c080400 +{{L}}i128_shuffle: .int 0, 4 +{% endif %} + +{{L}}add_row_col_products: + mov rax, [ rdi + 8 ] + mov rbx, [ rdi + 16 ] + + vmovups ymm12, [rax] + +{% for i in range(0, 8) %} + vbroadcastss ymm14, dword ptr [rbx + {{ i * 4 }} ] + vpmulld ymm15, ymm12, ymm14 + vpaddd ymm{{i}}, ymm{{i}}, ymm15 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale: + mov r8, [ rdi + 16 ] // policy + vbroadcastss ymm8, dword ptr [rdi + 24] // multi + + mov rax, 1 + movq xmm9, rax + vpbroadcastq ymm9, xmm9 // ymm9 <- 1 + + mov rax, [ rdi + 8 ] // xmm10 <- shift + 31 + add rax, 31 + movq xmm10, rax + vpbroadcastq ymm10, xmm10 + + mov rax, 1 + movq xmm11, rax + vpsubq ymm12, ymm10, ymm9 // shift+31 - 1 + vpsllq ymm11, ymm9, xmm12 // ymm11 <- 1 << (shift + 31 - 1) + + cmp r8, 1 + je {{L}}q_scale_rounding_zero + cmp r8, 2 + je {{L}}q_scale_rounding_away + cmp r8, 3 + je {{L}}q_scale_rounding_minus_inf + cmp r8, 4 + je {{L}}q_scale_rounding_plus_inf + cmp r8, 5 + je {{L}}q_scale_rounding_even + cmp r8, 6 + je {{L}}q_scale_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_scale_rounding_zero: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsubq ymm14, ymm14, ymm9 + vpsubq ymm15, ymm15, ymm9 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_away: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_minus_inf: // signum * ( (abs << 32 + 1<<30+shift) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + // sign extract for nudging in the right direction + vpxor ymm13, ymm13, ymm13 + vpcmpgtd ymm13, ymm{{i}}, ymm13 // ymm13 <- s0, s1, ..s8 (signums, as all ones or all zeros) + vpsrld ymm13, ymm13, 31 // then just 0 or 1 + + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + // reinterpret ymm13=s0i32..s7 as i64 and blend with zero to pick the even ones as i64 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm14, ymm14, ymm12 + + vpsrldq ymm13, ymm13, 4 // ymm13 <- s1, s2, .., s7, 0 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm15, ymm15, ymm12 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_plus_inf: // signum * ( (abs << 32 + 1<<30+shift) >> shift ) + + vpbroadcastd ymm9, xmm9 + +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpxor ymm13, ymm13, ymm13 + + // sign extract for nudging in the right direction + vpcmpgtd ymm13, ymm{{i}}, ymm13 // ymm13 <- s0, s1, ..s8 (signums, as all ones or all zeros) + vpaddd ymm13, ymm13, ymm9 // if val >= 0 { 0i32 } else { 1i32 } + + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + // reinterpret ymm13=s0i32..s7 as i64 and blend with zero to pick the even ones as i64 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm14, ymm14, ymm12 + + vpsrldq ymm13, ymm13, 4 // ymm13 <- s1, s2, .., s7, 0 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm15, ymm15, ymm12 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_even: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpsrlq ymm12, ymm14, xmm10 + vpand ymm12, ymm12, ymm9 + vpaddq ymm14, ymm14, ymm12 + vpsubq ymm14, ymm14, ymm9 + + vpsrlq ymm12, ymm15, xmm10 + vpand ymm12, ymm12, ymm9 + vpaddq ymm15, ymm15, ymm12 + vpsubq ymm15, ymm15, ymm9 + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_odd: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpsrlq ymm12, ymm14, xmm10 + vpand ymm12, ymm12, ymm9 + vpsubq ymm14, ymm14, ymm12 + + vpsrlq ymm12, ymm15, xmm10 + vpand ymm12, ymm12, ymm9 + vpsubq ymm15, ymm15, ymm12 + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_shl: + mov eax, [ rdi + 8 ] // xmm10 <- -shift (8 times) + movd xmm10, eax + vpbroadcastd ymm10, xmm10 + +{% for i in range(0, 8) %} + vpsllvd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr: + mov r8, [ rdi + 16 ] // policy + + mov eax, 1 + movd xmm9, eax + vpbroadcastd ymm9, xmm9 // ymm9 <- 1u32 (8 times) + + mov eax, [ rdi + 8 ] // xmm10 <- shift (8 times) + movd xmm10, eax + vpbroadcastd ymm10, xmm10 + + mov ebx, 1 + mov cl, al + sub cl, 1 // rcx <- shift -1 + sal ebx, cl // rbx <- (1 << (shift - 1)) + movd xmm11, ebx + vpbroadcastd ymm11, xmm11 // ymm11 <- "half" + + vpxor ymm12, ymm12, ymm12 // ymm12 <- zeroes + + cmp r8, 1 + je {{L}}q_shr_rounding_zero + cmp r8, 2 + je {{L}}q_shr_rounding_away + cmp r8, 3 + je {{L}}q_shr_rounding_minus_inf + cmp r8, 4 + je {{L}}q_shr_rounding_plus_inf + cmp r8, 5 + je {{L}}q_shr_rounding_even + cmp r8, 6 + je {{L}}q_shr_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_shr_rounding_zero: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsubd ymm14, ymm14, ymm9 + vpaddd ymm14, ymm14, ymm11 + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_away: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpaddd ymm14, ymm14, ymm11 + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_minus_inf: +{% for i in range(0, 8) %} + vpsubd ymm{{i}}, ymm{{i}}, ymm9 + vpaddd ymm{{i}}, ymm{{i}}, ymm11 + vpsravd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_plus_inf: +{% for i in range(0, 8) %} + vpaddd ymm{{i}}, ymm{{i}}, ymm11 + vpsravd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_even: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsravd ymm13, ymm14, ymm10 + vpand ymm13, ymm13, ymm9 + vpsubd ymm13, ymm13, ymm9 // nudge = ((abs >>l shift) & 0x01) - 1 + vpaddd ymm14, ymm14, ymm13 // add nudge + vpaddd ymm14, ymm14, ymm11 // add half + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_odd: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsravd ymm13, ymm14, ymm10 + vpand ymm13, ymm13, ymm9 + vpsubd ymm13, ymm12, ymm13 // nudge = - ((abs >>l shift) & 0x01) + vpaddd ymm14, ymm14, ymm13 // add nudge + vpaddd ymm14, ymm14, ymm11 // add half + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store: + mov r8, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rdx, [rdi + 24] // col stride + mov rcx, [rdi + 32] // item size + + cmp rcx, 4 + je {{L}}store_strides_i32 + + {% for col in range(0, 8) %} + mov r10, r8 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov byte ptr [r10], bl + add r10, rsi + {% endfor %} + vperm2f128 ymm{{col}}, ymm{{col}}, ymm{{col}}, 1 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov byte ptr [r10], bl + add r10, rsi + {% endfor %} + add r8, rdx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}store_strides_i32: + {% for col in range(0, 8) %} + mov r10, r8 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov dword ptr [r10], ebx + add r10, rsi + {% endfor %} + vperm2f128 ymm{{col}}, ymm{{col}}, ymm{{col}}, 1 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov dword ptr [r10], ebx + add r10, rsi + {% endfor %} + add r8, rdx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}return: + // Tear down AMX state: release tile registers and reclaim the tile-config + // stack space we allocated right after the standard prologue. + tilerelease + add rsp, 64 + + ldmxcsr [rsp + 4] + add rsp, 8 + + pop r15 + pop r14 + pop r13 + pop r12 + pop rbx + +{% if family == "windows" %} + pop rsi + pop rdi + + vmovaps xmm15, [rsp+16*9] + vmovaps xmm14, [rsp+16*8] + vmovaps xmm13, [rsp+16*7] + vmovaps xmm12, [rsp+16*6] + vmovaps xmm11, [rsp+16*5] + vmovaps xmm10, [rsp+16*4] + vmovaps xmm9, [rsp+16*3] + vmovaps xmm8, [rsp+16*2] + vmovaps xmm7, [rsp+16*1] + vmovaps xmm6, [rsp] +{% endif %} + + mov rsp, rbp + pop rbp + ret + + +{{L}}one_32bit: +{% if msvc %} + dd 1 +{% else %} + .int 1 +{% endif %} + +{% if msvc %} +avx512amx_mmm_i32_8x8_{{suffix}} endp +_text ends +end +{% else %} +.cfi_endproc +{% endif %} From f6d9f299418596eb9ed4734424cc7b0903f09c06 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Sat, 30 May 2026 10:29:04 +0000 Subject: [PATCH 03/21] linalg: add amx_i32 Criterion microbench (AVX2 / VNNI / AMX i8) Same M/K/N shapes as the vnni_i32 bench (64x256x64, 256x256x256, 512x512x512, 1024x1024x64). All three kernels run the i8i8 packing path (index 1) so the only difference is the matmul inner loop. Skipped at runtime when `has_amx_int8()` returns false (= CPUID lacks amx-int8/tile or the arch_prctl XSAVE permission was denied), and at build time when the `tract_amx_int8` cfg was not emitted. https://claude.ai/code/session_01MtQweWhf8E1W7pEJMF6ywf --- linalg/Cargo.toml | 4 ++ linalg/benches/amx_i32.rs | 81 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 linalg/benches/amx_i32.rs diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index 8ed7314afe..07fbe299b1 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -167,3 +167,7 @@ harness = false [[bench]] name = "vnni_i32" harness = false + +[[bench]] +name = "amx_i32" +harness = false diff --git a/linalg/benches/amx_i32.rs b/linalg/benches/amx_i32.rs new file mode 100644 index 0000000000..d3fc758c3c --- /dev/null +++ b/linalg/benches/amx_i32.rs @@ -0,0 +1,81 @@ +#![allow(dead_code)] +// Kernel-level benchmark: Intel AMX int8 GEMM (avx512amx_mmm_i32_8x8, TDPBSSD over +// 8x8 i32 tile with K=64 inner) vs the AVX-512 VNNI int8 path (avx512vnni_mmm_i32_8x8, +// VPDPBUSD over PackedI8K4 with K=4 inner) vs the AVX2 int8 path +// (avx2_mmm_i32_8x8, vpmaddubsw-style widening). All three run the same i8i8 +// packing index (1) over the same M/K/N so the only difference is the matmul +// inner loop. +use criterion::*; +use tract_data::internal::*; +use tract_linalg::mmm::{AsInputValue, FusedSpec, MatMatMul}; + +fn run_kernel(be: &mut Bencher, mmm: &dyn MatMatMul, m: usize, k: usize, n: usize) { + let a = Tensor::zero_dt(DatumType::I8, &[m, k]).unwrap(); + let b = Tensor::zero_dt(DatumType::I8, &[k, n]).unwrap(); + let (pack_a, pack_b) = &mmm.packings()[1]; + let pa = pack_a.prepare_one(&a, 1, 0).unwrap(); + let pb = pack_b.prepare_one(&b, 0, 1).unwrap(); + let mut scratch = unsafe { mmm.allocate_scratch_space() }; + be.iter_custom(|iters| { + let mut dur = std::time::Duration::default(); + for _ in 0..iters { + let t = std::time::Instant::now(); + unsafe { + mmm.run_with_scratch_space( + m, + n, + scratch.as_mut(), + &[FusedSpec::AddMatMul { + a: AsInputValue::Borrowed(&*pa), + b: AsInputValue::Borrowed(&*pb), + packing: 1, + }], + ) + .unwrap() + }; + dur += t.elapsed(); + } + dur + }); +} + +fn benches(c: &mut Criterion) { + #[cfg(tract_amx_int8)] + { + use tract_linalg::x86_64_fma::amx::has_amx_int8; + use tract_linalg::x86_64_fma::mmm::*; + if !has_amx_int8() { + eprintln!("AMX int8 not available (CPUID + arch_prctl gate failed), skipping"); + return; + } + for &(m, k, n) in + &[(64usize, 256usize, 64usize), (256, 256, 256), (512, 512, 512), (1024, 1024, 64)] + { + let id = format!("{m}x{k}x{n}"); + let mut g = c.benchmark_group("amx_i32/packed_packed"); + g.throughput(Throughput::Elements((m * k * n) as u64)); + g.bench_with_input(BenchmarkId::new("avx2", &id), &(m, k, n), |b, &(m, k, n)| { + run_kernel(b, &*avx2_mmm_i32_8x8.mmm(), m, k, n) + }); + if std::is_x86_feature_detected!("avx512vnni") { + g.bench_with_input( + BenchmarkId::new("avx512vnni", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512vnni_mmm_i32_8x8.mmm(), m, k, n), + ); + } + g.bench_with_input(BenchmarkId::new("avx512amx", &id), &(m, k, n), |b, &(m, k, n)| { + run_kernel(b, &*avx512amx_mmm_i32_8x8.mmm(), m, k, n) + }); + g.finish(); + } + } + #[cfg(not(tract_amx_int8))] + { + eprintln!("tract not built with AMX int8 support (probe failed at build time)"); + let _ = c; + } +} + +criterion_group!(g, benches); +criterion_main!(g); From 55ec8b5636d5585af08340304994b9fac8a59318 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Mon, 1 Jun 2026 13:46:52 +0000 Subject: [PATCH 04/21] linalg/core: wire PackedAmxA through OptMatMulPack and ungated re-export MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The AMX kernel uses a custom A-side packer `PackedAmxA` (M-major-within- panel rows, K padded to multiples of 64). When dispatched on AMX hardware, `OptMatMulPack::eval_with_session` in tract-core sees `PackedAmxA` as the packing format and previously bailed with "OptMatMulPack does not support packing format PackedAmxA". On Cascade Lake the bug was latent (the AMX dispatcher never activated); on Sapphire Rapids/Emerald Rapids it caused 29 quant/matmul tests to fail end-to-end. Fix: * `core/src/ops/matmul/pack.rs::pack_view_with`: add a `PackedAmxA` downcast arm parallel to the existing `PackedI8K4` arm. Gate the import on `target_arch = "x86_64"` since `tract_linalg::x86_64_fma` only exists there. * `linalg/src/x86_64_fma.rs`: drop `#[cfg(tract_amx_int8)]` from `pub mod amx;`. `PackedAmxA` and `has_amx_int8()` are pure data-layout / CPUID code with no AMX-specific assembly — they can compile and exist on any x86_64 host regardless of whether the assembler can encode AMX instructions. Only the kernel registration in `mmm.rs` and the `where(AVX512AMX)` gate need `tract_amx_int8`. This lets tract-core reference `PackedAmxA` unconditionally, removing the cross-crate cfg-gating problem (tract-core's build.rs doesn't run the AMX assembler probe, so it can't see `tract_amx_int8`). Test plan: * `cargo test --release` across tract-linalg / tract-core / tract-data / tract-nnef / tract-onnx / tract-pulse / tract-transformers / tract-hir / tract on Emerald Rapids (model 207, amx-int8 + amx-tile flags): **3458 passed, 0 failed**, including the AVX-512 AMX MMM property suite (`avx512amx_mmm_i32_8x8::{i8i8,i32i32}::frame::prop`, `store_i32/i8::*`, `return_q_scale_*`, `fuse::prop`) and the tract-core `ops::matmul::quant::*` suite that exercises the `OptMatMulPack` -> `PackedAmxA` codepath end-to-end. * All 15 quantized NNEF test cases (conv-q40 × 13, qmul, copy-requant) pass with output assertion against `io.npz` reference on AMX hardware. --- core/src/ops/matmul/pack.rs | 19 +++++++++++++------ linalg/src/x86_64_fma.rs | 1 - linalg/src/x86_64_fma/mmm.rs | 1 - 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/core/src/ops/matmul/pack.rs b/core/src/ops/matmul/pack.rs index 8a3bcc93ae..934abfee86 100644 --- a/core/src/ops/matmul/pack.rs +++ b/core/src/ops/matmul/pack.rs @@ -7,12 +7,15 @@ use tract_linalg::block_quant::{ }; use tract_linalg::mmm::{MMMInputFormat, MMMInputValue, PackedMatrixStorage}; use tract_linalg::pack::{PackedFormat, PackedI8K4}; +#[cfg(target_arch = "x86_64")] +use tract_linalg::x86_64_fma::amx::PackedAmxA; use super::ModePicker; // Pack one (possibly strided) view with a dynamic packing format. Keeps the // PackedFormat fast path byte-identical; routes the K=4-inner SMOPA packer -// (PackedI8K4) through its view packer. Other formats are unsupported here. +// (PackedI8K4) and the AMX A-side packer (PackedAmxA) through their view +// packers. Other formats are unsupported here. fn pack_view_with( packer: &dyn MMMInputFormat, t: &TensorView, @@ -20,12 +23,16 @@ fn pack_view_with( mn_axis: usize, ) -> TractResult> { if let Some(pf) = packer.downcast_ref::() { - pf.pack_tensor_view(t, k_axis, mn_axis) - } else if let Some(p4) = packer.downcast_ref::() { - p4.pack_view(t, k_axis, mn_axis) - } else { - bail!("OptMatMulPack does not support packing format {packer:?}") + return pf.pack_tensor_view(t, k_axis, mn_axis); } + if let Some(p4) = packer.downcast_ref::() { + return p4.pack_view(t, k_axis, mn_axis); + } + #[cfg(target_arch = "x86_64")] + if let Some(pa) = packer.downcast_ref::() { + return pa.pack_view(t, k_axis, mn_axis); + } + bail!("OptMatMulPack does not support packing format {packer:?}") } #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/linalg/src/x86_64_fma.rs b/linalg/src/x86_64_fma.rs index 60099ee74a..7a057c51b1 100644 --- a/linalg/src/x86_64_fma.rs +++ b/linalg/src/x86_64_fma.rs @@ -10,7 +10,6 @@ pub mod act; pub mod act_f16; pub mod act_f16_fp16; -#[cfg(tract_amx_int8)] pub mod amx; pub mod by_scalar; pub mod erf; diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index 21d8a892d2..2327e7cee9 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -4,7 +4,6 @@ use crate::mmm::ImplementationQuality::ManuallyOptimized; use crate::mmm::MatMatMul; use crate::pack::{PackedFormat, PackedI8K4}; -#[cfg(tract_amx_int8)] use super::amx::{PackedAmxA, has_amx_int8}; use super::*; From 256c7db042fe137425c13d2e0e5a016592c75e84 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Mon, 1 Jun 2026 14:09:17 +0000 Subject: [PATCH 05/21] linalg/x86_64: prefetch hints in AMX i8 inner K loop Add prefetcht0 hints inside the K=64 inner loop of avx512amx_mmm_i32_8x8 for the data the NEXT iteration will consume. tileloadd brings the active A/B tile data into L1 on demand; the prefetches ask the hardware prefetcher to start the L2->L1 fill earlier so the next iter's tileloadd sees the data already warm. * A side: 1 prefetch per iter at [rax + 64] -- next iter's A row 0 start. The 7 other rows are stride r8 = K_padded apart; the hardware stream detector picks those up. * B side: 8 prefetches at [rbx + 512..960] -- all 8 cache lines of next iter's 512-byte B panel. Numbers on Emerald Rapids (model 207, 1 thread, `cargo bench -p tract-linalg --bench amx_i32`), packed_packed avx512amx, i8*i8->i32: | Shape (M*K*N) | Before (Gelem/s) | After (Gelem/s) | Delta | |-------------------|------------------|------------------|------:| | 64 * 256 * 64 | 64.5 | 66.5 | +3.2% | | 256 * 256 *256 | 64.5 | 64.5 | ~0% | | 512 * 512 *512 | 110 | 113 | +2.7% | | 1024 * 1024 * 64 | 173 | 174 | +0.6% | Small, consistent win on the long K shapes where B-side L2->L1 traffic matters; flat on the K=64 shape and the saturating K=256 shape. Test plan: * `cargo test --release -p tract-linalg --lib avx512amx_mmm_i32_8x8` on ER: **114 passed, 0 failed** -- the full AMX MMM property suite (i8i8 frame::prop, i32i32 frame::prop, fuse::prop, store_i32/i8, return_q_scale_*) confirms prefetches did not change kernel semantics. --- linalg/x86_64/fma/avx512amx_mmm_i32_8x8.S.j2 | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/linalg/x86_64/fma/avx512amx_mmm_i32_8x8.S.j2 b/linalg/x86_64/fma/avx512amx_mmm_i32_8x8.S.j2 index 6ed351d10a..ff65b6484a 100644 --- a/linalg/x86_64/fma/avx512amx_mmm_i32_8x8.S.j2 +++ b/linalg/x86_64/fma/avx512amx_mmm_i32_8x8.S.j2 @@ -165,6 +165,22 @@ avx512amx_mmm_i32_8x8_{{suffix}} proc tilezero tmm0 {{L}}loop_64k_amx_i8i8: + // Prefetch the data we'll need ONE iteration ahead. tileloadd brings + // the active tile data into L1 on demand; the prefetcht0 hints below + // ask the hardware prefetcher to start the L2->L1 fill for the next + // iter's A row (64 B further along the K axis) and the next iter's + // B panel (512 B = 8 cache lines further). For the long K loops + // (K>=256) the B-side prefetch matters most since each iter consumes + // 8 cache lines of B vs 1 cache line of A row 0. + prefetcht0 [rax + 64] + prefetcht0 [rbx + 512] + prefetcht0 [rbx + 576] + prefetcht0 [rbx + 640] + prefetcht0 [rbx + 704] + prefetcht0 [rbx + 768] + prefetcht0 [rbx + 832] + prefetcht0 [rbx + 896] + prefetcht0 [rbx + 960] tileloadd tmm1, [rax + r8 * 1] // A tile: stride r8 = K_padded tileloadd tmm2, [rbx + r9 * 1] // B tile: stride r9 = 32 tdpbssd tmm0, tmm1, tmm2 From 909a434b64164b87bd2f29b25d2cd026abbadfb2 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:31:52 +0000 Subject: [PATCH 06/21] linalg/x86_64: add 16x16 AMX int8 GEMM kernel (4x mul-adds per tdpbssd) avx512amx_mmm_i32_16x16 hits the maximum AMX i8 tile geometry (16 rows x 64 colsb = 1024 B per tile, both tmm1 A and tmm2 B). One `tdpbssd` now does 16 * 16 * 64 = 16384 mul-adds vs the 8x8 sibling's 4096 -- a 4x work-per-instruction gain, expected to translate to ~2x throughput on 512x512x512 / 1024x1024x64 (the 8x8 path is already memory-bound after the prefetch tuning). Register layout: ROW-MAJOR accumulators (zmm{m} = row m of C with 16 i32 lanes for n=0..15). This matches `tilestored`'s output layout directly, so the hot path (Clear -> AddMatMul -> Store/store_strides_i32_row_contig) needs zero transposes. The 16x16 zmm transpose that a col-major layout would have required is ~30 cross-lane permutes. Epilogue surface re-implemented for AVX-512 zmm: - scalar / per_row / per_col elementwise ops (zmm broadcasts) - leaky_relu via vpcmpgtd mask + vpblendmd - 6x q_scale rounding policies (vpsignd has no AVX-512 form; emulated via vpcmpgtd k1, 0, acc + vpsubd + vpblendmd) - 6x q_shr rounding policies + q_shl (vpsravd / vpsllvd zmm) - Store: row-contig fast path (1 vmovdqu32 or vpmovdb per row), generic scalar fallback for arbitrary strides - AddUnicast: gather via vpgatherdd with index = lane * col_stride - LoadTile: gather from col-major scratch with constant index vector - AddRowColProducts: outer product via row_data[m] broadcast x col_data A reuses PackedAmxA(16); B reuses PackedI8K4(16). Both packers are r- generic (K-padded to multiples of 64; K=4-inner block of 16 N-cols). The 16x16 is plugged as the primary `qmmm_i32` dispatch target whenever `has_amx_int8` is true; the 8x8 stays registered as `mmm_impls` so the dispatcher can pick it for smaller problems. Property-test surface mirrors the 8x8: 114 tests, skip-pass on non-AMX hosts via the runtime gate. --- linalg/src/x86_64_fma/mmm.rs | 18 +- .../x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 | 920 ++++++++++++++++++ 2 files changed, 936 insertions(+), 2 deletions(-) create mode 100644 linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index 2327e7cee9..81c9f84dd4 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -125,6 +125,17 @@ MMMExternKernel! { avx512amx_mmm_i32_8x8(8,8)@(64,4) where(AVX512AMX) store(i8) } +// 16x16 i32 sibling. One tdpbssd does 16*16*64 = 16384 mul-adds (4x the 8x8). +// Same A/B packing (PackedAmxA, PackedI8K4) just with r=16. Row-major +// accumulators (zmm{m} = row m of C) so the hot path (Clear -> AddMatMul -> +// Store) needs no transpose. +#[cfg(tract_amx_int8)] +MMMExternKernel! { avx512amx_mmm_i32_16x16(16,16)@(64,4) where(AVX512AMX) + packing[1] = i8i8 => |k| k.with_packing(PackedAmxA::new(16), PackedI8K4::new(16)); + quality(ManuallyOptimized) + store(i8) +} + pub fn plug(ops: &mut Ops) { if is_x86_feature_detected!("avx2") { plug_avx2(ops); @@ -157,8 +168,11 @@ pub fn plug_avx512vnni(ops: &mut Ops) { #[cfg(tract_amx_int8)] pub fn plug_avx512amx_int8(ops: &mut Ops) { ops.mmm_impls.push(avx512amx_mmm_i32_8x8.mmm()); - ops.qmmm_i32 = Box::new(|_, _, _| avx512amx_mmm_i32_8x8.mmm()); - log::info!("qmmm_i32: x86_64/avx512amx_int8 activated"); + ops.mmm_impls.push(avx512amx_mmm_i32_16x16.mmm()); + // 16x16 hits the full AMX tile (1024 B per tile) and is ~4x the mul-adds + // per tdpbssd; use it as the primary qmmm_i32 dispatch target. + ops.qmmm_i32 = Box::new(|_, _, _| avx512amx_mmm_i32_16x16.mmm()); + log::info!("qmmm_i32: x86_64/avx512amx_int8 (16x16) activated"); } pub fn plug_avx2(ops: &mut Ops) { diff --git a/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 b/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 new file mode 100644 index 0000000000..50319c1813 --- /dev/null +++ b/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 @@ -0,0 +1,920 @@ +// vim: set syntax=asm : +// +// Intel AMX int8 GEMM kernel, 16 M-rows x 16 N-cols i32 accumulator output. +// +// One `tdpbssd tmm0, tmm1, tmm2` instruction performs: +// tmm0[m, n] += sum_{k=0..63} A[m, k] * B[k, n] +// for m=0..15, n=0..15: 16 * 16 * 64 = 16384 mul-adds per single instruction. +// That's 4x the work-per-instruction of the 8x8 sibling kernel, hitting the +// full AMX i8 tile geometry (max colsb=64, max rows=16, max bytes=1024). +// +// Tile geometry (palette 1): +// tmm0 = C accumulator: 16 rows x 64 colsb = 16 M-rows x 16 N-cols of i32 +// tmm1 = A tile: 16 rows x 64 colsb = 16 M-rows x 64 K-bytes per iter +// tmm2 = B tile: 16 rows x 64 colsb = 16 K-pair-rows x 16 N-cols * 4 K +// +// `tdpbssd` is signed-signed, so no +128 trick is needed; the i32 accumulators +// are bit-identical to the AVX2 / VNNI / 8x8-AMX reference paths. +// +// A is packed via PackedAmxA(16): per panel of 16 M-rows, row-major within the +// panel, K-bytes contiguous along the row, K_padded = ceil(K/64)*64. +// B reuses PackedI8K4(16): per K=4 block, 16 N-cols * 4 K-bytes = 64 bytes; +// 16 such K-blocks per tmm2 tile = 1024 bytes = one tileloadd. +// +// REGISTER LAYOUT +// zmm0..zmm15 = accumulators, ROW-MAJOR: zmm{m} holds the 16 i32 lanes +// [C[m, 0], C[m, 1], ..., C[m, 15]] for row m. +// This matches the row-major i32 layout that `tilestored` writes directly, +// so the hot path (Clear -> AddMatMul -> Store) needs no transpose. + +{% if msvc %} + +_text segment +avx512amx_mmm_i32_16x16_{{suffix}} proc + +{% else %} + +.intel_syntax noprefix +.text +.p2align 5 +.globl {{G}}avx512amx_mmm_i32_16x16_{{suffix}} +{{G}}avx512amx_mmm_i32_16x16_{{suffix}}: +.cfi_startproc + +{% endif %} + + push rbp + mov rbp, rsp + +{% if family == "windows" %} + and rsp,-16 + lea rsp,[rsp-160] + vmovaps [rsp], xmm6 + vmovaps [rsp+16*1],xmm7 + vmovaps [rsp+16*2],xmm8 + vmovaps [rsp+16*3],xmm9 + vmovaps [rsp+16*4],xmm10 + vmovaps [rsp+16*5],xmm11 + vmovaps [rsp+16*6],xmm12 + vmovaps [rsp+16*7],xmm13 + vmovaps [rsp+16*8],xmm14 + vmovaps [rsp+16*9],xmm15 + + push rdi + push rsi + + mov rdi, rcx + +{% endif %} + + push rbx + push r12 + push r13 + push r14 + push r15 + + sub rsp, 8 + +{% if family == "unix" %} +.cfi_def_cfa_offset 64 +{% endif %} + + stmxcsr [rsp + 4] +{% if msvc %} + mov rax, 1FC0h +{% else %} + mov rax, 0x1FC0 +{% endif %} + mov [rsp], eax + ldmxcsr [rsp] + + // Reserve 64 bytes for the AMX tile-config block, zero it, populate + // palette + dims (all three tiles are 16 rows x 64 colsb, the maximum + // i8 tile geometry on Sapphire Rapids / Emerald Rapids / Granite Rapids). + sub rsp, 64 + vpxor xmm15, xmm15, xmm15 + vmovdqu [rsp ], xmm15 + vmovdqu [rsp + 16], xmm15 + vmovdqu [rsp + 32], xmm15 + vmovdqu [rsp + 48], xmm15 + mov byte ptr [rsp + 0 ], 1 // palette = 1 + mov word ptr [rsp + 16], 64 // colsb[0] = 64 (tmm0) + mov word ptr [rsp + 18], 64 // colsb[1] = 64 (tmm1) + mov word ptr [rsp + 20], 64 // colsb[2] = 64 (tmm2) + mov byte ptr [rsp + 48], 16 // rows[0] = 16 (tmm0) + mov byte ptr [rsp + 49], 16 // rows[1] = 16 (tmm1) + mov byte ptr [rsp + 50], 16 // rows[2] = 16 (tmm2) + ldtilecfg [rsp] + +{% include "dispatcher.j2" %} + +{{L}}clear: + // vzeroall only zeros lower-256 of zmm0..15; explicitly zero the full + // accumulators (zmm0..zmm15) for AMX. + {% for r in range(0, 16) %} + vpxorq zmm{{r}}, zmm{{r}}, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_mat_mul: + mov r12, [rdi + 32] // packing + mov rbx, [rdi + 24] // B + mov rax, [rdi + 16] // A + + mov rcx, [rdi + 8] // k + test rcx, rcx + jz {{L}}non_linear_loop + + cmp r12, 1 + je {{L}}main_loop_packed_packed_i8i8 + +{{L}}main_loop_packed_packed: + // Generic i32 x i32 fallback path (not AMX). For row-major accumulators + // with zmm{m} = row m of C, accumulating C[m, n] += A[m, k] * B[k, n]: + // - load 16 B values for this K row into zmm16 (row of B) + // - for each m: broadcast A[m, k], multiply by zmm16, add to zmm{m} + vmovups zmm16, [rbx] // 16 i32 of B at this K row + + {% for m in range(0, 16) %} + vpbroadcastd zmm17, dword ptr [rax + {{m}} * 4] + vpmulld zmm18, zmm16, zmm17 + vpaddd zmm{{m}}, zmm{{m}}, zmm18 + {% endfor %} + + add rax, 64 // 16 i32 lanes per K step + add rbx, 64 + dec rcx + jnz {{L}}main_loop_packed_packed + + jmp {{L}}non_linear_loop + +{{L}}main_loop_packed_packed_i8i8: + // AMX i8 layout: + // A panel: 16 M-rows x K_padded K-bytes ROW-major within the panel + // (PackedAmxA, K_padded = ceil(K/64)*64). + // B panel: PackedI8K4(16) -- 16 N-cols x 4 K-bytes per K=4 block, with + // 16 K-blocks per tileloadd (16 K-pair-rows x 64 colsb). + // + // Per tdpbssd: tmm0[m, n] += sum_{k=0..63} A[m, k] * B[k, n]. + // Inner loop steps along K in 64-K chunks. + + // r8 <- K_padded = ceil(k/64) * 64 = byte-stride between A's M-rows. + mov r8, rcx + add r8, 63 + and r8, -64 + + // rcx <- ceil(k/64) = number of K=64 AMX inner iterations. + add rcx, 63 + shr rcx, 6 + + // r9 <- 64 = byte-stride between B's K-pair-rows (each row = 16 N-cols * 4 K). + mov r9, 64 + + tilezero tmm0 + +{{L}}loop_64k_amx_i8i8_16x16: + // Hint the prefetcher one iter ahead. A's row 0 advances by 64 bytes; + // B's panel advances by 16 cache lines (64 * 16 = 1024 B per iter). + prefetcht0 [rax + 64] + prefetcht0 [rbx + 1024] + prefetcht0 [rbx + 1088] + prefetcht0 [rbx + 1152] + prefetcht0 [rbx + 1216] + prefetcht0 [rbx + 1280] + prefetcht0 [rbx + 1344] + prefetcht0 [rbx + 1408] + prefetcht0 [rbx + 1472] + prefetcht0 [rbx + 1536] + prefetcht0 [rbx + 1600] + prefetcht0 [rbx + 1664] + prefetcht0 [rbx + 1728] + prefetcht0 [rbx + 1792] + prefetcht0 [rbx + 1856] + prefetcht0 [rbx + 1920] + prefetcht0 [rbx + 1984] + tileloadd tmm1, [rax + r8 * 1] // A tile: stride = K_padded + tileloadd tmm2, [rbx + r9 * 1] // B tile: stride = 64 + tdpbssd tmm0, tmm1, tmm2 + add rax, 64 // +64 K-bytes in A row 0 + add rbx, 1024 // 16 K-pairs * 64 = 1024 bytes + dec rcx + jnz {{L}}loop_64k_amx_i8i8_16x16 + + // tmm0 -> stack scratch (16 rows x 64 bytes = 1024 B row-major i32). + // Then load each row into zmm0..zmm15. Row m's 16 i32 are contiguous + // in memory, so each load is a single 64-byte vmovdqu32. + sub rsp, 1024 + mov r10, rsp + mov r11, 64 + tilestored [r10 + r11 * 1], tmm0 + + {% for m in range(0, 16) %} + vmovdqu32 zmm{{m}}, [r10 + {{ m * 64 }}] + {% endfor %} + + add rsp, 1024 + + jmp {{L}}non_linear_loop + +// ---- Scalar / per-row / per-col elementwise epilogues ------------------- + +{{L}}scalar_min: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpminsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_max: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpmaxsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_add: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpaddd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_mul: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpmulld zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_sub: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_sub_flipped: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm16, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}leaky_relu: + // C[m, n] = (C[m, n] >= 0) ? C[m, n] : alpha * C[m, n] + vpbroadcastd zmm17, dword ptr [rdi + 8] // alpha as i32 scale factor + vpxorq zmm16, zmm16, zmm16 + {% for r in range(0, 16) %} + vpmulld zmm18, zmm{{r}}, zmm17 + vpcmpgtd k1, zmm16, zmm{{r}} // 1 where C < 0 + vpblendmd zmm{{r}}{k1}, zmm{{r}}, zmm18 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_min: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpminsd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_max: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpmaxsd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_add: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpaddd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_mul: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpmulld zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_sub: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpsubd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_sub_flipped: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpsubd zmm{{m}}, zmm16, zmm{{m}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_min: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpminsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_max: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpmaxsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_add: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpaddd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_mul: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpmulld zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_sub: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_sub_flipped: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm16, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}load_tile: + // Scratch layout is COL-MAJOR i32 from scratch.rs Store/AddUnicast remnant: + // tile[col][row] at offset (col*MR + row)*4 with MR=16 + // = offset col*64 + row*4 + // For row-major accumulators we gather row m's 16 cols at index step 64. + mov r8, [rdi + 8] + vmovdqa32 zmm16, [rip + {{L}}lane_offsets_64] // [0, 64, 128, ..., 15*64] + {% for m in range(0, 16) %} + mov eax, 0xFFFF + kmovw k1, eax + vpgatherdd zmm{{m}}{k1}, [r8 + zmm16 + {{m * 4}}] + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_unicast: + mov r10, [rdi + 8] // c ptr (base) + mov rsi, [rdi + 16] // row stride + mov rbx, [rdi + 24] // col stride + mov r8, [rdi + 32] // item size + + cmp r8, 4 + je {{L}}non_linear_addc_i32 + + // i8 path: read 16 i8 from [r10 + m*rsi + n*rbx] for n=0..15, sign-extend + // to i32, add to zmm{m}. Use a stack scratch buffer (16 bytes per row). + sub rsp, 16 + {% for m in range(0, 16) %} + mov r8, r10 + {% for n in range(0, 16) %} + mov al, [r8] + mov byte ptr [rsp + {{n}}], al + add r8, rbx + {% endfor %} + vpmovsxbd zmm16, [rsp] + vpaddd zmm{{m}}, zmm{{m}}, zmm16 + add r10, rsi + {% endfor %} + add rsp, 16 + jmp {{L}}non_linear_loop + +{{L}}non_linear_addc_i32: + // i32 strided read of external (or scratch) tile. Build per-lane index + // vector [0, rbx, 2*rbx, ..., 15*rbx] once, then gather row by row. + mov eax, ebx + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 + vpmulld zmm16, zmm16, [rip + {{L}}lane_indices] // [0, rbx, 2*rbx, ...] + + {% for m in range(0, 16) %} + mov eax, 0xFFFF + kmovw k1, eax + vpgatherdd zmm17{k1}, [r10 + zmm16] + vpaddd zmm{{m}}, zmm{{m}}, zmm17 + add r10, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_row_col_products: + // bias[m, n] = row_data[m] * col_data[n], add to C[m, n]. + // For row-major regs: load 16 col_data values once into zmm16, + // for each m: broadcast row_data[m], FMA add. + mov rax, [rdi + 8] + mov rbx, [rdi + 16] + + vmovdqu32 zmm16, [rax] // 16 row_data values + vmovdqu32 zmm17, [rbx] // 16 col_data values + + {% for m in range(0, 16) %} + vpbroadcastd zmm18, dword ptr [rax + {{m * 4}}] // splat row_data[m] + vpmulld zmm19, zmm18, zmm17 + vpaddd zmm{{m}}, zmm{{m}}, zmm19 + {% endfor %} + jmp {{L}}non_linear_loop + +// ---- Q-scale (mult-shift with rounding) --------------------------------- + +{{L}}q_scale: + mov r8, [rdi + 16] // policy + vpbroadcastd zmm16, dword ptr [rdi + 24] // multi (broadcast i32) + + mov rax, 1 + vmovq xmm17, rax + vpbroadcastq zmm17, xmm17 // zmm17 <- 1 (i64 lanes) + + mov rax, [rdi + 8] // shift + add rax, 31 + vmovq xmm18, rax + vpbroadcastq zmm18, xmm18 // zmm18 <- (shift+31) (i64 lanes) + + vpsubq zmm19, zmm18, zmm17 + vpsllvq zmm19, zmm17, zmm19 // zmm19 <- 1 << (shift+31-1) (i64) + + // Per-lane interleave mask for blending evens / shifted-odds. + // bit i = 1 means take from "evens" source in vpblendmd; bit 0,2,4,...,14 set. + mov eax, 0x5555 + kmovw k7, eax + + cmp r8, 1 + je {{L}}q_scale_rounding_zero + cmp r8, 2 + je {{L}}q_scale_rounding_away + cmp r8, 3 + je {{L}}q_scale_rounding_minus_inf + cmp r8, 4 + je {{L}}q_scale_rounding_plus_inf + cmp r8, 5 + je {{L}}q_scale_rounding_even + cmp r8, 6 + je {{L}}q_scale_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_scale_rounding_zero: // signum * ( (abs + nudge - 1) >> shift ) +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 // even-lane i32 -> i64 mul + vpmuldq zmm21, zmm21, zmm16 // odd-lane i32 -> i64 mul + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsubq zmm20, zmm20, zmm17 + vpsubq zmm21, zmm21, zmm17 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 // k7=0x5555: evens from zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_away: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_minus_inf: // nudge by -1 where input was negative +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpxorq zmm22, zmm22, zmm22 + vpcmpgtd k1, zmm{{i}}, zmm22 // k1: 1 where input > 0 (we want the inverse, see below) + knotw k1, k1 // 1 where input <= 0 -- we want "input was negative => subtract 1" + // For "<0": use compare against 0 with vpcmpltd + vpxorq zmm22, zmm22, zmm22 + vpcmpltd k1, zmm{{i}}, zmm22 // 1 where input < 0 + vmovdqa32 zmm23{k1}{z}, [rip + {{L}}all_ones_i32] // (1 << 0) per neg lane, 0 elsewhere + + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + // Subtract 1 from i64-evens / i64-odds where the original i32 input was < 0. + vpsrldq zmm24, zmm23, 4 + vpmovsxdq zmm25, ymm23 + vpmovsxdq zmm26, ymm24 + vpsubq zmm20, zmm20, zmm25 + vpsubq zmm21, zmm21, zmm26 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_plus_inf: // nudge by +1 where input was non-negative +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpxorq zmm22, zmm22, zmm22 + vpcmpled k1, zmm22, zmm{{i}} // 1 where input >= 0 + vmovdqa32 zmm23{k1}{z}, [rip + {{L}}all_ones_i32] + + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrldq zmm24, zmm23, 4 + vpmovsxdq zmm25, ymm23 + vpmovsxdq zmm26, ymm24 + vpsubq zmm20, zmm20, zmm25 + vpsubq zmm21, zmm21, zmm26 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_even: // banker's: round half to even +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpsrlq zmm22, zmm20, xmm18 + vpandq zmm22, zmm22, zmm17 + vpaddq zmm20, zmm20, zmm22 + vpsubq zmm20, zmm20, zmm17 + + vpsrlq zmm22, zmm21, xmm18 + vpandq zmm22, zmm22, zmm17 + vpaddq zmm21, zmm21, zmm22 + vpsubq zmm21, zmm21, zmm17 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_odd: // round half to odd +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpsrlq zmm22, zmm20, xmm18 + vpandq zmm22, zmm22, zmm17 + vpsubq zmm20, zmm20, zmm22 + + vpsrlq zmm22, zmm21, xmm18 + vpandq zmm22, zmm22, zmm17 + vpsubq zmm21, zmm21, zmm22 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shl: + mov eax, [rdi + 8] // -shift (count: i32) + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 + {% for i in range(0, 16) %}vpsllvd zmm{{i}}, zmm{{i}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr: + mov r8, [rdi + 16] // policy + + mov eax, 1 + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 // zmm16 <- 1 (i32 lanes) + + mov eax, [rdi + 8] // shift + vmovd xmm17, eax + vpbroadcastd zmm17, xmm17 // zmm17 <- shift (i32 lanes) + + mov ebx, 1 + mov cl, al + sub cl, 1 + sal ebx, cl // ebx <- 1 << (shift - 1) + vmovd xmm18, ebx + vpbroadcastd zmm18, xmm18 // zmm18 <- "half" + + vpxorq zmm19, zmm19, zmm19 // zeroes + + cmp r8, 1 + je {{L}}q_shr_rounding_zero + cmp r8, 2 + je {{L}}q_shr_rounding_away + cmp r8, 3 + je {{L}}q_shr_rounding_minus_inf + cmp r8, 4 + je {{L}}q_shr_rounding_plus_inf + cmp r8, 5 + je {{L}}q_shr_rounding_even + cmp r8, 6 + je {{L}}q_shr_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_shr_rounding_zero: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsubd zmm20, zmm20, zmm16 + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_away: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_minus_inf: +{% for i in range(0, 16) %} + vpsubd zmm{{i}}, zmm{{i}}, zmm16 + vpaddd zmm{{i}}, zmm{{i}}, zmm18 + vpsravd zmm{{i}}, zmm{{i}}, zmm17 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_plus_inf: +{% for i in range(0, 16) %} + vpaddd zmm{{i}}, zmm{{i}}, zmm18 + vpsravd zmm{{i}}, zmm{{i}}, zmm17 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_even: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsravd zmm21, zmm20, zmm17 + vpandq zmm21, zmm21, zmm16 + vpsubd zmm21, zmm21, zmm16 // nudge = ((abs >>l shift) & 1) - 1 + vpaddd zmm20, zmm20, zmm21 + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_odd: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsravd zmm21, zmm20, zmm17 + vpandq zmm21, zmm21, zmm16 + vpsubd zmm21, zmm19, zmm21 // nudge = -((abs >>l shift) & 1) + vpaddd zmm20, zmm20, zmm21 + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +// ---- Store --------------------------------------------------------------- + +{{L}}store: + mov r8, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rdx, [rdi + 24] // col stride + mov rcx, [rdi + 32] // item size + + cmp rcx, 4 + je {{L}}store_strides_i32 + // else: i8 fallthrough + + cmp rdx, 1 + je {{L}}store_strides_i8_row_contig + + // Generic i8 strided store: per row, per lane scalar byte stores + {% for m in range(0, 16) %} + mov r10, r8 + // Extract from each 128-bit slice of zmm{{m}} + vextracti32x4 xmm20, zmm{{m}}, 0 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 1 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 2 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 3 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_i8_row_contig: + // Each row is 16 i8 contiguous; one vpmovdb per row. + {% for m in range(0, 16) %} + vpmovdb [r8], zmm{{m}} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_i32: + cmp rdx, 4 + je {{L}}store_strides_i32_row_contig + + // Generic i32 strided store + {% for m in range(0, 16) %} + mov r10, r8 + vextracti32x4 xmm20, zmm{{m}}, 0 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 1 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 2 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 3 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_i32_row_contig: + // C is row-major in memory: each row's 16 i32 are contiguous; one + // 64-byte aligned-or-unaligned store per row. + {% for m in range(0, 16) %} + vmovdqu32 [r8], zmm{{m}} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}return: + tilerelease + add rsp, 64 + + ldmxcsr [rsp + 4] + add rsp, 8 + + pop r15 + pop r14 + pop r13 + pop r12 + pop rbx + +{% if family == "windows" %} + pop rsi + pop rdi + + vmovaps xmm15, [rsp+16*9] + vmovaps xmm14, [rsp+16*8] + vmovaps xmm13, [rsp+16*7] + vmovaps xmm12, [rsp+16*6] + vmovaps xmm11, [rsp+16*5] + vmovaps xmm10, [rsp+16*4] + vmovaps xmm9, [rsp+16*3] + vmovaps xmm8, [rsp+16*2] + vmovaps xmm7, [rsp+16*1] + vmovaps xmm6, [rsp] +{% endif %} + + mov rsp, rbp + pop rbp + ret + +// ---- Read-only data (RIP-relative) --------------------------------------- + +.p2align 6 +{{L}}lane_offsets_64: + .int 0, 64, 128, 192, 256, 320, 384, 448 + .int 512, 576, 640, 704, 768, 832, 896, 960 + +.p2align 6 +{{L}}lane_indices: + .int 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + +.p2align 6 +{{L}}all_ones_i32: + .int 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + +{% if msvc %} +avx512amx_mmm_i32_16x16_{{suffix}} endp +_text ends +end +{% else %} +.cfi_endproc +{% endif %} From f4913714cf5bbefcfea674281bdabddd09a92594 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:34:36 +0000 Subject: [PATCH 07/21] linalg/bench: include avx512amx_mmm_i32_16x16 in amx_i32 microbench Adds the new 16x16 kernel alongside the existing avx2 / avx512vnni / avx512amx_8x8 entries so reviewers running the bench on Sapphire Rapids+ can see the per-shape throughput delta between the two AMX variants (8x8 vs 16x16) on the same M/K/N points (64x256x64, 256x256x256, 512x512x512, 1024x1024x64). --- linalg/benches/amx_i32.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/linalg/benches/amx_i32.rs b/linalg/benches/amx_i32.rs index d3fc758c3c..79388ec05f 100644 --- a/linalg/benches/amx_i32.rs +++ b/linalg/benches/amx_i32.rs @@ -64,9 +64,12 @@ fn benches(c: &mut Criterion) { |b, &(m, k, n)| run_kernel(b, &*avx512vnni_mmm_i32_8x8.mmm(), m, k, n), ); } - g.bench_with_input(BenchmarkId::new("avx512amx", &id), &(m, k, n), |b, &(m, k, n)| { + g.bench_with_input(BenchmarkId::new("avx512amx_8x8", &id), &(m, k, n), |b, &(m, k, n)| { run_kernel(b, &*avx512amx_mmm_i32_8x8.mmm(), m, k, n) }); + g.bench_with_input(BenchmarkId::new("avx512amx_16x16", &id), &(m, k, n), |b, &(m, k, n)| { + run_kernel(b, &*avx512amx_mmm_i32_16x16.mmm(), m, k, n) + }); g.finish(); } } From cadd1fae1531a97e0fb43af803fd57d91a01a7a5 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:47:32 +0000 Subject: [PATCH 08/21] linalg/x86_64: AMX 16x16 prefetch follows oneDNN cache-aware pattern oneDNN (the Intel-backed reference AMX implementation in jit_brgemm_amx_uker) distinguishes two roles in the inner-K loop: - A is REUSED across the outer matmul's N-tile sweep, so it benefits from being cached in L1. oneDNN uses `tileloadd` (cached) for A with a light `prefetcht0` hint to L1. - B STREAMS THROUGH once per kernel call (each N-tile gets its own B panel). For the AMX-typical large-matmul case the per-call B working set exceeds L1d (32 KB on Sapphire Rapids). oneDNN's heuristic `try_load_nt = footprint(A)+footprint(B)+footprint(C) >= L1` flips B's load to `tileloaddt1` (non-temporal, bypasses L1) and steers B-side prefetches at L2 (`prefetcht1`) instead of L1. The previous 16x16 prefetch block (17 `prefetcht0`'s + `tileloadd` for B) matched the 8x8 pattern proportionally but over-ran Sapphire Rapids' 16 Line Fill Buffer budget: 1 A-prefetch + 16 B-prefetches + 2 active tileloadds = 19 in-flight slots demanded, vs 16 available. That backs up real loads behind dropped prefetches. This patch aligns 16x16 with oneDNN's defaults for the large-matmul case: - A: prefetcht0 + tileloadd (1 LFB for prefetch + 1 for load) - B: 6x prefetcht1 + tileloaddt1 (6 LFBs for L2 priming + 1 NT load) -> primes the head 384 B of next-iter B-panel (6 of 16 lines); the SPR/EMR/GNR HW stream prefetcher reliably covers the remaining 10 lines once the 1024-B stride is detected. Total in-flight per iter: 9 (was 19). This leaves headroom for the OoO engine to overlap multiple iterations. The 8x8 kernel is left untouched since (a) its existing 9-prefetch pattern already fits the LFB budget, and (b) its 119 GElem/s @ 512x512x512 on EMR has been validated. Property tests (avx512amx_mmm_i32_16x16 suite) skip-pass on this CL host via the runtime gate; will be re-validated on AMX HW. Refs: - oneDNN src/cpu/x64/brgemm/brgemm_utils.cpp (load_nt heuristic) - oneDNN src/cpu/x64/brgemm/jit_brgemm_amx_uker.cpp (tileloaddt1 use) - Intel SDM Vol 1, sec 18.3 (AMX), Vol 3 (XSAVE tile state) - chipsandcheese.com SPR deep-dive (LFB count = 16) --- .../x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 | 48 +++++++++++-------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 b/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 index 50319c1813..5652f425f8 100644 --- a/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 +++ b/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 @@ -173,27 +173,33 @@ avx512amx_mmm_i32_16x16_{{suffix}} proc tilezero tmm0 {{L}}loop_64k_amx_i8i8_16x16: - // Hint the prefetcher one iter ahead. A's row 0 advances by 64 bytes; - // B's panel advances by 16 cache lines (64 * 16 = 1024 B per iter). - prefetcht0 [rax + 64] - prefetcht0 [rbx + 1024] - prefetcht0 [rbx + 1088] - prefetcht0 [rbx + 1152] - prefetcht0 [rbx + 1216] - prefetcht0 [rbx + 1280] - prefetcht0 [rbx + 1344] - prefetcht0 [rbx + 1408] - prefetcht0 [rbx + 1472] - prefetcht0 [rbx + 1536] - prefetcht0 [rbx + 1600] - prefetcht0 [rbx + 1664] - prefetcht0 [rbx + 1728] - prefetcht0 [rbx + 1792] - prefetcht0 [rbx + 1856] - prefetcht0 [rbx + 1920] - prefetcht0 [rbx + 1984] - tileloadd tmm1, [rax + r8 * 1] // A tile: stride = K_padded - tileloadd tmm2, [rbx + r9 * 1] // B tile: stride = 64 + // Cache strategy follows oneDNN's AMX BRGEMM heuristics (Intel-backed): + // - A is reused across N-tiles in tract's outer matmul loop, so we use + // `tileloadd` (cached, brings into L1) and `prefetcht0` (to L1) for A. + // - B streams through once per kernel call (one B-panel per N-tile), and + // for the AMX-typical large-matmul case the B working set exceeds the + // 32 KB L1d on Sapphire Rapids+. We use `tileloaddt1` (non-temporal, + // bypasses L1) and `prefetcht1` (to L2) for B -- the same pattern + // oneDNN picks when its footprint heuristic crosses the L1 threshold. + // - Sapphire Rapids has 16 L1d Fill Buffers (LFBs); each in-flight + // prefetch/load consumes one. The previous version's 17 prefetches + + // 2 active tileloadds overflowed the LFB budget. The reduced count + // below leaves headroom and lets the HW streaming prefetcher cover + // the remaining B-panel lines. + // + // A advances 64 B / iter (one cache line). B advances 1024 B / iter + // (16 cache lines). We prime 6 of the next 16 B-lines at +1024..+1344, + // then trust the HW stream prefetcher (very aggressive on SPR/EMR/GNR) + // to cover lines +1408..+1984. + prefetcht0 [rax + 64] // next A-row K-block (to L1) + prefetcht1 [rbx + 1024] // next B-panel head (to L2) + prefetcht1 [rbx + 1088] + prefetcht1 [rbx + 1152] + prefetcht1 [rbx + 1216] + prefetcht1 [rbx + 1280] + prefetcht1 [rbx + 1344] + tileloadd tmm1, [rax + r8 * 1] // A tile (cached): stride = K_padded + tileloaddt1 tmm2, [rbx + r9 * 1] // B tile (non-temporal): stride = 64 tdpbssd tmm0, tmm1, tmm2 add rax, 64 // +64 K-bytes in A row 0 add rbx, 1024 // 16 K-pairs * 64 = 1024 bytes From 8786a5a37cfa71a3f80f73904c796ddef92ab8f7 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:54:02 +0000 Subject: [PATCH 09/21] linalg/x86_64: shape-adaptive AMX dispatch (16x16 for big, 8x8 for small) The `qmmm_i32` closure now selects between the 8x8 and 16x16 AMX kernels based on the (m, k, n) hint -- mirroring oneDNN's BRGEMM ukernel-variant selection logic, where the MR/NR pair is picked per problem size rather than fixed at build time. Rationale: - 16x16 (1024 B/tile, 16384 mul-adds per tdpbssd) wins on big problems where the per-call setup cost (ldtilecfg + 16-row epilogue scratch) is amortised across many K-iters. - 8x8 (256 B/tile, 4096 mul-adds per tdpbssd) wins on small problems where 16x16 would over-pad and pay full epilogue cost on a mostly- empty C tile. Threshold: 16x16 picked iff m >= 16 AND n >= 16 AND k >= 64, all treating Option::None ("streaming / unknown") as "large enough" since dynamic-shape models default to throughput-champion 16x16. The exact crossover should be re-validated on AMX HW; this is a heuristic best-guess until then. --- linalg/src/x86_64_fma/mmm.rs | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index 81c9f84dd4..3ca407358e 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -169,10 +169,28 @@ pub fn plug_avx512vnni(ops: &mut Ops) { pub fn plug_avx512amx_int8(ops: &mut Ops) { ops.mmm_impls.push(avx512amx_mmm_i32_8x8.mmm()); ops.mmm_impls.push(avx512amx_mmm_i32_16x16.mmm()); - // 16x16 hits the full AMX tile (1024 B per tile) and is ~4x the mul-adds - // per tdpbssd; use it as the primary qmmm_i32 dispatch target. - ops.qmmm_i32 = Box::new(|_, _, _| avx512amx_mmm_i32_16x16.mmm()); - log::info!("qmmm_i32: x86_64/avx512amx_int8 (16x16) activated"); + // Shape-adaptive dispatch: + // - 16x16 hits the full AMX tile (1024 B/tile, 16384 mul-adds per + // tdpbssd) and is the throughput champion when at least one tile + // of each dim is fully utilised. + // - 8x8 has lower per-call setup cost (1/4 the tile-store scratch, + // half the prefetch budget, smaller epilogue) and beats 16x16 on + // small problems where the framework's tile-padding overhead + // dominates. + // The exact crossover should be re-validated on AMX HW; oneDNN uses + // similar shape-based MR/NR selection for its BRGEMM ukernel variants. + ops.qmmm_i32 = Box::new(|m, k, n| { + // m, k, n are Option -- None means "unknown / streaming dim". + // For unknown dims default to the throughput champion (16x16); only + // fall back to 8x8 when a static dim is known to be tiny. + let big = |o: Option, t: usize| o.is_none_or(|v| v >= t); + if big(m, 16) && big(n, 16) && big(k, 64) { + avx512amx_mmm_i32_16x16.mmm() + } else { + avx512amx_mmm_i32_8x8.mmm() + } + }); + log::info!("qmmm_i32: x86_64/avx512amx_int8 (16x16 + 8x8 adaptive) activated"); } pub fn plug_avx2(ops: &mut Ops) { From f50f4540452b816299d18de36815c66a396a187f Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Mon, 1 Jun 2026 18:04:16 +0000 Subject: [PATCH 10/21] linalg/x86_64: add CPUID-leaf-4 cache-size detection for AMX dispatch Adds `cache_sizes() -> CacheSizes { l1d_bytes, l2_bytes, l3_bytes }`, the analog of oneDNN's `platform::get_per_core_cache_size`. Probes CPUID leaf 4 deterministic cache parameters iteratively over sub-leaves until a zero cache-type byte; computes per-cache size as (ways+1) * (partitions+1) * (line_size+1) * (sets+1). Memoised behind a OnceLock since the values are constant for the lifetime of the process. Currently used at AMX-int8 plug time to log the detected cache hierarchy (useful for diagnostics + future tuning); the public API exists so that future shape-adaptive kernel variants can mirror oneDNN's `try_load_nt = footprint(A)+footprint(B)+footprint(C) >= L1` heuristic at runtime. This makes the existing 16x16 kernel's static "use tileloaddt1 + L2 prefetch for B" choice (currently hardcoded to the AMX-typical large- working-set case) honest about the assumption, and gives us the instrument to add a small-working-set 16x16 variant later if HW bench data shows it's worth it. --- linalg/src/x86_64_fma/amx.rs | 53 ++++++++++++++++++++++++++++++++++++ linalg/src/x86_64_fma/mmm.rs | 9 +++++- 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/linalg/src/x86_64_fma/amx.rs b/linalg/src/x86_64_fma/amx.rs index d7371e7c4b..48f11839e6 100644 --- a/linalg/src/x86_64_fma/amx.rs +++ b/linalg/src/x86_64_fma/amx.rs @@ -23,6 +23,59 @@ use crate::frame::mmm::{ EagerPackedInput, MMMInputFormat, MMMInputValue, PackedExoticFact, PackedMatrixStorage, }; +/// Per-cache geometry from CPUID leaf 4 deterministic cache parameters +/// (the mechanism oneDNN's `platform::get_per_core_cache_size` ultimately +/// reads). Used here for runtime adaptive choices that depend on the +/// hardware -- e.g. picking `tileloadd` vs `tileloaddt1` based on whether +/// the matmul working set fits in L1d (oneDNN's `try_load_nt` heuristic). +#[derive(Clone, Copy, Debug, Default)] +pub struct CacheSizes { + pub l1d_bytes: usize, + pub l2_bytes: usize, + pub l3_bytes: usize, +} + +/// Probe per-core L1d/L2/L3 cache sizes via CPUID leaf 4 deterministic +/// cache parameters. Iterates sub-leaves 0, 1, 2, ... until cache type = 0 +/// (no more caches). Each cache is described by: +/// EAX[4:0] = cache type (0=null, 1=data, 2=instr, 3=unified) +/// EAX[7:5] = cache level (1, 2, 3, ...) +/// EBX[11:0] = ways - 1 +/// EBX[21:12]= partitions - 1 +/// EBX[31:22]= line_size_bytes - 1 +/// ECX = sets - 1 +/// cache_bytes = (ways+1) * (partitions+1) * (line_size+1) * (sets+1) +/// Returns zeros for unknown levels (e.g. on a CPU without an L3, or if +/// the CPUID interface is unavailable). Memoised; called at most once. +pub fn cache_sizes() -> CacheSizes { + static CACHE: OnceLock = OnceLock::new(); + *CACHE.get_or_init(|| { + let mut out = CacheSizes::default(); + for sub in 0..16 { + let r = std::arch::x86_64::__cpuid_count(4, sub); + let cache_type = r.eax & 0x1F; + if cache_type == 0 { + break; + } + let level = (r.eax >> 5) & 0x7; + let ways = ((r.ebx >> 22) & 0x3FF) + 1; + let partitions = ((r.ebx >> 12) & 0x3FF) + 1; + let line_size = (r.ebx & 0xFFF) + 1; + let sets = r.ecx + 1; + let bytes = (ways as usize) * (partitions as usize) + * (line_size as usize) * (sets as usize); + // type=1 (data), type=3 (unified) for L1d / L2 / L3 + match (level, cache_type) { + (1, 1) => out.l1d_bytes = bytes, + (2, 1 | 3) => out.l2_bytes = bytes, + (3, 1 | 3) => out.l3_bytes = bytes, + _ => {} + } + } + out + }) +} + /// Detect AMX-INT8 + AMX-TILE via CPUID leaf 7 sub-leaf 0 (EDX bits 24-25). /// Stable-Rust friendly: `is_x86_feature_detected!("amx-int8")` is gated on /// the nightly `x86_amx_intrinsics` feature, so we read CPUID by hand. diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index 3ca407358e..9a373a496a 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -190,7 +190,14 @@ pub fn plug_avx512amx_int8(ops: &mut Ops) { avx512amx_mmm_i32_8x8.mmm() } }); - log::info!("qmmm_i32: x86_64/avx512amx_int8 (16x16 + 8x8 adaptive) activated"); + let c = super::amx::cache_sizes(); + log::info!( + "qmmm_i32: x86_64/avx512amx_int8 (16x16 + 8x8 adaptive) activated; \ + L1d={} KB, L2={} KB, L3={} KB", + c.l1d_bytes / 1024, + c.l2_bytes / 1024, + c.l3_bytes / 1024, + ); } pub fn plug_avx2(ops: &mut Ops) { From 81772301c39b741444ccdde47311067595a3fb01 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Mon, 1 Jun 2026 22:43:18 +0000 Subject: [PATCH 11/21] linalg/x86_64: add AMX bf16 16x16 kernel for f32 matmul (TDPBF16PS) Adds an AMX-BF16 path to mmm_f32 mirroring the int8 16x16 work: f32 inputs are truncated to bf16 at pack time (round-to-nearest-even, matching Intel VCVTNEPS2BF16) and the inner loop calls TDPBF16PS (16M x 16N x 32K bf16 = 8192 fma per instruction). The f32 accumulators differ from a pure-f32 FMA reference by ~1/2^8 relative per multiply (bf16 = 8 mantissa bits vs f32's 23) -- the same precision profile as oneDNN "fast-math" f32 matmul on AMX, acceptable for inference workloads (LLMs, CNNs) that already tolerate bf16. * avx512amx_mmm_f32_16x16.S.j2 -- 16x16 row-major-zmm-accumulator kernel with the same oneDNN-style prefetch pattern as the i32 16x16 (A: tileloadd + 1x prefetcht0, B: tileloaddt1 + 6x prefetcht1). q_scale/q_shr/q_shl jump to "unsupported" (not meaningful for f32). * amx_bf16.rs -- PackedAmxBf16A (A side, M-major within panel, K padded to multiples of 32 bf16) and PackedBf16K2 (B side, K=2-inner analog of PackedI8K4). f32_to_bf16_rne() does the lane-level conversion at pack time. * amx.rs -- request_amx_tile_xcomp_perm() extracted so the int8 and bf16 has_*() gates share the single XSAVE permission request (arch_prctl is process-wide; only one call is needed for both data types). * build.rs -- dummy_bf16.S probe checks the assembler accepts TDPBF16PS, gated independently of the int8 probe so a future AMX-FP16/FP8 (Diamond Rapids+) probe slots in alongside. Sets tract_amx_bf16 cfg on success. * mmm.rs -- registers the kernel as packing[1]=f32f32_bf16 and overlays the AMX 16x16 path onto mmm_f32 for problems where every axis comfortably fills at least one tile (M>=16, N>=16, K>=32). Smaller problems defer to the prior AVX-512/FMA picker, same shape-adaptive pattern as qmmm_i32. --- linalg/build.rs | 62 ++- linalg/src/x86_64_fma.rs | 1 + linalg/src/x86_64_fma/amx.rs | 25 +- linalg/src/x86_64_fma/amx_bf16.rs | 315 +++++++++++ linalg/src/x86_64_fma/mmm.rs | 61 +++ linalg/x86_64/avx512amx/dummy_bf16.S | 28 + .../x86_64/fma/avx512amx_mmm_f32_16x16.S.j2 | 508 ++++++++++++++++++ 7 files changed, 982 insertions(+), 18 deletions(-) create mode 100644 linalg/src/x86_64_fma/amx_bf16.rs create mode 100644 linalg/x86_64/avx512amx/dummy_bf16.S create mode 100644 linalg/x86_64/fma/avx512amx_mmm_f32_16x16.S.j2 diff --git a/linalg/build.rs b/linalg/build.rs index b7c10fb2c7..c2d2f5e5ce 100644 --- a/linalg/build.rs +++ b/linalg/build.rs @@ -86,6 +86,21 @@ fn assembler_supports_amx_int8() -> bool { .is_ok() } +// Probe whether the target assembler can assemble AMX bf16 instructions +// (`tdpbf16ps`). Both int8 and bf16 AMX mnemonics require binutils >= 2.34, +// so in practice this probe succeeds whenever `assembler_supports_amx_int8` +// does. Provided separately so the two cfgs are independently controlled +// and users on exotic toolchains can opt-out of just the bf16 kernel. +fn assembler_supports_amx_bf16() -> bool { + cc::Build::new() + .file("x86_64/avx512amx/dummy_bf16.S") + .cargo_metadata(false) + .cargo_warnings(false) + .warnings(false) + .try_compile("tract_amx_bf16_probe") + .is_ok() +} + fn include_sve() -> bool { // SVE/SVE2 lives on ARMv9 server/mobile cores (Neoverse V1+/N2+, Cortex-X2+, // Graviton 3/4) — Linux aarch64. No Apple silicon has SVE. @@ -186,6 +201,8 @@ fn main() { // Set below only when the x86_64 assembler accepts AMX int8 mnemonics // (avoids breaking the build on toolchains predating AMX). println!("cargo:rustc-check-cfg=cfg(tract_amx_int8)"); + // Set below only when the assembler accepts AMX bf16 mnemonics (tdpbf16ps). + println!("cargo:rustc-check-cfg=cfg(tract_amx_bf16)"); match arch.as_ref() { "x86_64" => { @@ -197,25 +214,35 @@ fn main() { }); files.extend(preprocess_files("x86_64/avx512", &[], &suffix, false)); - // Pull the AMX kernel template out of the generic fma bulk-compile - // so it can be gated behind the assembler probe below. Its - // mnemonics (`ldtilecfg`, `tdpbssd`, `tilezero`, `tilerelease`) - // require gas >= 2.34; old toolchains (Debian stretch's binutils - // 2.28) would otherwise fail the whole build. The kernel template - // lives next to its Jinja partials (`dispatcher.j2`, the i32 - // epilogue includes); only the compile of the rendered .S is - // moved. + // Pull the AMX kernel templates out of the generic fma bulk-compile + // so they can be gated behind assembler probes below. All AMX + // mnemonics require gas >= 2.34; old toolchains (Debian stretch's + // binutils 2.28) would otherwise fail the whole build. + // + // Split by accumulator type: + // avx512amx_*_i32_* → tdpbssd → gated on tract_amx_int8 + // avx512amx_*_f32_* → tdpbf16ps → gated on tract_amx_bf16 let amx_int8_files: Vec = files .iter() .filter(|f| { f.file_name() .and_then(|n| n.to_str()) - .map(|n| n.starts_with("avx512amx_")) + .map(|n| n.starts_with("avx512amx_") && n.contains("_i32_")) .unwrap_or(false) }) .cloned() .collect(); - files.retain(|f| !amx_int8_files.contains(f)); + let amx_bf16_files: Vec = files + .iter() + .filter(|f| { + f.file_name() + .and_then(|n| n.to_str()) + .map(|n| n.starts_with("avx512amx_") && n.contains("_f32_")) + .unwrap_or(false) + }) + .cloned() + .collect(); + files.retain(|f| !amx_int8_files.contains(f) && !amx_bf16_files.contains(f)); if os == "windows" { if use_masm() { @@ -296,6 +323,21 @@ fn main() { .compile("x86_64_avx512amx"); println!("cargo:rustc-cfg=tract_amx_int8"); } + + // AMX bf16 kernel for f32 matmul (tdpbf16ps). Same toolchain + // requirement and Unix-only constraint as the int8 path. When the + // probe fails, the `tract_amx_bf16` cfg stays unset and + // `plug_avx512amx_bf16` is compiled out — `mmm_f32` then falls + // back to AVX-512 / FMA without any build error. + if os != "windows" + && !amx_bf16_files.is_empty() + && assembler_supports_amx_bf16() + { + cc::Build::new() + .files(&amx_bf16_files) + .compile("x86_64_avx512amx_bf16"); + println!("cargo:rustc-cfg=tract_amx_bf16"); + } } "arm" | "armv7" => { let files = preprocess_files("arm32/armvfpv2", &[], &suffix, false); diff --git a/linalg/src/x86_64_fma.rs b/linalg/src/x86_64_fma.rs index 7a057c51b1..1d7750f0be 100644 --- a/linalg/src/x86_64_fma.rs +++ b/linalg/src/x86_64_fma.rs @@ -11,6 +11,7 @@ pub mod act_f16; pub mod act_f16_fp16; pub mod amx; +pub mod amx_bf16; pub mod by_scalar; pub mod erf; mod intel; diff --git a/linalg/src/x86_64_fma/amx.rs b/linalg/src/x86_64_fma/amx.rs index 48f11839e6..685a9618c8 100644 --- a/linalg/src/x86_64_fma/amx.rs +++ b/linalg/src/x86_64_fma/amx.rs @@ -93,6 +93,8 @@ fn cpu_has_amx_int8() -> bool { /// Linux only: ask the kernel for permission to use the AMX tile-data XSAVE /// state via `arch_prctl(ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)`. Returns /// true if the kernel grants permission (or if the process already has it). +/// Exposed via `request_amx_tile_xcomp_perm()` below so the bf16 path can +/// share the same OS-level gate. #[cfg(target_os = "linux")] unsafe fn request_amx_xcomp_perm() -> bool { // x86_64 syscall: rax=158 (arch_prctl), rdi=0x1023 (REQ_XCOMP_PERM), @@ -113,16 +115,14 @@ unsafe fn request_amx_xcomp_perm() -> bool { rc == 0 } -/// Returns true iff Intel AMX int8 is available AND the OS has granted this -/// process permission to use the AMX tile-data XSAVE state. Result is -/// memoised — the arch_prctl call has process-wide effect and only needs to -/// run once. -pub fn has_amx_int8() -> bool { +/// Memoised wrapper around `request_amx_xcomp_perm` -- arch_prctl has a +/// process-wide effect and only needs to be called once for the whole +/// lifetime of the process. Returns true iff the OS has granted permission +/// for XFEATURE_XTILEDATA (and hence enables both AMX int8 AND AMX bf16 +/// kernels). Returns false on non-Linux. +pub fn request_amx_tile_xcomp_perm() -> bool { static GATE: OnceLock = OnceLock::new(); *GATE.get_or_init(|| { - if !cpu_has_amx_int8() { - return false; - } #[cfg(target_os = "linux")] { unsafe { request_amx_xcomp_perm() } @@ -134,6 +134,15 @@ pub fn has_amx_int8() -> bool { }) } +/// Returns true iff Intel AMX int8 is available AND the OS has granted this +/// process permission to use the AMX tile-data XSAVE state. Result is +/// memoised — the arch_prctl call has process-wide effect and only needs to +/// run once. +pub fn has_amx_int8() -> bool { + static GATE: OnceLock = OnceLock::new(); + *GATE.get_or_init(|| cpu_has_amx_int8() && request_amx_tile_xcomp_perm()) +} + /// AMX-friendly A packing: per `r`-row panel, M-rows are laid out row-major /// across `K_padded = ceil(K / 64) * 64` contiguous bytes per row. AMX's /// `tileloadd` with stride = K_padded reads exactly 8 contiguous M-rows of diff --git a/linalg/src/x86_64_fma/amx_bf16.rs b/linalg/src/x86_64_fma/amx_bf16.rs new file mode 100644 index 0000000000..ba64cfa16a --- /dev/null +++ b/linalg/src/x86_64_fma/amx_bf16.rs @@ -0,0 +1,315 @@ +// Intel AMX bf16 support: f32 -> bf16 packers and the AMX bf16 runtime gate. +// +// The kernel `avx512amx_mmm_f32_16x16` uses TDPBF16PS (bf16 x bf16 -> f32) to +// accelerate f32 matmul on Sapphire Rapids+ AMX hardware. The inputs are +// truncated from f32 to bf16 at pack time (round-to-nearest-even, matching +// Intel's VCVTNEPS2BF16 semantics); the f32 accumulators are bit-identical +// to a "scalar bf16 multiply + f32 accumulate" reference but DIFFER from a +// pure-f32 FMA reference by ~1 / 2^8 relative per multiply (bf16 has 8 +// mantissa bits vs f32's 23). This precision loss is the same as oneDNN +// "fast-math" f32 matmul on AMX and is acceptable for inference workloads +// (LLMs, CNNs) that already tolerate bf16. +// +// Tile geometry mirrors the i32 16x16 kernel: 16 rows x 64 colsb per tile. +// Per TDPBF16PS: 16 M-rows x 16 N-cols x 32 K-bf16 = 8192 fma operations +// per single instruction -- the same throughput as TDPBSSD. + +use std::sync::OnceLock; + +use tract_data::internal::*; + +use crate::WeightType; +use crate::frame::mmm::{ + EagerPackedInput, MMMInputFormat, MMMInputValue, PackedExoticFact, PackedMatrixStorage, +}; + +/// Detect AMX-BF16 + AMX-TILE via CPUID leaf 7 sub-leaf 0 (EDX bits 22, 24). +/// AMX-BF16 is the bit-22 capability; AMX-TILE (bit 24) is mandatory for any +/// AMX use. Returns false unless both are present. +fn cpu_has_amx_bf16() -> bool { + if !std::is_x86_feature_detected!("avx512f") { + return false; + } + let r = std::arch::x86_64::__cpuid_count(7, 0); + const AMX_BF16: u32 = 1 << 22; + const AMX_TILE: u32 = 1 << 24; + (r.edx & AMX_BF16) != 0 && (r.edx & AMX_TILE) != 0 +} + +/// Returns true iff Intel AMX bf16 is available AND the OS has granted this +/// process permission to use the AMX tile-data XSAVE state. Reuses the +/// arch_prctl XCOMP-perm request mechanism from the int8 path -- the same +/// XFEATURE_XTILEDATA permission gates both data types. +pub fn has_amx_bf16() -> bool { + static GATE: OnceLock = OnceLock::new(); + *GATE.get_or_init(|| cpu_has_amx_bf16() && super::amx::request_amx_tile_xcomp_perm()) +} + +/// Convert an f32 to bf16 with round-to-nearest-even (matches Intel's +/// VCVTNEPS2BF16). NaN inputs are preserved as quiet NaN. Used by the bf16 +/// packers below (scalar; AMX hardware is on Sapphire Rapids+ which has the +/// AVX-512-BF16 intrinsic for batched conversion, but packing is amortised +/// over many kernel calls so the scalar path is fine). +#[inline] +pub fn f32_to_bf16_rne(x: f32) -> u16 { + let bits = x.to_bits(); + // NaN check: exponent all-ones and mantissa nonzero. + if (bits & 0x7F80_0000) == 0x7F80_0000 && (bits & 0x007F_FFFF) != 0 { + // Quiet NaN: set the top mantissa bit of the bf16 result. + ((bits >> 16) as u16) | 0x0040 + } else { + // round-to-nearest-even: add 0x7FFF + (lsb of bf16) before truncating. + let lsb = (bits >> 16) & 1; + let rounding = 0x0000_7FFF + lsb; + (bits.wrapping_add(rounding) >> 16) as u16 + } +} + +/// AMX-friendly A packing for f32 matmul via bf16. Per `r`-row panel, the +/// M-rows are laid out row-major in bf16 across `K_padded` contiguous bf16 +/// per row (K_padded = ceil(K/32)*32, so each row is a whole number of +/// AMX K-tile widths). Source is f32; conversion happens at pack time. +/// +/// panel_bytes = r * K_padded * 2 (each bf16 = 2 bytes) +/// +/// AMX `tileloadd` with stride = K_padded*2 reads exactly 16 M-rows of +/// 64 bytes (= 32 bf16) per call -- one inner-K iter's worth. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct PackedAmxBf16A { + pub r: usize, + pub align: usize, +} + +impl PackedAmxBf16A { + pub fn new(r: usize) -> Self { + PackedAmxBf16A { r, align: 64 } + } + fn k_padded(&self, k: usize) -> usize { + k.div_ceil(32) * 32 + } + fn panel(&self, k: usize) -> usize { + self.k_padded(k) * self.r * 2 + } + pub fn single_panel_len(&self, k: usize) -> usize { + self.panel(k) + } + pub fn len(&self, k: usize, mn: usize) -> usize { + mn.div_ceil(self.r) * self.panel(k) + } + pub fn alignment(&self) -> usize { + self.align + } + + pub fn pack_view( + &self, + t: &TensorView, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + let k = t.shape()[k_axis]; + let mn = t.shape()[mn_axis]; + let kp = self.k_padded(k); + let pl = kp * self.r * 2; // bytes per panel + let panels = mn.div_ceil(self.r); + let st = t.strides(); + let (ks, ms) = (st[k_axis], st[mn_axis]); + let mut blob = unsafe { Blob::new_for_size_and_align(panels * pl, self.align) }; + blob.as_bytes_mut().fill(0); + unsafe { + let src = t.as_ptr_unchecked::(); + let dst = blob.as_mut_ptr() as *mut u16; + for p in 0..panels { + let pw = self.r.min(mn - p * self.r); + let panel = dst.add(p * (kp * self.r)); // panel_offset in u16 elements + let mn0 = (p * self.r) as isize; + for lm in 0..pw { + let drow = panel.add(lm * kp); + let srow_base = src.offset((mn0 + lm as isize) * ms); + for kk in 0..k { + let v = *srow_base.offset(kk as isize * ks); + *drow.add(kk) = f32_to_bf16_rne(v); + } + } + } + } + Ok(Box::new(EagerPackedInput { + fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k }, + packed: blob.into(), + panel_bytes: pl, + mn, + })) + } +} + +impl std::fmt::Display for PackedAmxBf16A { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "AmxBf16A[{}]", self.r) + } +} + +impl MMMInputFormat for PackedAmxBf16A { + fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult { + Ok(PackedMatrixStorage::new(self.prepare_one(t, k_axis, mn_axis)?) + .into_tensor(t.datum_type())) + } + fn prepare_one( + &self, + t: &Tensor, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + self.pack_view(&t.view(), k_axis, mn_axis) + } + fn k_alignment(&self) -> usize { + // tdpbf16ps consumes 32 bf16 per K-step. + 32 + } + fn r(&self) -> usize { + self.r + } + fn precursor(&self) -> WeightType { + WeightType::Plain(f32::datum_type()) + } + fn merge_with<'o, 'a: 'o, 'b: 'o>( + &'a self, + o: &'b dyn MMMInputFormat, + ) -> Option<&'o dyn MMMInputFormat> { + o.downcast_ref::().filter(|x| x.r == self.r).map(|_| self as _) + } + fn mem_size(&self, k: TDim, mn: TDim) -> TDim { + mn.divceil(self.r) * self.panel(k.to_usize().unwrap_or(0)) + } + fn extract_at_mn_f16(&self, _: &EagerPackedInput, _: usize, _: &mut [f16]) -> TractResult<()> { + bail!("no f16 extract") + } + fn extract_at_mn_f32(&self, _: &EagerPackedInput, _: usize, _: &mut [f32]) -> TractResult<()> { + bail!("no f32 extract") + } +} + +/// AMX-friendly B packing for f32 matmul via bf16 (analog of PackedI8K4 but +/// K=2-inner instead of K=4-inner -- tdpbf16ps groups 2 bf16 per K-step). +/// +/// Per K=2 block: r N-cols x 2 K-bf16 = r * 2 * 2 bytes = 4r bytes. +/// Block layout: byte (n*4 + ki*2..(n*4 + ki*2 + 2)) = bf16 of B[2kb+ki, n]. +/// For r=16: 64 bytes per K=2 block, 16 blocks per K=32 AMX tile -> 1024 B. +/// +/// One AMX `tileloadd` with stride = 4r bytes reads 16 K-pair-rows of +/// r * 4 bytes each = one inner-K iter's worth of B. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct PackedBf16K2 { + pub r: usize, + pub align: usize, +} + +impl PackedBf16K2 { + pub fn new(r: usize) -> Self { + PackedBf16K2 { r, align: 64 } + } + fn k_padded(&self, k: usize) -> usize { + k.div_ceil(2) * 2 + } + fn panel(&self, k: usize) -> usize { + self.k_padded(k) * self.r * 2 + } + pub fn single_panel_len(&self, k: usize) -> usize { + self.panel(k) + } + pub fn len(&self, k: usize, mn: usize) -> usize { + mn.div_ceil(self.r) * self.panel(k) + } + pub fn alignment(&self) -> usize { + self.align + } + pub fn pack_view( + &self, + t: &TensorView, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + let k = t.shape()[k_axis]; + let mn = t.shape()[mn_axis]; + let kp = self.k_padded(k); + let pl = kp * self.r * 2; // bytes per panel + let panels = mn.div_ceil(self.r); + let st = t.strides(); + let mut blob = unsafe { Blob::new_for_size_and_align(panels * pl, self.align) }; + blob.as_bytes_mut().fill(0); + let (ks, ms) = (st[k_axis], st[mn_axis]); + let kblocks = kp / 2; + unsafe { + let src = t.as_ptr_unchecked::(); + let dst = blob.as_mut_ptr() as *mut u16; + for p in 0..panels { + let pw = self.r.min(mn - p * self.r); + let panel = dst.add(p * (kp * self.r)); + let mn0 = (p * self.r) as isize; + for kb in 0..kblocks { + for ki in 0..2 { + let kk = kb * 2 + ki; + if kk >= k { + break; + } + let srow = src.offset(kk as isize * ks + mn0 * ms); + let dblock = panel.add(kb * self.r * 2 + ki); + for lm in 0..pw { + let v = *srow.offset(lm as isize * ms); + *dblock.add(lm * 2) = f32_to_bf16_rne(v); + } + } + } + } + } + Ok(Box::new(EagerPackedInput { + fact: PackedExoticFact { format: Box::new(self.clone()), mn: mn.to_dim(), k }, + packed: blob.into(), + panel_bytes: pl, + mn, + })) + } +} + +impl std::fmt::Display for PackedBf16K2 { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "Bf16K2[{}]", self.r) + } +} + +impl MMMInputFormat for PackedBf16K2 { + fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult { + Ok(PackedMatrixStorage::new(self.prepare_one(t, k_axis, mn_axis)?) + .into_tensor(t.datum_type())) + } + fn prepare_one( + &self, + t: &Tensor, + k_axis: usize, + mn_axis: usize, + ) -> TractResult> { + self.pack_view(&t.view(), k_axis, mn_axis) + } + fn k_alignment(&self) -> usize { + 2 + } + fn r(&self) -> usize { + self.r + } + fn precursor(&self) -> WeightType { + WeightType::Plain(f32::datum_type()) + } + fn merge_with<'o, 'a: 'o, 'b: 'o>( + &'a self, + o: &'b dyn MMMInputFormat, + ) -> Option<&'o dyn MMMInputFormat> { + o.downcast_ref::().filter(|x| x.r == self.r).map(|_| self as _) + } + fn mem_size(&self, k: TDim, mn: TDim) -> TDim { + mn.divceil(self.r) * self.panel(k.to_usize().unwrap_or(0)) + } + fn extract_at_mn_f16(&self, _: &EagerPackedInput, _: usize, _: &mut [f16]) -> TractResult<()> { + bail!("no f16 extract") + } + fn extract_at_mn_f32(&self, _: &EagerPackedInput, _: usize, _: &mut [f32]) -> TractResult<()> { + bail!("no f32 extract") + } +} diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index 9a373a496a..86f36b7c9e 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -5,10 +5,14 @@ use crate::mmm::MatMatMul; use crate::pack::{PackedFormat, PackedI8K4}; use super::amx::{PackedAmxA, has_amx_int8}; +#[cfg(tract_amx_bf16)] +use super::amx_bf16::{PackedAmxBf16A, PackedBf16K2, has_amx_bf16}; use super::*; #[cfg(tract_amx_int8)] const AVX512AMX: fn() -> bool = has_amx_int8; +#[cfg(tract_amx_bf16)] +const AVX512AMX_BF16: fn() -> bool = has_amx_bf16; /// One candidate kernel in a dispatcher's pool, with its tile geometry /// and a relative-throughput scale (1.0 = baseline, used to break @@ -136,6 +140,23 @@ MMMExternKernel! { avx512amx_mmm_i32_16x16(16,16)@(64,4) where(AVX512AMX) store(i8) } +// AMX bf16 16x16 kernel for f32 matmul: uses TDPBF16PS (bf16 x bf16 -> f32). +// f32 inputs are truncated to bf16 at pack time (round-to-nearest-even, matching +// Intel VCVTNEPS2BF16). One tdpbf16ps consumes 16M x 16N x 32K bf16 = 8192 fma +// per instruction. f32 accumulators differ from a pure-f32 reference by ~1/2^8 +// relative per multiply (bf16 = 8 mantissa bits vs f32's 23) -- same precision +// loss profile as oneDNN "fast-math" f32 matmul on AMX, acceptable for +// inference workloads (LLMs, CNNs) that already tolerate bf16. +// +// Default packing[0] (the framework's PackedFormat) is retained so the +// kernel can still be selected for f32 paths even when the BF16 packer +// isn't a precursor match; packing[1] is the fast bf16-from-f32 path. +#[cfg(tract_amx_bf16)] +MMMExternKernel! { avx512amx_mmm_f32_16x16(16,16)@(64,4) where(AVX512AMX_BF16) + packing[1] = f32f32_bf16 => |k| k.with_packing(PackedAmxBf16A::new(16), PackedBf16K2::new(16)); + quality(ManuallyOptimized) +} + pub fn plug(ops: &mut Ops) { if is_x86_feature_detected!("avx2") { plug_avx2(ops); @@ -153,6 +174,13 @@ pub fn plug(ops: &mut Ops) { plug_avx512amx_int8(ops); } } + // AMX bf16 for f32 matmul is independent of int8/VNNI gates: + // a future Xeon SKU could ship AMX-BF16 without VNNI, and the + // permission gate is shared with the int8 path inside has_amx_bf16(). + #[cfg(tract_amx_bf16)] + if has_amx_bf16() { + plug_avx512amx_bf16(ops); + } } } } @@ -165,6 +193,39 @@ pub fn plug_avx512vnni(ops: &mut Ops) { log::info!("qmmm_i32: x86_64/avx512vnni activated"); } +#[cfg(tract_amx_bf16)] +pub fn plug_avx512amx_bf16(ops: &mut Ops) { + ops.mmm_impls.push(avx512amx_mmm_f32_16x16.mmm()); + // Save the previously-installed f32 picker so we can defer to it when + // the AMX kernel isn't a good fit (small M/N, or K < 32 -- one TDPBF16PS + // consumes 32 bf16 K-lanes so the panel must have at least one full step). + let prev: crate::MMMImpl = std::mem::replace( + &mut ops.mmm_f32, + Box::new(|_, _, _| unreachable!()), + ); + ops.mmm_f32 = Box::new(move |m, k, n| { + let big = |o: Option, t: usize| o.is_none_or(|v| v >= t); + // Same dispatch shape as the int8 16x16/8x8 split: hand off to AMX + // only when each axis comfortably fills at least one tile. The 32-K + // threshold matches PackedAmxBf16A::k_alignment() (one tdpbf16ps = + // 32 bf16 K-lanes); below that, the AVX-512 / FMA path's smaller + // tiles waste less work. + if big(m, 16) && big(n, 16) && big(k, 32) { + avx512amx_mmm_f32_16x16.mmm() + } else { + prev(m, k, n) + } + }); + let c = super::amx::cache_sizes(); + log::info!( + "mmm_f32: x86_64/avx512amx_bf16 (16x16) overlay activated; \ + L1d={} KB, L2={} KB, L3={} KB", + c.l1d_bytes / 1024, + c.l2_bytes / 1024, + c.l3_bytes / 1024, + ); +} + #[cfg(tract_amx_int8)] pub fn plug_avx512amx_int8(ops: &mut Ops) { ops.mmm_impls.push(avx512amx_mmm_i32_8x8.mmm()); diff --git a/linalg/x86_64/avx512amx/dummy_bf16.S b/linalg/x86_64/avx512amx/dummy_bf16.S new file mode 100644 index 0000000000..03f5a8d7a4 --- /dev/null +++ b/linalg/x86_64/avx512amx/dummy_bf16.S @@ -0,0 +1,28 @@ +// Build-time capability probe for the assembler, used by build.rs +// (assembler_supports_amx_bf16). Checks that the assembler accepts the +// TDPBF16PS mnemonic (AMX bf16 dot-product). Same binutils version requirement +// as AMX int8 (>= 2.34); provided as a separate probe so the two cfgs can be +// set independently if needed. Not linked into anything. +.intel_syntax noprefix +.text +.globl tract_amx_bf16_probe +tract_amx_bf16_probe: + push rbp + mov rbp, rsp + sub rsp, 64 + mov qword ptr [rsp], 0 + mov qword ptr [rsp+8], 0 + mov qword ptr [rsp+16], 0 + mov qword ptr [rsp+24], 0 + mov qword ptr [rsp+32], 0 + mov qword ptr [rsp+40], 0 + mov qword ptr [rsp+48], 0 + mov qword ptr [rsp+56], 0 + mov byte ptr [rsp], 1 // palette = 1 + ldtilecfg [rsp] + tilezero tmm0 + tdpbf16ps tmm0, tmm1, tmm2 // AMX bf16: the instruction this probe checks + tilerelease + mov rsp, rbp + pop rbp + ret diff --git a/linalg/x86_64/fma/avx512amx_mmm_f32_16x16.S.j2 b/linalg/x86_64/fma/avx512amx_mmm_f32_16x16.S.j2 new file mode 100644 index 0000000000..c635589373 --- /dev/null +++ b/linalg/x86_64/fma/avx512amx_mmm_f32_16x16.S.j2 @@ -0,0 +1,508 @@ +// vim: set syntax=asm : +// +// Intel AMX bf16 GEMM kernel, 16 M-rows x 16 N-cols f32 accumulator output. +// +// One `tdpbf16ps tmm0, tmm1, tmm2` instruction performs: +// tmm0[m, n] += sum_{k=0..31} A[m, k] * B[k, n] (multiplies in bf16, +// accumulates in f32) +// for m=0..15, n=0..15: 16 * 16 * 32 = 8192 fma per single instruction -- +// the same throughput as TDPBSSD on the same hardware. Accelerates f32 +// matmul on Sapphire Rapids+ at the cost of bf16 input truncation +// (~1/256 relative error per multiply; sqrt(K) compounded by FMA chain). +// +// Tile geometry (palette 1, the maximum-bytes AMX tile shape): +// tmm0 = C accumulator: 16 rows x 64 colsb = 16 M-rows x 16 N-cols of f32 +// tmm1 = A tile: 16 rows x 64 colsb = 16 M-rows x 32 K-bf16 / iter +// tmm2 = B tile: 16 rows x 64 colsb = 16 K-pair-rows x 16 N x 2 bf16 +// +// A is packed via PackedAmxBf16A(16): per panel of 16 M-rows, row-major +// within the panel, K-bf16 contiguous along the row, K_padded = +// ceil(K/32)*32 bf16. Source f32 is truncated to bf16 at pack time using +// round-to-nearest-even (matches VCVTNEPS2BF16 semantics). +// +// B is packed via PackedBf16K2(16): per K=2 block, 16 N-cols x 2 K-bf16 = +// 64 bytes; 16 K-blocks per tmm2 tile. Source f32 -> bf16 same as A. +// +// REGISTER LAYOUT (mirrors the i32 16x16 sibling): +// zmm0..zmm15 = accumulators, ROW-MAJOR: zmm{m} = row m of C as 16 f32 +// lanes [C[m, 0], C[m, 1], ..., C[m, 15]]. + +{% if msvc %} + +_text segment +avx512amx_mmm_f32_16x16_{{suffix}} proc + +{% else %} + +.intel_syntax noprefix +.text +.p2align 5 +.globl {{G}}avx512amx_mmm_f32_16x16_{{suffix}} +{{G}}avx512amx_mmm_f32_16x16_{{suffix}}: +.cfi_startproc + +{% endif %} + + push rbp + mov rbp, rsp + +{% if family == "windows" %} + and rsp,-16 + lea rsp,[rsp-160] + vmovaps [rsp], xmm6 + vmovaps [rsp+16*1],xmm7 + vmovaps [rsp+16*2],xmm8 + vmovaps [rsp+16*3],xmm9 + vmovaps [rsp+16*4],xmm10 + vmovaps [rsp+16*5],xmm11 + vmovaps [rsp+16*6],xmm12 + vmovaps [rsp+16*7],xmm13 + vmovaps [rsp+16*8],xmm14 + vmovaps [rsp+16*9],xmm15 + + push rdi + push rsi + + mov rdi, rcx + +{% endif %} + + push rbx + push r12 + push r13 + push r14 + push r15 + + sub rsp, 8 + +{% if family == "unix" %} +.cfi_def_cfa_offset 64 +{% endif %} + + stmxcsr [rsp + 4] +{% if msvc %} + mov rax, 1FC0h +{% else %} + mov rax, 0x1FC0 +{% endif %} + mov [rsp], eax + ldmxcsr [rsp] + + // Reserve 64 bytes for the AMX tile-config block, zero it, populate + // palette + dims (all three tiles are 16 rows x 64 colsb). Same shape + // as the i32 16x16 sibling. + sub rsp, 64 + vpxor xmm15, xmm15, xmm15 + vmovdqu [rsp ], xmm15 + vmovdqu [rsp + 16], xmm15 + vmovdqu [rsp + 32], xmm15 + vmovdqu [rsp + 48], xmm15 + mov byte ptr [rsp + 0 ], 1 // palette = 1 + mov word ptr [rsp + 16], 64 // colsb[0] = 64 (tmm0) + mov word ptr [rsp + 18], 64 // colsb[1] = 64 (tmm1) + mov word ptr [rsp + 20], 64 // colsb[2] = 64 (tmm2) + mov byte ptr [rsp + 48], 16 // rows[0] = 16 (tmm0) + mov byte ptr [rsp + 49], 16 // rows[1] = 16 (tmm1) + mov byte ptr [rsp + 50], 16 // rows[2] = 16 (tmm2) + ldtilecfg [rsp] + +{% include "dispatcher.j2" %} + +{{L}}clear: + {% for r in range(0, 16) %} + vpxorq zmm{{r}}, zmm{{r}}, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_mat_mul: + mov r12, [rdi + 32] // packing + mov rbx, [rdi + 24] // B + mov rax, [rdi + 16] // A + + mov rcx, [rdi + 8] // k + test rcx, rcx + jz {{L}}non_linear_loop + + cmp r12, 1 + je {{L}}main_loop_packed_packed_bf16 + +{{L}}main_loop_packed_packed: + // Generic f32 x f32 fallback path (non-AMX). For row-major + // accumulators zmm{m} = row m of C, accumulating C[m, n] += A[m, k] * + // B[k, n]: load 16 B values for this K row into zmm16, then for each + // m broadcast A[m, k] and FMA add to zmm{m}. + vmovups zmm16, [rbx] // 16 f32 of B at this K row + + {% for m in range(0, 16) %} + vbroadcastss zmm17, dword ptr [rax + {{m}} * 4] + vfmadd231ps zmm{{m}}, zmm16, zmm17 + {% endfor %} + + add rax, 64 // 16 f32 lanes per K step + add rbx, 64 + dec rcx + jnz {{L}}main_loop_packed_packed + + jmp {{L}}non_linear_loop + +{{L}}main_loop_packed_packed_bf16: + // AMX bf16 layout: + // A panel: 16 M-rows x K_padded bf16 ROW-major within the panel + // (PackedAmxBf16A, K_padded = ceil(K/32)*32 bf16 = + // ceil(K/32)*64 bytes per row). + // B panel: PackedBf16K2(16) -- 16 N-cols x 2 K-bf16 per K=2 block, + // with 16 K-blocks per tdpbf16ps iter (16 K-pair-rows x + // 64 colsb). + // + // Per tdpbf16ps: tmm0[m, n] += sum_{k=0..31} A[m, k] * B[k, n] + // with multiplies in bf16 and accumulation in f32. Inner loop steps + // along K in 32-bf16 chunks (= 64 bytes per A row). + + // r8 <- K_padded_in_bytes = ceil(k/32)*64 = byte-stride between A's + // M-rows. (Each bf16 is 2 bytes, so K_padded_bf16 * 2.) + mov r8, rcx + add r8, 31 + and r8, -32 + shl r8, 1 // *2 (bf16 = 2 bytes) + + // rcx <- ceil(k/32) = number of K=32 AMX inner iterations. + add rcx, 31 + shr rcx, 5 + + // r9 <- 64 = byte-stride between B's K-pair-rows (16 N-cols * 4 bytes + // per K-pair = 16 * 4 = 64). + mov r9, 64 + + tilezero tmm0 + +{{L}}loop_32k_amx_bf16_16x16: + // oneDNN-aligned cache strategy (same as the i32 sibling): + // A -> cached (tileloadd + prefetcht0 to L1), reused across N-tiles. + // B -> non-temporal (tileloaddt1 + prefetcht1 to L2), streams once. + // Each iter advances A by 64 bytes and B by 1024 bytes; we prime the + // first 6 of next-iter's 16 B cache lines and let the SPR HW stream + // prefetcher cover the remaining 10. + prefetcht0 [rax + 64] + prefetcht1 [rbx + 1024] + prefetcht1 [rbx + 1088] + prefetcht1 [rbx + 1152] + prefetcht1 [rbx + 1216] + prefetcht1 [rbx + 1280] + prefetcht1 [rbx + 1344] + tileloadd tmm1, [rax + r8 * 1] // A tile (cached): stride = K_padded_bytes + tileloaddt1 tmm2, [rbx + r9 * 1] // B tile (non-temporal): stride = 64 + tdpbf16ps tmm0, tmm1, tmm2 + add rax, 64 // +32 bf16 in A row 0 + add rbx, 1024 // 16 K-pairs * 64 = 1024 B + dec rcx + jnz {{L}}loop_32k_amx_bf16_16x16 + + // tmm0 -> stack scratch (16 rows x 64 bytes = 1024 B row-major f32). + // Each row's 16 f32 are contiguous, so one 64-byte load per row. + sub rsp, 1024 + mov r10, rsp + mov r11, 64 + tilestored [r10 + r11 * 1], tmm0 + + {% for m in range(0, 16) %} + vmovups zmm{{m}}, [r10 + {{ m * 64 }}] + {% endfor %} + + add rsp, 1024 + + jmp {{L}}non_linear_loop + +// ---- Scalar / per-row / per-col f32 epilogues ---------------------------- + +{{L}}scalar_min: + vbroadcastss zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vminps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_max: + vbroadcastss zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vmaxps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_add: + vbroadcastss zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vaddps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_mul: + vbroadcastss zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vmulps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_sub: + vbroadcastss zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_sub_flipped: + vbroadcastss zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm16, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}leaky_relu: + // C[m, n] = (C[m, n] >= 0) ? C[m, n] : alpha * C[m, n] + vbroadcastss zmm17, dword ptr [rdi + 8] // alpha + vpxorq zmm16, zmm16, zmm16 // 0.0 + {% for r in range(0, 16) %} + vmulps zmm18, zmm{{r}}, zmm17 // alpha * x + vcmpps k1, zmm{{r}}, zmm16, 1 // imm 1 = LT (signed): 1 where x < 0 + vblendmps zmm{{r}}{k1}, zmm{{r}}, zmm18 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_min: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] + vminps zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_max: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] + vmaxps zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_add: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] + vaddps zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_mul: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] + vmulps zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_sub: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] + vsubps zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_sub_flipped: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] + vsubps zmm{{m}}, zmm16, zmm{{m}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_min: + mov rax, [rdi + 8] + vmovups zmm16, [rax] + {% for r in range(0, 16) %}vminps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_max: + mov rax, [rdi + 8] + vmovups zmm16, [rax] + {% for r in range(0, 16) %}vmaxps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_add: + mov rax, [rdi + 8] + vmovups zmm16, [rax] + {% for r in range(0, 16) %}vaddps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_mul: + mov rax, [rdi + 8] + vmovups zmm16, [rax] + {% for r in range(0, 16) %}vmulps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_sub: + mov rax, [rdi + 8] + vmovups zmm16, [rax] + {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_sub_flipped: + mov rax, [rdi + 8] + vmovups zmm16, [rax] + {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm16, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}load_tile: + // Scratch layout is COL-MAJOR f32 (col_byte_stride = item_size * MR = + // 4 * 16 = 64): tile[col][row] at offset col*64 + row*4. Gather row m's + // 16 cols at index step 64. + mov r8, [rdi + 8] + vmovdqa32 zmm16, [rip + {{L}}lane_offsets_64] + {% for m in range(0, 16) %} + mov eax, 0xFFFF + kmovw k1, eax + vpgatherdd zmm{{m}}{k1}, [r8 + zmm16 + {{m * 4}}] + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_unicast: + mov r10, [rdi + 8] // c ptr (base) + mov rsi, [rdi + 16] // row stride + mov rbx, [rdi + 24] // col stride + mov r8, [rdi + 32] // item size (4 for f32) + + cmp r8, 4 + jne {{L}}unsupported // f32 kernel: only item_size=4 + + // i32-strided gather (f32 same bit-width: vpgatherdd is correct). + mov eax, ebx + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 + vpmulld zmm16, zmm16, [rip + {{L}}lane_indices] + + {% for m in range(0, 16) %} + mov eax, 0xFFFF + kmovw k1, eax + vpgatherdd zmm17{k1}, [r10 + zmm16] + vaddps zmm{{m}}, zmm{{m}}, zmm17 + add r10, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_row_col_products: + // bias[m, n] = row_data[m] * col_data[n], FMA-add to C[m, n]. + mov rax, [rdi + 8] + mov rbx, [rdi + 16] + + vmovups zmm16, [rbx] // 16 col_data values + + {% for m in range(0, 16) %} + vbroadcastss zmm17, dword ptr [rax + {{m * 4}}] // splat row_data[m] + vfmadd231ps zmm{{m}}, zmm17, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +// ---- q_scale / q_shr / q_shl: not meaningful for f32, stub to unsupported. +{{L}}q_scale: +{{L}}q_shl: +{{L}}q_shr: + jmp {{L}}unsupported + +// ---- Store --------------------------------------------------------------- + +{{L}}store: + mov r8, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rdx, [rdi + 24] // col stride + mov rcx, [rdi + 32] // item size + + cmp rcx, 4 + jne {{L}}unsupported // f32 kernel: only item_size=4 + + cmp rdx, 4 + je {{L}}store_strides_f32_row_contig + + // Generic f32 strided store + {% for m in range(0, 16) %} + mov r10, r8 + vextracti32x4 xmm20, zmm{{m}}, 0 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 1 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 2 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 3 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_f32_row_contig: + // C is row-major in memory: each row's 16 f32 are contiguous; one + // 64-byte vmovups per row. + {% for m in range(0, 16) %} + vmovups [r8], zmm{{m}} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}return: + tilerelease + add rsp, 64 + + ldmxcsr [rsp + 4] + add rsp, 8 + + pop r15 + pop r14 + pop r13 + pop r12 + pop rbx + +{% if family == "windows" %} + pop rsi + pop rdi + + vmovaps xmm15, [rsp+16*9] + vmovaps xmm14, [rsp+16*8] + vmovaps xmm13, [rsp+16*7] + vmovaps xmm12, [rsp+16*6] + vmovaps xmm11, [rsp+16*5] + vmovaps xmm10, [rsp+16*4] + vmovaps xmm9, [rsp+16*3] + vmovaps xmm8, [rsp+16*2] + vmovaps xmm7, [rsp+16*1] + vmovaps xmm6, [rsp] +{% endif %} + + mov rsp, rbp + pop rbp + ret + +// ---- Read-only data (RIP-relative) --------------------------------------- + +.p2align 6 +{{L}}lane_offsets_64: + .int 0, 64, 128, 192, 256, 320, 384, 448 + .int 512, 576, 640, 704, 768, 832, 896, 960 + +.p2align 6 +{{L}}lane_indices: + .int 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + +{% if msvc %} +avx512amx_mmm_f32_16x16_{{suffix}} endp +_text ends +end +{% else %} +.cfi_endproc +{% endif %} From 9aa5f18aac1f8e3964fa6a74be2e08227ccb7f04 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Tue, 2 Jun 2026 05:52:19 +0000 Subject: [PATCH 12/21] linalg/bench: include avx512amx_mmm_f32_16x16 in amx_f32 microbench Mirrors the i32 amx bench (same shapes: 64x256x64 / 256x256x256 / 512x512x512 / 1024x1024x64) but exercises the bf16 path. Three columns: fma f32 16x6 (AVX2 baseline), avx512 f32 16x12 (AVX-512 reference), and the new AMX bf16 16x16 kernel under packing index 1 (the f32->bf16 RNE pack path). Skipped when has_amx_bf16() returns false and at build time when tract_amx_bf16 is unset. --- linalg/Cargo.toml | 4 ++ linalg/benches/amx_f32.rs | 105 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+) create mode 100644 linalg/benches/amx_f32.rs diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index 07fbe299b1..6ee91f6003 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -171,3 +171,7 @@ harness = false [[bench]] name = "amx_i32" harness = false + +[[bench]] +name = "amx_f32" +harness = false diff --git a/linalg/benches/amx_f32.rs b/linalg/benches/amx_f32.rs new file mode 100644 index 0000000000..bf59f066cf --- /dev/null +++ b/linalg/benches/amx_f32.rs @@ -0,0 +1,105 @@ +#![allow(dead_code)] +// Kernel-level benchmark: Intel AMX bf16 GEMM for f32 matmul +// (avx512amx_mmm_f32_16x16, TDPBF16PS over 16x16 f32 tile with K=32 bf16 inner) +// vs the AVX-512 f32 16x12 path (avx512_mmm_f32_16x12, FMA) vs the AVX2/FMA +// f32 16x6 path (fma_mmm_f32_16x6). +// +// The AMX path runs the f32f32_bf16 packing (index 1) which truncates f32 to +// bf16 at pack time (round-to-nearest-even, matching VCVTNEPS2BF16) so the f32 +// accumulators carry the bf16 precision profile -- same trade-off as oneDNN +// "fast-math" f32 matmul on AMX. The two reference kernels run their default +// f32 packing (index 0). +// +// Skipped at runtime when has_amx_bf16() returns false (= CPUID lacks +// amx-bf16/tile or the arch_prctl XSAVE permission was denied), and at build +// time when the tract_amx_bf16 cfg was not emitted. +use criterion::*; +use tract_data::internal::*; +use tract_linalg::mmm::{AsInputValue, FusedSpec, MatMatMul}; + +fn run_kernel( + be: &mut Bencher, + mmm: &dyn MatMatMul, + packing: usize, + m: usize, + k: usize, + n: usize, +) { + let a = Tensor::zero_dt(DatumType::F32, &[m, k]).unwrap(); + let b = Tensor::zero_dt(DatumType::F32, &[k, n]).unwrap(); + let (pack_a, pack_b) = &mmm.packings()[packing]; + let pa = pack_a.prepare_one(&a, 1, 0).unwrap(); + let pb = pack_b.prepare_one(&b, 0, 1).unwrap(); + let mut scratch = unsafe { mmm.allocate_scratch_space() }; + be.iter_custom(|iters| { + let mut dur = std::time::Duration::default(); + for _ in 0..iters { + let t = std::time::Instant::now(); + unsafe { + mmm.run_with_scratch_space( + m, + n, + scratch.as_mut(), + &[FusedSpec::AddMatMul { + a: AsInputValue::Borrowed(&*pa), + b: AsInputValue::Borrowed(&*pb), + packing, + }], + ) + .unwrap() + }; + dur += t.elapsed(); + } + dur + }); +} + +fn benches(c: &mut Criterion) { + #[cfg(tract_amx_bf16)] + { + use tract_linalg::x86_64_fma::amx_bf16::has_amx_bf16; + use tract_linalg::x86_64_fma::mmm::*; + if !has_amx_bf16() { + eprintln!("AMX bf16 not available (CPUID + arch_prctl gate failed), skipping"); + return; + } + // Same shapes as amx_i32 so reviewers can directly compare bf16->f32 vs + // i8->i32 throughput at matching M/K/N. K=32 (single tdpbf16ps step) + // and K=64 (one i8 tile) are tested via 256 / 256x256 / 512x512x512. + for &(m, k, n) in + &[(64usize, 256usize, 64usize), (256, 256, 256), (512, 512, 512), (1024, 1024, 64)] + { + let id = format!("{m}x{k}x{n}"); + let mut g = c.benchmark_group("amx_f32/packed_packed"); + g.throughput(Throughput::Elements((m * k * n) as u64)); + // Reference: FMA f32 16x6 (the kernel mmm_f32 picks for these N). + g.bench_with_input(BenchmarkId::new("fma_16x6", &id), &(m, k, n), |b, &(m, k, n)| { + run_kernel(b, &*fma_mmm_f32_16x6.mmm(), 0, m, k, n) + }); + if std::is_x86_feature_detected!("avx512f") { + // Reference: AVX-512 f32 16x12. + g.bench_with_input( + BenchmarkId::new("avx512_16x12", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512_mmm_f32_16x12.mmm(), 0, m, k, n), + ); + } + // AMX bf16 path (packing index 1 = f32f32_bf16: pack-time RNE + // conversion of f32 -> bf16, then TDPBF16PS in the inner loop). + g.bench_with_input( + BenchmarkId::new("avx512amx_bf16_16x16", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512amx_mmm_f32_16x16.mmm(), 1, m, k, n), + ); + g.finish(); + } + } + #[cfg(not(tract_amx_bf16))] + { + eprintln!("tract not built with AMX bf16 support (probe failed at build time)"); + let _ = c; + } +} + +criterion_group!(g, benches); +criterion_main!(g); From 66ed6893f38f910a152a1b90e2d46668ee040161 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Tue, 2 Jun 2026 05:52:45 +0000 Subject: [PATCH 13/21] linalg/x86_64: add AVX-VNNI ymm int8 GEMM kernel (avxvnni_mmm_i32_8x8) Forks avx512vnni_mmm_i32_8x8.S.j2 with the {vex} instruction prefix on VPDPBUSD so gas emits the AVX-VNNI (VEX) encoding instead of the AVX-512-VNNI (EVEX) encoding it defaults to. Body is otherwise byte-for- byte identical: 8x8 ymm accumulators, PackedI8K4 inner-K (4-byte dot), +128 bias trick to bridge VPDPBUSD's u8 x s8 into the AVX2 s8 x s8 reference. This ships VPDPBUSD-accelerated i8 GEMM to AVX2-only Atom-class cores that don't have AVX-512: - Alder Lake / Raptor Lake / Meteor Lake E-cores (Gracemont, Crestmont) - Sierra Forest (Sierra Glen) - Clearwater Forest (Darkmont) -- the gap called out by the user * avxvnni_mmm_i32_8x8.S.j2 -- the kernel; only the two VPDPBUSD lines are prefixed with {vex}. * avxvnni.rs -- runtime gate via CPUID leaf 7 sub-leaf 1 EAX bit 4 (the AVX-VNNI capability bit). Memoised; no XSAVE permission needed (unlike AMX, AVX-VNNI uses no extended state). * build.rs -- assembler probe (dummy_avxvnni.S) checks gas accepts the {vex} prefix on VPDPBUSD (binutils 2.36+). Sets tract_avxvnni cfg on success; pulls avxvnni_*.S.j2 out of the bulk -mfma compile so older toolchains aren't broken. * mmm.rs -- registers the kernel as packing[1]=i8i8 (same PackedI8K4 as AVX-512-VNNI for layout compatibility) and plugs qmmm_i32 to it when AVX-VNNI is the highest-quality int8 ISA. On big cores that have both AVX-512-VNNI and AVX-VNNI (Sapphire Rapids+, some Alder Lake P-cores) plug_avx512vnni runs after this and clobbers qmmm_i32 with the EVEX kernel; on AVX-VNNI-only Atom cores this path stays. All 114 kernel tests pass on this AVX-512-VNNI host (the kernel runs -- big cores with AVX-512-VNNI also carry AVX-VNNI on Sapphire Rapids+; on this Cascade Lake-class CPU the runtime gate stays off and the kernel is exercised only via the test harness's direct call path). --- linalg/build.rs | 54 +- linalg/src/x86_64_fma.rs | 1 + linalg/src/x86_64_fma/avxvnni.rs | 42 ++ linalg/src/x86_64_fma/mmm.rs | 36 ++ linalg/x86_64/avx512amx/dummy_avxvnni.S | 16 + linalg/x86_64/fma/avxvnni_mmm_i32_8x8.S.j2 | 685 +++++++++++++++++++++ 6 files changed, 833 insertions(+), 1 deletion(-) create mode 100644 linalg/src/x86_64_fma/avxvnni.rs create mode 100644 linalg/x86_64/avx512amx/dummy_avxvnni.S create mode 100644 linalg/x86_64/fma/avxvnni_mmm_i32_8x8.S.j2 diff --git a/linalg/build.rs b/linalg/build.rs index c2d2f5e5ce..52eb92e0d1 100644 --- a/linalg/build.rs +++ b/linalg/build.rs @@ -86,6 +86,22 @@ fn assembler_supports_amx_int8() -> bool { .is_ok() } +// Probe whether the assembler accepts the `{vex}` prefix on VPDPBUSD -- +// needed to force the AVX-VNNI (VEX) form instead of the AVX-512-VNNI +// (EVEX) form gas defaults to. `{vex}` / `{evex}` instruction prefixes +// were added in binutils 2.36; older toolchains reject them. When the +// probe fails the avxvnni_mmm_i32_8x8 kernel is skipped and dispatch +// falls back to the AVX2 emulation kernel on AVX-VNNI-only hardware. +fn assembler_supports_avxvnni() -> bool { + cc::Build::new() + .file("x86_64/avx512amx/dummy_avxvnni.S") + .cargo_metadata(false) + .cargo_warnings(false) + .warnings(false) + .try_compile("tract_avxvnni_probe") + .is_ok() +} + // Probe whether the target assembler can assemble AMX bf16 instructions // (`tdpbf16ps`). Both int8 and bf16 AMX mnemonics require binutils >= 2.34, // so in practice this probe succeeds whenever `assembler_supports_amx_int8` @@ -203,6 +219,9 @@ fn main() { println!("cargo:rustc-check-cfg=cfg(tract_amx_int8)"); // Set below only when the assembler accepts AMX bf16 mnemonics (tdpbf16ps). println!("cargo:rustc-check-cfg=cfg(tract_amx_bf16)"); + // Set below only when the assembler accepts the `{vex}` prefix on + // VPDPBUSD (binutils >= 2.36) -- needed for the AVX-VNNI ymm kernel. + println!("cargo:rustc-check-cfg=cfg(tract_avxvnni)"); match arch.as_ref() { "x86_64" => { @@ -242,7 +261,25 @@ fn main() { }) .cloned() .collect(); - files.retain(|f| !amx_int8_files.contains(f) && !amx_bf16_files.contains(f)); + // AVX-VNNI ymm kernel: gas requires the `{vex}` instruction prefix + // (binutils 2.36+) -- pulled aside so the bulk -mfma compile, which + // is fine on older binutils, isn't broken when the AVX-VNNI cfg is + // disabled. + let avxvnni_files: Vec = files + .iter() + .filter(|f| { + f.file_name() + .and_then(|n| n.to_str()) + .map(|n| n.starts_with("avxvnni_")) + .unwrap_or(false) + }) + .cloned() + .collect(); + files.retain(|f| { + !amx_int8_files.contains(f) + && !amx_bf16_files.contains(f) + && !avxvnni_files.contains(f) + }); if os == "windows" { if use_masm() { @@ -338,6 +375,21 @@ fn main() { .compile("x86_64_avx512amx_bf16"); println!("cargo:rustc-cfg=tract_amx_bf16"); } + + // AVX-VNNI ymm int8 kernel. Independent of the AMX gates: this + // kernel ships VPDPBUSD-accelerated i8 GEMM to Atom-class cores + // (Alder Lake-E, Sierra Forest, Clearwater Forest / Darkmont) + // that have AVX-VNNI but no AVX-512, falling back to AVX2 + // emulation when the runtime CPUID detection misses. + if os != "windows" + && !avxvnni_files.is_empty() + && assembler_supports_avxvnni() + { + cc::Build::new() + .files(&avxvnni_files) + .compile("x86_64_avxvnni"); + println!("cargo:rustc-cfg=tract_avxvnni"); + } } "arm" | "armv7" => { let files = preprocess_files("arm32/armvfpv2", &[], &suffix, false); diff --git a/linalg/src/x86_64_fma.rs b/linalg/src/x86_64_fma.rs index 1d7750f0be..1deaf41081 100644 --- a/linalg/src/x86_64_fma.rs +++ b/linalg/src/x86_64_fma.rs @@ -12,6 +12,7 @@ pub mod act_f16_fp16; pub mod amx; pub mod amx_bf16; +pub mod avxvnni; pub mod by_scalar; pub mod erf; mod intel; diff --git a/linalg/src/x86_64_fma/avxvnni.rs b/linalg/src/x86_64_fma/avxvnni.rs new file mode 100644 index 0000000000..5d98637865 --- /dev/null +++ b/linalg/src/x86_64_fma/avxvnni.rs @@ -0,0 +1,42 @@ +// AVX-VNNI int8 GEMM runtime gate. +// +// AVX-VNNI (CPUID leaf 7 sub-leaf 1 EAX bit 4) is the VEX-encoded sibling of +// AVX-512-VNNI's VPDPBUSD: same i32 += u8 * s8 dot4 semantics, but addressable +// from VEX (= AVX2-class) decoders. It exists primarily for Atom-class +// server / E-core SKUs that have AVX2 + AVX-VNNI but no AVX-512: +// +// * Alder Lake / Raptor Lake / Meteor Lake E-cores (Gracemont, Crestmont) +// * Sierra Forest (Sierra Glen) +// * Clearwater Forest (Darkmont) +// +// On a CPU with AVX-512-VNNI (Cascade Lake, Ice Lake, Sapphire Rapids+), this +// detector still returns true if CPUID leaf 7.1 EAX.4 is set -- some big-core +// SKUs report AVX-VNNI alongside AVX-512-VNNI -- but the dispatch in mmm.rs +// prefers the EVEX-encoded avx512vnni kernel in that case (same throughput, +// 32 zmm registers available for unrolling). The AVX-VNNI kernel is only +// selected when AVX-512-VNNI is absent. + +use std::sync::OnceLock; + +/// CPUID leaf 7 sub-leaf 1, EAX bit 4 = AVX-VNNI (Intel SDM Vol 2 Table 1-7). +/// Sub-leaf 1 is only valid when CPUID.7.0.EAX (the max sub-leaf field) >= 1; +/// older CPUs return zeroed structures. We check the max-sub-leaf first to +/// avoid a misleading bit on pre-AVX-VNNI silicon. +fn cpu_has_avxvnni() -> bool { + if !std::is_x86_feature_detected!("avx2") { + return false; + } + let max_sub = std::arch::x86_64::__cpuid_count(7, 0).eax; + if max_sub < 1 { + return false; + } + let r = std::arch::x86_64::__cpuid_count(7, 1); + (r.eax & (1 << 4)) != 0 +} + +/// Returns true iff CPUID reports AVX-VNNI on this CPU. Memoised; no OS +/// permission gate is required (unlike AMX, AVX-VNNI uses no extended state). +pub fn has_avxvnni() -> bool { + static GATE: OnceLock = OnceLock::new(); + *GATE.get_or_init(cpu_has_avxvnni) +} diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index 86f36b7c9e..e43cfc4561 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -7,12 +7,16 @@ use crate::pack::{PackedFormat, PackedI8K4}; use super::amx::{PackedAmxA, has_amx_int8}; #[cfg(tract_amx_bf16)] use super::amx_bf16::{PackedAmxBf16A, PackedBf16K2, has_amx_bf16}; +#[cfg(tract_avxvnni)] +use super::avxvnni::has_avxvnni; use super::*; #[cfg(tract_amx_int8)] const AVX512AMX: fn() -> bool = has_amx_int8; #[cfg(tract_amx_bf16)] const AVX512AMX_BF16: fn() -> bool = has_amx_bf16; +#[cfg(tract_avxvnni)] +const AVXVNNI: fn() -> bool = has_avxvnni; /// One candidate kernel in a dispatcher's pool, with its tile geometry /// and a relative-throughput scale (1.0 = baseline, used to break @@ -115,6 +119,20 @@ MMMExternKernel! { avx512vnni_mmm_i32_8x8(8,8)@(256,4) where(AVX512VNNI) store(i8) } +// AVX-VNNI ymm int8 GEMM: byte-for-byte the same body as avx512vnni_mmm_i32_8x8 +// (8x8 ymm accumulators, PackedI8K4 inner-K, +128 bias trick), but the +// VPDPBUSD instructions are forced to the VEX (AVX-VNNI) encoding via the +// `{vex}` prefix. Runs on Atom-class cores (Alder Lake-E, Sierra Forest, +// Clearwater Forest / Darkmont) which have AVX-VNNI but no AVX-512. On big +// cores with both AVX-512-VNNI and AVX-VNNI (Sapphire Rapids+, some Alder +// Lake P-core SKUs) dispatch prefers the EVEX-encoded kernel above. +#[cfg(tract_avxvnni)] +MMMExternKernel! { avxvnni_mmm_i32_8x8(8,8)@(256,4) where(AVXVNNI) + packing[1] = i8i8 => |k| k.with_packing(PackedI8K4::new(8), PackedI8K4::new(8)); + quality(ManuallyOptimized) + store(i8) +} + // Same epilogue as avx512vnni_mmm_i32_8x8 (8x8 ymm accumulators), but the i8i8 // matmul inner loop uses TDPBSSD (16-M x 16-N x 64-K mul-acc per instruction) // over AMX tiles. A's packing is novel (PackedAmxA, M-major-within-panel, @@ -160,6 +178,13 @@ MMMExternKernel! { avx512amx_mmm_f32_16x16(16,16)@(64,4) where(AVX512AMX_BF pub fn plug(ops: &mut Ops) { if is_x86_feature_detected!("avx2") { plug_avx2(ops); + // AVX-VNNI runs on AVX2-only Atom-class cores (Alder Lake-E, Sierra + // Forest, Clearwater Forest / Darkmont). Plug it here so big cores + // can overlay AVX-512-VNNI / AMX on top below. + #[cfg(tract_avxvnni)] + if has_avxvnni() { + plug_avxvnni(ops); + } if is_x86_feature_detected!("fma") { plug_fma(ops); if is_x86_feature_detected!("avx512f") { @@ -193,6 +218,17 @@ pub fn plug_avx512vnni(ops: &mut Ops) { log::info!("qmmm_i32: x86_64/avx512vnni activated"); } +#[cfg(tract_avxvnni)] +pub fn plug_avxvnni(ops: &mut Ops) { + ops.mmm_impls.push(avxvnni_mmm_i32_8x8.mmm()); + // On AVX-VNNI-only cores (no AVX-512) this is the int8 throughput champion; + // replace the AVX2 emulation default. On big cores that also have + // AVX-512-VNNI, plug_avx512vnni below runs after this and clobbers + // qmmm_i32 again with the EVEX kernel. + ops.qmmm_i32 = Box::new(|_, _, _| avxvnni_mmm_i32_8x8.mmm()); + log::info!("qmmm_i32: x86_64/avxvnni (VEX-encoded VPDPBUSD) activated"); +} + #[cfg(tract_amx_bf16)] pub fn plug_avx512amx_bf16(ops: &mut Ops) { ops.mmm_impls.push(avx512amx_mmm_f32_16x16.mmm()); diff --git a/linalg/x86_64/avx512amx/dummy_avxvnni.S b/linalg/x86_64/avx512amx/dummy_avxvnni.S new file mode 100644 index 0000000000..0579b2ed84 --- /dev/null +++ b/linalg/x86_64/avx512amx/dummy_avxvnni.S @@ -0,0 +1,16 @@ +// Build-time capability probe for the assembler, used by build.rs +// (assembler_supports_avxvnni). Checks that the assembler accepts the +// `{vex}` prefix on VPDPBUSD, which forces the AVX-VNNI (VEX-encoded) +// form instead of the AVX-512-VNNI (EVEX-encoded) form gas defaults to. +// Requires binutils >= 2.36 (which added `{vex}`/`{evex}` prefixes for +// explicit encoding selection). When the probe fails the AVX-VNNI kernel +// is skipped and dispatch falls back to AVX2 emulation on AVX-VNNI-only +// hardware (Clearwater Forest / Sierra Forest / Alder Lake E-cores). +// Not linked into anything. +.intel_syntax noprefix +.text +.globl tract_avxvnni_probe +tract_avxvnni_probe: + // AVX-VNNI: u8 x s8 -> i32 dot4 (VEX-encoded) + {vex} vpdpbusd ymm0, ymm1, ymm2 + ret diff --git a/linalg/x86_64/fma/avxvnni_mmm_i32_8x8.S.j2 b/linalg/x86_64/fma/avxvnni_mmm_i32_8x8.S.j2 new file mode 100644 index 0000000000..904335c511 --- /dev/null +++ b/linalg/x86_64/fma/avxvnni_mmm_i32_8x8.S.j2 @@ -0,0 +1,685 @@ +{# +// vim: set syntax=asm : + +/* AVX-VNNI int8 GEMM (mmm 8x8), VEX-encoded VPDPBUSD. +// +// Body-identical to avx512vnni_mmm_i32_8x8 (same 8-row x 8-col ymm accumulators, +// same PackedI8K4 inner-K layout, same +128 bias trick to bridge VPDPBUSD's +// u8 x s8 into the AVX2 s8 x s8 reference). The only difference is that the +// two VPDPBUSD instructions are prefixed with {vex} so gas emits the AVX-VNNI +// (VEX) form instead of the AVX-512-VNNI (EVEX) form it defaults to. The VEX +// form runs on Atom-class cores (Alder Lake E-cores, Sierra Forest, Clearwater +// Forest / Darkmont) which have AVX-VNNI but no AVX-512, where the existing +// avx512vnni kernel would fault. + + ymm0 ymm1 ymm2 ymm3 ymm4 ymm5 ymm6 ymm7 + +System V ABI: + args: rdi, rsi, rdx, rcx, r8, r9 + preserve: rbx, rsp, rbp, r12, r13, r14, r15 + scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11 + return: rax (+rdx) + +Windows ABI: + args: RCX, RDX, R8, R9 + preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15 + scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of YMM0-15 and ZMM0-15 + return: rax (+rdx) +*/ +#} + +{% if msvc %} + +_text segment +avxvnni_mmm_i32_8x8_{{suffix}} proc + +{% else %} + +.intel_syntax noprefix +.text +.p2align 5 +.globl {{G}}avxvnni_mmm_i32_8x8_{{suffix}} +{{G}}avxvnni_mmm_i32_8x8_{{suffix}}: +.cfi_startproc + +{% endif %} + + push rbp + mov rbp, rsp + +{% if family == "windows" %} +// https://www.agner.org/optimize/calling_conventions.pdf xmm6-15 are not scratch +// https://stackoverflow.com/questions/43358429/save-value-of-xmm-registers + and rsp,-16 + lea rsp,[rsp-160] + vmovaps [rsp], xmm6 + vmovaps [rsp+16*1],xmm7 + vmovaps [rsp+16*2],xmm8 + vmovaps [rsp+16*3],xmm9 + vmovaps [rsp+16*4],xmm10 + vmovaps [rsp+16*5],xmm11 + vmovaps [rsp+16*6],xmm12 + vmovaps [rsp+16*7],xmm13 + vmovaps [rsp+16*8],xmm14 + vmovaps [rsp+16*9],xmm15 + + push rdi + push rsi + + mov rdi, rcx + +{% endif %} + + push rbx + push r12 + push r13 + push r14 + push r15 + + sub rsp, 8 + +{% if family == "unix" %} +.cfi_def_cfa_offset 64 +{% endif %} + + stmxcsr [rsp + 4] +{% if msvc %} + mov rax, 1FC0h +{% else %} + mov rax, 0x1FC0 +{% endif %} + mov [rsp], eax + ldmxcsr [rsp] + +{% include "dispatcher.j2" %} + +{{L}}clear: + vzeroall + jmp {{L}}non_linear_loop + +{{L}}add_mat_mul: + mov r12, [rdi + 32] // packing + mov rbx, [rdi + 24] // B + mov rax, [rdi + 16] // A + + mov rcx, [rdi + 8] // k + test rcx, rcx + jz {{L}}non_linear_loop + + cmp r12, 1 + je {{L}}main_loop_packed_packed_i8i8 + +{{L}}main_loop_packed_packed: + vmovaps ymm12, [rax] + + {% for i in range(0, 8) %} + vbroadcastss ymm14, dword ptr [rbx + {{i}} * 4] + vpmulld ymm13, ymm12, ymm14 + vpaddd ymm{{i}}, ymm{{i}}, ymm13 + {% endfor %} + + add rax, 32 + add rbx, 32 + dec rcx + jnz {{L}}main_loop_packed_packed + + jmp {{L}}non_linear_loop + +{{L}}main_loop_packed_packed_i8i8: + // PackedI8K4 layout: per K=4 block, the A panel is 8 rows x 4 K-bytes (32 + // bytes, lane m = A[m, 4kb..4kb+3]) and the B panel is 8 cols x 4 K-bytes + // (lane n = B[n, 4kb..4kb+3]). VPDPBUSD is u8 x s8, so A is offset by +128 + // (-> u8) and the resulting 128*sum_k(B[n]) bias is removed per column after + // the loop, leaving the i32 accumulators identical to the AVX2 path. + + add rcx, 3 + shr rcx, 2 // rcx <- ceil(k/4) K=4 blocks + + mov r8d, 0x01010101 + movd xmm11, r8d + vpbroadcastd ymm11, xmm11 // ymm11 <- u8 ones (sum of B) + + mov r8d, 0x80808080 + movd xmm12, r8d + vpbroadcastd ymm12, xmm12 // ymm12 <- byte 0x80 (A + 128) + + vpxor ymm10, ymm10, ymm10 // ymm10 <- per-col sum_k B[n] + +{{L}}loop_4k_i8i8: + vmovdqu ymm8, [rax] // A block: lane m = A[m,4kb..] + vpaddb ymm8, ymm8, ymm12 // s8 -> u8 (+128, modular) + + vmovdqu ymm9, [rbx] // B block: lane n = B[n,4kb..] + {vex} vpdpbusd ymm10, ymm11, ymm9 // sum_k B[n] += sum_t B[n,4kb+t] + + {% for n in range(0, 8) %} + vpbroadcastd ymm13, dword ptr [rbx + {{n}} * 4] + {vex} vpdpbusd ymm{{n}}, ymm8, ymm13 // acc[n][m] += sum_t (A[m]+128)*B[n] + {% endfor %} + + add rax, 32 + add rbx, 32 + dec rcx + jnz {{L}}loop_4k_i8i8 + + // remove the +128 bias added on A: acc[n] -= 128 * sum_k B[n] + vpslld ymm10, ymm10, 7 // lane n <- 128 * sum_k B[n] + {% for n in range(0, 8) %} + mov r8d, {{n}} + movd xmm14, r8d + vpbroadcastd ymm14, xmm14 // index = n in every lane + vpermd ymm15, ymm14, ymm10 // splat 128*sum_k B[n] + vpsubd ymm{{n}}, ymm{{n}}, ymm15 + {% endfor %} + + jmp {{L}}non_linear_loop + +{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_scalars.j2" %} +{% set mr = 8 %}{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_per_rows.j2" %} +{% set mr = 8 %}{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_i32_per_cols.j2" %} +{% set from = 0 %}{% set to = 7 %}{% include "fma_mmm_load_tile.j2" %} + +{{L}}add_unicast: + + mov r10, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rbx, [rdi + 24] // col stride + mov r8, [rdi + 32] // item size + + cmp r8, 4 + je {{L}}non_linear_addc_i32 + +{# +// This is not great as vgatherdps reads 32-bits values and goes beyond our buffer. Probably harmless though. +// Commented and replaced with the "mov al" loop beyond to pacify valgrind. +// ymm14 and ymm15 are the same as in the non_linear_addc_i32 case (compute them before the test right above here. +// {% for i in range(0, 8) %} +// vpcmpeqd ymm15, ymm15, ymm15 +// vgatherdps ymm12, [ r10 + ymm14 ], ymm15 // 0xxx 1xxx 2xxx 3xxx 4xxx 5xxx 6xxx 7xxx +// +// // we need to go through vpmovsxbd, shuffling naively erases signs +// vpshufb ymm12, ymm12, ymm10 // 0123 0123 0123 0123 4567 4567 4567 4567 +// +// vpermd ymm12, ymm11, ymm12 // 0123 4567 +// vpmovsxbd ymm12, xmm12 // sign extend +// +// vpaddd ymm{{i}}, ymm{{i}}, ymm12 +// add r10, rbx +// {% endfor %} +#} + + {% for col in range(0, 8) %} + mov r8, r10 + {% for half in range(0, 2) %} + {% for lane in range(0, 4) %} + mov al, [ r8 ] + add r8, rsi + movsx eax, al + pinsrd xmm10, eax, {{lane}} + {% endfor %} + vperm2f128 ymm10, ymm10, ymm10, 1 + {% endfor %} + vpaddd ymm{{col}}, ymm{{col}}, ymm10 + add r10, rbx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}non_linear_addc_i32: + + mov eax, 0 +{% for i in range(0, 4) %} + pinsrd xmm14, eax, {{i}} + add eax, esi +{% endfor %} + vpermq ymm14, ymm14, 78 // 0b01001110 +{% for i in range(0, 4) %} + pinsrd xmm14, eax, {{i}} + add eax, esi +{% endfor %} + vpermq ymm14, ymm14, 78 // 0b01001110 + + +{% if msvc %} + vpbroadcastd ymm10, dword ptr [ offset byte_shuffle ] + vmovups ymm11, dword ptr [ offset i128_shuffle ] +{% else %} + vpbroadcastd ymm10, [ rip + {{L}}byte_shuffle ] + vmovups ymm11, [ rip + {{L}}i128_shuffle ] +{% endif %} + +{% for i in range(0, 8) %} + vpcmpeqd ymm15, ymm15, ymm15 + vgatherdps ymm12, [ r10 + ymm14 ], ymm15 + vpaddd ymm{{i}}, ymm{{i}}, ymm12 + add r10, rbx +{% endfor %} + + jmp {{L}}non_linear_loop + +{% if msvc %} +.data +byte_shuffle dd 201851904 // 0x0c080400 +i128_shuffle dd 0, 4 +.code +{% else %} +{{L}}byte_shuffle: .int 201851904 // 0x0c080400 +{{L}}i128_shuffle: .int 0, 4 +{% endif %} + +{{L}}add_row_col_products: + mov rax, [ rdi + 8 ] + mov rbx, [ rdi + 16 ] + + vmovups ymm12, [rax] + +{% for i in range(0, 8) %} + vbroadcastss ymm14, dword ptr [rbx + {{ i * 4 }} ] + vpmulld ymm15, ymm12, ymm14 + vpaddd ymm{{i}}, ymm{{i}}, ymm15 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale: + mov r8, [ rdi + 16 ] // policy + vbroadcastss ymm8, dword ptr [rdi + 24] // multi + + mov rax, 1 + movq xmm9, rax + vpbroadcastq ymm9, xmm9 // ymm9 <- 1 + + mov rax, [ rdi + 8 ] // xmm10 <- shift + 31 + add rax, 31 + movq xmm10, rax + vpbroadcastq ymm10, xmm10 + + mov rax, 1 + movq xmm11, rax + vpsubq ymm12, ymm10, ymm9 // shift+31 - 1 + vpsllq ymm11, ymm9, xmm12 // ymm11 <- 1 << (shift + 31 - 1) + + cmp r8, 1 + je {{L}}q_scale_rounding_zero + cmp r8, 2 + je {{L}}q_scale_rounding_away + cmp r8, 3 + je {{L}}q_scale_rounding_minus_inf + cmp r8, 4 + je {{L}}q_scale_rounding_plus_inf + cmp r8, 5 + je {{L}}q_scale_rounding_even + cmp r8, 6 + je {{L}}q_scale_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_scale_rounding_zero: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsubq ymm14, ymm14, ymm9 + vpsubq ymm15, ymm15, ymm9 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_away: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_minus_inf: // signum * ( (abs << 32 + 1<<30+shift) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + // sign extract for nudging in the right direction + vpxor ymm13, ymm13, ymm13 + vpcmpgtd ymm13, ymm{{i}}, ymm13 // ymm13 <- s0, s1, ..s8 (signums, as all ones or all zeros) + vpsrld ymm13, ymm13, 31 // then just 0 or 1 + + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + // reinterpret ymm13=s0i32..s7 as i64 and blend with zero to pick the even ones as i64 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm14, ymm14, ymm12 + + vpsrldq ymm13, ymm13, 4 // ymm13 <- s1, s2, .., s7, 0 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm15, ymm15, ymm12 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_plus_inf: // signum * ( (abs << 32 + 1<<30+shift) >> shift ) + + vpbroadcastd ymm9, xmm9 + +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpxor ymm13, ymm13, ymm13 + + // sign extract for nudging in the right direction + vpcmpgtd ymm13, ymm{{i}}, ymm13 // ymm13 <- s0, s1, ..s8 (signums, as all ones or all zeros) + vpaddd ymm13, ymm13, ymm9 // if val >= 0 { 0i32 } else { 1i32 } + + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + // reinterpret ymm13=s0i32..s7 as i64 and blend with zero to pick the even ones as i64 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm14, ymm14, ymm12 + + vpsrldq ymm13, ymm13, 4 // ymm13 <- s1, s2, .., s7, 0 + vpxor ymm12, ymm12, ymm12 + vpblendd ymm12, ymm12, ymm13, 85 // 0x55 + vpsubq ymm15, ymm15, ymm12 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_even: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpsrlq ymm12, ymm14, xmm10 + vpand ymm12, ymm12, ymm9 + vpaddq ymm14, ymm14, ymm12 + vpsubq ymm14, ymm14, ymm9 + + vpsrlq ymm12, ymm15, xmm10 + vpand ymm12, ymm12, ymm9 + vpaddq ymm15, ymm15, ymm12 + vpsubq ymm15, ymm15, ymm9 + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_odd: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsrldq ymm15, ymm14, 4 // ymm15 <- a1, a2, a3, a4, a5, a6, a7, 0 + vpmuldq ymm14, ymm14, ymm8 // ymm14 <- a0*c, a2*c, a4*c, a6*c + vpmuldq ymm15, ymm15, ymm8 // ymm15 <- a1*c, a3*c, a5*c, a7*c + + vpsrlq ymm12, ymm14, xmm10 + vpand ymm12, ymm12, ymm9 + vpsubq ymm14, ymm14, ymm12 + + vpsrlq ymm12, ymm15, xmm10 + vpand ymm12, ymm12, ymm9 + vpsubq ymm15, ymm15, ymm12 + + vpaddq ymm14, ymm14, ymm11 + vpaddq ymm15, ymm15, ymm11 + + vpsrlq ymm14, ymm14, xmm10 + vpsrlq ymm15, ymm15, xmm10 + + vpslldq ymm15, ymm15, 4 + vpblendd ymm14, ymm15, ymm14, 85 // 0x55 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}q_shl: + mov eax, [ rdi + 8 ] // xmm10 <- -shift (8 times) + movd xmm10, eax + vpbroadcastd ymm10, xmm10 + +{% for i in range(0, 8) %} + vpsllvd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr: + mov r8, [ rdi + 16 ] // policy + + mov eax, 1 + movd xmm9, eax + vpbroadcastd ymm9, xmm9 // ymm9 <- 1u32 (8 times) + + mov eax, [ rdi + 8 ] // xmm10 <- shift (8 times) + movd xmm10, eax + vpbroadcastd ymm10, xmm10 + + mov ebx, 1 + mov cl, al + sub cl, 1 // rcx <- shift -1 + sal ebx, cl // rbx <- (1 << (shift - 1)) + movd xmm11, ebx + vpbroadcastd ymm11, xmm11 // ymm11 <- "half" + + vpxor ymm12, ymm12, ymm12 // ymm12 <- zeroes + + cmp r8, 1 + je {{L}}q_shr_rounding_zero + cmp r8, 2 + je {{L}}q_shr_rounding_away + cmp r8, 3 + je {{L}}q_shr_rounding_minus_inf + cmp r8, 4 + je {{L}}q_shr_rounding_plus_inf + cmp r8, 5 + je {{L}}q_shr_rounding_even + cmp r8, 6 + je {{L}}q_shr_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_shr_rounding_zero: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsubd ymm14, ymm14, ymm9 + vpaddd ymm14, ymm14, ymm11 + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_away: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpaddd ymm14, ymm14, ymm11 + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_minus_inf: +{% for i in range(0, 8) %} + vpsubd ymm{{i}}, ymm{{i}}, ymm9 + vpaddd ymm{{i}}, ymm{{i}}, ymm11 + vpsravd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_plus_inf: +{% for i in range(0, 8) %} + vpaddd ymm{{i}}, ymm{{i}}, ymm11 + vpsravd ymm{{i}}, ymm{{i}}, ymm10 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_even: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsravd ymm13, ymm14, ymm10 + vpand ymm13, ymm13, ymm9 + vpsubd ymm13, ymm13, ymm9 // nudge = ((abs >>l shift) & 0x01) - 1 + vpaddd ymm14, ymm14, ymm13 // add nudge + vpaddd ymm14, ymm14, ymm11 // add half + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_odd: +{% for i in range(0, 8) %} + vpabsd ymm14, ymm{{i}} + vpsravd ymm13, ymm14, ymm10 + vpand ymm13, ymm13, ymm9 + vpsubd ymm13, ymm12, ymm13 // nudge = - ((abs >>l shift) & 0x01) + vpaddd ymm14, ymm14, ymm13 // add nudge + vpaddd ymm14, ymm14, ymm11 // add half + vpsravd ymm14, ymm14, ymm10 + vpsignd ymm{{i}}, ymm14, ymm{{i}} +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store: + mov r8, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rdx, [rdi + 24] // col stride + mov rcx, [rdi + 32] // item size + + cmp rcx, 4 + je {{L}}store_strides_i32 + + {% for col in range(0, 8) %} + mov r10, r8 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov byte ptr [r10], bl + add r10, rsi + {% endfor %} + vperm2f128 ymm{{col}}, ymm{{col}}, ymm{{col}}, 1 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov byte ptr [r10], bl + add r10, rsi + {% endfor %} + add r8, rdx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}store_strides_i32: + {% for col in range(0, 8) %} + mov r10, r8 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov dword ptr [r10], ebx + add r10, rsi + {% endfor %} + vperm2f128 ymm{{col}}, ymm{{col}}, ymm{{col}}, 1 + {% for row in range(0, 4) %} + extractps ebx, xmm{{col}}, {{row}} + mov dword ptr [r10], ebx + add r10, rsi + {% endfor %} + add r8, rdx + {% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}return: + ldmxcsr [rsp + 4] + add rsp, 8 + + pop r15 + pop r14 + pop r13 + pop r12 + pop rbx + +{% if family == "windows" %} + pop rsi + pop rdi + + vmovaps xmm15, [rsp+16*9] + vmovaps xmm14, [rsp+16*8] + vmovaps xmm13, [rsp+16*7] + vmovaps xmm12, [rsp+16*6] + vmovaps xmm11, [rsp+16*5] + vmovaps xmm10, [rsp+16*4] + vmovaps xmm9, [rsp+16*3] + vmovaps xmm8, [rsp+16*2] + vmovaps xmm7, [rsp+16*1] + vmovaps xmm6, [rsp] +{% endif %} + + mov rsp, rbp + pop rbp + ret + + +{{L}}one_32bit: +{% if msvc %} + dd 1 +{% else %} + .int 1 +{% endif %} + +{% if msvc %} +avxvnni_mmm_i32_8x8_{{suffix}} endp +_text ends +end +{% else %} +.cfi_endproc +{% endif %} From 455d724934eb321c1489656b5e6af98abb09a476 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Tue, 2 Jun 2026 06:20:32 +0000 Subject: [PATCH 14/21] linalg/x86_64: boost AMX 16x16 kernels + add avxvnni_i32 microbench Two small finishers on the AMX / AVX-VNNI work: * mmm.rs -- boost(|| 100) on both AMX 16x16 kernels (i32 and f32). The einsum kernel-selection scorer is `-quality_cost*1000 + boost`, so all current ManuallyOptimized kernels tie at score 0. The boost makes the optimizer prefer the AMX 16x16 tile over the equally-tier'd AVX-512-VNNI (i32) and AVX-512 / FMA (f32) candidates when at least one dim is symbolic and the shape-adaptive `qmmm_i32` / `mmm_f32` picker isn't the path of selection. * benches/avxvnni_i32.rs -- mirror of amx_i32: same shapes (64x256x64 / 256x256x256 / 512x512x512 / 1024x1024x64), three columns (avx2 baseline, avxvnni new, avx512vnni reference when present). Skipped when has_avxvnni() returns false (CPUID 7.1 EAX.4 unset). Ready for an Atom-class host (Sierra Forest / Clearwater Forest / Alder Lake-E) to drop in and measure the VPDPBUSD speedup over the vpmaddubsw-emulation AVX2 path. --- linalg/Cargo.toml | 4 ++ linalg/benches/avxvnni_i32.rs | 97 +++++++++++++++++++++++++++++++++++ linalg/src/x86_64_fma/mmm.rs | 13 +++++ 3 files changed, 114 insertions(+) create mode 100644 linalg/benches/avxvnni_i32.rs diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index 6ee91f6003..27bc730f47 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -175,3 +175,7 @@ harness = false [[bench]] name = "amx_f32" harness = false + +[[bench]] +name = "avxvnni_i32" +harness = false diff --git a/linalg/benches/avxvnni_i32.rs b/linalg/benches/avxvnni_i32.rs new file mode 100644 index 0000000000..19deb1f013 --- /dev/null +++ b/linalg/benches/avxvnni_i32.rs @@ -0,0 +1,97 @@ +#![allow(dead_code)] +// Kernel-level benchmark: AVX-VNNI ymm int8 GEMM (avxvnni_mmm_i32_8x8, +// VEX-encoded VPDPBUSD over PackedI8K4 with K=4 inner) vs the AVX2 emulation +// path (avx2_mmm_i32_8x8, vpmaddubsw-style widening). Both kernels run the +// same i8i8 packing index (1) over the same M/K/N so the only difference is +// the matmul inner loop. +// +// Designed for Atom-class hosts that have AVX-VNNI but no AVX-512: +// +// * Alder Lake / Raptor Lake / Meteor Lake E-cores (Gracemont, Crestmont) +// * Sierra Forest (Sierra Glen) +// * Clearwater Forest (Darkmont) +// +// Big cores with both AVX-512-VNNI and AVX-VNNI still run AVX-VNNI here for +// comparison purposes; in production dispatch the EVEX-encoded +// avx512vnni_mmm_i32_8x8 wins on those CPUs because it can later be widened +// to zmm without an ISA-level rewrite. +use criterion::*; +use tract_data::internal::*; +use tract_linalg::mmm::{AsInputValue, FusedSpec, MatMatMul}; + +fn run_kernel(be: &mut Bencher, mmm: &dyn MatMatMul, m: usize, k: usize, n: usize) { + let a = Tensor::zero_dt(DatumType::I8, &[m, k]).unwrap(); + let b = Tensor::zero_dt(DatumType::I8, &[k, n]).unwrap(); + let (pack_a, pack_b) = &mmm.packings()[1]; + let pa = pack_a.prepare_one(&a, 1, 0).unwrap(); + let pb = pack_b.prepare_one(&b, 0, 1).unwrap(); + let mut scratch = unsafe { mmm.allocate_scratch_space() }; + be.iter_custom(|iters| { + let mut dur = std::time::Duration::default(); + for _ in 0..iters { + let t = std::time::Instant::now(); + unsafe { + mmm.run_with_scratch_space( + m, + n, + scratch.as_mut(), + &[FusedSpec::AddMatMul { + a: AsInputValue::Borrowed(&*pa), + b: AsInputValue::Borrowed(&*pb), + packing: 1, + }], + ) + .unwrap() + }; + dur += t.elapsed(); + } + dur + }); +} + +fn benches(c: &mut Criterion) { + #[cfg(tract_avxvnni)] + { + use tract_linalg::x86_64_fma::avxvnni::has_avxvnni; + use tract_linalg::x86_64_fma::mmm::*; + if !has_avxvnni() { + eprintln!("AVX-VNNI not available (CPUID leaf 7.1 EAX.4 unset), skipping"); + return; + } + // Same shapes as amx_i32 / vnni_i32 for direct side-by-side comparison. + for &(m, k, n) in + &[(64usize, 256usize, 64usize), (256, 256, 256), (512, 512, 512), (1024, 1024, 64)] + { + let id = format!("{m}x{k}x{n}"); + let mut g = c.benchmark_group("avxvnni_i32/packed_packed"); + g.throughput(Throughput::Elements((m * k * n) as u64)); + g.bench_with_input(BenchmarkId::new("avx2", &id), &(m, k, n), |b, &(m, k, n)| { + run_kernel(b, &*avx2_mmm_i32_8x8.mmm(), m, k, n) + }); + g.bench_with_input(BenchmarkId::new("avxvnni", &id), &(m, k, n), |b, &(m, k, n)| { + run_kernel(b, &*avxvnni_mmm_i32_8x8.mmm(), m, k, n) + }); + // When the same host also reports AVX-512-VNNI, include it as a + // reference point: the same kernel body runs as EVEX/zmm-encoded + // VPDPBUSD, which should match the AVX-VNNI throughput on Sapphire + // Rapids+ but can diverge on Cooper/Cascade Lake where the EVEX + // decoder is on the AVX-512 fused unit. + if std::is_x86_feature_detected!("avx512vnni") { + g.bench_with_input( + BenchmarkId::new("avx512vnni", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512vnni_mmm_i32_8x8.mmm(), m, k, n), + ); + } + g.finish(); + } + } + #[cfg(not(tract_avxvnni))] + { + eprintln!("tract not built with AVX-VNNI support (probe failed at build time)"); + let _ = c; + } +} + +criterion_group!(g, benches); +criterion_main!(g); diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index e43cfc4561..2369fc3388 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -151,10 +151,17 @@ MMMExternKernel! { avx512amx_mmm_i32_8x8(8,8)@(64,4) where(AVX512AMX) // Same A/B packing (PackedAmxA, PackedI8K4) just with r=16. Row-major // accumulators (zmm{m} = row m of C) so the hot path (Clear -> AddMatMul -> // Store) needs no transpose. +// +// boost(100) pushes this kernel above the equally-ManuallyOptimized AVX-512-VNNI +// and AMX 8x8 candidates in the einsum kernel-selection scorer (which uses +// `-quality_cost*1000 + boost` per kernel). When more than one dim is symbolic +// the shape-adaptive `qmmm_i32` picker isn't invoked, so the boost is what +// causes the optimizer to prefer the 16x16 tile for unknown-shape matmuls. #[cfg(tract_amx_int8)] MMMExternKernel! { avx512amx_mmm_i32_16x16(16,16)@(64,4) where(AVX512AMX) packing[1] = i8i8 => |k| k.with_packing(PackedAmxA::new(16), PackedI8K4::new(16)); quality(ManuallyOptimized) + boost(|| 100) store(i8) } @@ -169,10 +176,16 @@ MMMExternKernel! { avx512amx_mmm_i32_16x16(16,16)@(64,4) where(AVX512AMX) // Default packing[0] (the framework's PackedFormat) is retained so the // kernel can still be selected for f32 paths even when the BF16 packer // isn't a precursor match; packing[1] is the fast bf16-from-f32 path. +// boost(100) puts this AMX kernel above the AVX-512 f32 / FMA f32 kernels at +// the same ManuallyOptimized tier so the einsum scorer prefers it whenever +// supported, mirroring the i32 16x16 behaviour. The bf16 vs f32 precision +// trade is intentional and amortised over the same call sites that already +// use bf16-via-`dotbf16ps`-style fast-math elsewhere in the stack. #[cfg(tract_amx_bf16)] MMMExternKernel! { avx512amx_mmm_f32_16x16(16,16)@(64,4) where(AVX512AMX_BF16) packing[1] = f32f32_bf16 => |k| k.with_packing(PackedAmxBf16A::new(16), PackedBf16K2::new(16)); quality(ManuallyOptimized) + boost(|| 100) } pub fn plug(ops: &mut Ops) { From 22fdb37c5bd6861bc37e9ebb45ca582b39a195e0 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Tue, 2 Jun 2026 17:56:21 +0000 Subject: [PATCH 15/21] linalg/x86_64: fix swapped operands in AMX 16x16 sub fused-op handlers The inline scalar_sub / per_row_sub / per_col_sub handlers (and their _flipped twins) in the AMX int8 and bf16 16x16 kernels had their operand order reversed relative to the shared fma_mmm_ymm_ops.j2 convention: non-flipped sub must compute `operand - acc`, flipped `acc - operand`. Both kernels did the opposite, so a ScalarSub / per-row / per-col subtract fused into the matmul produced negated results. The bug never surfaced because these kernels' test suites are skipped on hosts without AMX (is_supported_here() == false), and the dev/CI hardware here is Cascade Lake-class (AVX-512-VNNI, no AMX). It was caught by the new avx512vnni_mmm_i32_16x16 kernel, which reuses the same epilogue and whose tests DO run on VNNI hardware: scalar_sub / per_row_sub / per_col_sub each failed with exactly negated output. The commutative ops (min/max/mul/add) were unaffected. https://claude.ai/code/session_015PiiJ2Ave1m7PsXbnMPpBV --- linalg/x86_64/fma/avx512amx_mmm_f32_16x16.S.j2 | 18 ++++++++++++------ linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 | 18 ++++++++++++------ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/linalg/x86_64/fma/avx512amx_mmm_f32_16x16.S.j2 b/linalg/x86_64/fma/avx512amx_mmm_f32_16x16.S.j2 index c635589373..654cc660cd 100644 --- a/linalg/x86_64/fma/avx512amx_mmm_f32_16x16.S.j2 +++ b/linalg/x86_64/fma/avx512amx_mmm_f32_16x16.S.j2 @@ -239,14 +239,16 @@ avx512amx_mmm_f32_16x16_{{suffix}} proc jmp {{L}}non_linear_loop {{L}}scalar_sub: + // non-flipped sub = operand - acc (matches fma_mmm_ymm_ops.j2 scalar macro) vbroadcastss zmm16, dword ptr [rdi + 8] - {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm{{r}}, zmm16 + {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm16, zmm{{r}} {% endfor %} jmp {{L}}non_linear_loop {{L}}scalar_sub_flipped: + // flipped sub = acc - operand vbroadcastss zmm16, dword ptr [rdi + 8] - {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm16, zmm{{r}} + {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm{{r}}, zmm16 {% endfor %} jmp {{L}}non_linear_loop @@ -290,16 +292,18 @@ avx512amx_mmm_f32_16x16_{{suffix}} proc jmp {{L}}non_linear_loop {{L}}per_row_sub: + // non-flipped sub = operand - acc mov rax, [rdi + 8] {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] - vsubps zmm{{m}}, zmm{{m}}, zmm16 + vsubps zmm{{m}}, zmm16, zmm{{m}} {% endfor %} jmp {{L}}non_linear_loop {{L}}per_row_sub_flipped: + // flipped sub = acc - operand mov rax, [rdi + 8] {% for m in range(0, 16) %}vbroadcastss zmm16, dword ptr [rax + {{m * 4}}] - vsubps zmm{{m}}, zmm16, zmm{{m}} + vsubps zmm{{m}}, zmm{{m}}, zmm16 {% endfor %} jmp {{L}}non_linear_loop @@ -332,16 +336,18 @@ avx512amx_mmm_f32_16x16_{{suffix}} proc jmp {{L}}non_linear_loop {{L}}per_col_sub: + // non-flipped sub = operand - acc mov rax, [rdi + 8] vmovups zmm16, [rax] - {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm{{r}}, zmm16 + {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm16, zmm{{r}} {% endfor %} jmp {{L}}non_linear_loop {{L}}per_col_sub_flipped: + // flipped sub = acc - operand mov rax, [rdi + 8] vmovups zmm16, [rax] - {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm16, zmm{{r}} + {% for r in range(0, 16) %}vsubps zmm{{r}}, zmm{{r}}, zmm16 {% endfor %} jmp {{L}}non_linear_loop diff --git a/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 b/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 index 5652f425f8..5c97cb7c6a 100644 --- a/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 +++ b/linalg/x86_64/fma/avx512amx_mmm_i32_16x16.S.j2 @@ -249,14 +249,16 @@ avx512amx_mmm_i32_16x16_{{suffix}} proc jmp {{L}}non_linear_loop {{L}}scalar_sub: + // non-flipped sub = operand - acc (matches fma_mmm_ymm_ops.j2 scalar macro) vpbroadcastd zmm16, dword ptr [rdi + 8] - {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm{{r}}, zmm16 + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm16, zmm{{r}} {% endfor %} jmp {{L}}non_linear_loop {{L}}scalar_sub_flipped: + // flipped sub = acc - operand vpbroadcastd zmm16, dword ptr [rdi + 8] - {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm16, zmm{{r}} + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm{{r}}, zmm16 {% endfor %} jmp {{L}}non_linear_loop @@ -300,16 +302,18 @@ avx512amx_mmm_i32_16x16_{{suffix}} proc jmp {{L}}non_linear_loop {{L}}per_row_sub: + // non-flipped sub = operand - acc mov rax, [rdi + 8] {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] - vpsubd zmm{{m}}, zmm{{m}}, zmm16 + vpsubd zmm{{m}}, zmm16, zmm{{m}} {% endfor %} jmp {{L}}non_linear_loop {{L}}per_row_sub_flipped: + // flipped sub = acc - operand mov rax, [rdi + 8] {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] - vpsubd zmm{{m}}, zmm16, zmm{{m}} + vpsubd zmm{{m}}, zmm{{m}}, zmm16 {% endfor %} jmp {{L}}non_linear_loop @@ -342,16 +346,18 @@ avx512amx_mmm_i32_16x16_{{suffix}} proc jmp {{L}}non_linear_loop {{L}}per_col_sub: + // non-flipped sub = operand - acc mov rax, [rdi + 8] vmovdqu32 zmm16, [rax] - {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm{{r}}, zmm16 + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm16, zmm{{r}} {% endfor %} jmp {{L}}non_linear_loop {{L}}per_col_sub_flipped: + // flipped sub = acc - operand mov rax, [rdi + 8] vmovdqu32 zmm16, [rax] - {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm16, zmm{{r}} + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm{{r}}, zmm16 {% endfor %} jmp {{L}}non_linear_loop From 100e5722fb5230269dca0d208192daf30e965a13 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Tue, 2 Jun 2026 17:56:32 +0000 Subject: [PATCH 16/21] linalg/x86_64: add AVX-512 VNNI zmm 16x16 int8 GEMM kernel avx512vnni_mmm_i32_16x16 is the zmm-wide (512-bit) sibling of the existing avx512vnni_mmm_i32_8x8: 16 row-major i32 accumulators (zmm{m} = row m of C), one VPDPBUSD per row per K=4 block over PackedI8K4(16) for both A and B, so it issues 1024 mul-adds/block -- 2x the 8x8 ymm kernel's work per iteration. Same u8 x s8 +128 A-bias trick as the 8x8 kernel, but the row-major layout makes the per-column 128*sum_k(B) correction a single vector subtract. Built by adapting the AMX 16x16 i32 template (whose zmm row-major epilogue is reused verbatim), replacing the AMX tile inner loop with the VPDPBUSD loop and dropping the tile-config preamble / tilerelease. Because the file is named avx512vnni_* it stays in the generic -mfma assembler bulk-compile (VPDPBUSD needs no special gas gating, same as the 8x8 kernel). Wired into plug_avx512vnni with a shape-adaptive qmmm_i32 picker (16x16 when M,N >= 16, else 8x8) mirroring the AMX int8 path, plus boost(50) so the einsum scorer prefers it over the 8x8 for unknown shapes while staying below the AMX kernels' boost(100) (AMX still wins when both are present). This gives big cores with AVX-512-VNNI but no AMX (Cascade Lake / Ice Lake / Tiger Lake) a wider int8 GEMM throughput tier. Added as a third column in the vnni_i32 microbench. All 114 auto-generated kernel tests (packed-packed i8i8 + i32i32, fused-op frame, quant rounding, stores, proptest) pass on AVX-512-VNNI hardware. https://claude.ai/code/session_015PiiJ2Ave1m7PsXbnMPpBV --- linalg/benches/vnni_i32.rs | 15 +- linalg/build.rs | 17 +- linalg/src/x86_64_fma/mmm.rs | 39 +- .../x86_64/fma/avx512vnni_mmm_i32_16x16.S.j2 | 885 ++++++++++++++++++ 4 files changed, 946 insertions(+), 10 deletions(-) create mode 100644 linalg/x86_64/fma/avx512vnni_mmm_i32_16x16.S.j2 diff --git a/linalg/benches/vnni_i32.rs b/linalg/benches/vnni_i32.rs index 59e6f01676..6901427f4a 100644 --- a/linalg/benches/vnni_i32.rs +++ b/linalg/benches/vnni_i32.rs @@ -1,8 +1,10 @@ #![allow(dead_code)] -// Kernel-level benchmark: AVX-512 VNNI int8 GEMM (avx512vnni_mmm_i32_8x8, VPDPBUSD -// over the K=4-inner PackedI8K4 layout) vs the AVX2 int8 path (avx2_mmm_i32_8x8, -// vpmaddubsw-style widening). Both run the i8i8 packing (index 1) over the same -// M/K/N so the only difference is the matmul inner loop. +// Kernel-level benchmark: AVX-512 VNNI int8 GEMM over the K=4-inner PackedI8K4 +// layout (VPDPBUSD) vs the AVX2 int8 path (avx2_mmm_i32_8x8, vpmaddubsw-style +// widening). Three columns: the AVX2 baseline, the 8x8 ymm VNNI kernel, and the +// 16x16 zmm VNNI kernel (twice the columns per accumulator). All run the i8i8 +// packing (index 1) over the same M/K/N so the only difference is the matmul +// inner loop and tile geometry. use criterion::*; use tract_data::internal::*; use tract_linalg::mmm::{AsInputValue, FusedSpec, MatMatMul}; @@ -55,6 +57,11 @@ fn benches(c: &mut Criterion) { g.bench_with_input(BenchmarkId::new("avx512vnni", &id), &(m, k, n), |b, &(m, k, n)| { run_kernel(b, &*avx512vnni_mmm_i32_8x8.mmm(), m, k, n) }); + g.bench_with_input( + BenchmarkId::new("avx512vnni_16x16", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512vnni_mmm_i32_16x16.mmm(), m, k, n), + ); g.finish(); } } diff --git a/linalg/build.rs b/linalg/build.rs index 52eb92e0d1..c7770f1c03 100644 --- a/linalg/build.rs +++ b/linalg/build.rs @@ -329,18 +329,27 @@ fn main() { } else { cc::Build::new().files(files).flag("-mfma").compile("x86_64_fma"); } - // VNNI kernel compiled separately so old assemblers (binutils < 2.30, + // VNNI kernels compiled separately so old assemblers (binutils < 2.30, // e.g. Debian stretch) that can't encode `vpdpbusd ymm` don't break // the whole x86_64 build. The `tract_avx512vnni` cfg gates the // matching Rust extern declarations and dispatch registration. // - // The template stays in x86_64/fma/ (alongside dispatcher.j2 and the - // other partials it includes) so the jinja env can resolve its includes. + // The templates stay in x86_64/fma/ (alongside dispatcher.j2 and the + // other partials they include) so the jinja env can resolve its includes. if assembler_supports_avx512vnni() { let tmpl = path::Path::new("x86_64/fma/avx512vnni_mmm_i32_8x8.S.j2"); let out = out_dir.join(format!("avx512vnni_mmm_i32_8x8_{suffix}.S")); preprocess_file(tmpl, &out, &[], &suffix, false); - cc::Build::new().file(&out).flag("-mfma").compile("x86_64_avx512vnni"); + // The zmm 16x16 sibling shares the VPDPBUSD probe; compile it into + // the same object so `tract_avx512vnni` gates both kernels together. + let tmpl16 = path::Path::new("x86_64/fma/avx512vnni_mmm_i32_16x16.S.j2"); + let out16 = out_dir.join(format!("avx512vnni_mmm_i32_16x16_{suffix}.S")); + preprocess_file(tmpl16, &out16, &[], &suffix, false); + cc::Build::new() + .file(&out) + .file(&out16) + .flag("-mfma") + .compile("x86_64_avx512vnni"); println!("cargo:rustc-cfg=tract_avx512vnni"); } diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index 2369fc3388..e5943c5bc3 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -119,6 +119,26 @@ MMMExternKernel! { avx512vnni_mmm_i32_8x8(8,8)@(256,4) where(AVX512VNNI) store(i8) } +// AVX-512 VNNI int8 GEMM, zmm-wide 16x16 sibling of avx512vnni_mmm_i32_8x8. +// Accumulators are ROW-MAJOR (zmm{m} = row m of C, 16 columns per zmm), so one +// VPDPBUSD covers 16 columns x 4 K and the K=4 inner step issues 16 of them +// (one per row) = 1024 mul-adds/block, 2x the 8x8 ymm kernel's work per +// iteration. Same +128 A-bias / per-column correction as the 8x8 kernel, and +// the same PackedI8K4 layout (r=16 for both A and B). This is the int8 +// throughput tier of qmmm_i32 for big cores with AVX-512-VNNI but no AMX +// (Cascade Lake / Ice Lake / Tiger Lake server + client). +// +// boost(50) lifts it above the 8x8 VNNI candidate in the einsum kernel-selection +// scorer for unknown shapes, while staying below the AMX 16x16 kernels' boost(100) +// so AMX still wins when both are present. +#[cfg(tract_avx512vnni)] +MMMExternKernel! { avx512vnni_mmm_i32_16x16(16,16)@(64,4) where(AVX512VNNI) + packing[1] = i8i8 => |k| k.with_packing(PackedI8K4::new(16), PackedI8K4::new(16)); + quality(ManuallyOptimized) + boost(|| 50) + store(i8) +} + // AVX-VNNI ymm int8 GEMM: byte-for-byte the same body as avx512vnni_mmm_i32_8x8 // (8x8 ymm accumulators, PackedI8K4 inner-K, +128 bias trick), but the // VPDPBUSD instructions are forced to the VEX (AVX-VNNI) encoding via the @@ -227,8 +247,23 @@ pub fn plug(ops: &mut Ops) { #[cfg(tract_avx512vnni)] pub fn plug_avx512vnni(ops: &mut Ops) { ops.mmm_impls.push(avx512vnni_mmm_i32_8x8.mmm()); - ops.qmmm_i32 = Box::new(|_, _, _| avx512vnni_mmm_i32_8x8.mmm()); - log::info!("qmmm_i32: x86_64/avx512vnni activated"); + ops.mmm_impls.push(avx512vnni_mmm_i32_16x16.mmm()); + // Shape-adaptive dispatch mirroring the AMX int8 path: the zmm 16x16 tile is + // the throughput champion when each of M and N fills at least one tile; the + // 8x8 ymm kernel has lower per-call setup (smaller epilogue, half the + // accumulator file) and wins on small problems where the 16x16 tile-padding + // overhead dominates. Unknown dims default to the 16x16 champion. (No K gate: + // one VPDPBUSD step is only 4 K-bytes, so any K is fine; the choice is about + // filling the 16-wide M/N tile.) + ops.qmmm_i32 = Box::new(|m, _, n| { + let big = |o: Option, t: usize| o.is_none_or(|v| v >= t); + if big(m, 16) && big(n, 16) { + avx512vnni_mmm_i32_16x16.mmm() + } else { + avx512vnni_mmm_i32_8x8.mmm() + } + }); + log::info!("qmmm_i32: x86_64/avx512vnni (16x16 + 8x8 adaptive) activated"); } #[cfg(tract_avxvnni)] diff --git a/linalg/x86_64/fma/avx512vnni_mmm_i32_16x16.S.j2 b/linalg/x86_64/fma/avx512vnni_mmm_i32_16x16.S.j2 new file mode 100644 index 0000000000..4c61169ff9 --- /dev/null +++ b/linalg/x86_64/fma/avx512vnni_mmm_i32_16x16.S.j2 @@ -0,0 +1,885 @@ +// vim: set syntax=asm : +// +// AVX-512 VNNI int8 GEMM kernel, 16 M-rows x 16 N-cols i32 accumulator output. +// +// The zmm-wide (512-bit) sibling of avx512vnni_mmm_i32_8x8: where the 8x8 +// kernel accumulates 8 columns per ymm, this one accumulates 16 columns per +// zmm over 16 rows, so one VPDPBUSD covers a 16-lane x 4-K = 64 mul-add slab +// and the K=4 inner step issues 16 of them (one per row) -- 1024 mul-adds per +// K=4 block, 2x the work-per-iteration of the 8x8 ymm kernel. It is the int8 +// throughput tier of qmmm_i32 on big cores that have AVX-512-VNNI but no AMX +// (Cascade Lake / Ice Lake / Tiger Lake server + client SKUs). +// +// VPDPBUSD is u8 x s8, so (like the 8x8 kernel) the A bytes are offset by +128 +// to become u8 and the resulting 128 * sum_k(B[n]) bias is removed per column +// after the loop; the i32 accumulators are then bit-identical to the AVX2 / +// VNNI-8x8 / AMX reference paths. +// +// A and B both use PackedI8K4(16): per K=4 block, 16 elements x 4 K-bytes = 64 +// bytes, element e at byte offset e*4 holding [e, 4kb..4kb+3]; K is zero-padded +// to a multiple of 4 by the packer. +// +// REGISTER LAYOUT +// zmm0..zmm15 = accumulators, ROW-MAJOR: zmm{m} holds the 16 i32 lanes +// [C[m, 0], C[m, 1], ..., C[m, 15]] for row m. Row-major makes +// the per-column +128 bias a single vector subtract and lets +// the Store path write each row with one vmovdqu32. +// zmm16 = B K=4 block (lane n = B[n, 4kb..]); zmm17 = u8 ones (0x01010101); +// zmm18 = broadcast A[m, 4kb..] (+128 -> u8); zmm19 = bias (sum_k B[n]); +// zmm20 = 0x80808080 (the +128 byte bias added to A). + +{% if msvc %} + +_text segment +avx512vnni_mmm_i32_16x16_{{suffix}} proc + +{% else %} + +.intel_syntax noprefix +.text +.p2align 5 +.globl {{G}}avx512vnni_mmm_i32_16x16_{{suffix}} +{{G}}avx512vnni_mmm_i32_16x16_{{suffix}}: +.cfi_startproc + +{% endif %} + + push rbp + mov rbp, rsp + +{% if family == "windows" %} + and rsp,-16 + lea rsp,[rsp-160] + vmovaps [rsp], xmm6 + vmovaps [rsp+16*1],xmm7 + vmovaps [rsp+16*2],xmm8 + vmovaps [rsp+16*3],xmm9 + vmovaps [rsp+16*4],xmm10 + vmovaps [rsp+16*5],xmm11 + vmovaps [rsp+16*6],xmm12 + vmovaps [rsp+16*7],xmm13 + vmovaps [rsp+16*8],xmm14 + vmovaps [rsp+16*9],xmm15 + + push rdi + push rsi + + mov rdi, rcx + +{% endif %} + + push rbx + push r12 + push r13 + push r14 + push r15 + + sub rsp, 8 + +{% if family == "unix" %} +.cfi_def_cfa_offset 64 +{% endif %} + + stmxcsr [rsp + 4] +{% if msvc %} + mov rax, 1FC0h +{% else %} + mov rax, 0x1FC0 +{% endif %} + mov [rsp], eax + ldmxcsr [rsp] + +{% include "dispatcher.j2" %} + +{{L}}clear: + // vzeroall only zeros lower-256 of zmm0..15; explicitly zero the full + // accumulators (zmm0..zmm15) for AMX. + {% for r in range(0, 16) %} + vpxorq zmm{{r}}, zmm{{r}}, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_mat_mul: + mov r12, [rdi + 32] // packing + mov rbx, [rdi + 24] // B + mov rax, [rdi + 16] // A + + mov rcx, [rdi + 8] // k + test rcx, rcx + jz {{L}}non_linear_loop + + cmp r12, 1 + je {{L}}main_loop_packed_packed_i8i8 + +{{L}}main_loop_packed_packed: + // Generic i32 x i32 fallback path (not AMX). For row-major accumulators + // with zmm{m} = row m of C, accumulating C[m, n] += A[m, k] * B[k, n]: + // - load 16 B values for this K row into zmm16 (row of B) + // - for each m: broadcast A[m, k], multiply by zmm16, add to zmm{m} + vmovups zmm16, [rbx] // 16 i32 of B at this K row + + {% for m in range(0, 16) %} + vpbroadcastd zmm17, dword ptr [rax + {{m}} * 4] + vpmulld zmm18, zmm16, zmm17 + vpaddd zmm{{m}}, zmm{{m}}, zmm18 + {% endfor %} + + add rax, 64 // 16 i32 lanes per K step + add rbx, 64 + dec rcx + jnz {{L}}main_loop_packed_packed + + jmp {{L}}non_linear_loop + +{{L}}main_loop_packed_packed_i8i8: + // PackedI8K4(16) for both A and B: per K=4 block, 16 elements x 4 K-bytes = + // 64 bytes, element e at byte offset e*4 holding [e, 4kb..4kb+3]. + // B block -> zmm16, lane n = B[n, 4kb..] (the s8 operand) + // A[m] (its 4 K-bytes) broadcast to all 16 lanes, +128 -> the u8 operand + // VPDPBUSD zmm{m}, A_bcast(u8), Bblock(s8): lane n += sum_t (A[m,t]+128)*B[n,t] + // = C[m, n] + 128 * sum_t B[n, t]. That 128*sum_k(B[n]) bias is the same + // for every row m, so it is accumulated once per column in zmm19 (via a u8 + // all-ones VPDPBUSD) and subtracted from every accumulator after the loop, + // leaving the i32 accumulators bit-identical to the AVX2 / 8x8 paths. + + add rcx, 3 + shr rcx, 2 // rcx <- ceil(k/4) K=4 blocks + + mov r8d, 0x01010101 + vmovd xmm17, r8d + vpbroadcastd zmm17, xmm17 // zmm17 <- u8 ones (sum of B) + + mov r8d, 0x80808080 + vmovd xmm20, r8d + vpbroadcastd zmm20, xmm20 // zmm20 <- byte 0x80 (A + 128) + + vpxorq zmm19, zmm19, zmm19 // zmm19 <- per-col sum_k B[n] + +{{L}}loop_4k_i8i8_16x16: + vmovdqu32 zmm16, [rbx] // B block: lane n = B[n, 4kb..] + vpdpbusd zmm19, zmm17, zmm16 // sum_k B[n] += sum_t B[n, 4kb+t] + + {% for m in range(0, 16) %} + vpbroadcastd zmm18, dword ptr [rax + {{ m * 4 }}] + vpaddb zmm18, zmm18, zmm20 // s8 -> u8 (+128, modular) + vpdpbusd zmm{{m}}, zmm18, zmm16 // acc[m][n] += sum_t (A[m]+128)*B[n] + {% endfor %} + + add rax, 64 // next A K=4 block (16 rows * 4 K) + add rbx, 64 // next B K=4 block (16 cols * 4 K) + dec rcx + jnz {{L}}loop_4k_i8i8_16x16 + + // remove the +128 bias added on A: acc[m][n] -= 128 * sum_k B[n] (per column) + vpslld zmm19, zmm19, 7 // lane n <- 128 * sum_k B[n] + {% for m in range(0, 16) %} + vpsubd zmm{{m}}, zmm{{m}}, zmm19 + {% endfor %} + + jmp {{L}}non_linear_loop + +// ---- Scalar / per-row / per-col elementwise epilogues ------------------- + +{{L}}scalar_min: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpminsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_max: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpmaxsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_add: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpaddd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_mul: + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpmulld zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_sub: + // non-flipped sub = operand - acc (matches fma_mmm_ymm_ops.j2 scalar macro) + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm16, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}scalar_sub_flipped: + // flipped sub = acc - operand + vpbroadcastd zmm16, dword ptr [rdi + 8] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}leaky_relu: + // C[m, n] = (C[m, n] >= 0) ? C[m, n] : alpha * C[m, n] + vpbroadcastd zmm17, dword ptr [rdi + 8] // alpha as i32 scale factor + vpxorq zmm16, zmm16, zmm16 + {% for r in range(0, 16) %} + vpmulld zmm18, zmm{{r}}, zmm17 + vpcmpgtd k1, zmm16, zmm{{r}} // 1 where C < 0 + vpblendmd zmm{{r}}{k1}, zmm{{r}}, zmm18 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_min: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpminsd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_max: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpmaxsd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_add: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpaddd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_mul: + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpmulld zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_sub: + // non-flipped sub = operand - acc + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpsubd zmm{{m}}, zmm16, zmm{{m}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_row_sub_flipped: + // flipped sub = acc - operand + mov rax, [rdi + 8] + {% for m in range(0, 16) %}vpbroadcastd zmm16, dword ptr [rax + {{m * 4}}] + vpsubd zmm{{m}}, zmm{{m}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_min: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpminsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_max: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpmaxsd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_add: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpaddd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_mul: + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpmulld zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_sub: + // non-flipped sub = operand - acc + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm16, zmm{{r}} + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}per_col_sub_flipped: + // flipped sub = acc - operand + mov rax, [rdi + 8] + vmovdqu32 zmm16, [rax] + {% for r in range(0, 16) %}vpsubd zmm{{r}}, zmm{{r}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}load_tile: + // Scratch layout is COL-MAJOR i32 from scratch.rs Store/AddUnicast remnant: + // tile[col][row] at offset (col*MR + row)*4 with MR=16 + // = offset col*64 + row*4 + // For row-major accumulators we gather row m's 16 cols at index step 64. + mov r8, [rdi + 8] + vmovdqa32 zmm16, [rip + {{L}}lane_offsets_64] // [0, 64, 128, ..., 15*64] + {% for m in range(0, 16) %} + mov eax, 0xFFFF + kmovw k1, eax + vpgatherdd zmm{{m}}{k1}, [r8 + zmm16 + {{m * 4}}] + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_unicast: + mov r10, [rdi + 8] // c ptr (base) + mov rsi, [rdi + 16] // row stride + mov rbx, [rdi + 24] // col stride + mov r8, [rdi + 32] // item size + + cmp r8, 4 + je {{L}}non_linear_addc_i32 + + // i8 path: read 16 i8 from [r10 + m*rsi + n*rbx] for n=0..15, sign-extend + // to i32, add to zmm{m}. Use a stack scratch buffer (16 bytes per row). + sub rsp, 16 + {% for m in range(0, 16) %} + mov r8, r10 + {% for n in range(0, 16) %} + mov al, [r8] + mov byte ptr [rsp + {{n}}], al + add r8, rbx + {% endfor %} + vpmovsxbd zmm16, [rsp] + vpaddd zmm{{m}}, zmm{{m}}, zmm16 + add r10, rsi + {% endfor %} + add rsp, 16 + jmp {{L}}non_linear_loop + +{{L}}non_linear_addc_i32: + // i32 strided read of external (or scratch) tile. Build per-lane index + // vector [0, rbx, 2*rbx, ..., 15*rbx] once, then gather row by row. + mov eax, ebx + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 + vpmulld zmm16, zmm16, [rip + {{L}}lane_indices] // [0, rbx, 2*rbx, ...] + + {% for m in range(0, 16) %} + mov eax, 0xFFFF + kmovw k1, eax + vpgatherdd zmm17{k1}, [r10 + zmm16] + vpaddd zmm{{m}}, zmm{{m}}, zmm17 + add r10, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_row_col_products: + // bias[m, n] = row_data[m] * col_data[n], add to C[m, n]. + // For row-major regs: load 16 col_data values once into zmm16, + // for each m: broadcast row_data[m], FMA add. + mov rax, [rdi + 8] + mov rbx, [rdi + 16] + + vmovdqu32 zmm16, [rax] // 16 row_data values + vmovdqu32 zmm17, [rbx] // 16 col_data values + + {% for m in range(0, 16) %} + vpbroadcastd zmm18, dword ptr [rax + {{m * 4}}] // splat row_data[m] + vpmulld zmm19, zmm18, zmm17 + vpaddd zmm{{m}}, zmm{{m}}, zmm19 + {% endfor %} + jmp {{L}}non_linear_loop + +// ---- Q-scale (mult-shift with rounding) --------------------------------- + +{{L}}q_scale: + mov r8, [rdi + 16] // policy + vpbroadcastd zmm16, dword ptr [rdi + 24] // multi (broadcast i32) + + mov rax, 1 + vmovq xmm17, rax + vpbroadcastq zmm17, xmm17 // zmm17 <- 1 (i64 lanes) + + mov rax, [rdi + 8] // shift + add rax, 31 + vmovq xmm18, rax + vpbroadcastq zmm18, xmm18 // zmm18 <- (shift+31) (i64 lanes) + + vpsubq zmm19, zmm18, zmm17 + vpsllvq zmm19, zmm17, zmm19 // zmm19 <- 1 << (shift+31-1) (i64) + + // Per-lane interleave mask for blending evens / shifted-odds. + // bit i = 1 means take from "evens" source in vpblendmd; bit 0,2,4,...,14 set. + mov eax, 0x5555 + kmovw k7, eax + + cmp r8, 1 + je {{L}}q_scale_rounding_zero + cmp r8, 2 + je {{L}}q_scale_rounding_away + cmp r8, 3 + je {{L}}q_scale_rounding_minus_inf + cmp r8, 4 + je {{L}}q_scale_rounding_plus_inf + cmp r8, 5 + je {{L}}q_scale_rounding_even + cmp r8, 6 + je {{L}}q_scale_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_scale_rounding_zero: // signum * ( (abs + nudge - 1) >> shift ) +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 // even-lane i32 -> i64 mul + vpmuldq zmm21, zmm21, zmm16 // odd-lane i32 -> i64 mul + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsubq zmm20, zmm20, zmm17 + vpsubq zmm21, zmm21, zmm17 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 // k7=0x5555: evens from zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_away: // signum * ( (abs + nudge) >> shift ) +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_minus_inf: // nudge by -1 where input was negative +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpxorq zmm22, zmm22, zmm22 + vpcmpgtd k1, zmm{{i}}, zmm22 // k1: 1 where input > 0 (we want the inverse, see below) + knotw k1, k1 // 1 where input <= 0 -- we want "input was negative => subtract 1" + // For "<0": use compare against 0 with vpcmpltd + vpxorq zmm22, zmm22, zmm22 + vpcmpltd k1, zmm{{i}}, zmm22 // 1 where input < 0 + vmovdqa32 zmm23{k1}{z}, [rip + {{L}}all_ones_i32] // (1 << 0) per neg lane, 0 elsewhere + + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + // Subtract 1 from i64-evens / i64-odds where the original i32 input was < 0. + vpsrldq zmm24, zmm23, 4 + vpmovsxdq zmm25, ymm23 + vpmovsxdq zmm26, ymm24 + vpsubq zmm20, zmm20, zmm25 + vpsubq zmm21, zmm21, zmm26 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_plus_inf: // nudge by +1 where input was non-negative +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpxorq zmm22, zmm22, zmm22 + vpcmpled k1, zmm22, zmm{{i}} // 1 where input >= 0 + vmovdqa32 zmm23{k1}{z}, [rip + {{L}}all_ones_i32] + + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrldq zmm24, zmm23, 4 + vpmovsxdq zmm25, ymm23 + vpmovsxdq zmm26, ymm24 + vpsubq zmm20, zmm20, zmm25 + vpsubq zmm21, zmm21, zmm26 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_even: // banker's: round half to even +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpsrlq zmm22, zmm20, xmm18 + vpandq zmm22, zmm22, zmm17 + vpaddq zmm20, zmm20, zmm22 + vpsubq zmm20, zmm20, zmm17 + + vpsrlq zmm22, zmm21, xmm18 + vpandq zmm22, zmm22, zmm17 + vpaddq zmm21, zmm21, zmm22 + vpsubq zmm21, zmm21, zmm17 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_scale_rounding_odd: // round half to odd +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsrldq zmm21, zmm20, 4 + vpmuldq zmm20, zmm20, zmm16 + vpmuldq zmm21, zmm21, zmm16 + + vpsrlq zmm22, zmm20, xmm18 + vpandq zmm22, zmm22, zmm17 + vpsubq zmm20, zmm20, zmm22 + + vpsrlq zmm22, zmm21, xmm18 + vpandq zmm22, zmm22, zmm17 + vpsubq zmm21, zmm21, zmm22 + + vpaddq zmm20, zmm20, zmm19 + vpaddq zmm21, zmm21, zmm19 + + vpsrlq zmm20, zmm20, xmm18 + vpsrlq zmm21, zmm21, xmm18 + + vpslldq zmm21, zmm21, 4 + vpblendmd zmm20{k7}, zmm21, zmm20 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shl: + mov eax, [rdi + 8] // -shift (count: i32) + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 + {% for i in range(0, 16) %}vpsllvd zmm{{i}}, zmm{{i}}, zmm16 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr: + mov r8, [rdi + 16] // policy + + mov eax, 1 + vmovd xmm16, eax + vpbroadcastd zmm16, xmm16 // zmm16 <- 1 (i32 lanes) + + mov eax, [rdi + 8] // shift + vmovd xmm17, eax + vpbroadcastd zmm17, xmm17 // zmm17 <- shift (i32 lanes) + + mov ebx, 1 + mov cl, al + sub cl, 1 + sal ebx, cl // ebx <- 1 << (shift - 1) + vmovd xmm18, ebx + vpbroadcastd zmm18, xmm18 // zmm18 <- "half" + + vpxorq zmm19, zmm19, zmm19 // zeroes + + cmp r8, 1 + je {{L}}q_shr_rounding_zero + cmp r8, 2 + je {{L}}q_shr_rounding_away + cmp r8, 3 + je {{L}}q_shr_rounding_minus_inf + cmp r8, 4 + je {{L}}q_shr_rounding_plus_inf + cmp r8, 5 + je {{L}}q_shr_rounding_even + cmp r8, 6 + je {{L}}q_shr_rounding_odd + + jmp {{L}}unsupported + +{{L}}q_shr_rounding_zero: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsubd zmm20, zmm20, zmm16 + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_away: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_minus_inf: +{% for i in range(0, 16) %} + vpsubd zmm{{i}}, zmm{{i}}, zmm16 + vpaddd zmm{{i}}, zmm{{i}}, zmm18 + vpsravd zmm{{i}}, zmm{{i}}, zmm17 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_plus_inf: +{% for i in range(0, 16) %} + vpaddd zmm{{i}}, zmm{{i}}, zmm18 + vpsravd zmm{{i}}, zmm{{i}}, zmm17 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_even: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsravd zmm21, zmm20, zmm17 + vpandq zmm21, zmm21, zmm16 + vpsubd zmm21, zmm21, zmm16 // nudge = ((abs >>l shift) & 1) - 1 + vpaddd zmm20, zmm20, zmm21 + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +{{L}}q_shr_rounding_odd: +{% for i in range(0, 16) %} + vpabsd zmm20, zmm{{i}} + vpsravd zmm21, zmm20, zmm17 + vpandq zmm21, zmm21, zmm16 + vpsubd zmm21, zmm19, zmm21 // nudge = -((abs >>l shift) & 1) + vpaddd zmm20, zmm20, zmm21 + vpaddd zmm20, zmm20, zmm18 + vpsravd zmm20, zmm20, zmm17 + // emulate AVX2 vpsignd (no AVX-512 form): apply sign of original acc. + vpxorq zmm26, zmm26, zmm26 + vpcmpgtd k1, zmm26, zmm{{i}} // k1 = 1 where acc < 0 + vpsubd zmm27, zmm26, zmm20 // zmm27 = -zmm20 + vpblendmd zmm{{i}}{k1}, zmm20, zmm27 +{% endfor %} + jmp {{L}}non_linear_loop + +// ---- Store --------------------------------------------------------------- + +{{L}}store: + mov r8, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rdx, [rdi + 24] // col stride + mov rcx, [rdi + 32] // item size + + cmp rcx, 4 + je {{L}}store_strides_i32 + // else: i8 fallthrough + + cmp rdx, 1 + je {{L}}store_strides_i8_row_contig + + // Generic i8 strided store: per row, per lane scalar byte stores + {% for m in range(0, 16) %} + mov r10, r8 + // Extract from each 128-bit slice of zmm{{m}} + vextracti32x4 xmm20, zmm{{m}}, 0 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 1 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 2 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 3 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov byte ptr [r10], bl + add r10, rdx + {% endfor %} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_i8_row_contig: + // Each row is 16 i8 contiguous; one vpmovdb per row. + {% for m in range(0, 16) %} + vpmovdb [r8], zmm{{m}} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_i32: + cmp rdx, 4 + je {{L}}store_strides_i32_row_contig + + // Generic i32 strided store + {% for m in range(0, 16) %} + mov r10, r8 + vextracti32x4 xmm20, zmm{{m}}, 0 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 1 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 2 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + vextracti32x4 xmm20, zmm{{m}}, 3 + {% for n in range(0, 4) %} + vpextrd ebx, xmm20, {{n}} + mov dword ptr [r10], ebx + add r10, rdx + {% endfor %} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}store_strides_i32_row_contig: + // C is row-major in memory: each row's 16 i32 are contiguous; one + // 64-byte aligned-or-unaligned store per row. + {% for m in range(0, 16) %} + vmovdqu32 [r8], zmm{{m}} + add r8, rsi + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}return: + ldmxcsr [rsp + 4] + add rsp, 8 + + pop r15 + pop r14 + pop r13 + pop r12 + pop rbx + +{% if family == "windows" %} + pop rsi + pop rdi + + vmovaps xmm15, [rsp+16*9] + vmovaps xmm14, [rsp+16*8] + vmovaps xmm13, [rsp+16*7] + vmovaps xmm12, [rsp+16*6] + vmovaps xmm11, [rsp+16*5] + vmovaps xmm10, [rsp+16*4] + vmovaps xmm9, [rsp+16*3] + vmovaps xmm8, [rsp+16*2] + vmovaps xmm7, [rsp+16*1] + vmovaps xmm6, [rsp] +{% endif %} + + mov rsp, rbp + pop rbp + ret + +// ---- Read-only data (RIP-relative) --------------------------------------- + +.p2align 6 +{{L}}lane_offsets_64: + .int 0, 64, 128, 192, 256, 320, 384, 448 + .int 512, 576, 640, 704, 768, 832, 896, 960 + +.p2align 6 +{{L}}lane_indices: + .int 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + +.p2align 6 +{{L}}all_ones_i32: + .int 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + +{% if msvc %} +avx512vnni_mmm_i32_16x16_{{suffix}} endp +_text ends +end +{% else %} +.cfi_endproc +{% endif %} From 3b78fc8b43860e0efebd84e1492c3811a0b37a54 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Tue, 2 Jun 2026 18:11:01 +0000 Subject: [PATCH 17/21] linalg/x86_64: document the int8 GEMM kernel cascade Maintainer note covering the AVX2 / AVX-512-VNNI (8x8 + 16x16) / AVX-VNNI / AMX (8x8 + 16x16 int8, 16x16 bf16) kernel family: the u8 x s8 +128 bias trick, the PackedI8K4 / PackedAmxA / bf16 packing layouts, the build.rs assembler-probe cfg gates (tract_amx_int8 / tract_amx_bf16 / tract_avxvnni), the plug() and qmmm_i32 dispatch cascade with the einsum scorer boost values, the testing model and why the AMX sub-handler bug stayed hidden (kernel tests skip when the host CPU lacks the feature), and a short follow-up list. https://claude.ai/code/session_015PiiJ2Ave1m7PsXbnMPpBV --- linalg/X86_64_INT8_GEMM.md | 135 +++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 linalg/X86_64_INT8_GEMM.md diff --git a/linalg/X86_64_INT8_GEMM.md b/linalg/X86_64_INT8_GEMM.md new file mode 100644 index 0000000000..c960b5f8e5 --- /dev/null +++ b/linalg/X86_64_INT8_GEMM.md @@ -0,0 +1,135 @@ +# x86_64 int8 GEMM kernels + +This note documents the int8 (i32-accumulator) matrix-multiply kernel family for +x86_64, for maintainers touching `linalg/src/x86_64_fma/mmm.rs` (Rust +registration + dispatch) and `linalg/x86_64/fma/*.S.j2` (assembly templates). + +The kernels form a throughput cascade from the portable AVX2 emulation up to +Intel AMX, with AVX-512-VNNI in between. The right kernel is chosen at runtime +from CPUID + (for selection among ties) the einsum kernel scorer. + +## Kernel family + +| Kernel | ISA | Tile M×N | Matmul instr | A packing | B packing | Build gate | +|---|---|---|---|---|---|---| +| `avx2_mmm_i32_8x8` | AVX2 | 8×8 (ymm) | `vpmaddubsw` emulation | `PackedFormat` i8 | `PackedFormat` i8 | always | +| `avx512vnni_mmm_i32_8x8` | AVX-512-VNNI | 8×8 (ymm) | `vpdpbusd` | `PackedI8K4(8)` | `PackedI8K4(8)` | always | +| `avx512vnni_mmm_i32_16x16` | AVX-512-VNNI | 16×16 (zmm) | `vpdpbusd` ×16 rows | `PackedI8K4(16)` | `PackedI8K4(16)` | always | +| `avxvnni_mmm_i32_8x8` | AVX-VNNI (VEX) | 8×8 (ymm) | `{vex} vpdpbusd` | `PackedI8K4(8)` | `PackedI8K4(8)` | `tract_avxvnni` | +| `avx512amx_mmm_i32_8x8` | AMX-INT8 | 8×8 | `tdpbssd` | `PackedAmxA(8)` | `PackedI8K4(8)` | `tract_amx_int8` | +| `avx512amx_mmm_i32_16x16` | AMX-INT8 | 16×16 | `tdpbssd` (16384 MACs) | `PackedAmxA(16)` | `PackedI8K4(16)` | `tract_amx_int8` | +| `avx512amx_mmm_f32_16x16` (f32) | AMX-BF16 | 16×16 | `tdpbf16ps` | `PackedAmxBf16A(16)` | `PackedBf16K2(16)` | `tract_amx_bf16` | + +The two AVX-512-VNNI kernels and the AVX2 one are always compiled (their +mnemonics are in every supported binutils); the AMX and AVX-VNNI kernels are +behind assembler-probe cfgs (see below). + +## The u8×s8 `+128` bias trick (VNNI / AVX-VNNI) + +`vpdpbusd` is **u8 × s8** (unsigned first operand). To compute the s8×s8 product +we need, the kernel offsets the A bytes by `+128` (modular `vpaddb`, making them +u8 in `[0,255]`) and then removes the resulting per-column bias +`128 * sum_k(B[n])` after the K loop. The bias is accumulated cheaply during the +loop with a `vpdpbusd` against an all-`0x01` u8 vector. + +- **8×8 (ymm)** accumulators are *column-major* (`ymm{n}` = column n), so the + bias is computed per column and splatted back with `vpermd`. +- **16×16 (zmm)** accumulators are *row-major* (`zmm{m}` = row m, 16 columns in + the 16 lanes). The per-column bias is then a single lane-aligned vector, so the + correction is one `vpsubd` per row — cleaner and cheaper than the 8×8 path. + +AMX `tdpbssd` is **s8 × s8**, so the AMX int8 kernels need no `+128` trick; their +i32 accumulators are bit-identical to the AVX2 / VNNI reference. + +## Packing formats (see `linalg/src/frame/pack.rs`) + +- **`PackedI8K4(r)`** — K=4-inner. Per K=4 block, `r` elements × 4 K-bytes (= `4r` + bytes); element `e` sits at byte offset `e*4` holding `[e, 4kb..4kb+3]`. K is + zero-padded to a multiple of 4, so kernels read `ceil(k/4)` full blocks safely. +- **`PackedAmxA(r)`** — AMX A layout: per panel of `r` M-rows, row-major within + the panel, K-bytes contiguous, K padded to a multiple of 64 (one `tdpbssd` step + consumes 64 K-bytes). +- **`PackedAmxBf16A` / `PackedBf16K2`** — f32 inputs truncated to bf16 at pack + time (round-to-nearest-even, matching `VCVTNEPS2BF16`) for the AMX-BF16 f32 path. + +## Build-time cfg gating (`linalg/build.rs`) + +Some mnemonics are too new for old toolchains, so each is guarded by an +**assembler probe** that tries to compile a tiny dummy `.S`. The probe sets a cfg +that gates *both* compiling the kernel template and referencing its Rust symbol: + +| cfg | enables | requires | +|---|---|---| +| `tract_amx_int8` | AMX int8 kernels (`tdpbssd`) | gas ≥ 2.34 | +| `tract_amx_bf16` | AMX bf16 kernel (`tdpbf16ps`) | gas ≥ 2.34 | +| `tract_avxvnni` | AVX-VNNI ymm kernel (`{vex}` prefix) | binutils ≥ 2.36 | + +Kernel `.S.j2` templates are sorted by filename prefix in `build.rs`: +`avx512amx_*_i32_*` and `*_f32_*` are pulled into their own gated compiles; +`avxvnni_*` likewise; everything else (including `avx512vnni_*`) stays in the +generic `-mfma` bulk compile. **A new `avx512vnni_*` kernel needs no `build.rs` +change** — but note that adding a brand-new template file may not trigger a +`build.rs` re-run on an incremental build (it emits per-file `rerun-if-changed`), +so `touch linalg/build.rs` after creating one. + +These cfgs reflect **assembler** capability, not the host CPU. A kernel can be +*compiled* (assembler supports the mnemonic) yet never *run* (CPU lacks the +feature) — which matters for tests (below). + +## Dispatch + +`plug()` installs kernels in nested feature order, richest ISA last: + +``` +avx2 → [avxvnni] → fma → avx512f → avx512vnni → [amx_int8] (int8 path) + → [amx_bf16 overlay] (f32 path) +``` + +Each `plug_*` pushes kernels into `ops.mmm_impls` and may set the explicit int8 +picker `ops.qmmm_i32`. Because later plugs overwrite `qmmm_i32`, the best +available ISA wins. The pickers are **shape-adaptive**: the 16×16 tile is the +throughput champion when both M and N fill at least one tile; the 8×8 kernel has +lower per-call setup and wins on small problems. (AMX additionally requires +K ≥ 64; VNNI has no K gate since one `vpdpbusd` step is just 4 K-bytes.) + +For paths that don't go through `qmmm_i32` (symbolic / unknown shapes via the +einsum kernel scorer), selection among equal-quality kernels uses +`-quality_cost*1000 + boost`. All `ManuallyOptimized` kernels tie on quality, so +`boost` breaks the tie: + +| Kernel | boost | +|---|---| +| `avx512amx_mmm_i32_16x16`, `avx512amx_mmm_f32_16x16` | 100 | +| `avx512vnni_mmm_i32_16x16` | 50 | +| all 8×8 kernels | 0 | + +So for unknown shapes: AMX 16×16 ≻ VNNI 16×16 ≻ {VNNI/AMX 8×8}. When AMX is +absent, VNNI 16×16 is the int8 champion. + +## Testing and a cautionary tale + +`MMMExternKernel!` auto-generates a `#[cfg(test)] mod test_` with +packed-packed (per packing), fused-op frame, quant-rounding, store, and proptest +coverage. The harness **skips a kernel when `ker.is_supported_here()` is false** +(runtime CPUID). Consequently **AMX kernel tests only run on AMX hardware.** + +The usual dev/CI host is Cascade Lake-class (AVX-512-VNNI, no AMX), so the AMX +tests are skipped there. That let a swapped-operand bug in the AMX 16×16 `sub` +fused-op handlers (`scalar_sub` / `per_row_sub` / `per_col_sub` and their +`_flipped` twins computed `acc - operand` instead of the correct `operand - acc`) +go unnoticed — until `avx512vnni_mmm_i32_16x16`, which **reuses the same zmm +row-major epilogue** and *does* run on VNNI hardware, exposed it (negated +results). Takeaway: a VNNI kernel that shares an AMX kernel's epilogue effectively +becomes the on-hardware test for that shared epilogue. The convention for the +non-commutative `sub` lives in `linalg/x86_64/fma/fma_mmm_ymm_ops.j2` +(`scalar` / `per_row` / `per_col` macros, `flipped` flag). + +## Possible follow-ups + +- A dispatch integration test asserting `qmmm_i32` selects the 16×16 kernel for + large M,N and the 8×8 for small (no precedent for kernel-selection asserts + in-tree yet; would need a small helper to read back the chosen `MatMatMul`). +- On Sapphire Rapids+ hardware: validate the AMX `sub` fix end-to-end, benchmark + the AMX kernels, and re-check the 16×16/8×8 crossover and the `boost` values. +- A wider AVX-512-BF16 (`vdpbf16ps`) f32 kernel for Cooper Lake-class cores, and + a Q4_0/Q8_0 → s8 packer feeding the AMX/VNNI 16×16 path directly. From e0bacc252a60bd946204c7e90e1a22be7cdbf7a7 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Tue, 2 Jun 2026 18:35:45 +0000 Subject: [PATCH 18/21] linalg/x86_64: add AMX validation + benchmark runbook for AMX-capable hosts Self-contained runbook for a session on a CPU with Intel AMX. Tasks that session to benchmark every int8/bf16 GEMM kernel in the tree -- the AMX kernels (int8 8x8 + 16x16, bf16 16x16) and the improved AVX-512-VNNI kernels (8x8 + the new zmm 16x16) -- and to run the AMX correctness suite, which validates the AMX 16x16 sub fused-op bugfix that could not be exercised on the non-AMX dev box. Covers: AMX prerequisites (CPUID amx_*, kernel >= 5.16 for the arch_prctl XTILEDATA permission), the gotcha that AMX kernel tests silently no-op (report "ok") when the host can't run AMX, using the benches as the authoritative runtime gate-check, exact test/bench commands, the bench column layout, the head-to-head comparisons to report (AMX 16x16 vs VNNI 16x16 etc.), a one-shot script, and a note that Intel SDE can emulate AMX for correctness but not perf. https://claude.ai/code/session_015PiiJ2Ave1m7PsXbnMPpBV --- linalg/AMX_BENCH_RUNBOOK.md | 211 ++++++++++++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 linalg/AMX_BENCH_RUNBOOK.md diff --git a/linalg/AMX_BENCH_RUNBOOK.md b/linalg/AMX_BENCH_RUNBOOK.md new file mode 100644 index 0000000000..c5fde39c07 --- /dev/null +++ b/linalg/AMX_BENCH_RUNBOOK.md @@ -0,0 +1,211 @@ +# AMX validation & benchmark runbook + +**For: a Claude Code session (or human) on an x86_64 CPU that has Intel AMX.** + +The kernel work on branch `claude/zealous-galileo-fEQ3d` was developed on a +Cascade Lake-class container (AVX-512-VNNI, **no AMX**). Everything that can run +without AMX is already validated there. This runbook covers the two things that +box **could not** do and that need a real AMX CPU. + +## Your task + +**Benchmark every int8 / bf16 GEMM kernel in this tree on this AMX CPU — all the +AMX kernels *and* the AVX-512-VNNI kernels we just improved — and run the AMX +correctness suite.** Full kernel inventory to cover: + +| Kernel | ISA | Covered by bench | +|---|---|---| +| `avx512amx_mmm_i32_8x8` | AMX int8 (`tdpbssd`) | `amx_i32` | +| `avx512amx_mmm_i32_16x16` | AMX int8 (`tdpbssd`) | `amx_i32` | +| `avx512amx_mmm_f32_16x16` | AMX bf16→f32 (`tdpbf16ps`) | `amx_f32` | +| `avx512vnni_mmm_i32_8x8` | AVX-512-VNNI (`vpdpbusd`) | `vnni_i32`, `amx_i32` | +| **`avx512vnni_mmm_i32_16x16`** ← new | AVX-512-VNNI (`vpdpbusd`, zmm) | `vnni_i32` | +| `avx2_mmm_i32_8x8` (baseline) | AVX2 | both i32 benches | + +Running the three benches in Step 4 covers all of the above. Yes — bench the VNNI +kernels here too: an AMX CPU (Sapphire Rapids+) also has AVX-512-VNNI, so it's the +one place you can measure AMX 16×16 and VNNI 16×16 **on the same silicon** and see +how much AMX actually wins. + +In addition, this AMX CPU is the only place that can: + +1. **Correctness-test the AMX kernels** — including a recent bugfix to the AMX + 16×16 `sub` fused-op handlers that was invisible on non-AMX hardware. +2. **Benchmark** the AMX int8 / bf16 kernels and the new AVX-512-VNNI 16×16 + kernel head-to-head. + +> ⚠️ **Most important caveat:** every AMX kernel test short-circuits to "ok" when +> the host can't run AMX (`is_supported_here()` is false). So a green +> `cargo test` on the wrong box proves **nothing**. You must first confirm AMX is +> actually live (Step 2). The **benchmarks are the authoritative gate-check** — +> they print an explicit "AMX … not available, skipping" message and emit no AMX +> columns if the gate is closed. + +--- + +## 0. Prerequisites + +| Requirement | Why | Check | +|---|---|---| +| AMX-capable CPU (Sapphire Rapids / Emerald Rapids / Granite Rapids Xeon, or Xeon Max) | `tdpbssd` / `tdpbf16ps` | `grep -o 'amx[_a-z]*' /proc/cpuinfo \| sort -u` → expect `amx_bf16 amx_int8 amx_tile` | +| Linux kernel ≥ 5.16 | AMX tile-data XSAVE permission via `arch_prctl(ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)` | `uname -r` | +| binutils/gas ≥ 2.34 (≥ 2.36 ideal) | assembles AMX mnemonics (and `{vex}` for AVX-VNNI) | `as --version` | +| Rust stable (dev used 1.94–1.96) | build | `cargo --version` | + +If `/proc/cpuinfo` shows no `amx_*` flags, this is the wrong machine — stop here. + +--- + +## 1. Get the code + +**Fresh clone (preferred):** +```sh +git clone https://github.com/czoli1976/tract.git +cd tract +git checkout claude/zealous-galileo-fEQ3d +``` + +**Existing checkout:** +```sh +git fetch origin claude/zealous-galileo-fEQ3d +git checkout claude/zealous-galileo-fEQ3d && git pull +# IMPORTANT when pulling into a checkout that was built before: the new kernel +# template (avx512vnni_mmm_i32_16x16.S.j2) may not trigger a build-script rerun +# (build.rs emits per-file rerun-if-changed). Force it once: +touch linalg/build.rs +``` +(A fresh clone needs no `touch` — it renders every template on first build.) + +--- + +## 2. Confirm AMX is actually live (do this first) + +The AMX kernels are gated by CPUID **and** the kernel granting tile-data XSAVE +permission. The benchmark is the cleanest runtime probe — if AMX is unavailable +it prints a skip line instead of numbers: + +```sh +cargo bench -p tract-linalg --bench amx_i32 -- --warm-up-time 0.2 --measurement-time 0.5 --sample-size 10 2>&1 | head -20 +``` + +- ✅ **Good:** you see `avx512amx_8x8` and `avx512amx_16x16` lines with `thrpt:`. +- ❌ **Bad:** `AMX int8 not available (CPUID + arch_prctl gate failed), skipping` + → AMX isn't usable (check kernel ≥ 5.16, not in a VM that masks AMX, XSAVE + permission not blocked by a seccomp/container policy). Don't proceed — the + correctness tests would silently no-op. + +Optional: `RUST_LOG=info cargo test -p tract-linalg --lib avx512amx_mmm_i32_16x16 -- --nocapture 2>&1 | grep -i activated` +should log `qmmm_i32: x86_64/avx512amx_int8 (16x16 + 8x8 adaptive) activated`. + +--- + +## 3. Correctness validation (the priority) + +Only meaningful once Step 2 confirms AMX is live. + +```sh +# All three AMX kernel suites: int8 8x8, int8 16x16, bf16 16x16. +cargo test -p tract-linalg --lib avx512amx 2>&1 | tail -30 + +# Full x86_64 mmm suite (AMX + VNNI + AVX2 + FMA + AVX-512), for completeness. +cargo test -p tract-linalg --lib x86_64_fma::mmm 2>&1 | tail -5 +``` + +**Expected:** `test result: ok. passed; 0 failed`. + +**What this specifically proves (and the dev box couldn't):** the +`scalar_sub` / `per_row_sub` / `per_col_sub` (+ `_flipped`) fused-op tests for +`test_avx512amx_mmm_i32_16x16` and `test_avx512amx_mmm_f32_16x16` **actually +execute**. Those guard commit `99eb75b9d`, which fixed swapped operands in the +AMX `sub` handlers (they were computing `acc − operand` instead of +`operand − acc`, i.e. negated results). This fix is currently only +build-verified — **this run is what confirms it on real silicon.** + +--- + +## 4. Benchmarks + +On real hardware use default sampling (drop the reduced flags) and pin a core for +stable numbers. Idle box, turbo/frequency-scaling fixed if you can. + +```sh +# int8: AVX2 vs VNNI 8x8 vs AMX 8x8 vs AMX 16x16 +taskset -c 2 cargo bench -p tract-linalg --bench amx_i32 + +# f32 via bf16: FMA 16x6 vs AVX-512 16x12 vs AMX-BF16 16x16 +taskset -c 2 cargo bench -p tract-linalg --bench amx_f32 + +# the new kernel in isolation: AVX2 vs VNNI 8x8 vs VNNI 16x16 +taskset -c 2 cargo bench -p tract-linalg --bench vnni_i32 +``` + +Bench layout (group `… /packed_packed`, shapes `64x256x64`, `256x256x256`, +`512x512x512`, `1024x1024x64`, throughput in `Gelem/s`): + +| Bench | Columns | +|---|---| +| `amx_i32` | `avx2`, `avx512vnni`, `avx512amx_8x8`, `avx512amx_16x16` | +| `amx_f32` | `fma_16x6`, `avx512_16x12`, `avx512amx_bf16_16x16` | +| `vnni_i32` | `avx2`, `avx512vnni` (8×8), `avx512vnni_16x16` | + +Criterion writes HTML reports under `target/criterion/`. + +--- + +## 5. What to report back + +**Correctness** +- Confirm AMX was live (Step 2 showed AMX columns / cpuinfo has `amx_int8`). +- `cargo test … avx512amx` result line (`N passed; 0 failed`), confirming the + AMX `*_sub` fused-op tests passed → bugfix `99eb75b9d` validated on hardware. + +**Performance** — the `thrpt:` (Gelem/s) per shape per column for all three +benches, plus these head-to-head reads: + +1. **AMX 16×16 vs VNNI 16×16** (compare `amx_i32`'s `avx512amx_16x16` against + `vnni_i32`'s `avx512vnni_16x16`, same shapes). AMX should win — that justifies + the dispatch ordering (`boost(100)` for AMX 16×16 > `boost(50)` for VNNI + 16×16). Report the ratio. +2. **AMX 16×16 vs AMX 8×8** — the 4×-work-per-instruction claim and where 8×8 + wins on small shapes (informs the `qmmm_i32` 16/8 crossover). +3. **VNNI 16×16 vs 8×8** — does the ~1.3–2.1× measured on Cascade Lake hold on + this CPU too? +4. **AMX-BF16 16×16 vs AVX-512 f32 16×12** — bf16 throughput win (with the bf16 + precision trade-off noted in `linalg/X86_64_INT8_GEMM.md`). + +--- + +## Appendix A — one-shot script + +```sh +set -e +echo "## CPU AMX flags:"; grep -o 'amx[_a-z]*' /proc/cpuinfo | sort -u || true +echo "## kernel:"; uname -r +echo "## gate check (expect AMX columns, not a skip message):" +cargo bench -p tract-linalg --bench amx_i32 -- --warm-up-time 0.2 --measurement-time 0.5 --sample-size 10 2>&1 | grep -iE "amx|skipping|thrpt" | head +echo "## correctness:" +cargo test -p tract-linalg --lib avx512amx 2>&1 | tail -3 +cargo test -p tract-linalg --lib x86_64_fma::mmm 2>&1 | tail -3 +echo "## full benches:" +taskset -c 2 cargo bench -p tract-linalg --bench amx_i32 +taskset -c 2 cargo bench -p tract-linalg --bench amx_f32 +taskset -c 2 cargo bench -p tract-linalg --bench vnni_i32 +``` + +## Appendix B — what's on this branch + +Three commits on top of the prior AMX/VNNI work: + +| Commit | Summary | +|---|---| +| `9e8f1c5aa` | doc: `linalg/X86_64_INT8_GEMM.md` — the full int8 GEMM kernel cascade | +| `26726db8e` | **feat**: `avx512vnni_mmm_i32_16x16` — zmm-wide int8 VNNI kernel (1.3–2.1× over 8×8 on Cascade Lake) | +| `99eb75b9d` | **fix**: swapped operands in AMX 16×16 `sub` fused-op handlers (int8 + bf16) — **needs AMX to validate** | + +Background and the kernel-selection/dispatch model: see +`linalg/X86_64_INT8_GEMM.md`. + +> Note on Intel SDE: SDE *can* emulate AMX for **functional/correctness** checks +> on a non-AMX box (`sde64 -spr -- `), but it is **not** a +> performance model — timings under SDE are meaningless. Use it only if no AMX +> hardware is available, and never for the benchmark numbers above. From fbf7dcaed352dbb1ee8a55aa86f7a4056ca108fc Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Wed, 3 Jun 2026 06:44:52 +0000 Subject: [PATCH 19/21] linalg/x86_64: add AMX int8/bf16 GEMM validation + benchmark results Results from running linalg/AMX_BENCH_RUNBOOK.md on an AMX-capable Xeon (2026-06-02): AMX-live confirmation; correctness (bugfix 99eb75b9d validated on silicon; the 3 bf16 test failures root-caused to an f32-grade harness tolerance and empirically verified against a bf16 reference -- not a kernel defect); the three int8/bf16 throughput tables; and the four head-to-head ratios. Includes a reproducibility note (the AMX host was later reclaimed). https://claude.ai/code/session_018Hes6yEvk2TSWB26SAJfqT --- linalg/AMX_BENCH_RESULTS.md | 85 +++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 linalg/AMX_BENCH_RESULTS.md diff --git a/linalg/AMX_BENCH_RESULTS.md b/linalg/AMX_BENCH_RESULTS.md new file mode 100644 index 0000000000..b1261e61d3 --- /dev/null +++ b/linalg/AMX_BENCH_RESULTS.md @@ -0,0 +1,85 @@ +# AMX validation & benchmark results + +Run of `linalg/AMX_BENCH_RUNBOOK.md` on real Intel AMX hardware. + +- **Host:** `Intel(R) Xeon(R) Processor @ 2.10GHz` (Sapphire/Emerald Rapids-class), 4 vCPU +- **ISA:** `amx_tile amx_int8 amx_bf16` + AVX-512-VNNI; kernel `6.18.5` (≥5.16); binutils `2.42`; rustc `1.94.1` +- **Branch:** `claude/zealous-galileo-fEQ3d` @ `7a23812` +- **Method:** `cargo bench`, default criterion sampling, pinned to core 2 (`taskset -c 2`), idle box (load ≈ 1.0) +- **Date:** 2026-06-02 + +## 1. AMX live confirmation ✅ + +Gate-check (`amx_i32` bench) produced `avx512amx_8x8`/`avx512amx_16x16` columns with real `thrpt:` numbers — **neither** "tract not built with AMX" (build probe) **nor** "AMX not available, skipping" (runtime CPUID + `arch_prctl` XTILEDATA gate) appeared. AMX is genuinely exercised. + +## 2. Correctness + +| Suite | Result | +|---|---| +| `cargo test -p tract-linalg --lib avx512amx` | **297 passed; 3 failed** | +| `cargo test -p tract-linalg --lib x86_64_fma::mmm` | **1833 passed; 3 failed** | + +**Bugfix `99eb75b9d` VALIDATED on silicon** ✅ — every `scalar_sub` / `per_row_sub` / `per_col_sub` (+`_f`) test passed for **both** `avx512amx_mmm_i32_16x16` and `avx512amx_mmm_f32_16x16`. + +**3 failures — all in the AMX bf16 path** (`avx512amx_mmm_f32_16x16::f32f32_bf16`): `fuse::prop`, `frame::prop`, `fuse::packed_packed_bug_3`. + +**Root cause = test-harness tolerance, NOT a kernel defect.** `packed_packed.rs:367` selects the comparison tolerance from the **accumulator** dtype: +```rust +let app = if K::Acc::datum_type() == f16::datum_type() + { Approximation::SuperApproximate } else { Approximation::Approximate }; +``` +This kernel accumulates in **f32** (TDPBF16PS: bf16×bf16→f32), so it gets `Approximate` = `(atol 1e-4, rtol 5e-4, 0 outliers)` — but the `f32f32_bf16` packing truncates inputs to bf16 (~2⁻⁸ ≈ 0.39% rel). bf16-grade error is checked against an f32-grade bar with zero tolerated outliers ⇒ guaranteed failure. `SuperApproximate` `(atol 0.1, rtol 0.05, 1e-4 outliers)` would pass. The structurally identical int8 16×16 kernel passes 100%. + +**Proposed fix:** in `check()`, pick `SuperApproximate` when the packing is bf16-based, not only when `K::Acc == f16`. + +**Empirically verified (on the AMX host):** the kernel was run on 7 cases (including the exact `bug_3` input) and compared against an independent **bf16-truncated** reference — built with the project's own `f32_to_bf16_rne` — judged by the *same* tight `Approximate` bar: **0 outliers across ~335k output elements** (max abs err ≤ 1.3e-5), versus **282,788 outliers** against a pure-f32 reference. The kernel reproduces "truncate inputs→bf16, accumulate→f32" exactly; the 3 red tests are 100% the f32 oracle, with no kernel defect. + +## 3. Benchmarks — throughput (Gelem/s, point estimate) + +### `amx_i32` — int8 GEMM +| M×K×N | avx2 | avx512vnni (8×8) | avx512amx_8×8 | avx512amx_16×16 | +|---|---:|---:|---:|---:| +| 64×256×64 | 0.41 | 11.21 | 68.41 | **233.64** | +| 256×256×256 | 0.41 | 11.31 | 68.47 | **237.29** | +| 512×512×512 | 0.39 | 8.94 † | 112.86 | **228.15** | +| 1024×1024×64 | 0.41 | 34.84 | 178.42 | **279.51** | + +### `amx_f32` — bf16→f32 GEMM +| M×K×N | fma_16×6 | avx512_16×12 | avx512amx_bf16_16×16 | +|---|---:|---:|---:| +| 64×256×64 | 37.12 | 64.31 | **207.35** | +| 256×256×256 | 37.90 | 71.90 | **225.74** | +| 512×512×512 | 39.37 | 64.69 | **348.38** | +| 1024×1024×64 | 36.85 | 59.22 | **318.36** | + +### `vnni_i32` — int8 GEMM (new 16×16 in isolation) +| M×K×N | avx2 | avx512vnni (8×8) | avx512vnni_16×16 | +|---|---:|---:|---:| +| 64×256×64 | 0.41 | 10.90 | **135.74** | +| 256×256×256 | 0.40 | 10.78 | **134.92** | +| 512×512×512 | 0.40 | 20.53 | **154.39** | +| 1024×1024×64 | 0.41 | 34.77 | **161.27** | + +† `avx512vnni`@512³ read 8.94 here vs 20.53 in `vnni_i32` (same kernel/shape). Treat **20.53** as the credible value (it fits the size trend 11.3→20.5→34.8); 8.94 was an outlier. A higher-sampling re-measure was attempted but could not complete — see §6. + +## 4. Head-to-head ratios + +| Comparison | 64×256×64 | 256×256×256 | 512×512×512 | 1024×1024×64 | +|---|---:|---:|---:|---:| +| **AMX 16×16 ÷ VNNI 16×16** (int8, same CPU) | 1.72× | 1.76× | 1.48× | 1.73× | +| **AMX 16×16 ÷ AMX 8×8** (int8) | 3.42× | 3.47× | 2.02× | 1.57× | +| **VNNI 16×16 ÷ VNNI 8×8** (int8) | 12.45× | 12.51× | 7.52× | 4.64× | +| **AMX bf16 16×16 ÷ AVX-512 f32 16×12** | 3.22× | 3.14× | 5.39× | 5.38× | +| *(bonus) AMX bf16 ÷ FMA f32 16×6* | 5.59× | 5.96× | 8.85× | 8.64× | + +## 5. Findings + +1. **AMX int8 16×16 wins everywhere — justifies `boost(100)` > VNNI `boost(50)`.** 1.48–1.76× over the new VNNI 16×16 on the *same* silicon. Dispatch ordering is correct. +2. **AMX 16×16 vs 8×8: 1.57–3.47×.** 16×16 leads on all tested shapes; the 4×-work/instr advantage is largest on compact shapes (3.4× @ 64×256×64) and narrowest on tall-skinny 1024×1024×64 (1.57×, N=64). No tested shape favors 8×8 — any crossover lives below this suite (smaller M or N<16). `qmmm_i32` defaulting to 16×16 here is sound. +3. **VNNI 16×16 vs 8×8: 4.64–12.5× — far above the dev box's 1.3–2.1×.** Likely the 8×8 kernel's ymm (256-bit) accumulators vs the new kernel's zmm (512-bit), amplified on Sapphire Rapids (no AVX-512 license downclock that Cascade Lake suffers). Strongly validates the new kernel; the magnitude warrants one sanity re-check (see #4). +4. **Data-quality flag (resolved by inspection):** `avx512vnni` 8×8 @ 512³ read 8.94 (in `amx_i32`) vs 20.53 (in `vnni_i32`) — a 2.3× swing on the same kernel/shape. **20.53 is the credible figure** (it continues the monotone size trend 11.3 @ 256³ → 20.5 @ 512³ → 34.8 @ 1024×1024×64; 8.94 breaks it). A `--sample-size 200` re-measure was launched but the AMX host was reclaimed before it could run (see §6); the ratio table already uses the consistent 20.53 pairing. AMX columns were stable across runs. +5. **AMX bf16 is 3.1–5.4× the AVX-512 f32 kernel** (5.6–8.9× over FMA), scaling up on larger shapes (348 Gelem/s @ 512³) — with the documented bf16 precision trade (see §2 and `X86_64_INT8_GEMM.md`). + +## 6. Reproducibility note + +Numbers were collected **2026-06-02** on an AMX-capable `Intel(R) Xeon(R) @ 2.10GHz` (`amx_tile/int8/bf16` + AVX-512-VNNI, kernel 6.18.5). The ephemeral session container was subsequently reclaimed and re-provisioned onto a different `Intel(R) Xeon(R) @ 2.80GHz` with **neither AMX nor AVX-512-VNNI** (only `avx512f`), on which `amx_i32`/`vnni_i32` both short-circuit and skip — so the one outstanding re-measure (VNNI-8×8 @ 512³) could not be completed in this session. To reproduce or extend, run on an AMX host (Sapphire Rapids / Emerald Rapids / Granite Rapids Xeon, or Xeon Max) following `linalg/AMX_BENCH_RUNBOOK.md`. From 51196c897d7bd516922118f37a48ab5a3417c9d2 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Wed, 3 Jun 2026 09:13:15 +0000 Subject: [PATCH 20/21] linalg/x86_64: rustfmt the AMX kernels, benches, and build probes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply the workspace rustfmt.toml (use_small_heuristics = "Max") to the AMX / AVX-VNNI additions so the slice passes `cargo fmt --check`, which upstream main already satisfies. Pure formatting — collapses call chains, if-conditions, and a fn signature onto single lines that fit the width. No functional change. https://claude.ai/code/session_018Hes6yEvk2TSWB26SAJfqT --- linalg/benches/amx_f32.rs | 9 +-------- linalg/benches/amx_i32.rs | 16 ++++++++++------ linalg/build.rs | 33 +++++++-------------------------- linalg/src/x86_64_fma/amx.rs | 4 ++-- linalg/src/x86_64_fma/mmm.rs | 6 ++---- 5 files changed, 22 insertions(+), 46 deletions(-) diff --git a/linalg/benches/amx_f32.rs b/linalg/benches/amx_f32.rs index bf59f066cf..1206bbde57 100644 --- a/linalg/benches/amx_f32.rs +++ b/linalg/benches/amx_f32.rs @@ -17,14 +17,7 @@ use criterion::*; use tract_data::internal::*; use tract_linalg::mmm::{AsInputValue, FusedSpec, MatMatMul}; -fn run_kernel( - be: &mut Bencher, - mmm: &dyn MatMatMul, - packing: usize, - m: usize, - k: usize, - n: usize, -) { +fn run_kernel(be: &mut Bencher, mmm: &dyn MatMatMul, packing: usize, m: usize, k: usize, n: usize) { let a = Tensor::zero_dt(DatumType::F32, &[m, k]).unwrap(); let b = Tensor::zero_dt(DatumType::F32, &[k, n]).unwrap(); let (pack_a, pack_b) = &mmm.packings()[packing]; diff --git a/linalg/benches/amx_i32.rs b/linalg/benches/amx_i32.rs index 79388ec05f..ae5f58ebc8 100644 --- a/linalg/benches/amx_i32.rs +++ b/linalg/benches/amx_i32.rs @@ -64,12 +64,16 @@ fn benches(c: &mut Criterion) { |b, &(m, k, n)| run_kernel(b, &*avx512vnni_mmm_i32_8x8.mmm(), m, k, n), ); } - g.bench_with_input(BenchmarkId::new("avx512amx_8x8", &id), &(m, k, n), |b, &(m, k, n)| { - run_kernel(b, &*avx512amx_mmm_i32_8x8.mmm(), m, k, n) - }); - g.bench_with_input(BenchmarkId::new("avx512amx_16x16", &id), &(m, k, n), |b, &(m, k, n)| { - run_kernel(b, &*avx512amx_mmm_i32_16x16.mmm(), m, k, n) - }); + g.bench_with_input( + BenchmarkId::new("avx512amx_8x8", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512amx_mmm_i32_8x8.mmm(), m, k, n), + ); + g.bench_with_input( + BenchmarkId::new("avx512amx_16x16", &id), + &(m, k, n), + |b, &(m, k, n)| run_kernel(b, &*avx512amx_mmm_i32_16x16.mmm(), m, k, n), + ); g.finish(); } } diff --git a/linalg/build.rs b/linalg/build.rs index c7770f1c03..81a3a0baf6 100644 --- a/linalg/build.rs +++ b/linalg/build.rs @@ -345,11 +345,7 @@ fn main() { let tmpl16 = path::Path::new("x86_64/fma/avx512vnni_mmm_i32_16x16.S.j2"); let out16 = out_dir.join(format!("avx512vnni_mmm_i32_16x16_{suffix}.S")); preprocess_file(tmpl16, &out16, &[], &suffix, false); - cc::Build::new() - .file(&out) - .file(&out16) - .flag("-mfma") - .compile("x86_64_avx512vnni"); + cc::Build::new().file(&out).file(&out16).flag("-mfma").compile("x86_64_avx512vnni"); println!("cargo:rustc-cfg=tract_avx512vnni"); } @@ -360,13 +356,8 @@ fn main() { // reference: when the probe fails on old toolchains (e.g. Debian // stretch's binutils 2.28), the kernel is omitted and `qmmm_i32` // dispatch falls back to VNNI or AVX2 with no build error. - if os != "windows" - && !amx_int8_files.is_empty() - && assembler_supports_amx_int8() - { - cc::Build::new() - .files(&amx_int8_files) - .compile("x86_64_avx512amx"); + if os != "windows" && !amx_int8_files.is_empty() && assembler_supports_amx_int8() { + cc::Build::new().files(&amx_int8_files).compile("x86_64_avx512amx"); println!("cargo:rustc-cfg=tract_amx_int8"); } @@ -375,13 +366,8 @@ fn main() { // probe fails, the `tract_amx_bf16` cfg stays unset and // `plug_avx512amx_bf16` is compiled out — `mmm_f32` then falls // back to AVX-512 / FMA without any build error. - if os != "windows" - && !amx_bf16_files.is_empty() - && assembler_supports_amx_bf16() - { - cc::Build::new() - .files(&amx_bf16_files) - .compile("x86_64_avx512amx_bf16"); + if os != "windows" && !amx_bf16_files.is_empty() && assembler_supports_amx_bf16() { + cc::Build::new().files(&amx_bf16_files).compile("x86_64_avx512amx_bf16"); println!("cargo:rustc-cfg=tract_amx_bf16"); } @@ -390,13 +376,8 @@ fn main() { // (Alder Lake-E, Sierra Forest, Clearwater Forest / Darkmont) // that have AVX-VNNI but no AVX-512, falling back to AVX2 // emulation when the runtime CPUID detection misses. - if os != "windows" - && !avxvnni_files.is_empty() - && assembler_supports_avxvnni() - { - cc::Build::new() - .files(&avxvnni_files) - .compile("x86_64_avxvnni"); + if os != "windows" && !avxvnni_files.is_empty() && assembler_supports_avxvnni() { + cc::Build::new().files(&avxvnni_files).compile("x86_64_avxvnni"); println!("cargo:rustc-cfg=tract_avxvnni"); } } diff --git a/linalg/src/x86_64_fma/amx.rs b/linalg/src/x86_64_fma/amx.rs index 685a9618c8..9877468095 100644 --- a/linalg/src/x86_64_fma/amx.rs +++ b/linalg/src/x86_64_fma/amx.rs @@ -62,8 +62,8 @@ pub fn cache_sizes() -> CacheSizes { let partitions = ((r.ebx >> 12) & 0x3FF) + 1; let line_size = (r.ebx & 0xFFF) + 1; let sets = r.ecx + 1; - let bytes = (ways as usize) * (partitions as usize) - * (line_size as usize) * (sets as usize); + let bytes = + (ways as usize) * (partitions as usize) * (line_size as usize) * (sets as usize); // type=1 (data), type=3 (unified) for L1d / L2 / L3 match (level, cache_type) { (1, 1) => out.l1d_bytes = bytes, diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index e5943c5bc3..d862c893dc 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -283,10 +283,8 @@ pub fn plug_avx512amx_bf16(ops: &mut Ops) { // Save the previously-installed f32 picker so we can defer to it when // the AMX kernel isn't a good fit (small M/N, or K < 32 -- one TDPBF16PS // consumes 32 bf16 K-lanes so the panel must have at least one full step). - let prev: crate::MMMImpl = std::mem::replace( - &mut ops.mmm_f32, - Box::new(|_, _, _| unreachable!()), - ); + let prev: crate::MMMImpl = + std::mem::replace(&mut ops.mmm_f32, Box::new(|_, _, _| unreachable!())); ops.mmm_f32 = Box::new(move |m, k, n| { let big = |o: Option, t: usize| o.is_none_or(|v| v >= t); // Same dispatch shape as the int8 16x16/8x8 split: hand off to AMX From 495c5154afacba843b974886bc7c3184e252abb0 Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Wed, 3 Jun 2026 10:11:05 +0000 Subject: [PATCH 21/21] linalg/x86_64: build on the 1.91 MSRV and the cfg-off path The AMX / AVX-VNNI detection reads CPUID via std::arch::x86_64::__cpuid_count. That intrinsic is `unsafe` on tract's MSRV (rustc 1.91) but was made safe in a later release, so the calls compiled locally yet broke every 1.91 CI job with E0133. Wrap each call in `unsafe { }` (required on 1.91) and add `#[allow(unused_unsafe)]` so newer toolchains, where the call is safe, don't trip `unused_unsafe` under `-D warnings`. Also gate the `PackedI8K4` and `super::amx` imports in mmm.rs on the cfgs that actually use them, so the old-assembler build (Debian stretch, all kernel cfgs off) has no unused-import warnings. Verified: tract-linalg compiles on rustc 1.91.0 and 1.94, clippy-clean. https://claude.ai/code/session_018Hes6yEvk2TSWB26SAJfqT --- linalg/src/x86_64_fma/amx.rs | 6 ++++-- linalg/src/x86_64_fma/amx_bf16.rs | 3 ++- linalg/src/x86_64_fma/avxvnni.rs | 6 ++++-- linalg/src/x86_64_fma/mmm.rs | 5 ++++- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/linalg/src/x86_64_fma/amx.rs b/linalg/src/x86_64_fma/amx.rs index 9877468095..a830ac2ac2 100644 --- a/linalg/src/x86_64_fma/amx.rs +++ b/linalg/src/x86_64_fma/amx.rs @@ -52,7 +52,8 @@ pub fn cache_sizes() -> CacheSizes { *CACHE.get_or_init(|| { let mut out = CacheSizes::default(); for sub in 0..16 { - let r = std::arch::x86_64::__cpuid_count(4, sub); + #[allow(unused_unsafe)] + let r = unsafe { std::arch::x86_64::__cpuid_count(4, sub) }; let cache_type = r.eax & 0x1F; if cache_type == 0 { break; @@ -83,7 +84,8 @@ fn cpu_has_amx_int8() -> bool { if !std::is_x86_feature_detected!("avx512f") { return false; } - let r = std::arch::x86_64::__cpuid_count(7, 0); + #[allow(unused_unsafe)] + let r = unsafe { std::arch::x86_64::__cpuid_count(7, 0) }; // bit 24 = AMX-TILE, bit 25 = AMX-INT8 in EDX. const AMX_TILE: u32 = 1 << 24; const AMX_INT8: u32 = 1 << 25; diff --git a/linalg/src/x86_64_fma/amx_bf16.rs b/linalg/src/x86_64_fma/amx_bf16.rs index ba64cfa16a..abda445c63 100644 --- a/linalg/src/x86_64_fma/amx_bf16.rs +++ b/linalg/src/x86_64_fma/amx_bf16.rs @@ -30,7 +30,8 @@ fn cpu_has_amx_bf16() -> bool { if !std::is_x86_feature_detected!("avx512f") { return false; } - let r = std::arch::x86_64::__cpuid_count(7, 0); + #[allow(unused_unsafe)] + let r = unsafe { std::arch::x86_64::__cpuid_count(7, 0) }; const AMX_BF16: u32 = 1 << 22; const AMX_TILE: u32 = 1 << 24; (r.edx & AMX_BF16) != 0 && (r.edx & AMX_TILE) != 0 diff --git a/linalg/src/x86_64_fma/avxvnni.rs b/linalg/src/x86_64_fma/avxvnni.rs index 5d98637865..ee6af7d857 100644 --- a/linalg/src/x86_64_fma/avxvnni.rs +++ b/linalg/src/x86_64_fma/avxvnni.rs @@ -26,11 +26,13 @@ fn cpu_has_avxvnni() -> bool { if !std::is_x86_feature_detected!("avx2") { return false; } - let max_sub = std::arch::x86_64::__cpuid_count(7, 0).eax; + #[allow(unused_unsafe)] + let max_sub = unsafe { std::arch::x86_64::__cpuid_count(7, 0) }.eax; if max_sub < 1 { return false; } - let r = std::arch::x86_64::__cpuid_count(7, 1); + #[allow(unused_unsafe)] + let r = unsafe { std::arch::x86_64::__cpuid_count(7, 1) }; (r.eax & (1 << 4)) != 0 } diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index d862c893dc..1579ac5d37 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -2,8 +2,11 @@ use crate::Ops; use crate::block_quant::*; use crate::mmm::ImplementationQuality::ManuallyOptimized; use crate::mmm::MatMatMul; -use crate::pack::{PackedFormat, PackedI8K4}; +use crate::pack::PackedFormat; +#[cfg(any(tract_avx512vnni, tract_avxvnni, tract_amx_int8))] +use crate::pack::PackedI8K4; +#[cfg(tract_amx_int8)] use super::amx::{PackedAmxA, has_amx_int8}; #[cfg(tract_amx_bf16)] use super::amx_bf16::{PackedAmxBf16A, PackedBf16K2, has_amx_bf16};