diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index bc9d35a073..a1b198a15c 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -189,6 +189,7 @@ #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -199,6 +200,7 @@ #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__s390x__) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 21dda5c0a9..8ef3d2b754 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -1,4 +1,3 @@ -#include #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP #include "ggml-common.h" @@ -250,7 +249,7 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); @@ -263,14 +262,6 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - // TODO -} - -void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - // TODO -} - void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -404,93 +395,6 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v_intrinsic - const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); - - vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, 8); - for (int l = 0; l < nb; l++) { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); - const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); - const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); - const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); - - // Load `b_ptr`. - const vuint8m4_t b_0_packed = __riscv_vle8_v_u8m4((const uint8_t *)b_ptr[l].qs, QK4_NL * 4); - const vint8m4_t b_0_lo = __riscv_vrgather_vv_i8m4(values, __riscv_vand_vx_u8m4(b_0_packed, 0xf, QK4_NL * 4), QK4_NL * 4); - const vint8m4_t b_0_hi = __riscv_vrgather_vv_i8m4(values, __riscv_vsrl_vx_u8m4(b_0_packed, 4, QK4_NL * 4), QK4_NL * 4); - - // Create 4 segments from `b`. - const vint8m2_t b_lo_0 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 0); - const vint8m2_t b_lo_1 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 1); - const vint8m2_t b_hi_0 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 0); - const vint8m2_t b_hi_1 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 1); - - // Multiply and accumulate. - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); - const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); - const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); - const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); - const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); - - // In-place reduction. - const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); - const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); - const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); - const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); - const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); - const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); - const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); - const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); - const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); - const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); - - // Multiply with scales. - const vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((const _Float16 *)b_ptr[l].d, 8); - const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d, 8); - sumf = __riscv_vfmacc_vv_f32m1(sumf, facc, d_0, QK4_NL / 4); - } - __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, QK4_NL / 4); - } - return; - -#endif - ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -537,7 +441,7 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16); } - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); @@ -568,6 +472,7 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(nb); UNUSED(ncols_interleaved); UNUSED(blocklen); + UNUSED(bs); #if defined __riscv_v_intrinsic const block_q8_0 * a_ptr = (const block_q8_0 *) vy; @@ -590,7 +495,7 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16); } - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); @@ -603,6 +508,405 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr == 1); + assert(nc % 16 == 0); + + UNUSED(bs); + + const int N_COLS_TILE = 16; + const int num_k_blocks = n / QK_K; + + const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + + const block_q8_K* lhs_base_ptr = (const block_q8_K*)vy; + const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks; + + vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int k_block = 0; k_block < num_k_blocks; ++k_block) { + const block_q8_K* lhs_current = &lhs_base_ptr[k_block]; + const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block]; + + // 1. Prepare Global Min Scales + vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); + vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl); + + vfloat32m2_t v_g_min_final = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d, vl); + + vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, vl); + + const uint8_t* rhs_qs_ptr = rhs_current->qs; + const uint8_t* rhs_sc_ptr = rhs_current->scales; + const int8_t* lhs_qs_ptr = lhs_current->qs; + + // --- Phase Loop (4 phases x 64 elements) --- + for (int phase = 0; phase < 4; ++phase) { + + // A. Load Scales/Mins + vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3; + vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3; + + { + vuint8mf2_t v_raw; + // Sub-block 0 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl); + v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 1 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl); + v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 2 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl); + v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 3 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl); + v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + rhs_sc_ptr += 64; + } + + int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16); + int k_offsets[4] = {0, 32, 64, 96}; + + // B. Inner Dot Product Loop + for (int l = 0; l < 16; ++l) { + vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += 16; + + // Sub-block 0 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[0] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + // Sub-block 1 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[1] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + // Sub-block 2 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[2] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + // Sub-block 3 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[3] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + } + + // correction + int sb_base_abs = base_k_phase / 16; + + // Sub-block 0 + { + int sb_idx = sb_base_abs + (k_offsets[0] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + // Sub-block 1 + { + int sb_idx = sb_base_abs + (k_offsets[1] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + // Sub-block 2 + { + int sb_idx = sb_base_abs + (k_offsets[2] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + // Sub-block 3 + { + int sb_idx = sb_base_abs + (k_offsets[3] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + + } // End Phase Loop + + // Apply global Scales + vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl); + vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl); + + vfloat32m2_t v_g_all_final = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d, vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all_final, vl); + v_sumf = __riscv_vfadd_vv_f32m2(v_sumf, v_sum, vl); + + } // End K-Block + __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl); + + } +} + +void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v + if (__riscv_vlenb() >= QK4_0) { + const size_t vl = QK4_0; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + for (int l = 0; l < nb; l++) { + const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); + const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); + const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); + const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); + const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); + const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); + + // vector version needs Zvfhmin extension + const float a_scales[4] = { + GGML_CPU_FP16_TO_FP32(a_ptr[l].d[0]), + GGML_CPU_FP16_TO_FP32(a_ptr[l].d[1]), + GGML_CPU_FP16_TO_FP32(a_ptr[l].d[2]), + GGML_CPU_FP16_TO_FP32(a_ptr[l].d[3]) + }; + const float b_scales[8] = { + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[0]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[1]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[2]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[3]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[4]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[5]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[6]), + GGML_CPU_FP16_TO_FP32(b_ptr[l].d[7]) + }; + const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); + + const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32]; + const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64]; + const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l0; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l0 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); + sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40]; + const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72]; + const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l1; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l1 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4); + sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48]; + const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80]; + const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l2; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l2 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4); + sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24]; + const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56]; + const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88]; + const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l3; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l3 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4); + sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4); + } + } + __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4); + } + } + + return; + } + +#endif + ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -672,7 +976,7 @@ void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, 16); const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, 16); - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); @@ -943,7 +1247,7 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } } - const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16), 16); + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16), 16); const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], 16); const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], 16); const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], 16); @@ -966,14 +1270,6 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - // TODO -} - -void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - // TODO -} - void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1040,7 +1336,7 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16); } - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); @@ -1114,7 +1410,7 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v sumi_3 = __riscv_vwadd_wv_i32m2(sumi_3, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 3], 16), 16); } - const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((_Float16 *)b_ptr[l].d, 16); + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); @@ -1137,461 +1433,300 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; +void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + const int num_k_blocks = n / QK_K; + const int N_ROWS_TILE = 4; + const int N_COLS_TILE = 16; + assert(nr % N_ROWS_TILE == 0); + assert(nc % N_COLS_TILE == 0); - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); + const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + // --- Tiling Loops --- +#pragma GCC unroll 1 + for (int row_tile = 0; row_tile < nr; row_tile += N_ROWS_TILE) { +#pragma GCC unroll 1 + for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + // Base Pointers + const block_q8_Kx4* lhs_base_ptr = (const block_q8_Kx4*)vy + (row_tile / N_ROWS_TILE) * num_k_blocks; + const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks; - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); + // Persistent Float Accumulators + vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, vl); -#if defined __riscv_v - if (__riscv_vlenb() >= QK4_0) { - const size_t vl = QK4_0; + // --- Super-Block Loop (K=0..255) --- +#pragma GCC unroll 1 + for (int k_block = 0; k_block < num_k_blocks; ++k_block) { + const block_q8_Kx4* lhs_current = &lhs_base_ptr[k_block]; + const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block]; - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - for (int l = 0; l < nb; l++) { - const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); - const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); - const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); - const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); - const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); - const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); - const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); + // 1. Load Global Min Scales (Keep as F16/LMUL=1 to save registers) + vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); + vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl); - // vector version needs Zvfhmin extension - const float a_scales[4] = { - GGML_CPU_FP16_TO_FP32(a_ptr[l].d[0]), - GGML_CPU_FP16_TO_FP32(a_ptr[l].d[1]), - GGML_CPU_FP16_TO_FP32(a_ptr[l].d[2]), - GGML_CPU_FP16_TO_FP32(a_ptr[l].d[3]) - }; - const float b_scales[8] = { - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[0]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[1]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[2]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[3]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[4]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[5]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[6]), - GGML_CPU_FP16_TO_FP32(b_ptr[l].d[7]) - }; - const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); + // 2. Initialize Integer Accumulators + vint32m2_t v_isum_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_3 = __riscv_vmv_v_x_i32m2(0, vl); - const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32]; - const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64]; - const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l0; + const uint8_t* rhs_qs_ptr = rhs_current->qs; + const uint8_t* rhs_sc_ptr = rhs_current->scales; + const int8_t* lhs_qs_ptr = lhs_current->qs; + + // --- Phase Loop (4 phases x 64 elements) --- +#pragma GCC unroll 1 + for (int phase = 0; phase < 4; ++phase) { + + // A. Load Scales/Mins for the 4 interleaved sub-blocks + vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3; + vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3; + + // Unrolled Load Logic { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + vuint8mf2_t v_raw; + // Sub-block 0 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl); + v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); - sumi_l0 = sumi_hi_m; + // Sub-block 1 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl); + v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 2 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl); + v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 3 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl); + v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + rhs_sc_ptr += 64; } - { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16); + int k_offsets[4] = {0, 32, 64, 96}; - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); - sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4); + // B. Inner Dot Product Loop +#pragma GCC unroll 1 + for (int l = 0; l < 16; ++l) { + vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += 16; + + // Unroll over 4 sub-blocks (0, 1, 2, 3 relative to phase) + + // --- Sub-block 0 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[0] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + // --- Sub-block 1 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[1] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + // --- Sub-block 2 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[2] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + // --- Sub-block 3 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[3] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } } - const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40]; - const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72]; - const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l1; - { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + // C CORRECTION + int sb_base_abs = base_k_phase / 16; - sumi_l1 = sumi_hi_m; + // --- Correction Sub-block 0 --- + { + int sb_abs = sb_base_abs + (k_offsets[0] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0); + + // Row 0 + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + // Row 1 + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + // Row 2 + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + // Row 3 + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); } + // --- Correction Sub-block 1 --- { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + int sb_abs = sb_base_abs + (k_offsets[1] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1); - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4); - sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4); + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); } - const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48]; - const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80]; - const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l2; + // --- Correction Sub-block 2 --- { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + int sb_abs = sb_base_abs + (k_offsets[2] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2); - sumi_l2 = sumi_hi_m; + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); } + // --- Correction Sub-block 3 --- { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + int sb_abs = sb_base_abs + (k_offsets[3] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3); - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4); - sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4); + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); } - const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24]; - const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56]; - const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88]; - const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l3; - { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + } // End Phase Loop - sumi_l3 = sumi_hi_m; - } - - { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4); - sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4); - } - } - __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4); - __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4); - __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4); - __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4); - } - } - - return; - } - -#endif - ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} - -void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if defined __riscv_v_intrinsic - const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); - - // 4x8 accumulators. - vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, 8); - vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, 8); - vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, 8); - vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, 8); - - for (int l = 0; l < nb; l++) { - // Load `b_ptr`. - const vuint8m4_t b_0_packed = __riscv_vle8_v_u8m4((const uint8_t *)b_ptr[l].qs, QK4_NL * 4); - const vint8m4_t b_0_lo = __riscv_vrgather_vv_i8m4(values, __riscv_vand_vx_u8m4(b_0_packed, 0xf, QK4_NL * 4), QK4_NL * 4); - const vint8m4_t b_0_hi = __riscv_vrgather_vv_i8m4(values, __riscv_vsrl_vx_u8m4(b_0_packed, 4, QK4_NL * 4), QK4_NL * 4); - - // Create 4 segments from `b`. - const vint8m2_t b_lo_0 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 0); - const vint8m2_t b_lo_1 = __riscv_vget_v_i8m4_i8m2(b_0_lo, 1); - const vint8m2_t b_hi_0 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 0); - const vint8m2_t b_hi_1 = __riscv_vget_v_i8m4_i8m2(b_0_hi, 1); - - // Load scales for `b`. - const vfloat16mf2_t b_d = __riscv_vle16_v_f16mf2((const _Float16 *)b_ptr[l].d, 8); + // --- Apply Main Scales --- + vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl); + vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl); { - // Load first 8 bytes of `a`. - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[32]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[64]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[96]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); - const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); - const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); - const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); - - // Multiply and accumulate. - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); - const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); - const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); - const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); - const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); - - // In-place reduction. - const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); - const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); - const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); - const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); - const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); - const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); - const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); - const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); - const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); - const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); - - // Multiply with scales. - const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[0], 8); - sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, facc, d_0, QK4_NL / 4); + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[0], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_0, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_0 = __riscv_vfadd_vv_f32m2(v_sumf_0, v_sum, vl); } - - // Load second 8 bytes of `a`. + // Row 1 { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[40]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[72]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[104]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); - const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); - const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); - const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); - - // Multiply and accumulate. - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); - const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); - const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); - const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); - const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); - - // In-place reduction. - const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); - const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); - const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); - const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); - const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); - const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); - const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); - const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); - const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); - const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); - - // Multiply with scales. - const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[1], 8); - sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, facc, d_0, QK4_NL / 4); + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[1], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_1, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_1 = __riscv_vfadd_vv_f32m2(v_sumf_1, v_sum, vl); } - - // Load third 8 bytes of `a`. + // Row 2 { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[48]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[80]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[112]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); - const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); - const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); - const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); - - // Multiply and accumulate. - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); - const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); - const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); - const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); - const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); - - // In-place reduction. - const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); - const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); - const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); - const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); - const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); - const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); - const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); - const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); - const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); - const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); - - // Multiply with scales. - const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[2], 8); - sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, facc, d_0, QK4_NL / 4); + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[2], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_2, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_2 = __riscv_vfadd_vv_f32m2(v_sumf_2, v_sum, vl); } - + // Row 3 { - // Load fourth 8 bytes of `a`. - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[24]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[56]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[88]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[120]; - __asm__ __volatile__("" ::: "memory"); - - // Broadcast `a_ptr` across 4 registers (8 bytes / register). - const vint8m2_t a_0 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, 8)); - const vint8m2_t a_1 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, 8)); - const vint8m2_t a_2 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, 8)); - const vint8m2_t a_3 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, 8)); - - // Multiply and accumulate. - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(b_lo_0, a_0, QK4_NL * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmul_vv_i16m4(b_lo_1, a_1, QK4_NL * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmul_vv_i16m4(b_hi_0, a_2, QK4_NL * 2); - const vint16m4_t sumi_hi_1 = __riscv_vwmul_vv_i16m4(b_hi_1, a_3, QK4_NL * 2); - const vint32m8_t sumi_lo = __riscv_vwadd_vv_i32m8(sumi_lo_0, sumi_lo_1, QK4_NL * 2); - const vint32m8_t sumi_hi = __riscv_vwadd_vv_i32m8(sumi_hi_0, sumi_hi_1, QK4_NL * 2); - const vint32m8_t sumi = __riscv_vadd_vv_i32m8(sumi_lo, sumi_hi, QK4_NL * 2); - - // In-place reduction. - const vuint64m8_t sumi_i32 = __riscv_vreinterpret_v_i64m8_u64m8(__riscv_vreinterpret_v_i32m8_i64m8(sumi)); - const vuint32m4_t sumi_h2_0 = __riscv_vnsrl_wx_u32m4(sumi_i32, 0, QK4_NL); - const vuint32m4_t sumi_h2_1 = __riscv_vnsrl_wx_u32m4(sumi_i32, 32, QK4_NL); - const vuint32m4_t sumi_h2 = __riscv_vadd_vv_u32m4(sumi_h2_0, sumi_h2_1, QK4_NL); - const vuint64m4_t sumi_h2_i32 = __riscv_vreinterpret_v_u32m4_u64m4(sumi_h2); - const vuint32m2_t sumi_h4_0 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 0, QK4_NL / 2); - const vuint32m2_t sumi_h4_1 = __riscv_vnsrl_wx_u32m2(sumi_h2_i32, 32, QK4_NL / 2); - const vuint32m2_t sumi_h4 = __riscv_vadd_vv_u32m2(sumi_h4_0, sumi_h4_1, QK4_NL / 2); - const vuint64m2_t sumi_h4_i32 = __riscv_vreinterpret_v_u32m2_u64m2(sumi_h4); - const vint32m1_t sumi_h8_0 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 0, QK4_NL / 4)); - const vint32m1_t sumi_h8_1 = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vnsrl_wx_u32m1(sumi_h4_i32, 32, QK4_NL / 4)); - const vint32m1_t sumi_h8 = __riscv_vadd_vv_i32m1(sumi_h8_0, sumi_h8_1, QK4_NL / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, QK4_NL / 4); - - // Multiply with scales. - const vfloat32m1_t d_0 = __riscv_vfwmul_vf_f32m1(b_d, *(const _Float16*)&a_ptr[l].d[3], 8); - sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, facc, d_0, QK4_NL / 4); + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[3], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_3, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_3 = __riscv_vfadd_vv_f32m2(v_sumf_3, v_sum, vl); } - } - __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, 8); - __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, 8); - __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, 8); - __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, 8); + } // End K-Block + + __riscv_vse32_v_f32m2(s + (row_tile + 0) * bs + col_tile, v_sumf_0, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 1) * bs + col_tile, v_sumf_1, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 2) * bs + col_tile, v_sumf_2, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 3) * bs + col_tile, v_sumf_3, vl); } } - return; - -#endif - ggml_gemm_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 5ca3aa06e0..fd72d92901 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1,4 +1,3 @@ -#include "ggml.h" #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP #include "ggml-common.h" @@ -1258,6 +1257,94 @@ void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } } + +void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr == 1); + assert(nc % 16 == 0); + + const int nb = n / QK_K; + const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; + const block_q8_K * y = (const block_q8_K *)vy; + + // Layout: Even-Low(0,2,4,6), Odd-Low(1,3,5,7), Even-High(8...), Odd-High(9...) + const int sb_perm[16] = { + 0, 4, 1, 5, 2, 6, 3, 7, // 0-7 + 8, 12, 9, 13, 10, 14, 11, 15 // 8-15 + }; + + for (int col_tile = 0; col_tile < nc; col_tile += 16) { + const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; + const block_q8_K * y_ptr = y; + + float sumf[16] = {0}; + + // Loop over K-blocks + for (int k_block = 0; k_block < nb; ++k_block) { + int32_t isum[16] = {0}; + int32_t summs[16] = {0}; + + const uint8_t * qs_rhs = x_ptr[k_block].qs; + const uint8_t * sc_rhs = x_ptr[k_block].scales; + const int8_t * qs_lhs = y_ptr[k_block].qs; + const int16_t * bs_lhs = y_ptr[k_block].bsums; + + // Iterate over sub-blocks 0..15 + for (int sb = 0; sb < 16; ++sb) { + // Correction Term + int16_t bsum = bs_lhs[sb]; + int scale_offset = sb_perm[sb] * 16; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + summs[col] += bsum * (sc_val >> 4); // Min is high 4 bits + } + + // Main Dot Product + // Calculate base offsets for Q2 unpacking based on SB + int byte_base; + if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; + else byte_base = (sb % 2 == 0) ? 32 : 48; + + int shift = ((sb / 2) % 4) * 2; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + int32_t d_sb = sc_val & 0xF; // Scale is low 4 bits + + // Process 16 elements (l=0..15) + for (int l = 0; l < 16; ++l) { + // Q2: Interleaved by column. Byte `l` contains 4 k-values. + int qs_idx = (byte_base + l) * 16 + col; + uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; + + // Q8: Linear access + int k = sb * 16 + l; + int8_t q8_val = qs_lhs[k]; + + isum[col] += q8_val * q2_val * d_sb; + } + } + } + + // Finalize K-Block + for (int col = 0; col < 16; ++col) { + float d_lhs = y_ptr[k_block].d; + float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); + float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); + + float d_all = d_lhs * d_rhs; + float d_min = d_lhs * dm_rhs; + + sumf[col] += (isum[col] * d_all) - (summs[col] * d_min); + } + } + + for (int col = 0; col < 16; ++col) { + s[col_tile + col] = sumf[col]; + } + } +} #endif void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -2330,6 +2417,102 @@ void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } } + + +void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr % 4 == 0); + assert(nc % 16 == 0); + const int nb = n / QK_K; + const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; + const block_q8_Kx4 * y = (const block_q8_Kx4 *)vy; + + const int sb_perm[16] = { + 0, 4, 1, 5, 2, 6, 3, 7, + 8, 12, 9, 13, 10, 14, 11, 15 + }; + + // Iterate Rows in tiles of 4 + for (int row_tile = 0; row_tile < nr; row_tile += 4) { + // Iterate Columns in tiles of 16 + for (int col_tile = 0; col_tile < nc; col_tile += 16) { + + const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; + const block_q8_Kx4 * y_ptr = y + (row_tile / 4) * nb; + + float sumf[4][16]; + memset(sumf, 0, sizeof(sumf)); + + for (int k_block = 0; k_block < nb; ++k_block) { + int32_t isum[4][16]; + int32_t summs[4][16]; + memset(isum, 0, sizeof(isum)); + memset(summs, 0, sizeof(summs)); + + const uint8_t * qs_rhs = x_ptr[k_block].qs; + const uint8_t * sc_rhs = x_ptr[k_block].scales; + const int8_t * qs_lhs = y_ptr[k_block].qs; + const int16_t * bs_lhs = y_ptr[k_block].bsums; + + for (int sb = 0; sb < 16; ++sb) { + int scale_offset = sb_perm[sb] * 16; + + int byte_base; + if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; + else byte_base = (sb % 2 == 0) ? 32 : 48; + int shift = ((sb / 2) % 4) * 2; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + int32_t d_sb = sc_val & 0xF; + int32_t m_sb = sc_val >> 4; + + // Correction Term + for (int r = 0; r < 4; ++r) { + int bsum_idx = (sb / 4) * 16 + r * 4 + (sb % 4); + summs[r][col] += bs_lhs[bsum_idx] * m_sb; + } + + // Main Dot Product + for (int l = 0; l < 16; ++l) { + int qs_idx = (byte_base + l) * 16 + col; + uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; + + // Calculate Q8 index for this specific k and row + int k = sb * 16 + l; + int q8_idx = (k / 4) * 16 + (k % 4); + + for (int r = 0; r < 4; ++r) { + // Add r*4 to jump to the correct row within the 4x4 chunk + int8_t q8_val = qs_lhs[q8_idx + r * 4]; + isum[r][col] += q8_val * q2_val * d_sb; + } + } + } + } + + // Finalize K-Block + for (int col = 0; col < 16; ++col) { + float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); + float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); + + for (int r = 0; r < 4; ++r) { + float d_lhs = y_ptr[k_block].d[r]; + float d_all = d_lhs * d_rhs; + float d_min = d_lhs * dm_rhs; + sumf[r][col] += (isum[r][col] * d_all) - (summs[r][col] * d_min); + } + } + } + + for (int r = 0; r < 4; ++r) { + for (int col = 0; col < 16; ++col) { + s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col]; + } + } + } + } +} #endif } // extern "C" @@ -2459,98 +2642,63 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in const int end = QK_K * 4 / blck_size_interleave; // Interleave Q4_K quants by taking 8 bytes at a time - if (blck_size_interleave == 8) { - for (int i = 0; i < end; ++i) { - int src_id = i % 8; - int src_offset = (i / 8) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; - uint64_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + + // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K + // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) + // The output Q4_Kx8 structure has 96 bytes + // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure + // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures + uint8_t s[8], m[8]; + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = in[j].scales[i] & 63; + m[j] = in[j].scales[i + 4] & 63; } - // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K - // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) - // The output Q4_Kx8 structure has 96 bytes - // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure - // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures - uint8_t s[8], m[8]; + out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[j] = in[j].scales[i] & 63; - m[j] = in[j].scales[i + 4] & 63; - } - - out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); - out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); - out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); - out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); - out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); - out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); - out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); - out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); - out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); - out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); - out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); - out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); + } + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); } - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); - m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); - } + out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); - out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); - out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); - out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); - out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); - out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); - out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); - out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); - out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); - out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); - out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); - out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); - out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); - - } - - } else if (blck_size_interleave == 1) { - for (int i = 0; i < end; ++i) { - int src_id = i % 8; - int src_offset = i / 8; - int dst_offset = i; - - out.qs[dst_offset] = in[src_id].qs[src_offset]; - } - - // RVV repacking. - // - // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. - uint8_t s[64], m[64]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[i * 8 + j] = in[j].scales[i] & 63; - m[i * 8 + j] = in[j].scales[i + 4] & 63; - } - } - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[32 + i * 8 + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); - m[32 + i * 8 + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); - } - } - - for (int i = 0; i < 64; i++) { - out.scales[i] = (s[i] & 15) + (m[i] & 15 << 4); - } - for (int i = 0; i < 32; i++) { - out.scales[64 + i] = (s[i] & 48 >> 4) + (m[i] & 48 >> 2) + (s[32 + i] & 48) + (m[32 + i] & 48 << 2); - } } return out; @@ -2785,6 +2933,68 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in return out; } +static block_q2_Kx16 make_block_q2_Kx16(const block_q2_K * in, unsigned int blck_size_interleave) { + block_q2_Kx16 out; + constexpr int N_COLS = 16; + + // 1. Copy Super-Scales (d) and Super-Mins (dmin) + for (int i = 0; i < N_COLS; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + // 2. Interleave Q2_K Data + const int bytes_per_col = 64; + const int total_bytes = N_COLS * bytes_per_col; + const int end = total_bytes / blck_size_interleave; + + for (int i = 0; i < end; ++i) { + int src_col_id = i % N_COLS; + int src_offset = (i / N_COLS) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + memcpy(&out.qs[dst_offset], &in[src_col_id].qs[src_offset], blck_size_interleave); + } + + // 3. Repack Scales into the Optimized "Sequential-Parallel" Layout + int out_idx = 0; + + // Arrays define the sub-block order for each group + const int even_low_sbs[] = {0, 2, 4, 6}; + const int odd_low_sbs[] = {1, 3, 5, 7}; + const int even_high_sbs[] = {8, 10, 12, 14}; + const int odd_high_sbs[] = {9, 11, 13, 15}; + + // Pack Group 1: Even-Low + for (int sb : even_low_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 2: Odd-Low + for (int sb : odd_low_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 3: Even-High + for (int sb : even_high_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 4: Odd-High + for (int sb : odd_high_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + return out; +} + static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(interleave_block == 4 || interleave_block == 8); @@ -2818,7 +3028,7 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 8 || interleave_block == 4 || interleave_block == 1); + GGML_ASSERT(interleave_block == 8 || interleave_block == 4); constexpr int nrows_interleaved = 8; block_q4_Kx8 * dst = (block_q4_Kx8*)t->data; @@ -2908,6 +3118,41 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q2_K_to_q2_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); + constexpr int nrows_interleaved = 16; + + block_q2_Kx16 * dst = (block_q2_Kx16*)t->data; + const block_q2_K * src = (const block_q2_K*) data; + + block_q2_K dst_tmp[nrows_interleaved]; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + // This loop gathers 16 separate blocks (one from each column) + // that correspond to the same K-dimension chunk. + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + + *dst++ = make_block_q2_Kx16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); constexpr int nrows_interleaved = 16; @@ -3144,14 +3389,6 @@ static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_s memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t)); } - } else if (blck_size_interleave == 8) { - for (int i = 0; i < end; ++i) { - int src_id = i % 4; - int src_offset = (i / 4) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - - memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); - } } else { GGML_ASSERT(false); } @@ -3207,14 +3444,7 @@ static block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_s int src_offset = (i / 8) * blck_size_interleave; int dst_offset = i * blck_size_interleave; - for (int b = 0; b < 8; ++b) { - out.qs[dst_offset + b] = in[src_id].qs[src_offset + b]; - } - - // Generates bus error on RVV as this is auto-vectorized and the - // source might possible not be 8-byte aligned - // - // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); } } else { GGML_ASSERT(false); @@ -3376,10 +3606,6 @@ template <> int repack(struct ggml_tensor * t, const void * d return repack_q4_0_to_q4_0_16_bl(t, 1, data, data_size); } -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_K_to_q4_K_8_bl(t, 1, data, data_size); -} - template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q4_K_to_q4_K_16_bl(t, 1, data, data_size); } @@ -3391,6 +3617,10 @@ template <> int repack(struct ggml_tensor * t, const void * template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q8_0_to_q8_0_16_bl(t, 1, data, data_size); } + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_K_to_q2_K_16_bl(t, 1, data, data_size); +} #endif // gemv @@ -3456,10 +3686,6 @@ template <> void gemv(int n, float * s, size_ ggml_gemv_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); -} - template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -3471,6 +3697,10 @@ template <> void gemv(int n, float * s, siz template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} #endif // gemm @@ -3537,10 +3767,6 @@ template <> void gemm(int n, float * s, size_ ggml_gemm_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); -} - template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -3552,6 +3778,10 @@ template <> void gemm(int n, float * s, siz template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); } + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} #endif class tensor_traits_base : public ggml::cpu::tensor_traits { @@ -3958,13 +4188,14 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons // instances for RISC-V // - // These implement outer-product style multiplication with interleave of 1. + // These implement outer-product style matrix multiplication kernels with + // an interleave of 1. #if defined __riscv_zvfh static const ggml::cpu::repack::tensor_traits q4_0_16x1_q8_0; - static const ggml::cpu::repack::tensor_traits q4_K_8x1_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_16x1_q8_K; static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits q2_K_16x1_q8_K; #endif if (cur->type == GGML_TYPE_Q4_0) { @@ -4027,6 +4258,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q2_K_8x8_q8_K; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q2_K_16x1_q8_K; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_Q5_K) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 9029ee3eb9..2f0521de06 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -64,6 +64,13 @@ struct block_q2_Kx8 { }; static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding"); +struct block_q2_Kx16 { + ggml_half d[16]; // Super-block scale for quantized scales + ggml_half dmin[16]; // Super-block scale for quantized mins + uint8_t scales[256]; // Sub-block scales (16 cols * 16 sub-blocks) + uint8_t qs[1024]; // Data (16 cols * 64 bytes per block) +}; +static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, "wrong q2_K block size/padding"); struct block_q5_Kx8 { ggml_half d[8]; // super-block scale for quantized scales @@ -152,15 +159,15 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #endif // Native implementations @@ -196,15 +203,15 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #endif #if defined(__cplusplus)