Merge 18ad28ca0a into 2634ed207a
This commit is contained in:
commit
5d2b111463
|
|
@ -606,7 +606,111 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
int ib = 0;
|
int ib = 0;
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
|
||||||
#if defined __ARM_NEON
|
#if defined(__ARM_FEATURE_SVE)
|
||||||
|
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
||||||
|
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
||||||
|
|
||||||
|
const int vector_length = ggml_cpu_get_sve_cnt() * 8;
|
||||||
|
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
|
||||||
|
// load LUT
|
||||||
|
svint8_t lut = svld1_s8(ph16, kvalues_mxfp4);
|
||||||
|
|
||||||
|
switch (vector_length) {
|
||||||
|
case 128:
|
||||||
|
{
|
||||||
|
const svbool_t ph4 = svptrue_pat_b32(SV_VL4);
|
||||||
|
|
||||||
|
for (; ib + 1 < nb; ib += 2) {
|
||||||
|
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
|
||||||
|
const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1];
|
||||||
|
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
|
||||||
|
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
|
||||||
|
|
||||||
|
// load x
|
||||||
|
const svuint8_t qx0r = svld1rq_u8(ph16, x0->qs);
|
||||||
|
const svuint8_t qx1r = svld1rq_u8(ph16, x1->qs);
|
||||||
|
|
||||||
|
// extract nibble
|
||||||
|
const svuint8_t idx0l = svand_n_u8_m(ph16, qx0r, 0x0F);
|
||||||
|
const svuint8_t idx0h = svlsr_n_u8_m(ph16, qx0r, 0x04);
|
||||||
|
const svuint8_t idx1l = svand_n_u8_m(ph16, qx1r, 0x0F);
|
||||||
|
const svuint8_t idx1h = svlsr_n_u8_m(ph16, qx1r, 0x04);
|
||||||
|
|
||||||
|
// 4-bit -> 8-bit
|
||||||
|
const svint8_t qx0l = svtbl_s8(lut, idx0l);
|
||||||
|
const svint8_t qx0h = svtbl_s8(lut, idx0h);
|
||||||
|
const svint8_t qx1l = svtbl_s8(lut, idx1l);
|
||||||
|
const svint8_t qx1h = svtbl_s8(lut, idx1h);
|
||||||
|
|
||||||
|
// load y
|
||||||
|
const svint8_t qy0h = svld1_s8(ph16, (const int8_t *) (y0->qs));
|
||||||
|
const svint8_t qy0l = svld1_s8(ph16, (const int8_t *) (y0->qs + 16));
|
||||||
|
const svint8_t qy1h = svld1_s8(ph16, (const int8_t *) (y1->qs));
|
||||||
|
const svint8_t qy1l = svld1_s8(ph16, (const int8_t *) (y1->qs + 16));
|
||||||
|
|
||||||
|
// dot product
|
||||||
|
const svint32_t dot0 = svdot_s32(
|
||||||
|
svdot_s32(svdup_n_s32(0), qx0l, qy0h),
|
||||||
|
qx0h, qy0l
|
||||||
|
);
|
||||||
|
const svint32_t dot1 = svdot_s32(
|
||||||
|
svdot_s32(svdup_n_s32(0), qx1l, qy1h),
|
||||||
|
qx1h, qy1l
|
||||||
|
);
|
||||||
|
|
||||||
|
sumv0 = svmla_n_f32_x(ph4, sumv0,
|
||||||
|
svcvt_f32_s32_x(ph4, dot0), GGML_CPU_FP16_TO_FP32(y0->d) * GGML_E8M0_TO_FP32_HALF(x0->e));
|
||||||
|
sumv1 = svmla_n_f32_x(ph4, sumv1,
|
||||||
|
svcvt_f32_s32_x(ph4, dot1), GGML_CPU_FP16_TO_FP32(y1->d) * GGML_E8M0_TO_FP32_HALF(x1->e));
|
||||||
|
}
|
||||||
|
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
||||||
|
} break;
|
||||||
|
case 256:
|
||||||
|
case 512:
|
||||||
|
{
|
||||||
|
const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
|
||||||
|
const svbool_t pl16 = svnot_b_z(ph32, ph16);
|
||||||
|
|
||||||
|
for (; ib + 1 < nb; ib += 2) {
|
||||||
|
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
|
||||||
|
const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1];
|
||||||
|
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
|
||||||
|
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];
|
||||||
|
|
||||||
|
// load x
|
||||||
|
const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
|
||||||
|
const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);
|
||||||
|
|
||||||
|
// extract nibble
|
||||||
|
const svuint8_t idx0 = svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04);
|
||||||
|
const svuint8_t idx1 = svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04);
|
||||||
|
|
||||||
|
// 4-bit -> 8-bit
|
||||||
|
const svint8_t qx0 = svtbl_s8(lut, idx0);
|
||||||
|
const svint8_t qx1 = svtbl_s8(lut, idx1);
|
||||||
|
|
||||||
|
// load y
|
||||||
|
const svint8_t qy0 = svld1_s8(ph32, y0->qs);
|
||||||
|
const svint8_t qy1 = svld1_s8(ph32, y1->qs);
|
||||||
|
|
||||||
|
// dot product
|
||||||
|
const svint32_t dot0 = svdot_s32(svdup_n_s32(0), qx0, qy0);
|
||||||
|
const svint32_t dot1 = svdot_s32(svdup_n_s32(0), qx1, qy1);
|
||||||
|
|
||||||
|
sumv0 = svmla_n_f32_x(ph32, sumv0,
|
||||||
|
svcvt_f32_s32_x(ph32, dot0), GGML_CPU_FP16_TO_FP32(y0->d) * GGML_E8M0_TO_FP32_HALF(x0->e));
|
||||||
|
sumv1 = svmla_n_f32_x(ph32, sumv1,
|
||||||
|
svcvt_f32_s32_x(ph32, dot1), GGML_CPU_FP16_TO_FP32(y1->d) * GGML_E8M0_TO_FP32_HALF(x1->e));
|
||||||
|
}
|
||||||
|
sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
|
||||||
|
} break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
assert(false && "Unsupported vector length");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
#elif defined (__ARM_NEON)
|
||||||
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
|
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||||
uint8x16x2_t q4bits;
|
uint8x16x2_t q4bits;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue