kleidiai: add and integrate SVE 256-bit vector-length kernel (#18458)
* kleidiai: add and integrate SVE 256-bit vector-length kernel * updated for review comments
This commit is contained in:
parent
d77d7c5c06
commit
2d6c00a9b8
|
|
@ -401,8 +401,8 @@ if (GGML_CPU_ALL_VARIANTS)
|
|||
ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC)
|
||||
ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8)
|
||||
ggml_add_cpu_backend_variant(android_armv9.0_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2)
|
||||
ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SME)
|
||||
ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)
|
||||
ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME)
|
||||
ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SVE2 SME)
|
||||
elseif (APPLE)
|
||||
ggml_add_cpu_backend_variant(apple_m1 DOTPROD)
|
||||
ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8)
|
||||
|
|
|
|||
|
|
@ -561,9 +561,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||
|
||||
# Fetch KleidiAI sources:
|
||||
include(FetchContent)
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.14.0")
|
||||
set(KLEIDIAI_COMMIT_TAG "v1.16.0")
|
||||
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "45e110675d93f99f82c23a1afcca76bc")
|
||||
set(KLEIDIAI_ARCHIVE_MD5 "0a9e9008adb6031f9e8cf70dff4a3321")
|
||||
|
||||
if (POLICY CMP0135)
|
||||
cmake_policy(SET CMP0135 NEW)
|
||||
|
|
@ -615,6 +615,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||
string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED)
|
||||
string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
|
||||
string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
|
||||
string(FIND "${ARCH_FLAGS_TEMP}" "+sve" SVE_ENABLED)
|
||||
|
||||
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP})
|
||||
|
||||
|
|
@ -659,6 +660,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
|
||||
endif()
|
||||
|
||||
if (NOT SVE_ENABLED MATCHES -1)
|
||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||
${KLEIDIAI_SRC}/kai/kai_common_sve_asm.S
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod_asm.S
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.c
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm_asm.S
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.c)
|
||||
endif()
|
||||
|
||||
set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
|
||||
list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES})
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@
|
|||
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
|
||||
#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h"
|
||||
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h"
|
||||
|
||||
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
||||
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
||||
|
|
@ -69,9 +71,9 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
|||
|
||||
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
|
||||
static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
||||
const void* lhs, const void* rhs, void* dst,
|
||||
size_t dst_stride_row, size_t dst_stride_col,
|
||||
float clamp_min, float clamp_max) {
|
||||
const void* lhs, const void* rhs, void* dst,
|
||||
size_t dst_stride_row, size_t dst_stride_col,
|
||||
float clamp_min, float clamp_max) {
|
||||
Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
||||
}
|
||||
|
||||
|
|
@ -152,8 +154,8 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
|
|||
|
||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
|
||||
static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
||||
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
|
||||
void* rhs_packed, size_t extra_bytes, const void* params) {
|
||||
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
|
||||
void* rhs_packed, size_t extra_bytes, const void* params) {
|
||||
Fn(num_groups, n, k, nr, kr, sr,
|
||||
static_cast<const int8_t*>(rhs),
|
||||
static_cast<const float*>(bias),
|
||||
|
|
@ -524,6 +526,61 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
},
|
||||
#endif
|
||||
#else
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
{
|
||||
/* SVE i8mm GEMM */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm>,
|
||||
},
|
||||
/* .gemm_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>,
|
||||
},
|
||||
/* SVE dotprod GEMV */
|
||||
/* .kern_info = */ {
|
||||
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
|
||||
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
|
||||
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
|
||||
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
|
||||
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
|
||||
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
|
||||
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
|
||||
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod,
|
||||
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
|
||||
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
|
||||
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod>,
|
||||
},
|
||||
/* .gemv_lhs_info = */ {
|
||||
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
|
||||
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>,
|
||||
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>,
|
||||
},
|
||||
/* .rhs_info = */ {
|
||||
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
|
||||
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
|
||||
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
/* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>,
|
||||
},
|
||||
/* .required_cpu = */ CPU_FEATURE_SVE | CPU_FEATURE_I8MM | CPU_FEATURE_DOTPROD,
|
||||
/* .lhs_type = */ GGML_TYPE_F32,
|
||||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
{
|
||||
/* i8mm GEMM */
|
||||
|
|
@ -578,7 +635,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||
/* .rhs_type = */ GGML_TYPE_Q4_0,
|
||||
/* .op_type = */ GGML_TYPE_F32,
|
||||
},
|
||||
#endif
|
||||
#endif // __ARM_FEATURE_MATMUL_INT8
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
{
|
||||
/* DOTPROD GEMM */
|
||||
|
|
@ -811,26 +868,27 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|||
ggml_kleidiai_kernels * kernel = nullptr;
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) {
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
|
||||
if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu &&
|
||||
gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type &&
|
||||
gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type &&
|
||||
gemm_gemv_kernels[i].op_type == tensor->type) {
|
||||
kernel = &gemm_gemv_kernels[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!kernel) {
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) {
|
||||
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
|
||||
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
|
||||
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
|
||||
gemm_gemv_kernels_q8[i].op_type == tensor->type) {
|
||||
kernel = &gemm_gemv_kernels_q8[i];
|
||||
break;
|
||||
#if defined(__ARM_FEATURE_SME) || \
|
||||
defined(__ARM_FEATURE_DOTPROD) || \
|
||||
defined(__ARM_FEATURE_MATMUL_INT8) || \
|
||||
defined(__ARM_FEATURE_SVE)
|
||||
auto try_table = [&](auto & table) {
|
||||
for (size_t i = 0; i < NELEMS(table) - 1; ++i) {
|
||||
if ((cpu_features & table[i].required_cpu) == table[i].required_cpu &&
|
||||
table[i].lhs_type == tensor->src[1]->type &&
|
||||
table[i].rhs_type == tensor->src[0]->type &&
|
||||
table[i].op_type == tensor->type) {
|
||||
kernel = &table[i];
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
if (tensor->src[0]->type == GGML_TYPE_Q8_0) {
|
||||
try_table(gemm_gemv_kernels_q8);
|
||||
} else {
|
||||
try_table(gemm_gemv_kernels);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(gemm_gemv_kernels);
|
||||
|
|
@ -845,7 +903,10 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) {
|
||||
ggml_kleidiai_kernels * kernels = nullptr;
|
||||
|
||||
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
#if defined(__ARM_FEATURE_SME) || \
|
||||
defined(__ARM_FEATURE_DOTPROD) || \
|
||||
defined(__ARM_FEATURE_MATMUL_INT8) || \
|
||||
defined(__ARM_FEATURE_SVE)
|
||||
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) {
|
||||
if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) {
|
||||
kernels = &gemm_gemv_kernels[i];
|
||||
|
|
|
|||
|
|
@ -46,13 +46,20 @@ struct ggml_kleidiai_context {
|
|||
} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
|
||||
|
||||
static const char* cpu_feature_to_string(cpu_feature f) {
|
||||
switch (f) {
|
||||
case CPU_FEATURE_NONE: return "NONE";
|
||||
case CPU_FEATURE_DOTPROD: return "DOTPROD";
|
||||
case CPU_FEATURE_I8MM: return "I8MM";
|
||||
case CPU_FEATURE_SVE: return "SVE";
|
||||
case CPU_FEATURE_SME: return "SME";
|
||||
default: return "UNKNOWN";
|
||||
if (f == CPU_FEATURE_NONE) {
|
||||
return "NONE";
|
||||
} else if ((f & CPU_FEATURE_SME) == CPU_FEATURE_SME) {
|
||||
return "SME";
|
||||
} else if ((f & CPU_FEATURE_SVE) == CPU_FEATURE_SVE) {
|
||||
return "SVE";
|
||||
}
|
||||
else if ((f & CPU_FEATURE_I8MM) == CPU_FEATURE_I8MM) {
|
||||
return "I8MM";
|
||||
} else if ((f & CPU_FEATURE_DOTPROD) == CPU_FEATURE_DOTPROD) {
|
||||
return "DOTPROD";
|
||||
}
|
||||
else {
|
||||
return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -68,7 +75,7 @@ static void init_kleidiai_context(void) {
|
|||
|
||||
ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
|
||||
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
|
||||
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
||||
((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
||||
|
||||
if (env_var) {
|
||||
sme_enabled = atoi(env_var);
|
||||
|
|
|
|||
Loading…
Reference in New Issue