ggml: refactor x86 AVX2 and ARM NEON MXFP dequant — shared traits and helpers

Add mxfp_dequant_traits_t to ggml-common.h as single source of truth for
MXFP IEEE-754 reconstruction parameters. Define static const instances for
all 4 formats (E4M3, E5M2, E2M3, E3M2), ready for CUDA/Metal/Vulkan reuse.

Extract shared dequant and FP6 unpack helpers on both architectures,
replacing duplicated inline code and macros. Net -215 lines.
This commit is contained in:
Tim Burke 2026-03-15 21:27:42 -04:00
parent c913ab36d2
commit b8e8d291d1
5 changed files with 531 additions and 746 deletions

View File

@ -267,6 +267,47 @@ static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4
#define MXFP6_E3M2_MANT_SHIFT 21 // 23-2
#define MXFP6_E3M2_SUB_SCALE (1.0f/16.0f) // 2^(-4) = 2^(1-3-2)
// Unified MXFP dequantization traits for SIMD backends (CPU x86/ARM, CUDA, Metal, Vulkan).
// Contains all parameters needed for IEEE-754 bit reconstruction of FP8/FP6 elements.
// FP4 uses LUT-based dequant and does not need this struct.
typedef struct {
int exp_mask; // (1<<E)-1: exponent field mask
int mant_mask; // (1<<M)-1: mantissa field mask
int exp_shift; // M: right-shift to extract exponent
int ieee_exp_off; // 127-bias: offset to convert to IEEE exponent
int mant_shift; // 23-M: left-shift to align mantissa in IEEE float
float sub_scale; // 2^(1-bias-M): subnormal scale factor
int sign_mask; // 0x80 for 8-bit, 0x20 for 6-bit formats
int sign_shift; // 24 for 8-bit, 26 for 6-bit formats
int qs_per_block; // bytes of quantized data per 32-element block
int emax_offset; // type-specific offset for E8M0 MSE search
} mxfp_dequant_traits_t;
// Static const trait instances for each MXFP format.
// Gated by GGML_COMMON_IMPL to ensure single definition per translation unit.
#if defined(GGML_COMMON_IMPL)
static const mxfp_dequant_traits_t MXFP_TRAITS_E4M3 = {
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE,
0x80, 24, MXFP_QS_PER_BLOCK_E4M3, MXFP8_E4M3_EMAX_OFFSET
};
static const mxfp_dequant_traits_t MXFP_TRAITS_E5M2 = {
MXFP8_E5M2_EXP_MASK, MXFP8_E5M2_MANT_MASK, MXFP8_E5M2_EXP_SHIFT,
MXFP8_E5M2_IEEE_EXP_OFF, MXFP8_E5M2_MANT_SHIFT, MXFP8_E5M2_SUB_SCALE,
0x80, 24, MXFP_QS_PER_BLOCK_E5M2, MXFP8_E5M2_EMAX_OFFSET
};
static const mxfp_dequant_traits_t MXFP_TRAITS_E2M3 = {
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE,
0x20, 26, MXFP_QS_PER_BLOCK_E2M3, MXFP6_E2M3_EMAX_OFFSET
};
static const mxfp_dequant_traits_t MXFP_TRAITS_E3M2 = {
MXFP6_E3M2_EXP_MASK, MXFP6_E3M2_MANT_MASK, MXFP6_E3M2_EXP_SHIFT,
MXFP6_E3M2_IEEE_EXP_OFF, MXFP6_E3M2_MANT_SHIFT, MXFP6_E3M2_SUB_SCALE,
0x20, 26, MXFP_QS_PER_BLOCK_E3M2, MXFP6_E3M2_EMAX_OFFSET
};
#endif // GGML_COMMON_IMPL
#define QK_MXFP4 32
typedef struct {
uint8_t e; // E8M0

View File

@ -4134,363 +4134,326 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
#endif
}
// NEON-optimized MXFP8 × Q8_0 dot product.
// Dequants FP8 elements to float via IEEE 754 bit construction, then dots against Q8_0.
// Parameters encode the FP8 format: sign_shift, exp_mask, mant_mask, ieee_exp_bias, mant_shift, sub_scale.
// ── MXFP FP8/FP6 NEON helpers ──────────────────────────────────────────────
// Shared IEEE-754 bit reconstruction and FP6 unpacking used by vec_dot,
// dequantize_row, and SoA dequant functions.
//
// NEON requires vshlq_n_u32 to have a compile-time literal constant, so we use
// two separate helpers for FP8 (sign at bit 7, shift 24) and FP6 (sign at bit 5,
// shift 26) rather than a single parameterized function.
#if defined(__ARM_NEON)
static inline void ggml_vec_dot_mxfp8_q8_0_neon(
// Use shared mxfp_dequant_traits_t from ggml-common.h.
#define mxfp_neon_traits_t mxfp_dequant_traits_t
// Dequantize 4 raw FP8 values (uint32x4_t) → 4 IEEE-754 floats.
// Sign bit at position 7, sign shift = 24.
static inline float32x4_t mxfp8_dequant_neon(
const uint32x4_t v_raw,
const uint32x4_t v_exp_mask, const uint32x4_t v_mant_mask,
const uint32x4_t v_ieee_off, const float32x4_t v_sub_sc,
const int32x4_t v_neg_exp_shift, const int32x4_t v_mant_shift) {
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80));
const uint32x4_t exp = vandq_u32(vshlq_u32(v_raw, v_neg_exp_shift), v_exp_mask);
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask);
const uint32x4_t ieee = vorrq_u32(
vorrq_u32(vshlq_n_u32(sign, 24),
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)),
vshlq_u32(mant, v_mant_shift));
const float32x4_t normal = vreinterpretq_f32_u32(ieee);
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc);
const float32x4_t sub_val = vreinterpretq_f32_u32(
vorrq_u32(vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)));
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0));
return vbslq_f32(is_sub, sub_val, normal);
}
// Dequantize 4 raw FP6 values (uint32x4_t) → 4 IEEE-754 floats.
// Sign bit at position 5, sign shift = 26.
static inline float32x4_t mxfp6_dequant_neon(
const uint32x4_t v_raw,
const uint32x4_t v_exp_mask, const uint32x4_t v_mant_mask,
const uint32x4_t v_ieee_off, const float32x4_t v_sub_sc,
const int32x4_t v_neg_exp_shift, const int32x4_t v_mant_shift) {
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20));
const uint32x4_t exp = vandq_u32(vshlq_u32(v_raw, v_neg_exp_shift), v_exp_mask);
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask);
const uint32x4_t ieee = vorrq_u32(
vorrq_u32(vshlq_n_u32(sign, 26),
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)),
vshlq_u32(mant, v_mant_shift));
const float32x4_t normal = vreinterpretq_f32_u32(ieee);
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc);
const float32x4_t sub_val = vreinterpretq_f32_u32(
vorrq_u32(vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)));
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0));
return vbslq_f32(is_sub, sub_val, normal);
}
// Unpack 4 tightly-packed 6-bit values from 3 bytes, widen to uint32x4_t.
static inline uint32x4_t unpack_fp6x4_neon(const uint8_t * p) {
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
uint8_t u[4];
u[0] = (pk >> 0) & 0x3F;
u[1] = (pk >> 6) & 0x3F;
u[2] = (pk >> 12) & 0x3F;
u[3] = (pk >> 18) & 0x3F;
const uint8x8_t raw8 = vcreate_u8(
(uint64_t)u[0] | ((uint64_t)u[1] << 8) |
((uint64_t)u[2] << 16) | ((uint64_t)u[3] << 24));
return vmovl_u16(vget_low_u16(vmovl_u8(raw8)));
}
// Widen 8 raw bytes to two uint32x4_t halves.
static inline void widen_u8x8_to_u32x4x2(const uint8_t * src,
uint32x4_t * lo, uint32x4_t * hi) {
const uint8x8_t raw8 = vld1_u8(src);
const uint16x8_t raw16 = vmovl_u8(raw8);
*lo = vmovl_u16(vget_low_u16(raw16));
*hi = vmovl_u16(vget_high_u16(raw16));
}
// Widen 8 Q8_0 int8 values to two float32x4_t halves.
static inline void widen_s8x8_to_f32x4x2(const int8_t * src,
float32x4_t * lo, float32x4_t * hi) {
const int8x8_t q8 = vld1_s8(src);
const int16x8_t q16 = vmovl_s8(q8);
*lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16)));
*hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16)));
}
// ── MXFP FP8/FP6 vec_dot ──────────────────────────────────────────────────
static void ggml_vec_dot_mxfp8_q8_0_neon(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
// FP8 format parameters:
const uint32_t exp_mask, // 0xF for E4M3, 0x1F for E5M2
const uint32_t mant_mask, // 0x7 for E4M3, 0x3 for E5M2
const int exp_shift, // 3 for E4M3, 2 for E5M2
const uint32_t ieee_exp_off, // 120 for E4M3, 112 for E5M2
const int mant_shift, // 20 for E4M3, 21 for E5M2
const float sub_scale) { // 1/512 for E4M3, 1/65536 for E5M2
const mxfp_neon_traits_t * t) {
assert(n % QK_MXFP8 == 0);
const int nb = n / QK_MXFP8;
const block_mxfp8 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
float32x4_t acc0 = vdupq_n_f32(0.0f);
float32x4_t acc1 = vdupq_n_f32(0.0f);
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
// Use variable shifts (vshlq_u32) instead of constant shifts (vshlq_n_u32)
// because exp_shift/mant_shift are function parameters, not compile-time constants.
// Clang requires _n_ intrinsics to have literal constant arguments.
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
const int32x4_t v_mant_shift = vdupq_n_s32(mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const float scale = GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d);
const float32x4_t v_scale = vdupq_n_f32(scale);
const float32x4_t v_scale = vdupq_n_f32(
GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
// Process 32 FP8 elements in 8 groups of 4
for (int j = 0; j < 32; j += 8) {
// Load 8 FP8 bytes, extend to two uint32x4_t
const uint8x8_t raw8 = vld1_u8(x[ib].qs + j);
const uint16x8_t raw16 = vmovl_u8(raw8);
const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16));
const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16));
uint32x4_t v_lo, v_hi;
widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi);
// Load 8 Q8_0 int8 values, extend to two int32x4_t → float32x4_t
const int8x8_t q8 = vld1_s8(y[ib].qs + j);
const int16x8_t q16 = vmovl_s8(q8);
const float32x4_t qf_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16)));
const float32x4_t qf_hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16)));
float32x4_t qf_lo, qf_hi;
widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi);
// Dequant FP8 → float for both groups of 4
#define DEQUANT_FP8_NEON(v_raw, qf, acc) do { \
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \
const uint32x4_t exp = vandq_u32( \
vshlq_u32(v_raw, v_neg_exp_shift), \
v_exp_mask); \
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \
/* Normal: IEEE bits = (exp + offset) << 23 | mant << mant_shift */ \
const uint32x4_t ieee = vorrq_u32( \
vorrq_u32(vshlq_n_u32(sign, 24), \
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \
vshlq_u32(mant, v_mant_shift)); \
const float32x4_t normal = vreinterpretq_f32_u32(ieee); \
/* Subnormal: sign * mant * sub_scale */ \
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \
const uint32x4_t sub_bits = vorrq_u32( \
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \
/* Select: subnormal when exp == 0, else normal */ \
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \
/* Multiply by scale and Q8 value, accumulate */ \
(acc) = vfmaq_f32((acc), vmulq_f32(val, v_scale), qf); \
} while (0)
const float32x4_t val_lo = mxfp8_dequant_neon(v_lo,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
const float32x4_t val_hi = mxfp8_dequant_neon(v_hi,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
DEQUANT_FP8_NEON(v_lo, qf_lo, acc0);
DEQUANT_FP8_NEON(v_hi, qf_hi, acc1);
#undef DEQUANT_FP8_NEON
acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo);
acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi);
}
}
*s = vaddvq_f32(vaddq_f32(acc0, acc1));
}
#endif
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__ARM_NEON)
// E4M3: sign(1) exp(4) mant(3), bias=7
ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy,
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
#else
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
// NEON-optimized MXFP6 × Q8_0 dot product.
// Unpacks tight 6-bit packing (4 values per 3 bytes), then dequants to float.
#if defined(__ARM_NEON)
static inline void ggml_vec_dot_mxfp6_q8_0_neon(
static void ggml_vec_dot_mxfp6_q8_0_neon(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
size_t block_size,
// FP6 format parameters:
const uint32_t exp_mask, // 0x3 for E2M3, 0x7 for E3M2
const uint32_t mant_mask, // 0x7 for E2M3, 0x3 for E3M2
const int exp_shift, // 3 for E2M3, 2 for E3M2
const uint32_t ieee_exp_off, // 126 for E2M3, 124 for E3M2
const int mant_shift, // 20 for E2M3, 21 for E3M2
const float sub_scale) { // 1/8 for E2M3, 1/16 for E3M2
const mxfp_neon_traits_t * t) {
assert(n % QK_MXFP6 == 0);
const int nb = n / QK_MXFP6;
const block_q8_0 * GGML_RESTRICT y = vy;
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
float32x4_t acc0 = vdupq_n_f32(0.0f);
float32x4_t acc1 = vdupq_n_f32(0.0f);
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
const int32x4_t v_mant_shift = vdupq_n_s32(mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size);
const float scale = GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d);
const float32x4_t v_scale = vdupq_n_f32(scale);
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
const float32x4_t v_scale = vdupq_n_f32(
GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
// Process 32 FP6 elements: 8 groups of 4, each packed in 3 bytes
for (int j = 0; j < 32; j += 8) {
// Unpack two groups of 4 FP6 values (6 bytes → 8 values)
uint8_t unpacked[8];
// Group 1: 3 bytes → 4 values
{
const uint8_t * p = xb->qs + (j * 3 / 4);
const uint32_t packed = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[0] = (packed >> 0) & 0x3F;
unpacked[1] = (packed >> 6) & 0x3F;
unpacked[2] = (packed >> 12) & 0x3F;
unpacked[3] = (packed >> 18) & 0x3F;
}
// Group 2: next 3 bytes → 4 values
{
const uint8_t * p = xb->qs + ((j + 4) * 3 / 4);
const uint32_t packed = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[4] = (packed >> 0) & 0x3F;
unpacked[5] = (packed >> 6) & 0x3F;
unpacked[6] = (packed >> 12) & 0x3F;
unpacked[7] = (packed >> 18) & 0x3F;
}
const uint32x4_t v_lo = unpack_fp6x4_neon(xb->qs + (j * 3 / 4));
const uint32x4_t v_hi = unpack_fp6x4_neon(xb->qs + ((j + 4) * 3 / 4));
// Extend to uint32x4_t
const uint8x8_t raw8 = vld1_u8(unpacked);
const uint16x8_t raw16 = vmovl_u8(raw8);
const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16));
const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16));
float32x4_t qf_lo, qf_hi;
widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi);
// Load Q8_0 int8 values
const int8x8_t q8 = vld1_s8(y[ib].qs + j);
const int16x8_t q16 = vmovl_s8(q8);
const float32x4_t qf_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16)));
const float32x4_t qf_hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16)));
const float32x4_t val_lo = mxfp6_dequant_neon(v_lo,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
const float32x4_t val_hi = mxfp6_dequant_neon(v_hi,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
// Dequant FP6 → float (same IEEE construction as FP8, sign bit at position 5)
#define DEQUANT_FP6_NEON(v_raw, qf, acc) do { \
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20)); \
const uint32x4_t exp = vandq_u32( \
vshlq_u32(v_raw, v_neg_exp_shift), \
v_exp_mask); \
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \
const uint32x4_t ieee = vorrq_u32( \
vorrq_u32(vshlq_n_u32(sign, 26), \
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \
vshlq_u32(mant, v_mant_shift)); \
const float32x4_t normal = vreinterpretq_f32_u32(ieee); \
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \
const uint32x4_t sub_bits = vorrq_u32( \
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)); \
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \
(acc) = vfmaq_f32((acc), vmulq_f32(val, v_scale), qf); \
} while (0)
DEQUANT_FP6_NEON(v_lo, qf_lo, acc0);
DEQUANT_FP6_NEON(v_hi, qf_hi, acc1);
#undef DEQUANT_FP6_NEON
acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo);
acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi);
}
}
*s = vaddvq_f32(vaddq_f32(acc0, acc1));
}
#endif
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__ARM_NEON)
// E2M3: sign(1) exp(2) mant(3), bias=1
ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, sizeof(block_mxfp6),
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
#else
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
// ── MXFP FP8/FP6 dequantize_row (AoS) ─────────────────────────────────────
// ---- MXFP dequantize_row (to_float) — NEON-optimized ----
#if defined(__ARM_NEON)
static inline void dequantize_row_mxfp8_neon(
static void dequantize_row_mxfp8_neon(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
const uint32_t exp_mask, const uint32_t mant_mask,
const int exp_shift, const uint32_t ieee_exp_off,
const int mant_shift, const float sub_scale) {
const mxfp_neon_traits_t * t) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const block_mxfp8 * GGML_RESTRICT x = vx;
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift);
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(x[ib].e));
for (int j = 0; j < 32; j += 8) {
const uint8x8_t raw8 = vld1_u8(x[ib].qs + j);
const uint16x8_t raw16 = vmovl_u8(raw8);
const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16));
const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16));
uint32x4_t v_lo, v_hi;
widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi);
#define DEQUANT_FP8_STORE(v_raw, dst) do { \
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \
const uint32x4_t exp = vandq_u32( \
vshlq_u32(v_raw, v_neg_exp_shift), \
v_exp_mask); \
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \
const uint32x4_t ieee = vorrq_u32( \
vorrq_u32(vshlq_n_u32(sign, 24), \
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \
vshlq_u32(mant, v_mant_shift_v)); \
const float32x4_t normal = vreinterpretq_f32_u32(ieee); \
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \
const uint32x4_t sub_bits = vorrq_u32( \
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \
vst1q_f32(dst, vmulq_f32(val, v_scale)); \
} while (0)
const float32x4_t val_lo = mxfp8_dequant_neon(v_lo,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
const float32x4_t val_hi = mxfp8_dequant_neon(v_hi,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
DEQUANT_FP8_STORE(v_lo, y + ib * QK_MXFP8 + j);
DEQUANT_FP8_STORE(v_hi, y + ib * QK_MXFP8 + j + 4);
#undef DEQUANT_FP8_STORE
vst1q_f32(y + ib * QK_MXFP8 + j, vmulq_f32(val_lo, v_scale));
vst1q_f32(y + ib * QK_MXFP8 + j + 4, vmulq_f32(val_hi, v_scale));
}
}
}
static inline void dequantize_row_mxfp6_neon(
static void dequantize_row_mxfp6_neon(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
size_t block_size,
const uint32_t exp_mask, const uint32_t mant_mask,
const int exp_shift, const uint32_t ieee_exp_off,
const int mant_shift, const float sub_scale) {
const mxfp_neon_traits_t * t) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift);
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size);
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(xb->e));
for (int j = 0; j < 32; j += 4) {
const uint8_t * p = xb->qs + (j * 3 / 4);
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
uint8_t unpacked[4];
unpacked[0] = (pk >> 0) & 0x3F;
unpacked[1] = (pk >> 6) & 0x3F;
unpacked[2] = (pk >> 12) & 0x3F;
unpacked[3] = (pk >> 18) & 0x3F;
const uint32x4_t v_raw = unpack_fp6x4_neon(xb->qs + (j * 3 / 4));
const uint8x8_t raw8 = vcreate_u8(
(uint64_t)unpacked[0] | ((uint64_t)unpacked[1] << 8) |
((uint64_t)unpacked[2] << 16) | ((uint64_t)unpacked[3] << 24));
const uint32x4_t v_raw = vmovl_u16(vget_low_u16(vmovl_u8(raw8)));
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20));
const uint32x4_t exp = vandq_u32(
vshlq_u32(v_raw, v_neg_exp_shift),
v_exp_mask);
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask);
const uint32x4_t ieee = vorrq_u32(
vorrq_u32(vshlq_n_u32(sign, 26),
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)),
vshlq_u32(mant, v_mant_shift_v));
const float32x4_t normal = vreinterpretq_f32_u32(ieee);
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc);
const uint32x4_t sub_bits = vorrq_u32(
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26));
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits);
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0));
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal);
const float32x4_t val = mxfp6_dequant_neon(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale));
}
}
}
#endif
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp8_neon(x, y, k,
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
#else
dequantize_row_mxfp8_cpu_generic(x, y, k);
#endif
// ── MXFP SoA dequant (flash attention) ─────────────────────────────────────
static void dequantize_row_mxfp8_soa_neon(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const mxfp_neon_traits_t * t) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 8) {
uint32x4_t v_lo, v_hi;
widen_u8x8_to_u32x4x2(qs + j, &v_lo, &v_hi);
const float32x4_t val_lo = mxfp8_dequant_neon(v_lo,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
const float32x4_t val_hi = mxfp8_dequant_neon(v_hi,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
vst1q_f32(y + ib * QK_MXFP8 + j, vmulq_f32(val_lo, v_scale));
vst1q_f32(y + ib * QK_MXFP8 + j + 4, vmulq_f32(val_hi, v_scale));
}
}
}
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp6_neon(x, y, k, sizeof(block_mxfp6),
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
#else
dequantize_row_mxfp6_cpu_generic(x, y, k);
#endif
static void dequantize_row_mxfp6_soa_neon(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const mxfp_neon_traits_t * t) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 4) {
const uint32x4_t v_raw = unpack_fp6x4_neon(qs + (j * 3 / 4));
const float32x4_t val = mxfp6_dequant_neon(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale));
}
}
}
// ---- MXFP SoA dequantize_row (to_float) — NEON-optimized ----
#if defined(__ARM_NEON)
static inline void dequantize_row_mxfp4_soa_neon(
// MXFP4 SoA dequant — LUT-based, no IEEE reconstruction needed.
static void dequantize_row_mxfp4_soa_neon(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
const uint8x16_t m4b = vdupq_n_u8(0x0f);
@ -4527,122 +4490,45 @@ static inline void dequantize_row_mxfp4_soa_neon(
}
}
static inline void dequantize_row_mxfp8_soa_neon(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const uint32_t exp_mask, const uint32_t mant_mask,
const int exp_shift, const uint32_t ieee_exp_off,
const int mant_shift, const float sub_scale) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
#endif // __ARM_NEON
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift);
// ── Public dispatch functions ──────────────────────────────────────────────
for (int ib = 0; ib < nb; ++ib) {
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 8) {
const uint8x8_t raw8 = vld1_u8(qs + j);
const uint16x8_t raw16 = vmovl_u8(raw8);
const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16));
const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16));
#define DEQUANT_FP8_STORE_SOA(v_raw, dst) do { \
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \
const uint32x4_t exp = vandq_u32( \
vshlq_u32(v_raw, v_neg_exp_shift), \
v_exp_mask); \
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \
const uint32x4_t ieee = vorrq_u32( \
vorrq_u32(vshlq_n_u32(sign, 24), \
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \
vshlq_u32(mant, v_mant_shift_v)); \
const float32x4_t normal = vreinterpretq_f32_u32(ieee); \
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \
const uint32x4_t sub_bits = vorrq_u32( \
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \
vst1q_f32(dst, vmulq_f32(val, v_scale)); \
} while (0)
DEQUANT_FP8_STORE_SOA(v_lo, y + ib * QK_MXFP8 + j);
DEQUANT_FP8_STORE_SOA(v_hi, y + ib * QK_MXFP8 + j + 4);
#undef DEQUANT_FP8_STORE_SOA
}
}
}
static inline void dequantize_row_mxfp6_soa_neon(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const uint32_t exp_mask, const uint32_t mant_mask,
const int exp_shift, const uint32_t ieee_exp_off,
const int mant_shift, const float sub_scale) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift);
for (int ib = 0; ib < nb; ++ib) {
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 4) {
const uint8_t * p = qs + (j * 3 / 4);
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
uint8_t unpacked[4];
unpacked[0] = (pk >> 0) & 0x3F;
unpacked[1] = (pk >> 6) & 0x3F;
unpacked[2] = (pk >> 12) & 0x3F;
unpacked[3] = (pk >> 18) & 0x3F;
const uint8x8_t raw8 = vcreate_u8(
(uint64_t)unpacked[0] | ((uint64_t)unpacked[1] << 8) |
((uint64_t)unpacked[2] << 16) | ((uint64_t)unpacked[3] << 24));
const uint32x4_t v_raw = vmovl_u16(vget_low_u16(vmovl_u8(raw8)));
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20));
const uint32x4_t exp = vandq_u32(
vshlq_u32(v_raw, v_neg_exp_shift),
v_exp_mask);
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask);
const uint32x4_t ieee = vorrq_u32(
vorrq_u32(vshlq_n_u32(sign, 26),
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)),
vshlq_u32(mant, v_mant_shift_v));
const float32x4_t normal = vreinterpretq_f32_u32(ieee);
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc);
const uint32x4_t sub_bits = vorrq_u32(
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26));
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits);
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0));
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal);
vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale));
}
}
}
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__ARM_NEON)
ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E4M3);
#else
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__ARM_NEON)
ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E2M3);
#else
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp8_neon(x, y, k, &MXFP_TRAITS_E4M3);
#else
dequantize_row_mxfp8_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp6_neon(x, y, k, &MXFP_TRAITS_E2M3);
#else
dequantize_row_mxfp6_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
@ -4654,9 +4540,7 @@ void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp8_soa_neon(x, y, k,
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
dequantize_row_mxfp8_soa_neon(x, y, k, &MXFP_TRAITS_E4M3);
#else
dequantize_row_mxfp8_soa_cpu_generic(x, y, k);
#endif
@ -4664,9 +4548,7 @@ void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES
void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__ARM_NEON)
dequantize_row_mxfp6_soa_neon(x, y, k,
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
dequantize_row_mxfp6_soa_neon(x, y, k, &MXFP_TRAITS_E2M3);
#else
dequantize_row_mxfp6_soa_cpu_generic(x, y, k);
#endif

View File

@ -3819,30 +3819,77 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
#endif
}
// AVX2-optimized MXFP8 × Q8_0 dot product.
// Dequants FP8 elements to float via IEEE 754 bit construction, then dots against Q8_0.
// Parameters encode the FP8 format: exp_mask, mant_mask, exp_shift, ieee_exp_offset, mant_shift, sub_scale.
// ── MXFP FP8/FP6 AVX2 helpers ──────────────────────────────────────────────
// Shared IEEE-754 bit reconstruction and FP6 unpacking used by vec_dot,
// dequantize_row, and SoA dequant functions.
#if defined(__AVX2__)
static inline void ggml_vec_dot_mxfp8_q8_0_avx2(
// Use shared mxfp_dequant_traits_t from ggml-common.h.
// Aliases for readability within this file.
#define mxfp_avx2_traits_t mxfp_dequant_traits_t
// Dequantize 8 raw MXFP values (widened to int32) → 8 IEEE-754 floats.
// Handles both normal and subnormal paths. Works for any FP6/FP8 format.
static inline __m256 mxfp_dequant_avx2(
const __m256i v_raw,
const __m256i v_exp_mask, const __m256i v_mant_mask,
const __m256i v_ieee_off, const __m256 v_sub_sc,
const __m256i v_sign_mask, const __m256i v_zero,
int exp_shift, int sign_shift, int mant_shift) {
const __m256i sign = _mm256_and_si256(v_raw, v_sign_mask);
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, sign_shift),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, sign_shift)));
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
return _mm256_blendv_ps(normal, sub_val, is_sub);
}
// Unpack 4 tightly-packed 6-bit values from 3 bytes into separate bytes.
static inline void unpack_fp6x4_avx2(const uint8_t * p, uint8_t out[4]) {
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
out[0] = (pk >> 0) & 0x3F;
out[1] = (pk >> 6) & 0x3F;
out[2] = (pk >> 12) & 0x3F;
out[3] = (pk >> 18) & 0x3F;
}
// Unpack 8 FP6 values (two groups of 4) from packed qs data at offset j.
static inline __m256i unpack_fp6x8_avx2(const uint8_t * qs, int j) {
uint8_t unpacked[8];
unpack_fp6x4_avx2(qs + (j * 3 / 4), unpacked);
unpack_fp6x4_avx2(qs + ((j + 4) * 3 / 4), unpacked + 4);
return _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)unpacked));
}
// ── MXFP FP8/FP6 vec_dot ──────────────────────────────────────────────────
// Unified FP8 × Q8_0 dot product (works for E4M3 and E5M2).
static void ggml_vec_dot_mxfp8_q8_0_avx2(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
// FP8 format parameters:
const int exp_mask, // 0xF for E4M3, 0x1F for E5M2
const int mant_mask, // 0x7 for E4M3, 0x3 for E5M2
const int exp_shift, // 3 for E4M3, 2 for E5M2
const int ieee_exp_off, // 120 for E4M3, 112 for E5M2
const int mant_shift, // 20 for E4M3, 21 for E5M2
const float sub_scale) { // 1/512 for E4M3, 1/65536 for E5M2
const mxfp_avx2_traits_t * t) {
assert(n % QK_MXFP8 == 0);
const int nb = n / QK_MXFP8;
const block_mxfp8 * GGML_RESTRICT x = vx;
const block_q8_0 * GGML_RESTRICT y = vy;
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
const __m256i v_zero = _mm256_setzero_si256();
__m256 acc = _mm256_setzero_ps();
@ -3851,141 +3898,55 @@ static inline void ggml_vec_dot_mxfp8_q8_0_avx2(
const __m256 v_scale = _mm256_set1_ps(
GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
// Process 32 FP8 elements in 4 groups of 8
// AVX2 _mm256_cvtepu8_epi32 widens 8 bytes → 8 int32s directly
for (int j = 0; j < 32; j += 8) {
// Load 8 FP8 bytes → 8 int32s
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(x[ib].qs + j));
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
const __m256i v_raw = _mm256_cvtepu8_epi32(
_mm_loadl_epi64((const __m128i *)(x[ib].qs + j)));
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
_mm_loadl_epi64((const __m128i *)(y[ib].qs + j))));
// Load 8 Q8_0 int8 values → float
const __m128i q8 = _mm_loadl_epi64((const __m128i *)(y[ib].qs + j));
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q8));
const __m256 val = mxfp_dequant_avx2(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
// Extract sign (bit 7), exponent, mantissa
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80));
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
// Normal path: IEEE bits = (sign << 24) | ((exp + offset) << 23) | (mant << mant_shift)
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, 24),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
// Subnormal path: |val| = mant * sub_scale, then apply sign
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24)));
// Select: subnormal when exp == 0, else normal
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
// Accumulate: val * scale * q8_float
acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc);
}
}
*s = hsum_float_8(acc);
}
#endif
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__AVX2__)
// E4M3: sign(1) exp(4) mant(3), bias=7
ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy,
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
#else
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
// AVX2-optimized MXFP6 × Q8_0 dot product.
// Unpacks tight 6-bit packing (4 values per 3 bytes), then dequants to float.
#if defined(__AVX2__)
static inline void ggml_vec_dot_mxfp6_q8_0_avx2(
// Unified FP6 × Q8_0 dot product (works for E2M3 and E3M2).
static void ggml_vec_dot_mxfp6_q8_0_avx2(
int n, float * GGML_RESTRICT s,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
size_t block_size,
// FP6 format parameters:
const int exp_mask, // 0x3 for E2M3, 0x7 for E3M2
const int mant_mask, // 0x7 for E2M3, 0x3 for E3M2
const int exp_shift, // 3 for E2M3, 2 for E3M2
const int ieee_exp_off, // 126 for E2M3, 124 for E3M2
const int mant_shift, // 20 for E2M3, 21 for E3M2
const float sub_scale) { // 1/8 for E2M3, 1/16 for E3M2
const mxfp_avx2_traits_t * t) {
assert(n % QK_MXFP6 == 0);
const int nb = n / QK_MXFP6;
const block_q8_0 * GGML_RESTRICT y = vy;
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
const __m256i v_zero = _mm256_setzero_si256();
__m256 acc = _mm256_setzero_ps();
for (int ib = 0; ib < nb; ++ib) {
const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size);
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
const __m256 v_scale = _mm256_set1_ps(
GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
// Process 32 FP6 elements in 4 groups of 8 (each group = 2 × 3-byte packs)
for (int j = 0; j < 32; j += 8) {
// Unpack 8 FP6 values from 6 bytes (two groups of 3 bytes → 4 values each)
uint8_t unpacked[8];
{
const uint8_t * p = xb->qs + (j * 3 / 4);
const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[0] = (pk0 >> 0) & 0x3F;
unpacked[1] = (pk0 >> 6) & 0x3F;
unpacked[2] = (pk0 >> 12) & 0x3F;
unpacked[3] = (pk0 >> 18) & 0x3F;
}
{
const uint8_t * p = xb->qs + ((j + 4) * 3 / 4);
const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[4] = (pk1 >> 0) & 0x3F;
unpacked[5] = (pk1 >> 6) & 0x3F;
unpacked[6] = (pk1 >> 12) & 0x3F;
unpacked[7] = (pk1 >> 18) & 0x3F;
}
const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j);
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
_mm_loadl_epi64((const __m128i *)(y[ib].qs + j))));
// Widen 8 bytes → 8 int32s
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked);
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
// Load 8 Q8_0 int8 values → float
const __m128i q8 = _mm_loadl_epi64((const __m128i *)(y[ib].qs + j));
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q8));
// Extract sign (bit 5 for FP6), exponent, mantissa
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20));
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
// Normal: IEEE bits = (sign << 26) | ((exp + offset) << 23) | (mant << mant_shift)
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, 26),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
// Subnormal: |val| = mant * sub_scale, apply sign
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26)));
// Select: subnormal when exp == 0
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
const __m256 val = mxfp_dequant_avx2(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc);
}
@ -3993,162 +3954,140 @@ static inline void ggml_vec_dot_mxfp6_q8_0_avx2(
*s = hsum_float_8(acc);
}
#endif
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__AVX2__)
// E2M3: sign(1) exp(2) mant(3), bias=1
ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, sizeof(block_mxfp6),
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
#else
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
// ── MXFP FP8/FP6 dequantize_row (AoS) ─────────────────────────────────────
// ---- MXFP dequantize_row (to_float) — AVX2-optimized ----
// Extracts the SIMD dequant logic from vec_dot above, writing floats to output buffer
// instead of accumulating a dot product.
#if defined(__AVX2__)
static inline void dequantize_row_mxfp8_avx2(
static void dequantize_row_mxfp8_avx2(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
const int exp_mask, const int mant_mask, const int exp_shift,
const int ieee_exp_off, const int mant_shift, const float sub_scale) {
const mxfp_avx2_traits_t * t) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const block_mxfp8 * GGML_RESTRICT x = vx;
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
const __m256i v_zero = _mm256_setzero_si256();
for (int ib = 0; ib < nb; ++ib) {
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(x[ib].e));
for (int j = 0; j < 32; j += 8) {
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(x[ib].qs + j));
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
const __m256i v_raw = _mm256_cvtepu8_epi32(
_mm_loadl_epi64((const __m128i *)(x[ib].qs + j)));
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80));
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, 24),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24)));
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
const __m256 val = mxfp_dequant_avx2(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
_mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale));
}
}
}
static inline void dequantize_row_mxfp6_avx2(
static void dequantize_row_mxfp6_avx2(
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
size_t block_size,
const int exp_mask, const int mant_mask, const int exp_shift,
const int ieee_exp_off, const int mant_shift, const float sub_scale) {
const mxfp_avx2_traits_t * t) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
const __m256i v_zero = _mm256_setzero_si256();
for (int ib = 0; ib < nb; ++ib) {
const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size);
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(xb->e));
for (int j = 0; j < 32; j += 8) {
// Unpack 8 FP6 values from 6 bytes (two groups of 3 bytes → 4 values each)
uint8_t unpacked[8];
{
const uint8_t * p = xb->qs + (j * 3 / 4);
const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[0] = (pk0 >> 0) & 0x3F;
unpacked[1] = (pk0 >> 6) & 0x3F;
unpacked[2] = (pk0 >> 12) & 0x3F;
unpacked[3] = (pk0 >> 18) & 0x3F;
}
{
const uint8_t * p = xb->qs + ((j + 4) * 3 / 4);
const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[4] = (pk1 >> 0) & 0x3F;
unpacked[5] = (pk1 >> 6) & 0x3F;
unpacked[6] = (pk1 >> 12) & 0x3F;
unpacked[7] = (pk1 >> 18) & 0x3F;
}
const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j);
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked);
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20));
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, 26),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26)));
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
const __m256 val = mxfp_dequant_avx2(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
_mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale));
}
}
}
#endif
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp8_avx2(x, y, k,
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
#else
dequantize_row_mxfp8_cpu_generic(x, y, k);
#endif
// ── MXFP SoA dequant (flash attention) ─────────────────────────────────────
static void dequantize_row_mxfp8_soa_avx2(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const mxfp_avx2_traits_t * t) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
const __m256i v_zero = _mm256_setzero_si256();
for (int ib = 0; ib < nb; ++ib) {
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 8) {
const __m256i v_raw = _mm256_cvtepu8_epi32(
_mm_loadl_epi64((const __m128i *)(qs + j)));
const __m256 val = mxfp_dequant_avx2(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
_mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale));
}
}
}
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp6_avx2(x, y, k, sizeof(block_mxfp6),
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
#else
dequantize_row_mxfp6_cpu_generic(x, y, k);
#endif
static void dequantize_row_mxfp6_soa_avx2(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const mxfp_avx2_traits_t * t) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
const __m256i v_zero = _mm256_setzero_si256();
for (int ib = 0; ib < nb; ++ib) {
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 8) {
const __m256i v_raw = unpack_fp6x8_avx2(qs, j);
const __m256 val = mxfp_dequant_avx2(v_raw,
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
_mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale));
}
}
}
// SoA dequant for flash attention — contiguous qs region + separate e8m0 region
#if defined(__AVX2__)
static inline void dequantize_row_mxfp4_soa_avx2(
// MXFP4 SoA dequant — LUT-based, no IEEE reconstruction needed.
static void dequantize_row_mxfp4_soa_avx2(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
assert(k % QK_MXFP4 == 0);
const int nb = k / QK_MXFP4;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
const char * qs_base = (const char *)src;
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
const __m128i m4b = _mm_set1_epi8(0x0f);
@ -4163,13 +4102,11 @@ static inline void dequantize_row_mxfp4_soa_avx2(
const __m128i lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits, m4b));
const __m128i hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4b));
// lo nibbles → first 16 floats
const __m256i lo32_0 = _mm256_cvtepi8_epi32(lo);
const __m256i lo32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(lo, 8));
_mm256_storeu_ps(y + i * QK_MXFP4 + 0, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_0), v_scale));
_mm256_storeu_ps(y + i * QK_MXFP4 + 8, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_1), v_scale));
// hi nibbles → second 16 floats
const __m256i hi32_0 = _mm256_cvtepi8_epi32(hi);
const __m256i hi32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(hi, 8));
_mm256_storeu_ps(y + i * QK_MXFP4 + 16, _mm256_mul_ps(_mm256_cvtepi32_ps(hi32_0), v_scale));
@ -4177,116 +4114,45 @@ static inline void dequantize_row_mxfp4_soa_avx2(
}
}
static inline void dequantize_row_mxfp8_soa_avx2(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const int exp_mask, const int mant_mask, const int exp_shift,
const int ieee_exp_off, const int mant_shift, const float sub_scale) {
assert(k % QK_MXFP8 == 0);
const int nb = k / QK_MXFP8;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
#endif // __AVX2__
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
const __m256i v_zero = _mm256_setzero_si256();
// ── Public dispatch functions ──────────────────────────────────────────────
for (int ib = 0; ib < nb; ++ib) {
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 8) {
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(qs + j));
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80));
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, 24),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24)));
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
_mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale));
}
}
}
static inline void dequantize_row_mxfp6_soa_avx2(
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
const int exp_mask, const int mant_mask, const int exp_shift,
const int ieee_exp_off, const int mant_shift, const float sub_scale) {
assert(k % QK_MXFP6 == 0);
const int nb = k / QK_MXFP6;
const char * row = (const char *)src;
const char * qs_base = row;
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
const __m256i v_zero = _mm256_setzero_si256();
for (int ib = 0; ib < nb; ++ib) {
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK));
for (int j = 0; j < 32; j += 8) {
uint8_t unpacked[8];
{
const uint8_t * p = qs + (j * 3 / 4);
const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[0] = (pk0 >> 0) & 0x3F;
unpacked[1] = (pk0 >> 6) & 0x3F;
unpacked[2] = (pk0 >> 12) & 0x3F;
unpacked[3] = (pk0 >> 18) & 0x3F;
}
{
const uint8_t * p = qs + ((j + 4) * 3 / 4);
const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
unpacked[4] = (pk1 >> 0) & 0x3F;
unpacked[5] = (pk1 >> 6) & 0x3F;
unpacked[6] = (pk1 >> 12) & 0x3F;
unpacked[7] = (pk1 >> 18) & 0x3F;
}
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked);
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20));
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
const __m256i ieee = _mm256_or_si256(
_mm256_or_si256(_mm256_slli_epi32(sign, 26),
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
_mm256_slli_epi32(mant, mant_shift));
const __m256 normal = _mm256_castsi256_ps(ieee);
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26)));
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
_mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale));
}
}
}
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__AVX2__)
ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E4M3);
#else
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
#if defined(__AVX2__)
ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E2M3);
#else
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp8_avx2(x, y, k, &MXFP_TRAITS_E4M3);
#else
dequantize_row_mxfp8_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp6_avx2(x, y, k, &MXFP_TRAITS_E2M3);
#else
dequantize_row_mxfp6_cpu_generic(x, y, k);
#endif
}
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
@ -4298,9 +4164,7 @@ void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp8_soa_avx2(x, y, k,
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
dequantize_row_mxfp8_soa_avx2(x, y, k, &MXFP_TRAITS_E4M3);
#else
dequantize_row_mxfp8_soa_cpu_generic(x, y, k);
#endif
@ -4308,9 +4172,7 @@ void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES
void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
#if defined(__AVX2__)
dequantize_row_mxfp6_soa_avx2(x, y, k,
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
dequantize_row_mxfp6_soa_avx2(x, y, k, &MXFP_TRAITS_E2M3);
#else
dequantize_row_mxfp6_soa_cpu_generic(x, y, k);
#endif

View File

@ -8297,8 +8297,8 @@ static mxfp_fa_params mxfp_fa_params_init(
if (is_mxfp_k) {
switch (k->type) {
case GGML_TYPE_MXFP4_E2M1: p.q_quantize = quantize_row_mxfp4_soa; p.k_dequantize = dequantize_row_mxfp4_soa_cpu; break;
case GGML_TYPE_MXFP8_E4M3: p.q_quantize = quantize_row_mxfp8_soa; p.k_dequantize = dequantize_row_mxfp8_soa_cpu; break;
case GGML_TYPE_MXFP4_E2M1: p.q_quantize = quantize_row_mxfp4_soa; p.k_dequantize = dequantize_row_mxfp4_soa_cpu; break;
case GGML_TYPE_MXFP8_E4M3: p.q_quantize = quantize_row_mxfp8_soa; p.k_dequantize = dequantize_row_mxfp8_soa_cpu; break;
case GGML_TYPE_MXFP6_E2M3: p.q_quantize = quantize_row_mxfp6_soa; p.k_dequantize = dequantize_row_mxfp6_soa_cpu; break;
default: GGML_ABORT("unsupported MXFP K type");
}
@ -8306,8 +8306,8 @@ static mxfp_fa_params mxfp_fa_params_init(
if (is_mxfp_v) {
switch (v->type) {
case GGML_TYPE_MXFP4_E2M1: p.v_dequantize = dequantize_row_mxfp4_soa_cpu; break;
case GGML_TYPE_MXFP8_E4M3: p.v_dequantize = dequantize_row_mxfp8_soa_cpu; break;
case GGML_TYPE_MXFP4_E2M1: p.v_dequantize = dequantize_row_mxfp4_soa_cpu; break;
case GGML_TYPE_MXFP8_E4M3: p.v_dequantize = dequantize_row_mxfp8_soa_cpu; break;
case GGML_TYPE_MXFP6_E2M3: p.v_dequantize = dequantize_row_mxfp6_soa_cpu; break;
default: GGML_ABORT("unsupported MXFP V type");
}
@ -8328,6 +8328,7 @@ static mxfp_fa_params mxfp_fa_params_init(
// Per-head SoA addressing for multihead mode.
// Precompute byte offsets so the hot loop can skip per-head pointer math.
// qs_per_block values from centralized MXFP_QS_PER_BLOCK_* defines in ggml-common.h.
auto mxfp_qs_per_block = [](ggml_type type) -> int {
switch (type) {
case GGML_TYPE_MXFP4_E2M1: return MXFP4_SOA_QS_PER_BLOCK;
@ -8341,7 +8342,6 @@ static mxfp_fa_params mxfp_fa_params_init(
p.k_qs_per_block = mxfp_qs_per_block(k->type);
p.k_blocks_per_head = (int)(DK / 32);
p.k_head_qs_bytes = p.k_blocks_per_head * p.k_qs_per_block;
// e8m0 offset from row start = total_blocks * qs_per_block
const int64_t k_total_blocks = p.k_multihead ? nek2 * p.k_blocks_per_head : p.k_blocks_per_head;
p.k_head_e8m0_offset = k_total_blocks * p.k_qs_per_block;
}

View File

@ -61,7 +61,7 @@ GGML_API void quantize_row_mxfp4_soa (const float * GGML_RESTRICT x, void * G
GGML_API void dequantize_row_mxfp4_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp8_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp8_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void quantize_row_mxfp6_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
GGML_API void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);