ggml: refactor x86 AVX2 and ARM NEON MXFP dequant — shared traits and helpers
Add mxfp_dequant_traits_t to ggml-common.h as single source of truth for MXFP IEEE-754 reconstruction parameters. Define static const instances for all 4 formats (E4M3, E5M2, E2M3, E3M2), ready for CUDA/Metal/Vulkan reuse. Extract shared dequant and FP6 unpack helpers on both architectures, replacing duplicated inline code and macros. Net -215 lines.
This commit is contained in:
parent
c913ab36d2
commit
b8e8d291d1
|
|
@ -267,6 +267,47 @@ static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4
|
|||
#define MXFP6_E3M2_MANT_SHIFT 21 // 23-2
|
||||
#define MXFP6_E3M2_SUB_SCALE (1.0f/16.0f) // 2^(-4) = 2^(1-3-2)
|
||||
|
||||
// Unified MXFP dequantization traits for SIMD backends (CPU x86/ARM, CUDA, Metal, Vulkan).
|
||||
// Contains all parameters needed for IEEE-754 bit reconstruction of FP8/FP6 elements.
|
||||
// FP4 uses LUT-based dequant and does not need this struct.
|
||||
typedef struct {
|
||||
int exp_mask; // (1<<E)-1: exponent field mask
|
||||
int mant_mask; // (1<<M)-1: mantissa field mask
|
||||
int exp_shift; // M: right-shift to extract exponent
|
||||
int ieee_exp_off; // 127-bias: offset to convert to IEEE exponent
|
||||
int mant_shift; // 23-M: left-shift to align mantissa in IEEE float
|
||||
float sub_scale; // 2^(1-bias-M): subnormal scale factor
|
||||
int sign_mask; // 0x80 for 8-bit, 0x20 for 6-bit formats
|
||||
int sign_shift; // 24 for 8-bit, 26 for 6-bit formats
|
||||
int qs_per_block; // bytes of quantized data per 32-element block
|
||||
int emax_offset; // type-specific offset for E8M0 MSE search
|
||||
} mxfp_dequant_traits_t;
|
||||
|
||||
// Static const trait instances for each MXFP format.
|
||||
// Gated by GGML_COMMON_IMPL to ensure single definition per translation unit.
|
||||
#if defined(GGML_COMMON_IMPL)
|
||||
static const mxfp_dequant_traits_t MXFP_TRAITS_E4M3 = {
|
||||
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
|
||||
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE,
|
||||
0x80, 24, MXFP_QS_PER_BLOCK_E4M3, MXFP8_E4M3_EMAX_OFFSET
|
||||
};
|
||||
static const mxfp_dequant_traits_t MXFP_TRAITS_E5M2 = {
|
||||
MXFP8_E5M2_EXP_MASK, MXFP8_E5M2_MANT_MASK, MXFP8_E5M2_EXP_SHIFT,
|
||||
MXFP8_E5M2_IEEE_EXP_OFF, MXFP8_E5M2_MANT_SHIFT, MXFP8_E5M2_SUB_SCALE,
|
||||
0x80, 24, MXFP_QS_PER_BLOCK_E5M2, MXFP8_E5M2_EMAX_OFFSET
|
||||
};
|
||||
static const mxfp_dequant_traits_t MXFP_TRAITS_E2M3 = {
|
||||
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
|
||||
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE,
|
||||
0x20, 26, MXFP_QS_PER_BLOCK_E2M3, MXFP6_E2M3_EMAX_OFFSET
|
||||
};
|
||||
static const mxfp_dequant_traits_t MXFP_TRAITS_E3M2 = {
|
||||
MXFP6_E3M2_EXP_MASK, MXFP6_E3M2_MANT_MASK, MXFP6_E3M2_EXP_SHIFT,
|
||||
MXFP6_E3M2_IEEE_EXP_OFF, MXFP6_E3M2_MANT_SHIFT, MXFP6_E3M2_SUB_SCALE,
|
||||
0x20, 26, MXFP_QS_PER_BLOCK_E3M2, MXFP6_E3M2_EMAX_OFFSET
|
||||
};
|
||||
#endif // GGML_COMMON_IMPL
|
||||
|
||||
#define QK_MXFP4 32
|
||||
typedef struct {
|
||||
uint8_t e; // E8M0
|
||||
|
|
|
|||
|
|
@ -4134,363 +4134,326 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
#endif
|
||||
}
|
||||
|
||||
// NEON-optimized MXFP8 × Q8_0 dot product.
|
||||
// Dequants FP8 elements to float via IEEE 754 bit construction, then dots against Q8_0.
|
||||
// Parameters encode the FP8 format: sign_shift, exp_mask, mant_mask, ieee_exp_bias, mant_shift, sub_scale.
|
||||
// ── MXFP FP8/FP6 NEON helpers ──────────────────────────────────────────────
|
||||
// Shared IEEE-754 bit reconstruction and FP6 unpacking used by vec_dot,
|
||||
// dequantize_row, and SoA dequant functions.
|
||||
//
|
||||
// NEON requires vshlq_n_u32 to have a compile-time literal constant, so we use
|
||||
// two separate helpers for FP8 (sign at bit 7, shift 24) and FP6 (sign at bit 5,
|
||||
// shift 26) rather than a single parameterized function.
|
||||
|
||||
#if defined(__ARM_NEON)
|
||||
static inline void ggml_vec_dot_mxfp8_q8_0_neon(
|
||||
|
||||
// Use shared mxfp_dequant_traits_t from ggml-common.h.
|
||||
#define mxfp_neon_traits_t mxfp_dequant_traits_t
|
||||
|
||||
// Dequantize 4 raw FP8 values (uint32x4_t) → 4 IEEE-754 floats.
|
||||
// Sign bit at position 7, sign shift = 24.
|
||||
static inline float32x4_t mxfp8_dequant_neon(
|
||||
const uint32x4_t v_raw,
|
||||
const uint32x4_t v_exp_mask, const uint32x4_t v_mant_mask,
|
||||
const uint32x4_t v_ieee_off, const float32x4_t v_sub_sc,
|
||||
const int32x4_t v_neg_exp_shift, const int32x4_t v_mant_shift) {
|
||||
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80));
|
||||
const uint32x4_t exp = vandq_u32(vshlq_u32(v_raw, v_neg_exp_shift), v_exp_mask);
|
||||
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask);
|
||||
|
||||
const uint32x4_t ieee = vorrq_u32(
|
||||
vorrq_u32(vshlq_n_u32(sign, 24),
|
||||
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)),
|
||||
vshlq_u32(mant, v_mant_shift));
|
||||
const float32x4_t normal = vreinterpretq_f32_u32(ieee);
|
||||
|
||||
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc);
|
||||
const float32x4_t sub_val = vreinterpretq_f32_u32(
|
||||
vorrq_u32(vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)));
|
||||
|
||||
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0));
|
||||
return vbslq_f32(is_sub, sub_val, normal);
|
||||
}
|
||||
|
||||
// Dequantize 4 raw FP6 values (uint32x4_t) → 4 IEEE-754 floats.
|
||||
// Sign bit at position 5, sign shift = 26.
|
||||
static inline float32x4_t mxfp6_dequant_neon(
|
||||
const uint32x4_t v_raw,
|
||||
const uint32x4_t v_exp_mask, const uint32x4_t v_mant_mask,
|
||||
const uint32x4_t v_ieee_off, const float32x4_t v_sub_sc,
|
||||
const int32x4_t v_neg_exp_shift, const int32x4_t v_mant_shift) {
|
||||
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20));
|
||||
const uint32x4_t exp = vandq_u32(vshlq_u32(v_raw, v_neg_exp_shift), v_exp_mask);
|
||||
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask);
|
||||
|
||||
const uint32x4_t ieee = vorrq_u32(
|
||||
vorrq_u32(vshlq_n_u32(sign, 26),
|
||||
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)),
|
||||
vshlq_u32(mant, v_mant_shift));
|
||||
const float32x4_t normal = vreinterpretq_f32_u32(ieee);
|
||||
|
||||
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc);
|
||||
const float32x4_t sub_val = vreinterpretq_f32_u32(
|
||||
vorrq_u32(vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)));
|
||||
|
||||
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0));
|
||||
return vbslq_f32(is_sub, sub_val, normal);
|
||||
}
|
||||
|
||||
// Unpack 4 tightly-packed 6-bit values from 3 bytes, widen to uint32x4_t.
|
||||
static inline uint32x4_t unpack_fp6x4_neon(const uint8_t * p) {
|
||||
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
uint8_t u[4];
|
||||
u[0] = (pk >> 0) & 0x3F;
|
||||
u[1] = (pk >> 6) & 0x3F;
|
||||
u[2] = (pk >> 12) & 0x3F;
|
||||
u[3] = (pk >> 18) & 0x3F;
|
||||
const uint8x8_t raw8 = vcreate_u8(
|
||||
(uint64_t)u[0] | ((uint64_t)u[1] << 8) |
|
||||
((uint64_t)u[2] << 16) | ((uint64_t)u[3] << 24));
|
||||
return vmovl_u16(vget_low_u16(vmovl_u8(raw8)));
|
||||
}
|
||||
|
||||
// Widen 8 raw bytes to two uint32x4_t halves.
|
||||
static inline void widen_u8x8_to_u32x4x2(const uint8_t * src,
|
||||
uint32x4_t * lo, uint32x4_t * hi) {
|
||||
const uint8x8_t raw8 = vld1_u8(src);
|
||||
const uint16x8_t raw16 = vmovl_u8(raw8);
|
||||
*lo = vmovl_u16(vget_low_u16(raw16));
|
||||
*hi = vmovl_u16(vget_high_u16(raw16));
|
||||
}
|
||||
|
||||
// Widen 8 Q8_0 int8 values to two float32x4_t halves.
|
||||
static inline void widen_s8x8_to_f32x4x2(const int8_t * src,
|
||||
float32x4_t * lo, float32x4_t * hi) {
|
||||
const int8x8_t q8 = vld1_s8(src);
|
||||
const int16x8_t q16 = vmovl_s8(q8);
|
||||
*lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16)));
|
||||
*hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16)));
|
||||
}
|
||||
|
||||
// ── MXFP FP8/FP6 vec_dot ──────────────────────────────────────────────────
|
||||
|
||||
static void ggml_vec_dot_mxfp8_q8_0_neon(
|
||||
int n, float * GGML_RESTRICT s,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
// FP8 format parameters:
|
||||
const uint32_t exp_mask, // 0xF for E4M3, 0x1F for E5M2
|
||||
const uint32_t mant_mask, // 0x7 for E4M3, 0x3 for E5M2
|
||||
const int exp_shift, // 3 for E4M3, 2 for E5M2
|
||||
const uint32_t ieee_exp_off, // 120 for E4M3, 112 for E5M2
|
||||
const int mant_shift, // 20 for E4M3, 21 for E5M2
|
||||
const float sub_scale) { // 1/512 for E4M3, 1/65536 for E5M2
|
||||
const mxfp_neon_traits_t * t) {
|
||||
assert(n % QK_MXFP8 == 0);
|
||||
const int nb = n / QK_MXFP8;
|
||||
const block_mxfp8 * GGML_RESTRICT x = vx;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
|
||||
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
|
||||
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
|
||||
|
||||
float32x4_t acc0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t acc1 = vdupq_n_f32(0.0f);
|
||||
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
|
||||
// Use variable shifts (vshlq_u32) instead of constant shifts (vshlq_n_u32)
|
||||
// because exp_shift/mant_shift are function parameters, not compile-time constants.
|
||||
// Clang requires _n_ intrinsics to have literal constant arguments.
|
||||
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
|
||||
const int32x4_t v_mant_shift = vdupq_n_s32(mant_shift);
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const float scale = GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d);
|
||||
const float32x4_t v_scale = vdupq_n_f32(scale);
|
||||
const float32x4_t v_scale = vdupq_n_f32(
|
||||
GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
// Process 32 FP8 elements in 8 groups of 4
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
// Load 8 FP8 bytes, extend to two uint32x4_t
|
||||
const uint8x8_t raw8 = vld1_u8(x[ib].qs + j);
|
||||
const uint16x8_t raw16 = vmovl_u8(raw8);
|
||||
const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16));
|
||||
const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16));
|
||||
uint32x4_t v_lo, v_hi;
|
||||
widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi);
|
||||
|
||||
// Load 8 Q8_0 int8 values, extend to two int32x4_t → float32x4_t
|
||||
const int8x8_t q8 = vld1_s8(y[ib].qs + j);
|
||||
const int16x8_t q16 = vmovl_s8(q8);
|
||||
const float32x4_t qf_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16)));
|
||||
const float32x4_t qf_hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16)));
|
||||
float32x4_t qf_lo, qf_hi;
|
||||
widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi);
|
||||
|
||||
// Dequant FP8 → float for both groups of 4
|
||||
#define DEQUANT_FP8_NEON(v_raw, qf, acc) do { \
|
||||
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \
|
||||
const uint32x4_t exp = vandq_u32( \
|
||||
vshlq_u32(v_raw, v_neg_exp_shift), \
|
||||
v_exp_mask); \
|
||||
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \
|
||||
/* Normal: IEEE bits = (exp + offset) << 23 | mant << mant_shift */ \
|
||||
const uint32x4_t ieee = vorrq_u32( \
|
||||
vorrq_u32(vshlq_n_u32(sign, 24), \
|
||||
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \
|
||||
vshlq_u32(mant, v_mant_shift)); \
|
||||
const float32x4_t normal = vreinterpretq_f32_u32(ieee); \
|
||||
/* Subnormal: sign * mant * sub_scale */ \
|
||||
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \
|
||||
const uint32x4_t sub_bits = vorrq_u32( \
|
||||
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \
|
||||
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \
|
||||
/* Select: subnormal when exp == 0, else normal */ \
|
||||
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \
|
||||
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \
|
||||
/* Multiply by scale and Q8 value, accumulate */ \
|
||||
(acc) = vfmaq_f32((acc), vmulq_f32(val, v_scale), qf); \
|
||||
} while (0)
|
||||
const float32x4_t val_lo = mxfp8_dequant_neon(v_lo,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
const float32x4_t val_hi = mxfp8_dequant_neon(v_hi,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
|
||||
DEQUANT_FP8_NEON(v_lo, qf_lo, acc0);
|
||||
DEQUANT_FP8_NEON(v_hi, qf_hi, acc1);
|
||||
#undef DEQUANT_FP8_NEON
|
||||
acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo);
|
||||
acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi);
|
||||
}
|
||||
}
|
||||
|
||||
*s = vaddvq_f32(vaddq_f32(acc0, acc1));
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__ARM_NEON)
|
||||
// E4M3: sign(1) exp(4) mant(3), bias=7
|
||||
ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy,
|
||||
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
|
||||
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
|
||||
#else
|
||||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
// NEON-optimized MXFP6 × Q8_0 dot product.
|
||||
// Unpacks tight 6-bit packing (4 values per 3 bytes), then dequants to float.
|
||||
#if defined(__ARM_NEON)
|
||||
static inline void ggml_vec_dot_mxfp6_q8_0_neon(
|
||||
static void ggml_vec_dot_mxfp6_q8_0_neon(
|
||||
int n, float * GGML_RESTRICT s,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
size_t block_size,
|
||||
// FP6 format parameters:
|
||||
const uint32_t exp_mask, // 0x3 for E2M3, 0x7 for E3M2
|
||||
const uint32_t mant_mask, // 0x7 for E2M3, 0x3 for E3M2
|
||||
const int exp_shift, // 3 for E2M3, 2 for E3M2
|
||||
const uint32_t ieee_exp_off, // 126 for E2M3, 124 for E3M2
|
||||
const int mant_shift, // 20 for E2M3, 21 for E3M2
|
||||
const float sub_scale) { // 1/8 for E2M3, 1/16 for E3M2
|
||||
const mxfp_neon_traits_t * t) {
|
||||
assert(n % QK_MXFP6 == 0);
|
||||
const int nb = n / QK_MXFP6;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
|
||||
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
|
||||
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
|
||||
|
||||
float32x4_t acc0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t acc1 = vdupq_n_f32(0.0f);
|
||||
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
|
||||
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
|
||||
const int32x4_t v_mant_shift = vdupq_n_s32(mant_shift);
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size);
|
||||
const float scale = GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d);
|
||||
const float32x4_t v_scale = vdupq_n_f32(scale);
|
||||
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
|
||||
const float32x4_t v_scale = vdupq_n_f32(
|
||||
GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
// Process 32 FP6 elements: 8 groups of 4, each packed in 3 bytes
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
// Unpack two groups of 4 FP6 values (6 bytes → 8 values)
|
||||
uint8_t unpacked[8];
|
||||
// Group 1: 3 bytes → 4 values
|
||||
{
|
||||
const uint8_t * p = xb->qs + (j * 3 / 4);
|
||||
const uint32_t packed = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
unpacked[0] = (packed >> 0) & 0x3F;
|
||||
unpacked[1] = (packed >> 6) & 0x3F;
|
||||
unpacked[2] = (packed >> 12) & 0x3F;
|
||||
unpacked[3] = (packed >> 18) & 0x3F;
|
||||
}
|
||||
// Group 2: next 3 bytes → 4 values
|
||||
{
|
||||
const uint8_t * p = xb->qs + ((j + 4) * 3 / 4);
|
||||
const uint32_t packed = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
unpacked[4] = (packed >> 0) & 0x3F;
|
||||
unpacked[5] = (packed >> 6) & 0x3F;
|
||||
unpacked[6] = (packed >> 12) & 0x3F;
|
||||
unpacked[7] = (packed >> 18) & 0x3F;
|
||||
}
|
||||
const uint32x4_t v_lo = unpack_fp6x4_neon(xb->qs + (j * 3 / 4));
|
||||
const uint32x4_t v_hi = unpack_fp6x4_neon(xb->qs + ((j + 4) * 3 / 4));
|
||||
|
||||
// Extend to uint32x4_t
|
||||
const uint8x8_t raw8 = vld1_u8(unpacked);
|
||||
const uint16x8_t raw16 = vmovl_u8(raw8);
|
||||
const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16));
|
||||
const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16));
|
||||
float32x4_t qf_lo, qf_hi;
|
||||
widen_s8x8_to_f32x4x2(y[ib].qs + j, &qf_lo, &qf_hi);
|
||||
|
||||
// Load Q8_0 int8 values
|
||||
const int8x8_t q8 = vld1_s8(y[ib].qs + j);
|
||||
const int16x8_t q16 = vmovl_s8(q8);
|
||||
const float32x4_t qf_lo = vcvtq_f32_s32(vmovl_s16(vget_low_s16(q16)));
|
||||
const float32x4_t qf_hi = vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16)));
|
||||
const float32x4_t val_lo = mxfp6_dequant_neon(v_lo,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
const float32x4_t val_hi = mxfp6_dequant_neon(v_hi,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
|
||||
// Dequant FP6 → float (same IEEE construction as FP8, sign bit at position 5)
|
||||
#define DEQUANT_FP6_NEON(v_raw, qf, acc) do { \
|
||||
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20)); \
|
||||
const uint32x4_t exp = vandq_u32( \
|
||||
vshlq_u32(v_raw, v_neg_exp_shift), \
|
||||
v_exp_mask); \
|
||||
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \
|
||||
const uint32x4_t ieee = vorrq_u32( \
|
||||
vorrq_u32(vshlq_n_u32(sign, 26), \
|
||||
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \
|
||||
vshlq_u32(mant, v_mant_shift)); \
|
||||
const float32x4_t normal = vreinterpretq_f32_u32(ieee); \
|
||||
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \
|
||||
const uint32x4_t sub_bits = vorrq_u32( \
|
||||
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26)); \
|
||||
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \
|
||||
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \
|
||||
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \
|
||||
(acc) = vfmaq_f32((acc), vmulq_f32(val, v_scale), qf); \
|
||||
} while (0)
|
||||
|
||||
DEQUANT_FP6_NEON(v_lo, qf_lo, acc0);
|
||||
DEQUANT_FP6_NEON(v_hi, qf_hi, acc1);
|
||||
#undef DEQUANT_FP6_NEON
|
||||
acc0 = vfmaq_f32(acc0, vmulq_f32(val_lo, v_scale), qf_lo);
|
||||
acc1 = vfmaq_f32(acc1, vmulq_f32(val_hi, v_scale), qf_hi);
|
||||
}
|
||||
}
|
||||
|
||||
*s = vaddvq_f32(vaddq_f32(acc0, acc1));
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__ARM_NEON)
|
||||
// E2M3: sign(1) exp(2) mant(3), bias=1
|
||||
ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, sizeof(block_mxfp6),
|
||||
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
|
||||
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
|
||||
#else
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
// ── MXFP FP8/FP6 dequantize_row (AoS) ─────────────────────────────────────
|
||||
|
||||
// ---- MXFP dequantize_row (to_float) — NEON-optimized ----
|
||||
|
||||
#if defined(__ARM_NEON)
|
||||
static inline void dequantize_row_mxfp8_neon(
|
||||
static void dequantize_row_mxfp8_neon(
|
||||
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
|
||||
const uint32_t exp_mask, const uint32_t mant_mask,
|
||||
const int exp_shift, const uint32_t ieee_exp_off,
|
||||
const int mant_shift, const float sub_scale) {
|
||||
const mxfp_neon_traits_t * t) {
|
||||
assert(k % QK_MXFP8 == 0);
|
||||
const int nb = k / QK_MXFP8;
|
||||
const block_mxfp8 * GGML_RESTRICT x = vx;
|
||||
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
|
||||
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
|
||||
const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift);
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
|
||||
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
|
||||
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(x[ib].e));
|
||||
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
const uint8x8_t raw8 = vld1_u8(x[ib].qs + j);
|
||||
const uint16x8_t raw16 = vmovl_u8(raw8);
|
||||
const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16));
|
||||
const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16));
|
||||
uint32x4_t v_lo, v_hi;
|
||||
widen_u8x8_to_u32x4x2(x[ib].qs + j, &v_lo, &v_hi);
|
||||
|
||||
#define DEQUANT_FP8_STORE(v_raw, dst) do { \
|
||||
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \
|
||||
const uint32x4_t exp = vandq_u32( \
|
||||
vshlq_u32(v_raw, v_neg_exp_shift), \
|
||||
v_exp_mask); \
|
||||
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \
|
||||
const uint32x4_t ieee = vorrq_u32( \
|
||||
vorrq_u32(vshlq_n_u32(sign, 24), \
|
||||
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \
|
||||
vshlq_u32(mant, v_mant_shift_v)); \
|
||||
const float32x4_t normal = vreinterpretq_f32_u32(ieee); \
|
||||
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \
|
||||
const uint32x4_t sub_bits = vorrq_u32( \
|
||||
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \
|
||||
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \
|
||||
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \
|
||||
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \
|
||||
vst1q_f32(dst, vmulq_f32(val, v_scale)); \
|
||||
} while (0)
|
||||
const float32x4_t val_lo = mxfp8_dequant_neon(v_lo,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
const float32x4_t val_hi = mxfp8_dequant_neon(v_hi,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
|
||||
DEQUANT_FP8_STORE(v_lo, y + ib * QK_MXFP8 + j);
|
||||
DEQUANT_FP8_STORE(v_hi, y + ib * QK_MXFP8 + j + 4);
|
||||
#undef DEQUANT_FP8_STORE
|
||||
vst1q_f32(y + ib * QK_MXFP8 + j, vmulq_f32(val_lo, v_scale));
|
||||
vst1q_f32(y + ib * QK_MXFP8 + j + 4, vmulq_f32(val_hi, v_scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline void dequantize_row_mxfp6_neon(
|
||||
static void dequantize_row_mxfp6_neon(
|
||||
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
|
||||
size_t block_size,
|
||||
const uint32_t exp_mask, const uint32_t mant_mask,
|
||||
const int exp_shift, const uint32_t ieee_exp_off,
|
||||
const int mant_shift, const float sub_scale) {
|
||||
const mxfp_neon_traits_t * t) {
|
||||
assert(k % QK_MXFP6 == 0);
|
||||
const int nb = k / QK_MXFP6;
|
||||
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
|
||||
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
|
||||
const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift);
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
|
||||
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
|
||||
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size);
|
||||
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
|
||||
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32(xb->e));
|
||||
|
||||
for (int j = 0; j < 32; j += 4) {
|
||||
const uint8_t * p = xb->qs + (j * 3 / 4);
|
||||
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
uint8_t unpacked[4];
|
||||
unpacked[0] = (pk >> 0) & 0x3F;
|
||||
unpacked[1] = (pk >> 6) & 0x3F;
|
||||
unpacked[2] = (pk >> 12) & 0x3F;
|
||||
unpacked[3] = (pk >> 18) & 0x3F;
|
||||
const uint32x4_t v_raw = unpack_fp6x4_neon(xb->qs + (j * 3 / 4));
|
||||
|
||||
const uint8x8_t raw8 = vcreate_u8(
|
||||
(uint64_t)unpacked[0] | ((uint64_t)unpacked[1] << 8) |
|
||||
((uint64_t)unpacked[2] << 16) | ((uint64_t)unpacked[3] << 24));
|
||||
const uint32x4_t v_raw = vmovl_u16(vget_low_u16(vmovl_u8(raw8)));
|
||||
|
||||
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20));
|
||||
const uint32x4_t exp = vandq_u32(
|
||||
vshlq_u32(v_raw, v_neg_exp_shift),
|
||||
v_exp_mask);
|
||||
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask);
|
||||
|
||||
const uint32x4_t ieee = vorrq_u32(
|
||||
vorrq_u32(vshlq_n_u32(sign, 26),
|
||||
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)),
|
||||
vshlq_u32(mant, v_mant_shift_v));
|
||||
const float32x4_t normal = vreinterpretq_f32_u32(ieee);
|
||||
|
||||
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc);
|
||||
const uint32x4_t sub_bits = vorrq_u32(
|
||||
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26));
|
||||
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits);
|
||||
|
||||
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0));
|
||||
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal);
|
||||
const float32x4_t val = mxfp6_dequant_neon(v_raw,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
|
||||
vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__ARM_NEON)
|
||||
dequantize_row_mxfp8_neon(x, y, k,
|
||||
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
|
||||
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
|
||||
#else
|
||||
dequantize_row_mxfp8_cpu_generic(x, y, k);
|
||||
#endif
|
||||
// ── MXFP SoA dequant (flash attention) ─────────────────────────────────────
|
||||
|
||||
static void dequantize_row_mxfp8_soa_neon(
|
||||
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
|
||||
const mxfp_neon_traits_t * t) {
|
||||
assert(k % QK_MXFP8 == 0);
|
||||
const int nb = k / QK_MXFP8;
|
||||
const char * qs_base = (const char *)src;
|
||||
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
|
||||
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
|
||||
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
|
||||
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
|
||||
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK));
|
||||
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
uint32x4_t v_lo, v_hi;
|
||||
widen_u8x8_to_u32x4x2(qs + j, &v_lo, &v_hi);
|
||||
|
||||
const float32x4_t val_lo = mxfp8_dequant_neon(v_lo,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
const float32x4_t val_hi = mxfp8_dequant_neon(v_hi,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
|
||||
vst1q_f32(y + ib * QK_MXFP8 + j, vmulq_f32(val_lo, v_scale));
|
||||
vst1q_f32(y + ib * QK_MXFP8 + j + 4, vmulq_f32(val_hi, v_scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__ARM_NEON)
|
||||
dequantize_row_mxfp6_neon(x, y, k, sizeof(block_mxfp6),
|
||||
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
|
||||
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
|
||||
#else
|
||||
dequantize_row_mxfp6_cpu_generic(x, y, k);
|
||||
#endif
|
||||
static void dequantize_row_mxfp6_soa_neon(
|
||||
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
|
||||
const mxfp_neon_traits_t * t) {
|
||||
assert(k % QK_MXFP6 == 0);
|
||||
const int nb = k / QK_MXFP6;
|
||||
const char * qs_base = (const char *)src;
|
||||
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
|
||||
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(t->exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(t->mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(t->ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(t->sub_scale);
|
||||
const int32x4_t v_neg_exp = vdupq_n_s32(-(int)t->exp_shift);
|
||||
const int32x4_t v_mant_sh = vdupq_n_s32(t->mant_shift);
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
|
||||
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK));
|
||||
|
||||
for (int j = 0; j < 32; j += 4) {
|
||||
const uint32x4_t v_raw = unpack_fp6x4_neon(qs + (j * 3 / 4));
|
||||
|
||||
const float32x4_t val = mxfp6_dequant_neon(v_raw,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc, v_neg_exp, v_mant_sh);
|
||||
|
||||
vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- MXFP SoA dequantize_row (to_float) — NEON-optimized ----
|
||||
|
||||
#if defined(__ARM_NEON)
|
||||
static inline void dequantize_row_mxfp4_soa_neon(
|
||||
// MXFP4 SoA dequant — LUT-based, no IEEE reconstruction needed.
|
||||
static void dequantize_row_mxfp4_soa_neon(
|
||||
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
|
||||
assert(k % QK_MXFP4 == 0);
|
||||
const int nb = k / QK_MXFP4;
|
||||
const char * row = (const char *)src;
|
||||
const char * qs_base = row;
|
||||
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
|
||||
const char * qs_base = (const char *)src;
|
||||
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
|
||||
|
||||
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
|
|
@ -4527,122 +4490,45 @@ static inline void dequantize_row_mxfp4_soa_neon(
|
|||
}
|
||||
}
|
||||
|
||||
static inline void dequantize_row_mxfp8_soa_neon(
|
||||
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
|
||||
const uint32_t exp_mask, const uint32_t mant_mask,
|
||||
const int exp_shift, const uint32_t ieee_exp_off,
|
||||
const int mant_shift, const float sub_scale) {
|
||||
assert(k % QK_MXFP8 == 0);
|
||||
const int nb = k / QK_MXFP8;
|
||||
const char * row = (const char *)src;
|
||||
const char * qs_base = row;
|
||||
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
|
||||
#endif // __ARM_NEON
|
||||
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
|
||||
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
|
||||
const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift);
|
||||
// ── Public dispatch functions ──────────────────────────────────────────────
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
|
||||
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK));
|
||||
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
const uint8x8_t raw8 = vld1_u8(qs + j);
|
||||
const uint16x8_t raw16 = vmovl_u8(raw8);
|
||||
const uint32x4_t v_lo = vmovl_u16(vget_low_u16(raw16));
|
||||
const uint32x4_t v_hi = vmovl_u16(vget_high_u16(raw16));
|
||||
|
||||
#define DEQUANT_FP8_STORE_SOA(v_raw, dst) do { \
|
||||
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x80)); \
|
||||
const uint32x4_t exp = vandq_u32( \
|
||||
vshlq_u32(v_raw, v_neg_exp_shift), \
|
||||
v_exp_mask); \
|
||||
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask); \
|
||||
const uint32x4_t ieee = vorrq_u32( \
|
||||
vorrq_u32(vshlq_n_u32(sign, 24), \
|
||||
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)), \
|
||||
vshlq_u32(mant, v_mant_shift_v)); \
|
||||
const float32x4_t normal = vreinterpretq_f32_u32(ieee); \
|
||||
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc); \
|
||||
const uint32x4_t sub_bits = vorrq_u32( \
|
||||
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 24)); \
|
||||
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits); \
|
||||
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0)); \
|
||||
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal); \
|
||||
vst1q_f32(dst, vmulq_f32(val, v_scale)); \
|
||||
} while (0)
|
||||
|
||||
DEQUANT_FP8_STORE_SOA(v_lo, y + ib * QK_MXFP8 + j);
|
||||
DEQUANT_FP8_STORE_SOA(v_hi, y + ib * QK_MXFP8 + j + 4);
|
||||
#undef DEQUANT_FP8_STORE_SOA
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline void dequantize_row_mxfp6_soa_neon(
|
||||
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
|
||||
const uint32_t exp_mask, const uint32_t mant_mask,
|
||||
const int exp_shift, const uint32_t ieee_exp_off,
|
||||
const int mant_shift, const float sub_scale) {
|
||||
assert(k % QK_MXFP6 == 0);
|
||||
const int nb = k / QK_MXFP6;
|
||||
const char * row = (const char *)src;
|
||||
const char * qs_base = row;
|
||||
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
|
||||
|
||||
const uint32x4_t v_exp_mask = vdupq_n_u32(exp_mask);
|
||||
const uint32x4_t v_mant_mask = vdupq_n_u32(mant_mask);
|
||||
const uint32x4_t v_ieee_off = vdupq_n_u32(ieee_exp_off);
|
||||
const float32x4_t v_sub_sc = vdupq_n_f32(sub_scale);
|
||||
const int32x4_t v_neg_exp_shift = vdupq_n_s32(-exp_shift);
|
||||
const int32x4_t v_mant_shift_v = vdupq_n_s32(mant_shift);
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const float32x4_t v_scale = vdupq_n_f32(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
|
||||
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK));
|
||||
|
||||
for (int j = 0; j < 32; j += 4) {
|
||||
const uint8_t * p = qs + (j * 3 / 4);
|
||||
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
uint8_t unpacked[4];
|
||||
unpacked[0] = (pk >> 0) & 0x3F;
|
||||
unpacked[1] = (pk >> 6) & 0x3F;
|
||||
unpacked[2] = (pk >> 12) & 0x3F;
|
||||
unpacked[3] = (pk >> 18) & 0x3F;
|
||||
|
||||
const uint8x8_t raw8 = vcreate_u8(
|
||||
(uint64_t)unpacked[0] | ((uint64_t)unpacked[1] << 8) |
|
||||
((uint64_t)unpacked[2] << 16) | ((uint64_t)unpacked[3] << 24));
|
||||
const uint32x4_t v_raw = vmovl_u16(vget_low_u16(vmovl_u8(raw8)));
|
||||
|
||||
const uint32x4_t sign = vandq_u32(v_raw, vdupq_n_u32(0x20));
|
||||
const uint32x4_t exp = vandq_u32(
|
||||
vshlq_u32(v_raw, v_neg_exp_shift),
|
||||
v_exp_mask);
|
||||
const uint32x4_t mant = vandq_u32(v_raw, v_mant_mask);
|
||||
|
||||
const uint32x4_t ieee = vorrq_u32(
|
||||
vorrq_u32(vshlq_n_u32(sign, 26),
|
||||
vshlq_n_u32(vaddq_u32(exp, v_ieee_off), 23)),
|
||||
vshlq_u32(mant, v_mant_shift_v));
|
||||
const float32x4_t normal = vreinterpretq_f32_u32(ieee);
|
||||
|
||||
const float32x4_t sub_abs = vmulq_f32(vcvtq_f32_u32(mant), v_sub_sc);
|
||||
const uint32x4_t sub_bits = vorrq_u32(
|
||||
vreinterpretq_u32_f32(sub_abs), vshlq_n_u32(sign, 26));
|
||||
const float32x4_t sub_val = vreinterpretq_f32_u32(sub_bits);
|
||||
|
||||
const uint32x4_t is_sub = vceqq_u32(exp, vdupq_n_u32(0));
|
||||
const float32x4_t val = vbslq_f32(is_sub, sub_val, normal);
|
||||
|
||||
vst1q_f32(y + ib * QK_MXFP6 + j, vmulq_f32(val, v_scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__ARM_NEON)
|
||||
ggml_vec_dot_mxfp8_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E4M3);
|
||||
#else
|
||||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__ARM_NEON)
|
||||
ggml_vec_dot_mxfp6_q8_0_neon(n, s, vx, vy, &MXFP_TRAITS_E2M3);
|
||||
#else
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__ARM_NEON)
|
||||
dequantize_row_mxfp8_neon(x, y, k, &MXFP_TRAITS_E4M3);
|
||||
#else
|
||||
dequantize_row_mxfp8_cpu_generic(x, y, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__ARM_NEON)
|
||||
dequantize_row_mxfp6_neon(x, y, k, &MXFP_TRAITS_E2M3);
|
||||
#else
|
||||
dequantize_row_mxfp6_cpu_generic(x, y, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__ARM_NEON)
|
||||
|
|
@ -4654,9 +4540,7 @@ void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES
|
|||
|
||||
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__ARM_NEON)
|
||||
dequantize_row_mxfp8_soa_neon(x, y, k,
|
||||
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
|
||||
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
|
||||
dequantize_row_mxfp8_soa_neon(x, y, k, &MXFP_TRAITS_E4M3);
|
||||
#else
|
||||
dequantize_row_mxfp8_soa_cpu_generic(x, y, k);
|
||||
#endif
|
||||
|
|
@ -4664,9 +4548,7 @@ void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES
|
|||
|
||||
void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__ARM_NEON)
|
||||
dequantize_row_mxfp6_soa_neon(x, y, k,
|
||||
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
|
||||
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
|
||||
dequantize_row_mxfp6_soa_neon(x, y, k, &MXFP_TRAITS_E2M3);
|
||||
#else
|
||||
dequantize_row_mxfp6_soa_cpu_generic(x, y, k);
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -3819,30 +3819,77 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
#endif
|
||||
}
|
||||
|
||||
// AVX2-optimized MXFP8 × Q8_0 dot product.
|
||||
// Dequants FP8 elements to float via IEEE 754 bit construction, then dots against Q8_0.
|
||||
// Parameters encode the FP8 format: exp_mask, mant_mask, exp_shift, ieee_exp_offset, mant_shift, sub_scale.
|
||||
// ── MXFP FP8/FP6 AVX2 helpers ──────────────────────────────────────────────
|
||||
// Shared IEEE-754 bit reconstruction and FP6 unpacking used by vec_dot,
|
||||
// dequantize_row, and SoA dequant functions.
|
||||
|
||||
#if defined(__AVX2__)
|
||||
static inline void ggml_vec_dot_mxfp8_q8_0_avx2(
|
||||
|
||||
// Use shared mxfp_dequant_traits_t from ggml-common.h.
|
||||
// Aliases for readability within this file.
|
||||
#define mxfp_avx2_traits_t mxfp_dequant_traits_t
|
||||
|
||||
// Dequantize 8 raw MXFP values (widened to int32) → 8 IEEE-754 floats.
|
||||
// Handles both normal and subnormal paths. Works for any FP6/FP8 format.
|
||||
static inline __m256 mxfp_dequant_avx2(
|
||||
const __m256i v_raw,
|
||||
const __m256i v_exp_mask, const __m256i v_mant_mask,
|
||||
const __m256i v_ieee_off, const __m256 v_sub_sc,
|
||||
const __m256i v_sign_mask, const __m256i v_zero,
|
||||
int exp_shift, int sign_shift, int mant_shift) {
|
||||
const __m256i sign = _mm256_and_si256(v_raw, v_sign_mask);
|
||||
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
|
||||
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
|
||||
|
||||
const __m256i ieee = _mm256_or_si256(
|
||||
_mm256_or_si256(_mm256_slli_epi32(sign, sign_shift),
|
||||
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
|
||||
_mm256_slli_epi32(mant, mant_shift));
|
||||
const __m256 normal = _mm256_castsi256_ps(ieee);
|
||||
|
||||
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
|
||||
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
|
||||
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, sign_shift)));
|
||||
|
||||
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
|
||||
return _mm256_blendv_ps(normal, sub_val, is_sub);
|
||||
}
|
||||
|
||||
// Unpack 4 tightly-packed 6-bit values from 3 bytes into separate bytes.
|
||||
static inline void unpack_fp6x4_avx2(const uint8_t * p, uint8_t out[4]) {
|
||||
const uint32_t pk = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
out[0] = (pk >> 0) & 0x3F;
|
||||
out[1] = (pk >> 6) & 0x3F;
|
||||
out[2] = (pk >> 12) & 0x3F;
|
||||
out[3] = (pk >> 18) & 0x3F;
|
||||
}
|
||||
|
||||
// Unpack 8 FP6 values (two groups of 4) from packed qs data at offset j.
|
||||
static inline __m256i unpack_fp6x8_avx2(const uint8_t * qs, int j) {
|
||||
uint8_t unpacked[8];
|
||||
unpack_fp6x4_avx2(qs + (j * 3 / 4), unpacked);
|
||||
unpack_fp6x4_avx2(qs + ((j + 4) * 3 / 4), unpacked + 4);
|
||||
return _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)unpacked));
|
||||
}
|
||||
|
||||
// ── MXFP FP8/FP6 vec_dot ──────────────────────────────────────────────────
|
||||
|
||||
// Unified FP8 × Q8_0 dot product (works for E4M3 and E5M2).
|
||||
static void ggml_vec_dot_mxfp8_q8_0_avx2(
|
||||
int n, float * GGML_RESTRICT s,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
// FP8 format parameters:
|
||||
const int exp_mask, // 0xF for E4M3, 0x1F for E5M2
|
||||
const int mant_mask, // 0x7 for E4M3, 0x3 for E5M2
|
||||
const int exp_shift, // 3 for E4M3, 2 for E5M2
|
||||
const int ieee_exp_off, // 120 for E4M3, 112 for E5M2
|
||||
const int mant_shift, // 20 for E4M3, 21 for E5M2
|
||||
const float sub_scale) { // 1/512 for E4M3, 1/65536 for E5M2
|
||||
const mxfp_avx2_traits_t * t) {
|
||||
assert(n % QK_MXFP8 == 0);
|
||||
const int nb = n / QK_MXFP8;
|
||||
const block_mxfp8 * GGML_RESTRICT x = vx;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
|
||||
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
|
||||
const __m256i v_zero = _mm256_setzero_si256();
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
|
@ -3851,141 +3898,55 @@ static inline void ggml_vec_dot_mxfp8_q8_0_avx2(
|
|||
const __m256 v_scale = _mm256_set1_ps(
|
||||
GGML_E8M0_TO_FP32(x[ib].e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
// Process 32 FP8 elements in 4 groups of 8
|
||||
// AVX2 _mm256_cvtepu8_epi32 widens 8 bytes → 8 int32s directly
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
// Load 8 FP8 bytes → 8 int32s
|
||||
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(x[ib].qs + j));
|
||||
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
|
||||
const __m256i v_raw = _mm256_cvtepu8_epi32(
|
||||
_mm_loadl_epi64((const __m128i *)(x[ib].qs + j)));
|
||||
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
|
||||
_mm_loadl_epi64((const __m128i *)(y[ib].qs + j))));
|
||||
|
||||
// Load 8 Q8_0 int8 values → float
|
||||
const __m128i q8 = _mm_loadl_epi64((const __m128i *)(y[ib].qs + j));
|
||||
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q8));
|
||||
const __m256 val = mxfp_dequant_avx2(v_raw,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
|
||||
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
|
||||
|
||||
// Extract sign (bit 7), exponent, mantissa
|
||||
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80));
|
||||
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
|
||||
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
|
||||
|
||||
// Normal path: IEEE bits = (sign << 24) | ((exp + offset) << 23) | (mant << mant_shift)
|
||||
const __m256i ieee = _mm256_or_si256(
|
||||
_mm256_or_si256(_mm256_slli_epi32(sign, 24),
|
||||
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
|
||||
_mm256_slli_epi32(mant, mant_shift));
|
||||
const __m256 normal = _mm256_castsi256_ps(ieee);
|
||||
|
||||
// Subnormal path: |val| = mant * sub_scale, then apply sign
|
||||
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
|
||||
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
|
||||
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24)));
|
||||
|
||||
// Select: subnormal when exp == 0, else normal
|
||||
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
|
||||
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
|
||||
|
||||
// Accumulate: val * scale * q8_float
|
||||
acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc);
|
||||
}
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__AVX2__)
|
||||
// E4M3: sign(1) exp(4) mant(3), bias=7
|
||||
ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy,
|
||||
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
|
||||
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
|
||||
#else
|
||||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
// AVX2-optimized MXFP6 × Q8_0 dot product.
|
||||
// Unpacks tight 6-bit packing (4 values per 3 bytes), then dequants to float.
|
||||
#if defined(__AVX2__)
|
||||
static inline void ggml_vec_dot_mxfp6_q8_0_avx2(
|
||||
// Unified FP6 × Q8_0 dot product (works for E2M3 and E3M2).
|
||||
static void ggml_vec_dot_mxfp6_q8_0_avx2(
|
||||
int n, float * GGML_RESTRICT s,
|
||||
const void * GGML_RESTRICT vx,
|
||||
const void * GGML_RESTRICT vy,
|
||||
size_t block_size,
|
||||
// FP6 format parameters:
|
||||
const int exp_mask, // 0x3 for E2M3, 0x7 for E3M2
|
||||
const int mant_mask, // 0x7 for E2M3, 0x3 for E3M2
|
||||
const int exp_shift, // 3 for E2M3, 2 for E3M2
|
||||
const int ieee_exp_off, // 126 for E2M3, 124 for E3M2
|
||||
const int mant_shift, // 20 for E2M3, 21 for E3M2
|
||||
const float sub_scale) { // 1/8 for E2M3, 1/16 for E3M2
|
||||
const mxfp_avx2_traits_t * t) {
|
||||
assert(n % QK_MXFP6 == 0);
|
||||
const int nb = n / QK_MXFP6;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
|
||||
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
|
||||
const __m256i v_zero = _mm256_setzero_si256();
|
||||
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size);
|
||||
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
|
||||
const __m256 v_scale = _mm256_set1_ps(
|
||||
GGML_E8M0_TO_FP32(xb->e) * GGML_CPU_FP16_TO_FP32(y[ib].d));
|
||||
|
||||
// Process 32 FP6 elements in 4 groups of 8 (each group = 2 × 3-byte packs)
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
// Unpack 8 FP6 values from 6 bytes (two groups of 3 bytes → 4 values each)
|
||||
uint8_t unpacked[8];
|
||||
{
|
||||
const uint8_t * p = xb->qs + (j * 3 / 4);
|
||||
const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
unpacked[0] = (pk0 >> 0) & 0x3F;
|
||||
unpacked[1] = (pk0 >> 6) & 0x3F;
|
||||
unpacked[2] = (pk0 >> 12) & 0x3F;
|
||||
unpacked[3] = (pk0 >> 18) & 0x3F;
|
||||
}
|
||||
{
|
||||
const uint8_t * p = xb->qs + ((j + 4) * 3 / 4);
|
||||
const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
unpacked[4] = (pk1 >> 0) & 0x3F;
|
||||
unpacked[5] = (pk1 >> 6) & 0x3F;
|
||||
unpacked[6] = (pk1 >> 12) & 0x3F;
|
||||
unpacked[7] = (pk1 >> 18) & 0x3F;
|
||||
}
|
||||
const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j);
|
||||
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
|
||||
_mm_loadl_epi64((const __m128i *)(y[ib].qs + j))));
|
||||
|
||||
// Widen 8 bytes → 8 int32s
|
||||
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked);
|
||||
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
|
||||
|
||||
// Load 8 Q8_0 int8 values → float
|
||||
const __m128i q8 = _mm_loadl_epi64((const __m128i *)(y[ib].qs + j));
|
||||
const __m256 qf = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(q8));
|
||||
|
||||
// Extract sign (bit 5 for FP6), exponent, mantissa
|
||||
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20));
|
||||
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
|
||||
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
|
||||
|
||||
// Normal: IEEE bits = (sign << 26) | ((exp + offset) << 23) | (mant << mant_shift)
|
||||
const __m256i ieee = _mm256_or_si256(
|
||||
_mm256_or_si256(_mm256_slli_epi32(sign, 26),
|
||||
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
|
||||
_mm256_slli_epi32(mant, mant_shift));
|
||||
const __m256 normal = _mm256_castsi256_ps(ieee);
|
||||
|
||||
// Subnormal: |val| = mant * sub_scale, apply sign
|
||||
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
|
||||
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
|
||||
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26)));
|
||||
|
||||
// Select: subnormal when exp == 0
|
||||
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
|
||||
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
|
||||
const __m256 val = mxfp_dequant_avx2(v_raw,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
|
||||
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
|
||||
|
||||
acc = _mm256_fmadd_ps(_mm256_mul_ps(val, v_scale), qf, acc);
|
||||
}
|
||||
|
|
@ -3993,162 +3954,140 @@ static inline void ggml_vec_dot_mxfp6_q8_0_avx2(
|
|||
|
||||
*s = hsum_float_8(acc);
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__AVX2__)
|
||||
// E2M3: sign(1) exp(2) mant(3), bias=1
|
||||
ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, sizeof(block_mxfp6),
|
||||
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
|
||||
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
|
||||
#else
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
// ── MXFP FP8/FP6 dequantize_row (AoS) ─────────────────────────────────────
|
||||
|
||||
// ---- MXFP dequantize_row (to_float) — AVX2-optimized ----
|
||||
// Extracts the SIMD dequant logic from vec_dot above, writing floats to output buffer
|
||||
// instead of accumulating a dot product.
|
||||
|
||||
#if defined(__AVX2__)
|
||||
static inline void dequantize_row_mxfp8_avx2(
|
||||
static void dequantize_row_mxfp8_avx2(
|
||||
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
|
||||
const int exp_mask, const int mant_mask, const int exp_shift,
|
||||
const int ieee_exp_off, const int mant_shift, const float sub_scale) {
|
||||
const mxfp_avx2_traits_t * t) {
|
||||
assert(k % QK_MXFP8 == 0);
|
||||
const int nb = k / QK_MXFP8;
|
||||
const block_mxfp8 * GGML_RESTRICT x = vx;
|
||||
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
|
||||
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
|
||||
const __m256i v_zero = _mm256_setzero_si256();
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(x[ib].e));
|
||||
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(x[ib].qs + j));
|
||||
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
|
||||
const __m256i v_raw = _mm256_cvtepu8_epi32(
|
||||
_mm_loadl_epi64((const __m128i *)(x[ib].qs + j)));
|
||||
|
||||
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80));
|
||||
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
|
||||
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
|
||||
|
||||
const __m256i ieee = _mm256_or_si256(
|
||||
_mm256_or_si256(_mm256_slli_epi32(sign, 24),
|
||||
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
|
||||
_mm256_slli_epi32(mant, mant_shift));
|
||||
const __m256 normal = _mm256_castsi256_ps(ieee);
|
||||
|
||||
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
|
||||
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
|
||||
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24)));
|
||||
|
||||
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
|
||||
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
|
||||
const __m256 val = mxfp_dequant_avx2(v_raw,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
|
||||
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
|
||||
|
||||
_mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline void dequantize_row_mxfp6_avx2(
|
||||
static void dequantize_row_mxfp6_avx2(
|
||||
const void * GGML_RESTRICT vx, float * GGML_RESTRICT y, int64_t k,
|
||||
size_t block_size,
|
||||
const int exp_mask, const int mant_mask, const int exp_shift,
|
||||
const int ieee_exp_off, const int mant_shift, const float sub_scale) {
|
||||
const mxfp_avx2_traits_t * t) {
|
||||
assert(k % QK_MXFP6 == 0);
|
||||
const int nb = k / QK_MXFP6;
|
||||
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
|
||||
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
|
||||
const __m256i v_zero = _mm256_setzero_si256();
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const block_mxfp6 * GGML_RESTRICT xb = (const block_mxfp6 *)((const char *)vx + ib * block_size);
|
||||
const block_mxfp6 * GGML_RESTRICT xb = ((const block_mxfp6 *)vx) + ib;
|
||||
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32(xb->e));
|
||||
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
// Unpack 8 FP6 values from 6 bytes (two groups of 3 bytes → 4 values each)
|
||||
uint8_t unpacked[8];
|
||||
{
|
||||
const uint8_t * p = xb->qs + (j * 3 / 4);
|
||||
const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
unpacked[0] = (pk0 >> 0) & 0x3F;
|
||||
unpacked[1] = (pk0 >> 6) & 0x3F;
|
||||
unpacked[2] = (pk0 >> 12) & 0x3F;
|
||||
unpacked[3] = (pk0 >> 18) & 0x3F;
|
||||
}
|
||||
{
|
||||
const uint8_t * p = xb->qs + ((j + 4) * 3 / 4);
|
||||
const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
unpacked[4] = (pk1 >> 0) & 0x3F;
|
||||
unpacked[5] = (pk1 >> 6) & 0x3F;
|
||||
unpacked[6] = (pk1 >> 12) & 0x3F;
|
||||
unpacked[7] = (pk1 >> 18) & 0x3F;
|
||||
}
|
||||
const __m256i v_raw = unpack_fp6x8_avx2(xb->qs, j);
|
||||
|
||||
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked);
|
||||
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
|
||||
|
||||
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20));
|
||||
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
|
||||
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
|
||||
|
||||
const __m256i ieee = _mm256_or_si256(
|
||||
_mm256_or_si256(_mm256_slli_epi32(sign, 26),
|
||||
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
|
||||
_mm256_slli_epi32(mant, mant_shift));
|
||||
const __m256 normal = _mm256_castsi256_ps(ieee);
|
||||
|
||||
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
|
||||
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
|
||||
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26)));
|
||||
|
||||
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
|
||||
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
|
||||
const __m256 val = mxfp_dequant_avx2(v_raw,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
|
||||
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
|
||||
|
||||
_mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__AVX2__)
|
||||
dequantize_row_mxfp8_avx2(x, y, k,
|
||||
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
|
||||
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
|
||||
#else
|
||||
dequantize_row_mxfp8_cpu_generic(x, y, k);
|
||||
#endif
|
||||
// ── MXFP SoA dequant (flash attention) ─────────────────────────────────────
|
||||
|
||||
static void dequantize_row_mxfp8_soa_avx2(
|
||||
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
|
||||
const mxfp_avx2_traits_t * t) {
|
||||
assert(k % QK_MXFP8 == 0);
|
||||
const int nb = k / QK_MXFP8;
|
||||
const char * qs_base = (const char *)src;
|
||||
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
|
||||
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
|
||||
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
|
||||
const __m256i v_zero = _mm256_setzero_si256();
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
|
||||
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK));
|
||||
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
const __m256i v_raw = _mm256_cvtepu8_epi32(
|
||||
_mm_loadl_epi64((const __m128i *)(qs + j)));
|
||||
|
||||
const __m256 val = mxfp_dequant_avx2(v_raw,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
|
||||
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
|
||||
|
||||
_mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__AVX2__)
|
||||
dequantize_row_mxfp6_avx2(x, y, k, sizeof(block_mxfp6),
|
||||
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
|
||||
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
|
||||
#else
|
||||
dequantize_row_mxfp6_cpu_generic(x, y, k);
|
||||
#endif
|
||||
static void dequantize_row_mxfp6_soa_avx2(
|
||||
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
|
||||
const mxfp_avx2_traits_t * t) {
|
||||
assert(k % QK_MXFP6 == 0);
|
||||
const int nb = k / QK_MXFP6;
|
||||
const char * qs_base = (const char *)src;
|
||||
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
|
||||
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(t->exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(t->mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(t->ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(t->sub_scale);
|
||||
const __m256i v_sign_mask = _mm256_set1_epi32(t->sign_mask);
|
||||
const __m256i v_zero = _mm256_setzero_si256();
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
|
||||
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK));
|
||||
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
const __m256i v_raw = unpack_fp6x8_avx2(qs, j);
|
||||
|
||||
const __m256 val = mxfp_dequant_avx2(v_raw,
|
||||
v_exp_mask, v_mant_mask, v_ieee_off, v_sub_sc,
|
||||
v_sign_mask, v_zero, t->exp_shift, t->sign_shift, t->mant_shift);
|
||||
|
||||
_mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SoA dequant for flash attention — contiguous qs region + separate e8m0 region
|
||||
#if defined(__AVX2__)
|
||||
static inline void dequantize_row_mxfp4_soa_avx2(
|
||||
// MXFP4 SoA dequant — LUT-based, no IEEE reconstruction needed.
|
||||
static void dequantize_row_mxfp4_soa_avx2(
|
||||
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k) {
|
||||
assert(k % QK_MXFP4 == 0);
|
||||
const int nb = k / QK_MXFP4;
|
||||
const char * row = (const char *)src;
|
||||
const char * qs_base = row;
|
||||
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
|
||||
const char * qs_base = (const char *)src;
|
||||
const char * e8m0_base = qs_base + MXFP_SOA_E8M0_OFFSET(nb, MXFP4_SOA_QS_PER_BLOCK);
|
||||
|
||||
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
|
||||
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||
|
|
@ -4163,13 +4102,11 @@ static inline void dequantize_row_mxfp4_soa_avx2(
|
|||
const __m128i lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits, m4b));
|
||||
const __m128i hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4b));
|
||||
|
||||
// lo nibbles → first 16 floats
|
||||
const __m256i lo32_0 = _mm256_cvtepi8_epi32(lo);
|
||||
const __m256i lo32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(lo, 8));
|
||||
_mm256_storeu_ps(y + i * QK_MXFP4 + 0, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_0), v_scale));
|
||||
_mm256_storeu_ps(y + i * QK_MXFP4 + 8, _mm256_mul_ps(_mm256_cvtepi32_ps(lo32_1), v_scale));
|
||||
|
||||
// hi nibbles → second 16 floats
|
||||
const __m256i hi32_0 = _mm256_cvtepi8_epi32(hi);
|
||||
const __m256i hi32_1 = _mm256_cvtepi8_epi32(_mm_srli_si128(hi, 8));
|
||||
_mm256_storeu_ps(y + i * QK_MXFP4 + 16, _mm256_mul_ps(_mm256_cvtepi32_ps(hi32_0), v_scale));
|
||||
|
|
@ -4177,116 +4114,45 @@ static inline void dequantize_row_mxfp4_soa_avx2(
|
|||
}
|
||||
}
|
||||
|
||||
static inline void dequantize_row_mxfp8_soa_avx2(
|
||||
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
|
||||
const int exp_mask, const int mant_mask, const int exp_shift,
|
||||
const int ieee_exp_off, const int mant_shift, const float sub_scale) {
|
||||
assert(k % QK_MXFP8 == 0);
|
||||
const int nb = k / QK_MXFP8;
|
||||
const char * row = (const char *)src;
|
||||
const char * qs_base = row;
|
||||
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP8_SOA_QS_PER_BLOCK);
|
||||
#endif // __AVX2__
|
||||
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
|
||||
const __m256i v_zero = _mm256_setzero_si256();
|
||||
// ── Public dispatch functions ──────────────────────────────────────────────
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
|
||||
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP8_SOA_QS_PER_BLOCK));
|
||||
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)(qs + j));
|
||||
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
|
||||
|
||||
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x80));
|
||||
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
|
||||
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
|
||||
|
||||
const __m256i ieee = _mm256_or_si256(
|
||||
_mm256_or_si256(_mm256_slli_epi32(sign, 24),
|
||||
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
|
||||
_mm256_slli_epi32(mant, mant_shift));
|
||||
const __m256 normal = _mm256_castsi256_ps(ieee);
|
||||
|
||||
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
|
||||
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
|
||||
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 24)));
|
||||
|
||||
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
|
||||
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
|
||||
|
||||
_mm256_storeu_ps(y + ib * QK_MXFP8 + j, _mm256_mul_ps(val, v_scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline void dequantize_row_mxfp6_soa_avx2(
|
||||
const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k,
|
||||
const int exp_mask, const int mant_mask, const int exp_shift,
|
||||
const int ieee_exp_off, const int mant_shift, const float sub_scale) {
|
||||
assert(k % QK_MXFP6 == 0);
|
||||
const int nb = k / QK_MXFP6;
|
||||
const char * row = (const char *)src;
|
||||
const char * qs_base = row;
|
||||
const char * e8m0_base = row + MXFP_SOA_E8M0_OFFSET(nb, MXFP6_SOA_QS_PER_BLOCK);
|
||||
|
||||
const __m256i v_exp_mask = _mm256_set1_epi32(exp_mask);
|
||||
const __m256i v_mant_mask = _mm256_set1_epi32(mant_mask);
|
||||
const __m256i v_ieee_off = _mm256_set1_epi32(ieee_exp_off);
|
||||
const __m256 v_sub_sc = _mm256_set1_ps(sub_scale);
|
||||
const __m256i v_zero = _mm256_setzero_si256();
|
||||
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
const __m256 v_scale = _mm256_set1_ps(GGML_E8M0_TO_FP32((uint8_t)e8m0_base[ib]));
|
||||
const uint8_t * qs = (const uint8_t *)(qs_base + MXFP_SOA_QS_OFFSET(ib, MXFP6_SOA_QS_PER_BLOCK));
|
||||
|
||||
for (int j = 0; j < 32; j += 8) {
|
||||
uint8_t unpacked[8];
|
||||
{
|
||||
const uint8_t * p = qs + (j * 3 / 4);
|
||||
const uint32_t pk0 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
unpacked[0] = (pk0 >> 0) & 0x3F;
|
||||
unpacked[1] = (pk0 >> 6) & 0x3F;
|
||||
unpacked[2] = (pk0 >> 12) & 0x3F;
|
||||
unpacked[3] = (pk0 >> 18) & 0x3F;
|
||||
}
|
||||
{
|
||||
const uint8_t * p = qs + ((j + 4) * 3 / 4);
|
||||
const uint32_t pk1 = (uint32_t)p[0] | ((uint32_t)p[1] << 8) | ((uint32_t)p[2] << 16);
|
||||
unpacked[4] = (pk1 >> 0) & 0x3F;
|
||||
unpacked[5] = (pk1 >> 6) & 0x3F;
|
||||
unpacked[6] = (pk1 >> 12) & 0x3F;
|
||||
unpacked[7] = (pk1 >> 18) & 0x3F;
|
||||
}
|
||||
|
||||
const __m128i raw8 = _mm_loadl_epi64((const __m128i *)unpacked);
|
||||
const __m256i v_raw = _mm256_cvtepu8_epi32(raw8);
|
||||
|
||||
const __m256i sign = _mm256_and_si256(v_raw, _mm256_set1_epi32(0x20));
|
||||
const __m256i exp = _mm256_and_si256(_mm256_srli_epi32(v_raw, exp_shift), v_exp_mask);
|
||||
const __m256i mant = _mm256_and_si256(v_raw, v_mant_mask);
|
||||
|
||||
const __m256i ieee = _mm256_or_si256(
|
||||
_mm256_or_si256(_mm256_slli_epi32(sign, 26),
|
||||
_mm256_slli_epi32(_mm256_add_epi32(exp, v_ieee_off), 23)),
|
||||
_mm256_slli_epi32(mant, mant_shift));
|
||||
const __m256 normal = _mm256_castsi256_ps(ieee);
|
||||
|
||||
const __m256 sub_abs = _mm256_mul_ps(_mm256_cvtepi32_ps(mant), v_sub_sc);
|
||||
const __m256 sub_val = _mm256_castsi256_ps(_mm256_or_si256(
|
||||
_mm256_castps_si256(sub_abs), _mm256_slli_epi32(sign, 26)));
|
||||
|
||||
const __m256 is_sub = _mm256_castsi256_ps(_mm256_cmpeq_epi32(exp, v_zero));
|
||||
const __m256 val = _mm256_blendv_ps(normal, sub_val, is_sub);
|
||||
|
||||
_mm256_storeu_ps(y + ib * QK_MXFP6 + j, _mm256_mul_ps(val, v_scale));
|
||||
}
|
||||
}
|
||||
}
|
||||
void ggml_vec_dot_mxfp8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__AVX2__)
|
||||
ggml_vec_dot_mxfp8_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E4M3);
|
||||
#else
|
||||
ggml_vec_dot_mxfp8_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_mxfp6_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc); UNUSED(bs); UNUSED(bx); UNUSED(by);
|
||||
#if defined(__AVX2__)
|
||||
ggml_vec_dot_mxfp6_q8_0_avx2(n, s, vx, vy, &MXFP_TRAITS_E2M3);
|
||||
#else
|
||||
ggml_vec_dot_mxfp6_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp8_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__AVX2__)
|
||||
dequantize_row_mxfp8_avx2(x, y, k, &MXFP_TRAITS_E4M3);
|
||||
#else
|
||||
dequantize_row_mxfp8_cpu_generic(x, y, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp6_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__AVX2__)
|
||||
dequantize_row_mxfp6_avx2(x, y, k, &MXFP_TRAITS_E2M3);
|
||||
#else
|
||||
dequantize_row_mxfp6_cpu_generic(x, y, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__AVX2__)
|
||||
|
|
@ -4298,9 +4164,7 @@ void dequantize_row_mxfp4_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES
|
|||
|
||||
void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__AVX2__)
|
||||
dequantize_row_mxfp8_soa_avx2(x, y, k,
|
||||
MXFP8_E4M3_EXP_MASK, MXFP8_E4M3_MANT_MASK, MXFP8_E4M3_EXP_SHIFT,
|
||||
MXFP8_E4M3_IEEE_EXP_OFF, MXFP8_E4M3_MANT_SHIFT, MXFP8_E4M3_SUB_SCALE);
|
||||
dequantize_row_mxfp8_soa_avx2(x, y, k, &MXFP_TRAITS_E4M3);
|
||||
#else
|
||||
dequantize_row_mxfp8_soa_cpu_generic(x, y, k);
|
||||
#endif
|
||||
|
|
@ -4308,9 +4172,7 @@ void dequantize_row_mxfp8_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RES
|
|||
|
||||
void dequantize_row_mxfp6_soa_cpu(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
|
||||
#if defined(__AVX2__)
|
||||
dequantize_row_mxfp6_soa_avx2(x, y, k,
|
||||
MXFP6_E2M3_EXP_MASK, MXFP6_E2M3_MANT_MASK, MXFP6_E2M3_EXP_SHIFT,
|
||||
MXFP6_E2M3_IEEE_EXP_OFF, MXFP6_E2M3_MANT_SHIFT, MXFP6_E2M3_SUB_SCALE);
|
||||
dequantize_row_mxfp6_soa_avx2(x, y, k, &MXFP_TRAITS_E2M3);
|
||||
#else
|
||||
dequantize_row_mxfp6_soa_cpu_generic(x, y, k);
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -8297,8 +8297,8 @@ static mxfp_fa_params mxfp_fa_params_init(
|
|||
|
||||
if (is_mxfp_k) {
|
||||
switch (k->type) {
|
||||
case GGML_TYPE_MXFP4_E2M1: p.q_quantize = quantize_row_mxfp4_soa; p.k_dequantize = dequantize_row_mxfp4_soa_cpu; break;
|
||||
case GGML_TYPE_MXFP8_E4M3: p.q_quantize = quantize_row_mxfp8_soa; p.k_dequantize = dequantize_row_mxfp8_soa_cpu; break;
|
||||
case GGML_TYPE_MXFP4_E2M1: p.q_quantize = quantize_row_mxfp4_soa; p.k_dequantize = dequantize_row_mxfp4_soa_cpu; break;
|
||||
case GGML_TYPE_MXFP8_E4M3: p.q_quantize = quantize_row_mxfp8_soa; p.k_dequantize = dequantize_row_mxfp8_soa_cpu; break;
|
||||
case GGML_TYPE_MXFP6_E2M3: p.q_quantize = quantize_row_mxfp6_soa; p.k_dequantize = dequantize_row_mxfp6_soa_cpu; break;
|
||||
default: GGML_ABORT("unsupported MXFP K type");
|
||||
}
|
||||
|
|
@ -8306,8 +8306,8 @@ static mxfp_fa_params mxfp_fa_params_init(
|
|||
|
||||
if (is_mxfp_v) {
|
||||
switch (v->type) {
|
||||
case GGML_TYPE_MXFP4_E2M1: p.v_dequantize = dequantize_row_mxfp4_soa_cpu; break;
|
||||
case GGML_TYPE_MXFP8_E4M3: p.v_dequantize = dequantize_row_mxfp8_soa_cpu; break;
|
||||
case GGML_TYPE_MXFP4_E2M1: p.v_dequantize = dequantize_row_mxfp4_soa_cpu; break;
|
||||
case GGML_TYPE_MXFP8_E4M3: p.v_dequantize = dequantize_row_mxfp8_soa_cpu; break;
|
||||
case GGML_TYPE_MXFP6_E2M3: p.v_dequantize = dequantize_row_mxfp6_soa_cpu; break;
|
||||
default: GGML_ABORT("unsupported MXFP V type");
|
||||
}
|
||||
|
|
@ -8328,6 +8328,7 @@ static mxfp_fa_params mxfp_fa_params_init(
|
|||
|
||||
// Per-head SoA addressing for multihead mode.
|
||||
// Precompute byte offsets so the hot loop can skip per-head pointer math.
|
||||
// qs_per_block values from centralized MXFP_QS_PER_BLOCK_* defines in ggml-common.h.
|
||||
auto mxfp_qs_per_block = [](ggml_type type) -> int {
|
||||
switch (type) {
|
||||
case GGML_TYPE_MXFP4_E2M1: return MXFP4_SOA_QS_PER_BLOCK;
|
||||
|
|
@ -8341,7 +8342,6 @@ static mxfp_fa_params mxfp_fa_params_init(
|
|||
p.k_qs_per_block = mxfp_qs_per_block(k->type);
|
||||
p.k_blocks_per_head = (int)(DK / 32);
|
||||
p.k_head_qs_bytes = p.k_blocks_per_head * p.k_qs_per_block;
|
||||
// e8m0 offset from row start = total_blocks * qs_per_block
|
||||
const int64_t k_total_blocks = p.k_multihead ? nek2 * p.k_blocks_per_head : p.k_blocks_per_head;
|
||||
p.k_head_e8m0_offset = k_total_blocks * p.k_qs_per_block;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ GGML_API void quantize_row_mxfp4_soa (const float * GGML_RESTRICT x, void * G
|
|||
GGML_API void dequantize_row_mxfp4_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_mxfp8_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
|
||||
GGML_API void dequantize_row_mxfp8_soa (const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
|
||||
GGML_API void quantize_row_mxfp6_soa(const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
|
||||
GGML_API void quantize_row_mxfp6_soa (const float * GGML_RESTRICT x, void * GGML_RESTRICT dst, int64_t k);
|
||||
GGML_API void dequantize_row_mxfp6_soa(const void * GGML_RESTRICT src, float * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
|
|
|
|||
Loading…
Reference in New Issue