diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index f88cd7f627..fefe39785a 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -617,7 +617,84 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[8]; + int sumi1,sumi2,sumi3,sumi4; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *)vy; + for(int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (4 * blocklen)); k++) { + const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ; + const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16; + const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32; + const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi3 = 0; + sumi4 = 0; + sumi = 0; + int offset = ((k / 2) % 2) + j * 2; + for (int i = 0; i < blocklen; ++i) { + const int hbits_index = k * ncols_interleaved * blocklen + j * blocklen + i; + const int lbits_index = (hbits_index / 32) * 64 + (hbits_index % 32); + const int v0_hbits = (int8_t) ((b_ptr[l].qh[hbits_index] & 3) << 4); + const int v1_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 2 ) & 3) << 4); + const int v2_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 4 ) & 3) << 4); + const int v3_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 6 ) & 3) << 4); + + const int v0_lbits = (int8_t) (b_ptr[l].qh[lbits_index] & 0xF); + const int v1_lbits = (int8_t) (b_ptr[l].qh[lbits_index + 32] & 0xF); + const int v2_lbits = (int8_t) ((b_ptr[l].qh[lbits_index] >> 4) & 0xF); + const int v3_lbits = (int8_t) ((b_ptr[l].qh[lbits_index + 32] >> 4) & 0xF); + + const int v0 = v0_hbits | v0_lbits; + const int v1 = v1_hbits | v1_lbits; + const int v2 = v2_hbits | v2_lbits; + const int v3 = v3_hbits | v3_lbits; + + sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]); + sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]); + sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]); + + sumi1 = sumi1 * scales_0[offset]; + sumi2 = sumi2 * scales_1[offset]; + sumi3 = sumi3 * scales_2[offset]; + sumi4 = sumi4 * scales_3[offset]; + sumi += sumi1 + sumi2 + sumi3 + sumi4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } } void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {