Optimize ggml_vec_dot_mxfp4_q8_0 dot product on ARM SVE

This commit is contained in:
jiang.suhang 2026-01-29 15:40:06 +09:00
parent b33df266d0
commit 18ad28ca0a
1 changed files with 105 additions and 1 deletions

View File

@ -606,7 +606,111 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
int ib = 0;
float sumf = 0;
#if defined __ARM_NEON
#if defined(__ARM_FEATURE_SVE)
svfloat32_t sumv0 = svdup_n_f32(0.0f);
svfloat32_t sumv1 = svdup_n_f32(0.0f);
const int vector_length = ggml_cpu_get_sve_cnt() * 8;
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
// load LUT
svint8_t lut = svld1_s8(ph16, kvalues_mxfp4);
switch (vector_length) {
case 128:
{
const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
for (; ib + 1 < nb; ib += 2) {
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1];
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
// load x
const svuint8_t qx0r = svld1rq_u8(ph16, x0->qs);
const svuint8_t qx1r = svld1rq_u8(ph16, x1->qs);
// extract nibble
const svuint8_t idx0l = svand_n_u8_m(ph16, qx0r, 0x0F);
const svuint8_t idx0h = svlsr_n_u8_m(ph16, qx0r, 0x04);
const svuint8_t idx1l = svand_n_u8_m(ph16, qx1r, 0x0F);
const svuint8_t idx1h = svlsr_n_u8_m(ph16, qx1r, 0x04);
// 4-bit -> 8-bit
const svint8_t qx0l = svtbl_s8(lut, idx0l);
const svint8_t qx0h = svtbl_s8(lut, idx0h);
const svint8_t qx1l = svtbl_s8(lut, idx1l);
const svint8_t qx1h = svtbl_s8(lut, idx1h);
// load y
const svint8_t qy0h = svld1_s8(ph16, (const int8_t *) (y0->qs));
const svint8_t qy0l = svld1_s8(ph16, (const int8_t *) (y0->qs + 16));
const svint8_t qy1h = svld1_s8(ph16, (const int8_t *) (y1->qs));
const svint8_t qy1l = svld1_s8(ph16, (const int8_t *) (y1->qs + 16));
// dot product
const svint32_t dot0 = svdot_s32(
svdot_s32(svdup_n_s32(0), qx0l, qy0h),
qx0h, qy0l
);
const svint32_t dot1 = svdot_s32(
svdot_s32(svdup_n_s32(0), qx1l, qy1h),
qx1h, qy1l
);
sumv0 = svmla_n_f32_x(ph4, sumv0,
svcvt_f32_s32_x(ph4, dot0), GGML_CPU_FP16_TO_FP32(y0->d) * GGML_E8M0_TO_FP32_HALF(x0->e));
sumv1 = svmla_n_f32_x(ph4, sumv1,
svcvt_f32_s32_x(ph4, dot1), GGML_CPU_FP16_TO_FP32(y1->d) * GGML_E8M0_TO_FP32_HALF(x1->e));
}
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
} break;
case 256:
case 512:
{
const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
const svbool_t pl16 = svnot_b_z(ph32, ph16);
for (; ib + 1 < nb; ib += 2) {
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1];
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
// load x
const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
// extract nibble
const svuint8_t idx0 = svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04);
const svuint8_t idx1 = svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04);
// 4-bit -> 8-bit
const svint8_t qx0 = svtbl_s8(lut, idx0);
const svint8_t qx1 = svtbl_s8(lut, idx1);
// load y
const svint8_t qy0 = svld1_s8(ph32, y0->qs);
const svint8_t qy1 = svld1_s8(ph32, y1->qs);
// dot product
const svint32_t dot0 = svdot_s32(svdup_n_s32(0), qx0, qy0);
const svint32_t dot1 = svdot_s32(svdup_n_s32(0), qx1, qy1);
sumv0 = svmla_n_f32_x(ph32, sumv0,
svcvt_f32_s32_x(ph32, dot0), GGML_CPU_FP16_TO_FP32(y0->d) * GGML_E8M0_TO_FP32_HALF(x0->e));
sumv1 = svmla_n_f32_x(ph32, sumv1,
svcvt_f32_s32_x(ph32, dot1), GGML_CPU_FP16_TO_FP32(y1->d) * GGML_E8M0_TO_FP32_HALF(x1->e));
}
sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
} break;
default:
assert(false && "Unsupported vector length");
break;
}
#elif defined (__ARM_NEON)
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
const uint8x16_t m4b = vdupq_n_u8(0x0f);
uint8x16x2_t q4bits;