Merge 18ad28ca0a into 2634ed207a
This commit is contained in:
commit
5d2b111463
|
|
@ -606,7 +606,111 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
int ib = 0;
|
||||
float sumf = 0;
|
||||
|
||||
#if defined __ARM_NEON
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
||||
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
||||
|
||||
const int vector_length = ggml_cpu_get_sve_cnt() * 8;
|
||||
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
|
||||
// load LUT
|
||||
svint8_t lut = svld1_s8(ph16, kvalues_mxfp4);
|
||||
|
||||
switch (vector_length) {
|
||||
case 128:
|
||||
{
|
||||
const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
|
||||
|
||||
for (; ib + 1 < nb; ib += 2) {
|
||||
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
|
||||
const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1];
|
||||
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
|
||||
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
|
||||
|
||||
// load x
|
||||
const svuint8_t qx0r = svld1rq_u8(ph16, x0->qs);
|
||||
const svuint8_t qx1r = svld1rq_u8(ph16, x1->qs);
|
||||
|
||||
// extract nibble
|
||||
const svuint8_t idx0l = svand_n_u8_m(ph16, qx0r, 0x0F);
|
||||
const svuint8_t idx0h = svlsr_n_u8_m(ph16, qx0r, 0x04);
|
||||
const svuint8_t idx1l = svand_n_u8_m(ph16, qx1r, 0x0F);
|
||||
const svuint8_t idx1h = svlsr_n_u8_m(ph16, qx1r, 0x04);
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
const svint8_t qx0l = svtbl_s8(lut, idx0l);
|
||||
const svint8_t qx0h = svtbl_s8(lut, idx0h);
|
||||
const svint8_t qx1l = svtbl_s8(lut, idx1l);
|
||||
const svint8_t qx1h = svtbl_s8(lut, idx1h);
|
||||
|
||||
// load y
|
||||
const svint8_t qy0h = svld1_s8(ph16, (const int8_t *) (y0->qs));
|
||||
const svint8_t qy0l = svld1_s8(ph16, (const int8_t *) (y0->qs + 16));
|
||||
const svint8_t qy1h = svld1_s8(ph16, (const int8_t *) (y1->qs));
|
||||
const svint8_t qy1l = svld1_s8(ph16, (const int8_t *) (y1->qs + 16));
|
||||
|
||||
// dot product
|
||||
const svint32_t dot0 = svdot_s32(
|
||||
svdot_s32(svdup_n_s32(0), qx0l, qy0h),
|
||||
qx0h, qy0l
|
||||
);
|
||||
const svint32_t dot1 = svdot_s32(
|
||||
svdot_s32(svdup_n_s32(0), qx1l, qy1h),
|
||||
qx1h, qy1l
|
||||
);
|
||||
|
||||
sumv0 = svmla_n_f32_x(ph4, sumv0,
|
||||
svcvt_f32_s32_x(ph4, dot0), GGML_CPU_FP16_TO_FP32(y0->d) * GGML_E8M0_TO_FP32_HALF(x0->e));
|
||||
sumv1 = svmla_n_f32_x(ph4, sumv1,
|
||||
svcvt_f32_s32_x(ph4, dot1), GGML_CPU_FP16_TO_FP32(y1->d) * GGML_E8M0_TO_FP32_HALF(x1->e));
|
||||
}
|
||||
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
||||
} break;
|
||||
case 256:
|
||||
case 512:
|
||||
{
|
||||
const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
|
||||
const svbool_t pl16 = svnot_b_z(ph32, ph16);
|
||||
|
||||
for (; ib + 1 < nb; ib += 2) {
|
||||
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
|
||||
const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1];
|
||||
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
|
||||
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
|
||||
|
||||
// load x
|
||||
const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
|
||||
const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
|
||||
|
||||
// extract nibble
|
||||
const svuint8_t idx0 = svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04);
|
||||
const svuint8_t idx1 = svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04);
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
const svint8_t qx0 = svtbl_s8(lut, idx0);
|
||||
const svint8_t qx1 = svtbl_s8(lut, idx1);
|
||||
|
||||
// load y
|
||||
const svint8_t qy0 = svld1_s8(ph32, y0->qs);
|
||||
const svint8_t qy1 = svld1_s8(ph32, y1->qs);
|
||||
|
||||
// dot product
|
||||
const svint32_t dot0 = svdot_s32(svdup_n_s32(0), qx0, qy0);
|
||||
const svint32_t dot1 = svdot_s32(svdup_n_s32(0), qx1, qy1);
|
||||
|
||||
sumv0 = svmla_n_f32_x(ph32, sumv0,
|
||||
svcvt_f32_s32_x(ph32, dot0), GGML_CPU_FP16_TO_FP32(y0->d) * GGML_E8M0_TO_FP32_HALF(x0->e));
|
||||
sumv1 = svmla_n_f32_x(ph32, sumv1,
|
||||
svcvt_f32_s32_x(ph32, dot1), GGML_CPU_FP16_TO_FP32(y1->d) * GGML_E8M0_TO_FP32_HALF(x1->e));
|
||||
}
|
||||
sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
|
||||
} break;
|
||||
|
||||
default:
|
||||
assert(false && "Unsupported vector length");
|
||||
break;
|
||||
}
|
||||
|
||||
#elif defined (__ARM_NEON)
|
||||
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
uint8x16x2_t q4bits;
|
||||
|
|
|
|||
Loading…
Reference in New Issue