From fbaa95bc29b9005586b76ab3d74f485eec05f282 Mon Sep 17 00:00:00 2001 From: rehan-10xengineer Date: Fri, 13 Mar 2026 20:36:04 +0500 Subject: [PATCH] ggml-cpu: add RVV vec dot kernels for quantization types (#18859) * ggml-cpu: add rvv quantize_row_q8_K kernel Co-authored-by: Rehan Qasim * ggml-cpu: add rvv vec_dot for iq4_nl, mxfp4, iq2_xxs Co-authored-by: Rehan Qasim * ggml-cpu: add rvv vec_dot for iq4_xs, refactor * ggml-cpu: remove ifunc for rvv vec dot * ggml-cpu: add vec_dot for iq2_xs, iq3_xxs Co-authored-by: Rehan Qasim * ggml-cpu: refactor quants.c --------- Co-authored-by: taimur-10x Co-authored-by: Rehan Qasim Co-authored-by: Rehan Qasim --- ggml/src/ggml-cpu/arch-fallback.h | 7 - ggml/src/ggml-cpu/arch/riscv/quants.c | 1959 ++++++++++++++++++------- 2 files changed, 1421 insertions(+), 545 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 175aa4a4bb..41da829315 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -199,13 +199,6 @@ #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__riscv) // quants.c -#define quantize_row_q8_K_generic quantize_row_q8_K -#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K -#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K -#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K -#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 -#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K -#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index bf9f4df118..826055dd9a 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -113,6 +113,104 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i #endif } +void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + block_q8_K * y_blocks = (block_q8_K *)y; + size_t nb = k / QK_K; + +#if defined(__riscv_v_intrinsic) + const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8(); + + for (size_t i = 0; i < nb; i++) { + const float* x_block = x + i * QK_K; + block_q8_K* y_block = &y_blocks[i]; + + // 1. Calculate Min/Max + vfloat32m8_t max_v = __riscv_vfmv_v_f_f32m8(-__builtin_inff(), vlmax_f32m8); + vfloat32m8_t min_v = __riscv_vfmv_v_f_f32m8(__builtin_inff(), vlmax_f32m8); + + size_t rem = QK_K; + size_t offset = 0; + while (rem > 0) { + size_t vl = __riscv_vsetvl_e32m8(rem); + vfloat32m8_t v_curr = __riscv_vle32_v_f32m8(x_block + offset, vl); + max_v = __riscv_vfmax_vv_f32m8(max_v, v_curr, vl); + min_v = __riscv_vfmin_vv_f32m8(min_v, v_curr, vl); + rem -= vl; + offset += vl; + } + + vfloat32m1_t v_init_max = __riscv_vfmv_s_f_f32m1(-__builtin_inff(), 1); + vfloat32m1_t v_init_min = __riscv_vfmv_s_f_f32m1(__builtin_inff(), 1); + + vfloat32m1_t v_scalar_max = __riscv_vfredmax_vs_f32m8_f32m1(max_v, v_init_max, vlmax_f32m8); + vfloat32m1_t v_scalar_min = __riscv_vfredmin_vs_f32m8_f32m1(min_v, v_init_min, vlmax_f32m8); + + float max_val = __riscv_vfmv_f_s_f32m1_f32(v_scalar_max); + float min_val = __riscv_vfmv_f_s_f32m1_f32(v_scalar_min); + + float amax = fabsf(max_val) > fabsf(min_val) ? fabsf(max_val) : fabsf(min_val); + + if (amax == 0.0f) { + y_block->d = 0.0f; + memset(y_block->qs, 0, QK_K); + memset(y_block->bsums, 0, sizeof(y_block->bsums)); + continue; + } + + const float iscale = -127.f / (fabsf(max_val) > fabsf(min_val) ? max_val : min_val); + y_block->d = 1.0f / iscale; + + // 2. Quantize and Calculate Sums + offset = 0; + rem = QK_K; + vint16m1_t v_zero_sum = __riscv_vmv_v_x_i16m1(0, 1); + + while (rem > 0) { + size_t vl = __riscv_vsetvl_e32m8(rem); + vfloat32m8_t v_f = __riscv_vle32_v_f32m8(x_block + offset, vl); + + v_f = __riscv_vfmul_vf_f32m8(v_f, iscale, vl); + + vint32m8_t v_i32 = __riscv_vfcvt_x_f_v_i32m8_rm(v_f, __RISCV_FRM_RNE, vl); + vint16m4_t v_i16 = __riscv_vnclip_wx_i16m4(v_i32, 0, __RISCV_VXRM_RNE, vl); + vint8m2_t v_q = __riscv_vnclip_wx_i8m2(v_i16, 0, __RISCV_VXRM_RNE, vl); + + __riscv_vse8_v_i8m2(y_block->qs + offset, v_q, vl); + + // first iteration clear + + int sum_idx; + vint8m1_t chunk_m1; + vint16m1_t v_sum; + sum_idx = offset / 16; + chunk_m1 = __riscv_vget_v_i8m2_i8m1(v_q, 0); + v_sum = __riscv_vwredsum_vs_i8m1_i16m1(chunk_m1, v_zero_sum, 16); + y_block->bsums[sum_idx] = (int16_t)__riscv_vmv_x_s_i16m1_i16(v_sum); + + // remaining iterations + vint8m2_t slid_q = v_q; + for (size_t k = 16; k < vl; k += 16) { + slid_q = __riscv_vslidedown_vx_i8m2(slid_q, 16, vl); + + sum_idx = (offset + k) / 16; + chunk_m1 = __riscv_vget_v_i8m2_i8m1(slid_q, 0); + + v_sum = __riscv_vwredsum_vs_i8m1_i16m1(chunk_m1, v_zero_sum, 16); + y_block->bsums[sum_idx] =(int16_t)__riscv_vmv_x_s_i16m1_i16(v_sum); + } + + rem -= vl; + offset += vl; + } + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_K_ref(x, y, k); +#endif +} + //===================================== Dot products ================================= void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { @@ -1954,544 +2052,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -static const uint8_t sign_gather_indices_arr[64] = { - 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, - 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 -}; - -static const uint8_t sign_bit_masks_arr[64] = { - 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, - 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128 -}; - -static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); - - const block_iq2_s * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - const uint64_t * grid64 = (const uint64_t *)iq2s_grid; - - // --- Pre-load Constants --- - uint16_t gather_qh_arr[8] = {0, 0, 0, 0, 1, 1, 1, 1}; - vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 8); - uint16_t shift_qh_arr[8] = {11, 9, 7, 5, 11, 9, 7, 5}; - vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 8); - - // Constants for sign extraction - vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64); - vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64); - - float sumf = 0.0f; - - for (int i = 0; i < nb; ++i) { - const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint8_t * GGML_RESTRICT scales = x[i].scales; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const uint8_t * signs_ptr = qs + 32; - - float sum_block = 0.0f; - - for (int ib = 0; ib < 4; ++ib) { - // Combine low + high bits - vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 8); - qs += 8; - uint16_t qh_val; - memcpy(&qh_val, qh, 2); - qh += 2; - vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8((const uint8_t*)&qh_val, 2); - vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 2); - vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); - vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 8); - v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 8); - - // Mask: We want bits 11-12. 0x1800 = 0001 1000 0000 0000 - v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 8); - vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 8); - - // Multiply by 8 to get byte offset, instead of element offset - v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 8); - vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 8); - - // Lookup Grid using Byte Offsets - vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 8); - - vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); - vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); - - // Load signs and generate sign mask - vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 8); - signs_ptr += 8; - - vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); - vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); - - vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); - vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); - - vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); - q8 += 64; - - vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); - vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 64); - - vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); - - int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(v_dot, 0), v_zero, 16)); - int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(v_dot, 1), v_zero, 16)); - int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(v_dot, 2), v_zero, 16)); - int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(v_dot, 3), v_zero, 16)); - - uint8_t sc0 = scales[0]; - uint8_t sc1 = scales[1]; - scales += 2; - - sum_block += s0 * (2 * (sc0 & 0xF) + 1); - sum_block += s1 * (2 * (sc0 >> 4) + 1); - sum_block += s2 * (2 * (sc1 & 0xF) + 1); - sum_block += s3 * (2 * (sc1 >> 4) + 1); - } - sumf += sum_block * combined_scale; - } - *s = 0.125f * sumf; -} - -static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); - - const block_iq2_s * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - const uint64_t * grid64 = (const uint64_t *)iq2s_grid; - - // Pre-load Constants - vuint8m2_t v_ids = __riscv_vid_v_u8m2(32); - vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 32); - vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 32); - vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 32); - vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 32); - uint16_t shift_qh_arr[4] = {11, 9, 7, 5}; - vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 4); - - float sumf = 0.0f; - - for (int i = 0; i < nb; ++i) { - const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint8_t * GGML_RESTRICT scales = x[i].scales; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const uint8_t * signs_ptr = qs + 32; - float sum_block = 0.0f; - - for (int ib = 0; ib < 8; ++ib) { - - // Load Low Bits [4 bytes] - vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 4); - qs += 4; - - // Load 1 byte. It contains bits for 4 mini-blocks. - uint8_t qh_val = *qh++; - - // Combine Low + High bits of 10bit indices - vuint8mf4_t v_qh_raw = __riscv_vmv_v_x_u8mf4(qh_val, 4); - vuint16mf2_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qh_raw, 4); - vuint16mf2_t v_qh_mf2 = __riscv_vsll_vv_u16mf2(v_qh_u16, v_shift_qh, 4); - v_qh_mf2 = __riscv_vand_vx_u16mf2(v_qh_mf2, 0x1800, 4); - vuint16mf2_t v_qs_u16_mf2 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 4); - vuint16mf2_t v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16_mf2, 3, 4); - vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_mf2, 4); - - // Lookup Grid - vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(__riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 4))); - - vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 4); - signs_ptr += 4; - vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); - vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32); - - // generating sign mask - vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32); - vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32); - - vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32); - q8 += 32; - - // apply signs - vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative,v_q8, v_q8, 0, 32); - vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 32); - - // Reduction - vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); - - // Reduce 0-15 (First Half) - int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( - __riscv_vget_v_i16m4_i16m2(v_dot, 0), v_zero, 16)); - - // Reduce 16-31 (Second Half) - int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( - __riscv_vget_v_i16m4_i16m2(v_dot, 1), v_zero, 16)); - - // Apply sub Scales - uint8_t sc = *scales++; - - sum_block += s0 * (2 * (sc & 0xF) + 1); - sum_block += s1 * (2 * (sc >> 4) + 1); - } - sumf += sum_block * combined_scale; - } - *s = 0.125f * sumf; -} - -void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic - switch (__riscv_vlenb() * 8) { - case 128: - ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); - break; - case 256: - ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); - break; - default: - ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); - break; - } -#else - ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - -static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq3_s * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - - const uint64_t * grid64 = (const uint64_t *)iq3s_grid; - - // --- Pre-load Constants --- - const uint16_t qh_bit_shifts_arr[16] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 - }; - vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64); - vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64); - vuint16m1_t v_qh_shifts = __riscv_vle16_v_u16m1(qh_bit_shifts_arr, 16); - - float sumf = 0.0f; - - for (int i = 0; i < nb; ++i) { - const float d = GGML_CPU_FP16_TO_FP32(x[i].d); - const float combined_scale = d * y[i].d; - - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint8_t * GGML_RESTRICT scales = x[i].scales; - const uint8_t * GGML_RESTRICT signs = x[i].signs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - float sum_block = 0.0f; - - // Loop: Process 64 weights (16 mini-blocks of 4) per iteration - for (int ib = 0; ib < 4; ++ib) { - - vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 16); - qs += 16; - - uint16_t qh_val; - memcpy(&qh_val, qh, 2); - qh += 2; - - vuint16m1_t v_qh_val = __riscv_vmv_v_x_u16m1(qh_val, 16); - // Extract bits: (qh >> i) & 1 - v_qh_val = __riscv_vsrl_vv_u16m1(v_qh_val, v_qh_shifts, 16); - v_qh_val = __riscv_vand_vx_u16m1(v_qh_val, 1, 16); - - vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 16); - v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 16); - v_qh_val = __riscv_vsll_vx_u16m1(v_qh_val, 10, 16); - vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_val, 16); - - // Grid value is 4xuint8 - vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16); - vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); - vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8); - signs += 8; - - // Generate sign mask - vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); - vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); - vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); - vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); - - vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); - q8 += 64; - - // Apply Signs - vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); - vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64); - - // Reduction - vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0); - vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1); - vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); - - int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32)); - int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32)); - - // Apply sub-scales - uint8_t sc_byte = *scales++; - int sc_lo = (sc_byte & 0xF) * 2 + 1; - int sc_hi = (sc_byte >> 4) * 2 + 1; - - sum_block += s_lo * sc_lo + s_hi * sc_hi; - } - sumf += sum_block * combined_scale; - } - *s = 0.125f * sumf; -} - -void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic - switch (__riscv_vlenb() * 8) { - case 256: - ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); - break; - default: - ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); - break; - } -#else - ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - -static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_tq1_0 * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - - float sumf = 0.0f; - uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; - - for (int i = 0; i < nb; i++) { - // First loop. - vint32m4_t suml1; - { - const int vl = 32; - vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl); - - vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tq, 3, vl), 8, vl); - vuint16m2_t tq1 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 3, vl), 3, vl), 8, vl); - vuint16m2_t tq2 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 9, vl), 3, vl), 8, vl); - vuint16m2_t tq3 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 27, vl), 3, vl), 8, vl); - vuint16m2_t tq4 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 81, vl), 3, vl), 8, vl); - - vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 0, vl), vl); - vint16m2_t q81 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 32, vl), vl); - vint16m2_t q82 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 64, vl), vl); - vint16m2_t q83 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 96, vl), vl); - vint16m2_t q84 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 128, vl), vl); - - vint16m2_t sum0 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl); - vint16m2_t sum1 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq1, 1, vl)), q81, vl); - vint16m2_t sum2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq2, 1, vl)), q82, vl); - vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl); - vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl); - - vint32m4_t sumi0 = __riscv_vwadd_vv_i32m4(sum0, sum1, vl); - vint32m4_t sumi1 = __riscv_vwadd_vv_i32m4(sum2, sum3, vl); - suml1 = __riscv_vadd_vv_i32m4(__riscv_vwcvt_x_x_v_i32m4(sum4, vl), __riscv_vadd_vv_i32m4(sumi0, sumi1, vl), vl); - } - - // Second loop. - vint32m2_t suml2; - { - const int vl = 16; - vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl); - - vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3 * 1, vl), 8, vl); - vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl); - vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl); - vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl); - vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl); - - vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 160, vl), vl); - vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 176, vl), vl); - vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 192, vl), vl); - vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 208, vl), vl); - vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 224, vl), vl); - - vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); - vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl); - vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl); - vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl); - vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl); - - vint32m2_t sumi0 = __riscv_vwadd_vv_i32m2(sum0, sum1, vl); - vint32m2_t sumi1 = __riscv_vwadd_vv_i32m2(sum2, sum3, vl); - suml2 = __riscv_vadd_vv_i32m2(__riscv_vwcvt_x_x_v_i32m2(sum4, vl), __riscv_vadd_vv_i32m2(sumi0, sumi1, vl), vl); - } - - // Third loop. - vint32m2_t suml3; - { - const int vl = 16; - - uint32_t qh; - memcpy(&qh, &x[i].qh[0], 4); - // Prevent fusion with vmv. - __asm__ __volatile__("" : "+r"(qh)); - vuint8mf2_t tq = __riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4)); - - vuint8mf2_t p = __riscv_vle8_v_u8mf2(pow, vl); - - vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vv_u8mf2(tq, p, vl), 3, vl), 8, vl); - - vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl); - - vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); - suml3 = __riscv_vwcvt_x_x_v_i32m2(sum0, vl); - } - - vint32m2_t sumb = __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(suml1, 0), __riscv_vget_v_i32m4_i32m2(suml1, 1), 16); - sumb = __riscv_vadd_vv_i32m2(sumb, suml2, 16); - sumb = __riscv_vadd_vv_i32m2(sumb, suml3, 16); - - vint32m1_t sum = __riscv_vredsum_vs_i32m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); - sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - } - - *s = sumf; -} - -void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic - switch (__riscv_vlenb() * 8) { - case 256: - ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); - break; - default: - ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); - break; - } -#else - ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - -static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_tq2_0 * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - - float sumf = 0.0f; - for (int i = 0; i < nb; ++i) { - int32_t sumi = 0; - - for (size_t j = 0; j < sizeof(x[0].qs); j += 32) { - const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32]; - const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32]; - const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32]; - const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32]; - const uint8_t* px = &x[i].qs[j]; - - size_t vlmax_16m2 = __riscv_vsetvl_e16m2(32); - vint16m2_t vacc16 = __riscv_vmv_v_x_i16m2(0, vlmax_16m2); - - size_t vl = __riscv_vsetvl_e8m1(32); - - vuint8m1_t vx_u8 = __riscv_vle8_v_u8m1(px, vl); - - vint8m1_t vy0 = __riscv_vle8_v_i8m1(py0 , vl); - vint8m1_t vy1 = __riscv_vle8_v_i8m1(py1, vl); - vint8m1_t vy2 = __riscv_vle8_v_i8m1(py2, vl); - vint8m1_t vy3 = __riscv_vle8_v_i8m1(py3, vl); - - // l=0 (bits 1:0) - vuint8m1_t t0 = __riscv_vand_vx_u8m1(vx_u8, 0x03, vl); - vint8m1_t vq0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t0), 1, vl); - - // l=1 (bits 3:2) - vuint8m1_t t1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 2, vl), 0x03, vl); - vint8m1_t vq1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t1), 1, vl); - - // l=2 (bits 5:4) - vuint8m1_t t2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 4, vl), 0x03, vl); - vint8m1_t vq2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t2), 1, vl); - - // l=3 (bits 7:6) - vuint8m1_t t3 = __riscv_vsrl_vx_u8m1(vx_u8, 6, vl); // No final AND needed as vsrl shifts in zeros - vint8m1_t vq3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t3), 1, vl); - - // 4. Multiply and accumulate - vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq0, vy0, vl); - vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq1, vy1, vl); - vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq2, vy2, vl); - vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq3, vy3, vl); - - vlmax_16m2 = __riscv_vsetvl_e16m2(32); - vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1); - vint32m1_t vred32 = __riscv_vwredsum_vs_i16m2_i32m1(vacc16, vzero32, vlmax_16m2); - - sumi += __riscv_vmv_x_s_i32m1_i32(vred32); - } - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - sumf += (float)sumi * d; - } - - *s = sumf; -} - -void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic - switch (__riscv_vlenb() * 8) { - case 256: - ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); - break; - default: - ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); - break; - } -#else - ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2724,3 +2284,1326 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } + +static const uint8_t sign_gather_indices_arr[64] = { + 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, + 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 +}; + +static const uint8_t sign_bit_masks_arr[64] = { + 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, + 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128 +}; + + +static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + // Pre-load Constants + vuint8m2_t v_ids = __riscv_vid_v_u8m2(32); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 32); + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 32); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 32); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 32); + uint16_t shift_qh_arr[4] = {11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 4); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + float sum_block = 0.0f; + + for (int ib = 0; ib < 8; ++ib) { + + // Load Low Bits [4 bytes] + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 4); + qs += 4; + + // Load 1 byte. It contains bits for 4 mini-blocks. + uint8_t qh_val = *qh++; + + // Combine Low + High bits of 10bit indices + vuint8mf4_t v_qh_raw = __riscv_vmv_v_x_u8mf4(qh_val, 4); + vuint16mf2_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qh_raw, 4); + vuint16mf2_t v_qh_mf2 = __riscv_vsll_vv_u16mf2(v_qh_u16, v_shift_qh, 4); + v_qh_mf2 = __riscv_vand_vx_u16mf2(v_qh_mf2, 0x1800, 4); + vuint16mf2_t v_qs_u16_mf2 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 4); + vuint16mf2_t v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16_mf2, 3, 4); + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_mf2, 4); + + // Lookup Grid + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(__riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 4))); + + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 4); + signs_ptr += 4; + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32); + + // generating sign mask + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // apply signs + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative,v_q8, v_q8, 0, 32); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 32); + + // Reduction + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + // Reduce 0-15 (First Half) + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m4_i16m2(v_dot, 0), v_zero, 16)); + + // Reduce 16-31 (Second Half) + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m4_i16m2(v_dot, 1), v_zero, 16)); + + // Apply sub Scales + uint8_t sc = *scales++; + + sum_block += s0 * (2 * (sc & 0xF) + 1); + sum_block += s1 * (2 * (sc >> 4) + 1); + } + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + // --- Pre-load Constants --- + uint16_t gather_qh_arr[8] = {0, 0, 0, 0, 1, 1, 1, 1}; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 8); + uint16_t shift_qh_arr[8] = {11, 9, 7, 5, 11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 8); + + // Constants for sign extraction + vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64); + vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 4; ++ib) { + // Combine low + high bits + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 8); + qs += 8; + uint16_t qh_val; + memcpy(&qh_val, qh, 2); + qh += 2; + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8((const uint8_t*)&qh_val, 2); + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 2); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 8); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 8); + + // Mask: We want bits 11-12. 0x1800 = 0001 1000 0000 0000 + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 8); + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 8); + + // Multiply by 8 to get byte offset, instead of element offset + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 8); + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 8); + + // Lookup Grid using Byte Offsets + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 8); + + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + // Load signs and generate sign mask + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 8); + signs_ptr += 8; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); + + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 64); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 0), v_zero, 16)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 1), v_zero, 16)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 2), v_zero, 16)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 3), v_zero, 16)); + + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + scales += 2; + + sum_block += s0 * (2 * (sc0 & 0xF) + 1); + sum_block += s1 * (2 * (sc0 >> 4) + 1); + sum_block += s2 * (2 * (sc1 & 0xF) + 1); + sum_block += s3 * (2 * (sc1 >> 4) + 1); + } + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined(__riscv_v) +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; +#endif + +static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xs_grid; + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT qs = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + + int32_t sum_int = 0; + + // Loop over 4 subblocks of 64 elements (QK_K = 256) + for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) { + // Load 8 uint16 indices (controls 64 values) + vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2(qs, 8); + qs += 8; + + // Extract indices for grid (low 9 bits) and signs (high 7 bits) + // Multiply by 8 (<< 3) for byte offsets into the uint64 tables + vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2(__riscv_vand_vx_u16mf2(v_qs, 511, 8), 3, 8); + vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2(__riscv_vsrl_vx_u16mf2(v_qs, 9, 8), 3, 8); + + vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2(grid64, vidx_grid, 8); + vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2(signs64, vidx_sign, 8); + + vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64)); + vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64)); + + vint8m2_t q2_final = __riscv_vmul_vv_i8m2(q2u, q2s, 64); + + vint8m2_t q8v = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + vint16m4_t prod = __riscv_vwmul_vv_i16m4(q2_final, q8v, 64); + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 16)); + int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 16)); + int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 16)); + int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 16)); + + const uint8_t scale_byte_1 = scales[0]; + const uint8_t scale_byte_2 = scales[1]; + scales += 2; + + sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1); + sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1); + sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1); + sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1); + } + + sumf += d * sum_int; + } + *s = 0.125f * sumf; +} + +void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xxs_grid; + + uint32_t shift_constants[4] = {0, 7, 14, 21}; + vuint32m1_t v_shifts = __riscv_vle32_v_u32m1(shift_constants, 4); + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum = 0.0f; + + #pragma GCC unroll 1 + for (int ib32 = 0; ib32 < QK_K / 32; ib32 += 2) { + vint8m2_t q8_1 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32; + vint8m2_t q8_2 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32; + + vuint8mf4_t v_raw_q2_1 = __riscv_vle8_v_u8mf4(q2_ptr, 4); + vuint8mf4_t v_raw_q2_2 = __riscv_vle8_v_u8mf4(q2_ptr + 8, 4); + + vuint16mf2_t vidx_q2_1 = __riscv_vwcvtu_x_x_v_u16mf2(v_raw_q2_1, 4); + vuint16mf2_t vidx_q2_2 = __riscv_vwcvtu_x_x_v_u16mf2(v_raw_q2_2, 4); + + vidx_q2_1 = __riscv_vsll_vx_u16mf2(vidx_q2_1, 3, 4); + vidx_q2_2 = __riscv_vsll_vx_u16mf2(vidx_q2_2, 3, 4); + + uint32_t s_packed_1, s_packed_2; + memcpy(&s_packed_1, q2_ptr + 4, 4); + memcpy(&s_packed_2, q2_ptr + 12, 4); + + vuint32m1_t v_s_1 = __riscv_vmv_v_x_u32m1(s_packed_1, 4); + vuint32m1_t v_s_2 = __riscv_vmv_v_x_u32m1(s_packed_2, 4); + v_s_1 = __riscv_vsrl_vv_u32m1(v_s_1, v_shifts, 4); + v_s_2 = __riscv_vsrl_vv_u32m1(v_s_2, v_shifts, 4); + + v_s_1 = __riscv_vand_vx_u32m1(v_s_1, 127, 4); + v_s_2 = __riscv_vand_vx_u32m1(v_s_2, 127, 4); + + vuint16mf2_t vidx_s2_1 = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_1, 4), 3, 4); + vuint16mf2_t vidx_s2_2 = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_2, 4), 3, 4); + + vuint64m2_t vq2_64_1 = __riscv_vluxei16_v_u64m2(grid64, vidx_q2_1, 4); + vuint64m2_t vq2_64_2 = __riscv_vluxei16_v_u64m2(grid64, vidx_q2_2, 4); + + vint8m2_t q2_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64_1)); + vint8m2_t q2_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64_2)); + + vuint64m2_t vs2_64_1 = __riscv_vluxei16_v_u64m2(signs64, vidx_s2_1, 4); + vuint64m2_t vs2_64_2 = __riscv_vluxei16_v_u64m2(signs64, vidx_s2_2, 4); + vint8m2_t s2_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64_1)); + vint8m2_t s2_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64_2)); + + vint8m2_t q8s_1 = __riscv_vmul_vv_i8m2(q8_1, s2_1, 32); + vint8m2_t q8s_2 = __riscv_vmul_vv_i8m2(q8_2, s2_2, 32); + + vint16m4_t dot1 = __riscv_vwmul_vv_i16m4(q8s_1, q2_1, 32); + vint16m4_t dot2 = __riscv_vwmul_vv_i16m4(q8s_2, q2_2, 32); + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t sumv1 = __riscv_vwredsum_vs_i16m4_i32m1(dot1, zero_vec, 32); + vint32m1_t sumv2 = __riscv_vwredsum_vs_i16m4_i32m1(dot2, zero_vec, 32); + + int32_t scalar_sum1 = __riscv_vmv_x_s_i32m1_i32(sumv1); + int32_t scalar_sum2 = __riscv_vmv_x_s_i32m1_i32(sumv2); + + int16_t scale1 = 2 * ((s_packed_1 >> 28) & 0xF) + 1; + int16_t scale2 = 2 * ((s_packed_2 >> 28) & 0xF) + 1; + + sum += scalar_sum1 * scale1 + scalar_sum2 * scale2; + q2_ptr += 16; + } + sumf += sum * combined_scale; + } + *s = 0.125f * sumf; +} + +static void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xxs_grid; + + uint32_t shift_constants[4] = {0, 7, 14, 21}; + vuint32mf2_t v_shifts = __riscv_vle32_v_u32mf2(shift_constants, 4); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum = 0.0f; + + for (int ib32 = 0; ib32 < QK_K / 32; ib32 += 2) { + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8, 32); q8 += 32; + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8, 32); q8 += 32; + + vuint8mf8_t v_raw_q2_1 = __riscv_vle8_v_u8mf8(q2_ptr, 4); + vuint8mf8_t v_raw_q2_2 = __riscv_vle8_v_u8mf8(q2_ptr + 8, 4); + + vuint16mf4_t vidx_q2_1 = __riscv_vwcvtu_x_x_v_u16mf4(v_raw_q2_1, 4); + vuint16mf4_t vidx_q2_2 = __riscv_vwcvtu_x_x_v_u16mf4(v_raw_q2_2, 4); + + vidx_q2_1 = __riscv_vsll_vx_u16mf4(vidx_q2_1, 3, 4); + vidx_q2_2 = __riscv_vsll_vx_u16mf4(vidx_q2_2, 3, 4); + + uint32_t s_packed_1, s_packed_2; + memcpy(&s_packed_1, q2_ptr + 4, 4); + memcpy(&s_packed_2, q2_ptr + 12, 4); + + vuint32mf2_t v_s_1 = __riscv_vmv_v_x_u32mf2(s_packed_1, 4); + vuint32mf2_t v_s_2 = __riscv_vmv_v_x_u32mf2(s_packed_2, 4); + + v_s_1 = __riscv_vsrl_vv_u32mf2(v_s_1, v_shifts, 4); + v_s_2 = __riscv_vsrl_vv_u32mf2(v_s_2, v_shifts, 4); + + v_s_1 = __riscv_vand_vx_u32mf2(v_s_1, 127, 4); + v_s_2 = __riscv_vand_vx_u32mf2(v_s_2, 127, 4); + + // Narrow u32 -> u16 (vncvt) and Scale by 8 to get byte offsets + vuint16mf4_t vidx_s2_1 = __riscv_vsll_vx_u16mf4(__riscv_vncvt_x_x_w_u16mf4(v_s_1, 4), 3, 4); + vuint16mf4_t vidx_s2_2 = __riscv_vsll_vx_u16mf4(__riscv_vncvt_x_x_w_u16mf4(v_s_2, 4), 3, 4); + + // Load q2 values from lookup grid + vuint64m1_t vq2_64_1 = __riscv_vluxei16_v_u64m1(grid64, vidx_q2_1, 4); + vuint64m1_t vq2_64_2 = __riscv_vluxei16_v_u64m1(grid64, vidx_q2_2, 4); + vint8m1_t q2_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vq2_64_1)); + vint8m1_t q2_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vq2_64_2)); + + // Load sign values + vuint64m1_t vs2_64_1 = __riscv_vluxei16_v_u64m1(signs64, vidx_s2_1, 4); + vuint64m1_t vs2_64_2 = __riscv_vluxei16_v_u64m1(signs64, vidx_s2_2, 4); + vint8m1_t s2_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vs2_64_1)); + vint8m1_t s2_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vs2_64_2)); + + // Apply signs to q8 + vint8m1_t q8s_1 = __riscv_vmul_vv_i8m1(q8_1, s2_1, 32); + vint8m1_t q8s_2 = __riscv_vmul_vv_i8m1(q8_2, s2_2, 32); + + // multiplying q2 with q8 + vint16m2_t dot1 = __riscv_vwmul_vv_i16m2(q8s_1, q2_1, 32); + vint16m2_t dot2 = __riscv_vwmul_vv_i16m2(q8s_2, q2_2, 32); + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t sumv1 = __riscv_vwredsum_vs_i16m2_i32m1(dot1, zero_vec, 32); + vint32m1_t sumv2 = __riscv_vwredsum_vs_i16m2_i32m1(dot2, zero_vec, 32); + int32_t scalar_sum1 = __riscv_vmv_x_s_i32m1_i32(sumv1); + int32_t scalar_sum2 = __riscv_vmv_x_s_i32m1_i32(sumv2); + int16_t scale1 = 2 * ((s_packed_1 >> 28) & 0xF) + 1; + int16_t scale2 = 2 * ((s_packed_2 >> 28) & 0xF) + 1; + + sum += scalar_sum1 * scale1 + scalar_sum2 * scale2; + q2_ptr += 16; + } + sumf += sum * combined_scale; + } + *s = 0.125f * sumf; +} + +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const uint64_t * grid64 = (const uint64_t *)iq3s_grid; + + // --- Pre-load Constants --- + const uint16_t qh_bit_shifts_arr[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + }; + vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64); + vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64); + vuint16m1_t v_qh_shifts = __riscv_vle16_v_u16m1(qh_bit_shifts_arr, 16); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d); + const float combined_scale = d * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum_block = 0.0f; + + // Loop: Process 64 weights (16 mini-blocks of 4) per iteration + for (int ib = 0; ib < 4; ++ib) { + + vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 16); + qs += 16; + + uint16_t qh_val; + memcpy(&qh_val, qh, 2); + qh += 2; + + vuint16m1_t v_qh_val = __riscv_vmv_v_x_u16m1(qh_val, 16); + // Extract bits: (qh >> i) & 1 + v_qh_val = __riscv_vsrl_vv_u16m1(v_qh_val, v_qh_shifts, 16); + v_qh_val = __riscv_vand_vx_u16m1(v_qh_val, 1, 16); + + vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 16); + v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 16); + v_qh_val = __riscv_vsll_vx_u16m1(v_qh_val, 10, 16); + vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_val, 16); + + // Grid value is 4xuint8 + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8); + signs += 8; + + // Generate sign mask + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + // Apply Signs + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64); + + // Reduction + vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0); + vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1); + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32)); + int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32)); + + // Apply sub-scales + uint8_t sc_byte = *scales++; + int sc_lo = (sc_byte & 0xF) * 2 + 1; + int sc_hi = (sc_byte >> 4) * 2 + 1; + + sum_block += s_lo * sc_lo + s_hi * sc_hi; + } + sumf += sum_block * combined_scale; + } + *s = sumf; +} + +void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + // constants for unpacking logic + const uint32_t shifts_val[8] = {0, 7, 14, 21, 0, 7, 14, 21}; + vuint32m1_t v_shifts = __riscv_vle32_v_u32m1(shifts_val, 8); + + const uint32_t gather_idx_val[8] = {0, 0, 0, 0, 1, 1, 1, 1}; + vuint32m1_t v_gather_idx = __riscv_vle32_v_u32m1(gather_idx_val, 8); + + uint32_t aux32[2]; + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float block_sum = 0.0f; + + for (int ib = 0; ib < QK_K / 64; ++ib) { + // Load q8 (64 bytes) + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + // load of metadata via memcpy + memcpy(aux32, metadata, 2 * sizeof(uint32_t)); + metadata += 2 * sizeof(uint32_t); + + // Load q3 indices and gather magnitudes + vuint8mf2_t v_q3_idx_u8 = __riscv_vle8_v_u8mf2(q3_indices, 16); + q3_indices += 16; + + vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_u8, 4, 16); + vuint32m2_t v_q3_magnitudes_u32 = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 16); + vint8m2_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u32m2_u8m2(v_q3_magnitudes_u32)); + + // --- Unpacking of Sign Indices --- + + // 1. Load the 2 auxiliary 32-bit integers into a vector + vuint32m1_t v_aux = __riscv_vle32_v_u32m1(aux32, 2); + + // 2. Broadcast/Gather: replicate aux[0] to first 4 lanes, aux[1] to next 4 lanes + vuint32m1_t v_aux_expanded = __riscv_vrgather_vv_u32m1(v_aux, v_gather_idx, 8); + + // 3. Apply Shifts and Mask: ((val >> shift) & 127) + vuint32m1_t v_s_vals_raw = __riscv_vand_vx_u32m1(__riscv_vsrl_vv_u32m1(v_aux_expanded, v_shifts, 8), 127, 8); + + // 4. Narrow to u16 (required for vluxei index) and multiply by 8 (byte offset for u64 table) + vuint16mf2_t sign_indices_byte_offset = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_vals_raw, 8), 3, 8); + + // 5. Gather Signs + vuint64m2_t v_s_vals_u64 = __riscv_vluxei16_v_u64m2(signs64, sign_indices_byte_offset, 8); + vint8m2_t v_s_vals = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(v_s_vals_u64)); + + vint8m2_t v_q3_signed = __riscv_vmul_vv_i8m2(v_q3_magnitudes, v_s_vals, 64); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_q8, v_q3_signed, 64); + + vint16m2_t v_dot_1 = __riscv_vget_v_i16m4_i16m2(v_dot, 0); + vint16m2_t v_dot_2 = __riscv_vget_v_i16m4_i16m2(v_dot, 1); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t v_sum_1 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_1, v_zero, 32); + vint32m1_t v_sum_2 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_2, v_zero, 32); + + int32_t sum1_i = __riscv_vmv_x_s_i32m1_i32(v_sum_1); + int32_t sum2_i = __riscv_vmv_x_s_i32m1_i32(v_sum_2); + + const float scale1_f = (float)(2 * (aux32[0] >> 28) + 1); + const float scale2_f = (float)(2 * (aux32[1] >> 28) + 1); + + block_sum += sum1_i * scale1_f + sum2_i * scale2_f; + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} + +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + + // Load the lookup table once. + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); + int acc1, acc2; + + // We process 2 blocks at once. + for (; ib + 1 < nb; ib += 2) { + // Weights and activations. + vuint8m1_t iq4_packed1 = __riscv_vle8_v_u8m1(x[ib + 0].qs, 16); + vint8m2_t q8b1 = __riscv_vle8_v_i8m2(y[ib + 0].qs, 32); + vuint8m1_t iq4_packed2 = __riscv_vle8_v_u8m1(x[ib + 1].qs, 16); + vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32); + + // Unpack the weight blocks. + vuint8m2_t iq4bits1; + iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 0, __riscv_vand_vx_u8m1(iq4_packed1, 0xf, 16)); + iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 1, __riscv_vsrl_vx_u8m1(iq4_packed1, 4, 16)); + vuint8m2_t iq4bits2; + iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 0, __riscv_vand_vx_u8m1(iq4_packed2, 0xf, 16)); + iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 1, __riscv_vsrl_vx_u8m1(iq4_packed2, 4, 16)); + + // Gather values from the lookup table. + vint8m2_t iq4b1 = __riscv_vrgather_vv_i8m2(values, iq4bits1, 32); + vint8m2_t iq4b2 = __riscv_vrgather_vv_i8m2(values, iq4bits2, 32); + + // Accumulation. + vint16m4_t sum1 = __riscv_vwmul_vv_i16m4(q8b1, iq4b1, 32); + vint16m4_t sum2 = __riscv_vwmul_vv_i16m4(q8b2, iq4b2, 32); + __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m4_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m4_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1)); + sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2)); + } + + *s = sumf; +} + +static void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + + // Load the lookup table once. + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + int acc1, acc2; + + // We process 2 blocks at once. + for (; ib + 1 < nb; ib += 2) { + // Weights and activations. + vuint8mf2_t iq4_packed1 = __riscv_vle8_v_u8mf2(x[ib + 0].qs, 16); + vint8mf2_t q8b_lo1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs, 16); + vint8mf2_t q8b_hi1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs + 16, 16); + vuint8mf2_t iq4_packed2 = __riscv_vle8_v_u8mf2(x[ib + 1].qs, 16); + vint8mf2_t q8b_lo2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs, 16); + vint8mf2_t q8b_hi2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs + 16, 16); + + // Unpack the weight blocks. + vuint8mf2_t iq4bits_lo1 = __riscv_vand_vx_u8mf2(iq4_packed1, 0xf, 16); + vuint8mf2_t iq4bits_hi1 = __riscv_vsrl_vx_u8mf2(iq4_packed1, 4, 16); + vuint8mf2_t iq4bits_lo2 = __riscv_vand_vx_u8mf2(iq4_packed2, 0xf, 16); + vuint8mf2_t iq4bits_hi2 = __riscv_vsrl_vx_u8mf2(iq4_packed2, 4, 16); + + // Gather values from the lookup table. + vint8mf2_t iq4b_lo1 = __riscv_vrgather_vv_i8mf2(values, iq4bits_lo1, 16); + vint8mf2_t iq4b_hi1 = __riscv_vrgather_vv_i8mf2(values, iq4bits_hi1, 16); + vint8mf2_t iq4b_lo2 = __riscv_vrgather_vv_i8mf2(values, iq4bits_lo2, 16); + vint8mf2_t iq4b_hi2 = __riscv_vrgather_vv_i8mf2(values, iq4bits_hi2, 16); + + // Accumulation. + vint16m1_t sum1 = __riscv_vwmul_vv_i16m1(q8b_lo1, iq4b_lo1, 16); + sum1 = __riscv_vwmacc_vv_i16m1(sum1, q8b_hi1, iq4b_hi1, 16); + vint16m1_t sum2 = __riscv_vwmul_vv_i16m1(q8b_lo2, iq4b_lo2, 16); + sum2 = __riscv_vwmacc_vv_i16m1(sum2, q8b_hi2, iq4b_hi2, 16); + __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m1_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 16), 1); + __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m1_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 16), 1); + sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1)); + sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2)); + } + + *s = sumf; +} + +void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq4_nl_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __riscv_v_intrinsic + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + float sumf = 0; + int acc[4]; + + // Indices for re-ordering IQ4 data. + uint64_t index[16] = { + 0, 1, 8, 9, + 2, 3, 10, 11, + 4, 5,12, 13, + 6, 7, 14, 15, + }; + vuint64m4_t i_vec = __riscv_vle64_v_u64m4(index, 16); + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + + for (int ib = 0; ib < QK_K / 128; ++ib) { + // Weights and activations. + vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 64); + vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); + iq4 += 64; + q8 += 128; + + // Unpack the weight blocks. + vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 64); + vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 64); + vuint8m4_t iq4bits; + iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 0, iq4bits_lo); + iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 1, iq4bits_hi); + vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgather_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 16)); + vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 128); + + // Multiply with activations. + vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 128); + + // Reduce separately. + __riscv_vse32_v_i32m1(&acc[0],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + __riscv_vse32_v_i32m1(&acc[1],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + __riscv_vse32_v_i32m1(&acc[2],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + __riscv_vse32_v_i32m1(&acc[3],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + + int ls1 = ((x[ibl].scales_l[ib * 2 + 0] & 0xf) | ((h << 4) & 0x30)) - 32; + int ls2 = ((x[ibl].scales_l[ib * 2 + 0] >> 4) | ((h << 2) & 0x30)) - 32; + int ls3 = ((x[ibl].scales_l[ib * 2 + 1] & 0xf) | ((h << 0) & 0x30)) - 32; + int ls4 = ((x[ibl].scales_l[ib * 2 + 1] >> 4) | ((h >> 2) & 0x30)) - 32; + h >>= 8; + + sumi1 += acc[0] * ls1; + sumi2 += acc[1] * ls2; + sumi3 += acc[2] * ls3; + sumi4 += acc[3] * ls4; + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2 + sumi3 + sumi4); + } + + *s = sumf; + +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + for (int i = 0; i < nb; i++) { + // First loop. + vint32m4_t suml1; + { + const int vl = 32; + vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl); + + vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tq, 3, vl), 8, vl); + vuint16m2_t tq1 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 3, vl), 3, vl), 8, vl); + vuint16m2_t tq2 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 9, vl), 3, vl), 8, vl); + vuint16m2_t tq3 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 27, vl), 3, vl), 8, vl); + vuint16m2_t tq4 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 81, vl), 3, vl), 8, vl); + + vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 0, vl), vl); + vint16m2_t q81 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 32, vl), vl); + vint16m2_t q82 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 64, vl), vl); + vint16m2_t q83 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 96, vl), vl); + vint16m2_t q84 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 128, vl), vl); + + vint16m2_t sum0 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl); + vint16m2_t sum1 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq1, 1, vl)), q81, vl); + vint16m2_t sum2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq2, 1, vl)), q82, vl); + vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl); + vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl); + + vint32m4_t sumi0 = __riscv_vwadd_vv_i32m4(sum0, sum1, vl); + vint32m4_t sumi1 = __riscv_vwadd_vv_i32m4(sum2, sum3, vl); + suml1 = __riscv_vadd_vv_i32m4(__riscv_vwcvt_x_x_v_i32m4(sum4, vl), __riscv_vadd_vv_i32m4(sumi0, sumi1, vl), vl); + } + + // Second loop. + vint32m2_t suml2; + { + const int vl = 16; + vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl); + + vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3 * 1, vl), 8, vl); + vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl); + vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl); + vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl); + vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl); + + vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 160, vl), vl); + vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 176, vl), vl); + vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 192, vl), vl); + vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 208, vl), vl); + vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 224, vl), vl); + + vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); + vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl); + vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl); + vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl); + vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl); + + vint32m2_t sumi0 = __riscv_vwadd_vv_i32m2(sum0, sum1, vl); + vint32m2_t sumi1 = __riscv_vwadd_vv_i32m2(sum2, sum3, vl); + suml2 = __riscv_vadd_vv_i32m2(__riscv_vwcvt_x_x_v_i32m2(sum4, vl), __riscv_vadd_vv_i32m2(sumi0, sumi1, vl), vl); + } + + // Third loop. + vint32m2_t suml3; + { + const int vl = 16; + + uint32_t qh; + memcpy(&qh, &x[i].qh[0], 4); + // Prevent fusion with vmv. + __asm__ __volatile__("" : "+r"(qh)); + vuint8mf2_t tq = __riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4)); + + vuint8mf2_t p = __riscv_vle8_v_u8mf2(pow, vl); + + vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vv_u8mf2(tq, p, vl), 3, vl), 8, vl); + + vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl); + + vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); + suml3 = __riscv_vwcvt_x_x_v_i32m2(sum0, vl); + } + + vint32m2_t sumb = __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(suml1, 0), __riscv_vget_v_i32m4_i32m2(suml1, 1), 16); + sumb = __riscv_vadd_vv_i32m2(sumb, suml2, 16); + sumb = __riscv_vadd_vv_i32m2(sumb, suml3, 16); + + vint32m1_t sum = __riscv_vredsum_vs_i32m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); + sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + } + + *s = sumf; +} + +void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x[0].qs); j += 32) { + const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32]; + const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32]; + const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32]; + const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32]; + const uint8_t* px = &x[i].qs[j]; + + size_t vlmax_16m2 = __riscv_vsetvl_e16m2(32); + vint16m2_t vacc16 = __riscv_vmv_v_x_i16m2(0, vlmax_16m2); + + size_t vl = __riscv_vsetvl_e8m1(32); + + vuint8m1_t vx_u8 = __riscv_vle8_v_u8m1(px, vl); + + vint8m1_t vy0 = __riscv_vle8_v_i8m1(py0 , vl); + vint8m1_t vy1 = __riscv_vle8_v_i8m1(py1, vl); + vint8m1_t vy2 = __riscv_vle8_v_i8m1(py2, vl); + vint8m1_t vy3 = __riscv_vle8_v_i8m1(py3, vl); + + // l=0 (bits 1:0) + vuint8m1_t t0 = __riscv_vand_vx_u8m1(vx_u8, 0x03, vl); + vint8m1_t vq0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t0), 1, vl); + + // l=1 (bits 3:2) + vuint8m1_t t1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 2, vl), 0x03, vl); + vint8m1_t vq1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t1), 1, vl); + + // l=2 (bits 5:4) + vuint8m1_t t2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 4, vl), 0x03, vl); + vint8m1_t vq2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t2), 1, vl); + + // l=3 (bits 7:6) + vuint8m1_t t3 = __riscv_vsrl_vx_u8m1(vx_u8, 6, vl); // No final AND needed as vsrl shifts in zeros + vint8m1_t vq3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t3), 1, vl); + + // 4. Multiply and accumulate + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq0, vy0, vl); + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq1, vy1, vl); + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq2, vy2, vl); + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq3, vy3, vl); + + vlmax_16m2 = __riscv_vsetvl_e16m2(32); + vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t vred32 = __riscv_vwredsum_vs_i16m2_i32m1(vacc16, vzero32, vlmax_16m2); + + sumi += __riscv_vmv_x_s_i32m1_i32(vred32); + } + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + sumf += (float)sumi * d; + } + + *s = sumf; +} + +void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_MXFP4 == 0); + static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); + + const block_mxfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK_MXFP4; + + int ib = 0; + float sumf = 0; + + // Load the lookup table once. + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_mxfp4, 16); + int acc1, acc2; + + // We process 2 blocks at once. + for (; ib + 1 < nb; ib += 2) { + // Weights and activations. + vuint8m1_t mx_packed1 = __riscv_vle8_v_u8m1(x[ib + 0].qs, 16); + vint8m2_t q8b1 = __riscv_vle8_v_i8m2(y[ib + 0].qs, 32); + vuint8m1_t mx_packed2 = __riscv_vle8_v_u8m1(x[ib + 1].qs, 16); + vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32); + + // Unpack the weight blocks. + vuint8m2_t mxbits1; + mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 0, __riscv_vand_vx_u8m1(mx_packed1, 0xf, 16)); + mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 1, __riscv_vsrl_vx_u8m1(mx_packed1, 4, 16)); + vuint8m2_t mxbits2; + mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 0, __riscv_vand_vx_u8m1(mx_packed2, 0xf, 16)); + mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 1, __riscv_vsrl_vx_u8m1(mx_packed2, 4, 16)); + + // Gather values from the lookup table. + vint8m2_t mxb1 = __riscv_vrgather_vv_i8m2(values, mxbits1, 32); + vint8m2_t mxb2 = __riscv_vrgather_vv_i8m2(values, mxbits2, 32); + + // Accumulation. + vint16m4_t sum1 = __riscv_vwmul_vv_i16m4(q8b1, mxb1, 32); + vint16m4_t sum2 = __riscv_vwmul_vv_i16m4(q8b2, mxb2, 32); + __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m4_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m4_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1)); + sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2)); + } + + *s = sumf; +} + +static void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_MXFP4 == 0); + static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); + + const block_mxfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK_MXFP4; + + int ib = 0; + float sumf = 0; + + // Load the lookup table once. + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_mxfp4, 16); + int acc1, acc2; + + // We process 2 blocks at once. + for (; ib + 1 < nb; ib+=2) { + // Weights and activations. + vuint8mf2_t mx_packed1 = __riscv_vle8_v_u8mf2(x[ib + 0].qs, 16); + vint8mf2_t q8b_lo1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs, 16); + vint8mf2_t q8b_hi1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs + 16, 16); + vuint8mf2_t mx_packed2 = __riscv_vle8_v_u8mf2(x[ib + 1].qs, 16); + vint8mf2_t q8b_lo2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs, 16); + vint8mf2_t q8b_hi2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs + 16, 16); + + // Unpack the weight blocks. + vuint8mf2_t mxbits_lo1 = __riscv_vand_vx_u8mf2(mx_packed1, 0xf, 16); + vuint8mf2_t mxbits_hi1 = __riscv_vsrl_vx_u8mf2(mx_packed1, 4, 16); + vuint8mf2_t mxbits_lo2 = __riscv_vand_vx_u8mf2(mx_packed2, 0xf, 16); + vuint8mf2_t mxbits_hi2 = __riscv_vsrl_vx_u8mf2(mx_packed2, 4, 16); + + // Gather values from the lookup table. + vint8mf2_t mxb_lo1 = __riscv_vrgather_vv_i8mf2(values, mxbits_lo1, 16); + vint8mf2_t mxb_hi1 = __riscv_vrgather_vv_i8mf2(values, mxbits_hi1, 16); + vint8mf2_t mxb_lo2 = __riscv_vrgather_vv_i8mf2(values, mxbits_lo2, 16); + vint8mf2_t mxb_hi2 = __riscv_vrgather_vv_i8mf2(values, mxbits_hi2, 16); + + // Accumulation. + vint16m1_t sum1 = __riscv_vwmul_vv_i16m1(q8b_lo1, mxb_lo1, 16); + sum1 = __riscv_vwmacc_vv_i16m1(sum1, q8b_hi1, mxb_hi1, 16); + vint16m1_t sum2 = __riscv_vwmul_vv_i16m1(q8b_lo2, mxb_lo2, 16); + sum2 = __riscv_vwmacc_vv_i16m1(sum2, q8b_hi2, mxb_hi2, 16); + __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m1_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 16), 1); + __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m1_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 16), 1); + sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1)); + sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2)); + } + + *s = sumf; +} + +void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_mxfp4_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + return ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +}