ggml-cpu: add rvv vec_dot for iq3_s

Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai>
This commit is contained in:
taimur-10x 2026-01-12 20:01:03 +05:00
parent e7083a7882
commit 1cf6b94c7c
1 changed files with 102 additions and 0 deletions

View File

@ -2174,3 +2174,105 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_iq3_s * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;
const int nb = n / QK_K;
#if defined __riscv_v_intrinsic
const uint64_t * grid64 = (const uint64_t *)iq3s_grid;
// --- Pre-load Constants ---
const uint16_t qh_bit_shifts_arr[16] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
};
vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64);
vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64);
vuint16m1_t v_qh_shifts = __riscv_vle16_v_u16m1(qh_bit_shifts_arr, 16);
float sumf = 0.0f;
for (int i = 0; i < nb; ++i) {
const float d = GGML_CPU_FP16_TO_FP32(x[i].d);
const float combined_scale = d * y[i].d;
const uint8_t * GGML_RESTRICT qs = x[i].qs;
const uint8_t * GGML_RESTRICT qh = x[i].qh;
const uint8_t * GGML_RESTRICT scales = x[i].scales;
const uint8_t * GGML_RESTRICT signs = x[i].signs;
const int8_t * GGML_RESTRICT q8 = y[i].qs;
float sum_block = 0.0f;
// Loop: Process 64 weights (16 mini-blocks of 4) per iteration
for (int ib = 0; ib < 4; ++ib) {
vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 16);
qs += 16;
uint16_t qh_val;
memcpy(&qh_val, qh, 2);
qh += 2;
vuint16m1_t v_qh_val = __riscv_vmv_v_x_u16m1(qh_val, 16);
// Extract bits: (qh >> i) & 1
v_qh_val = __riscv_vsrl_vv_u16m1(v_qh_val, v_qh_shifts, 16);
v_qh_val = __riscv_vand_vx_u16m1(v_qh_val, 1, 16);
vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 16);
v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 16);
v_qh_val = __riscv_vsll_vx_u16m1(v_qh_val, 10, 16);
vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_val, 16);
//grid value is 4xuint8
vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16);
vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed);
vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8);
signs += 8;
// generate sign Mask
vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw);
vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64);
vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64);
vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64);
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
q8 += 64;
// Apply Signs
vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64);
vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64);
// Reduction
vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0);
vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1);
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32));
int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32));
// Apply sub Scales
uint8_t sc_byte = *scales++;
int sc_lo = (sc_byte & 0xF) * 2 + 1;
int sc_hi = (sc_byte >> 4) * 2 + 1;
sum_block += s_lo * sc_lo + s_hi * sc_hi;
}
sumf += sum_block * combined_scale;
}
*s = 0.125f * sumf;
#else
UNUSED(x);
UNUSED(y);
UNUSED(nb);
ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
}