This commit is contained in:
xctan 2026-01-02 13:20:56 -03:00 committed by GitHub
commit 0ce7f870b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 322 additions and 201 deletions

View File

@ -443,6 +443,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
list(APPEND GGML_CPU_SOURCES
ggml-cpu/arch/riscv/quants.c
ggml-cpu/arch/riscv/repack.cpp
ggml-cpu/arch/riscv/dispatch.cpp
)
if (GGML_CPU_RISCV64_SPACEMIT)
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC})

View File

@ -0,0 +1,100 @@
#include <asm/hwprobe.h>
#include <asm/unistd.h>
#include <unistd.h>
#include "ggml-cpu.h"
#include "quants.h"
#include "kernels.inc"
#if defined(__riscv_v)
// helper macros for runtime kernel dispatch
#define RVV_VEC_DOT_DISPATCH_PAIR(func_name, MINVLEN, SUFFIX) \
if (vlenb >= MINVLEN) { \
return func_name##SUFFIX; \
}
#define RVV_VEC_DOT_DISPATCH_2(func_name, c1, s1) \
RVV_VEC_DOT_DISPATCH_PAIR(func_name, c1, s1)
#define RVV_VEC_DOT_DISPATCH_4(func_name, c1, s1, ...) \
RVV_VEC_DOT_DISPATCH_PAIR(func_name, c1, s1) \
RVV_VEC_DOT_DISPATCH_2(func_name, __VA_ARGS__)
#define RVV_VEC_DOT_DISPATCH_6(func_name, c1, s1, ...) \
RVV_VEC_DOT_DISPATCH_PAIR(func_name, c1, s1) \
RVV_VEC_DOT_DISPATCH_4(func_name, __VA_ARGS__)
// add more if needed
#define GET_RVV_VEC_DOT_DISPATCH_MACRO(_1, _2, _3, _4, _5, _6, NAME, ...) NAME
#define RVV_VEC_DOT_DISPATCH_CHECKS(func_name, ...) \
GET_RVV_VEC_DOT_DISPATCH_MACRO(__VA_ARGS__, RVV_VEC_DOT_DISPATCH_6, \
SKIP, RVV_VEC_DOT_DISPATCH_4, \
SKIP, RVV_VEC_DOT_DISPATCH_2)(func_name, __VA_ARGS__)
#define RVV_VEC_DOT_DISPATCH(func_name, ...) \
static ggml_vec_dot_t func_name##_kernel_sel() { \
int vlenb = probe_vlenb(); \
RVV_VEC_DOT_DISPATCH_CHECKS(func_name, __VA_ARGS__) \
return func_name##_generic; \
} \
static ggml_vec_dot_t func_name##_kernel = func_name##_kernel_sel(); \
void func_name(int n, float * GGML_RESTRICT s, size_t bs, \
const void * GGML_RESTRICT vx, size_t bx, \
const void * GGML_RESTRICT vy, size_t by, int nrc) { \
(func_name##_kernel)(n, s, bs, vx, bx, vy, by, nrc); \
}
#include <riscv_vector.h>
static bool probe_rvv() {
bool has_rvv = false;
struct riscv_hwprobe probe;
probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0;
probe.value = 0;
int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0);
if (0 == ret) {
has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V);
}
return has_rvv;
}
static int probe_vlenb() {
if (probe_rvv()) {
return __riscv_vlenb();
}
return 0;
}
#elif defined(__riscv_xtheadvector)
#define RVV_VEC_DOT_DISPATCH(func_name, ...) \
void func_name(int n, float * GGML_RESTRICT s, size_t bs, \
const void * GGML_RESTRICT vx, size_t bx, \
const void * GGML_RESTRICT vy, size_t by, int nrc) { \
(func_name##_071)(n, s, bs, vx, bx, vy, by, nrc); \
}
#else
#define RVV_VEC_DOT_DISPATCH(func_name, ...) \
void func_name(int n, float * GGML_RESTRICT s, size_t bs, \
const void * GGML_RESTRICT vx, size_t bx, \
const void * GGML_RESTRICT vy, size_t by, int nrc) { \
(func_name##_generic)(n, s, bs, vx, bx, vy, by, nrc); \
}
#endif
extern "C" {
RVV_VEC_DOT_DISPATCH(ggml_vec_dot_q2_K_q8_K, 32, _256, 16, _128)
}

View File

@ -0,0 +1,3 @@
void ggml_vec_dot_q2_K_q8_K_071(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K_256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_K_q8_K_128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);

View File

@ -8,6 +8,8 @@
#include "../../quants.h"
#include "../../ggml-cpu-impl.h"
#include "kernels.inc"
#include <math.h>
#include <string.h>
#include <assert.h>
@ -376,7 +378,9 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
#endif
}
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
#if defined(__riscv_xtheadvector)
void ggml_vec_dot_q2_K_q8_K_071(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
@ -388,8 +392,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int nb = n / QK_K;
#if defined __riscv_xtheadvector
float sumf = 0;
uint8_t atmp[16];
@ -484,245 +486,260 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
}
*s = sumf;
}
#elif defined __riscv_v
#elif defined(__riscv_v)
void ggml_vec_dot_q2_K_q8_K_256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q2_K * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;
const int nb = n / QK_K;
float sumf = 0;
uint8_t atmp[16];
const int vector_length = __riscv_vlenb() * 8;
uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
switch (vector_length) {
case 256:
for (int i = 0; i < nb; ++i) {
const uint8_t * q2 = x[i].qs;
const int8_t * q8 = y[i].qs;
const uint8_t * sc = x[i].scales;
for (int i = 0; i < nb; ++i) {
const uint8_t * q2 = x[i].qs;
const int8_t * q8 = y[i].qs;
const uint8_t * sc = x[i].scales;
const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
size_t vl = 16;
size_t vl = 16;
vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl);
vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl);
vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl);
vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl);
vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl);
vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl));
vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl);
vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl);
sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums);
vl = 32;
vl = 32;
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1);
vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl);
uint8_t is = 0;
int isum = 0;
uint8_t is = 0;
int isum = 0;
for (int j = 0; j < QK_K / 128; ++j) {
// load Q2
vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
for (int j = 0; j < QK_K / 128; ++j) {
// load Q2
vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl);
vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl);
vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl);
vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl);
vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl);
vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl);
vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl);
vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl);
// duplicate scale elements for product
vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl);
vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl);
vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl);
vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl);
// duplicate scale elements for product
vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl);
vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl);
vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl);
vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl);
vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl));
vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl));
vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl));
vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl));
// load Q8
vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl);
vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl);
vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl);
// load Q8
vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl);
vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl);
vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl);
vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl);
vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl);
vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl);
vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl);
vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl);
vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl);
vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl);
isum += __riscv_vmv_x_s_i32m1_i32(isum1);
isum += __riscv_vmv_x_s_i32m1_i32(isum1);
q2 += 32;
q8 += 128;
is = 8;
}
sumf += dall * isum;
q2 += 32;
q8 += 128;
is = 8;
}
break;
case 128:
for (int i = 0; i < nb; ++i) {
const uint8_t * q2 = x[i].qs;
const int8_t * q8 = y[i].qs;
const uint8_t * sc = x[i].scales;
const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
uint8_t *patmp = atmp;
int vsums;
int tmp, t1, t2, t3, t4, t5, t6, t7;
sumf += dall * isum;
}
*s = sumf;
}
void ggml_vec_dot_q2_K_q8_K_128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q2_K * GGML_RESTRICT x = vx;
const block_q8_K * GGML_RESTRICT y = vy;
const int nb = n / QK_K;
float sumf = 0;
uint8_t atmp[16];
for (int i = 0; i < nb; ++i) {
const uint8_t * q2 = x[i].qs;
const int8_t * q8 = y[i].qs;
const uint8_t * sc = x[i].scales;
const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
uint8_t *patmp = atmp;
int vsums;
int tmp, t1, t2, t3, t4, t5, t6, t7;
__asm__ __volatile__(
"vsetivli zero, 16, e8, m1\n\t"
"vmv.v.x v8, zero\n\t"
"lb zero, 15(%[sc])\n\t"
"vle8.v v1, (%[sc])\n\t"
"vle8.v v2, (%[bsums])\n\t"
"addi %[tmp], %[bsums], 16\n\t"
"vand.vi v0, v1, 0xF\n\t"
"vsrl.vi v1, v1, 4\n\t"
"vle8.v v3, (%[tmp])\n\t"
"vse8.v v0, (%[scale])\n\t"
"vsetivli zero, 16, e16, m2\n\t"
"vzext.vf2 v0, v1\n\t"
"vwmul.vv v4, v0, v2\n\t"
"vsetivli zero, 16, e32, m4\n\t"
"vredsum.vs v8, v4, v8\n\t"
"vmv.x.s %[vsums], v8"
: [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
: [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
: "memory"
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
);
sumf += dmin * vsums;
int isum = 0;
for (int j = 0; j < QK_K/128; ++j) {
__asm__ __volatile__(
"lb zero, 31(%[q2])\n\t"
"addi %[tmp], %[q2], 16\n\t"
"addi %[t1], %[q8], 16\n\t"
"vsetivli zero, 16, e8, m1\n\t"
"vmv.v.x v8, zero\n\t"
"lb zero, 15(%[sc])\n\t"
"vle8.v v1, (%[sc])\n\t"
"vle8.v v2, (%[bsums])\n\t"
"addi %[tmp], %[bsums], 16\n\t"
"vand.vi v0, v1, 0xF\n\t"
"vsrl.vi v1, v1, 4\n\t"
"vle8.v v3, (%[tmp])\n\t"
"vse8.v v0, (%[scale])\n\t"
"vsetivli zero, 16, e16, m2\n\t"
"vzext.vf2 v0, v1\n\t"
"vwmul.vv v4, v0, v2\n\t"
"vsetivli zero, 16, e32, m4\n\t"
"vredsum.vs v8, v4, v8\n\t"
"vmv.x.s %[vsums], v8"
: [tmp] "=&r" (tmp), [vsums] "=&r" (vsums)
: [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums)
"vle8.v v0, (%[q2])\n\t"
"vle8.v v1, (%[tmp])\n\t"
"vsrl.vi v2, v0, 2\n\t"
"vsrl.vi v3, v1, 2\n\t"
"vsrl.vi v4, v0, 4\n\t"
"addi %[tmp], %[q8], 32\n\t"
"vle8.v v8, (%[q8])\n\t"
"vle8.v v9, (%[t1])\n\t"
"addi %[t1], %[t1], 32\n\t"
"vsrl.vi v5, v1, 4\n\t"
"vsrl.vi v6, v0, 6\n\t"
"vsrl.vi v7, v1, 6\n\t"
"vle8.v v10, (%[tmp])\n\t"
"vle8.v v11, (%[t1])\n\t"
"addi %[tmp], %[tmp], 32\n\t"
"addi %[t1], %[t1], 32\n\t"
"vand.vi v0, v0, 0x3\n\t"
"vand.vi v1, v1, 0x3\n\t"
"vand.vi v2, v2, 0x3\n\t"
"vle8.v v12, (%[tmp])\n\t"
"vle8.v v13, (%[t1])\n\t"
"addi %[tmp], %[tmp], 32\n\t"
"addi %[t1], %[t1], 32\n\t"
"vand.vi v3, v3, 0x3\n\t"
"vand.vi v4, v4, 0x3\n\t"
"vand.vi v5, v5, 0x3\n\t"
"vle8.v v14, (%[tmp])\n\t"
"vle8.v v15, (%[t1])\n\t"
"vwmul.vv v16, v0, v8\n\t"
"vwmul.vv v18, v1, v9\n\t"
"vwmul.vv v20, v2, v10\n\t"
"vwmul.vv v22, v3, v11\n\t"
"vwmul.vv v24, v4, v12\n\t"
"vwmul.vv v26, v5, v13\n\t"
"vwmul.vv v28, v6, v14\n\t"
"vwmul.vv v30, v7, v15\n\t"
"vsetivli zero, 8, e16, m1\n\t"
"vmv.v.x v0, zero\n\t"
"lbu %[tmp], 0(%[scale])\n\t"
"vwredsum.vs v8, v16, v0\n\t"
"vwredsum.vs v9, v18, v0\n\t"
"lbu %[t1], 1(%[scale])\n\t"
"vwredsum.vs v10, v20, v0\n\t"
"vwredsum.vs v11, v22, v0\n\t"
"lbu %[t2], 2(%[scale])\n\t"
"vwredsum.vs v12, v24, v0\n\t"
"vwredsum.vs v13, v26, v0\n\t"
"lbu %[t3], 3(%[scale])\n\t"
"vwredsum.vs v14, v28, v0\n\t"
"vwredsum.vs v15, v30, v0\n\t"
"lbu %[t4], 4(%[scale])\n\t"
"vwredsum.vs v8, v17, v8\n\t"
"vwredsum.vs v9, v19, v9\n\t"
"lbu %[t5], 5(%[scale])\n\t"
"vwredsum.vs v10, v21, v10\n\t"
"vwredsum.vs v11, v23, v11\n\t"
"lbu %[t6], 6(%[scale])\n\t"
"vwredsum.vs v12, v25, v12\n\t"
"vwredsum.vs v13, v27, v13\n\t"
"lbu %[t7], 7(%[scale])\n\t"
"vwredsum.vs v14, v29, v14\n\t"
"vwredsum.vs v15, v31, v15\n\t"
"vsetivli zero, 4, e32, m1\n\t"
"vmul.vx v0, v8, %[tmp]\n\t"
"vmul.vx v1, v9, %[t1]\n\t"
"vmacc.vx v0, %[t2], v10\n\t"
"vmacc.vx v1, %[t3], v11\n\t"
"vmacc.vx v0, %[t4], v12\n\t"
"vmacc.vx v1, %[t5], v13\n\t"
"vmacc.vx v0, %[t6], v14\n\t"
"vmacc.vx v1, %[t7], v15\n\t"
"vmv.x.s %[tmp], v0\n\t"
"vmv.x.s %[t1], v1\n\t"
"add %[isum], %[isum], %[tmp]\n\t"
"add %[isum], %[isum], %[t1]"
: [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
, [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
, [isum] "+&r" (isum)
: [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
: "memory"
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
);
sumf += dmin * vsums;
int isum = 0;
for (int j = 0; j < QK_K/128; ++j) {
__asm__ __volatile__(
"lb zero, 31(%[q2])\n\t"
"addi %[tmp], %[q2], 16\n\t"
"addi %[t1], %[q8], 16\n\t"
"vsetivli zero, 16, e8, m1\n\t"
"vle8.v v0, (%[q2])\n\t"
"vle8.v v1, (%[tmp])\n\t"
"vsrl.vi v2, v0, 2\n\t"
"vsrl.vi v3, v1, 2\n\t"
"vsrl.vi v4, v0, 4\n\t"
"addi %[tmp], %[q8], 32\n\t"
"vle8.v v8, (%[q8])\n\t"
"vle8.v v9, (%[t1])\n\t"
"addi %[t1], %[t1], 32\n\t"
"vsrl.vi v5, v1, 4\n\t"
"vsrl.vi v6, v0, 6\n\t"
"vsrl.vi v7, v1, 6\n\t"
"vle8.v v10, (%[tmp])\n\t"
"vle8.v v11, (%[t1])\n\t"
"addi %[tmp], %[tmp], 32\n\t"
"addi %[t1], %[t1], 32\n\t"
"vand.vi v0, v0, 0x3\n\t"
"vand.vi v1, v1, 0x3\n\t"
"vand.vi v2, v2, 0x3\n\t"
"vle8.v v12, (%[tmp])\n\t"
"vle8.v v13, (%[t1])\n\t"
"addi %[tmp], %[tmp], 32\n\t"
"addi %[t1], %[t1], 32\n\t"
"vand.vi v3, v3, 0x3\n\t"
"vand.vi v4, v4, 0x3\n\t"
"vand.vi v5, v5, 0x3\n\t"
"vle8.v v14, (%[tmp])\n\t"
"vle8.v v15, (%[t1])\n\t"
"vwmul.vv v16, v0, v8\n\t"
"vwmul.vv v18, v1, v9\n\t"
"vwmul.vv v20, v2, v10\n\t"
"vwmul.vv v22, v3, v11\n\t"
"vwmul.vv v24, v4, v12\n\t"
"vwmul.vv v26, v5, v13\n\t"
"vwmul.vv v28, v6, v14\n\t"
"vwmul.vv v30, v7, v15\n\t"
"vsetivli zero, 8, e16, m1\n\t"
"vmv.v.x v0, zero\n\t"
"lbu %[tmp], 0(%[scale])\n\t"
"vwredsum.vs v8, v16, v0\n\t"
"vwredsum.vs v9, v18, v0\n\t"
"lbu %[t1], 1(%[scale])\n\t"
"vwredsum.vs v10, v20, v0\n\t"
"vwredsum.vs v11, v22, v0\n\t"
"lbu %[t2], 2(%[scale])\n\t"
"vwredsum.vs v12, v24, v0\n\t"
"vwredsum.vs v13, v26, v0\n\t"
"lbu %[t3], 3(%[scale])\n\t"
"vwredsum.vs v14, v28, v0\n\t"
"vwredsum.vs v15, v30, v0\n\t"
"lbu %[t4], 4(%[scale])\n\t"
"vwredsum.vs v8, v17, v8\n\t"
"vwredsum.vs v9, v19, v9\n\t"
"lbu %[t5], 5(%[scale])\n\t"
"vwredsum.vs v10, v21, v10\n\t"
"vwredsum.vs v11, v23, v11\n\t"
"lbu %[t6], 6(%[scale])\n\t"
"vwredsum.vs v12, v25, v12\n\t"
"vwredsum.vs v13, v27, v13\n\t"
"lbu %[t7], 7(%[scale])\n\t"
"vwredsum.vs v14, v29, v14\n\t"
"vwredsum.vs v15, v31, v15\n\t"
"vsetivli zero, 4, e32, m1\n\t"
"vmul.vx v0, v8, %[tmp]\n\t"
"vmul.vx v1, v9, %[t1]\n\t"
"vmacc.vx v0, %[t2], v10\n\t"
"vmacc.vx v1, %[t3], v11\n\t"
"vmacc.vx v0, %[t4], v12\n\t"
"vmacc.vx v1, %[t5], v13\n\t"
"vmacc.vx v0, %[t6], v14\n\t"
"vmacc.vx v1, %[t7], v15\n\t"
"vmv.x.s %[tmp], v0\n\t"
"vmv.x.s %[t1], v1\n\t"
"add %[isum], %[isum], %[tmp]\n\t"
"add %[isum], %[isum], %[t1]"
: [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3)
, [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7)
, [isum] "+&r" (isum)
: [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8)
: "memory"
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
);
q2 += 32; q8 += 128; patmp += 8;
}
sumf += dall * isum;
q2 += 32; q8 += 128; patmp += 8;
}
break;
default:
assert(false && "Unsupported vector length");
break;
sumf += dall * isum;
}
*s = sumf;
#else
UNUSED(x);
UNUSED(y);
UNUSED(nb);
ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif
}
#endif // ggml_vec_dot_q2_K_q8_K
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);