diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 6fcb13d676..bc9d35a073 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -45,7 +45,6 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 @@ -58,7 +57,6 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K # define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 @@ -66,10 +64,8 @@ // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K -#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) @@ -82,7 +78,6 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -91,7 +86,6 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__POWERPC__) || defined(__powerpc__) @@ -115,7 +109,6 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 @@ -128,7 +121,6 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 @@ -153,7 +145,6 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 @@ -166,7 +157,6 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 @@ -186,11 +176,11 @@ #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp +#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 -#define ggml_quantize_mat_q8_0_4x16_generic ggml_quantize_mat_q8_0_4x16 +#define ggml_quantize_mat_q8_K_4x1_generic ggml_quantize_mat_q8_K_4x1 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K @@ -238,7 +228,6 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 @@ -251,7 +240,6 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 @@ -284,7 +272,6 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 @@ -297,7 +284,6 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 923902786e..21dda5c0a9 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -1,3 +1,4 @@ +#include #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP #include "ggml-common.h" @@ -203,7 +204,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int ncols_interleaved = 16; @@ -223,26 +224,32 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(blocklen); #if defined __riscv_v_intrinsic + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); - // 1x16 Accumulator1 + // 1x16 Accumulator vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); for (int l = 0; l < nb; l++) { - // 1x32 integer accumulator - vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); + // 1x16 Integer Accumulator + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); // Accumulation loop. - for (int i = 0; i < QK8_0; i++) { + for (int i = 0; i < QK4_0 / 2; i++) { // Load `b_ptr`. - const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); - // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); + const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); + const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); - sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16); + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, 16); + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, 16); } + const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); @@ -253,13 +260,154 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } return; #endif - ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); + ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + // TODO +} + +void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + // TODO +} + +void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const block_q8_K * a_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + + // 1x16 Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, 16); + + // We process 4 sub-blocks at once. + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + + const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d, 16); + sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, 16), 16), 16); + + // Accumulation for 2 sub-blocks. + { + // 4x16 integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, 16); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, 16); + } + + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_s_0_16, 16); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_s_1_16, 16); + } + { + // 4x16 integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, 16); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, 16); + } + + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_s_0_16, 16); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_s_1_16, 16); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], 16), 16); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, 16); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } + return; +#endif + ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 4; + const int ncols_interleaved = 8; const int blocklen = 8; assert (n % qk == 0); @@ -276,173 +424,71 @@ void ggml_gemv_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(blocklen); #if defined __riscv_v_intrinsic - const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); - vfloat32mf2_t sumf = __riscv_vfmv_v_f_f32mf2(0.0, 4); + vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, 8); for (int l = 0; l < nb; l++) { - // Load first 8 bytes of `a`. const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; __asm__ __volatile__("" ::: "memory"); + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); + const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); + const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); + const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); + // Load `b_ptr`. - const vuint8m2_t b_0_packed = __riscv_vle8_v_u8m2((const uint8_t *)b_ptr[l].qs, QK4_NL * 2); - const vint8m2_t b_0_lo = __riscv_vrgather_vv_i8m2(values, __riscv_vand_vx_u8m2(b_0_packed, 0xf, QK4_NL * 2), QK4_NL * 2); - const vint8m2_t b_0_hi = __riscv_vrgather_vv_i8m2(values, __riscv_vsrl_vx_u8m2(b_0_packed, 4, QK4_NL * 2), QK4_NL * 2); + const vuint8m4_t b_0_packed = __riscv_vle8_v_u8m4((const uint8_t *)b_ptr[l].qs, QK4_NL * 4); + const vint8m4_t b_0_lo = __riscv_vrgather_vv_i8m4(values, __riscv_vand_vx_u8m4(b_0_packed, 0xf, QK4_NL * 4), QK4_NL * 4); + const vint8m4_t b_0_hi = __riscv_vrgather_vv_i8m4(values, __riscv_vsrl_vx_u8m4(b_0_packed, 4, QK4_NL * 4), QK4_NL * 4); // Create 4 segments from `b`. - const vint8m1_t b_lo_0 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 0); - const vint8m1_t b_lo_1 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 1); - const vint8m1_t b_hi_0 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 0); - const vint8m1_t b_hi_1 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 1); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); - const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); - const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); - const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); + const vint8m2_t b_lo_0 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 0); + const vint8m2_t b_lo_1 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 1); + const vint8m2_t b_hi_0 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 0); + const vint8m2_t b_hi_1 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 1); // Multiply and accumulate. - const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); - const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); - const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); - const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); - const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); - const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); - const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); + const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); + const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); + const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); + const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); // In-place reduction. - const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); - const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); - const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); - const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); - const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); - const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); - const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); - const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); - const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); - const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); - const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); + const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); + const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); + const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); + const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); + const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); + const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); + const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); + const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); + const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); + const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); // Multiply with scales. - const vfloat16mf4_t b_d = __riscv_vle16_v_f16mf4((const _Float16 *)b_ptr[l].d, 4); - const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d, 4); - sumf = __riscv_vfmacc_vv_f32mf2(sumf, facc, d_0, QK4_NL / 8); + const vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((const _Float16 *)b_ptr[l].d, 8); + const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d, 8); + sumf = __riscv_vfmacc_vv_f32m1(sumf, facc, d_0, QK4_NL / 4); } - __riscv_vse32_v_f32mf2(s + x * ncols_interleaved, sumf, QK4_NL / 8); + __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, QK4_NL / 4); } return; #endif - ggml_gemv_iq4_nl_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - -void ggml_gemv_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v_intrinsic - const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16); - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - - vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0f, 4); - for (int l = 0; l + 1 < nb; l += 2) { - vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 0, 16); - vuint8m1_t b_1_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 16, 16); - vuint8m1_t b_2_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 32, 16); - vuint8m1_t b_3_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 48, 16); - vuint8m1_t b_4_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 0, 16); - vuint8m1_t b_5_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 16, 16); - vuint8m1_t b_6_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 32, 16); - vuint8m1_t b_7_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 48, 16); - - vuint8m1_t b_0_lo = __riscv_vand_vx_u8m1(b_0_packed, 0xf, 16); - vuint8m1_t b_0_hi = __riscv_vsrl_vx_u8m1(b_0_packed, 4, 16); - vuint8m1_t b_1_lo = __riscv_vand_vx_u8m1(b_1_packed, 0xf, 16); - vuint8m1_t b_1_hi = __riscv_vsrl_vx_u8m1(b_1_packed, 4, 16); - vuint8m1_t b_2_lo = __riscv_vand_vx_u8m1(b_2_packed, 0xf, 16); - vuint8m1_t b_2_hi = __riscv_vsrl_vx_u8m1(b_2_packed, 4, 16); - vuint8m1_t b_3_lo = __riscv_vand_vx_u8m1(b_3_packed, 0xf, 16); - vuint8m1_t b_3_hi = __riscv_vsrl_vx_u8m1(b_3_packed, 4, 16); - vuint8m1_t b_4_lo = __riscv_vand_vx_u8m1(b_4_packed, 0xf, 16); - vuint8m1_t b_4_hi = __riscv_vsrl_vx_u8m1(b_4_packed, 4, 16); - vuint8m1_t b_5_lo = __riscv_vand_vx_u8m1(b_5_packed, 0xf, 16); - vuint8m1_t b_5_hi = __riscv_vsrl_vx_u8m1(b_5_packed, 4, 16); - vuint8m1_t b_6_lo = __riscv_vand_vx_u8m1(b_6_packed, 0xf, 16); - vuint8m1_t b_6_hi = __riscv_vsrl_vx_u8m1(b_6_packed, 4, 16); - vuint8m1_t b_7_lo = __riscv_vand_vx_u8m1(b_7_packed, 0xf, 16); - vuint8m1_t b_7_hi = __riscv_vsrl_vx_u8m1(b_7_packed, 4, 16); - - vint8m1_t b_0 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_0_lo, b_0_hi, 16, 32), 32); - vint8m1_t b_1 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_1_lo, b_1_hi, 16, 32), 32); - vint8m1_t b_2 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_2_lo, b_2_hi, 16, 32), 32); - vint8m1_t b_3 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_3_lo, b_3_hi, 16, 32), 32); - vint8m1_t b_4 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_4_lo, b_4_hi, 16, 32), 32); - vint8m1_t b_5 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_5_lo, b_5_hi, 16, 32), 32); - vint8m1_t b_6 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_6_lo, b_6_hi, 16, 32), 32); - vint8m1_t b_7 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_7_lo, b_7_hi, 16, 32), 32); - - vint8m1_t a_0 = __riscv_vle8_v_i8m1(a_ptr[l].qs, 32); - vint8m1_t a_1 = __riscv_vle8_v_i8m1(a_ptr[l + 1].qs, 32); - - vint32m1_t sumi_0 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_0, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_1 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_1, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_2 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_2, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_3 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_3, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_4 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_4, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_5 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_5, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_6 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_6, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_7 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_7, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - - int sumi_temp[8]; - __riscv_vse32_v_i32m1(&sumi_temp[0], sumi_0, 1); - __riscv_vse32_v_i32m1(&sumi_temp[1], sumi_1, 1); - __riscv_vse32_v_i32m1(&sumi_temp[2], sumi_2, 1); - __riscv_vse32_v_i32m1(&sumi_temp[3], sumi_3, 1); - __riscv_vse32_v_i32m1(&sumi_temp[4], sumi_4, 1); - __riscv_vse32_v_i32m1(&sumi_temp[5], sumi_5, 1); - __riscv_vse32_v_i32m1(&sumi_temp[6], sumi_6, 1); - __riscv_vse32_v_i32m1(&sumi_temp[7], sumi_7, 1); - vint32m1_t sum_0 = __riscv_vle32_v_i32m1(&sumi_temp[0], 4); - vint32m1_t sum_1 = __riscv_vle32_v_i32m1(&sumi_temp[4], 4); - - vfloat16mf2_t b_d_0 = __riscv_vle16_v_f16mf2((_Float16 *)b_ptr[l].d, 4); - vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d_0, *(const _Float16 *)&a_ptr[l].d, 4); - vfloat16mf2_t b_d_1 = __riscv_vle16_v_f16mf2((_Float16 *)b_ptr[l + 1].d, 4); - vfloat32m1_t d_1 = __riscv_vfwmul_vf_f32m1(b_d_1, *(const _Float16 *)&a_ptr[l + 1].d, 4); - - sumf = __riscv_vfmacc_vv_f32m1(sumf, d_0, __riscv_vfcvt_f_x_v_f32m1(sum_0, 4), 4); - sumf = __riscv_vfmacc_vv_f32m1(sumf, d_1, __riscv_vfcvt_f_x_v_f32m1(sum_1, 4), 4); - } - __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, 4); - } - return; -#endif - ggml_gemv_iq4_nl_4x16_q8_0_generic(n, s, bs, vx, vy, nr, nc); + ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -504,11 +550,64 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + + // 1x16 Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 1x16 Integer Accumulator + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK8_0; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); + + sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } + return; +#endif + ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; assert (n % qk == 0); assert (nr % 4 == 0); @@ -525,81 +624,354 @@ void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(blocklen); #if defined __riscv_v_intrinsic - const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16); + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); - // 4x4 Accumulators - vfloat32m1_t sumf_0 = __riscv_vfmv_v_f_f32m1(0.0f, 4); - vfloat32m1_t sumf_1 = __riscv_vfmv_v_f_f32m1(0.0f, 4); - vfloat32m1_t sumf_2 = __riscv_vfmv_v_f_f32m1(0.0f, 4); - vfloat32m1_t sumf_3 = __riscv_vfmv_v_f_f32m1(0.0f, 4); + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); for (int l = 0; l < nb; l++) { - int sumi_temp[16]; - uint8_t index[4] = {0, 8, 64, 72}; - vuint8mf8_t i_vec = __riscv_vle8_v_u8mf8(&index[0], 4); - vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 0, 16); - vuint8m1_t b_1_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 16, 16); - vuint8m1_t b_2_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 32, 16); - vuint8m1_t b_3_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 48, 16); + // 4x16 integer accumulators + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); - vuint8m1_t b_0_lo = __riscv_vand_vx_u8m1(b_0_packed, 0xf, 16); - vuint8m1_t b_0_hi = __riscv_vsrl_vx_u8m1(b_0_packed, 4, 16); - vuint8m1_t b_1_lo = __riscv_vand_vx_u8m1(b_1_packed, 0xf, 16); - vuint8m1_t b_1_hi = __riscv_vsrl_vx_u8m1(b_1_packed, 4, 16); - vuint8m1_t b_2_lo = __riscv_vand_vx_u8m1(b_2_packed, 0xf, 16); - vuint8m1_t b_2_hi = __riscv_vsrl_vx_u8m1(b_2_packed, 4, 16); - vuint8m1_t b_3_lo = __riscv_vand_vx_u8m1(b_3_packed, 0xf, 16); - vuint8m1_t b_3_hi = __riscv_vsrl_vx_u8m1(b_3_packed, 4, 16); + // Accumulation loop. + for (int i = 0; i < QK4_0 / 2; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); + const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); - vint8m1_t b_0 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_0_lo, b_0_hi, 16, 32), 32); - vint8m1_t b_1 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_1_lo, b_1_hi, 16, 32), 32); - vint8m1_t b_2 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_2_lo, b_2_hi, 16, 32), 32); - vint8m1_t b_3 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_3_lo, b_3_hi, 16, 32), 32); + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, 16); + sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, 16); + sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, 16); + sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, 16); - #pragma unroll 4 - for (int i = 0; i < 4; i++) { - vint8m1_t a_i = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vloxei8_v_i64m1((int64_t*)(a_ptr[l].qs + i * 16), i_vec, 4)); - vint32m1_t sumi_0 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_0, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_1 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_1, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_2 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_2, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - vint32m1_t sumi_3 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_3, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); - __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 0], sumi_0, 1); - __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 1], sumi_1, 1); - __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 2], sumi_2, 1); - __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 3], sumi_3, 1); + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, 16); + sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, 16); + sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, 16); + sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, 16); } - vint32m1_t sum_0 = __riscv_vle32_v_i32m1(&sumi_temp[0], 4); - vint32m1_t sum_1 = __riscv_vle32_v_i32m1(&sumi_temp[4], 4); - vint32m1_t sum_2 = __riscv_vle32_v_i32m1(&sumi_temp[8], 4); - vint32m1_t sum_3 = __riscv_vle32_v_i32m1(&sumi_temp[12], 4); + // Do the final accumulation in i32 to prevent overflow. + const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); + const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, 16); + const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, 16); + const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, 16); - vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((_Float16 *)b_ptr[l].d, 4); - vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[0], 4); - vfloat32m1_t d_1 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[1], 4); - vfloat32m1_t d_2 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[2], 4); - vfloat32m1_t d_3 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[3], 4); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); - sumf_0 = __riscv_vfmacc_vv_f32m1(sumf_0, d_0, __riscv_vfcvt_f_x_v_f32m1(sum_0, 4), 4); - sumf_1 = __riscv_vfmacc_vv_f32m1(sumf_1, d_1, __riscv_vfcvt_f_x_v_f32m1(sum_1, 4), 4); - sumf_2 = __riscv_vfmacc_vv_f32m1(sumf_2, d_2, __riscv_vfcvt_f_x_v_f32m1(sum_2, 4), 4); - sumf_3 = __riscv_vfmacc_vv_f32m1(sumf_3, d_3, __riscv_vfcvt_f_x_v_f32m1(sum_3, 4), 4); + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); } - __riscv_vse32_v_f32m1(s + (y * 4 + 0) * bs + x * 4, sumf_0, 4); - __riscv_vse32_v_f32m1(s + (y * 4 + 1) * bs + x * 4, sumf_1, 4); - __riscv_vse32_v_f32m1(s + (y * 4 + 2) * bs + x * 4, sumf_2, 4); - __riscv_vse32_v_f32m1(s + (y * 4 + 3) * bs + x * 4, sumf_3, 4); + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } return; #endif - ggml_gemm_iq4_nl_4x16_q8_0_generic(n, s, bs, vx, vy, nr, nc); + ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, 16); + + // We process 4 sub-blocks at once. + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, 16); + + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + + const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[0], 16); + const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[1], 16); + const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[2], 16); + const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, 16), 16), 16); + sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, 16), 16), 16); + sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, 16), 16), 16); + sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, 16), 16), 16); + + + // Accumulation for 2 sub-blocks. + { + // 4x8 integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, 16); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, 16); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, 16); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, 16); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, 16); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, 16); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, 16); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, 16); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_0_s_0_16, 16); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_0_s_1_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_1_s_0_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_1_s_1_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_2_s_0_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_2_s_1_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_3_s_0_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_3_s_1_16, 16); + } + { + // 4x8 integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // This might overflow. + // + // Recheck. + for (int i = 0; i < QK4_0; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, 16); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, 16); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, 16); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, 16); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, 16); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, 16); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, 16); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, 16); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_0_s_0_16, 16); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_0_s_1_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_1_s_0_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_1_s_1_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_2_s_0_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_2_s_1_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_3_s_0_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_3_s_1_16, 16); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16), 16); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } + return; +#endif + ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + // TODO +} + +void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + // TODO } void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -724,7 +1096,7 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); for (int l = 0; l < nb; l++) { - // 4x16 integer accumulators + // 4x16 Integer Accumulators vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16); vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16); vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); @@ -765,93 +1137,6 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v_intrinsic - const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); - - vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, 8); - for (int l = 0; l < nb; l++) { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); - const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); - const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); - const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); - - // Load `b_ptr`. - const vuint8m4_t b_0_packed = __riscv_vle8_v_u8m4((const uint8_t *)b_ptr[l].qs, QK4_NL * 4); - const vint8m4_t b_0_lo = __riscv_vrgather_vv_i8m4(values, __riscv_vand_vx_u8m4(b_0_packed, 0xf, QK4_NL * 4), QK4_NL * 4); - const vint8m4_t b_0_hi = __riscv_vrgather_vv_i8m4(values, __riscv_vsrl_vx_u8m4(b_0_packed, 4, QK4_NL * 4), QK4_NL * 4); - - // Create 4 segments from `b`. - const vint8m2_t b_lo_0 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 0); - const vint8m2_t b_lo_1 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 1); - const vint8m2_t b_hi_0 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 0); - const vint8m2_t b_hi_1 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 1); - - // Multiply and accumulate. - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); - const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); - const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); - const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); - const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); - - // In-place reduction. - const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); - const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); - const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); - const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); - const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); - const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); - const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); - const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); - const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); - const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); - - // Multiply with scales. - const vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((const _Float16 *)b_ptr[l].d, 8); - const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d, 8); - sumf = __riscv_vfmacc_vv_f32m1(sumf, facc, d_0, QK4_NL / 4); - } - __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, QK4_NL / 4); - } - return; - -#endif - ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1078,239 +1363,6 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v_intrinsic - const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - - // 4x4 accumulators. - vfloat32mf2_t sumf0 = __riscv_vfmv_v_f_f32mf2(0.0, 4); - vfloat32mf2_t sumf1 = __riscv_vfmv_v_f_f32mf2(0.0, 4); - vfloat32mf2_t sumf2 = __riscv_vfmv_v_f_f32mf2(0.0, 4); - vfloat32mf2_t sumf3 = __riscv_vfmv_v_f_f32mf2(0.0, 4); - - for (int l = 0; l < nb; l++) { - // Load `b_ptr`. - const vuint8m2_t b_0_packed = __riscv_vle8_v_u8m2((const uint8_t *)b_ptr[l].qs, QK4_NL * 2); - const vint8m2_t b_0_lo = __riscv_vrgather_vv_i8m2(values, __riscv_vand_vx_u8m2(b_0_packed, 0xf, QK4_NL * 2), QK4_NL * 2); - const vint8m2_t b_0_hi = __riscv_vrgather_vv_i8m2(values, __riscv_vsrl_vx_u8m2(b_0_packed, 4, QK4_NL * 2), QK4_NL * 2); - - // Create 4 segments from `b`. - const vint8m1_t b_lo_0 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 0); - const vint8m1_t b_lo_1 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 1); - const vint8m1_t b_hi_0 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 0); - const vint8m1_t b_hi_1 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 1); - - // Load scales for `b`. - const vfloat16mf4_t b_d = __riscv_vle16_v_f16mf4((const _Float16 *)b_ptr[l].d, 4); - - // Load first 8 bytes of `a`. - { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[32]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[64]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[96]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); - const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); - const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); - const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); - - // Multiply and accumulate. - const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); - const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); - const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); - const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); - const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); - const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); - const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); - - // In-place reduction. - const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); - const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); - const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); - const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); - const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); - const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); - const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); - const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); - const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); - const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); - const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); - - // Multiply with scales. - const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[0], 4); - sumf0 = __riscv_vfmacc_vv_f32mf2(sumf0, facc, d_0, QK4_NL / 8); - } - - // Load second 8 bytes of `a`. - { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[40]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[72]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[104]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); - const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); - const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); - const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); - - // Multiply and accumulate. - const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); - const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); - const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); - const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); - const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); - const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); - const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); - - // In-place reduction. - const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); - const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); - const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); - const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); - const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); - const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); - const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); - const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); - const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); - const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); - const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); - - // Multiply with scales. - const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[1], 4); - sumf1 = __riscv_vfmacc_vv_f32mf2(sumf1, facc, d_0, QK4_NL / 8); - } - - // Load third 8 bytes of `a`. - { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[48]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[80]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[112]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); - const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); - const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); - const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); - - // Multiply and accumulate. - const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); - const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); - const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); - const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); - const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); - const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); - const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); - - // In-place reduction. - const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); - const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); - const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); - const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); - const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); - const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); - const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); - const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); - const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); - const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); - const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); - - // Multiply with scales. - const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[2], 4); - sumf2 = __riscv_vfmacc_vv_f32mf2(sumf2, facc, d_0, QK4_NL / 8); - } - - // Load fourth 8 bytes of `a`. - { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[24]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[56]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[88]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[120]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); - const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); - const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); - const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); - - // Multiply and accumulate. - const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); - const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); - const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); - const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); - const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); - const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); - const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); - - // In-place reduction. - const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); - const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); - const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); - const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); - const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); - const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); - const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); - const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); - const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); - const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); - const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); - - // Multiply with scales. - const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[3], 4); - sumf3 = __riscv_vfmacc_vv_f32mf2(sumf3, facc, d_0, QK4_NL / 8); - } - } - - __riscv_vse32_v_f32mf2(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, 8); - __riscv_vse32_v_f32mf2(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, 8); - __riscv_vse32_v_f32mf2(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, 8); - __riscv_vse32_v_f32mf2(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, 8); - } - } - return; - -#endif - ggml_gemm_iq4_nl_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index e0bf9e6354..5ca3aa06e0 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1,3 +1,4 @@ +#include "ggml.h" #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP #include "ggml-common.h" @@ -48,6 +49,7 @@ static inline int nearest_int(float fval) { extern "C" { +#if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); @@ -86,6 +88,51 @@ void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GG } } +void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK_K == 256); + assert(k % QK_K == 0); + const int nb = k / QK_K; + + block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; + + const int blck_size_interleave = 1; + float srcv[4][QK_K]; + float iscale[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + float max = 0; + + for (int j = 0; j < QK_K; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK_K + j]; + // Update the maximum value of the corresponding super block + if(amax < fabsf(srcv[row_iter][j])) { + amax = fabsf(srcv[row_iter][j]); + max = srcv[row_iter][j]; + } + } + + iscale[row_iter] = amax ? -127.f/max : 0; + y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; + } + + for (int j = 0; j < QK_K / 4; j++) { + y[i].bsums[j] = 0; + } + for (int j = 0; j < QK_K * 4; j++) { + int src_id = j % 4; + int src_offset = j / 4; + int index = ((j >> 6) << 2) + (j & 3); + + float x0 = srcv[src_id][src_offset] * iscale[src_id]; + y[i].qs[j] = nearest_int(x0); + y[i].bsums[index] += y[i].qs[j]; + } + } +} +#endif + void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); @@ -162,44 +209,6 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG } } -void ggml_quantize_mat_q8_0_4x16_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; - - // scalar - const int blck_size_interleave = 16; - float srcv[4][QK8_0]; - float id[4]; - - for (int i = 0; i < nb; i++) { - for (int row_iter = 0; row_iter < 4; row_iter++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; - amax = MAX(amax, fabsf(srcv[row_iter][j])); - } - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d); - } - - for (int j = 0; j < QK8_0 * 4; j++) { - int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (j % blck_size_interleave); - - float x0 = srcv[src_id][src_offset] * id[src_id]; - y[i].qs[j] = roundf(x0); - } - } -} - void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); assert(k % QK_K == 0); @@ -313,24 +322,12 @@ template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTR ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { - assert(nrow == 4); - UNUSED(nrow); - ggml_quantize_mat_q8_0_4x1(x, vy, n_per_row); -} - template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<16, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { - assert(nrow == 4); - UNUSED(nrow); - ggml_quantize_mat_q8_0_4x16(x, vy, n_per_row); -} - template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); @@ -343,6 +340,20 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); } +#if defined __riscv_zvfh +template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_0_4x1(x, vy, n_per_row); +} + +template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_K_4x1(x, vy, n_per_row); +} +#endif + extern "C" { void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -919,82 +930,6 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } -void ggml_gemv_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -} - -void ggml_gemv_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -} - void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1033,44 +968,6 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } -void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[16]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -} - void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, @@ -1165,6 +1062,204 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n, } } +#if defined __riscv_zvfh +void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + float sumf[16]; + float sum_minf[16]; + uint8_t scales[128]; + uint8_t mins[128]; + int sumi1; + int sumi2; + int sumi; + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + sum_minf[j] = 0.0f; + } + for (int l = 0; l < nb; l++) { + for (int i = 0; i < 128; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < 64; i++) { + scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; + scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30); + mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2; + } + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * 16]; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += min[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * 16]; + uint8_t *scales_1 = &scales[(sb + 1) * 16]; + for (int i = 0; i < QK4_0; i++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4); + sumi1 = (v0 * a_ptr[l].qs[sb * 32 + i]); + sumi2 = (v1 * a_ptr[l].qs[sb * 32 + 32 + i]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * blocklen + i]; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} +#endif + void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1847,118 +1942,6 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } -void ggml_gemm_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - { - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } -} - -void ggml_gemm_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - - { - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } -} - void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -2003,50 +1986,6 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } -void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 16; - const int blocklen = 1; - - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); - - float sumf[4][16]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4])); - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } -} - void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, @@ -2099,6 +2038,8 @@ void ggml_gemm_q8_0_4x4_q8_0_generic(int n, } } + + void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, @@ -2151,6 +2092,246 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int n, } } +#if defined __riscv_zvfh +void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][16]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][16]; + float sum_minf[4][16]; + uint8_t scales[128]; + uint8_t mins[128]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int i = 0; i < 128; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < 64; i++) { + scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; + scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30); + mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2; + } + + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * 16]; + for(int m = 0; m < 4; m++) { + const int16_t bsums = a_ptr[l].bsums[sb * 8 + m] + a_ptr[l].bsums[sb * 8 + m + 4]; + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += min[j] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * 16]; + uint8_t *scales_1 = &scales[(sb + 1) * 16]; + + for (int i = 0; i < QK4_0; i++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + + const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4); + sumi1 = (v0 * a_ptr[l].qs[sb * 4 * 32 + i * 4 + m]); + sumi2 = (v1 * a_ptr[l].qs[sb * 4 * 32 + 32 * 4 + i * 4 + m]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][16]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4])); + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][16]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; + } + sumf[m][j] += + sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} +#endif + } // extern "C" static block_q8_0x4 make_block_q8_0x4(block_q8_0 * in, unsigned int blck_size_interleave) { @@ -2239,6 +2420,31 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in return out; } +static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_0 * 8 / blck_size_interleave; + + if (blck_size_interleave == 1) { + const uint8_t xor_mask = 0x88; + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset] ^ xor_mask; + } + } else { + GGML_ASSERT(false); + } + + return out; +} + static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) { block_q4_Kx8 out; //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure @@ -2253,63 +2459,150 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in const int end = QK_K * 4 / blck_size_interleave; // Interleave Q4_K quants by taking 8 bytes at a time - for (int i = 0; i < end; ++i) { - int src_id = i % 8; - int src_offset = (i / 8) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; + if (blck_size_interleave == 8) { + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; - uint64_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); - } - - // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K - // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) - // The output Q4_Kx8 structure has 96 bytes - // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure - // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures - uint8_t s[8], m[8]; - - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[j] = in[j].scales[i] & 63; - m[j] = in[j].scales[i + 4] & 63; + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); } - out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); - out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); - out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); - out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); - out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); - out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); - out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); - out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); - out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); - out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); - out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); - out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); + // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K + // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) + // The output Q4_Kx8 structure has 96 bytes + // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure + // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures + uint8_t s[8], m[8]; - } + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = in[j].scales[i] & 63; + m[j] = in[j].scales[i + 4] & 63; + } + + out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); - m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); } - out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); - out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); - out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); - out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); - out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); - out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); - out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); - out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); - out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); - out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); - out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); - out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); + + } + + } else if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = i / 8; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + + // RVV repacking. + // + // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. + uint8_t s[64], m[64]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[i * 8 + j] = in[j].scales[i] & 63; + m[i * 8 + j] = in[j].scales[i + 4] & 63; + } + } + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[32 + i * 8 + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[32 + i * 8 + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + } + + for (int i = 0; i < 64; i++) { + out.scales[i] = (s[i] & 15) + (m[i] & 15 << 4); + } + for (int i = 0; i < 32; i++) { + out.scales[64 + i] = (s[i] & 48 >> 4) + (m[i] & 48 >> 2) + (s[32 + i] & 48) + (m[32 + i] & 48 << 2); + } + } + + return out; +} + +static block_q4_Kx16 make_block_q4_Kx16(block_q4_K * in, unsigned int blck_size_interleave) { + block_q4_Kx16 out; + //Delta(scale) and dmin values of the 16 Q4_K structures are copied onto the output interleaved structure + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < 16; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + const int end = QK_K * 8 / blck_size_interleave; + + if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + + // RVV repacking. + // + // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. + uint8_t s[128], m[128]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 16; j++) { + s[i * 16 + j] = in[j].scales[i] & 63; + m[i * 16 + j] = in[j].scales[i + 4] & 63; + } + } + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 16; j++) { + s[64 + i * 16 + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[64 + i * 16 + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + } + + for (int i = 0; i < 128; i++) { + out.scales[i] = (s[i] & 15) | ((m[i] & 15) << 4); + } + for (int i = 0; i < 64; i++) { + out.scales[128 + i] = ((s[i] & 48) >> 4) | ((m[i] & 48) >> 2) | (s[64 + i] & 48) | ((m[64 + i] & 48) << 2); + } + } else { + GGML_ASSERT(false); } return out; @@ -2525,7 +2818,7 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 8 || interleave_block == 4); + GGML_ASSERT(interleave_block == 8 || interleave_block == 4 || interleave_block == 1); constexpr int nrows_interleaved = 8; block_q4_Kx8 * dst = (block_q4_Kx8*)t->data; @@ -2554,6 +2847,36 @@ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q4_K_to_q4_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + constexpr int nrows_interleaved = 16; + + block_q4_Kx16 * dst = (block_q4_Kx16*)t->data; + const block_q4_K * src = (const block_q4_K*) data; + block_q4_K dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_Kx16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q2_K); GGML_ASSERT(interleave_block == 8); @@ -2585,6 +2908,36 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + constexpr int nrows_interleaved = 16; + + block_q4_0x16 * dst = (block_q4_0x16*)t->data; + const block_q4_0 * src = (const block_q4_0*) data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, @@ -2717,11 +3070,16 @@ static block_q8_0x16 make_block_q8_0x16(block_q8_0 * in, unsigned int blck_size_ } const int end = QK8_0 * 16 / blck_size_interleave; - for (int i = 0; i < end; ++i) { - int src_id = i % 16; - int src_offset = (i / 16) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave); + + if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + } else { + GGML_ASSERT(false); } return out; @@ -2792,25 +3150,9 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s int src_offset = (i / 4) * blck_size_interleave; int dst_offset = i * blck_size_interleave; - for (int b = 0; b < 8; ++b) { - out.qs[dst_offset + b] = in[src_id].qs[src_offset + b]; - } - - // Generates bus error on RVV as this is auto-vectorized and the - // source might possible not be 8-byte aligned - // - // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); } - } else if (blck_size_interleave == 16) { - for (int i = 0; i < end; ++i) { - int src_id = i; - int src_offset = 0; - int dst_offset = i * 16; - - memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], 4 * sizeof(uint32_t)); - } - } - else { + } else { GGML_ASSERT(false); } @@ -2819,7 +3161,7 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); - // GGML_ASSERT(interleave_block == 4); + GGML_ASSERT(interleave_block == 4); const block_iq4_nl * src = (const block_iq4_nl *)data; block_iq4_nlx4 * dst = ( block_iq4_nlx4 *)t->data; @@ -2926,15 +3268,10 @@ static block_iq4_nlx16 make_block_iq4_nlx16(block_iq4_nl * in, unsigned int blck if (blck_size_interleave == 1) { for (int i = 0; i < end; ++i) { int src_id = i % 16; - int src_offset = (i / 16) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; + int src_offset = i / 16; + int dst_offset = i; out.qs[dst_offset] = in[src_id].qs[src_offset]; - - // Generates bus error on RVV as this is auto-vectorized and the - // source might possible not be 8-byte aligned - // - // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); } } else { GGML_ASSERT(false); @@ -3023,22 +3360,9 @@ template <> int repack(struct ggml_tensor * t, const void * // return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); //} -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); -} - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_iq4_nl_to_iq4_nl_4_bl(t, 16, data, data_size); -} - template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size); } - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_iq4_nl_to_iq4_nl_16_bl(t, 1, data, data_size); -} - template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size); } @@ -3047,9 +3371,27 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size); } +#if defined __riscv_zvfh +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 1, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_8_bl(t, 1, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_16_bl(t, 1, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_16_bl(t, 1, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q8_0_to_q8_0_16_bl(t, 1, data, data_size); } +#endif // gemv template @@ -3098,22 +3440,9 @@ template <> void gemv(int n, float * s, size ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_iq4_nl_4x8_q8_0(n, s, bs, vx, vy, nr, nc); -} - -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_iq4_nl_4x16_q8_0(n, s, bs, vx, vy, nr, nc); -} - template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } - -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); -} - template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -3122,9 +3451,27 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } +#endif // gemm template @@ -3173,22 +3520,10 @@ template <> void gemm(int n, float * s, size ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_iq4_nl_4x8_q8_0(n, s, bs, vx, vy, nr, nc); -} - template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_iq4_nl_4x16_q8_0(n, s, bs, vx, vy, nr, nc); -} - -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); -} - template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -3197,9 +3532,27 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } +#endif class tensor_traits_base : public ggml::cpu::tensor_traits { public: @@ -3597,19 +3950,25 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons // instance for IQ4 static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0; - static const ggml::cpu::repack::tensor_traits iq4_nl_4x16_q8_0; - static const ggml::cpu::repack::tensor_traits iq4_nl_4x8_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0; - static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; // instance for Q8_0 static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0; + + // instances for RISC-V + // + // These implement outer-product style multiplication with interleave of 1. +#if defined __riscv_zvfh + static const ggml::cpu::repack::tensor_traits q4_0_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits q4_K_8x1_q8_K; + static const ggml::cpu::repack::tensor_traits q4_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_16x1_q8_0; +#endif if (cur->type == GGML_TYPE_Q4_0) { - if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) - || (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) { + if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { if (cur->ne[1] % 8 == 0) { return &q4_0_8x8_q8_0; } @@ -3624,6 +3983,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q4_0_4x4_q8_0; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q4_0_16x1_q8_0; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_Q4_K) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { @@ -3640,6 +4010,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q4_K_8x4_q8_K; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q4_K_16x1_q8_K; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_Q2_K) { if (ggml_cpu_has_avx512()) { if (cur->ne[1] % 8 == 0) { @@ -3673,8 +4054,8 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { case 128: { break; } // TODO - case 256: { if (cur->ne[1] % 4 == 0) { return &iq4_nl_16x1_q8_0; } break; } - case 512: { if (cur->ne[1] % 8 == 0) { return &iq4_nl_8x8_q8_0; } break; } + case 256: { if (cur->ne[1] % 16 == 0) { return &iq4_nl_16x1_q8_0; } break; } + case 512: { break; } // TODO case 1024: { break; } // TODO default: { return nullptr; } } diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 2075d4961b..9029ee3eb9 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -28,11 +28,14 @@ template struct block { // control size static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding"); static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding"); +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<4,16> size/padding"); static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding"); static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding"); +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding"); using block_q4_0x4 = block<4, 4>; using block_q4_0x8 = block<4, 8>; +using block_q4_0x16 = block<4, 16>; using block_q8_0x4 = block<8, 4>; using block_q8_0x8 = block<8, 8>; using block_q8_0x16 = block<8, 16>; @@ -45,7 +48,14 @@ struct block_q4_Kx8 { }; static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); +struct block_q4_Kx16 { + ggml_half d[16]; // super-block scale for quantized scales + ggml_half dmin[16]; // super-block scale for quantized mins + uint8_t scales[192]; // scales and mins, quantized with 6 bits + uint8_t qs[2048]; // 4--bit quants +}; +static_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, "wrong q4_K block size/padding"); struct block_q2_Kx8 { ggml_half d[8]; // super-block scale for quantized scales ggml_half dmin[8]; // super-block scale for quantized mins @@ -110,10 +120,8 @@ static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "w extern "C" { #endif -void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); -void ggml_quantize_mat_q8_0_4x16(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -125,13 +133,9 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -141,19 +145,27 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#if defined __riscv_zvfh +void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#endif // Native implementations -void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); -void ggml_quantize_mat_q8_0_4x16_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -165,13 +177,9 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -181,13 +189,23 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#if defined __riscv_zvfh +void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#endif #if defined(__cplusplus) } // extern "C"