ggml-cpu: add vec_dot for iq2_xs, iq3_xxs

This commit is contained in:
Rehan Qasim 2025-12-20 02:57:06 +05:00 committed by taimur-10x
parent c646e71982
commit 1240fd1d84
2 changed files with 225 additions and 21 deletions

View File

@ -155,9 +155,7 @@
// quants.c
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K
#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K
#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K

View File

@ -2089,7 +2089,7 @@ static const int8_t keven_signs_q2xs[1024] = {
};
#endif
void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
@ -2116,7 +2116,7 @@ void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs,
float sum = 0.0f;
#pragma GCC nounroll
#pragma GCC unroll 1
for (int ib32 = 0; ib32 < QK_K / 32; ib32 += 2) {
vint8m2_t q8_1 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32;
vint8m2_t q8_2 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32;
@ -2180,7 +2180,7 @@ void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs,
*s = 0.125f * sumf;
}
void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
static void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
@ -2278,16 +2278,18 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
#if defined __riscv_v_intrinsic
switch (__riscv_vlenb() * 8) {
case 128:
return ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);
ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc);
break;
default:
return ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
break;
}
#else
return ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc);
ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
@ -2340,7 +2342,7 @@ void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, c
*s = sumf;
}
void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
static void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
@ -2401,16 +2403,18 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
#if defined __riscv_v_intrinsic
switch (__riscv_vlenb() * 8) {
case 128:
return ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc);
ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc);
break;
default:
return ggml_vec_dot_iq4_nl_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc);
ggml_vec_dot_iq4_nl_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc);
break;
}
#else
return ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
@ -2463,7 +2467,7 @@ void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, co
*s = sumf;
}
void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
static void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
@ -2524,16 +2528,18 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
#if defined __riscv_v_intrinsic
switch (__riscv_vlenb() * 8) {
case 128:
return ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc);
ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc);
break;
default:
return ggml_vec_dot_mxfp4_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc);
ggml_vec_dot_mxfp4_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc);
break;
}
#else
return ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
@ -2621,11 +2627,211 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
#if defined __riscv_v_intrinsic
switch (__riscv_vlenb() * 8) {
case 256:
return ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
break;
default:
return ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
break;
}
#else
return ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_iq3_xxs * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;
const int nb = n / QK_K;
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid;
// constants for unpacking logic
const uint32_t shifts_val[8] = {0, 7, 14, 21, 0, 7, 14, 21};
vuint32m1_t v_shifts = __riscv_vle32_v_u32m1(shifts_val, 8);
const uint32_t gather_idx_val[8] = {0, 0, 0, 0, 1, 1, 1, 1};
vuint32m1_t v_gather_idx = __riscv_vle32_v_u32m1(gather_idx_val, 8);
uint32_t aux32[2];
float sumf = 0.0f;
for (int i = 0; i < nb; ++i) {
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * GGML_RESTRICT q3_indices = x[i].qs;
const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4;
const int8_t * GGML_RESTRICT q8 = y[i].qs;
float block_sum = 0.0f;
for (int ib = 0; ib < QK_K / 64; ++ib) {
// Load q8 (64 bytes)
vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64);
q8 += 64;
// load of metadata via memcpy
memcpy(aux32, metadata, 2 * sizeof(uint32_t));
metadata += 2 * sizeof(uint32_t);
// Load q3 indices and gather magnitudes
vuint8mf2_t v_q3_idx_u8 = __riscv_vle8_v_u8mf2(q3_indices, 16);
q3_indices += 16;
vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_u8, 4, 16);
vuint32m2_t v_q3_magnitudes_u32 = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 16);
vint8m2_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u32m2_u8m2(v_q3_magnitudes_u32));
// --- Unpacking of Sign Indices ---
// 1. Load the 2 auxiliary 32-bit integers into a vector
vuint32m1_t v_aux = __riscv_vle32_v_u32m1(aux32, 2);
// 2. Broadcast/Gather: replicate aux[0] to first 4 lanes, aux[1] to next 4 lanes
vuint32m1_t v_aux_expanded = __riscv_vrgather_vv_u32m1(v_aux, v_gather_idx, 8);
// 3. Apply Shifts and Mask: ((val >> shift) & 127)
vuint32m1_t v_s_vals_raw = __riscv_vand_vx_u32m1(__riscv_vsrl_vv_u32m1(v_aux_expanded, v_shifts, 8), 127, 8);
// 4. Narrow to u16 (required for vluxei index) and multiply by 8 (byte offset for u64 table)
vuint16mf2_t sign_indices_byte_offset = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_vals_raw, 8), 3, 8);
// 5. Gather Signs
vuint64m2_t v_s_vals_u64 = __riscv_vluxei16_v_u64m2(signs64, sign_indices_byte_offset, 8);
vint8m2_t v_s_vals = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(v_s_vals_u64));
vint8m2_t v_q3_signed = __riscv_vmul_vv_i8m2(v_q3_magnitudes, v_s_vals, 64);
vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_q8, v_q3_signed, 64);
vint16m2_t v_dot_1 = __riscv_vget_v_i16m4_i16m2(v_dot, 0);
vint16m2_t v_dot_2 = __riscv_vget_v_i16m4_i16m2(v_dot, 1);
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1);
vint32m1_t v_sum_1 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_1, v_zero, 32);
vint32m1_t v_sum_2 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_2, v_zero, 32);
int32_t sum1_i = __riscv_vmv_x_s_i32m1_i32(v_sum_1);
int32_t sum2_i = __riscv_vmv_x_s_i32m1_i32(v_sum_2);
const float scale1_f = (float)(2 * (aux32[0] >> 28) + 1);
const float scale2_f = (float)(2 * (aux32[1] >> 28) + 1);
block_sum += sum1_i * scale1_f + sum2_i * scale2_f;
}
sumf += d * block_sum;
}
*s = 0.25f * sumf;
}
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
#if defined __riscv_v_intrinsic
switch (__riscv_vlenb() * 8) {
case 256:
return ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
default:
return ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
}
#endif
return ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
}
static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_iq2_xs * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;
const int nb = n / QK_K;
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
const uint64_t * grid64 = (const uint64_t *)iq2xs_grid;
float sumf = 0.0f;
for (int i = 0; i < nb; ++i) {
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * GGML_RESTRICT qs = x[i].qs;
const int8_t * GGML_RESTRICT q8 = y[i].qs;
const uint8_t * GGML_RESTRICT scales = x[i].scales;
int32_t sum_int = 0;
// Loop over 4 subblocks of 64 elements (QK_K = 256)
for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) {
// Load 8 uint16 indices (controls 64 values)
vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2(qs, 8);
qs += 8;
// Extract indices for grid (low 9 bits) and signs (high 7 bits)
// Multiply by 8 (<< 3) for byte offsets into the uint64 tables
vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2(__riscv_vand_vx_u16mf2(v_qs, 511, 8), 3, 8);
vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2(__riscv_vsrl_vx_u16mf2(v_qs, 9, 8), 3, 8);
vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2(grid64, vidx_grid, 8);
vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2(signs64, vidx_sign, 8);
vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64));
vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64));
// Apply signs
vint8m2_t q2_final = __riscv_vmul_vv_i8m2(q2u, q2s, 64);
// Load Q8 weights (64 elements)
vint8m2_t q8v = __riscv_vle8_v_i8m2(q8, 64);
q8 += 64;
// Multiply (Widening to int16, 64 elements -> LMUL=4)
vint16m4_t prod = __riscv_vwmul_vv_i16m4(q2_final, q8v, 64);
// Reduction
vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1);
int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
__riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 16));
int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
__riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 16));
int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
__riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 16));
int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(
__riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 16));
// Apply Scales
const uint8_t scale_byte_1 = scales[0];
const uint8_t scale_byte_2 = scales[1];
scales += 2;
sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1);
sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1);
sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1);
sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1);
}
sumf += d * sum_int;
}
*s = 0.125f * sumf;
}
void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
#if defined __riscv_v_intrinsic
switch (__riscv_vlenb() * 8) {
case 256:
return ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc);
default:
return ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
}
#endif
return ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
}