Merge 0e373e245b into 05a6f0e894
This commit is contained in:
commit
103659e553
|
|
@ -441,6 +441,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||
elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64")
|
||||
message(STATUS "riscv64 detected")
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu/arch/riscv/dispatch.cpp
|
||||
ggml-cpu/arch/riscv/quants.c
|
||||
ggml-cpu/arch/riscv/repack.cpp
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,112 @@
|
|||
#include <asm/hwprobe.h>
|
||||
#include <asm/unistd.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "ggml-cpu.h"
|
||||
#include "quants.h"
|
||||
|
||||
extern "C" {
|
||||
#include "kernels.inc"
|
||||
}
|
||||
|
||||
#if defined(__riscv_v) && __riscv_v >= 1000000
|
||||
|
||||
// helper macros for runtime kernel dispatch
|
||||
|
||||
#define RVV_VEC_DOT_DISPATCH_PAIR(func_name, MINVLEN, SUFFIX) \
|
||||
if (vlenb >= MINVLEN) { \
|
||||
return func_name##SUFFIX; \
|
||||
}
|
||||
|
||||
#define RVV_VEC_DOT_DISPATCH_2(func_name, c1, s1) \
|
||||
RVV_VEC_DOT_DISPATCH_PAIR(func_name, c1, s1)
|
||||
|
||||
#define RVV_VEC_DOT_DISPATCH_4(func_name, c1, s1, ...) \
|
||||
RVV_VEC_DOT_DISPATCH_PAIR(func_name, c1, s1) \
|
||||
RVV_VEC_DOT_DISPATCH_2(func_name, __VA_ARGS__)
|
||||
|
||||
#define RVV_VEC_DOT_DISPATCH_6(func_name, c1, s1, ...) \
|
||||
RVV_VEC_DOT_DISPATCH_PAIR(func_name, c1, s1) \
|
||||
RVV_VEC_DOT_DISPATCH_4(func_name, __VA_ARGS__)
|
||||
// add more if needed
|
||||
|
||||
#define GET_RVV_VEC_DOT_DISPATCH_MACRO(_1, _2, _3, _4, _5, _6, NAME, ...) NAME
|
||||
|
||||
#define RVV_VEC_DOT_DISPATCH_CHECKS(func_name, ...) \
|
||||
GET_RVV_VEC_DOT_DISPATCH_MACRO(__VA_ARGS__, RVV_VEC_DOT_DISPATCH_6, \
|
||||
SKIP, RVV_VEC_DOT_DISPATCH_4, \
|
||||
SKIP, RVV_VEC_DOT_DISPATCH_2)(func_name, __VA_ARGS__)
|
||||
|
||||
#define RVV_VEC_DOT_DISPATCH(func_name, ...) \
|
||||
static ggml_vec_dot_t func_name##_kernel_sel() { \
|
||||
int vlenb = dispatch_vlenb; \
|
||||
RVV_VEC_DOT_DISPATCH_CHECKS(func_name, __VA_ARGS__) \
|
||||
return func_name##_generic; \
|
||||
} \
|
||||
static ggml_vec_dot_t func_name##_kernel = func_name##_kernel_sel(); \
|
||||
void func_name(int n, float * GGML_RESTRICT s, size_t bs, \
|
||||
const void * GGML_RESTRICT vx, size_t bx, \
|
||||
const void * GGML_RESTRICT vy, size_t by, int nrc) { \
|
||||
(func_name##_kernel)(n, s, bs, vx, bx, vy, by, nrc); \
|
||||
}
|
||||
|
||||
#include <riscv_vector.h>
|
||||
|
||||
static bool probe_rvv() {
|
||||
bool has_rvv = false;
|
||||
|
||||
struct riscv_hwprobe probe;
|
||||
probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0;
|
||||
probe.value = 0;
|
||||
|
||||
int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0);
|
||||
|
||||
if (0 == ret) {
|
||||
has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V);
|
||||
}
|
||||
|
||||
return has_rvv;
|
||||
}
|
||||
|
||||
static int probe_vlenb() {
|
||||
if (probe_rvv()) {
|
||||
return __riscv_vlenb();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
static int dispatch_vlenb = probe_vlenb();
|
||||
|
||||
#elif defined(__riscv_xtheadvector)
|
||||
|
||||
void ggml_vec_dot_q5_K_q8_K_071(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
|
||||
}
|
||||
|
||||
#define RVV_VEC_DOT_DISPATCH(func_name, ...) \
|
||||
void func_name(int n, float * GGML_RESTRICT s, size_t bs, \
|
||||
const void * GGML_RESTRICT vx, size_t bx, \
|
||||
const void * GGML_RESTRICT vy, size_t by, int nrc) { \
|
||||
(func_name##_071)(n, s, bs, vx, bx, vy, by, nrc); \
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#define RVV_VEC_DOT_DISPATCH(func_name, ...) \
|
||||
void func_name(int n, float * GGML_RESTRICT s, size_t bs, \
|
||||
const void * GGML_RESTRICT vx, size_t bx, \
|
||||
const void * GGML_RESTRICT vy, size_t by, int nrc) { \
|
||||
(func_name##_generic)(n, s, bs, vx, bx, vy, by, nrc); \
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
extern "C" {
|
||||
|
||||
RVV_VEC_DOT_DISPATCH(ggml_vec_dot_q2_K_q8_K, 32, _256, 16, _128)
|
||||
RVV_VEC_DOT_DISPATCH(ggml_vec_dot_q3_K_q8_K, 32, _256, 16, _128)
|
||||
RVV_VEC_DOT_DISPATCH(ggml_vec_dot_q4_K_q8_K, 32, _256, 16, _128)
|
||||
RVV_VEC_DOT_DISPATCH(ggml_vec_dot_q5_K_q8_K, 16, _128)
|
||||
RVV_VEC_DOT_DISPATCH(ggml_vec_dot_q6_K_q8_K, 32, _256, 16, _128)
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
void ggml_vec_dot_q2_K_q8_K_071(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q2_K_q8_K_256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q2_K_q8_K_128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q3_K_q8_K_071(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q3_K_q8_K_256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q3_K_q8_K_128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_K_q8_K_071(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_K_q8_K_256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_K_q8_K_128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_K_q8_K_071(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_K_q8_K_128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q6_K_q8_K_071(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q6_K_q8_K_256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q6_K_q8_K_128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -43,7 +43,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined __riscv_v
|
||||
#if defined(__riscv_v) && __riscv_v >= 1000000
|
||||
if (__riscv_vlenb() >= QK4_0) {
|
||||
const size_t vl = QK4_0;
|
||||
|
||||
|
|
@ -135,7 +135,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined __riscv_v
|
||||
#if defined(__riscv_v) && __riscv_v >= 1000000
|
||||
if (__riscv_vlenb() >= QK4_0) {
|
||||
const size_t vl = QK4_0;
|
||||
|
||||
|
|
|
|||
|
|
@ -391,6 +391,8 @@ template <> inline vfloat16m2_t load(const ggml_fp16_t *p) {
|
|||
template <> inline vfloat16m4_t load(const ggml_fp16_t *p) {
|
||||
return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4());
|
||||
}
|
||||
#endif
|
||||
#if defined(__riscv_v) && __riscv_v >= 1000000
|
||||
template <> inline vfloat32m1_t load(const float *p) {
|
||||
return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
|
|
@ -432,6 +434,8 @@ template <> inline vfloat16m2_t set_zero() {
|
|||
template <> inline vfloat16m4_t set_zero() {
|
||||
return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4());
|
||||
}
|
||||
#endif
|
||||
#if defined(__riscv_v) && __riscv_v >= 1000000
|
||||
template <> inline vfloat32m1_t set_zero() {
|
||||
return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1());
|
||||
}
|
||||
|
|
@ -446,7 +450,7 @@ template <> inline vfloat32m8_t set_zero() {
|
|||
}
|
||||
#endif
|
||||
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
#if defined(__riscv_v) && __riscv_v >= 1000000
|
||||
template <typename T> size_t vlmax() {
|
||||
if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); }
|
||||
else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); }
|
||||
|
|
@ -633,7 +637,7 @@ class tinyBLAS {
|
|||
const int64_t ldc;
|
||||
};
|
||||
|
||||
#if defined(__riscv_v_intrinsic)
|
||||
#if defined(__riscv_v) && __riscv_v >= 1000000
|
||||
template <typename D, typename V, typename TA, typename TB, typename TC>
|
||||
class tinyBLAS_RVV {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -93,6 +93,9 @@ extern "C" {
|
|||
return r;
|
||||
}
|
||||
#elif defined(__riscv) && defined(__riscv_zfhmin)
|
||||
// suppress _Float16 warnings
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wpedantic"
|
||||
static inline float riscv_compute_fp16_to_fp32(ggml_fp16_t h) {
|
||||
_Float16 hf;
|
||||
memcpy(&hf, &h, sizeof(ggml_fp16_t));
|
||||
|
|
@ -105,6 +108,7 @@ extern "C" {
|
|||
memcpy(&res, &hf, sizeof(ggml_fp16_t));
|
||||
return res;
|
||||
}
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#define GGML_CPU_COMPUTE_FP16_TO_FP32(x) riscv_compute_fp16_to_fp32(x)
|
||||
#define GGML_CPU_COMPUTE_FP32_TO_FP16(x) riscv_compute_fp32_to_fp16(x)
|
||||
|
|
@ -1209,7 +1213,7 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
|
|||
#define GGML_F16_VEC_MUL GGML_F32x4_MUL
|
||||
#define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE
|
||||
|
||||
#elif defined(__riscv_v_intrinsic)
|
||||
#elif defined(__riscv_v) && __riscv_v >= 1000000
|
||||
|
||||
// compatible with vlen >= 128
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue