From b8e8d291d119bba4be8baa6423b49e34b7ad2df4 Mon Sep 17 00:00:00 2001 From: Tim Burke Date: Sun, 15 Mar 2026 21:27:42 -0400 Subject: [PATCH] =?UTF-8?q?ggml:=20refactor=20x86=20AVX2=20and=20ARM=20NEO?= =?UTF-8?q?N=20MXFP=20dequant=20=E2=80=94=20shared=20traits=20and=20helper?= =?UTF-8?q?s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add mxfp_dequant_traits_t to ggml-common.h as single source of truth for MXFP IEEE-754 reconstruction parameters. Define static const instances for all 4 formats (E4M3, E5M2, E2M3, E3M2), ready for CUDA/Metal/Vulkan reuse. Extract shared dequant and FP6 unpack helpers on both architectures, replacing duplicated inline code and macros. Net -215 lines. --- ggml/src/ggml-common.h | 41 ++ ggml/src/ggml-cpu/arch/arm/quants.c | 662 ++++++++++++---------------- ggml/src/ggml-cpu/arch/x86/quants.c | 562 +++++++++-------------- ggml/src/ggml-cpu/ops.cpp | 10 +- ggml/src/ggml-quants.h | 2 +- 5 files changed, 531 insertions(+), 746 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 9945fef137..cc9a4a0aca 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -267,6 +267,47 @@ static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4 #define MXFP6_E3M2_MANT_SHIFT 21 // 23-2 #define MXFP6_E3M2_SUB_SCALE (1.0f/16.0f) // 2^(-4) = 2^(1-3-2) +// Unified MXFP dequantization traits for SIMD backends (CPU x86/ARM, CUDA, Metal, Vulkan). +// Contains all parameters needed for IEEE-754 bit reconstruction of FP8/FP6 elements. +// FP4 uses LUT-based dequant and does not need this struct. +typedef struct { + int exp_mask; // (1<> 0) & 0x3F; + u[1] = (pk >> 6) & 0x3F; + u[2] = (pk >> 12) & 0x3F; + u[3] = (pk >> 18) & 0x3F; + const uint8x8_t raw8 = vcreate_u8( + (uint64_t)u[0] | ((uint64_t)u[1] << 8) | + ((uint64_t)u[2] << 16) | ((uint64_t)u[3] << 24)); + return vmovl_u16(vget_low_u16(vmovl_u8(raw8))); +} + +// Widen 8 raw bytes to two uint32x4_t halves. +static inline void widen_u8x8_to_u32x4x2(const uint8_t * src, + uint32x4_t * lo, uint32x4_t * hi) { + const uint8x8_t raw8 = vld1_u8(src); + const uint16x8_t raw16 = vmovl_u8(raw8); + *lo = vmovl_u16(vget_low_u16(raw16)); + *hi = vmovl_u16(vget_high_u16(raw16)); +} + +// Widen 8 Q8_0 int8 values to two float32x4_t halves. +static inline void widen_s8x8_to_f32x4x2(const int8_t * src, + float32x4_t * lo, float32x4_t * hi) { + const int8x8_t q8 = vld1_s8(src); + const int16x8_t q16 = vmovl_s8(q8); + *lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16))); + *hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16))); +} + +// ── MXFP FP8/FP6 vec_dot ────────────────────────────────────────────────── + +static void ggml_vec_dot_mxfp8_q8_0_neon( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, - // FP8 format parameters: - const uint32_t exp_mask, // 0xF for E4M3, 0x1F for E5M2 - const uint32_t mant_mask, // 0x7 for E4M3, 0x3 for E5M2 - const int exp_shift, // 3 for E4M3, 2 for E5M2 - const uint32_t ieee_exp_off, // 120 for E4M3, 112 for E5M2 - const int mant_shift, // 20 for E4M3, 21 for E5M2 - const float sub_scale) { // 1/512 for E4M3, 1/65536 for E5M2 + const mxfp_neon_traits_t * t) { assert(n % QK_MXFP8 == 0); const int nb = n / QK_MXFP8; const block_mxfp8 * GGML_RESTRICT x = vx; const block_q8_0 * GGML_RESTRICT y = vy; + const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); + const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); + const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); + float32x4_t acc0 = vdupq_n_f32(0.0f); float32x4_t acc1 = vdupq_n_f32(0.0f); - const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); - // Use variable shifts (vshlq_u32) instead of constant shifts (vshlq_n_u32) - // because exp_shift/mant_shift are function parameters, not compile-time constants. - // Clang requires _n_ intrinsics to have literal constant arguments. - const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); - const int32x4_t v_mant_shift = vdupq_n_s32(mant_shift); - for (int ib = 0; ib < nb; ++ib) { - const float scale = GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d); - const float32x4_t v_scale = vdupq_n_f32(scale); + const float32x4_t v_scale = vdupq_n_f32( + GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - // Process 32 FP8 elements in 8 groups of 4 for (int j = 0; j < 32; j += 8) { - // Load 8 FP8 bytes, extend to two uint32x4_t - const uint8x8_t raw8 = vld1_u8(x[ib].qs + j); - const uint16x8_t raw16 = vmovl_u8(raw8); - const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16)); - const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16)); + uint32x4_t v_lo, v_hi; + widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi); - // Load 8 Q8_0 int8 values, extend to two int32x4_t → float32x4_t - const int8x8_t q8 = vld1_s8(y[ib].qs + j); - const int16x8_t q16 = vmovl_s8(q8); - const float32x4_t qf_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16))); - const float32x4_t qf_hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16))); + float32x4_t qf_lo, qf_hi; + widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi); - // Dequant FP8 → float for both groups of 4 - #define DEQUANT_FP8_NEON(v_raw, qf, acc) do { \ - const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \ - const uint32x4_t exp = vandq_u32( \ - vshlq_u32(v_raw, v_neg_exp_shift), \ - v_exp_mask); \ - const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \ - /* Normal: IEEE bits = (exp + offset) << 23 | mant << mant_shift */ \ - const uint32x4_t ieee = vorrq_u32( \ - vorrq_u32(vshlq_n_u32(sign, 24), \ - vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \ - vshlq_u32(mant, v_mant_shift)); \ - const float32x4_t normal = vreinterpretq_f32_u32(ieee); \ - /* Subnormal: sign * mant * sub_scale */ \ - const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \ - const uint32x4_t sub_bits = vorrq_u32( \ - vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \ - const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \ - /* Select: subnormal when exp == 0, else normal */ \ - const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \ - const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \ - /* Multiply by scale and Q8 value, accumulate */ \ - (acc) = vfmaq_f32((acc), vmulq_f32(val, v_scale), qf); \ - } while (0) + const float32x4_t val_lo = mxfp8_dequant_neon(v_lo, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); + const float32x4_t val_hi = mxfp8_dequant_neon(v_hi, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - DEQUANT_FP8_NEON(v_lo, qf_lo, acc0); - DEQUANT_FP8_NEON(v_hi, qf_hi, acc1); - #undef DEQUANT_FP8_NEON + acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo); + acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi); } } *s = vaddvq_f32(vaddq_f32(acc0, acc1)); } -#endif -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__ARM_NEON) - // E4M3: sign(1) exp(4) mant(3), bias=7 - ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy, - MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, - MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); -#else - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - -// NEON-optimized MXFP6 × Q8_0 dot product. -// Unpacks tight 6-bit packing (4 values per 3 bytes), then dequants to float. -#if defined(__ARM_NEON) -static inline void ggml_vec_dot_mxfp6_q8_0_neon( +static void ggml_vec_dot_mxfp6_q8_0_neon( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, - size_t block_size, - // FP6 format parameters: - const uint32_t exp_mask, // 0x3 for E2M3, 0x7 for E3M2 - const uint32_t mant_mask, // 0x7 for E2M3, 0x3 for E3M2 - const int exp_shift, // 3 for E2M3, 2 for E3M2 - const uint32_t ieee_exp_off, // 126 for E2M3, 124 for E3M2 - const int mant_shift, // 20 for E2M3, 21 for E3M2 - const float sub_scale) { // 1/8 for E2M3, 1/16 for E3M2 + const mxfp_neon_traits_t * t) { assert(n % QK_MXFP6 == 0); const int nb = n / QK_MXFP6; const block_q8_0 * GGML_RESTRICT y = vy; + const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); + const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); + const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); + float32x4_t acc0 = vdupq_n_f32(0.0f); float32x4_t acc1 = vdupq_n_f32(0.0f); - const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); - const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); - const int32x4_t v_mant_shift = vdupq_n_s32(mant_shift); - for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size); - const float scale = GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d); - const float32x4_t v_scale = vdupq_n_f32(scale); + const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; + const float32x4_t v_scale = vdupq_n_f32( + GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - // Process 32 FP6 elements: 8 groups of 4, each packed in 3 bytes for (int j = 0; j < 32; j += 8) { - // Unpack two groups of 4 FP6 values (6 bytes → 8 values) - uint8_t unpacked[8]; - // Group 1: 3 bytes → 4 values - { - const uint8_t * p = xb->qs + (j * 3 / 4); - const uint32_t packed = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[0] = (packed >> 0) & 0x3F; - unpacked[1] = (packed >> 6) & 0x3F; - unpacked[2] = (packed >> 12) & 0x3F; - unpacked[3] = (packed >> 18) & 0x3F; - } - // Group 2: next 3 bytes → 4 values - { - const uint8_t * p = xb->qs + ((j + 4) * 3 / 4); - const uint32_t packed = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[4] = (packed >> 0) & 0x3F; - unpacked[5] = (packed >> 6) & 0x3F; - unpacked[6] = (packed >> 12) & 0x3F; - unpacked[7] = (packed >> 18) & 0x3F; - } + const uint32x4_t v_lo = unpack_fp6x4_neon(xb->qs + (j * 3 / 4)); + const uint32x4_t v_hi = unpack_fp6x4_neon(xb->qs + ((j + 4) * 3 / 4)); - // Extend to uint32x4_t - const uint8x8_t raw8 = vld1_u8(unpacked); - const uint16x8_t raw16 = vmovl_u8(raw8); - const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16)); - const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16)); + float32x4_t qf_lo, qf_hi; + widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi); - // Load Q8_0 int8 values - const int8x8_t q8 = vld1_s8(y[ib].qs + j); - const int16x8_t q16 = vmovl_s8(q8); - const float32x4_t qf_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16))); - const float32x4_t qf_hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16))); + const float32x4_t val_lo = mxfp6_dequant_neon(v_lo, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); + const float32x4_t val_hi = mxfp6_dequant_neon(v_hi, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - // Dequant FP6 → float (same IEEE construction as FP8, sign bit at position 5) - #define DEQUANT_FP6_NEON(v_raw, qf, acc) do { \ - const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20)); \ - const uint32x4_t exp = vandq_u32( \ - vshlq_u32(v_raw, v_neg_exp_shift), \ - v_exp_mask); \ - const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \ - const uint32x4_t ieee = vorrq_u32( \ - vorrq_u32(vshlq_n_u32(sign, 26), \ - vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \ - vshlq_u32(mant, v_mant_shift)); \ - const float32x4_t normal = vreinterpretq_f32_u32(ieee); \ - const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \ - const uint32x4_t sub_bits = vorrq_u32( \ - vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)); \ - const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \ - const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \ - const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \ - (acc) = vfmaq_f32((acc), vmulq_f32(val, v_scale), qf); \ - } while (0) - - DEQUANT_FP6_NEON(v_lo, qf_lo, acc0); - DEQUANT_FP6_NEON(v_hi, qf_hi, acc1); - #undef DEQUANT_FP6_NEON + acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo); + acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi); } } *s = vaddvq_f32(vaddq_f32(acc0, acc1)); } -#endif -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__ARM_NEON) - // E2M3: sign(1) exp(2) mant(3), bias=1 - ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, sizeof(block_mxfp6), - MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, - MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); -#else - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} +// ── MXFP FP8/FP6 dequantize_row (AoS) ───────────────────────────────────── -// ---- MXFP dequantize_row (to_float) — NEON-optimized ---- - -#if defined(__ARM_NEON) -static inline void dequantize_row_mxfp8_neon( +static void dequantize_row_mxfp8_neon( const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, - const uint32_t exp_mask, const uint32_t mant_mask, - const int exp_shift, const uint32_t ieee_exp_off, - const int mant_shift, const float sub_scale) { + const mxfp_neon_traits_t * t) { assert(k % QK_MXFP8 == 0); const int nb = k / QK_MXFP8; const block_mxfp8 * GGML_RESTRICT x = vx; - const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); - const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); - const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift); + const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); + const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); + const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); for (int ib = 0; ib < nb; ++ib) { const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(x[ib].e)); for (int j = 0; j < 32; j += 8) { - const uint8x8_t raw8 = vld1_u8(x[ib].qs + j); - const uint16x8_t raw16 = vmovl_u8(raw8); - const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16)); - const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16)); + uint32x4_t v_lo, v_hi; + widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi); - #define DEQUANT_FP8_STORE(v_raw, dst) do { \ - const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \ - const uint32x4_t exp = vandq_u32( \ - vshlq_u32(v_raw, v_neg_exp_shift), \ - v_exp_mask); \ - const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \ - const uint32x4_t ieee = vorrq_u32( \ - vorrq_u32(vshlq_n_u32(sign, 24), \ - vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \ - vshlq_u32(mant, v_mant_shift_v)); \ - const float32x4_t normal = vreinterpretq_f32_u32(ieee); \ - const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \ - const uint32x4_t sub_bits = vorrq_u32( \ - vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \ - const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \ - const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \ - const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \ - vst1q_f32(dst, vmulq_f32(val, v_scale)); \ - } while (0) + const float32x4_t val_lo = mxfp8_dequant_neon(v_lo, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); + const float32x4_t val_hi = mxfp8_dequant_neon(v_hi, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); - DEQUANT_FP8_STORE(v_lo, y + ib * QK_MXFP8 + j); - DEQUANT_FP8_STORE(v_hi, y + ib * QK_MXFP8 + j + 4); - #undef DEQUANT_FP8_STORE + vst1q_f32(y + ib * QK_MXFP8 + j, vmulq_f32(val_lo, v_scale)); + vst1q_f32(y + ib * QK_MXFP8 + j + 4, vmulq_f32(val_hi, v_scale)); } } } -static inline void dequantize_row_mxfp6_neon( +static void dequantize_row_mxfp6_neon( const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, - size_t block_size, - const uint32_t exp_mask, const uint32_t mant_mask, - const int exp_shift, const uint32_t ieee_exp_off, - const int mant_shift, const float sub_scale) { + const mxfp_neon_traits_t * t) { assert(k % QK_MXFP6 == 0); const int nb = k / QK_MXFP6; - const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); - const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); - const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift); + const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); + const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); + const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size); + const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(xb->e)); for (int j = 0; j < 32; j += 4) { - const uint8_t * p = xb->qs + (j * 3 / 4); - const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - uint8_t unpacked[4]; - unpacked[0] = (pk >> 0) & 0x3F; - unpacked[1] = (pk >> 6) & 0x3F; - unpacked[2] = (pk >> 12) & 0x3F; - unpacked[3] = (pk >> 18) & 0x3F; + const uint32x4_t v_raw = unpack_fp6x4_neon(xb->qs + (j * 3 / 4)); - const uint8x8_t raw8 = vcreate_u8( - (uint64_t)unpacked[0] | ((uint64_t)unpacked[1] << 8) | - ((uint64_t)unpacked[2] << 16) | ((uint64_t)unpacked[3] << 24)); - const uint32x4_t v_raw = vmovl_u16(vget_low_u16(vmovl_u8(raw8))); - - const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20)); - const uint32x4_t exp = vandq_u32( - vshlq_u32(v_raw, v_neg_exp_shift), - v_exp_mask); - const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); - - const uint32x4_t ieee = vorrq_u32( - vorrq_u32(vshlq_n_u32(sign, 26), - vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), - vshlq_u32(mant, v_mant_shift_v)); - const float32x4_t normal = vreinterpretq_f32_u32(ieee); - - const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); - const uint32x4_t sub_bits = vorrq_u32( - vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)); - const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); - - const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); - const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); + const float32x4_t val = mxfp6_dequant_neon(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale)); } } } -#endif -void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { -#if defined(__ARM_NEON) - dequantize_row_mxfp8_neon(x, y, k, - MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, - MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); -#else - dequantize_row_mxfp8_cpu_generic(x, y, k); -#endif +// ── MXFP SoA dequant (flash attention) ───────────────────────────────────── + +static void dequantize_row_mxfp8_soa_neon( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, + const mxfp_neon_traits_t * t) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); + + const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); + const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); + const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); + + for (int ib = 0; ib < nb; ++ib) { + const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK)); + + for (int j = 0; j < 32; j += 8) { + uint32x4_t v_lo, v_hi; + widen_u8x8_to_u32x4x2(qs + j, &v_lo, &v_hi); + + const float32x4_t val_lo = mxfp8_dequant_neon(v_lo, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); + const float32x4_t val_hi = mxfp8_dequant_neon(v_hi, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); + + vst1q_f32(y + ib * QK_MXFP8 + j, vmulq_f32(val_lo, v_scale)); + vst1q_f32(y + ib * QK_MXFP8 + j + 4, vmulq_f32(val_hi, v_scale)); + } + } } -void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { -#if defined(__ARM_NEON) - dequantize_row_mxfp6_neon(x, y, k, sizeof(block_mxfp6), - MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, - MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); -#else - dequantize_row_mxfp6_cpu_generic(x, y, k); -#endif +static void dequantize_row_mxfp6_soa_neon( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, + const mxfp_neon_traits_t * t) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); + + const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask); + const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask); + const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off); + const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale); + const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift); + const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift); + + for (int ib = 0; ib < nb; ++ib) { + const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK)); + + for (int j = 0; j < 32; j += 4) { + const uint32x4_t v_raw = unpack_fp6x4_neon(qs + (j * 3 / 4)); + + const float32x4_t val = mxfp6_dequant_neon(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh); + + vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale)); + } + } } -// ---- MXFP SoA dequantize_row (to_float) — NEON-optimized ---- - -#if defined(__ARM_NEON) -static inline void dequantize_row_mxfp4_soa_neon( +// MXFP4 SoA dequant — LUT-based, no IEEE reconstruction needed. +static void dequantize_row_mxfp4_soa_neon( const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { assert(k % QK_MXFP4 == 0); const int nb = k / QK_MXFP4; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); const int8x16_t values = vld1q_s8(kvalues_mxfp4); const uint8x16_t m4b = vdupq_n_u8(0x0f); @@ -4527,122 +4490,45 @@ static inline void dequantize_row_mxfp4_soa_neon( } } -static inline void dequantize_row_mxfp8_soa_neon( - const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, - const uint32_t exp_mask, const uint32_t mant_mask, - const int exp_shift, const uint32_t ieee_exp_off, - const int mant_shift, const float sub_scale) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); +#endif // __ARM_NEON - const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); - const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); - const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift); +// ── Public dispatch functions ────────────────────────────────────────────── - for (int ib = 0; ib < nb; ++ib) { - const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); - const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK)); - - for (int j = 0; j < 32; j += 8) { - const uint8x8_t raw8 = vld1_u8(qs + j); - const uint16x8_t raw16 = vmovl_u8(raw8); - const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16)); - const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16)); - - #define DEQUANT_FP8_STORE_SOA(v_raw, dst) do { \ - const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \ - const uint32x4_t exp = vandq_u32( \ - vshlq_u32(v_raw, v_neg_exp_shift), \ - v_exp_mask); \ - const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \ - const uint32x4_t ieee = vorrq_u32( \ - vorrq_u32(vshlq_n_u32(sign, 24), \ - vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \ - vshlq_u32(mant, v_mant_shift_v)); \ - const float32x4_t normal = vreinterpretq_f32_u32(ieee); \ - const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \ - const uint32x4_t sub_bits = vorrq_u32( \ - vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \ - const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \ - const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \ - const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \ - vst1q_f32(dst, vmulq_f32(val, v_scale)); \ - } while (0) - - DEQUANT_FP8_STORE_SOA(v_lo, y + ib * QK_MXFP8 + j); - DEQUANT_FP8_STORE_SOA(v_hi, y + ib * QK_MXFP8 + j + 4); - #undef DEQUANT_FP8_STORE_SOA - } - } -} - -static inline void dequantize_row_mxfp6_soa_neon( - const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, - const uint32_t exp_mask, const uint32_t mant_mask, - const int exp_shift, const uint32_t ieee_exp_off, - const int mant_shift, const float sub_scale) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); - - const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask); - const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask); - const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off); - const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale); - const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift); - const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift); - - for (int ib = 0; ib < nb; ++ib) { - const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); - const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK)); - - for (int j = 0; j < 32; j += 4) { - const uint8_t * p = qs + (j * 3 / 4); - const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - uint8_t unpacked[4]; - unpacked[0] = (pk >> 0) & 0x3F; - unpacked[1] = (pk >> 6) & 0x3F; - unpacked[2] = (pk >> 12) & 0x3F; - unpacked[3] = (pk >> 18) & 0x3F; - - const uint8x8_t raw8 = vcreate_u8( - (uint64_t)unpacked[0] | ((uint64_t)unpacked[1] << 8) | - ((uint64_t)unpacked[2] << 16) | ((uint64_t)unpacked[3] << 24)); - const uint32x4_t v_raw = vmovl_u16(vget_low_u16(vmovl_u8(raw8))); - - const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20)); - const uint32x4_t exp = vandq_u32( - vshlq_u32(v_raw, v_neg_exp_shift), - v_exp_mask); - const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); - - const uint32x4_t ieee = vorrq_u32( - vorrq_u32(vshlq_n_u32(sign, 26), - vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), - vshlq_u32(mant, v_mant_shift_v)); - const float32x4_t normal = vreinterpretq_f32_u32(ieee); - - const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); - const uint32x4_t sub_bits = vorrq_u32( - vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)); - const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); - - const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); - const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); - - vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale)); - } - } -} +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); +#if defined(__ARM_NEON) + ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E4M3); +#else + ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif +} + +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); +#if defined(__ARM_NEON) + ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E2M3); +#else + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__ARM_NEON) + dequantize_row_mxfp8_neon(x, y, k, &MXFP_TRAITS_E4M3); +#else + dequantize_row_mxfp8_cpu_generic(x, y, k); +#endif +} + +void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__ARM_NEON) + dequantize_row_mxfp6_neon(x, y, k, &MXFP_TRAITS_E2M3); +#else + dequantize_row_mxfp6_cpu_generic(x, y, k); +#endif +} void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__ARM_NEON) @@ -4654,9 +4540,7 @@ void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__ARM_NEON) - dequantize_row_mxfp8_soa_neon(x, y, k, - MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, - MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); + dequantize_row_mxfp8_soa_neon(x, y, k, &MXFP_TRAITS_E4M3); #else dequantize_row_mxfp8_soa_cpu_generic(x, y, k); #endif @@ -4664,9 +4548,7 @@ void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__ARM_NEON) - dequantize_row_mxfp6_soa_neon(x, y, k, - MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, - MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); + dequantize_row_mxfp6_soa_neon(x, y, k, &MXFP_TRAITS_E2M3); #else dequantize_row_mxfp6_soa_cpu_generic(x, y, k); #endif diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 0c6f6ed49a..b00b1467d3 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -3819,30 +3819,77 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -// AVX2-optimized MXFP8 × Q8_0 dot product. -// Dequants FP8 elements to float via IEEE 754 bit construction, then dots against Q8_0. -// Parameters encode the FP8 format: exp_mask, mant_mask, exp_shift, ieee_exp_offset, mant_shift, sub_scale. +// ── MXFP FP8/FP6 AVX2 helpers ────────────────────────────────────────────── +// Shared IEEE-754 bit reconstruction and FP6 unpacking used by vec_dot, +// dequantize_row, and SoA dequant functions. + #if defined(__AVX2__) -static inline void ggml_vec_dot_mxfp8_q8_0_avx2( + +// Use shared mxfp_dequant_traits_t from ggml-common.h. +// Aliases for readability within this file. +#define mxfp_avx2_traits_t mxfp_dequant_traits_t + +// Dequantize 8 raw MXFP values (widened to int32) → 8 IEEE-754 floats. +// Handles both normal and subnormal paths. Works for any FP6/FP8 format. +static inline __m256 mxfp_dequant_avx2( + const __m256i v_raw, + const __m256i v_exp_mask, const __m256i v_mant_mask, + const __m256i v_ieee_off, const __m256 v_sub_sc, + const __m256i v_sign_mask, const __m256i v_zero, + int exp_shift, int sign_shift, int mant_shift) { + const __m256i sign = _mm256_and_si256(v_raw, v_sign_mask); + const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); + const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); + + const __m256i ieee = _mm256_or_si256( + _mm256_or_si256(_mm256_slli_epi32(sign, sign_shift), + _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), + _mm256_slli_epi32(mant, mant_shift)); + const __m256 normal = _mm256_castsi256_ps(ieee); + + const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); + const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( + _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, sign_shift))); + + const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); + return _mm256_blendv_ps(normal, sub_val, is_sub); +} + +// Unpack 4 tightly-packed 6-bit values from 3 bytes into separate bytes. +static inline void unpack_fp6x4_avx2(const uint8_t * p, uint8_t out[4]) { + const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); + out[0] = (pk >> 0) & 0x3F; + out[1] = (pk >> 6) & 0x3F; + out[2] = (pk >> 12) & 0x3F; + out[3] = (pk >> 18) & 0x3F; +} + +// Unpack 8 FP6 values (two groups of 4) from packed qs data at offset j. +static inline __m256i unpack_fp6x8_avx2(const uint8_t * qs, int j) { + uint8_t unpacked[8]; + unpack_fp6x4_avx2(qs + (j * 3 / 4), unpacked); + unpack_fp6x4_avx2(qs + ((j + 4) * 3 / 4), unpacked + 4); + return _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)unpacked)); +} + +// ── MXFP FP8/FP6 vec_dot ────────────────────────────────────────────────── + +// Unified FP8 × Q8_0 dot product (works for E4M3 and E5M2). +static void ggml_vec_dot_mxfp8_q8_0_avx2( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, - // FP8 format parameters: - const int exp_mask, // 0xF for E4M3, 0x1F for E5M2 - const int mant_mask, // 0x7 for E4M3, 0x3 for E5M2 - const int exp_shift, // 3 for E4M3, 2 for E5M2 - const int ieee_exp_off, // 120 for E4M3, 112 for E5M2 - const int mant_shift, // 20 for E4M3, 21 for E5M2 - const float sub_scale) { // 1/512 for E4M3, 1/65536 for E5M2 + const mxfp_avx2_traits_t * t) { assert(n % QK_MXFP8 == 0); const int nb = n / QK_MXFP8; const block_mxfp8 * GGML_RESTRICT x = vx; const block_q8_0 * GGML_RESTRICT y = vy; - const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); + const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); const __m256i v_zero = _mm256_setzero_si256(); __m256 acc = _mm256_setzero_ps(); @@ -3851,141 +3898,55 @@ static inline void ggml_vec_dot_mxfp8_q8_0_avx2( const __m256 v_scale = _mm256_set1_ps( GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - // Process 32 FP8 elements in 4 groups of 8 - // AVX2 _mm256_cvtepu8_epi32 widens 8 bytes → 8 int32s directly for (int j = 0; j < 32; j += 8) { - // Load 8 FP8 bytes → 8 int32s - const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(x[ib].qs + j)); - const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); + const __m256i v_raw = _mm256_cvtepu8_epi32( + _mm_loadl_epi64((const __m128i *)(x[ib].qs + j))); + const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( + _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)))); - // Load 8 Q8_0 int8 values → float - const __m128i q8 = _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)); - const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q8)); + const __m256 val = mxfp_dequant_avx2(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, + v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); - // Extract sign (bit 7), exponent, mantissa - const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80)); - const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); - const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); - - // Normal path: IEEE bits = (sign << 24) | ((exp + offset) << 23) | (mant << mant_shift) - const __m256i ieee = _mm256_or_si256( - _mm256_or_si256(_mm256_slli_epi32(sign, 24), - _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), - _mm256_slli_epi32(mant, mant_shift)); - const __m256 normal = _mm256_castsi256_ps(ieee); - - // Subnormal path: |val| = mant * sub_scale, then apply sign - const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); - const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( - _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24))); - - // Select: subnormal when exp == 0, else normal - const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); - const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); - - // Accumulate: val * scale * q8_float acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc); } } *s = hsum_float_8(acc); } -#endif -void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__AVX2__) - // E4M3: sign(1) exp(4) mant(3), bias=7 - ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy, - MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, - MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); -#else - ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - -// AVX2-optimized MXFP6 × Q8_0 dot product. -// Unpacks tight 6-bit packing (4 values per 3 bytes), then dequants to float. -#if defined(__AVX2__) -static inline void ggml_vec_dot_mxfp6_q8_0_avx2( +// Unified FP6 × Q8_0 dot product (works for E2M3 and E3M2). +static void ggml_vec_dot_mxfp6_q8_0_avx2( int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, - size_t block_size, - // FP6 format parameters: - const int exp_mask, // 0x3 for E2M3, 0x7 for E3M2 - const int mant_mask, // 0x7 for E2M3, 0x3 for E3M2 - const int exp_shift, // 3 for E2M3, 2 for E3M2 - const int ieee_exp_off, // 126 for E2M3, 124 for E3M2 - const int mant_shift, // 20 for E2M3, 21 for E3M2 - const float sub_scale) { // 1/8 for E2M3, 1/16 for E3M2 + const mxfp_avx2_traits_t * t) { assert(n % QK_MXFP6 == 0); const int nb = n / QK_MXFP6; const block_q8_0 * GGML_RESTRICT y = vy; - const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); + const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); const __m256i v_zero = _mm256_setzero_si256(); __m256 acc = _mm256_setzero_ps(); for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size); + const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; const __m256 v_scale = _mm256_set1_ps( GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d)); - // Process 32 FP6 elements in 4 groups of 8 (each group = 2 × 3-byte packs) for (int j = 0; j < 32; j += 8) { - // Unpack 8 FP6 values from 6 bytes (two groups of 3 bytes → 4 values each) - uint8_t unpacked[8]; - { - const uint8_t * p = xb->qs + (j * 3 / 4); - const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[0] = (pk0 >> 0) & 0x3F; - unpacked[1] = (pk0 >> 6) & 0x3F; - unpacked[2] = (pk0 >> 12) & 0x3F; - unpacked[3] = (pk0 >> 18) & 0x3F; - } - { - const uint8_t * p = xb->qs + ((j + 4) * 3 / 4); - const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[4] = (pk1 >> 0) & 0x3F; - unpacked[5] = (pk1 >> 6) & 0x3F; - unpacked[6] = (pk1 >> 12) & 0x3F; - unpacked[7] = (pk1 >> 18) & 0x3F; - } + const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j); + const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32( + _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)))); - // Widen 8 bytes → 8 int32s - const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked); - const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); - - // Load 8 Q8_0 int8 values → float - const __m128i q8 = _mm_loadl_epi64((const __m128i *)(y[ib].qs + j)); - const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q8)); - - // Extract sign (bit 5 for FP6), exponent, mantissa - const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20)); - const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); - const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); - - // Normal: IEEE bits = (sign << 26) | ((exp + offset) << 23) | (mant << mant_shift) - const __m256i ieee = _mm256_or_si256( - _mm256_or_si256(_mm256_slli_epi32(sign, 26), - _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), - _mm256_slli_epi32(mant, mant_shift)); - const __m256 normal = _mm256_castsi256_ps(ieee); - - // Subnormal: |val| = mant * sub_scale, apply sign - const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); - const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( - _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26))); - - // Select: subnormal when exp == 0 - const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); - const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); + const __m256 val = mxfp_dequant_avx2(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, + v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc); } @@ -3993,162 +3954,140 @@ static inline void ggml_vec_dot_mxfp6_q8_0_avx2( *s = hsum_float_8(acc); } -#endif -void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); -#if defined(__AVX2__) - // E2M3: sign(1) exp(2) mant(3), bias=1 - ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, sizeof(block_mxfp6), - MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, - MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); -#else - ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} +// ── MXFP FP8/FP6 dequantize_row (AoS) ───────────────────────────────────── -// ---- MXFP dequantize_row (to_float) — AVX2-optimized ---- -// Extracts the SIMD dequant logic from vec_dot above, writing floats to output buffer -// instead of accumulating a dot product. - -#if defined(__AVX2__) -static inline void dequantize_row_mxfp8_avx2( +static void dequantize_row_mxfp8_avx2( const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, - const int exp_mask, const int mant_mask, const int exp_shift, - const int ieee_exp_off, const int mant_shift, const float sub_scale) { + const mxfp_avx2_traits_t * t) { assert(k % QK_MXFP8 == 0); const int nb = k / QK_MXFP8; const block_mxfp8 * GGML_RESTRICT x = vx; - const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); + const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); const __m256i v_zero = _mm256_setzero_si256(); for (int ib = 0; ib < nb; ++ib) { const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(x[ib].e)); for (int j = 0; j < 32; j += 8) { - const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(x[ib].qs + j)); - const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); + const __m256i v_raw = _mm256_cvtepu8_epi32( + _mm_loadl_epi64((const __m128i *)(x[ib].qs + j))); - const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80)); - const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); - const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); - - const __m256i ieee = _mm256_or_si256( - _mm256_or_si256(_mm256_slli_epi32(sign, 24), - _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), - _mm256_slli_epi32(mant, mant_shift)); - const __m256 normal = _mm256_castsi256_ps(ieee); - - const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); - const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( - _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24))); - - const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); - const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); + const __m256 val = mxfp_dequant_avx2(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, + v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); _mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale)); } } } -static inline void dequantize_row_mxfp6_avx2( +static void dequantize_row_mxfp6_avx2( const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k, - size_t block_size, - const int exp_mask, const int mant_mask, const int exp_shift, - const int ieee_exp_off, const int mant_shift, const float sub_scale) { + const mxfp_avx2_traits_t * t) { assert(k % QK_MXFP6 == 0); const int nb = k / QK_MXFP6; - const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); + const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); + const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); const __m256i v_zero = _mm256_setzero_si256(); for (int ib = 0; ib < nb; ++ib) { - const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size); + const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib; const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(xb->e)); for (int j = 0; j < 32; j += 8) { - // Unpack 8 FP6 values from 6 bytes (two groups of 3 bytes → 4 values each) - uint8_t unpacked[8]; - { - const uint8_t * p = xb->qs + (j * 3 / 4); - const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[0] = (pk0 >> 0) & 0x3F; - unpacked[1] = (pk0 >> 6) & 0x3F; - unpacked[2] = (pk0 >> 12) & 0x3F; - unpacked[3] = (pk0 >> 18) & 0x3F; - } - { - const uint8_t * p = xb->qs + ((j + 4) * 3 / 4); - const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[4] = (pk1 >> 0) & 0x3F; - unpacked[5] = (pk1 >> 6) & 0x3F; - unpacked[6] = (pk1 >> 12) & 0x3F; - unpacked[7] = (pk1 >> 18) & 0x3F; - } + const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j); - const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked); - const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); - - const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20)); - const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); - const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); - - const __m256i ieee = _mm256_or_si256( - _mm256_or_si256(_mm256_slli_epi32(sign, 26), - _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), - _mm256_slli_epi32(mant, mant_shift)); - const __m256 normal = _mm256_castsi256_ps(ieee); - - const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); - const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( - _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26))); - - const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); - const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); + const __m256 val = mxfp_dequant_avx2(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, + v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); _mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale)); } } } -#endif -void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { -#if defined(__AVX2__) - dequantize_row_mxfp8_avx2(x, y, k, - MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, - MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); -#else - dequantize_row_mxfp8_cpu_generic(x, y, k); -#endif +// ── MXFP SoA dequant (flash attention) ───────────────────────────────────── + +static void dequantize_row_mxfp8_soa_avx2( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, + const mxfp_avx2_traits_t * t) { + assert(k % QK_MXFP8 == 0); + const int nb = k / QK_MXFP8; + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); + + const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); + const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); + const __m256i v_zero = _mm256_setzero_si256(); + + for (int ib = 0; ib < nb; ++ib) { + const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK)); + + for (int j = 0; j < 32; j += 8) { + const __m256i v_raw = _mm256_cvtepu8_epi32( + _mm_loadl_epi64((const __m128i *)(qs + j))); + + const __m256 val = mxfp_dequant_avx2(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, + v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); + + _mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale)); + } + } } -void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { -#if defined(__AVX2__) - dequantize_row_mxfp6_avx2(x, y, k, sizeof(block_mxfp6), - MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, - MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); -#else - dequantize_row_mxfp6_cpu_generic(x, y, k); -#endif +static void dequantize_row_mxfp6_soa_avx2( + const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, + const mxfp_avx2_traits_t * t) { + assert(k % QK_MXFP6 == 0); + const int nb = k / QK_MXFP6; + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); + + const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask); + const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask); + const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off); + const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale); + const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask); + const __m256i v_zero = _mm256_setzero_si256(); + + for (int ib = 0; ib < nb; ++ib) { + const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); + const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK)); + + for (int j = 0; j < 32; j += 8) { + const __m256i v_raw = unpack_fp6x8_avx2(qs, j); + + const __m256 val = mxfp_dequant_avx2(v_raw, + v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, + v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift); + + _mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale)); + } + } } -// SoA dequant for flash attention — contiguous qs region + separate e8m0 region -#if defined(__AVX2__) -static inline void dequantize_row_mxfp4_soa_avx2( +// MXFP4 SoA dequant — LUT-based, no IEEE reconstruction needed. +static void dequantize_row_mxfp4_soa_avx2( const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) { assert(k % QK_MXFP4 == 0); const int nb = k / QK_MXFP4; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); + const char * qs_base = (const char *)src; + const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK); const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4); const __m128i m4b = _mm_set1_epi8(0x0f); @@ -4163,13 +4102,11 @@ static inline void dequantize_row_mxfp4_soa_avx2( const __m128i lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits, m4b)); const __m128i hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4b)); - // lo nibbles → first 16 floats const __m256i lo32_0 = _mm256_cvtepi8_epi32(lo); const __m256i lo32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(lo, 8)); _mm256_storeu_ps(y + i * QK_MXFP4 + 0, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_0), v_scale)); _mm256_storeu_ps(y + i * QK_MXFP4 + 8, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_1), v_scale)); - // hi nibbles → second 16 floats const __m256i hi32_0 = _mm256_cvtepi8_epi32(hi); const __m256i hi32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(hi, 8)); _mm256_storeu_ps(y + i * QK_MXFP4 + 16, _mm256_mul_ps(_mm256_cvtepi32_ps(hi32_0), v_scale)); @@ -4177,116 +4114,45 @@ static inline void dequantize_row_mxfp4_soa_avx2( } } -static inline void dequantize_row_mxfp8_soa_avx2( - const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, - const int exp_mask, const int mant_mask, const int exp_shift, - const int ieee_exp_off, const int mant_shift, const float sub_scale) { - assert(k % QK_MXFP8 == 0); - const int nb = k / QK_MXFP8; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK); +#endif // __AVX2__ - const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); - const __m256i v_zero = _mm256_setzero_si256(); +// ── Public dispatch functions ────────────────────────────────────────────── - for (int ib = 0; ib < nb; ++ib) { - const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); - const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK)); - - for (int j = 0; j < 32; j += 8) { - const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(qs + j)); - const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); - - const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80)); - const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); - const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); - - const __m256i ieee = _mm256_or_si256( - _mm256_or_si256(_mm256_slli_epi32(sign, 24), - _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), - _mm256_slli_epi32(mant, mant_shift)); - const __m256 normal = _mm256_castsi256_ps(ieee); - - const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); - const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( - _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24))); - - const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); - const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); - - _mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale)); - } - } -} - -static inline void dequantize_row_mxfp6_soa_avx2( - const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k, - const int exp_mask, const int mant_mask, const int exp_shift, - const int ieee_exp_off, const int mant_shift, const float sub_scale) { - assert(k % QK_MXFP6 == 0); - const int nb = k / QK_MXFP6; - const char * row = (const char *)src; - const char * qs_base = row; - const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK); - - const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask); - const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask); - const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off); - const __m256 v_sub_sc = _mm256_set1_ps(sub_scale); - const __m256i v_zero = _mm256_setzero_si256(); - - for (int ib = 0; ib < nb; ++ib) { - const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib])); - const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK)); - - for (int j = 0; j < 32; j += 8) { - uint8_t unpacked[8]; - { - const uint8_t * p = qs + (j * 3 / 4); - const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[0] = (pk0 >> 0) & 0x3F; - unpacked[1] = (pk0 >> 6) & 0x3F; - unpacked[2] = (pk0 >> 12) & 0x3F; - unpacked[3] = (pk0 >> 18) & 0x3F; - } - { - const uint8_t * p = qs + ((j + 4) * 3 / 4); - const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16); - unpacked[4] = (pk1 >> 0) & 0x3F; - unpacked[5] = (pk1 >> 6) & 0x3F; - unpacked[6] = (pk1 >> 12) & 0x3F; - unpacked[7] = (pk1 >> 18) & 0x3F; - } - - const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked); - const __m256i v_raw = _mm256_cvtepu8_epi32(raw8); - - const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20)); - const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask); - const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask); - - const __m256i ieee = _mm256_or_si256( - _mm256_or_si256(_mm256_slli_epi32(sign, 26), - _mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)), - _mm256_slli_epi32(mant, mant_shift)); - const __m256 normal = _mm256_castsi256_ps(ieee); - - const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc); - const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256( - _mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26))); - - const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero)); - const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub); - - _mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale)); - } - } -} +void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); +#if defined(__AVX2__) + ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E4M3); +#else + ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif +} + +void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by); +#if defined(__AVX2__) + ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E2M3); +#else + ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__AVX2__) + dequantize_row_mxfp8_avx2(x, y, k, &MXFP_TRAITS_E4M3); +#else + dequantize_row_mxfp8_cpu_generic(x, y, k); +#endif +} + +void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { +#if defined(__AVX2__) + dequantize_row_mxfp6_avx2(x, y, k, &MXFP_TRAITS_E2M3); +#else + dequantize_row_mxfp6_cpu_generic(x, y, k); +#endif +} void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__AVX2__) @@ -4298,9 +4164,7 @@ void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__AVX2__) - dequantize_row_mxfp8_soa_avx2(x, y, k, - MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT, - MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE); + dequantize_row_mxfp8_soa_avx2(x, y, k, &MXFP_TRAITS_E4M3); #else dequantize_row_mxfp8_soa_cpu_generic(x, y, k); #endif @@ -4308,9 +4172,7 @@ void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { #if defined(__AVX2__) - dequantize_row_mxfp6_soa_avx2(x, y, k, - MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT, - MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE); + dequantize_row_mxfp6_soa_avx2(x, y, k, &MXFP_TRAITS_E2M3); #else dequantize_row_mxfp6_soa_cpu_generic(x, y, k); #endif diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 27275ca1e1..12f04905af 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8297,8 +8297,8 @@ static mxfp_fa_params mxfp_fa_params_init( if (is_mxfp_k) { switch (k->type) { - case GGML_TYPE_MXFP4_E2M1: p.q_quantize = quantize_row_mxfp4_soa; p.k_dequantize = dequantize_row_mxfp4_soa_cpu; break; - case GGML_TYPE_MXFP8_E4M3: p.q_quantize = quantize_row_mxfp8_soa; p.k_dequantize = dequantize_row_mxfp8_soa_cpu; break; + case GGML_TYPE_MXFP4_E2M1: p.q_quantize = quantize_row_mxfp4_soa; p.k_dequantize = dequantize_row_mxfp4_soa_cpu; break; + case GGML_TYPE_MXFP8_E4M3: p.q_quantize = quantize_row_mxfp8_soa; p.k_dequantize = dequantize_row_mxfp8_soa_cpu; break; case GGML_TYPE_MXFP6_E2M3: p.q_quantize = quantize_row_mxfp6_soa; p.k_dequantize = dequantize_row_mxfp6_soa_cpu; break; default: GGML_ABORT("unsupported MXFP K type"); } @@ -8306,8 +8306,8 @@ static mxfp_fa_params mxfp_fa_params_init( if (is_mxfp_v) { switch (v->type) { - case GGML_TYPE_MXFP4_E2M1: p.v_dequantize = dequantize_row_mxfp4_soa_cpu; break; - case GGML_TYPE_MXFP8_E4M3: p.v_dequantize = dequantize_row_mxfp8_soa_cpu; break; + case GGML_TYPE_MXFP4_E2M1: p.v_dequantize = dequantize_row_mxfp4_soa_cpu; break; + case GGML_TYPE_MXFP8_E4M3: p.v_dequantize = dequantize_row_mxfp8_soa_cpu; break; case GGML_TYPE_MXFP6_E2M3: p.v_dequantize = dequantize_row_mxfp6_soa_cpu; break; default: GGML_ABORT("unsupported MXFP V type"); } @@ -8328,6 +8328,7 @@ static mxfp_fa_params mxfp_fa_params_init( // Per-head SoA addressing for multihead mode. // Precompute byte offsets so the hot loop can skip per-head pointer math. + // qs_per_block values from centralized MXFP_QS_PER_BLOCK_* defines in ggml-common.h. auto mxfp_qs_per_block = [](ggml_type type) -> int { switch (type) { case GGML_TYPE_MXFP4_E2M1: return MXFP4_SOA_QS_PER_BLOCK; @@ -8341,7 +8342,6 @@ static mxfp_fa_params mxfp_fa_params_init( p.k_qs_per_block = mxfp_qs_per_block(k->type); p.k_blocks_per_head = (int)(DK / 32); p.k_head_qs_bytes = p.k_blocks_per_head * p.k_qs_per_block; - // e8m0 offset from row start = total_blocks * qs_per_block const int64_t k_total_blocks = p.k_multihead ? nek2 * p.k_blocks_per_head : p.k_blocks_per_head; p.k_head_e8m0_offset = k_total_blocks * p.k_qs_per_block; } diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index a0f6928e10..b386446035 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -61,7 +61,7 @@ GGML_API void quantize_row_mxfp4_soa (const float * GGML_RESTRICT x, void * G GGML_API void dequantize_row_mxfp4_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_mxfp8_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_mxfp8_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); -GGML_API void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); +GGML_API void quantize_row_mxfp6_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k); GGML_API void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);