kleidiai : add sme fp16 compute path for q4_0 gemm on aarch64 (#20043)

This commit is contained in:
Charles Xu 2026-03-03 10:40:26 +01:00 committed by GitHub
parent 24350fdf9b
commit 137435ff15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 21 deletions

View File

@ -566,9 +566,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
# Fetch KleidiAI sources:
include(FetchContent)
set(KLEIDIAI_COMMIT_TAG "v1.16.0")
set(KLEIDIAI_COMMIT_TAG "v1.22.0")
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
set(KLEIDIAI_ARCHIVE_MD5 "0a9e9008adb6031f9e8cf70dff4a3321")
set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1")
if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW)
@ -608,6 +608,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
@ -648,7 +649,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
if (NOT SME_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S
@ -656,10 +656,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_f16pmrx2_f32_neon.c
${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2+sme2+fp16")
endif()
if (NOT SVE_ENABLED MATCHES -1)

View File

@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: MIT
//
@ -9,7 +9,6 @@
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h"
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
@ -20,6 +19,7 @@
#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h"
#include "kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.h"
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
@ -31,6 +31,7 @@
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
#include "kai_lhs_pack_f16pmrx2_f32_neon.h"
#include "kai_common.h"
@ -309,24 +310,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
{
/* SME GEMM */
/* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>,
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa,
/* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>,
/* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>,
/* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>,
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_f16pmrx2_f32_neon,
/* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_pack_f16pmrx2_f32_neon>,
/* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_pack_f16pmrx2_f32_neon>,
/* .pack_func_ex = */ &lhs_pack_void_fn10<kai_run_lhs_pack_f16pmrx2_f32_neon>,
},
/* SME GEMV */
/* .kern_info = */ {