From d19cdcfac72421cff035cbe7876eb7bbbb351243 Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Tue, 3 Feb 2026 01:42:41 +0500 Subject: [PATCH] ggml-cpu: extend rvv gemm, gemv to other vlens Co-authored-by: Rehan Qasim --- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 1423 ++++++++------ ggml/src/ggml-cpu/repack.cpp | 2390 +++++++++++------------ ggml/src/ggml-cpu/repack.h | 176 +- 3 files changed, 2051 insertions(+), 1938 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index c37488cae5..3401c35876 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -203,10 +203,10 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo } #if defined __riscv_zvfh -void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +template +static inline void ggml_gemv_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 16; const int blocklen = 1; assert (n % qk == 0); @@ -224,230 +224,56 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); // 1x16 Accumulator - vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { // 1x16 Integer Accumulator - vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); // Accumulation loop. for (int i = 0; i < QK4_0 / 2; i++) { // Load `b_ptr`. - const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); - const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); + const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, ncols_interleaved), 4, ncols_interleaved); const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); - sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, 16); - sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, 16); + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, ncols_interleaved); + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, ncols_interleaved); } - const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); + const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, ncols_interleaved); - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); - const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, ncols_interleaved); - sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); } - __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); } } -void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - const block_q8_K * a_ptr = (const block_q8_K *) vy; - - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); - - // 1x16 Accumulator - vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); - - for (int l = 0; l < nb; l++) { - vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, 16); - - // Load `dmin`. - const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( - __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d, 16); - - // We process 4 sub-blocks at once. - for (int j = 0; j < QK_K / 128; j++) { - // Extract the scales and the mins. - // - // Low bits. - vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); - vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); - vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); - - // High bits. - vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); - vuint8m2_t scales_hi; - vuint8m2_t mins_hi; - if (!j) { - scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); - mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); - } else { - scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); - mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); - } - vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); - vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); - - // Reduce the mins and multiply with `dmin`. - // - // Correct in `sumf`. - vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, 16); - bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), 16); - bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), 16); - bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), 16); - bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), 16); - - sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, 16), 16), 16); - - // Accumulation for 2 sub-blocks. - // - // This might overflow, so we accumulate in two steps. - // - // Recheck. - for (int k = 0; k < 2; k++) { - vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - - for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); - const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); - const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); - - sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, 16); - sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, 16); - } - - sumi = __riscv_vwmacc_vv_i32m2(sumi, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), - sumi_s_0_16, 16); - sumi = __riscv_vwmacc_vv_i32m2(sumi, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), - sumi_s_1_16, 16); - } - // Accumulation for 2 sub-blocks. - // - // This might overflow, so we accumulate in two steps. - // - // Recheck. - for (int k = 0; k < 2; k++) { - vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - - for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); - const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); - const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); - - sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, 16); - sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, 16); - } - - sumi = __riscv_vwmacc_vv_i32m2(sumi, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), - sumi_s_0_16, 16); - sumi = __riscv_vwmacc_vv_i32m2(sumi, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), - sumi_s_1_16, 16); - } - } - - const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], 16), 16); - const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, 16); - - sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); - } - - __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); - } +void ggml_gemv_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_0_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_0_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_0_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_0_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +template +void ggml_gemv_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); - - // 1x16 Accumulator1 - vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); - - for (int l = 0; l < nb; l++) { - // 1x16 integer accumulator - vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); - - // Accumulation loop. - for (int i = 0; i < QK4_NL / 2; i++) { - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16); - const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); - const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); - // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); - // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); - - const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], 16); - const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], 16); - sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16); - } - - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); - const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); - - sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); - } - - __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); - } -} - -void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; const int blocklen = 1; assert (n % qk == 0); @@ -466,55 +292,68 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + const block<8, ncols_interleaved> * b_ptr = (const block<8, ncols_interleaved> *) vx + (x * nb); // 1x16 Accumulator - vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { // 1x16 Integer Accumulator - vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); // Accumulation loop. for (int i = 0; i < QK8_0; i++) { // Load `b_ptr`. - const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); - sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16); + sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], ncols_interleaved), ncols_interleaved); } - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); - const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, ncols_interleaved); - sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); } - __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); } } -void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemv_q8_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q8_0_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q8_0_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q8_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q8_0_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q8_0_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +} + +template +void ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert(n % QK_K == 0); assert(nr == 1); assert(nc % 16 == 0); UNUSED(bs); - const int N_COLS_TILE = 16; const int num_k_blocks = n / QK_K; - const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); - for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + const size_t vl = __riscv_vsetvl_e32m2(ncols_interleaved); + for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { const block_q8_K* lhs_base_ptr = (const block_q8_K*)vy; - const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks; + const block_q2_Kx* rhs_base_ptr = (const block_q2_Kx*)vx + (col_tile / ncols_interleaved) * num_k_blocks; vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl); for (int k_block = 0; k_block < num_k_blocks; ++k_block) { const block_q8_K* lhs_current = &lhs_base_ptr[k_block]; - const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block]; + const block_q2_Kx* rhs_current = &rhs_base_ptr[k_block]; // 1. Prepare Global Min Scales vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); @@ -665,6 +504,228 @@ void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl); } } + +void ggml_gemv_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q2_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q2_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q2_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q2_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q2_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +} + +template +void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx * b_ptr = (const block_q4_Kx *) vx + (x * nb); + + // 1x16 Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + // We process 4 sub-blocks at once. + const int vl = ncols_interleaved * 4; + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * vl], vl); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, vl); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, vl); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[vl * 2], vl); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, vl), 4, vl); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, vl), 2, vl); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, vl); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, vl), 2, vl); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, vl), vl); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, vl), vl)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); + + const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved); + sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, ncols_interleaved), ncols_interleaved), ncols_interleaved); + + // Accumulation for 2 sub-blocks. + { + // 4x16 integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved)); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, ncols_interleaved); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, ncols_interleaved); + } + + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_s_0_16, 16); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_s_1_16, 16); + } + { + // 4x16 integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved)); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, ncols_interleaved); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, ncols_interleaved); + } + + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_s_0_16, ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_s_1_16, ncols_interleaved); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], ncols_interleaved), ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, ncols_interleaved); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); + } +} + +void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +} + +template +void ggml_gemv_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx * b_ptr = (const block_iq4_nlx *) vx + (x * nb); + + // 1x16 Accumulator1 + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + // 1x16 integer accumulator + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + + // Accumulation loop. + for (int i = 0; i < QK4_NL / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved), ncols_interleaved); + + const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], ncols_interleaved); + const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], ncols_interleaved); + sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, ncols_interleaved), ncols_interleaved); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, ncols_interleaved); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); + } +} + +void ggml_gemv_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_iq4_nl_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_iq4_nl_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_iq4_nl_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_iq4_nl_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +} #endif void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -894,10 +955,10 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo } #if defined __riscv_zvfh -void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +template +void ggml_gemm_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 16; const int blocklen = 1; assert (n % qk == 0); @@ -917,426 +978,86 @@ void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); // 4x16 Accumulators - vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { // 4x16 integer accumulators - vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); // Accumulation loop. for (int i = 0; i < QK4_0 / 2; i++) { // Load `b_ptr`. - const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); - const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); - const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); + const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, ncols_interleaved), 4, ncols_interleaved); + const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, ncols_interleaved); - sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, 16); - sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, 16); - sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, 16); - sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, 16); + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, ncols_interleaved); + sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, ncols_interleaved); + sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, ncols_interleaved); + sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, ncols_interleaved); - sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, 16); - sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, 16); - sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, 16); - sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, 16); + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, ncols_interleaved); + sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, ncols_interleaved); + sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, ncols_interleaved); + sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, ncols_interleaved); } // Do the final accumulation in i32 to prevent overflow. - const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); - const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, 16); - const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, 16); - const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, 16); + const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, ncols_interleaved); + const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, ncols_interleaved); + const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, ncols_interleaved); + const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, ncols_interleaved); - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); - const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); - const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); - const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); - const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], ncols_interleaved); - sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); - sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); - sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); - sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); } - __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); } } } -void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); - - // 4x16 Accumulators - vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - - for (int l = 0; l < nb; l++) { - vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, 16); - - // Load `dmin`. - const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16); - - // We process 4 sub-blocks at once. - for (int j = 0; j < QK_K / 128; j++) { - // Extract the scales and the mins. - // - // Low bits. - vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); - vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); - vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); - - // High bits. - vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); - vuint8m2_t scales_hi; - vuint8m2_t mins_hi; - if (!j) { - scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); - mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); - } else { - scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); - mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); - } - vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); - vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); - - // Reduce the mins and multiply with `dmin`. - // - // Correct in `sumf`. - vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, 16); - - bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, - a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); - bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, - a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); - bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, - a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); - bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, - a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); - bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, - a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); - bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, - a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); - bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, - a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); - bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, - a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); - bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, - a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); - bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, - a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); - bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, - a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); - bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, - a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); - bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, - a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); - bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, - a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); - bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, - a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); - bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, - a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); - - const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], 16); - const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], 16); - const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], 16); - const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], 16); - - sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, 16), 16), 16); - sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, 16), 16), 16); - sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, 16), 16), 16); - sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, 16), 16), 16); - - - // Accumulation for 2 sub-blocks. - // - // This might overflow, so we accumulate in two steps. - // - // Recheck. - for (int k = 0; k < 2; k++) { - // 4x16 integer accumulators - vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - - for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); - const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); - const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); - - sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, 16); - sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, 16); - sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, 16); - sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, 16); - - sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, 16); - sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, 16); - sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, 16); - sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, 16); - } - - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), - sumi_0_s_0_16, 16); - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), - sumi_0_s_1_16, 16); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), - sumi_1_s_0_16, 16); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), - sumi_1_s_1_16, 16); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), - sumi_2_s_0_16, 16); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), - sumi_2_s_1_16, 16); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), - sumi_3_s_0_16, 16); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), - sumi_3_s_1_16, 16); - } - // Accumulation for 2 sub-blocks. - // - // This might overflow, so we accumulate in two steps. - // - // Recheck. - for (int k = 0; k < 2; k++) { - // 4x16 integer accumulators - vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - - for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); - const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); - const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); - - sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, 16); - sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, 16); - sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, 16); - sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, 16); - - sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, 16); - sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, 16); - sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, 16); - sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, 16); - } - - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), - sumi_0_s_0_16, 16); - sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), - sumi_0_s_1_16, 16); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), - sumi_1_s_0_16, 16); - sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), - sumi_1_s_1_16, 16); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), - sumi_2_s_0_16, 16); - sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), - sumi_2_s_1_16, 16); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), - sumi_3_s_0_16, 16); - sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), - sumi_3_s_1_16, 16); - } - } - - const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16), 16); - const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], 16); - const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], 16); - const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], 16); - const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], 16); - - sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); - sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); - sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); - sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); - } - - __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); - } - } +void ggml_gemm_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_0_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_0_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_0_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_0_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +template +void ggml_gemm_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); - - // 4x16 Accumulators - vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - - for (int l = 0; l < nb; l++) { - // 4x16 integer accumulators - vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16); - vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16); - vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); - vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16); - - // Accumulation loop. - for (int i = 0; i < QK4_NL / 2; i++) { - // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16); - const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); - const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); - // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); - // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); - - const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], 16); - const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], 16); - const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], 16); - const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], 16); - - const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], 16); - const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], 16); - const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], 16); - const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], 16); - - sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, 16), 16); - sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, 16), 16); - sumi_2 = __riscv_vadd_vv_i32m2(sumi_2, __riscv_vwadd_vv_i32m2(sumi_2_lo, sumi_2_hi, 16), 16); - sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16); - } - - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); - const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); - const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); - const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); - const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); - - sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); - sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); - sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); - sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); - } - - __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); - } - } -} - -void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; const int blocklen = 1; assert (n % qk == 0); @@ -1356,70 +1077,82 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + const block<8, ncols_interleaved> * b_ptr = (const block<8, ncols_interleaved> *) vx + (x * nb); // 4x16 Accumulators - vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); - vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { // 4x16 Integer Accumulators - vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16); - vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16); - vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); - vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); // Accumulation loop. for (int i = 0; i < QK8_0; i++) { // Load `b_ptr`. - const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); - // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); + const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); - sumi_0 = __riscv_vwadd_wv_i32m2(sumi_0, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 0], 16), 16); - sumi_1 = __riscv_vwadd_wv_i32m2(sumi_1, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 1], 16), 16); - sumi_2 = __riscv_vwadd_wv_i32m2(sumi_2, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 2], 16), 16); - sumi_3 = __riscv_vwadd_wv_i32m2(sumi_3, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 3], 16), 16); + sumi_0 = __riscv_vwadd_wv_i32m2(sumi_0, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 0], ncols_interleaved), ncols_interleaved); + sumi_1 = __riscv_vwadd_wv_i32m2(sumi_1, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 1], ncols_interleaved), ncols_interleaved); + sumi_2 = __riscv_vwadd_wv_i32m2(sumi_2, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 2], ncols_interleaved), ncols_interleaved); + sumi_3 = __riscv_vwadd_wv_i32m2(sumi_3, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 3], ncols_interleaved), ncols_interleaved); } - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); - const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); - const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); - const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); - const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], ncols_interleaved); - sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); - sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); - sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); - sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); } - __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); - __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); } } } -void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemm_q8_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q8_0_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q8_0_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q8_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q8_0_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q8_0_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +} + +template +void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert(n % QK_K == 0); const int num_k_blocks = n / QK_K; const int N_ROWS_TILE = 4; - const int N_COLS_TILE = 16; assert(nr % N_ROWS_TILE == 0); - assert(nc % N_COLS_TILE == 0); + assert(nc % ncols_interleaved == 0); - const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + const size_t vl = __riscv_vsetvl_e32m2(ncols_interleaved); // --- Tiling Loops --- #pragma GCC unroll 1 for (int row_tile = 0; row_tile < nr; row_tile += N_ROWS_TILE) { #pragma GCC unroll 1 - for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { // Base Pointers const block_q8_Kx4* lhs_base_ptr = (const block_q8_Kx4*)vy + (row_tile / N_ROWS_TILE) * num_k_blocks; - const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks; + const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / ncols_interleaved) * num_k_blocks; // Persistent Float Accumulators vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); @@ -1700,4 +1433,400 @@ void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } } } + +void ggml_gemm_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q2_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q2_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q2_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q2_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q2_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +} + +template +void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx * b_ptr = (const block_q4_Kx *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + // We process 4 sub-blocks at once. + const int vl = ncols_interleaved * 4; + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * vl], vl); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, vl); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, vl); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[vl], vl); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, vl), 4, vl); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, vl), 2, vl); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, vl); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, vl), 2, vl); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, vl), vl); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, vl), vl)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + + const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[3], ncols_interleaved); + + sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, ncols_interleaved), ncols_interleaved), ncols_interleaved); + sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, ncols_interleaved), ncols_interleaved), ncols_interleaved); + sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, ncols_interleaved), ncols_interleaved), ncols_interleaved); + sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, ncols_interleaved), ncols_interleaved), ncols_interleaved); + + + // Accumulation for 2 sub-blocks. + { + // 4x8 integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved)); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, ncols_interleaved); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, ncols_interleaved); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, ncols_interleaved); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, ncols_interleaved); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, ncols_interleaved); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, ncols_interleaved); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, ncols_interleaved); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, ncols_interleaved); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_0_s_0_16, ncols_interleaved); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_0_s_1_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_1_s_0_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_1_s_1_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_2_s_0_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_2_s_1_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_3_s_0_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_3_s_1_16, ncols_interleaved); + } + { + // 4x16 integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved)); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, ncols_interleaved); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, ncols_interleaved); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, ncols_interleaved); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, ncols_interleaved); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, ncols_interleaved); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, ncols_interleaved); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, ncols_interleaved); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, ncols_interleaved); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_0_s_0_16, ncols_interleaved); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_0_s_1_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_1_s_0_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_1_s_1_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_2_s_0_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_2_s_1_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_3_s_0_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_3_s_1_16, ncols_interleaved); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved), ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], ncols_interleaved); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); + } + } +} + +void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +} + +template +void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx * b_ptr = (const block_iq4_nlx *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + // 4x16 integer accumulators + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); + + // Accumulation loop. + for (int i = 0; i < QK4_NL / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved), ncols_interleaved); + // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); + // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); + + const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], ncols_interleaved); + const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], ncols_interleaved); + const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], ncols_interleaved); + const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], ncols_interleaved); + + const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], ncols_interleaved); + const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], ncols_interleaved); + const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], ncols_interleaved); + const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], ncols_interleaved); + + sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, ncols_interleaved), ncols_interleaved); + sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, ncols_interleaved), ncols_interleaved); + sumi_2 = __riscv_vadd_vv_i32m2(sumi_2, __riscv_vwadd_vv_i32m2(sumi_2_lo, sumi_2_hi, ncols_interleaved), ncols_interleaved); + sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, ncols_interleaved), ncols_interleaved); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], ncols_interleaved); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); + } + } + return; +#endif + ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_iq4_nl_Mx1_q8_0<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_iq4_nl_Mx1_q8_0<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_iq4_nl_Mx1_q8_0<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_iq4_nl_Mx1_q8_0<64>(n, s, bs, vx, vy, nr, nc); +} #endif diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index f18758f16b..eaf88f174a 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1,3 +1,4 @@ +#include "ggml.h" #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP #include "ggml-common.h" @@ -353,287 +354,247 @@ template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_K>(const float * GGML_RESTR } #endif -template -static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - constexpr int blocklen = M; - constexpr int ncols_interleaved = N; - const int qk = QK_K; - const int nb = n / qk; - const int blocks_per_half = 64 / blocklen; +#if defined __riscv_zvfh +template +static inline void ggml_gemv_q4_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +template +static inline void ggml_gemv_q8_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert(nr == 1); assert(n % qk == 0); assert(nc % ncols_interleaved == 0); UNUSED(bs); UNUSED(nr); - float sumf[8]; + float sumf[16]; + int sumi; - const block_q8_K * a_ptr = (const block_q8_K *) vy; + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + const block<8, ncols_interleaved> * b_ptr = (const block<8, ncols_interleaved> *) vx + (x * nb); for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0f; + sumf[j] = 0.0; } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; - const int base_h = base_l + 64; - - const int scale_idx_l = base_l / 16; - const int scale_idx_h = base_h / 16; - - const int qh_shift_l = ((base_l % 128) / 32) * 2; - const int qh_shift_h = ((base_h % 128) / 32) * 2; - - const int qh_half_l = (base_l / 128) * 32; - const int qh_half_h = (base_h / 128) * 32; - + for (int k = 0; k < (qk / blocklen); k++) { for (int j = 0; j < ncols_interleaved; j++) { - const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; - const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; - - int sumi_l = 0; - int sumi_h = 0; - - for (int i = 0; i < blocklen; i++) { - const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; - const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; - const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; - - const int qh_idx_l = qh_half_l + ((base_l + i) % 32); - const int qh_chunk_l = qh_idx_l / blocklen; - const int qh_pos_l = qh_idx_l % blocklen; - const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; - const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; - - const int qh_idx_h = qh_half_h + ((base_h + i) % 32); - const int qh_chunk_h = qh_idx_h / blocklen; - const int qh_pos_h = qh_idx_h % blocklen; - const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; - const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; - - const int q_l = ((hi_2_l << 4) | l_4) - 32; - const int q_h = ((hi_2_h << 4) | hi_4) - 32; - - const int8_t a_l = a_ptr[l].qs[base_l + i]; - const int8_t a_h = a_ptr[l].qs[base_h + i]; - - sumi_l += q_l * a_l; - sumi_h += q_h * a_h; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * blocklen + i]; } - - sumf[j] += - (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); } } } - for (int j = 0; j < ncols_interleaved; j++) { s[x * ncols_interleaved + j] = sumf[j]; } } } -template -static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - constexpr int blocklen = M; - constexpr int ncols_interleaved = N; - const int qk = QK_K; - const int nb = n / qk; - const int blocks_per_half = 64 / blocklen; - const int q8_half_stride = 512; - const int q8_low_high_step = 256; - - assert(n % qk == 0); - assert(nr % 4 == 0); +template +static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr == 1); assert(nc % ncols_interleaved == 0); UNUSED(bs); - float sumf[4][8]; + const int nb = n / QK_K; + const block_q2_Kx * x = (const block_q2_Kx *)vx; + const block_q8_K * y = (const block_q8_K *)vy; - for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + // Layout: Even-Low(0,2,4,6), Odd-Low(1,3,5,7), Even-High(8...), Odd-High(9...) + const int sb_perm[16] = { + 0, 4, 1, 5, 2, 6, 3, 7, // 0-7 + 8, 12, 9, 13, 10, 14, 11, 15 // 8-15 + }; - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0f; + for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { + const block_q2_Kx * x_ptr = x + (col_tile / ncols_interleaved) * nb; + const block_q8_K * y_ptr = y; + + float sumf[16] = {0}; + + // Loop over K-blocks + for (int k_block = 0; k_block < nb; ++k_block) { + int32_t isum[16] = {0}; + int32_t summs[16] = {0}; + + const uint8_t * qs_rhs = x_ptr[k_block].qs; + const uint8_t * sc_rhs = x_ptr[k_block].scales; + const int8_t * qs_lhs = y_ptr[k_block].qs; + const int16_t * bs_lhs = y_ptr[k_block].bsums; + + // Iterate over sub-blocks 0..15 + for (int sb = 0; sb < 16; ++sb) { + // Correction Term + int16_t bsum = bs_lhs[sb]; + int scale_offset = sb_perm[sb] * 16; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + summs[col] += bsum * (sc_val >> 4); // Min is high 4 bits } - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; - const int base_h = base_l + 64; + // Main Dot Product + // Calculate base offsets for Q2 unpacking based on SB + int byte_base; + if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; + else byte_base = (sb % 2 == 0) ? 32 : 48; - const int scale_idx_l = base_l / 16; - const int scale_idx_h = base_h / 16; + int shift = ((sb / 2) % 4) * 2; - const int qh_shift_l = ((base_l % 128) / 32) * 2; - const int qh_shift_h = ((base_h % 128) / 32) * 2; + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + int32_t d_sb = sc_val & 0xF; // Scale is low 4 bits - const int qh_half_l = (base_l / 128) * 32; - const int qh_half_h = (base_h / 128) * 32; + // Process 16 elements (l=0..15) + for (int l = 0; l < 16; ++l) { + // Q2: Interleaved by column. Byte `l` contains 4 k-values. + int qs_idx = (byte_base + l) * 16 + col; + uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; - const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4); + // Q8: Linear access + int k = sb * 16 + l; + int8_t q8_val = qs_lhs[k]; - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; - const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; - - int sumi_l = 0; - int sumi_h = 0; - - for (int i = 0; i < blocklen; i++) { - const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; - const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; - const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; - - const int qh_idx_l = qh_half_l + ((base_l + i) % 32); - const int qh_chunk_l = qh_idx_l / blocklen; - const int qh_pos_l = qh_idx_l % blocklen; - const int qh_offset_l = - qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; - const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; - - const int qh_idx_h = qh_half_h + ((base_h + i) % 32); - const int qh_chunk_h = qh_idx_h / blocklen; - const int qh_pos_h = qh_idx_h % blocklen; - const int qh_offset_h = - qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; - const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; - - const int q_l = ((hi_2_l << 4) | l_4) - 32; - const int q_h = ((hi_2_h << 4) | hi_4) - 32; - - const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i]; - const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step]; - - sumi_l += q_l * q8_l; - sumi_h += q_h * q8_h; - } - - sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * - a_ptr[l].d[m]; - } + isum[col] += q8_val * q2_val * d_sb; } } } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } + // Finalize K-Block + for (int col = 0; col < 16; ++col) { + float d_lhs = y_ptr[k_block].d; + float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); + float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); + + float d_all = d_lhs * d_rhs; + float d_min = d_lhs * dm_rhs; + + sumf[col] += (isum[col] * d_all) - (summs[col] * d_min); } } + + for (int col = 0; col < 16; ++col) { + s[col_tile + col] = sumf[col]; + } } } -template -static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - constexpr int blocklen = M; - constexpr int ncols_interleaved = N; - const int qk = QK_K; - const int nb = n / qk; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - +template +static inline void ggml_gemv_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + UNUSED(s); UNUSED(bs); + UNUSED(vx); + UNUSED(vy); UNUSED(nr); - - float sumf[ncols_interleaved]; - float sum_minf[ncols_interleaved]; - uint32_t utmp[32]; - int sumi1; - int sumi2; - int sumi; - + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + float sumf[ncols_interleaved]; + float sum_minf[ncols_interleaved]; + uint8_t scales[ncols_interleaved * 8]; + uint8_t mins[ncols_interleaved * 8]; + int sumi1; + int sumi2; + int sumi; const block_q8_K * a_ptr = (const block_q8_K *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); - + const block_q4_Kx * b_ptr = (const block_q4_Kx *) vx + (x * nb); for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0; - sum_minf[j] = 0.0; + sumf[j] = 0.0f; + sum_minf[j] = 0.0f; } for (int l = 0; l < nb; l++) { - for (int sb = 0; sb < 8; sb++) { - memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); - utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; - utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); - utmp[sb * 4 + 2] = uaux_0; - utmp[sb * 4 + 0] &= kmask1; + for (int i = 0; i < ncols_interleaved * 8; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; } - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - constexpr int scale_stride = 32; - uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; - uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; - - const int qh_shift = (k / (32 / blocklen)) * 2; + for (int i = 0; i < ncols_interleaved * 4; i++) { + scales[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x0C) << 2; + scales[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x30); + mins[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0xC0) >> 2; + } + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * ncols_interleaved]; for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; - - const int qh_idx = (k * blocklen + i) % 32; - const int qh_chunk = qh_idx / blocklen; - const int qh_pos = qh_idx % blocklen; - const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; - - const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; - const uint8_t h0 = (qh_val >> qh_shift) & 1; - const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; - - const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); - const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); - - const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i; - - sumi1 = (v0 * a_ptr[l].qs[q8_offset]); - sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); + sum_minf[j] += min[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * ncols_interleaved]; + uint8_t *scales_1 = &scales[(sb + 1) * ncols_interleaved]; + for (int i = 0; i < QK4_0; i++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + const int v0 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] >> 4); + sumi1 = (v0 * a_ptr[l].qs[sb * 32 + i]); + sumi2 = (v1 * a_ptr[l].qs[sb * 32 + 32 + i]); sumi1 = sumi1 * scales_0[j]; sumi2 = sumi2 * scales_1[j]; sumi += sumi1 + sumi2; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; - } - } - for (int sb = 0; sb < 8; sb++) { - uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; - for (int j = 0; j < ncols_interleaved; j++) { - sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * - GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; } } } @@ -643,99 +604,323 @@ static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int n, } } -template -static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - constexpr int blocklen = M; - constexpr int ncols_interleaved = N; - const int qk = QK_K; - const int nb = n / qk; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; +template +static inline void ggml_gemv_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx * b_ptr = (const block_iq4_nlx *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} +#endif + +#if defined __riscv_zvfh +template +static inline void ggml_gemm_q4_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][ncols_interleaved]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +template +static inline void ggml_gemm_q8_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; assert(n % qk == 0); assert(nr % 4 == 0); assert(nc % ncols_interleaved == 0); - float sumf[4][ncols_interleaved]; - float sum_minf[4][ncols_interleaved]; - uint32_t utmp[32]; - int sumi1; - int sumi2; - int sumi; + float sumf[4][ncols_interleaved]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; + } + sumf[m][j] += + sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +template +static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr % 4 == 0); + assert(nc % 16 == 0); + const int nb = n / QK_K; + const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; + const block_q8_Kx4 * y = (const block_q8_Kx4 *)vy; + + const int sb_perm[16] = { + 0, 4, 1, 5, 2, 6, 3, 7, + 8, 12, 9, 13, 10, 14, 11, 15 + }; + + // Iterate Rows in tiles of 4 + for (int row_tile = 0; row_tile < nr; row_tile += 4) { + // Iterate Columns in tiles of 16 + for (int col_tile = 0; col_tile < nc; col_tile += 16) { + + const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; + const block_q8_Kx4 * y_ptr = y + (row_tile / 4) * nb; + + float sumf[4][16]; + memset(sumf, 0, sizeof(sumf)); + + for (int k_block = 0; k_block < nb; ++k_block) { + int32_t isum[4][16]; + int32_t summs[4][16]; + memset(isum, 0, sizeof(isum)); + memset(summs, 0, sizeof(summs)); + + const uint8_t * qs_rhs = x_ptr[k_block].qs; + const uint8_t * sc_rhs = x_ptr[k_block].scales; + const int8_t * qs_lhs = y_ptr[k_block].qs; + const int16_t * bs_lhs = y_ptr[k_block].bsums; + + for (int sb = 0; sb < 16; ++sb) { + int scale_offset = sb_perm[sb] * 16; + + int byte_base; + if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; + else byte_base = (sb % 2 == 0) ? 32 : 48; + int shift = ((sb / 2) % 4) * 2; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + int32_t d_sb = sc_val & 0xF; + int32_t m_sb = sc_val >> 4; + + // Correction Term + for (int r = 0; r < 4; ++r) { + int bsum_idx = (sb / 4) * 16 + r * 4 + (sb % 4); + summs[r][col] += bs_lhs[bsum_idx] * m_sb; + } + + // Main Dot Product + for (int l = 0; l < 16; ++l) { + int qs_idx = (byte_base + l) * 16 + col; + uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; + + // Calculate Q8 index for this specific k and row + int k = sb * 16 + l; + int q8_idx = (k / 4) * 16 + (k % 4); + + for (int r = 0; r < 4; ++r) { + // Add r*4 to jump to the correct row within the 4x4 chunk + int8_t q8_val = qs_lhs[q8_idx + r * 4]; + isum[r][col] += q8_val * q2_val * d_sb; + } + } + } + } + + // Finalize K-Block + for (int col = 0; col < 16; ++col) { + float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); + float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); + + for (int r = 0; r < 4; ++r) { + float d_lhs = y_ptr[k_block].d[r]; + float d_all = d_lhs * d_rhs; + float d_min = d_lhs * dm_rhs; + sumf[r][col] += (isum[r][col] * d_all) - (summs[r][col] * d_min); + } + } + } + + for (int r = 0; r < 4; ++r) { + for (int col = 0; col < 16; ++col) { + s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col]; + } + } + } + } +} + +template +static inline void ggml_gemm_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][ncols_interleaved]; + float sum_minf[4][ncols_interleaved]; + uint8_t scales[8 * ncols_interleaved]; + uint8_t mins[8 * ncols_interleaved]; + int sumi1; + int sumi2; + int sumi; for (int y = 0; y < nr / 4; y++) { const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + const block_q4_Kx * b_ptr = (const block_q4_Kx *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0; + sumf[m][j] = 0.0; sum_minf[m][j] = 0.0; } } for (int l = 0; l < nb; l++) { - for (int sb = 0; sb < 8; sb++) { - memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); - utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; - utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); - utmp[sb * 4 + 2] = uaux_0; - utmp[sb * 4 + 0] &= kmask1; + for (int i = 0; i < ncols_interleaved * 8; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < ncols_interleaved * 4; i++) { + scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; + scales[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x30); + mins[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0xC0) >> 2; } - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - constexpr int scale_stride = 32; - uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; - uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; - const int qh_shift = (k / (32 / blocklen)) * 2; - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; - - const int qh_idx = (k * blocklen + i) % 32; - const int qh_chunk = qh_idx / blocklen; - const int qh_pos = qh_idx % blocklen; - const int b_qh_offset = - qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; - - const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; - const uint8_t h0 = (qh_val >> qh_shift) & 1; - const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; - - const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); - const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); - - const int q8_offset = (k / (32 / blocklen)) * 256 + - (k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i; - - sumi1 = (v0 * a_ptr[l].qs[q8_offset]); - sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * ncols_interleaved]; + for(int m = 0; m < 4; m++) { + const int16_t bsums = a_ptr[l].bsums[sb * 8 + m] + a_ptr[l].bsums[sb * 8 + m + 4]; + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += min[j] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; } } } - for (int sb = 0; sb < 8; sb++) { - uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; - for (int m = 0; m < 4; m++) { - const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); - for (int j = 0; j < ncols_interleaved; j++) { - sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * - GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * ncols_interleaved]; + uint8_t *scales_1 = &scales[(sb + 1) * ncols_interleaved]; + + for (int i = 0; i < QK4_0; i++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + + const int v0 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] >> 4); + sumi1 = (v0 * a_ptr[l].qs[sb * 4 * 32 + i * 4 + m]); + sumi2 = (v1 * a_ptr[l].qs[sb * 4 * 32 + 32 * 4 + i * 4 + m]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } } } } @@ -749,6 +934,51 @@ static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int n, } } +template +static inline void ggml_gemm_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int blocklen = 1; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][16]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx * b_ptr = (const block_iq4_nlx *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4])); + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} +#endif + extern "C" { void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1367,291 +1597,74 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n, // Only enable these for RISC-V. #if defined __riscv_zvfh +// Q4_0 +void ggml_gemv_q4_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_0_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - float sumf[16]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } + ggml_gemv_q4_0_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_0_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_0_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - float sumf[16]; - float sum_minf[16]; - uint8_t scales[128]; - uint8_t mins[128]; - int sumi1; - int sumi2; - int sumi; - const block_q8_K * a_ptr = (const block_q8_K *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0f; - sum_minf[j] = 0.0f; - } - for (int l = 0; l < nb; l++) { - for (int i = 0; i < 128; i++) { - scales[i] = b_ptr[l].scales[i] & 0x0F; - mins[i] = b_ptr[l].scales[i] >> 4; - } - for (int i = 0; i < 64; i++) { - scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; - mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; - scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30); - mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2; - } - for (int sb = 0; sb < 8; sb++) { - uint8_t *min = &mins[sb * 16]; - for (int j = 0; j < ncols_interleaved; j++) { - sum_minf[j] += min[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; - } - } - for (int sb = 0; sb < 8; sb += 2) { - uint8_t *scales_0 = &scales[sb * 16]; - uint8_t *scales_1 = &scales[(sb + 1) * 16]; - for (int i = 0; i < QK4_0; i++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi = 0; - const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF); - const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4); - sumi1 = (v0 * a_ptr[l].qs[sb * 32 + i]); - sumi2 = (v1 * a_ptr[l].qs[sb * 32 + 32 + i]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; - } - } - } - } - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; - } - } +// Q8_0 +void ggml_gemv_q8_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q8_0_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); } - -void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[16]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -} - void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[16]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / blocklen); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; - sumi += v0 * a_ptr[l].qs[k * blocklen + i]; - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j]; - } - } + ggml_gemv_q8_0_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q8_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q8_0_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q8_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q8_0_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q2_K +void ggml_gemv_q2_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q2_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - assert(n % QK_K == 0); - assert(nr == 1); - assert(nc % 16 == 0); + ggml_gemv_q2_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q2_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q2_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q2_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} - UNUSED(bs); - UNUSED(nr); +// Q4_K +void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q4_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} - const int nb = n / QK_K; - const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; - const block_q8_K * y = (const block_q8_K *)vy; - - // Layout: Even-Low(0,2,4,6), Odd-Low(1,3,5,7), Even-High(8...), Odd-High(9...) - const int sb_perm[16] = { - 0, 4, 1, 5, 2, 6, 3, 7, // 0-7 - 8, 12, 9, 13, 10, 14, 11, 15 // 8-15 - }; - - for (int col_tile = 0; col_tile < nc; col_tile += 16) { - const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; - const block_q8_K * y_ptr = y; - - float sumf[16] = {0}; - - // Loop over K-blocks - for (int k_block = 0; k_block < nb; ++k_block) { - int32_t isum[16] = {0}; - int32_t summs[16] = {0}; - - const uint8_t * qs_rhs = x_ptr[k_block].qs; - const uint8_t * sc_rhs = x_ptr[k_block].scales; - const int8_t * qs_lhs = y_ptr[k_block].qs; - const int16_t * bs_lhs = y_ptr[k_block].bsums; - - // Iterate over sub-blocks 0..15 - for (int sb = 0; sb < 16; ++sb) { - // Correction Term - int16_t bsum = bs_lhs[sb]; - int scale_offset = sb_perm[sb] * 16; - - for (int col = 0; col < 16; ++col) { - uint8_t sc_val = sc_rhs[scale_offset + col]; - summs[col] += bsum * (sc_val >> 4); // Min is high 4 bits - } - - // Main Dot Product - // Calculate base offsets for Q2 unpacking based on SB - int byte_base; - if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; - else byte_base = (sb % 2 == 0) ? 32 : 48; - - int shift = ((sb / 2) % 4) * 2; - - for (int col = 0; col < 16; ++col) { - uint8_t sc_val = sc_rhs[scale_offset + col]; - int32_t d_sb = sc_val & 0xF; // Scale is low 4 bits - - // Process 16 elements (l=0..15) - for (int l = 0; l < 16; ++l) { - // Q2: Interleaved by column. Byte `l` contains 4 k-values. - int qs_idx = (byte_base + l) * 16 + col; - uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; - - // Q8: Linear access - int k = sb * 16 + l; - int8_t q8_val = qs_lhs[k]; - - isum[col] += q8_val * q2_val * d_sb; - } - } - } - - // Finalize K-Block - for (int col = 0; col < 16; ++col) { - float d_lhs = y_ptr[k_block].d; - float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); - float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); - - float d_all = d_lhs * d_rhs; - float d_min = d_lhs * dm_rhs; - - sumf[col] += (isum[col] * d_all) - (summs[col] * d_min); - } - } - - for (int col = 0; col < 16; ++col) { - s[col_tile + col] = sumf[col]; - } - } +// IQ4_NL +void ggml_gemv_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_iq4_nl_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_iq4_nl_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_iq4_nl_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_iq4_nl_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } #endif @@ -2385,338 +2398,74 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int n, // Only enable these for RISC-V. #if defined __riscv_zvfh +// Q4_0 +void ggml_gemm_q4_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_0_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - float sumf[4][16]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } + ggml_gemm_q4_0_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_0_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_0_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - float sumf[4][16]; - float sum_minf[4][16]; - uint8_t scales[128]; - uint8_t mins[128]; - int sumi1; - int sumi2; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0; - sum_minf[m][j] = 0.0; - } - } - for (int l = 0; l < nb; l++) { - for (int i = 0; i < 128; i++) { - scales[i] = b_ptr[l].scales[i] & 0x0F; - mins[i] = b_ptr[l].scales[i] >> 4; - } - for (int i = 0; i < 64; i++) { - scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; - mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; - scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30); - mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2; - } - - for (int sb = 0; sb < 8; sb++) { - uint8_t *min = &mins[sb * 16]; - for(int m = 0; m < 4; m++) { - const int16_t bsums = a_ptr[l].bsums[sb * 8 + m] + a_ptr[l].bsums[sb * 8 + m + 4]; - for(int j = 0; j < ncols_interleaved; j++) { - sum_minf[m][j] += min[j] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; - } - } - } - - for (int sb = 0; sb < 8; sb += 2) { - uint8_t *scales_0 = &scales[sb * 16]; - uint8_t *scales_1 = &scales[(sb + 1) * 16]; - - for (int i = 0; i < QK4_0; i++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi = 0; - - const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF); - const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4); - sumi1 = (v0 * a_ptr[l].qs[sb * 4 * 32 + i * 4 + m]); - sumi2 = (v1 * a_ptr[l].qs[sb * 4 * 32 + 32 * 4 + i * 4 + m]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; - - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; - } - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; - } - } - } - } +// Q8_0 +void ggml_gemm_q8_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q8_0_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); } - -void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); - - float sumf[4][16]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4])); - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } -} - void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); - - float sumf[4][16]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0; - } - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / blocklen); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; - sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; - } - sumf[m][j] += - sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } + ggml_gemm_q8_0_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q8_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q8_0_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q8_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q8_0_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } - +// Q2_K +void ggml_gemm_q2_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q2_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - assert(n % QK_K == 0); - assert(nr % 4 == 0); - assert(nc % 16 == 0); - const int nb = n / QK_K; - const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; - const block_q8_Kx4 * y = (const block_q8_Kx4 *)vy; + ggml_gemm_q2_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q2_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q2_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q2_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} - const int sb_perm[16] = { - 0, 4, 1, 5, 2, 6, 3, 7, - 8, 12, 9, 13, 10, 14, 11, 15 - }; +// Q4_K +void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q4_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} - // Iterate Rows in tiles of 4 - for (int row_tile = 0; row_tile < nr; row_tile += 4) { - // Iterate Columns in tiles of 16 - for (int col_tile = 0; col_tile < nc; col_tile += 16) { - - const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; - const block_q8_Kx4 * y_ptr = y + (row_tile / 4) * nb; - - float sumf[4][16]; - memset(sumf, 0, sizeof(sumf)); - - for (int k_block = 0; k_block < nb; ++k_block) { - int32_t isum[4][16]; - int32_t summs[4][16]; - memset(isum, 0, sizeof(isum)); - memset(summs, 0, sizeof(summs)); - - const uint8_t * qs_rhs = x_ptr[k_block].qs; - const uint8_t * sc_rhs = x_ptr[k_block].scales; - const int8_t * qs_lhs = y_ptr[k_block].qs; - const int16_t * bs_lhs = y_ptr[k_block].bsums; - - for (int sb = 0; sb < 16; ++sb) { - int scale_offset = sb_perm[sb] * 16; - - int byte_base; - if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; - else byte_base = (sb % 2 == 0) ? 32 : 48; - int shift = ((sb / 2) % 4) * 2; - - for (int col = 0; col < 16; ++col) { - uint8_t sc_val = sc_rhs[scale_offset + col]; - int32_t d_sb = sc_val & 0xF; - int32_t m_sb = sc_val >> 4; - - // Correction Term - for (int r = 0; r < 4; ++r) { - int bsum_idx = (sb / 4) * 16 + r * 4 + (sb % 4); - summs[r][col] += bs_lhs[bsum_idx] * m_sb; - } - - // Main Dot Product - for (int l = 0; l < 16; ++l) { - int qs_idx = (byte_base + l) * 16 + col; - uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; - - // Calculate Q8 index for this specific k and row - int k = sb * 16 + l; - int q8_idx = (k / 4) * 16 + (k % 4); - - for (int r = 0; r < 4; ++r) { - // Add r*4 to jump to the correct row within the 4x4 chunk - int8_t q8_val = qs_lhs[q8_idx + r * 4]; - isum[r][col] += q8_val * q2_val * d_sb; - } - } - } - } - - // Finalize K-Block - for (int col = 0; col < 16; ++col) { - float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); - float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); - - for (int r = 0; r < 4; ++r) { - float d_lhs = y_ptr[k_block].d[r]; - float d_all = d_lhs * d_rhs; - float d_min = d_lhs * dm_rhs; - sumf[r][col] += (isum[r][col] * d_all) - (summs[r][col] * d_min); - } - } - } - - for (int r = 0; r < 4; ++r) { - for (int col = 0; col < 16; ++col) { - s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col]; - } - } - } - } +// IQ4_NL +void ggml_gemm_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_iq4_nl_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_iq4_nl_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_iq4_nl_Mx1_q8_0_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_iq4_nl_Mx1_q8_0_generic<64>(n, s, bs, vx, vy, nr, nc); } #endif @@ -2808,31 +2557,6 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in return out; } -static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { - block_q4_0x16 out; - - for (int i = 0; i < 16; i++) { - out.d[i] = in[i].d; - } - - const int end = QK4_0 * 8 / blck_size_interleave; - - if (blck_size_interleave == 1) { - const uint8_t xor_mask = 0x88; - for (int i = 0; i < end; ++i) { - int src_id = i % 16; - int src_offset = i / 16; - int dst_offset = i; - - out.qs[dst_offset] = in[src_id].qs[src_offset] ^ xor_mask; - } - } else { - GGML_ASSERT(false); - } - - return out; -} - static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) { block_q4_Kx8 out; //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure @@ -2910,58 +2634,6 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in return out; } -static block_q4_Kx16 make_block_q4_Kx16(block_q4_K * in, unsigned int blck_size_interleave) { - block_q4_Kx16 out; - //Delta(scale) and dmin values of the 16 Q4_K structures are copied onto the output interleaved structure - for (int i = 0; i < 16; i++) { - out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; - } - - for (int i = 0; i < 16; i++) { - out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; - } - - const int end = QK_K * 8 / blck_size_interleave; - - if (blck_size_interleave == 1) { - for (int i = 0; i < end; ++i) { - int src_id = i % 16; - int src_offset = i / 16; - int dst_offset = i; - - out.qs[dst_offset] = in[src_id].qs[src_offset]; - } - - // RVV repacking. - // - // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. - uint8_t s[128], m[128]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 16; j++) { - s[i * 16 + j] = in[j].scales[i] & 63; - m[i * 16 + j] = in[j].scales[i + 4] & 63; - } - } - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 16; j++) { - s[64 + i * 16 + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); - m[64 + i * 16 + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); - } - } - - for (int i = 0; i < 128; i++) { - out.scales[i] = (s[i] & 15) | ((m[i] & 15) << 4); - } - for (int i = 0; i < 64; i++) { - out.scales[128 + i] = ((s[i] & 48) >> 4) | ((m[i] & 48) >> 2) | (s[64 + i] & 48) | ((m[64 + i] & 48) << 2); - } - } else { - GGML_ASSERT(false); - } - - return out; -} - static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) { block_q2_Kx8 out; @@ -3135,68 +2807,6 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in return out; } -static block_q2_Kx16 make_block_q2_Kx16(const block_q2_K * in, unsigned int blck_size_interleave) { - block_q2_Kx16 out; - constexpr int N_COLS = 16; - - // 1. Copy Super-Scales (d) and Super-Mins (dmin) - for (int i = 0; i < N_COLS; i++) { - out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; - out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; - } - - // 2. Interleave Q2_K Data - const int bytes_per_col = 64; - const int total_bytes = N_COLS * bytes_per_col; - const int end = total_bytes / blck_size_interleave; - - for (int i = 0; i < end; ++i) { - int src_col_id = i % N_COLS; - int src_offset = (i / N_COLS) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - memcpy(&out.qs[dst_offset], &in[src_col_id].qs[src_offset], blck_size_interleave); - } - - // 3. Repack Scales into the Optimized "Sequential-Parallel" Layout - int out_idx = 0; - - // Arrays define the sub-block order for each group - const int even_low_sbs[] = {0, 2, 4, 6}; - const int odd_low_sbs[] = {1, 3, 5, 7}; - const int even_high_sbs[] = {8, 10, 12, 14}; - const int odd_high_sbs[] = {9, 11, 13, 15}; - - // Pack Group 1: Even-Low - for (int sb : even_low_sbs) { - for (int col = 0; col < N_COLS; col++) { - out.scales[out_idx++] = in[col].scales[sb]; - } - } - - // Pack Group 2: Odd-Low - for (int sb : odd_low_sbs) { - for (int col = 0; col < N_COLS; col++) { - out.scales[out_idx++] = in[col].scales[sb]; - } - } - - // Pack Group 3: Even-High - for (int sb : even_high_sbs) { - for (int col = 0; col < N_COLS; col++) { - out.scales[out_idx++] = in[col].scales[sb]; - } - } - - // Pack Group 4: Odd-High - for (int sb : odd_high_sbs) { - for (int col = 0; col < N_COLS; col++) { - out.scales[out_idx++] = in[col].scales[sb]; - } - } - - return out; -} - static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(interleave_block == 4 || interleave_block == 8); @@ -3259,36 +2869,6 @@ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } -static int repack_q4_K_to_q4_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - constexpr int nrows_interleaved = 16; - - block_q4_Kx16 * dst = (block_q4_Kx16*)t->data; - const block_q4_K * src = (const block_q4_K*) data; - block_q4_K dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; - - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K)); - - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { - return -1; - } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++ ) { - dst_tmp[i] = src[x + i * nblocks]; - } - *dst++ = make_block_q4_Kx16(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; - - GGML_UNUSED(data_size); -} - static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q2_K); GGML_ASSERT(interleave_block == 8); @@ -3320,71 +2900,6 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } -static int repack_q2_K_to_q2_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q2_K); - constexpr int nrows_interleaved = 16; - - block_q2_Kx16 * dst = (block_q2_Kx16*)t->data; - const block_q2_K * src = (const block_q2_K*) data; - - block_q2_K dst_tmp[nrows_interleaved]; - - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; - - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); - - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { - return -1; - } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - // This loop gathers 16 separate blocks (one from each column) - // that correspond to the same K-dimension chunk. - for (int i = 0; i < nrows_interleaved; i++ ) { - dst_tmp[i] = src[x + i * nblocks]; - } - - *dst++ = make_block_q2_Kx16(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; - - GGML_UNUSED(data_size); -} - -static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_0); - constexpr int nrows_interleaved = 16; - - block_q4_0x16 * dst = (block_q4_0x16*)t->data; - const block_q4_0 * src = (const block_q4_0*) data; - block_q4_0 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_0; - - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); - - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { - return -1; - } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++ ) { - dst_tmp[i] = src[x + i * nblocks]; - } - *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; - - GGML_UNUSED(data_size); -} - static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, @@ -3509,60 +3024,6 @@ static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t, return 0; } -static block_q8_0x16 make_block_q8_0x16(block_q8_0 * in, unsigned int blck_size_interleave) { - block_q8_0x16 out; - - for (int i = 0; i < 16; i++) { - out.d[i] = in[i].d; - } - - const int end = QK8_0 * 16 / blck_size_interleave; - - if (blck_size_interleave == 1) { - for (int i = 0; i < end; ++i) { - int src_id = i % 16; - int src_offset = i / 16; - int dst_offset = i; - out.qs[dst_offset] = in[src_id].qs[src_offset]; - } - } else { - GGML_ASSERT(false); - } - - return out; -} - -static int repack_q8_0_to_q8_0_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q8_0); - constexpr int nrows_interleaved = 16; - - block_q8_0x16 * dst = (block_q8_0x16 *) t->data; - const block_q8_0 * src = (const block_q8_0 *) data; - block_q8_0 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK8_0; - - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); - - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { - return -1; - } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; - } - *dst++ = make_block_q8_0x16(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; -} - static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) { block_iq4_nlx4 out; @@ -3688,41 +3149,315 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b GGML_UNUSED(data_size); } -static block_iq4_nlx16 make_block_iq4_nlx16(block_iq4_nl * in, unsigned int blck_size_interleave) { - block_iq4_nlx16 out; +#if defined __riscv_zvfh +template +static block<4, nrows_interleaved> make_block_q4_0xMx1(block_q4_0 * in) { + block<4, nrows_interleaved> out; - for (int i = 0; i < 16; i++) { + for (int i = 0; i < nrows_interleaved; i++) { out.d[i] = in[i].d; } - const int end = QK4_NL * 8 / blck_size_interleave; + const int end = QK4_0 * nrows_interleaved / 2; - if (blck_size_interleave == 1) { - for (int i = 0; i < end; ++i) { - int src_id = i % 16; - int src_offset = i / 16; - int dst_offset = i; + const uint8_t xor_mask = 0x88; + for (int i = 0; i < end; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; - out.qs[dst_offset] = in[src_id].qs[src_offset]; - } - } else { - GGML_ASSERT(false); + out.qs[dst_offset] = in[src_id].qs[src_offset] ^ xor_mask; } return out; } -static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); - GGML_ASSERT(interleave_block == 1); +template +static int repack_q4_0_to_q4_0_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); - const block_iq4_nl * src = (const block_iq4_nl *)data; - block_iq4_nlx16 * dst = ( block_iq4_nlx16 *)t->data; + block<4, nrows_interleaved> * dst = (block<4, nrows_interleaved>*)t->data; + const block_q4_0 * src = (const block_q4_0*) data; + block_q4_0 dst_tmp[nrows_interleaved]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; - block_iq4_nl dst_tmp[16]; + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0xMx1(dst_tmp); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +template +static block<8, nrows_interleaved> make_block_q8_0xMx1(block_q8_0 * in) { + block<8, nrows_interleaved> out; + + for (int i = 0; i < nrows_interleaved; i++) { + out.d[i] = in[i].d; + } + + const int end = QK8_0 * nrows_interleaved; + + for (int i = 0; i < end; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + + return out; +} + +template +static int repack_q8_0_to_q8_0_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + + block<8, nrows_interleaved> * dst = (block<8, nrows_interleaved> *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + block_q8_0 dst_tmp[nrows_interleaved]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q8_0xMx1(dst_tmp); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +template +static block_q2_Kx make_block_q2_KxMx1(const block_q2_K * in) { + block_q2_Kx out; + constexpr int N_COLS = nrows_interleaved; + + // 1. Copy Super-Scales (d) and Super-Mins (dmin) + for (int i = 0; i < N_COLS; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + // 2. Interleave Q2_K Data + const int bytes_per_col = 64; + const int total_bytes = N_COLS * bytes_per_col; + const int end = total_bytes; + + for (int i = 0; i < end; ++i) { + int src_col_id = i % N_COLS; + int src_offset = (i / N_COLS); + int dst_offset = i * 1; + memcpy(&out.qs[dst_offset], &in[src_col_id].qs[src_offset], 1); + } + + // 3. Repack Scales into the Optimized "Sequential-Parallel" Layout + int out_idx = 0; + + // Arrays define the sub-block order for each group + const int even_low_sbs[] = {0, 2, 4, 6}; + const int odd_low_sbs[] = {1, 3, 5, 7}; + const int even_high_sbs[] = {8, 10, 12, 14}; + const int odd_high_sbs[] = {9, 11, 13, 15}; + + // Pack Group 1: Even-Low + for (int sb : even_low_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 2: Odd-Low + for (int sb : odd_low_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 3: Even-High + for (int sb : even_high_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 4: Odd-High + for (int sb : odd_high_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + return out; +} + +template +static int repack_q2_K_to_q2_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); + + block_q2_Kx * dst = (block_q2_Kx*)t->data; + const block_q2_K * src = (const block_q2_K*) data; + + block_q2_K dst_tmp[nrows_interleaved]; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + // This loop gathers 16 separate blocks (one from each column) + // that correspond to the same K-dimension chunk. + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + + *dst++ = make_block_q2_KxMx1(dst_tmp); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +template +static block_q4_Kx make_block_q4_KxMx1(block_q4_K * in) { + block_q4_Kx out; + //Delta(scale) and dmin values of the 16 Q4_K structures are copied onto the output interleaved structure + for (int i = 0; i < nrows_interleaved; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < nrows_interleaved; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + const int end = QK_K * nrows_interleaved / 2; + + for (int i = 0; i < end; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + + // RVV repacking. + // + // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. + uint8_t s[8 * nrows_interleaved], m[8 * nrows_interleaved]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < nrows_interleaved; j++) { + s[i * nrows_interleaved + j] = in[j].scales[i] & 63; + m[i * nrows_interleaved + j] = in[j].scales[i + 4] & 63; + } + } + for (int i = 0; i < 4; i++) { + for (int j = 0; j < nrows_interleaved; j++) { + s[nrows_interleaved * 8 / 2 + i * nrows_interleaved + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[nrows_interleaved * 8 / 2 + i * nrows_interleaved + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + } + + for (int i = 0; i < 8 * nrows_interleaved; i++) { + out.scales[i] = (s[i] & 15) | ((m[i] & 15) << 4); + } + for (int i = 0; i < 8 * nrows_interleaved / 2; i++) { + out.scales[nrows_interleaved * 8 + i] = ((s[i] & 48) >> 4) | ((m[i] & 48) >> 2) | (s[nrows_interleaved * 8 / 2 + i] & 48) | ((m[nrows_interleaved * 8 / 2 + i] & 48) << 2); + } + + return out; +} + +template +static int repack_q4_K_to_q4_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + + block_q4_Kx * dst = (block_q4_Kx*)t->data; + const block_q4_K * src = (const block_q4_K*) data; + block_q4_K dst_tmp[nrows_interleaved]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_KxMx1(dst_tmp); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +template +static block_iq4_nlx make_block_iq4_nlxMx1(block_iq4_nl * in) { + block_iq4_nlx out; + + for (int i = 0; i < nrows_interleaved; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_NL * nrows_interleaved / 2; + + for (int i = 0; i < end; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + + return out; +} + +template +static int repack_iq4_nl_to_iq4_nl_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); + + const block_iq4_nl * src = (const block_iq4_nl *)data; + block_iq4_nlx * dst = ( block_iq4_nlx *)t->data; + + block_iq4_nl dst_tmp[nrows_interleaved]; int nrow = ggml_nrows(t); - int nrows_interleaved = 16; int nblocks = t->ne[0] / QK4_NL; GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); @@ -3736,7 +3471,7 @@ static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_ for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_iq4_nlx16(dst_tmp, interleave_block); + *dst++ = make_block_iq4_nlxMx1(dst_tmp); } src += nrows_interleaved * nblocks; } @@ -3744,6 +3479,7 @@ static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_ GGML_UNUSED(data_size); } +#endif static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) { block_mxfp4x4 out; @@ -3935,24 +3671,74 @@ template <> int repack(struct ggml_tensor * t, const void * da } #if defined __riscv_zvfh +// Q4_0 +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_Mx1_bl<8>(t, data, data_size); +} template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_0_to_q4_0_16_bl(t, 1, data, data_size); + return repack_q4_0_to_q4_0_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_Mx1_bl<64>(t, data, data_size); } -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_K_to_q4_K_16_bl(t, 1, data, data_size); +// Q8_0 +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q8_0_to_q8_0_Mx1_bl<8>(t, data, data_size); } - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_iq4_nl_to_iq4_nl_16_bl(t, 1, data, data_size); -} - template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q8_0_to_q8_0_16_bl(t, 1, data, data_size); + return repack_q8_0_to_q8_0_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q8_0_to_q8_0_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q8_0_to_q8_0_Mx1_bl<64>(t, data, data_size); } +// Q2_K +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_K_to_q2_K_Mx1_bl<8>(t, data, data_size); +} template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q2_K_to_q2_K_16_bl(t, 1, data, data_size); + return repack_q2_K_to_q2_K_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_K_to_q2_K_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_K_to_q2_K_Mx1_bl<64>(t, data, data_size); +} + +// Q4_K +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_Mx1_bl<8>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_Mx1_bl<64>(t, data, data_size); +} + +// IQ4_NL +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_Mx1_bl<8>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_Mx1_bl<64>(t, data, data_size); } #endif @@ -4032,25 +3818,75 @@ template <> void gemv(int n, float * s, size_t } #if defined __riscv_zvfh +// Q4_0 +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } - -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_64x1_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +// Q8_0 +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q8_0_8x1_q8_0(n, s, bs, vx, vy, nr, nc); } - template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q8_0_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q8_0_64x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +// Q2_K +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q2_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q2_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q2_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +// Q4_K +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +// IQ4_NL +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_64x1_q8_0(n, s, bs, vx, vy, nr, nc); +} #endif // gemm @@ -4129,25 +3965,75 @@ template <> void gemm(int n, float * s, size_t } #if defined __riscv_zvfh +// Q4_0 +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } - -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_64x1_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +// Q8_0 +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q8_0_8x1_q8_0(n, s, bs, vx, vy, nr, nc); } - template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q8_0_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q8_0_64x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +// Q2_K +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +// Q4_K +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +// IQ4_NL +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_8x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_32x1_q8_0(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_64x1_q8_0(n, s, bs, vx, vy, nr, nc); +} #endif class tensor_traits_base : public ggml::cpu::tensor_traits { @@ -4563,11 +4449,35 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons // These implement outer-product style matrix multiplication kernels with // an interleave of 1. #if defined __riscv_zvfh + // Q4_0 + static const ggml::cpu::repack::tensor_traits q4_0_8x1_q8_0; static const ggml::cpu::repack::tensor_traits q4_0_16x1_q8_0; - static const ggml::cpu::repack::tensor_traits q4_K_16x1_q8_K; - static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits q4_0_32x1_q8_0; + static const ggml::cpu::repack::tensor_traits q4_0_64x1_q8_0; + + // Q8_0 + static const ggml::cpu::repack::tensor_traits q8_0_8x1_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits q8_0_32x1_q8_0; + static const ggml::cpu::repack::tensor_traits q8_0_64x1_q8_0; + + // Q2_K + static const ggml::cpu::repack::tensor_traits q2_K_8x1_q8_K; static const ggml::cpu::repack::tensor_traits q2_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits q2_K_32x1_q8_K; + static const ggml::cpu::repack::tensor_traits q2_K_64x1_q8_K; + + // Q4_K + static const ggml::cpu::repack::tensor_traits q4_K_8x1_q8_K; + static const ggml::cpu::repack::tensor_traits q4_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits q4_K_32x1_q8_K; + static const ggml::cpu::repack::tensor_traits q4_K_64x1_q8_K; + + // IQ4_NL + static const ggml::cpu::repack::tensor_traits iq4_nl_8x1_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_32x1_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_64x1_q8_0; #endif if (cur->type == GGML_TYPE_Q4_0) { @@ -4589,10 +4499,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { break; } // TODO + case 128: { if (cur->ne[1] % 8 == 0) { return &q4_0_8x1_q8_0; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q4_0_16x1_q8_0; } break; } - case 512: { break; } // TODO - case 1024: { break; } // TODO + case 512: { if (cur->ne[1] % 32 == 0) { return &q4_0_32x1_q8_0; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q4_0_64x1_q8_0; } break; } default: { return nullptr; } } #endif @@ -4616,10 +4526,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { break; } // TODO + case 128: { if (cur->ne[1] % 8 == 0) { return &q4_K_8x1_q8_K; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q4_K_16x1_q8_K; } break; } - case 512: { break; } // TODO - case 1024: { break; } // TODO + case 512: { if (cur->ne[1] % 32 == 0) { return &q4_K_32x1_q8_K; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q4_K_64x1_q8_K; } break; } default: { return nullptr; } } #endif @@ -4633,10 +4543,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { break; } // TODO + case 128: { if (cur->ne[1] % 8 == 0) { return &q2_K_8x1_q8_K; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q2_K_16x1_q8_K; } break; } - case 512: { break; } // TODO - case 1024: { break; } // TODO + case 512: { if (cur->ne[1] % 32 == 0) { return &q2_K_32x1_q8_K; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q2_K_64x1_q8_K; } break; } default: { return nullptr; } } #endif @@ -4677,10 +4587,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { break; } // TODO + case 128: { if (cur->ne[1] % 8 == 0) { return &iq4_nl_8x1_q8_0; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &iq4_nl_16x1_q8_0; } break; } - case 512: { break; } // TODO - case 1024: { break; } // TODO + case 512: { if (cur->ne[1] % 32 == 0) { return &iq4_nl_32x1_q8_0; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &iq4_nl_64x1_q8_0; } break; } default: { return nullptr; } } #endif @@ -4710,10 +4620,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { break; } // TODO + case 128: { if (cur->ne[1] % 8 == 0) { return &q8_0_8x1_q8_0; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q8_0_16x1_q8_0; } break; } - case 512: { break; } // TODO - case 1024: { break; } // TODO + case 512: { if (cur->ne[1] % 32 == 0) { return &q8_0_32x1_q8_0; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q8_0_64x1_q8_0; } break; } default: { return nullptr; } } #endif diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index cb21edf623..0d97c9b9b5 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -29,48 +29,58 @@ template struct block { static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding"); static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding"); static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<4,16> size/padding"); +static_assert(sizeof(block<4, 32>) == 32 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<4,32> size/padding"); +static_assert(sizeof(block<4, 64>) == 64 * sizeof(ggml_half) + QK8_0 * 32, "wrong block<4,64> size/padding"); static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding"); static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding"); static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding"); +static_assert(sizeof(block<8, 32>) == 32 * sizeof(ggml_half) + QK8_0 * 32, "wrong block<8,32> size/padding"); +static_assert(sizeof(block<8, 64>) == 64 * sizeof(ggml_half) + QK8_0 * 64, "wrong block<8,64> size/padding"); using block_q4_0x4 = block<4, 4>; using block_q4_0x8 = block<4, 8>; using block_q4_0x16 = block<4, 16>; +using block_q4_0x32 = block<4, 32>; +using block_q4_0x64 = block<4, 64>; using block_q8_0x4 = block<8, 4>; using block_q8_0x8 = block<8, 8>; using block_q8_0x16 = block<8, 16>; +using block_q8_0x32 = block<8, 32>; +using block_q8_0x64 = block<8, 64>; -struct block_q4_Kx8 { - ggml_half d[8]; // super-block scale for quantized scales - ggml_half dmin[8]; // super-block scale for quantized mins - uint8_t scales[96]; // scales and mins, quantized with 6 bits - uint8_t qs[1024]; // 4--bit quants +template struct block_q4_Kx{ + ggml_half d[N]; // super-block scale for quantized scales + ggml_half dmin[N]; // super-block scale for quantized mins + uint8_t scales[12 * N]; // scales and mins, quantized with 6 bits + uint8_t qs[128 * N]; // 4--bit quants }; +using block_q4_Kx8 = block_q4_Kx<8>; +using block_q4_Kx16 = block_q4_Kx<16>; +using block_q4_Kx32 = block_q4_Kx<32>; +using block_q4_Kx64 = block_q4_Kx<64>; + static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); -struct block_q4_Kx16 { - ggml_half d[16]; // super-block scale for quantized scales - ggml_half dmin[16]; // super-block scale for quantized mins - uint8_t scales[192]; // scales and mins, quantized with 6 bits - uint8_t qs[2048]; // 4--bit quants +static_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, "wrong q4_K block size/padding"); +static_assert(sizeof(block_q4_Kx32) == sizeof(ggml_half) * 64 + K_SCALE_SIZE * 32 + QK_K * 16, "wrong q4_K block size/padding"); +static_assert(sizeof(block_q4_Kx64) == sizeof(ggml_half) * 128 + K_SCALE_SIZE * 64 + QK_K * 32, "wrong q4_K block size/padding"); + +template struct block_q2_Kx { + ggml_half d[N]; // super-block scale for quantized scales + ggml_half dmin[N]; // super-block scale for quantized mins + uint8_t scales[16 * N]; // scales and mins, quantized with 4 bits + uint8_t qs[64 * N]; // 2--bit quants }; -static_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, "wrong q4_K block size/padding"); -struct block_q2_Kx8 { - ggml_half d[8]; // super-block scale for quantized scales - ggml_half dmin[8]; // super-block scale for quantized mins - uint8_t scales[128]; // scales and mins, quantized with 4 bits - uint8_t qs[512]; // 2--bit quants -}; +using block_q2_Kx8 = block_q2_Kx<8>; +using block_q2_Kx16 = block_q2_Kx<16>; +using block_q2_Kx32 = block_q2_Kx<32>; +using block_q2_Kx64 = block_q2_Kx<64>; static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding"); -struct block_q2_Kx16 { - ggml_half d[16]; // Super-block scale for quantized scales - ggml_half dmin[16]; // Super-block scale for quantized mins - uint8_t scales[256]; // Sub-block scales (16 cols * 16 sub-blocks) - uint8_t qs[1024]; // Data (16 cols * 64 bytes per block) -}; static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, "wrong q2_K block size/padding"); +static_assert(sizeof(block_q2_Kx32) == sizeof(ggml_half) * 64 + QK_K * 2 + QK_K * 8, "wrong q2_K block size/padding"); +static_assert(sizeof(block_q2_Kx64) == sizeof(ggml_half) * 128 + QK_K * 4 + QK_K * 16, "wrong q2_K block size/padding"); struct block_q5_Kx8 { ggml_half d[8]; // super-block scale for quantized scales @@ -83,15 +93,22 @@ struct block_q5_Kx8 { static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5, "wrong q5_K block size/padding"); -struct block_q6_Kx8 { - ggml_half d[8]; - int8_t scales[QK_K / 16 * 8]; - uint8_t ql[QK_K / 2 * 8]; // low bits of 6-bit quants (groups of 2) - uint8_t qh[QK_K / 4 * 8]; // high bits of 6-bit quants (groups of 4) +template struct block_q6_Kx { + ggml_half d[N]; + int8_t scales[QK_K / 16 * N]; + uint8_t ql[QK_K / 2 * N]; // low bits of 6-bit quants (groups of 2) + uint8_t qh[QK_K / 4 * N]; // high bits of 6-bit quants (groups of 4) }; -static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8, - "wrong q6_K block size/padding"); +using block_q6_Kx8 = block_q6_Kx<8>; +using block_q6_Kx16 = block_q6_Kx<16>; +using block_q6_Kx32 = block_q6_Kx<32>; +using block_q6_Kx64 = block_q6_Kx<64>; + +static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8, "wrong q6_K block size/padding"); +static_assert(sizeof(block_q6_Kx16) == sizeof(ggml_half) * 16 + QK_K / 16 * 16 + 3 * QK_K / 4 * 16, "wrong q6_K block size/padding"); +static_assert(sizeof(block_q6_Kx32) == sizeof(ggml_half) * 32 + QK_K / 16 * 32 + 3 * QK_K / 4 * 32, "wrong q6_K block size/padding"); +static_assert(sizeof(block_q6_Kx64) == sizeof(ggml_half) * 64 + QK_K / 16 * 64 + 3 * QK_K / 4 * 64, "wrong q6_K block size/padding"); struct block_q8_Kx4 { float d[4]; // delta @@ -101,26 +118,23 @@ struct block_q8_Kx4 { static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding"); -struct block_iq4_nlx4 { - ggml_half d[4]; // deltas for 4 iq4_nl blocks - uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks +template struct block_iq4_nlx { + ggml_half d[N]; // deltas for `N` iq4_nl blocks + uint8_t qs[QK4_NL * N / 2]; // nibbles / quants for N iq4_nl blocks }; +using block_iq4_nlx4 = block_iq4_nlx<4>; +using block_iq4_nlx8 = block_iq4_nlx<8>; +using block_iq4_nlx16 = block_iq4_nlx<16>; +using block_iq4_nlx32 = block_iq4_nlx<32>; +using block_iq4_nlx64 = block_iq4_nlx<64>; + static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding"); - -struct block_iq4_nlx8 { - ggml_half d[8]; // deltas for 8 iq4_nl blocks - uint8_t qs[QK4_NL * 4]; // nibbles / quants for 8 iq4_nl blocks -}; - static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding"); - -struct block_iq4_nlx16 { - ggml_half d[16]; // deltas for 16 iq4_nl blocks - uint8_t qs[QK4_NL * 8]; // nibbles / quants for 16 iq4_nl blocks -}; - static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "wrong iq4_nlx16 block size/padding"); +static_assert(sizeof(block_iq4_nlx32) == 32 * sizeof(ggml_half) + QK4_NL * 16, "wrong iq4_nlx32 block size/padding"); +static_assert(sizeof(block_iq4_nlx64) == 64 * sizeof(ggml_half) + QK4_NL * 32, "wrong iq4_nlx64 block size/padding"); + struct block_mxfp4x4 { uint8_t e[4]; uint8_t qs[QK_MXFP4 * 2]; @@ -176,16 +190,46 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #endif // Native implementations @@ -228,16 +272,46 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, #if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_64x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #endif #if defined(__cplusplus)