From f7b3fc008b0b935b6a368d672870f2064d583814 Mon Sep 17 00:00:00 2001 From: Xiongchuan Tan Date: Wed, 24 Dec 2025 23:43:25 +0800 Subject: [PATCH] ggml-cpu : add riscv vec dot kernel dispatch based on vlen --- ggml/src/ggml-cpu/CMakeLists.txt | 1 + ggml/src/ggml-cpu/arch/riscv/dispatch.cpp | 100 ++++++ ggml/src/ggml-cpu/arch/riscv/kernels.inc | 3 + ggml/src/ggml-cpu/arch/riscv/quants.c | 419 +++++++++++----------- 4 files changed, 322 insertions(+), 201 deletions(-) create mode 100644 ggml/src/ggml-cpu/arch/riscv/dispatch.cpp create mode 100644 ggml/src/ggml-cpu/arch/riscv/kernels.inc diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 28fb7612e5..9f5ee430e0 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -443,6 +443,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/riscv/quants.c ggml-cpu/arch/riscv/repack.cpp + ggml-cpu/arch/riscv/dispatch.cpp ) if (GGML_CPU_RISCV64_SPACEMIT) target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC}) diff --git a/ggml/src/ggml-cpu/arch/riscv/dispatch.cpp b/ggml/src/ggml-cpu/arch/riscv/dispatch.cpp new file mode 100644 index 0000000000..5f726c119b --- /dev/null +++ b/ggml/src/ggml-cpu/arch/riscv/dispatch.cpp @@ -0,0 +1,100 @@ +#include +#include +#include + +#include "ggml-cpu.h" +#include "quants.h" +#include "kernels.inc" + +#if defined(__riscv_v) + +// helper macros for runtime kernel dispatch + +#define RVV_VEC_DOT_DISPATCH_PAIR(func_name, MINVLEN, SUFFIX) \ + if (vlenb >= MINVLEN) { \ + return func_name##SUFFIX; \ + } + +#define RVV_VEC_DOT_DISPATCH_2(func_name, c1, s1) \ + RVV_VEC_DOT_DISPATCH_PAIR(func_name, c1, s1) + +#define RVV_VEC_DOT_DISPATCH_4(func_name, c1, s1, ...) \ + RVV_VEC_DOT_DISPATCH_PAIR(func_name, c1, s1) \ + RVV_VEC_DOT_DISPATCH_2(func_name, __VA_ARGS__) + +#define RVV_VEC_DOT_DISPATCH_6(func_name, c1, s1, ...) \ + RVV_VEC_DOT_DISPATCH_PAIR(func_name, c1, s1) \ + RVV_VEC_DOT_DISPATCH_4(func_name, __VA_ARGS__) +// add more if needed + +#define GET_RVV_VEC_DOT_DISPATCH_MACRO(_1, _2, _3, _4, _5, _6, NAME, ...) NAME + +#define RVV_VEC_DOT_DISPATCH_CHECKS(func_name, ...) \ + GET_RVV_VEC_DOT_DISPATCH_MACRO(__VA_ARGS__, RVV_VEC_DOT_DISPATCH_6, \ + SKIP, RVV_VEC_DOT_DISPATCH_4, \ + SKIP, RVV_VEC_DOT_DISPATCH_2)(func_name, __VA_ARGS__) + +#define RVV_VEC_DOT_DISPATCH(func_name, ...) \ + static ggml_vec_dot_t func_name##_kernel_sel() { \ + int vlenb = probe_vlenb(); \ + RVV_VEC_DOT_DISPATCH_CHECKS(func_name, __VA_ARGS__) \ + return func_name##_generic; \ + } \ + static ggml_vec_dot_t func_name##_kernel = func_name##_kernel_sel(); \ + void func_name(int n, float * GGML_RESTRICT s, size_t bs, \ + const void * GGML_RESTRICT vx, size_t bx, \ + const void * GGML_RESTRICT vy, size_t by, int nrc) { \ + (func_name##_kernel)(n, s, bs, vx, bx, vy, by, nrc); \ + } + +#include + +static bool probe_rvv() { + bool has_rvv = false; + + struct riscv_hwprobe probe; + probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0; + probe.value = 0; + + int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0); + + if (0 == ret) { + has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V); + } + + return has_rvv; +} + +static int probe_vlenb() { + if (probe_rvv()) { + return __riscv_vlenb(); + } + return 0; +} + +#elif defined(__riscv_xtheadvector) + +#define RVV_VEC_DOT_DISPATCH(func_name, ...) \ + void func_name(int n, float * GGML_RESTRICT s, size_t bs, \ + const void * GGML_RESTRICT vx, size_t bx, \ + const void * GGML_RESTRICT vy, size_t by, int nrc) { \ + (func_name##_071)(n, s, bs, vx, bx, vy, by, nrc); \ + } + +#else + +#define RVV_VEC_DOT_DISPATCH(func_name, ...) \ + void func_name(int n, float * GGML_RESTRICT s, size_t bs, \ + const void * GGML_RESTRICT vx, size_t bx, \ + const void * GGML_RESTRICT vy, size_t by, int nrc) { \ + (func_name##_generic)(n, s, bs, vx, bx, vy, by, nrc); \ + } + +#endif + +extern "C" { + +RVV_VEC_DOT_DISPATCH(ggml_vec_dot_q2_K_q8_K, 32, _256, 16, _128) + +} + diff --git a/ggml/src/ggml-cpu/arch/riscv/kernels.inc b/ggml/src/ggml-cpu/arch/riscv/kernels.inc new file mode 100644 index 0000000000..72bd090818 --- /dev/null +++ b/ggml/src/ggml-cpu/arch/riscv/kernels.inc @@ -0,0 +1,3 @@ +void ggml_vec_dot_q2_K_q8_K_071(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q2_K_q8_K_256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q2_K_q8_K_128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index ae0ebb3cad..2833dc86d9 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -8,6 +8,8 @@ #include "../../quants.h" #include "../../ggml-cpu-impl.h" +#include "kernels.inc" + #include #include #include @@ -376,7 +378,9 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined(__riscv_xtheadvector) + +void ggml_vec_dot_q2_K_q8_K_071(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -388,8 +392,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - float sumf = 0; uint8_t atmp[16]; @@ -484,245 +486,260 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} -#elif defined __riscv_v +#elif defined(__riscv_v) + +void ggml_vec_dot_q2_K_q8_K_256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; float sumf = 0; uint8_t atmp[16]; - const int vector_length = __riscv_vlenb() * 8; uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; - const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - size_t vl = 16; + size_t vl = 16; - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - vl = 32; + vl = 32; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - uint8_t is = 0; - int isum = 0; + uint8_t is = 0; + int isum = 0; - for (int j = 0; j < QK_K / 128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + for (int j = 0; j < QK_K / 128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); - isum += __riscv_vmv_x_s_i32m1_i32(isum1); + isum += __riscv_vmv_x_s_i32m1_i32(isum1); - q2 += 32; - q8 += 128; - is = 8; - } - - sumf += dall * isum; + q2 += 32; + q8 += 128; + is = 8; } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - uint8_t *patmp = atmp; - int vsums; - int tmp, t1, t2, t3, t4, t5, t6, t7; + + sumf += dall * isum; + } + + *s = sumf; +} + +void ggml_vec_dot_q2_K_q8_K_128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + uint8_t atmp[16]; + + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + uint8_t *patmp = atmp; + int vsums; + int tmp, t1, t2, t3, t4, t5, t6, t7; + __asm__ __volatile__( + "vsetivli zero, 16, e8, m1\n\t" + "vmv.v.x v8, zero\n\t" + "lb zero, 15(%[sc])\n\t" + "vle8.v v1, (%[sc])\n\t" + "vle8.v v2, (%[bsums])\n\t" + "addi %[tmp], %[bsums], 16\n\t" + "vand.vi v0, v1, 0xF\n\t" + "vsrl.vi v1, v1, 4\n\t" + "vle8.v v3, (%[tmp])\n\t" + "vse8.v v0, (%[scale])\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vzext.vf2 v0, v1\n\t" + "vwmul.vv v4, v0, v2\n\t" + "vsetivli zero, 16, e32, m4\n\t" + "vredsum.vs v8, v4, v8\n\t" + "vmv.x.s %[vsums], v8" + : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) + : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf += dmin * vsums; + int isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { __asm__ __volatile__( + "lb zero, 31(%[q2])\n\t" + "addi %[tmp], %[q2], 16\n\t" + "addi %[t1], %[q8], 16\n\t" "vsetivli zero, 16, e8, m1\n\t" - "vmv.v.x v8, zero\n\t" - "lb zero, 15(%[sc])\n\t" - "vle8.v v1, (%[sc])\n\t" - "vle8.v v2, (%[bsums])\n\t" - "addi %[tmp], %[bsums], 16\n\t" - "vand.vi v0, v1, 0xF\n\t" - "vsrl.vi v1, v1, 4\n\t" - "vle8.v v3, (%[tmp])\n\t" - "vse8.v v0, (%[scale])\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vzext.vf2 v0, v1\n\t" - "vwmul.vv v4, v0, v2\n\t" - "vsetivli zero, 16, e32, m4\n\t" - "vredsum.vs v8, v4, v8\n\t" - "vmv.x.s %[vsums], v8" - : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) - : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) + "vle8.v v0, (%[q2])\n\t" + "vle8.v v1, (%[tmp])\n\t" + "vsrl.vi v2, v0, 2\n\t" + "vsrl.vi v3, v1, 2\n\t" + "vsrl.vi v4, v0, 4\n\t" + "addi %[tmp], %[q8], 32\n\t" + "vle8.v v8, (%[q8])\n\t" + "vle8.v v9, (%[t1])\n\t" + "addi %[t1], %[t1], 32\n\t" + "vsrl.vi v5, v1, 4\n\t" + "vsrl.vi v6, v0, 6\n\t" + "vsrl.vi v7, v1, 6\n\t" + "vle8.v v10, (%[tmp])\n\t" + "vle8.v v11, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v0, v0, 0x3\n\t" + "vand.vi v1, v1, 0x3\n\t" + "vand.vi v2, v2, 0x3\n\t" + "vle8.v v12, (%[tmp])\n\t" + "vle8.v v13, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v3, v3, 0x3\n\t" + "vand.vi v4, v4, 0x3\n\t" + "vand.vi v5, v5, 0x3\n\t" + "vle8.v v14, (%[tmp])\n\t" + "vle8.v v15, (%[t1])\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v18, v1, v9\n\t" + "vwmul.vv v20, v2, v10\n\t" + "vwmul.vv v22, v3, v11\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vwmul.vv v26, v5, v13\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmul.vv v30, v7, v15\n\t" + "vsetivli zero, 8, e16, m1\n\t" + "vmv.v.x v0, zero\n\t" + "lbu %[tmp], 0(%[scale])\n\t" + "vwredsum.vs v8, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "lbu %[t1], 1(%[scale])\n\t" + "vwredsum.vs v10, v20, v0\n\t" + "vwredsum.vs v11, v22, v0\n\t" + "lbu %[t2], 2(%[scale])\n\t" + "vwredsum.vs v12, v24, v0\n\t" + "vwredsum.vs v13, v26, v0\n\t" + "lbu %[t3], 3(%[scale])\n\t" + "vwredsum.vs v14, v28, v0\n\t" + "vwredsum.vs v15, v30, v0\n\t" + "lbu %[t4], 4(%[scale])\n\t" + "vwredsum.vs v8, v17, v8\n\t" + "vwredsum.vs v9, v19, v9\n\t" + "lbu %[t5], 5(%[scale])\n\t" + "vwredsum.vs v10, v21, v10\n\t" + "vwredsum.vs v11, v23, v11\n\t" + "lbu %[t6], 6(%[scale])\n\t" + "vwredsum.vs v12, v25, v12\n\t" + "vwredsum.vs v13, v27, v13\n\t" + "lbu %[t7], 7(%[scale])\n\t" + "vwredsum.vs v14, v29, v14\n\t" + "vwredsum.vs v15, v31, v15\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v8, %[tmp]\n\t" + "vmul.vx v1, v9, %[t1]\n\t" + "vmacc.vx v0, %[t2], v10\n\t" + "vmacc.vx v1, %[t3], v11\n\t" + "vmacc.vx v0, %[t4], v12\n\t" + "vmacc.vx v1, %[t5], v13\n\t" + "vmacc.vx v0, %[t6], v14\n\t" + "vmacc.vx v1, %[t7], v15\n\t" + "vmv.x.s %[tmp], v0\n\t" + "vmv.x.s %[t1], v1\n\t" + "add %[isum], %[isum], %[tmp]\n\t" + "add %[isum], %[isum], %[t1]" + : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) + , [isum] "+&r" (isum) + : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) : "memory" , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" ); - sumf += dmin * vsums; - int isum = 0; - - for (int j = 0; j < QK_K/128; ++j) { - __asm__ __volatile__( - "lb zero, 31(%[q2])\n\t" - "addi %[tmp], %[q2], 16\n\t" - "addi %[t1], %[q8], 16\n\t" - "vsetivli zero, 16, e8, m1\n\t" - "vle8.v v0, (%[q2])\n\t" - "vle8.v v1, (%[tmp])\n\t" - "vsrl.vi v2, v0, 2\n\t" - "vsrl.vi v3, v1, 2\n\t" - "vsrl.vi v4, v0, 4\n\t" - "addi %[tmp], %[q8], 32\n\t" - "vle8.v v8, (%[q8])\n\t" - "vle8.v v9, (%[t1])\n\t" - "addi %[t1], %[t1], 32\n\t" - "vsrl.vi v5, v1, 4\n\t" - "vsrl.vi v6, v0, 6\n\t" - "vsrl.vi v7, v1, 6\n\t" - "vle8.v v10, (%[tmp])\n\t" - "vle8.v v11, (%[t1])\n\t" - "addi %[tmp], %[tmp], 32\n\t" - "addi %[t1], %[t1], 32\n\t" - "vand.vi v0, v0, 0x3\n\t" - "vand.vi v1, v1, 0x3\n\t" - "vand.vi v2, v2, 0x3\n\t" - "vle8.v v12, (%[tmp])\n\t" - "vle8.v v13, (%[t1])\n\t" - "addi %[tmp], %[tmp], 32\n\t" - "addi %[t1], %[t1], 32\n\t" - "vand.vi v3, v3, 0x3\n\t" - "vand.vi v4, v4, 0x3\n\t" - "vand.vi v5, v5, 0x3\n\t" - "vle8.v v14, (%[tmp])\n\t" - "vle8.v v15, (%[t1])\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v18, v1, v9\n\t" - "vwmul.vv v20, v2, v10\n\t" - "vwmul.vv v22, v3, v11\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vwmul.vv v26, v5, v13\n\t" - "vwmul.vv v28, v6, v14\n\t" - "vwmul.vv v30, v7, v15\n\t" - "vsetivli zero, 8, e16, m1\n\t" - "vmv.v.x v0, zero\n\t" - "lbu %[tmp], 0(%[scale])\n\t" - "vwredsum.vs v8, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "lbu %[t1], 1(%[scale])\n\t" - "vwredsum.vs v10, v20, v0\n\t" - "vwredsum.vs v11, v22, v0\n\t" - "lbu %[t2], 2(%[scale])\n\t" - "vwredsum.vs v12, v24, v0\n\t" - "vwredsum.vs v13, v26, v0\n\t" - "lbu %[t3], 3(%[scale])\n\t" - "vwredsum.vs v14, v28, v0\n\t" - "vwredsum.vs v15, v30, v0\n\t" - "lbu %[t4], 4(%[scale])\n\t" - "vwredsum.vs v8, v17, v8\n\t" - "vwredsum.vs v9, v19, v9\n\t" - "lbu %[t5], 5(%[scale])\n\t" - "vwredsum.vs v10, v21, v10\n\t" - "vwredsum.vs v11, v23, v11\n\t" - "lbu %[t6], 6(%[scale])\n\t" - "vwredsum.vs v12, v25, v12\n\t" - "vwredsum.vs v13, v27, v13\n\t" - "lbu %[t7], 7(%[scale])\n\t" - "vwredsum.vs v14, v29, v14\n\t" - "vwredsum.vs v15, v31, v15\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v8, %[tmp]\n\t" - "vmul.vx v1, v9, %[t1]\n\t" - "vmacc.vx v0, %[t2], v10\n\t" - "vmacc.vx v1, %[t3], v11\n\t" - "vmacc.vx v0, %[t4], v12\n\t" - "vmacc.vx v1, %[t5], v13\n\t" - "vmacc.vx v0, %[t6], v14\n\t" - "vmacc.vx v1, %[t7], v15\n\t" - "vmv.x.s %[tmp], v0\n\t" - "vmv.x.s %[t1], v1\n\t" - "add %[isum], %[isum], %[tmp]\n\t" - "add %[isum], %[isum], %[t1]" - : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) - , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) - , [isum] "+&r" (isum) - : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q2 += 32; q8 += 128; patmp += 8; - } - - sumf += dall * isum; + q2 += 32; q8 += 128; patmp += 8; } - break; - default: - assert(false && "Unsupported vector length"); - break; + + sumf += dall * isum; } *s = sumf; - -#else - - UNUSED(x); - UNUSED(y); - UNUSED(nb); - - ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif } +#endif // ggml_vec_dot_q2_K_q8_K + void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1);