diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index a351d54823..b3df909f8c 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -217,7 +217,7 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; -#if defined(__ARM_NEON) +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) const int blck_size_interleave = 8; float32x4_t srcv[4][64]; // 64 = QK_K/4 float iscale[4]; @@ -334,11 +334,10 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR #else // NOTE: - // Current C implementation is actually aligned with x86 AVX2 design, but differs from above ARM NEON design. - // This is because the [bsums] layout is different in block_q8_Kx4 for the 2 designs. - // As NEON is supported in almost all the modern ARM platforms, this generic path can be rarely arrived nowadays. - // (The exceptional cases aren't suitable for AI work) - // However logically we may still need a corresponding generic version for ARM, called xxx_generic_arm for example. + // Current C impl of Q8_K quanti is originally designed to work with block_q4_Kx8 in x86 AVX design, and differs from + // above Q8_K quanti logic in AArch64 NEON design, which is designed to work with block_q4_Kx4. The main difference is in + // the process of their "[bsums] layout". Hoever, we can still reuse the x86 C impl for AArch64, as long as we access the + // "[bsums] layout" correctly in ggml_gemm_q4_K_4x8_q8_K_generic(). UNUSED(nb); UNUSED(y); ggml_quantize_mat_q8_K_4x8_generic(x, vy, k); @@ -564,6 +563,76 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float * res_ptr = s; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + float32x4_t sumf = vdupq_n_f32(0); + for (int l = 0; l < nb; l++) { + uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0); + uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16); + uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32); + uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48); + + int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4); + int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F); + int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4); + int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F); + int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4); + int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F); + int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4); + int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F); + + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16); + + int32x4_t sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0); + sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0); + sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1); + sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1); + sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2); + sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2); + sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3); + sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3); + + float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d)); + float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); + float32x4_t d = a_d * b_d; + + sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi)); + } + + vst1q_f32(res_ptr + x * 4, sumf); + } + return; +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -583,7 +652,7 @@ void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo UNUSED(ncols_interleaved); UNUSED(blocklen); -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { const block_q4_Kx4 *GGML_RESTRICT q4 = (const block_q4_Kx4*) vx; const uint8x16_t m4b = vdupq_n_u8(0xf); @@ -683,76 +752,6 @@ void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - float * res_ptr = s; - - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - - float32x4_t sumf = vdupq_n_f32(0); - for (int l = 0; l < nb; l++) { - uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0); - uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16); - uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32); - uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48); - - int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4); - int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F); - int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4); - int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F); - int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4); - int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F); - int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4); - int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F); - - int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0); - int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16); - - int32x4_t sumi = vdupq_n_s32(0); - sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0); - sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0); - sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1); - sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1); - sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2); - sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2); - sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3); - sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3); - - float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d)); - float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); - float32x4_t d = a_d * b_d; - - sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi)); - } - - vst1q_f32(res_ptr + x * 4, sumf); - } - return; -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) - ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { constexpr int qk = QK_K; const int nb = n / qk; @@ -2507,6 +2506,82 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + float32x4_t sumf[4]; + for (int m = 0; m < 4; m++) { + sumf[m] = vdupq_n_f32(0); + } + + for (int l = 0; l < nb; l++) { + float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d)); + float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); + + int32x4_t sumi_0 = vdupq_n_s32(0); + int32x4_t sumi_1 = vdupq_n_s32(0); + int32x4_t sumi_2 = vdupq_n_s32(0); + int32x4_t sumi_3 = vdupq_n_s32(0); + + for (int k = 0; k < 4; k++) { + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64); + + uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); + int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4); + int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF); + + sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3); + sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3); + } + + sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0)); + sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1)); + sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2)); + sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3)); + } + + for (int m = 0; m < 4; m++) { + vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]); + } + } + } + return; +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -2527,7 +2602,7 @@ void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__ARM_FEATURE_MATMUL_INT8) +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) const block_q8_Kx4 * GGML_RESTRICT q8_ptr_start = (const block_q8_Kx4 *) vy; const block_q4_Kx4 * GGML_RESTRICT q4_ptr_start = (const block_q4_Kx4 *) vx; @@ -2800,82 +2875,6 @@ void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - - float32x4_t sumf[4]; - for (int m = 0; m < 4; m++) { - sumf[m] = vdupq_n_f32(0); - } - - for (int l = 0; l < nb; l++) { - float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d)); - float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); - - int32x4_t sumi_0 = vdupq_n_s32(0); - int32x4_t sumi_1 = vdupq_n_s32(0); - int32x4_t sumi_2 = vdupq_n_s32(0); - int32x4_t sumi_3 = vdupq_n_s32(0); - - for (int k = 0; k < 4; k++) { - int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0); - int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64); - - uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); - int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4); - int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF); - - sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0); - sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1); - sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2); - sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3); - sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0); - sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1); - sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2); - sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3); - } - - sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0)); - sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1)); - sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2)); - sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3)); - } - - for (int m = 0; m < 4; m++) { - vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]); - } - } - } - return; -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) - ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { constexpr int qk = QK_K; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 74285101af..7e2d0bf4ef 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1099,11 +1099,14 @@ void ggml_gemm_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } // prepare partial results for tail work + // + // NOTE: + // the "[bsums] layout" here is from ggml_quantize_mat_q8_K_4x8_generic(). for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { for (int i = 0; i < QK_K / 32; i++) { - const int16_t bsums = a_ptr[n].bsums[i * 2 + m * 16] + a_ptr[n].bsums[i * 2 + 1 + m * 16]; - sum_minf[m][j] += mins[j][i] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d[m]; + const int16_t *bsums = a_ptr[n].bsums + (i * 8) - ((i % 2) * 6) + (m * 4); + sum_minf[m][j] += mins[j][i] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d[m]; } } } @@ -2314,10 +2317,6 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); -} - template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -2326,6 +2325,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_K_4x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -2866,9 +2869,9 @@ static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buf GGML_UNUSED(buft); } -// For the aarch64 q4_k repack implementation only: -// New tensor storage size calculation after q4_kx4 repacking. -// Because sizeof(block_q4_Kx4) is a bit larger than default sizeof(block_q4_K)*4. +// Below func is for the aarch64 q4_K_4x8_q8_K repack case only: +// Tensor storage after repacking is a bit larger than before -- sizeof(block_q4_Kx4) > sizeof(block_q4_K)*4 +// This is due to member "scales" are pre-decoded in repacking stage, not in execution stage. static inline size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) { size_t nbytes; const size_t blck_size = 256; @@ -2880,7 +2883,9 @@ static inline size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) { static size_t ggml_backend_cpu_aarch64_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { if (tensor->type == GGML_TYPE_Q4_K) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { - return ggml_nbytes_q4_kx4(tensor); + if (tensor->ne[1] % 4 == 0) { + return ggml_nbytes_q4_kx4(tensor); // for q4_K_4x8_q8_K only + } } } return ggml_nbytes(tensor); diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 5323fa39e0..dd467f479a 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -39,35 +39,8 @@ using block_q8_0x8 = block<8, 8>; struct block_q4_Kx4 { ggml_half d[4]; // super-block scale for quantized scales ggml_half dmin[4]; // super-block scale for quantized mins - int8_t scales[64]; // scales and mins, quantized with 8 bits (recover from 6-bit during repack) TODO: consider if uint8_t? + int8_t scales[64]; // scales and mins, quantized with 8 bits (recover from 6-bit during repack) uint8_t qs[512]; // 4--bit quants - - /********************************************************layout***************************************************************/ - // low <-------------------------------------------------------------------------------------> high - // - // d: |s0|s1|s2|s3| - // - // dmin: |s0|s1|s2|s3| - // - // scales: |-------- d --------|-------- m --------| - // |s0b0|s1b0|s2b0|s3b0|s0b0|s1b0|s2b0|s3b0| - // |s0b1|s1b1|s2b1|s3b1|s0b1|s1b1|s2b1|s3b1| - // ...... - // |s0b7|s1b7|s2b7|s3b7|s0b7|s1b7|s2b7|s3b7| - // - // qs: - // |s0w0 |s0w16|s0w1 |s0w17|s0w2 |s0w18|s0w3 |s0w19|s0w4 |s0w20|s0w5 |s0w21|s0w6 |s0w22|s0w7 |s0w23| --- 8B from s0 - // |s1w0 |s1w16|s1w1 |s1w17|s1w2 |s1w18|s1w3 |s1w19|s1w4 |s1w20|s1w5 |s1w21|s1w6 |s1w22|s1w7 |s1w23| --- 8B from s1 - // |s2w0 |s2w16|s2w1 |s2w17|s2w2 |s2w18|s2w3 |s2w19|s2w4 |s2w20|s2w5 |s2w21|s2w6 |s2w22|s2w7 |s2w23| --- 8B from s2 - // |s3w0 |s3w16|s3w1 |s3w17|s3w2 |s3w18|s3w3 |s3w19|s3w4 |s3w20|s3w5 |s3w21|s3w6 |s3w22|s3w7 |s3w23| --- 8B from s3 - // |s0w8 |s0w24|s0w9 |s0w25|s0w10|s0w26|s0w11|s0w27|s0w12|s0w28|s0w13|s0w29|s0w14|s0w30|s0w15|s0w31| --- 8B from s0 - // |s1w8 |s1w24|s1w9 |s1w25|s1w10|s1w26|s1w11|s1w27|s1w12|s1w28|s1w13|s1w29|s1w14|s1w30|s1w15|s1w31| --- 8B from s1 - // |s2w8 |s2w24|s2w9 |s2w25|s2w10|s2w26|s2w11|s2w27|s2w12|s2w28|s2w13|s2w29|s2w14|s2w30|s2w15|s2w31| --- 8B from s2 - // |s3w8 |s3w24|s3w9 |s3w25|s3w10|s3w26|s3w11|s3w27|s3w12|s3w28|s3w13|s3w29|s3w14|s3w30|s3w15|s3w31| --- 8B from s3 - // - // ...... - // - /*****************************************************************************************************************************/ }; struct block_q4_Kx8 {