diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index b390ab61c7..8ee5deff60 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -606,7 +606,111 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo int ib = 0; float sumf = 0; -#if defined __ARM_NEON +#if defined(__ARM_FEATURE_SVE) + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); + + const int vector_length = ggml_cpu_get_sve_cnt() * 8; + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // load LUT + svint8_t lut = svld1_s8(ph16, kvalues_mxfp4); + + switch (vector_length) { + case 128: + { + const svbool_t ph4 = svptrue_pat_b32(SV_VL4); + + for (; ib + 1 < nb; ib += 2) { + const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(ph16, x0->qs); + const svuint8_t qx1r = svld1rq_u8(ph16, x1->qs); + + // extract nibble + const svuint8_t idx0l = svand_n_u8_m(ph16, qx0r, 0x0F); + const svuint8_t idx0h = svlsr_n_u8_m(ph16, qx0r, 0x04); + const svuint8_t idx1l = svand_n_u8_m(ph16, qx1r, 0x0F); + const svuint8_t idx1h = svlsr_n_u8_m(ph16, qx1r, 0x04); + + // 4-bit -> 8-bit + const svint8_t qx0l = svtbl_s8(lut, idx0l); + const svint8_t qx0h = svtbl_s8(lut, idx0h); + const svint8_t qx1l = svtbl_s8(lut, idx1l); + const svint8_t qx1h = svtbl_s8(lut, idx1h); + + // load y + const svint8_t qy0h = svld1_s8(ph16, (const int8_t *) (y0->qs)); + const svint8_t qy0l = svld1_s8(ph16, (const int8_t *) (y0->qs + 16)); + const svint8_t qy1h = svld1_s8(ph16, (const int8_t *) (y1->qs)); + const svint8_t qy1l = svld1_s8(ph16, (const int8_t *) (y1->qs + 16)); + + // dot product + const svint32_t dot0 = svdot_s32( + svdot_s32(svdup_n_s32(0), qx0l, qy0h), + qx0h, qy0l + ); + const svint32_t dot1 = svdot_s32( + svdot_s32(svdup_n_s32(0), qx1l, qy1h), + qx1h, qy1l + ); + + sumv0 = svmla_n_f32_x(ph4, sumv0, + svcvt_f32_s32_x(ph4, dot0), GGML_CPU_FP16_TO_FP32(y0->d) * GGML_E8M0_TO_FP32_HALF(x0->e)); + sumv1 = svmla_n_f32_x(ph4, sumv1, + svcvt_f32_s32_x(ph4, dot1), GGML_CPU_FP16_TO_FP32(y1->d) * GGML_E8M0_TO_FP32_HALF(x1->e)); + } + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; + case 256: + case 512: + { + const svbool_t ph32 = svptrue_pat_b8(SV_VL32); + const svbool_t pl16 = svnot_b_z(ph32, ph16); + + for (; ib + 1 < nb; ib += 2) { + const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs); + const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs); + + // extract nibble + const svuint8_t idx0 = svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04); + const svuint8_t idx1 = svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04); + + // 4-bit -> 8-bit + const svint8_t qx0 = svtbl_s8(lut, idx0); + const svint8_t qx1 = svtbl_s8(lut, idx1); + + // load y + const svint8_t qy0 = svld1_s8(ph32, y0->qs); + const svint8_t qy1 = svld1_s8(ph32, y1->qs); + + // dot product + const svint32_t dot0 = svdot_s32(svdup_n_s32(0), qx0, qy0); + const svint32_t dot1 = svdot_s32(svdup_n_s32(0), qx1, qy1); + + sumv0 = svmla_n_f32_x(ph32, sumv0, + svcvt_f32_s32_x(ph32, dot0), GGML_CPU_FP16_TO_FP32(y0->d) * GGML_E8M0_TO_FP32_HALF(x0->e)); + sumv1 = svmla_n_f32_x(ph32, sumv1, + svcvt_f32_s32_x(ph32, dot1), GGML_CPU_FP16_TO_FP32(y1->d) * GGML_E8M0_TO_FP32_HALF(x1->e)); + } + sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1)); + } break; + + default: + assert(false && "Unsupported vector length"); + break; + } + +#elif defined (__ARM_NEON) const int8x16_t values = vld1q_s8(kvalues_mxfp4); const uint8x16_t m4b = vdupq_n_u8(0x0f); uint8x16x2_t q4bits;