diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 7b6f2c6fee..f8e24830ab 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -784,6 +784,133 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q8_0_4x4_q8_0(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx; + + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float32x4_t acc = vdupq_n_f32(0); + for (int b = 0; b < nb; b++) { + int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs); + int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64); + float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + + int8x16x2_t a = vld1q_s8_x2(a_ptr->qs); + float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + + int32x4_t ret = vdupq_n_s32(0); + + ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0); + ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1); + ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2); + ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3); + + ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0); + ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1); + ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2); + ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3); + + acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); + a_ptr++; + b_ptr++; + } + vst1q_f32(s, acc); + s += ncols_interleaved; + } + return; + +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q8_0_4x8_q8_0(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx; + + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float32x4_t acc = vdupq_n_f32(0); + + for (int b = 0; b < nb; b++) { + int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs); + int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64); + float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + + int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs); + int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]); + int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]); + int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]); + int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]); + float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + + int32x4_t ret0 = vdupq_n_s32(0); + int32x4_t ret1 = vdupq_n_s32(0); + + // 0..7 + ret0 = vdotq_s32(ret0, b_low.val[0], a0); + ret1 = vdotq_s32(ret1, b_low.val[1], a0); + // 8..15 + ret0 = vdotq_s32(ret0, b_low.val[2], a1); + ret1 = vdotq_s32(ret1, b_low.val[3], a1); + // 16..23 + ret0 = vdotq_s32(ret0, b_high.val[0], a2); + ret1 = vdotq_s32(ret1, b_high.val[1], a2); + // 24..31 + ret0 = vdotq_s32(ret0, b_high.val[2], a3); + ret1 = vdotq_s32(ret1, b_high.val[3], a3); + + int32x4_t ret = vpaddq_s32(ret0, ret1); + + acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); + a_ptr++; + b_ptr++; + } + vst1q_f32(s, acc); + s += ncols_interleaved; + } + return; + +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -2609,132 +2736,6 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_q8_0_4x4_q8_0(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx; - - for (int c = 0; c < nc; c += ncols_interleaved) { - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - float32x4_t acc = vdupq_n_f32(0); - for (int b = 0; b < nb; b++) { - int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs); - int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64); - float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); - - int8x16x2_t a = vld1q_s8_x2(a_ptr->qs); - float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); - - int32x4_t ret = vdupq_n_s32(0); - - ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0); - ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1); - ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2); - ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3); - - ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0); - ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1); - ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2); - ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3); - - acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); - a_ptr++; - b_ptr++; - } - vst1q_f32(s, acc); - s += ncols_interleaved; - } - return; - -#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - -void ggml_gemv_q8_0_4x8_q8_0(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx; - - for (int c = 0; c < nc; c += ncols_interleaved) { - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - float32x4_t acc = vdupq_n_f32(0); - - for (int b = 0; b < nb; b++) { - int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs); - int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64); - float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); - - int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs); - int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]); - int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]); - int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]); - int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]); - float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); - - int32x4_t ret0 = vdupq_n_s32(0); - int32x4_t ret1 = vdupq_n_s32(0); - - // 0..7 - ret0 = vdotq_s32(ret0, b_low.val[0], a0); - ret1 = vdotq_s32(ret1, b_low.val[1], a0); - // 8..15 - ret0 = vdotq_s32(ret0, b_low.val[2], a1); - ret1 = vdotq_s32(ret1, b_low.val[3], a1); - // 16..23 - ret0 = vdotq_s32(ret0, b_high.val[0], a2); - ret1 = vdotq_s32(ret1, b_high.val[1], a2); - // 24..31 - ret0 = vdotq_s32(ret0, b_high.val[2], a3); - ret1 = vdotq_s32(ret1, b_high.val[3], a3); - - int32x4_t ret = vpaddq_s32(ret0, ret1); - - acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); - a_ptr++; - b_ptr++; - } - vst1q_f32(s, acc); - s += ncols_interleaved; - } - return; - -#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s,