diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 427c1146e4..6fcb13d676 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -45,6 +45,7 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 @@ -57,6 +58,7 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K # define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 @@ -64,8 +66,10 @@ // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) @@ -78,6 +82,7 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -86,6 +91,7 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__POWERPC__) || defined(__powerpc__) @@ -109,6 +115,7 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 @@ -121,6 +128,7 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 @@ -145,6 +153,7 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 @@ -157,6 +166,7 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 @@ -177,9 +187,10 @@ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 -#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_0_4x16_generic ggml_quantize_mat_q8_0_4x16 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K @@ -188,7 +199,6 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -199,7 +209,6 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__s390x__) @@ -229,6 +238,7 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 @@ -241,6 +251,7 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 @@ -273,6 +284,7 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_4x8_q8_0_generic ggml_gemv_iq4_nl_4x8_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 @@ -285,6 +297,7 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_4x8_q8_0_generic ggml_gemm_iq4_nl_4x8_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 9f37197ab4..667d889a7d 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -203,6 +203,527 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + vfloat32mf2_t sumf = __riscv_vfmv_v_f_f32mf2(0.0, 4); + for (int l = 0; l < nb; l++) { + // Load first 8 bytes of `a`. + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; + __asm__ __volatile__("" ::: "memory"); + + // Load `b_ptr`. + const vuint8m2_t b_0_packed = __riscv_vle8_v_u8m2((const uint8_t *)b_ptr[l].qs, QK4_NL * 2); + const vint8m2_t b_0_lo = __riscv_vrgather_vv_i8m2(values, __riscv_vand_vx_u8m2(b_0_packed, 0xf, QK4_NL * 2), QK4_NL * 2); + const vint8m2_t b_0_hi = __riscv_vrgather_vv_i8m2(values, __riscv_vsrl_vx_u8m2(b_0_packed, 4, QK4_NL * 2), QK4_NL * 2); + + // Create 4 segments from `b`. + const vint8m1_t b_lo_0 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 0); + const vint8m1_t b_lo_1 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 1); + const vint8m1_t b_hi_0 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 0); + const vint8m1_t b_hi_1 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 1); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); + const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); + const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); + const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); + + // Multiply and accumulate. + const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); + const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); + const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); + const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); + const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); + const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); + const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); + + // In-place reduction. + const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); + const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); + const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); + const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); + const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); + const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); + const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); + const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); + const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); + const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); + const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); + + // Multiply with scales. + const vfloat16mf4_t b_d = __riscv_vle16_v_f16mf4((const _Float16 *)b_ptr[l].d, 4); + const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d, 4); + sumf = __riscv_vfmacc_vv_f32mf2(sumf, facc, d_0, QK4_NL / 8); + } + __riscv_vse32_v_f32mf2(s + x * ncols_interleaved, sumf, QK4_NL / 8); + } + return; + +#endif + ggml_gemv_iq4_nl_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0f, 4); + for (int l = 0; l + 1 < nb; l += 2) { + vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 0, 16); + vuint8m1_t b_1_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 16, 16); + vuint8m1_t b_2_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 32, 16); + vuint8m1_t b_3_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 48, 16); + vuint8m1_t b_4_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 0, 16); + vuint8m1_t b_5_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 16, 16); + vuint8m1_t b_6_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 32, 16); + vuint8m1_t b_7_packed = __riscv_vle8_v_u8m1(b_ptr[l + 1].qs + 48, 16); + + vuint8m1_t b_0_lo = __riscv_vand_vx_u8m1(b_0_packed, 0xf, 16); + vuint8m1_t b_0_hi = __riscv_vsrl_vx_u8m1(b_0_packed, 4, 16); + vuint8m1_t b_1_lo = __riscv_vand_vx_u8m1(b_1_packed, 0xf, 16); + vuint8m1_t b_1_hi = __riscv_vsrl_vx_u8m1(b_1_packed, 4, 16); + vuint8m1_t b_2_lo = __riscv_vand_vx_u8m1(b_2_packed, 0xf, 16); + vuint8m1_t b_2_hi = __riscv_vsrl_vx_u8m1(b_2_packed, 4, 16); + vuint8m1_t b_3_lo = __riscv_vand_vx_u8m1(b_3_packed, 0xf, 16); + vuint8m1_t b_3_hi = __riscv_vsrl_vx_u8m1(b_3_packed, 4, 16); + vuint8m1_t b_4_lo = __riscv_vand_vx_u8m1(b_4_packed, 0xf, 16); + vuint8m1_t b_4_hi = __riscv_vsrl_vx_u8m1(b_4_packed, 4, 16); + vuint8m1_t b_5_lo = __riscv_vand_vx_u8m1(b_5_packed, 0xf, 16); + vuint8m1_t b_5_hi = __riscv_vsrl_vx_u8m1(b_5_packed, 4, 16); + vuint8m1_t b_6_lo = __riscv_vand_vx_u8m1(b_6_packed, 0xf, 16); + vuint8m1_t b_6_hi = __riscv_vsrl_vx_u8m1(b_6_packed, 4, 16); + vuint8m1_t b_7_lo = __riscv_vand_vx_u8m1(b_7_packed, 0xf, 16); + vuint8m1_t b_7_hi = __riscv_vsrl_vx_u8m1(b_7_packed, 4, 16); + + vint8m1_t b_0 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_0_lo, b_0_hi, 16, 32), 32); + vint8m1_t b_1 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_1_lo, b_1_hi, 16, 32), 32); + vint8m1_t b_2 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_2_lo, b_2_hi, 16, 32), 32); + vint8m1_t b_3 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_3_lo, b_3_hi, 16, 32), 32); + vint8m1_t b_4 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_4_lo, b_4_hi, 16, 32), 32); + vint8m1_t b_5 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_5_lo, b_5_hi, 16, 32), 32); + vint8m1_t b_6 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_6_lo, b_6_hi, 16, 32), 32); + vint8m1_t b_7 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_7_lo, b_7_hi, 16, 32), 32); + + vint8m1_t a_0 = __riscv_vle8_v_i8m1(a_ptr[l].qs, 32); + vint8m1_t a_1 = __riscv_vle8_v_i8m1(a_ptr[l + 1].qs, 32); + + vint32m1_t sumi_0 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_0, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_1 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_1, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_2 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_2, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_3 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_0, b_3, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_4 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_4, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_5 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_5, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_6 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_6, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_7 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_1, b_7, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + + int sumi_temp[8]; + __riscv_vse32_v_i32m1(&sumi_temp[0], sumi_0, 1); + __riscv_vse32_v_i32m1(&sumi_temp[1], sumi_1, 1); + __riscv_vse32_v_i32m1(&sumi_temp[2], sumi_2, 1); + __riscv_vse32_v_i32m1(&sumi_temp[3], sumi_3, 1); + __riscv_vse32_v_i32m1(&sumi_temp[4], sumi_4, 1); + __riscv_vse32_v_i32m1(&sumi_temp[5], sumi_5, 1); + __riscv_vse32_v_i32m1(&sumi_temp[6], sumi_6, 1); + __riscv_vse32_v_i32m1(&sumi_temp[7], sumi_7, 1); + vint32m1_t sum_0 = __riscv_vle32_v_i32m1(&sumi_temp[0], 4); + vint32m1_t sum_1 = __riscv_vle32_v_i32m1(&sumi_temp[4], 4); + + vfloat16mf2_t b_d_0 = __riscv_vle16_v_f16mf2((_Float16 *)b_ptr[l].d, 4); + vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d_0, *(const _Float16 *)&a_ptr[l].d, 4); + vfloat16mf2_t b_d_1 = __riscv_vle16_v_f16mf2((_Float16 *)b_ptr[l + 1].d, 4); + vfloat32m1_t d_1 = __riscv_vfwmul_vf_f32m1(b_d_1, *(const _Float16 *)&a_ptr[l + 1].d, 4); + + sumf = __riscv_vfmacc_vv_f32m1(sumf, d_0, __riscv_vfcvt_f_x_v_f32m1(sum_0, 4), 4); + sumf = __riscv_vfmacc_vv_f32m1(sumf, d_1, __riscv_vfcvt_f_x_v_f32m1(sum_1, 4), 4); + } + __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, 4); + } + return; +#endif + ggml_gemv_iq4_nl_4x16_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + + // 1x16 Accumulator1 + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 1x16 integer accumulator + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)b_ptr[l].qs, 16); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); + // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); + // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); + + // Accumulation loop. + for (int i = 0; i < 16; i++) { + const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], 16); + const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], 16); + sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16); + } + + vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } + return; +#endif + ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + // 4x4 Accumulators + vfloat32m1_t sumf_0 = __riscv_vfmv_v_f_f32m1(0.0f, 4); + vfloat32m1_t sumf_1 = __riscv_vfmv_v_f_f32m1(0.0f, 4); + vfloat32m1_t sumf_2 = __riscv_vfmv_v_f_f32m1(0.0f, 4); + vfloat32m1_t sumf_3 = __riscv_vfmv_v_f_f32m1(0.0f, 4); + + for (int l = 0; l < nb; l++) { + int sumi_temp[16]; + uint8_t index[4] = {0, 8, 64, 72}; + vuint8mf8_t i_vec = __riscv_vle8_v_u8mf8(&index[0], 4); + vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 0, 16); + vuint8m1_t b_1_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 16, 16); + vuint8m1_t b_2_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 32, 16); + vuint8m1_t b_3_packed = __riscv_vle8_v_u8m1(b_ptr[l].qs + 48, 16); + + vuint8m1_t b_0_lo = __riscv_vand_vx_u8m1(b_0_packed, 0xf, 16); + vuint8m1_t b_0_hi = __riscv_vsrl_vx_u8m1(b_0_packed, 4, 16); + vuint8m1_t b_1_lo = __riscv_vand_vx_u8m1(b_1_packed, 0xf, 16); + vuint8m1_t b_1_hi = __riscv_vsrl_vx_u8m1(b_1_packed, 4, 16); + vuint8m1_t b_2_lo = __riscv_vand_vx_u8m1(b_2_packed, 0xf, 16); + vuint8m1_t b_2_hi = __riscv_vsrl_vx_u8m1(b_2_packed, 4, 16); + vuint8m1_t b_3_lo = __riscv_vand_vx_u8m1(b_3_packed, 0xf, 16); + vuint8m1_t b_3_hi = __riscv_vsrl_vx_u8m1(b_3_packed, 4, 16); + + vint8m1_t b_0 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_0_lo, b_0_hi, 16, 32), 32); + vint8m1_t b_1 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_1_lo, b_1_hi, 16, 32), 32); + vint8m1_t b_2 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_2_lo, b_2_hi, 16, 32), 32); + vint8m1_t b_3 = __riscv_vrgather_vv_i8m1(values, __riscv_vslideup_vx_u8m1(b_3_lo, b_3_hi, 16, 32), 32); + + #pragma unroll 4 + for (int i = 0; i < 4; i++) { + vint8m1_t a_i = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vloxei8_v_i64m1((int64_t*)(a_ptr[l].qs + i * 16), i_vec, 4)); + vint32m1_t sumi_0 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_0, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_1 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_1, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_2 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_2, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + vint32m1_t sumi_3 = __riscv_vwredsum_vs_i16m2_i32m1(__riscv_vwmul_vv_i16m2(a_i, b_3, 32), __riscv_vmv_v_x_i32m1(0, 1), 32); + __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 0], sumi_0, 1); + __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 1], sumi_1, 1); + __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 2], sumi_2, 1); + __riscv_vse32_v_i32m1(&sumi_temp[i * 4 + 3], sumi_3, 1); + } + + vint32m1_t sum_0 = __riscv_vle32_v_i32m1(&sumi_temp[0], 4); + vint32m1_t sum_1 = __riscv_vle32_v_i32m1(&sumi_temp[4], 4); + vint32m1_t sum_2 = __riscv_vle32_v_i32m1(&sumi_temp[8], 4); + vint32m1_t sum_3 = __riscv_vle32_v_i32m1(&sumi_temp[12], 4); + + vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((_Float16 *)b_ptr[l].d, 4); + vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[0], 4); + vfloat32m1_t d_1 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[1], 4); + vfloat32m1_t d_2 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[2], 4); + vfloat32m1_t d_3 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16 *)&a_ptr[l].d[3], 4); + + sumf_0 = __riscv_vfmacc_vv_f32m1(sumf_0, d_0, __riscv_vfcvt_f_x_v_f32m1(sum_0, 4), 4); + sumf_1 = __riscv_vfmacc_vv_f32m1(sumf_1, d_1, __riscv_vfcvt_f_x_v_f32m1(sum_1, 4), 4); + sumf_2 = __riscv_vfmacc_vv_f32m1(sumf_2, d_2, __riscv_vfcvt_f_x_v_f32m1(sum_2, 4), 4); + sumf_3 = __riscv_vfmacc_vv_f32m1(sumf_3, d_3, __riscv_vfcvt_f_x_v_f32m1(sum_3, 4), 4); + } + + __riscv_vse32_v_f32m1(s + (y * 4 + 0) * bs + x * 4, sumf_0, 4); + __riscv_vse32_v_f32m1(s + (y * 4 + 1) * bs + x * 4, sumf_1, 4); + __riscv_vse32_v_f32m1(s + (y * 4 + 2) * bs + x * 4, sumf_2, 4); + __riscv_vse32_v_f32m1(s + (y * 4 + 3) * bs + x * 4, sumf_3, 4); + } + } + return; +#endif + ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 4x16 integer accumulators + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)b_ptr[l].qs, 16); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); + // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); + // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); + + // Accumulation loop. + for (int i = 0; i < 16; i++) { + const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], 16); + const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], 16); + const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], 16); + const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], 16); + + const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], 16); + const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], 16); + const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], 16); + const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], 16); + + sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, 16), 16); + sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, 16), 16); + sumi_2 = __riscv_vadd_vv_i32m2(sumi_2, __riscv_vwadd_vv_i32m2(sumi_2_lo, sumi_2_hi, 16), 16); + sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16); + } + + vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); + vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); + vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); + vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } + return; +#endif + ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); + + vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, 8); + for (int l = 0; l < nb; l++) { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); + const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); + const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); + const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); + + // Load `b_ptr`. + const vuint8m4_t b_0_packed = __riscv_vle8_v_u8m4((const uint8_t *)b_ptr[l].qs, QK4_NL * 4); + const vint8m4_t b_0_lo = __riscv_vrgather_vv_i8m4(values, __riscv_vand_vx_u8m4(b_0_packed, 0xf, QK4_NL * 4), QK4_NL * 4); + const vint8m4_t b_0_hi = __riscv_vrgather_vv_i8m4(values, __riscv_vsrl_vx_u8m4(b_0_packed, 4, QK4_NL * 4), QK4_NL * 4); + + // Create 4 segments from `b`. + const vint8m2_t b_lo_0 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 0); + const vint8m2_t b_lo_1 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 1); + const vint8m2_t b_hi_0 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 0); + const vint8m2_t b_hi_1 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 1); + + // Multiply and accumulate. + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); + const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); + const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); + const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); + const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); + + // In-place reduction. + const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); + const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); + const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); + const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); + const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); + const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); + const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); + const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); + const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); + const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); + + // Multiply with scales. + const vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((const _Float16 *)b_ptr[l].d, 8); + const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d, 8); + sumf = __riscv_vfmacc_vv_f32m1(sumf, facc, d_0, QK4_NL / 4); + } + __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, QK4_NL / 4); + } + return; + +#endif + ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -428,3 +949,469 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } + +void ggml_gemm_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + // 4x4 accumulators. + vfloat32mf2_t sumf0 = __riscv_vfmv_v_f_f32mf2(0.0, 4); + vfloat32mf2_t sumf1 = __riscv_vfmv_v_f_f32mf2(0.0, 4); + vfloat32mf2_t sumf2 = __riscv_vfmv_v_f_f32mf2(0.0, 4); + vfloat32mf2_t sumf3 = __riscv_vfmv_v_f_f32mf2(0.0, 4); + + for (int l = 0; l < nb; l++) { + // Load `b_ptr`. + const vuint8m2_t b_0_packed = __riscv_vle8_v_u8m2((const uint8_t *)b_ptr[l].qs, QK4_NL * 2); + const vint8m2_t b_0_lo = __riscv_vrgather_vv_i8m2(values, __riscv_vand_vx_u8m2(b_0_packed, 0xf, QK4_NL * 2), QK4_NL * 2); + const vint8m2_t b_0_hi = __riscv_vrgather_vv_i8m2(values, __riscv_vsrl_vx_u8m2(b_0_packed, 4, QK4_NL * 2), QK4_NL * 2); + + // Create 4 segments from `b`. + const vint8m1_t b_lo_0 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 0); + const vint8m1_t b_lo_1 = __riscv_vget_v_i8m2_i8m1(b_0_lo, 1); + const vint8m1_t b_hi_0 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 0); + const vint8m1_t b_hi_1 = __riscv_vget_v_i8m2_i8m1(b_0_hi, 1); + + // Load scales for `b`. + const vfloat16mf4_t b_d = __riscv_vle16_v_f16mf4((const _Float16 *)b_ptr[l].d, 4); + + // Load first 8 bytes of `a`. + { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[32]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[64]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[96]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); + const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); + const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); + const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); + + // Multiply and accumulate. + const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); + const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); + const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); + const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); + const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); + const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); + const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); + + // In-place reduction. + const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); + const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); + const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); + const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); + const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); + const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); + const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); + const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); + const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); + const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); + const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); + + // Multiply with scales. + const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[0], 4); + sumf0 = __riscv_vfmacc_vv_f32mf2(sumf0, facc, d_0, QK4_NL / 8); + } + + // Load second 8 bytes of `a`. + { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[40]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[72]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[104]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); + const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); + const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); + const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); + + // Multiply and accumulate. + const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); + const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); + const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); + const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); + const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); + const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); + const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); + + // In-place reduction. + const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); + const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); + const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); + const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); + const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); + const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); + const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); + const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); + const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); + const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); + const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); + + // Multiply with scales. + const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[1], 4); + sumf1 = __riscv_vfmacc_vv_f32mf2(sumf1, facc, d_0, QK4_NL / 8); + } + + // Load third 8 bytes of `a`. + { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[48]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[80]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[112]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); + const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); + const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); + const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); + + // Multiply and accumulate. + const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); + const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); + const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); + const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); + const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); + const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); + const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); + + // In-place reduction. + const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); + const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); + const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); + const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); + const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); + const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); + const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); + const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); + const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); + const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); + const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); + + // Multiply with scales. + const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[2], 4); + sumf2 = __riscv_vfmacc_vv_f32mf2(sumf2, facc, d_0, QK4_NL / 8); + } + + // Load fourth 8 bytes of `a`. + { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[24]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[56]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[88]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[120]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m1_t a_0 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a0, 4)); + const vint8m1_t a_1 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a1, 4)); + const vint8m1_t a_2 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a2, 4)); + const vint8m1_t a_3 = __riscv_vreinterpret_v_i64m1_i8m1(__riscv_vmv_v_x_i64m1(a3, 4)); + + // Multiply and accumulate. + const vint16m2_t sumi_lo_0 = __riscv_vwmul_vv_i16m2(b_lo_0, a_0, QK4_NL); + const vint16m2_t sumi_lo_1 = __riscv_vwmul_vv_i16m2(b_lo_1, a_1, QK4_NL); + const vint16m2_t sumi_hi_0 = __riscv_vwmul_vv_i16m2(b_hi_0, a_2, QK4_NL); + const vint16m2_t sumi_hi_1 = __riscv_vwmul_vv_i16m2(b_hi_1, a_3, QK4_NL); + const vint32m4_t sumi_lo = __riscv_vwadd_vv_i32m4(sumi_lo_0, sumi_lo_1, QK4_NL); + const vint32m4_t sumi_hi = __riscv_vwadd_vv_i32m4(sumi_hi_0, sumi_hi_1, QK4_NL); + const vint32m4_t sumi = __riscv_vadd_vv_i32m4(sumi_lo, sumi_hi, QK4_NL); + + // In-place reduction. + const vuint64m4_t sumi_i32 = __riscv_vreinterpret_v_i64m4_u64m4(__riscv_vreinterpret_v_i32m4_i64m4(sumi)); + const vuint32m2_t sumi_h2_0 = __riscv_vnsrl_wx_u32m2(sumi_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h2_1 = __riscv_vnsrl_wx_u32m2(sumi_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h2 = __riscv_vadd_vv_u32m2(sumi_h2_0, sumi_h2_1, QK4_NL/ 2); + const vuint64m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h2); + const vuint32m1_t sumi_h4_0 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 0, QK4_NL / 4); + const vuint32m1_t sumi_h4_1 = __riscv_vnsrl_wx_u32m1(sumi_h2_i32, 32, QK4_NL / 4); + const vuint32m1_t sumi_h4 = __riscv_vadd_vv_u32m1(sumi_h4_0, sumi_h4_1, QK4_NL / 4); + const vuint64m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m1_u64m1(sumi_h4); + const vint32mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 0, QK4_NL / 8)); + const vint32mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vnsrl_wx_u32mf2(sumi_h4_i32, 32, QK4_NL / 8)); + const vint32mf2_t sumi_h8 = __riscv_vadd_vv_i32mf2(sumi_h8_0, sumi_h8_1, QK4_NL / 8); + const vfloat32mf2_t facc = __riscv_vfcvt_f_x_v_f32mf2(sumi_h8, QK4_NL / 8); + + // Multiply with scales. + const vfloat32mf2_t d_0 = __riscv_vfwmul_vf_f32mf2(b_d, *(const _Float16*)&a_ptr[l].d[3], 4); + sumf3 = __riscv_vfmacc_vv_f32mf2(sumf3, facc, d_0, QK4_NL / 8); + } + } + + __riscv_vse32_v_f32mf2(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, 8); + __riscv_vse32_v_f32mf2(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, 8); + __riscv_vse32_v_f32mf2(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, 8); + __riscv_vse32_v_f32mf2(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, 8); + } + } + return; + +#endif + ggml_gemm_iq4_nl_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); + + // 4x8 accumulators. + vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, 8); + vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, 8); + vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, 8); + vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, 8); + + for (int l = 0; l < nb; l++) { + // Load `b_ptr`. + const vuint8m4_t b_0_packed = __riscv_vle8_v_u8m4((const uint8_t *)b_ptr[l].qs, QK4_NL * 4); + const vint8m4_t b_0_lo = __riscv_vrgather_vv_i8m4(values, __riscv_vand_vx_u8m4(b_0_packed, 0xf, QK4_NL * 4), QK4_NL * 4); + const vint8m4_t b_0_hi = __riscv_vrgather_vv_i8m4(values, __riscv_vsrl_vx_u8m4(b_0_packed, 4, QK4_NL * 4), QK4_NL * 4); + + // Create 4 segments from `b`. + const vint8m2_t b_lo_0 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 0); + const vint8m2_t b_lo_1 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 1); + const vint8m2_t b_hi_0 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 0); + const vint8m2_t b_hi_1 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 1); + + // Load scales for `b`. + const vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((const _Float16 *)b_ptr[l].d, 8); + + { + // Load first 8 bytes of `a`. + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[32]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[64]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[96]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); + const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); + const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); + const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); + + // Multiply and accumulate. + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); + const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); + const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); + const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); + const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); + + // In-place reduction. + const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); + const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); + const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); + const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); + const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); + const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); + const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); + const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); + const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); + const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); + + // Multiply with scales. + const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[0], 8); + sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, facc, d_0, QK4_NL / 4); + } + + // Load second 8 bytes of `a`. + { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[40]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[72]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[104]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); + const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); + const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); + const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); + + // Multiply and accumulate. + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); + const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); + const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); + const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); + const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); + + // In-place reduction. + const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); + const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); + const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); + const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); + const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); + const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); + const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); + const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); + const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); + const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); + + // Multiply with scales. + const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[1], 8); + sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, facc, d_0, QK4_NL / 4); + } + + // Load third 8 bytes of `a`. + { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[48]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[80]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[112]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); + const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); + const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); + const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); + + // Multiply and accumulate. + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); + const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); + const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); + const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); + const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); + + // In-place reduction. + const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); + const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); + const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); + const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); + const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); + const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); + const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); + const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); + const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); + const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); + + // Multiply with scales. + const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[2], 8); + sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, facc, d_0, QK4_NL / 4); + } + + { + // Load fourth 8 bytes of `a`. + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[24]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[56]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[88]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[120]; + __asm__ __volatile__("" ::: "memory"); + + // Broadcast `a_ptr` across 4 registers (8 bytes / register). + const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); + const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); + const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); + const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); + + // Multiply and accumulate. + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); + const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); + const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); + const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); + const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); + + // In-place reduction. + const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); + const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); + const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); + const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); + const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); + const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); + const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); + const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); + const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); + const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); + const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); + const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); + + // Multiply with scales. + const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[3], 8); + sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, facc, d_0, QK4_NL / 4); + } + } + + __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, 8); + __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, 8); + __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, 8); + __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, 8); + } + } + return; + +#endif + ggml_gemm_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 24e8ab4618..5532e68035 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -48,6 +48,44 @@ static inline int nearest_int(float fval) { extern "C" { +void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + + // scalar + const int blck_size_interleave = 1; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +} + void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); @@ -124,6 +162,43 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG } } +void ggml_quantize_mat_q8_0_4x16_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + + // scalar + const int blck_size_interleave = 16; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +} void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); @@ -238,12 +313,24 @@ template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTR ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row); } +template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_0_4x1(x, vy, n_per_row); +} + template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); } +template <> void ggml_quantize_mat_t<16, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_0_4x16(x, vy, n_per_row); +} + template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); @@ -832,6 +919,82 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemv_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1646,6 +1809,118 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemm_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +void ggml_gemm_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -2380,7 +2655,31 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t)); } - } else { + } else if (blck_size_interleave == 8) { + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + for (int b = 0; b < 8; ++b) { + out.qs[dst_offset + b] = in[src_id].qs[src_offset + b]; + } + + // Generates bus error on RVV as this is auto-vectorized and the + // source might possible not be 8-byte aligned + // + // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + } + } else if (blck_size_interleave == 16) { + for (int i = 0; i < end; ++i) { + int src_id = i; + int src_offset = 0; + int dst_offset = i * 16; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], 4 * sizeof(uint32_t)); + } + } + else { GGML_ASSERT(false); } @@ -2389,7 +2688,7 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); - GGML_ASSERT(interleave_block == 4); + // GGML_ASSERT(interleave_block == 4); const block_iq4_nl * src = (const block_iq4_nl *)data; block_iq4_nlx4 * dst = ( block_iq4_nlx4 *)t->data; @@ -2435,7 +2734,14 @@ static block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_s int src_offset = (i / 8) * blck_size_interleave; int dst_offset = i * blck_size_interleave; - memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + for (int b = 0; b < 8; ++b) { + out.qs[dst_offset + b] = in[src_id].qs[src_offset + b]; + } + + // Generates bus error on RVV as this is auto-vectorized and the + // source might possible not be 8-byte aligned + // + // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); } } else { GGML_ASSERT(false); @@ -2477,6 +2783,68 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b GGML_UNUSED(data_size); } +static block_iq4_nlx16 make_block_iq4_nlx16(block_iq4_nl * in, unsigned int blck_size_interleave) { + block_iq4_nlx16 out; + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_NL * 8 / blck_size_interleave; + + if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = (i / 16) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + + // Generates bus error on RVV as this is auto-vectorized and the + // source might possible not be 8-byte aligned + // + // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); + GGML_ASSERT(interleave_block == 1); + + const block_iq4_nl * src = (const block_iq4_nl *)data; + block_iq4_nlx16 * dst = ( block_iq4_nlx16 *)t->data; + + block_iq4_nl dst_tmp[16]; + + int nrow = ggml_nrows(t); + int nrows_interleaved = 16; + int nblocks = t->ne[0] / QK4_NL; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); + + if (t->ne[1] % nrows_interleaved != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_iq4_nlx16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + namespace ggml::cpu::repack { // repack template @@ -2524,10 +2892,22 @@ template <> int repack(struct ggml_tensor * t, const void * // return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); //} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_4_bl(t, 16, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_16_bl(t, 1, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size); } @@ -2583,10 +2963,22 @@ template <> void gemv(int n, float * s, size ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_4x16_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -2642,10 +3034,22 @@ template <> void gemm(int n, float * s, size ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_4x16_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -3050,7 +3454,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons // instance for IQ4 static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_4x16_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_4x8_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; // instance for Q8_0 static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0; @@ -3118,6 +3525,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &iq4_nl_4x4_q8_0; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 4 == 0) { return &iq4_nl_16x1_q8_0; } break; } + case 512: { if (cur->ne[1] % 8 == 0) { return &iq4_nl_8x8_q8_0; } break; } + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_Q8_0) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (cur->ne[1] % 4 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 855320eeeb..d0431276bd 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -97,12 +97,22 @@ struct block_iq4_nlx8 { static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding"); +struct block_iq4_nlx16 { + ggml_half d[16]; // deltas for 16 iq4_nl blocks + uint8_t qs[QK4_NL * 8]; // nibbles / quants for 16 iq4_nl blocks +}; + +static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "wrong iq4_nlx16 block size/padding"); + + #if defined(__cplusplus) extern "C" { #endif +void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_0_4x16(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -114,7 +124,10 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -124,15 +137,20 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_4x16_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); // Native implementations +void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_0_4x16_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -144,7 +162,10 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -154,7 +175,10 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_4x16_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);