diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 3401c35876..af9ac7326f 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -226,11 +226,11 @@ static inline void ggml_gemv_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_ for (int x = 0; x < nc / ncols_interleaved; x++) { const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); - // 1x16 Accumulator + // 1xM Accumulator vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { - // 1x16 Integer Accumulator + // 1xM Integer Accumulator vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); @@ -239,7 +239,7 @@ static inline void ggml_gemv_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_ // Load `b_ptr`. const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, ncols_interleaved), 4, ncols_interleaved); - const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); + const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, ncols_interleaved); sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, ncols_interleaved); sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, ncols_interleaved); @@ -294,11 +294,11 @@ void ggml_gemv_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int x = 0; x < nc / ncols_interleaved; x++) { const block<8, ncols_interleaved> * b_ptr = (const block<8, ncols_interleaved> *) vx + (x * nb); - // 1x16 Accumulator + // 1xM Accumulator vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { - // 1x16 Integer Accumulator + // 1xM Integer Accumulator vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); // Accumulation loop. @@ -334,34 +334,33 @@ void ggml_gemv_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static void inline ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert(n % QK_K == 0); assert(nr == 1); - assert(nc % 16 == 0); + assert(nc % ncols_interleaved == 0); UNUSED(bs); const int num_k_blocks = n / QK_K; - const size_t vl = __riscv_vsetvl_e32m2(ncols_interleaved); for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { const block_q8_K* lhs_base_ptr = (const block_q8_K*)vy; const block_q2_Kx* rhs_base_ptr = (const block_q2_Kx*)vx + (col_tile / ncols_interleaved) * num_k_blocks; - vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int k_block = 0; k_block < num_k_blocks; ++k_block) { const block_q8_K* lhs_current = &lhs_base_ptr[k_block]; const block_q2_Kx* rhs_current = &rhs_base_ptr[k_block]; // 1. Prepare Global Min Scales - vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); - vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl); + vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, ncols_interleaved); + vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, ncols_interleaved); - vfloat32m2_t v_g_min_final = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d, vl); + vfloat32m2_t v_g_min_final = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d, ncols_interleaved); - vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); const uint8_t* rhs_qs_ptr = rhs_current->qs; const uint8_t* rhs_sc_ptr = rhs_current->scales; @@ -377,75 +376,77 @@ void ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo { vuint8mf2_t v_raw; // Sub-block 0 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl); - v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); - v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr, ncols_interleaved); + v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, ncols_interleaved), ncols_interleaved); + v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, ncols_interleaved), ncols_interleaved); // Sub-block 1 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl); - v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); - v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + rhs_sc_ptr+=ncols_interleaved; + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr, ncols_interleaved); + v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, ncols_interleaved), ncols_interleaved); + v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, ncols_interleaved), ncols_interleaved); // Sub-block 2 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl); - v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); - v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + rhs_sc_ptr+=ncols_interleaved; + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr, ncols_interleaved); + v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, ncols_interleaved), ncols_interleaved); + v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, ncols_interleaved), ncols_interleaved); // Sub-block 3 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl); - v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); - v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + rhs_sc_ptr+=ncols_interleaved; + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, ncols_interleaved); + v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, ncols_interleaved), ncols_interleaved); + v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, ncols_interleaved), ncols_interleaved); - rhs_sc_ptr += 64; + rhs_sc_ptr+=ncols_interleaved; } int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16); int k_offsets[4] = {0, 32, 64, 96}; - // B. Inner Dot Product Loop for (int l = 0; l < 16; ++l) { - vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); - rhs_qs_ptr += 16; + vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, ncols_interleaved); + rhs_qs_ptr += ncols_interleaved; // Sub-block 0 { - vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl); + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, ncols_interleaved); vint16m1_t v_w = __riscv_vmul_vv_i16m1( - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), - __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl); + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, ncols_interleaved)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), ncols_interleaved); int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[0] + l]; - v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, ncols_interleaved); } // Sub-block 1 { - vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl); + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, ncols_interleaved), 3, ncols_interleaved); vint16m1_t v_w = __riscv_vmul_vv_i16m1( - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), - __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl); + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, ncols_interleaved)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), ncols_interleaved); int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[1] + l]; - v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, ncols_interleaved); } // Sub-block 2 { - vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl); + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, ncols_interleaved), 3, ncols_interleaved); vint16m1_t v_w = __riscv_vmul_vv_i16m1( - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), - __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl); + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, ncols_interleaved)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), ncols_interleaved); int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[2] + l]; - v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, ncols_interleaved); } // Sub-block 3 { - vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl); + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, ncols_interleaved), 3, ncols_interleaved); vint16m1_t v_w = __riscv_vmul_vv_i16m1( - __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), - __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl); + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, ncols_interleaved)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), ncols_interleaved); int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[3] + l]; - v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, ncols_interleaved); } } @@ -457,54 +458,56 @@ void ggml_gemv_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo int sb_idx = sb_base_abs + (k_offsets[0] / 16); int16_t bsum = lhs_current->bsums[sb_idx]; vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0); - vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); - vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); - v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, ncols_interleaved); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, ncols_interleaved), v_g_min_final, ncols_interleaved); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, ncols_interleaved); } // Sub-block 1 { int sb_idx = sb_base_abs + (k_offsets[1] / 16); int16_t bsum = lhs_current->bsums[sb_idx]; vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1); - vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); - vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); - v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, ncols_interleaved); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, ncols_interleaved), v_g_min_final, ncols_interleaved); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, ncols_interleaved); } // Sub-block 2 { int sb_idx = sb_base_abs + (k_offsets[2] / 16); int16_t bsum = lhs_current->bsums[sb_idx]; vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2); - vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); - vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); - v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, ncols_interleaved); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, ncols_interleaved), v_g_min_final, ncols_interleaved); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, ncols_interleaved); } // Sub-block 3 { int sb_idx = sb_base_abs + (k_offsets[3] / 16); int16_t bsum = lhs_current->bsums[sb_idx]; vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3); - vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); - vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); - v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, ncols_interleaved); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, ncols_interleaved), v_g_min_final, ncols_interleaved); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, ncols_interleaved); } } // End Phase Loop // Apply global Scales - vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl); - vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl); + vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, ncols_interleaved); + vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, ncols_interleaved); - vfloat32m2_t v_g_all_final = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d, vl); - vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum, vl); - v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all_final, vl); - v_sumf = __riscv_vfadd_vv_f32m2(v_sumf, v_sum, vl); + vfloat32m2_t v_g_all_final = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d, ncols_interleaved); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum, ncols_interleaved); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all_final, ncols_interleaved); + v_sumf = __riscv_vfadd_vv_f32m2(v_sumf, v_sum, ncols_interleaved); } // End K-Block - __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl); + __riscv_vse32_v_f32m2(s + col_tile, v_sumf, ncols_interleaved); + } } + void ggml_gemv_q2_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_q2_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); } @@ -542,7 +545,7 @@ void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q4_Kx * b_ptr = (const block_q4_Kx *) vx + (x * nb); - // 1x16 Accumulator + // 1xM Accumulator vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { @@ -586,15 +589,16 @@ void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, ncols_interleaved), ncols_interleaved), ncols_interleaved); // Accumulation for 2 sub-blocks. - { - // 4x16 integer accumulators + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - // This might overflow. - // - // Recheck. - for (int i = 0; i < QK4_0; i++) { + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { // Load `b_ptr`. const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); @@ -606,20 +610,22 @@ void ggml_gemv_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), - sumi_s_0_16, 16); + sumi_s_0_16, ncols_interleaved); sumi = __riscv_vwmacc_vv_i32m2(sumi, __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), - sumi_s_1_16, 16); + sumi_s_1_16, ncols_interleaved); } - { - // 4x16 integer accumulators + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - // This might overflow. - // - // Recheck. - for (int i = 0; i < QK4_0; i++) { + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { // Load `b_ptr`. const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); @@ -661,6 +667,166 @@ void ggml_gemv_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemv_q4_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); } +template +void ggml_gemv_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx * b_ptr = (const block_q5_Kx *) vx + (x * nb); + + // 1xM Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + // We process 4 sub-blocks at once. + const int vl = ncols_interleaved * 4; + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * vl], vl); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, vl); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, vl); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[vl * 2], vl); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, vl), 4, vl); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, vl), 2, vl); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, vl); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, vl), 2, vl); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, vl), vl); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, vl), vl)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); + + const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d, ncols_interleaved); + sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, ncols_interleaved), ncols_interleaved), ncols_interleaved); + + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in 4 steps. + // + // Recheck. + for (int k = 0; k < 4; k++) { + // 4xM integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + for (int i = k * 8; i < (k + 1) * 8; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_lo_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_lo_packed, 4, ncols_interleaved)); + + // Load high bits and merge with low bits. + const vuint8mf2_t b_hi_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[i * ncols_interleaved], ncols_interleaved); + const vbool16_t b_hi_0_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 0), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vadd_vx_i8mf2_mu(b_hi_0_mask, b_s_lo_0, b_s_lo_0, 16, ncols_interleaved); + const vbool16_t b_hi_1_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 1), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_1 = __riscv_vadd_vx_i8mf2_mu(b_hi_1_mask, b_s_lo_1, b_s_lo_1, 16, ncols_interleaved); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, ncols_interleaved); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, ncols_interleaved); + } + + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_s_0_16, ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_s_1_16, ncols_interleaved); + } + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in 4 steps. + // + // Recheck. + for (int k = 0; k < 4; k++) { + // 4xM integer accumulators + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + for (int i = k * 8; i < (k + 1) * 8; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_lo_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_lo_packed, 4, ncols_interleaved)); + + // Load high bits and merge with low bits. + const vuint8mf2_t b_hi_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[i * ncols_interleaved], ncols_interleaved); + const vbool16_t b_hi_0_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 2), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vadd_vx_i8mf2_mu(b_hi_0_mask, b_s_lo_0, b_s_lo_0, 16, ncols_interleaved); + const vbool16_t b_hi_1_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 3), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_1 = __riscv_vadd_vx_i8mf2_mu(b_hi_1_mask, b_s_lo_1, b_s_lo_1, 16, ncols_interleaved); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, ncols_interleaved); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, ncols_interleaved); + } + + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_s_0_16, ncols_interleaved); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_s_1_16, ncols_interleaved); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], ncols_interleaved), ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, ncols_interleaved); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, ncols_interleaved), d_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + x * ncols_interleaved, sumf, ncols_interleaved); + } +} + +void ggml_gemv_q5_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q5_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q5_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +} + template void ggml_gemv_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -680,27 +846,27 @@ void ggml_gemv_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(ncols_interleaved); UNUSED(blocklen); - const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16); const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { const block_iq4_nlx * b_ptr = (const block_iq4_nlx *) vx + (x * nb); - // 1x16 Accumulator1 + // 1xM Accumulator vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { - // 1x16 integer accumulator + // 1xM integer accumulator vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); // Accumulation loop. for (int i = 0; i < QK4_NL / 2; i++) { // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); - const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); - const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved), ncols_interleaved); + const vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8m1_t b_0_lo = __riscv_vrgather_vv_i8m1(values, __riscv_vand_vx_u8m1(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); + const vint8m1_t b_0_hi = __riscv_vrgather_vv_i8m1(values, __riscv_vsrl_vx_u8m1(b_0_packed, 4, ncols_interleaved), ncols_interleaved); - const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], ncols_interleaved); - const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], ncols_interleaved); + const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i], ncols_interleaved); + const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), a_ptr[l].qs[16 + i], ncols_interleaved); sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, ncols_interleaved), ncols_interleaved); } @@ -980,14 +1146,14 @@ void ggml_gemm_q4_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int x = 0; x < nc / ncols_interleaved; x++) { const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); - // 4x16 Accumulators + // 4xM Accumulators vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { - // 4x16 integer accumulators + // 4xM integer accumulators vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); @@ -1079,14 +1245,14 @@ void ggml_gemm_q8_0_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int x = 0; x < nc / ncols_interleaved; x++) { const block<8, ncols_interleaved> * b_ptr = (const block<8, ncols_interleaved> *) vx + (x * nb); - // 4x16 Accumulators + // 4xM Accumulators vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { - // 4x16 Integer Accumulators + // 4xM Integer Accumulators vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); @@ -1137,7 +1303,7 @@ void ggml_gemm_q8_0_64x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } template -void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +static void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { assert(n % QK_K == 0); const int num_k_blocks = n / QK_K; const int N_ROWS_TILE = 4; @@ -1152,7 +1318,7 @@ void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { // Base Pointers const block_q8_Kx4* lhs_base_ptr = (const block_q8_Kx4*)vy + (row_tile / N_ROWS_TILE) * num_k_blocks; - const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / ncols_interleaved) * num_k_blocks; + const block_q2_Kx* rhs_base_ptr = (const block_q2_Kx*)vx + (col_tile / ncols_interleaved) * num_k_blocks; // Persistent Float Accumulators vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); @@ -1164,7 +1330,7 @@ void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #pragma GCC unroll 1 for (int k_block = 0; k_block < num_k_blocks; ++k_block) { const block_q8_Kx4* lhs_current = &lhs_base_ptr[k_block]; - const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block]; + const block_q2_Kx* rhs_current = &rhs_base_ptr[k_block]; // 1. Load Global Min Scales (Keep as F16/LMUL=1 to save registers) vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); @@ -1192,26 +1358,29 @@ void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo { vuint8mf2_t v_raw; // Sub-block 0 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl); + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr , vl); v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); // Sub-block 1 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl); + rhs_sc_ptr+=ncols_interleaved; + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr , vl); v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); // Sub-block 2 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl); + rhs_sc_ptr+=ncols_interleaved; + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr, vl); v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); // Sub-block 3 - v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl); + rhs_sc_ptr+=ncols_interleaved; + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr, vl); v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + rhs_sc_ptr+=ncols_interleaved; - rhs_sc_ptr += 64; } int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16); @@ -1221,8 +1390,7 @@ void ggml_gemm_q2_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #pragma GCC unroll 1 for (int l = 0; l < 16; ++l) { vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); - rhs_qs_ptr += 16; - + rhs_qs_ptr+=ncols_interleaved; // Unroll over 4 sub-blocks (0, 1, 2, 3 relative to phase) // --- Sub-block 0 --- @@ -1472,7 +1640,7 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q4_Kx * b_ptr = (const block_q4_Kx *) vx + (x * nb); - // 4x16 Accumulators + // 4xM Accumulators vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); @@ -1495,7 +1663,7 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, vl); // High bits. - vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[vl], vl); + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[vl * 2], vl); vuint8m2_t scales_hi; vuint8m2_t mins_hi; if (!j) { @@ -1518,52 +1686,52 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], - __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2( __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[0], ncols_interleaved); @@ -1581,8 +1749,12 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo // Accumulation for 2 sub-blocks. - { - // 4x8 integer accumulators + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); @@ -1592,12 +1764,9 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - // This might overflow. - // - // Recheck. - for (int i = 0; i < QK4_0; i++) { + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], 16); + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved)); @@ -1637,8 +1806,13 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), sumi_3_s_1_16, ncols_interleaved); } - { - // 4x16 integer accumulators + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4xM integer accumulators vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); @@ -1648,10 +1822,7 @@ void ggml_gemm_q4_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); - // This might overflow. - // - // Recheck. - for (int i = 0; i < QK4_0; i++) { + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { // Load `b_ptr`. const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, ncols_interleaved)); @@ -1728,6 +1899,320 @@ void ggml_gemm_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_gemm_q4_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); } +template +void ggml_gemm_q5_K_Mx1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx * b_ptr = (const block_q5_Kx *) vx + (x * nb); + + // 4xM Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + // We process 4 sub-blocks at once. + const int vl = ncols_interleaved * 4; + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * vl], vl); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, vl); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, vl); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[vl * 2], vl); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, vl), 4, vl); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, vl), 2, vl); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, vl); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, vl), 2, vl); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, vl), vl); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, vl), vl)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, ncols_interleaved); + + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 0), ncols_interleaved); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 1), ncols_interleaved); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 2), ncols_interleaved); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 3), ncols_interleaved); + + const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, ncols_interleaved), ncols_interleaved), a_ptr[l].d[3], ncols_interleaved); + + sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, ncols_interleaved), ncols_interleaved), ncols_interleaved); + sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, ncols_interleaved), ncols_interleaved), ncols_interleaved); + sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, ncols_interleaved), ncols_interleaved), ncols_interleaved); + sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, ncols_interleaved), ncols_interleaved), ncols_interleaved); + + + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in 4 steps. + // + // Recheck. + for (int k = 0; k < 4; k++) { + // 4xM integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + for (int i = k * 8; i < (k + 1) * 8; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_lo_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_lo_packed, 4, ncols_interleaved)); + + // Load high bits and merge with low bits. + const vuint8mf2_t b_hi_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[i * ncols_interleaved], ncols_interleaved); + const vbool16_t b_hi_0_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 0), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vadd_vx_i8mf2_mu(b_hi_0_mask, b_s_lo_0, b_s_lo_0, 16, ncols_interleaved); + const vbool16_t b_hi_1_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 1), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_1 = __riscv_vadd_vx_i8mf2_mu(b_hi_1_mask, b_s_lo_1, b_s_lo_1, 16, ncols_interleaved); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, ncols_interleaved); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, ncols_interleaved); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, ncols_interleaved); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, ncols_interleaved); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, ncols_interleaved); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, ncols_interleaved); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, ncols_interleaved); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, ncols_interleaved); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_0_s_0_16, ncols_interleaved); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_0_s_1_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_1_s_0_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_1_s_1_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_2_s_0_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_2_s_1_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_3_s_0_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_3_s_1_16, ncols_interleaved); + } + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 4; k++) { + // 4xM integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, ncols_interleaved); + + for (int i = k * 8; i < (k + 1) * 8; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_lo_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 64 * ncols_interleaved + 32 * ncols_interleaved + i * ncols_interleaved], ncols_interleaved); + const vint8mf2_t b_s_lo_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_lo_packed, 0xF, ncols_interleaved)); + const vint8mf2_t b_s_lo_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_lo_packed, 4, ncols_interleaved)); + + // Load high bits and merge with low bits. + const vuint8mf2_t b_hi_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qh[i * ncols_interleaved], ncols_interleaved); + const vbool16_t b_hi_0_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 2), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_0 = __riscv_vadd_vx_i8mf2_mu(b_hi_0_mask, b_s_lo_0, b_s_lo_0, 16, ncols_interleaved); + const vbool16_t b_hi_1_mask = __riscv_vmsne_vx_u8mf2_b16(__riscv_vand_vx_u8mf2(b_hi_packed, 1 << (j*4 + 3), ncols_interleaved), 0, ncols_interleaved); + const vint8mf2_t b_s_1 = __riscv_vadd_vx_i8mf2_mu(b_hi_1_mask, b_s_lo_1, b_s_lo_1, 16, ncols_interleaved); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, ncols_interleaved); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, ncols_interleaved); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, ncols_interleaved); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, ncols_interleaved); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, ncols_interleaved); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, ncols_interleaved); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, ncols_interleaved); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, ncols_interleaved); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_0_s_0_16, ncols_interleaved); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_0_s_1_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_1_s_0_16, ncols_interleaved); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_1_s_1_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_2_s_0_16, ncols_interleaved); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_2_s_1_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_3_s_0_16, ncols_interleaved); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_3_s_1_16, ncols_interleaved); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, ncols_interleaved), ncols_interleaved); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], ncols_interleaved); + const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], ncols_interleaved); + const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], ncols_interleaved); + const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], ncols_interleaved); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, ncols_interleaved), d_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, ncols_interleaved), d_1, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, ncols_interleaved), d_2, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, ncols_interleaved), d_3, ncols_interleaved); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * ncols_interleaved, sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * ncols_interleaved, sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * ncols_interleaved, sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * ncols_interleaved, sumf_3, ncols_interleaved); + } + } +} + +void ggml_gemm_q5_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q5_K_Mx1_q8_K<8>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q5_K_8x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q5_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q5_K_Mx1_q8_K<16>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q5_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q5_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q5_K_Mx1_q8_K<32>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q5_K_32x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} +void ggml_gemm_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_zvfh + ggml_gemm_q5_K_Mx1_q8_K<64>(n, s, bs, vx, vy, nr, nc); +#else + ggml_gemm_q5_K_64x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + template void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -1749,21 +2234,21 @@ void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(blocklen); #if defined __riscv_v_intrinsic - const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + const vint8m1_t values = __riscv_vle8_v_i8m1(kvalues_iq4nl, 16); for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { const block_iq4_nlx * b_ptr = (const block_iq4_nlx *) vx + (x * nb); - // 4x16 Accumulators + // 4xM Accumulators vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, ncols_interleaved); for (int l = 0; l < nb; l++) { - // 4x16 integer accumulators + // 4xM integer accumulators vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, ncols_interleaved); @@ -1772,21 +2257,21 @@ void ggml_gemm_iq4_nl_Mx1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const // Accumulation loop. for (int i = 0; i < QK4_NL / 2; i++) { // Load `b_ptr`. - const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); - const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); - const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, ncols_interleaved), ncols_interleaved); + const vuint8m1_t b_0_packed = __riscv_vle8_v_u8m1((const uint8_t *)&b_ptr[l].qs[i * ncols_interleaved], ncols_interleaved); + const vint8m1_t b_0_lo = __riscv_vrgather_vv_i8m1(values, __riscv_vand_vx_u8m1(b_0_packed, 0xf, ncols_interleaved), ncols_interleaved); + const vint8m1_t b_0_hi = __riscv_vrgather_vv_i8m1(values, __riscv_vsrl_vx_u8m1(b_0_packed, 4, ncols_interleaved), ncols_interleaved); // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); - const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], ncols_interleaved); - const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], ncols_interleaved); - const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], ncols_interleaved); - const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], ncols_interleaved); + const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4], ncols_interleaved); + const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4 + 1], ncols_interleaved); + const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4 + 2], ncols_interleaved); + const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_lo), a_ptr[l].qs[i * 4 + 3], ncols_interleaved); - const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], ncols_interleaved); - const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], ncols_interleaved); - const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], ncols_interleaved); - const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], ncols_interleaved); + const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), a_ptr[l].qs[64 + i * 4], ncols_interleaved); + const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), a_ptr[l].qs[64 + i * 4 + 1], ncols_interleaved); + const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), a_ptr[l].qs[64 + i * 4 + 2], ncols_interleaved); + const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(__riscv_vlmul_trunc_v_i8m1_i8mf2(b_0_hi), a_ptr[l].qs[64 + i * 4 + 3], ncols_interleaved); sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, ncols_interleaved), ncols_interleaved); sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, ncols_interleaved), ncols_interleaved); diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index eaf88f174a..ab0c9bda07 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -361,6 +361,7 @@ static inline void ggml_gemv_q4_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT const int nb = n / qk; const int blocklen = 1; + assert(nr == 1); assert (n % qk == 0); assert (nc % ncols_interleaved == 0); @@ -374,7 +375,7 @@ static inline void ggml_gemv_q4_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT UNUSED(ncols_interleaved); UNUSED(blocklen); - float sumf[16]; + float sumf[ncols_interleaved]; int sumi; const block_q8_0 * a_ptr = (const block_q8_0 *) vy; @@ -412,7 +413,7 @@ static inline void ggml_gemv_q8_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT UNUSED(bs); UNUSED(nr); - float sumf[16]; + float sumf[ncols_interleaved]; int sumi; const block_q8_0 * a_ptr = (const block_q8_0 *) vy; @@ -462,12 +463,12 @@ static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT const block_q2_Kx * x_ptr = x + (col_tile / ncols_interleaved) * nb; const block_q8_K * y_ptr = y; - float sumf[16] = {0}; + float sumf[ncols_interleaved] = {0}; // Loop over K-blocks for (int k_block = 0; k_block < nb; ++k_block) { - int32_t isum[16] = {0}; - int32_t summs[16] = {0}; + int32_t isum[ncols_interleaved] = {0}; + int32_t summs[ncols_interleaved] = {0}; const uint8_t * qs_rhs = x_ptr[k_block].qs; const uint8_t * sc_rhs = x_ptr[k_block].scales; @@ -478,9 +479,9 @@ static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT for (int sb = 0; sb < 16; ++sb) { // Correction Term int16_t bsum = bs_lhs[sb]; - int scale_offset = sb_perm[sb] * 16; + int scale_offset = sb_perm[sb] * ncols_interleaved; - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { uint8_t sc_val = sc_rhs[scale_offset + col]; summs[col] += bsum * (sc_val >> 4); // Min is high 4 bits } @@ -493,14 +494,14 @@ static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT int shift = ((sb / 2) % 4) * 2; - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { uint8_t sc_val = sc_rhs[scale_offset + col]; int32_t d_sb = sc_val & 0xF; // Scale is low 4 bits // Process 16 elements (l=0..15) for (int l = 0; l < 16; ++l) { // Q2: Interleaved by column. Byte `l` contains 4 k-values. - int qs_idx = (byte_base + l) * 16 + col; + int qs_idx = (byte_base + l) * ncols_interleaved + col; uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; // Q8: Linear access @@ -513,7 +514,7 @@ static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } // Finalize K-Block - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { float d_lhs = y_ptr[k_block].d; float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); @@ -525,7 +526,7 @@ static inline void ggml_gemv_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } } - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { s[col_tile + col] = sumf[col]; } } @@ -536,8 +537,11 @@ static inline void ggml_gemv_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT const int qk = QK_K; const int nb = n / qk; const int blocklen = 1; + + assert(nr == 1); assert (n % qk == 0); assert (nc % ncols_interleaved == 0); + UNUSED(s); UNUSED(bs); UNUSED(vx); @@ -547,6 +551,7 @@ static inline void ggml_gemv_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT UNUSED(nb); UNUSED(ncols_interleaved); UNUSED(blocklen); + float sumf[ncols_interleaved]; float sum_minf[ncols_interleaved]; uint8_t scales[ncols_interleaved * 8]; @@ -604,6 +609,85 @@ static inline void ggml_gemv_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } } +template +static inline void ggml_gemv_q5_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert(nr == 1); + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[ncols_interleaved]; + float sum_minf[ncols_interleaved]; + uint8_t scales[ncols_interleaved * 8]; + uint8_t mins[ncols_interleaved * 8]; + int sumi1; + int sumi2; + int sumi; + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx * b_ptr = (const block_q5_Kx *) vx + (x * nb); + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + sum_minf[j] = 0.0f; + } + for (int l = 0; l < nb; l++) { + for (int i = 0; i < ncols_interleaved * 8; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < ncols_interleaved * 4; i++) { + scales[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x0C) << 2; + scales[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x30); + mins[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0xC0) >> 2; + } + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * ncols_interleaved]; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += min[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * ncols_interleaved]; + uint8_t *scales_1 = &scales[(sb + 1) * ncols_interleaved]; + for (int i = 0; i < QK4_0; i++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + int v0 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] & 0xF); + int v1 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] >> 4); + if (b_ptr[l].qh[i * ncols_interleaved + j] & (1 << (sb + 0))) { v0 += 16; } + if (b_ptr[l].qh[i * ncols_interleaved + j] & (1 << (sb + 1))) { v1 += 16; } + sumi1 = (v0 * a_ptr[l].qs[sb * 32 + i]); + sumi2 = (v1 * a_ptr[l].qs[sb * 32 + 32 + i]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + template static inline void ggml_gemv_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -617,7 +701,7 @@ static inline void ggml_gemv_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRIC UNUSED(bs); UNUSED(nr); - float sumf[16]; + float sumf[ncols_interleaved]; int sumi; const block_q8_0 * a_ptr = (const block_q8_0 *) vy; @@ -714,7 +798,7 @@ static inline void ggml_gemm_q8_0_Mx1_q8_0_generic(int n, float * GGML_RESTRICT for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block<4, ncols_interleaved> * b_ptr = (const block<4, ncols_interleaved> *) vx + (x * nb); + const block<8, ncols_interleaved> * b_ptr = (const block<8, ncols_interleaved> *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { sumf[m][j] = 0.0; @@ -750,7 +834,7 @@ static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT assert(nr % 4 == 0); assert(nc % 16 == 0); const int nb = n / QK_K; - const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; + const block_q2_Kx * x = (const block_q2_Kx *)vx; const block_q8_Kx4 * y = (const block_q8_Kx4 *)vy; const int sb_perm[16] = { @@ -761,17 +845,17 @@ static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT // Iterate Rows in tiles of 4 for (int row_tile = 0; row_tile < nr; row_tile += 4) { // Iterate Columns in tiles of 16 - for (int col_tile = 0; col_tile < nc; col_tile += 16) { + for (int col_tile = 0; col_tile < nc; col_tile += ncols_interleaved) { - const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; + const block_q2_Kx * x_ptr = x + (col_tile / ncols_interleaved) * nb; const block_q8_Kx4 * y_ptr = y + (row_tile / 4) * nb; - float sumf[4][16]; + float sumf[4][ncols_interleaved]; memset(sumf, 0, sizeof(sumf)); for (int k_block = 0; k_block < nb; ++k_block) { - int32_t isum[4][16]; - int32_t summs[4][16]; + int32_t isum[4][ncols_interleaved]; + int32_t summs[4][ncols_interleaved]; memset(isum, 0, sizeof(isum)); memset(summs, 0, sizeof(summs)); @@ -781,14 +865,14 @@ static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT const int16_t * bs_lhs = y_ptr[k_block].bsums; for (int sb = 0; sb < 16; ++sb) { - int scale_offset = sb_perm[sb] * 16; + int scale_offset = sb_perm[sb] * ncols_interleaved; int byte_base; if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; else byte_base = (sb % 2 == 0) ? 32 : 48; int shift = ((sb / 2) % 4) * 2; - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { uint8_t sc_val = sc_rhs[scale_offset + col]; int32_t d_sb = sc_val & 0xF; int32_t m_sb = sc_val >> 4; @@ -801,7 +885,7 @@ static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT // Main Dot Product for (int l = 0; l < 16; ++l) { - int qs_idx = (byte_base + l) * 16 + col; + int qs_idx = (byte_base + l) * ncols_interleaved + col; uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; // Calculate Q8 index for this specific k and row @@ -818,7 +902,7 @@ static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } // Finalize K-Block - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); @@ -832,7 +916,7 @@ static inline void ggml_gemm_q2_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } for (int r = 0; r < 4; ++r) { - for (int col = 0; col < 16; ++col) { + for (int col = 0; col < ncols_interleaved; ++col) { s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col]; } } @@ -884,8 +968,8 @@ static inline void ggml_gemm_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT mins[i] = b_ptr[l].scales[i] >> 4; } for (int i = 0; i < ncols_interleaved * 4; i++) { - scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; - mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; + scales[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x0C) << 2; scales[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x30); mins[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0xC0) >> 2; } @@ -934,6 +1018,102 @@ static inline void ggml_gemm_q4_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT } } +template +static inline void ggml_gemm_q5_K_Mx1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][ncols_interleaved]; + float sum_minf[4][ncols_interleaved]; + uint8_t scales[8 * ncols_interleaved]; + uint8_t mins[8 * ncols_interleaved]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx * b_ptr = (const block_q5_Kx *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int i = 0; i < ncols_interleaved * 8; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < ncols_interleaved * 4; i++) { + scales[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x0C) << 2; + scales[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0x30); + mins[i + ncols_interleaved * 4] |= (b_ptr[l].scales[ncols_interleaved * 8 + i] & 0xC0) >> 2; + } + + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * ncols_interleaved]; + for(int m = 0; m < 4; m++) { + const int16_t bsums = a_ptr[l].bsums[sb * 8 + m] + a_ptr[l].bsums[sb * 8 + m + 4]; + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += min[j] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * ncols_interleaved]; + uint8_t *scales_1 = &scales[(sb + 1) * ncols_interleaved]; + + for (int i = 0; i < QK4_0; i++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + + int v0 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] & 0xF); + int v1 = (int8_t) (b_ptr[l].qs[sb * 16 * ncols_interleaved + i * ncols_interleaved + j] >> 4); + if (b_ptr[l].qh[i * ncols_interleaved + j] & (1 << (sb + 0))) { v0 += 16; } + if (b_ptr[l].qh[i * ncols_interleaved + j] & (1 << (sb + 1))) { v1 += 16; } + sumi1 = (v0 * a_ptr[l].qs[sb * 4 * 32 + i * 4 + m]); + sumi2 = (v1 * a_ptr[l].qs[sb * 4 * 32 + 32 * 4 + i * 4 + m]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + template static inline void ggml_gemm_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -944,7 +1124,7 @@ static inline void ggml_gemm_iq4_nl_Mx1_q8_0_generic(int n, float * GGML_RESTRIC assert(nr % 4 == 0); assert(nc % ncols_interleaved == 0); - float sumf[4][16]; + float sumf[4][ncols_interleaved]; int sumi; for (int y = 0; y < nr / 4; y++) { @@ -1340,11 +1520,6 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } - -void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); -} - void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } @@ -1599,7 +1774,7 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n, #if defined __riscv_zvfh // Q4_0 void ggml_gemv_q4_0_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - ggml_gemm_q4_0_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); + ggml_gemv_q4_0_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_q4_0_Mx1_q8_0_generic<16>(n, s, bs, vx, vy, nr, nc); @@ -1653,6 +1828,20 @@ void ggml_gemv_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemv_q4_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q5_K +void ggml_gemv_q5_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q5_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q5_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemv_q5_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} + // IQ4_NL void ggml_gemv_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_iq4_nl_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); @@ -2454,6 +2643,20 @@ void ggml_gemm_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, ggml_gemm_q4_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); } +// Q5_K +void ggml_gemm_q5_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_Mx1_q8_K_generic<8>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q5_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_Mx1_q8_K_generic<16>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q5_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_Mx1_q8_K_generic<32>(n, s, bs, vx, vy, nr, nc); +} +void ggml_gemm_q5_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_Mx1_q8_K_generic<64>(n, s, bs, vx, vy, nr, nc); +} + // IQ4_NL void ggml_gemm_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemm_iq4_nl_Mx1_q8_0_generic<8>(n, s, bs, vx, vy, nr, nc); @@ -2900,6 +3103,7 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } + static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, @@ -3255,13 +3459,12 @@ static block_q2_Kx make_block_q2_KxMx1(const block_q2_K * in) block_q2_Kx out; constexpr int N_COLS = nrows_interleaved; - // 1. Copy Super-Scales (d) and Super-Mins (dmin) for (int i = 0; i < N_COLS; i++) { out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; } - // 2. Interleave Q2_K Data + // Interleave Q2_K Data const int bytes_per_col = 64; const int total_bytes = N_COLS * bytes_per_col; const int end = total_bytes; @@ -3273,7 +3476,7 @@ static block_q2_Kx make_block_q2_KxMx1(const block_q2_K * in) memcpy(&out.qs[dst_offset], &in[src_col_id].qs[src_offset], 1); } - // 3. Repack Scales into the Optimized "Sequential-Parallel" Layout + // Repack Scales into the Optimized "Sequential-Parallel" Layout int out_idx = 0; // Arrays define the sub-block order for each group @@ -3333,7 +3536,7 @@ static int repack_q2_K_to_q2_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_ for (int b = 0; b < nrow; b += nrows_interleaved) { for (int64_t x = 0; x < nblocks; x++) { - // This loop gathers 16 separate blocks (one from each column) + // This loop gathers 16 separate blocks (one from each row (of transposed matrix() // that correspond to the same K-dimension chunk. for (int i = 0; i < nrows_interleaved; i++ ) { dst_tmp[i] = src[x + i * nblocks]; @@ -3351,7 +3554,6 @@ static int repack_q2_K_to_q2_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_ template static block_q4_Kx make_block_q4_KxMx1(block_q4_K * in) { block_q4_Kx out; - //Delta(scale) and dmin values of the 16 Q4_K structures are copied onto the output interleaved structure for (int i = 0; i < nrows_interleaved; i++) { out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; } @@ -3427,6 +3629,94 @@ static int repack_q4_K_to_q4_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_ GGML_UNUSED(data_size); } +template +static block_q5_Kx make_block_q5_KxMx1(block_q5_K * in) { + block_q5_Kx out; + for (int i = 0; i < nrows_interleaved; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < nrows_interleaved; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + const int end_ls = QK_K * nrows_interleaved / 2; + + for (int i = 0; i < end_ls; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + + const int end_hs = 32 * nrows_interleaved; + + for (int i = 0; i < end_hs; ++i) { + int src_id = i % nrows_interleaved; + int src_offset = i / nrows_interleaved; + int dst_offset = i; + + out.qh[dst_offset] = in[src_id].qh[src_offset]; + } + + // RVV repacking. + // + // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. + uint8_t s[8 * nrows_interleaved], m[8 * nrows_interleaved]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < nrows_interleaved; j++) { + s[i * nrows_interleaved + j] = in[j].scales[i] & 63; + m[i * nrows_interleaved + j] = in[j].scales[i + 4] & 63; + } + } + for (int i = 0; i < 4; i++) { + for (int j = 0; j < nrows_interleaved; j++) { + s[nrows_interleaved * 8 / 2 + i * nrows_interleaved + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[nrows_interleaved * 8 / 2 + i * nrows_interleaved + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + } + + for (int i = 0; i < 8 * nrows_interleaved; i++) { + out.scales[i] = (s[i] & 15) | ((m[i] & 15) << 4); + } + for (int i = 0; i < 8 * nrows_interleaved / 2; i++) { + out.scales[nrows_interleaved * 8 + i] = ((s[i] & 48) >> 4) | ((m[i] & 48) >> 2) | (s[nrows_interleaved * 8 / 2 + i] & 48) | ((m[nrows_interleaved * 8 / 2 + i] & 48) << 2); + } + + return out; +} + +template +static int repack_q5_K_to_q5_K_Mx1_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + + block_q5_Kx * dst = (block_q5_Kx*)t->data; + const block_q5_K * src = (const block_q5_K*) data; + block_q5_K dst_tmp[nrows_interleaved]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q5_KxMx1(dst_tmp); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + template static block_iq4_nlx make_block_iq4_nlxMx1(block_iq4_nl * in) { block_iq4_nlx out; @@ -3727,6 +4017,20 @@ template <> int repack(struct ggml_tensor * t, const void * d return repack_q4_K_to_q4_K_Mx1_bl<64>(t, data, data_size); } +// Q5_K +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_Mx1_bl<8>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_Mx1_bl<16>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_Mx1_bl<32>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_Mx1_bl<64>(t, data, data_size); +} + // IQ4_NL template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_Mx1_bl<8>(t, data, data_size); @@ -3874,6 +4178,20 @@ template <> void gemv(int n, float * s, size_ ggml_gemv_q4_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); } +// Q5_K +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + // IQ4_NL template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_8x1_q8_0(n, s, bs, vx, vy, nr, nc); @@ -4021,6 +4339,20 @@ template <> void gemm(int n, float * s, size_ ggml_gemm_q4_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); } +// Q5_K +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_8x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_32x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_64x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + // IQ4_NL template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_8x1_q8_0(n, s, bs, vx, vy, nr, nc); @@ -4473,6 +4805,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_K_32x1_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_64x1_q8_K; + // Q5_K + static const ggml::cpu::repack::tensor_traits q5_K_8x1_q8_K; + static const ggml::cpu::repack::tensor_traits q5_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits q5_K_32x1_q8_K; + static const ggml::cpu::repack::tensor_traits q5_K_64x1_q8_K; + // IQ4_NL static const ggml::cpu::repack::tensor_traits iq4_nl_8x1_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; @@ -4499,7 +4837,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { if (cur->ne[1] % 8 == 0) { return &q4_0_8x1_q8_0; } break; } + case 128: { if (cur->ne[1] % 8 == 0) { return &q4_0_8x1_q8_0; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q4_0_16x1_q8_0; } break; } case 512: { if (cur->ne[1] % 32 == 0) { return &q4_0_32x1_q8_0; } break; } case 1024: { if (cur->ne[1] % 64 == 0) { return &q4_0_64x1_q8_0; } break; } @@ -4526,7 +4864,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { if (cur->ne[1] % 8 == 0) { return &q4_K_8x1_q8_K; } break; } + case 128: { if (cur->ne[1] % 8 == 0) { return &q4_K_8x1_q8_K; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q4_K_16x1_q8_K; } break; } case 512: { if (cur->ne[1] % 32 == 0) { return &q4_K_32x1_q8_K; } break; } case 1024: { if (cur->ne[1] % 64 == 0) { return &q4_K_64x1_q8_K; } break; } @@ -4543,10 +4881,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { if (cur->ne[1] % 8 == 0) { return &q2_K_8x1_q8_K; } break; } + case 128: { if (cur->ne[1] % 8 == 0) { return &q2_K_8x1_q8_K; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q2_K_16x1_q8_K; } break; } case 512: { if (cur->ne[1] % 32 == 0) { return &q2_K_32x1_q8_K; } break; } - case 1024: { if (cur->ne[1] % 64 == 0) { return &q2_K_64x1_q8_K; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q2_K_64x1_q8_K; } break; } default: { return nullptr; } } #endif @@ -4561,6 +4899,16 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (cur->ne[1] % 8 == 0) { return &q5_K_8x4_q8_K; } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { if (cur->ne[1] % 8 == 0) { return &q5_K_8x1_q8_K; } break; } + case 256: { if (cur->ne[1] % 16 == 0) { return &q5_K_16x1_q8_K; } break; } + case 512: { if (cur->ne[1] % 32 == 0) { return &q5_K_32x1_q8_K; } break; } + case 1024: { if (cur->ne[1] % 64 == 0) { return &q5_K_64x1_q8_K; } break; } + default: { return nullptr; } + } + #endif } } else if (cur->type == GGML_TYPE_Q6_K) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { @@ -4587,7 +4935,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { if (cur->ne[1] % 8 == 0) { return &iq4_nl_8x1_q8_0; } break; } + case 128: { if (cur->ne[1] % 8 == 0) { return &iq4_nl_8x1_q8_0; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &iq4_nl_16x1_q8_0; } break; } case 512: { if (cur->ne[1] % 32 == 0) { return &iq4_nl_32x1_q8_0; } break; } case 1024: { if (cur->ne[1] % 64 == 0) { return &iq4_nl_64x1_q8_0; } break; } @@ -4620,7 +4968,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh switch (__riscv_vlenb() * 8) { - case 128: { if (cur->ne[1] % 8 == 0) { return &q8_0_8x1_q8_0; } break; } + case 128: { if (cur->ne[1] % 8 == 0) { return &q8_0_8x1_q8_0; } break; } case 256: { if (cur->ne[1] % 16 == 0) { return &q8_0_16x1_q8_0; } break; } case 512: { if (cur->ne[1] % 32 == 0) { return &q8_0_32x1_q8_0; } break; } case 1024: { if (cur->ne[1] % 64 == 0) { return &q8_0_64x1_q8_0; } break; } diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 0d97c9b9b5..2a0b1d42a4 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -82,33 +82,23 @@ static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, static_assert(sizeof(block_q2_Kx32) == sizeof(ggml_half) * 64 + QK_K * 2 + QK_K * 8, "wrong q2_K block size/padding"); static_assert(sizeof(block_q2_Kx64) == sizeof(ggml_half) * 128 + QK_K * 4 + QK_K * 16, "wrong q2_K block size/padding"); -struct block_q5_Kx8 { - ggml_half d[8]; // super-block scale for quantized scales - ggml_half dmin[8]; // super-block scale for quantized mins - uint8_t scales[96]; // scales and mins, quantized with 6 bits - uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants - uint8_t qs[QK_K * 8 / 2]; // low bits of 5-bit quants (in groups of 4) +template struct block_q5_Kx { + ggml_half d[N]; // super-block scale for quantized scales + ggml_half dmin[N]; // super-block scale for quantized mins + uint8_t scales[12 * N]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K * N / 8]; // high bits of 5-bit quants + uint8_t qs[QK_K * N / 2]; // low bits of 5-bit quants (in groups of 4) }; -static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5, - "wrong q5_K block size/padding"); +using block_q5_Kx8 = block_q5_Kx<8>; +using block_q5_Kx16 = block_q5_Kx<16>; +using block_q5_Kx32 = block_q5_Kx<32>; +using block_q5_Kx64 = block_q5_Kx<64>; -template struct block_q6_Kx { - ggml_half d[N]; - int8_t scales[QK_K / 16 * N]; - uint8_t ql[QK_K / 2 * N]; // low bits of 6-bit quants (groups of 2) - uint8_t qh[QK_K / 4 * N]; // high bits of 6-bit quants (groups of 4) -}; - -using block_q6_Kx8 = block_q6_Kx<8>; -using block_q6_Kx16 = block_q6_Kx<16>; -using block_q6_Kx32 = block_q6_Kx<32>; -using block_q6_Kx64 = block_q6_Kx<64>; - -static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8, "wrong q6_K block size/padding"); -static_assert(sizeof(block_q6_Kx16) == sizeof(ggml_half) * 16 + QK_K / 16 * 16 + 3 * QK_K / 4 * 16, "wrong q6_K block size/padding"); -static_assert(sizeof(block_q6_Kx32) == sizeof(ggml_half) * 32 + QK_K / 16 * 32 + 3 * QK_K / 4 * 32, "wrong q6_K block size/padding"); -static_assert(sizeof(block_q6_Kx64) == sizeof(ggml_half) * 64 + QK_K / 16 * 64 + 3 * QK_K / 4 * 64, "wrong q6_K block size/padding"); +static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 10, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_Kx32) == sizeof(ggml_half) * 64 + K_SCALE_SIZE * 32 + QK_K * 20, "wrong q5_K block size/padding"); +static_assert(sizeof(block_q5_Kx64) == sizeof(ggml_half) * 128 + K_SCALE_SIZE * 64 + QK_K * 40, "wrong q5_K block size/padding"); struct block_q8_Kx4 { float d[4]; // delta @@ -119,7 +109,7 @@ struct block_q8_Kx4 { static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding"); template struct block_iq4_nlx { - ggml_half d[N]; // deltas for `N` iq4_nl blocks + ggml_half d[N]; // deltas for `N` iq4_nl blocks uint8_t qs[QK4_NL * N / 2]; // nibbles / quants for N iq4_nl blocks }; @@ -189,7 +179,6 @@ void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined __riscv_zvfh void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); -void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -206,6 +195,10 @@ void ggml_gemv_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -226,6 +219,10 @@ void ggml_gemm_q4_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_32x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_64x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_32x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -288,6 +285,10 @@ void ggml_gemv_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -308,6 +309,10 @@ void ggml_gemm_q4_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_32x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_64x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_32x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);