From 4a99793eac83a8edccae519a085bf0e3a9e4206c Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Fri, 5 Dec 2025 16:21:06 +0500 Subject: [PATCH 1/3] ggml-cpu: add repack GEMM and GEMV for floating-point --- ggml/src/ggml-cpu/arch-fallback.h | 150 ++++- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 274 +++++++++ ggml/src/ggml-cpu/repack.cpp | 706 ++++++++++++++++++------ ggml/src/ggml-cpu/repack.h | 61 ++ 4 files changed, 992 insertions(+), 199 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 427c1146e4..c61fe80bdb 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -36,6 +36,8 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1 +#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -46,8 +48,14 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 -#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 -#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 +#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 +#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 +#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 +#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 +#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 +#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 +#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -58,36 +66,72 @@ # define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 -#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 -#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 +#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 +#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 +#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 +#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 +#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 +#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 +#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 +#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 +#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 +#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 +#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 +#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 +#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 +#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 +#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 +#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 +#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 +#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 +#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 +#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 +#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 +#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1 +#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 -#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 +#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 +#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 +#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 +#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 +#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 +#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 +#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 -#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 +#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 +#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 +#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 +#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 +#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 +#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 +#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 #elif defined(__POWERPC__) || defined(__powerpc__) // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679 // quants.c @@ -100,6 +144,8 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1 +#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -110,8 +156,14 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 -#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 -#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 +#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 +#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 +#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 +#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 +#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 +#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 +#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -122,8 +174,14 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 -#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 -#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 +#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 +#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 +#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 +#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 +#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 +#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 +#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 #elif defined(__loongarch64) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -136,6 +194,8 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1 +#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -146,8 +206,14 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 -#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 -#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 +#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 +#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 +#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 +#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 +#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 +#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 +#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -158,8 +224,14 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 -#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 -#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 +#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 +#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 +#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 +#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 +#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 +#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 +#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 #elif defined(__riscv) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -180,6 +252,8 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1 +#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K @@ -220,6 +294,8 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1 +#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -230,8 +306,14 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 -#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 -#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 +#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 +#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 +#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 +#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 +#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 +#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 +#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -242,8 +324,14 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 -#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 -#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 +#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 +#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 +#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 +#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 +#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 +#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 +#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 #elif defined(__wasm__) // quants.c #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 @@ -264,6 +352,8 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1 +#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -274,8 +364,14 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 -#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 -#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 +#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 +#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 +#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 +#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 +#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 +#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 +#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 +#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -286,6 +382,12 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 -#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 -#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 +#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 +#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 +#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 +#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 +#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 +#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 +#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 +#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 #endif diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 2a35ff9ad8..358e43f8d2 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -340,3 +340,277 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } + +template +static inline void ggml_gemv_f16_1xM_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int nb = n / 1; + + assert (nr == 1); + assert(n % 1 == 0); + assert(nc % ncols_interleaved == 0); + + const _Float16 * a_ptr = (const _Float16 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_f16 * b_ptr = (const block_f16 *) vx + (x * nb); + + // Accumulators + vfloat32m4_t sumf_0 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vfloat16m2_t b_0 = __riscv_vle16_v_f16m2((const _Float16 *)&b_ptr[l].d[0], ncols_interleaved); + + sumf_0 = __riscv_vfwmacc_vf_f32m4(sumf_0, *(const _Float16*)(&a_ptr[l]), b_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m4(&s[x * ncols_interleaved], sumf_0, ncols_interleaved); + } + + return; +} + +void ggml_gemv_f16_1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemv_f16_1xM_f16<16>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemv_f16_1x16_f16_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_f16_1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemv_f16_1xM_f16<32>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemv_f16_1x32_f16_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_f16_1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemv_f16_1xM_f16<64>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemv_f16_1x64_f16_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_f16_1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemv_f16_1xM_f16<128>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemv_f16_1x128_f16_generic(n, s, bs, vx, vy, nr, nc); +} + +template +static inline void ggml_gemv_f32_1xM_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int nb = n / 1; + + assert (nr == 1); + assert(n % 1 == 0); + assert(nc % ncols_interleaved == 0); + + const float * a_ptr = (const float *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_f32 * b_ptr = (const block_f32 *) vx + (x * nb); + + // Accumulators + vfloat32m4_t sumf_0 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vfloat32m4_t b_0 = __riscv_vle32_v_f32m4((const float *)&b_ptr[l].d[0], ncols_interleaved); + + sumf_0 = __riscv_vfmacc_vf_f32m4(sumf_0, *(const float*)(&a_ptr[l]), b_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m4(&s[x * ncols_interleaved], sumf_0, ncols_interleaved); + } + + return; +} + +void ggml_gemv_f32_1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemv_f32_1xM_f32<16>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemv_f32_1x16_f32_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_f32_1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemv_f32_1xM_f32<32>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemv_f32_1x32_f32_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_f32_1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemv_f32_1xM_f32<64>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemv_f32_1x64_f32_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_f32_1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemv_f32_1xM_f32<128>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemv_f32_1x128_f32_generic(n, s, bs, vx, vy, nr, nc); +} + +template +static inline void ggml_gemm_f16_7x1xM_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int nb = n / 1; + + assert (nr % 7 == 0); + assert(n % 1 == 0); + assert(nc % ncols_interleaved == 0); + + for (int y = 0; y < nr / 7; y++) { + const block_f16_7x1 * a_ptr = (const block_f16_7x1*) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_f16 * b_ptr = (const block_f16 *) vx + (x * nb); + + // Accumulators + vfloat32m4_t sumf_0 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + vfloat32m4_t sumf_1 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + vfloat32m4_t sumf_2 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + vfloat32m4_t sumf_3 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + vfloat32m4_t sumf_4 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + vfloat32m4_t sumf_5 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + vfloat32m4_t sumf_6 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vfloat16m2_t b_0 = __riscv_vle16_v_f16m2((const _Float16 *)&b_ptr[l].d[0], ncols_interleaved); + + sumf_0 = __riscv_vfwmacc_vf_f32m4(sumf_0, *(const _Float16*)&a_ptr[l].d[0], b_0, ncols_interleaved); + sumf_1 = __riscv_vfwmacc_vf_f32m4(sumf_1, *(const _Float16*)&a_ptr[l].d[1], b_0, ncols_interleaved); + sumf_2 = __riscv_vfwmacc_vf_f32m4(sumf_2, *(const _Float16*)&a_ptr[l].d[2], b_0, ncols_interleaved); + sumf_3 = __riscv_vfwmacc_vf_f32m4(sumf_3, *(const _Float16*)&a_ptr[l].d[3], b_0, ncols_interleaved); + sumf_4 = __riscv_vfwmacc_vf_f32m4(sumf_4, *(const _Float16*)&a_ptr[l].d[4], b_0, ncols_interleaved); + sumf_5 = __riscv_vfwmacc_vf_f32m4(sumf_5, *(const _Float16*)&a_ptr[l].d[5], b_0, ncols_interleaved); + sumf_6 = __riscv_vfwmacc_vf_f32m4(sumf_6, *(const _Float16*)&a_ptr[l].d[6], b_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m4(&s[(y * 7 + 0) * bs + x * ncols_interleaved], sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m4(&s[(y * 7 + 1) * bs + x * ncols_interleaved], sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m4(&s[(y * 7 + 2) * bs + x * ncols_interleaved], sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m4(&s[(y * 7 + 3) * bs + x * ncols_interleaved], sumf_3, ncols_interleaved); + __riscv_vse32_v_f32m4(&s[(y * 7 + 4) * bs + x * ncols_interleaved], sumf_4, ncols_interleaved); + __riscv_vse32_v_f32m4(&s[(y * 7 + 5) * bs + x * ncols_interleaved], sumf_5, ncols_interleaved); + __riscv_vse32_v_f32m4(&s[(y * 7 + 6) * bs + x * ncols_interleaved], sumf_6, ncols_interleaved); + } + } + return; +} + +void ggml_gemm_f16_7x1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemm_f16_7x1xM_f16<16>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemm_f16_7x1x16_f16_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_f16_7x1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemm_f16_7x1xM_f16<32>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemm_f16_7x1x32_f16_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_f16_7x1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemm_f16_7x1xM_f16<64>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemm_f16_7x1x64_f16_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_f16_7x1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemm_f16_7x1xM_f16<128>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemm_f16_7x1x128_f16_generic(n, s, bs, vx, vy, nr, nc); +} + +template +static inline void ggml_gemm_f32_7x1xM_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int nb = n / 1; + + assert (nr % 7 == 0); + assert(n % 1 == 0); + assert(nc % ncols_interleaved == 0); + + for (int y = 0; y < nr / 7; y++) { + const block_f32_7x1 * a_ptr = (const block_f32_7x1*) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_f32 * b_ptr = (const block_f32 *) vx + (x * nb); + + // Accumulators + vfloat32m4_t sumf_0 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + vfloat32m4_t sumf_1 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + vfloat32m4_t sumf_2 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + vfloat32m4_t sumf_3 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + vfloat32m4_t sumf_4 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + vfloat32m4_t sumf_5 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + vfloat32m4_t sumf_6 = __riscv_vfmv_v_f_f32m4(0.0f, ncols_interleaved); + + for (int l = 0; l < nb; l++) { + vfloat32m4_t b_0 = __riscv_vle32_v_f32m4((const float*)&b_ptr[l].d[0], ncols_interleaved); + + sumf_0 = __riscv_vfmacc_vf_f32m4(sumf_0, *(const float*)&a_ptr[l].d[0], b_0, ncols_interleaved); + sumf_1 = __riscv_vfmacc_vf_f32m4(sumf_1, *(const float*)&a_ptr[l].d[1], b_0, ncols_interleaved); + sumf_2 = __riscv_vfmacc_vf_f32m4(sumf_2, *(const float*)&a_ptr[l].d[2], b_0, ncols_interleaved); + sumf_3 = __riscv_vfmacc_vf_f32m4(sumf_3, *(const float*)&a_ptr[l].d[3], b_0, ncols_interleaved); + sumf_4 = __riscv_vfmacc_vf_f32m4(sumf_4, *(const float*)&a_ptr[l].d[4], b_0, ncols_interleaved); + sumf_5 = __riscv_vfmacc_vf_f32m4(sumf_5, *(const float*)&a_ptr[l].d[5], b_0, ncols_interleaved); + sumf_6 = __riscv_vfmacc_vf_f32m4(sumf_6, *(const float*)&a_ptr[l].d[6], b_0, ncols_interleaved); + } + + __riscv_vse32_v_f32m4(&s[(y * 7 + 0) * bs + x * ncols_interleaved], sumf_0, ncols_interleaved); + __riscv_vse32_v_f32m4(&s[(y * 7 + 1) * bs + x * ncols_interleaved], sumf_1, ncols_interleaved); + __riscv_vse32_v_f32m4(&s[(y * 7 + 2) * bs + x * ncols_interleaved], sumf_2, ncols_interleaved); + __riscv_vse32_v_f32m4(&s[(y * 7 + 3) * bs + x * ncols_interleaved], sumf_3, ncols_interleaved); + __riscv_vse32_v_f32m4(&s[(y * 7 + 4) * bs + x * ncols_interleaved], sumf_4, ncols_interleaved); + __riscv_vse32_v_f32m4(&s[(y * 7 + 5) * bs + x * ncols_interleaved], sumf_5, ncols_interleaved); + __riscv_vse32_v_f32m4(&s[(y * 7 + 6) * bs + x * ncols_interleaved], sumf_6, ncols_interleaved); + } + } + return; +} + +void ggml_gemm_f32_7x1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemm_f32_7x1xM_f32<16>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemm_f32_7x1x16_f32_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_f32_7x1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemm_f32_7x1xM_f32<32>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemm_f32_7x1x32_f32_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_f32_7x1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemm_f32_7x1xM_f32<64>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemm_f32_7x1x64_f32_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_f32_7x1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined __riscv_v_intrinsic + ggml_gemm_f32_7x1xM_f32<128>(n, s, bs, vx, vy, nr, nc); + return; +#endif + ggml_gemm_f32_7x1x128_f32_generic(n, s, bs, vx, vy, nr, nc); +} diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 24e8ab4618..cee953d23b 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -31,6 +31,40 @@ static inline int nearest_int(float fval) { return (i & 0x007fffff) - 0x00400000; } +// Helper template functions for `fp16` and `fp32`. + +template +static inline void ggml_repack_mat_f16_NxK_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % interleave_size == 0); + const int nb = k / interleave_size; + + block_f16 * GGML_RESTRICT y = (block_f16 *) vy; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < nrows_interleaved; j++) { + for (int l = 0; l < interleave_size; l++) { + y[i].d[j * interleave_size + l] = GGML_CPU_FP32_TO_FP16(x[j * k + i * interleave_size + l]); + } + } + } +} + +template +static inline void ggml_repack_mat_f32_NxK_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % interleave_size == 0); + const int nb = k / interleave_size; + + block_f32 * GGML_RESTRICT y = (block_f32 *) vy; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < nrows_interleaved; j++) { + for (int l = 0; l < interleave_size; l++) { + y[i].d[j * interleave_size + l] = x[j * k + i * interleave_size + l]; + } + } + } +} + // Functions to create the interleaved data layout formats // interleave 4 block_q4_0s in blocks of blck_size_interleave @@ -46,6 +80,7 @@ static inline int nearest_int(float fval) { // operations durin unpacking) // + extern "C" { void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { @@ -227,35 +262,189 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG } } +void ggml_repack_mat_f16_7x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + ggml_repack_mat_f16_NxK_generic<7, 1>(x, vy, k); +} + +void ggml_repack_mat_f32_7x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + ggml_repack_mat_f32_NxK_generic<7, 1>(x, vy, k); +} + } // extern "C" -template -void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row); +template +void ggml_repack_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row); -template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_repack_mat_t<4, 4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_repack_mat_t<4, 8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_repack_mat_t<4, 4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_repack_mat_t<4, 8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); } +template <> void ggml_repack_mat_t<7, 1, GGML_TYPE_F16>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 7); + UNUSED(nrow); + ggml_repack_mat_f16_7x1(x, vy, n_per_row); +} + +template <> void ggml_repack_mat_t<7, 1, GGML_TYPE_F32>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 7); + UNUSED(nrow); + ggml_repack_mat_f32_7x1(x, vy, n_per_row); +} + +template +static inline void ggml_gemv_f16_KxM_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int nb = n / interleave_size; + + assert(nr == 1); + assert(n % interleave_size == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[ncols_interleaved]; + + const ggml_half * a_ptr = (const ggml_half *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_f16 * b_ptr = + (const block_f16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { sumf[j] = 0.0f; } + for (int l = 0; l < nb; l++) { + for (int j = 0; j < ncols_interleaved; j++) { + for (int k = 0; k < interleave_size; k++) { + sumf[j] += GGML_FP16_TO_FP32(b_ptr[l].d[j * interleave_size + k]) * GGML_FP16_TO_FP32(a_ptr[l + k]); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { s[x * ncols_interleaved + j] = sumf[j]; } + } +} + +template +static inline void ggml_gemv_f32_KxM_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int nb = n / interleave_size; + + assert(nr == 1); + assert(n % interleave_size == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[ncols_interleaved]; + + const float * a_ptr = (const float *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_f32 * b_ptr = + (const block_f32 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { sumf[j] = 0.0f; } + for (int l = 0; l < nb; l++) { + for (int j = 0; j < ncols_interleaved; j++) { + for (int k = 0; k < interleave_size; k++) { + sumf[j] += b_ptr[l].d[j * interleave_size + k] * a_ptr[l + k]; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { s[x * ncols_interleaved + j] = sumf[j]; } + } +} + +template +static inline void ggml_gemm_f16_NxKxM_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int nb = n / interleave_size; + + assert (nr % nrows == 0); + assert(n % interleave_size == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[nrows][ncols_interleaved]; + + for (int y = 0; y < nr / nrows; y++) { + const block_f16 * a_ptr = + (const block_f16 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_f16 * b_ptr = + (const block_f16 *) vx + (x * nb); + + for (int m = 0; m < nrows; m++) { + for (int j = 0; j < ncols_interleaved; j++) { sumf[m][j] = 0.0f; } + } + for (int l = 0; l < nb; l++) { + for (int m = 0; m < nrows; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + for (int k = 0; k < interleave_size; k++) { + sumf[m][j] += b_ptr[l].d[j * interleave_size + k] * a_ptr[l].d[m * interleave_size + k]; + } + } + } + } + for (int m = 0; m < nrows; m++) { + for (int j = 0; j < ncols_interleaved; j++) + { s[(y * nrows + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; } + } + } + } +} + +template +static inline void ggml_gemm_f32_NxKxM_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int nb = n / interleave_size; + + assert (nr % nrows == 0); + assert(n % interleave_size == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[nrows][ncols_interleaved]; + + for (int y = 0; y < nr / nrows; y++) { + const block_f32 * a_ptr = + (const block_f32 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_f32 * b_ptr = + (const block_f32 *) vx + (x * nb); + + for (int m = 0; m < nrows; m++) { + for (int j = 0; j < ncols_interleaved; j++) { sumf[m][j] = 0.0f; } + } + for (int l = 0; l < nb; l++) { + for (int m = 0; m < nrows; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + for (int k = 0; k < interleave_size; k++) { + sumf[m][j] += b_ptr[l].d[j * interleave_size + k] * a_ptr[l].d[m * interleave_size + k]; + } + } + } + } + for (int m = 0; m < nrows; m++) { + for (int j = 0; j < ncols_interleaved; j++) + { s[(y * nrows + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; } + } + } + } +} + extern "C" { void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -794,51 +983,107 @@ void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; +void ggml_gemv_q5_K_8x8_q8_K_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; - assert(nr == 1); assert(n % qk == 0); assert(nc % ncols_interleaved == 0); UNUSED(bs); UNUSED(nr); - float sumf[4]; - int sumi; + float sumf[8]; + float sum_minf[8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + const block_q8_K * a_ptr = (const block_q8_K *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; + + const int qh_shift = (k / 4) * 2; for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; + sumi1 = 0; + sumi2 = 0; + sumi = 0; for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * 8 + i) % 32; + const int qh_chunk = qh_idx / 8; + const int qh_pos = qh_idx % 8; + const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; } } } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } } } -void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; + +void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + constexpr int qk = QK_K; const int nb = n / qk; const int ncols_interleaved = 8; const int blocklen = 8; - assert(nr == 1); assert(n % qk == 0); assert(nc % ncols_interleaved == 0); @@ -846,15 +1091,35 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs UNUSED(nr); float sumf[8]; - int sumi; - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + const block_q8_K * a_ptr = (const block_q8_K *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { + + + for (int k = 0; k < 16; k++) { + // k = 0.. 7 weights 0-63 low, 64-127 high + // k = 8..15 weights 128-191 low, 192-255 high + const int base_l = (k / 8) * 128 + (k % 8) * 8; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + // qh_half: offset to the correct 32-byte half (0 or 32) + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + for (int j = 0; j < ncols_interleaved; j++) { sumi = 0; for (int i = 0; i < blocklen; ++i) { @@ -1690,108 +1955,36 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } -void ggml_gemm_q8_0_4x4_q8_0_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); - - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0; - } - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / blocklen); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; - sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; - } - sumf[m][j] += - sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } +void ggml_gemm_f16_7x1x16_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_f16_NxKxM_f16_generic<7, 1, 16>(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q8_0_4x8_q8_0_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; +void ggml_gemm_f16_7x1x32_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_f16_NxKxM_f16_generic<7, 1, 32>(n, s, bs, vx, vy, nr, nc); +} - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); +void ggml_gemm_f16_7x1x64_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_f16_NxKxM_f16_generic<7, 1, 64>(n, s, bs, vx, vy, nr, nc); +} - float sumf[4][4]; - int sumi; +void ggml_gemm_f16_7x1x128_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_f16_NxKxM_f16_generic<7, 1, 128>(n, s, bs, vx, vy, nr, nc); +} - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0; - } - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / blocklen); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; - sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; - } - sumf[m][j] += - sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } +void ggml_gemm_f32_7x1x16_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_f32_NxKxM_f32_generic<7, 1, 16>(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_f32_7x1x32_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_f32_NxKxM_f32_generic<7, 1, 32>(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_f32_7x1x64_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_f32_NxKxM_f32_generic<7, 1, 64>(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_f32_7x1x128_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_f32_NxKxM_f32_generic<7, 1, 128>(n, s, bs, vx, vy, nr, nc); } } // extern "C" @@ -2477,6 +2670,78 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b GGML_UNUSED(data_size); } +template +static int repack_f16_to_f16_N_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_F16); + + const ggml_half * src = (const ggml_half *)data; + block_f16 * dst = ( block_f16 *)t->data; + + ggml_half dst_tmp[nrows_interleaved * interleave_size]; + + int nrow = ggml_nrows(t); + int row_size = t->ne[0]; + int nblocks = row_size / interleave_size; + + GGML_ASSERT(data_size == nrow * nblocks * interleave_size * sizeof(ggml_half)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % interleave_size != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int i = 0; i < nblocks; i++) { + for (int j = 0; j < nrows_interleaved; j++) { + for (int k = 0; k < interleave_size; k++) { + dst_tmp[j * interleave_size + k] = src[(j + b) * row_size + i * interleave_size + k]; + } + } + block_f16 out; + memcpy(&out.d, dst_tmp, sizeof(ggml_half) * nrows_interleaved * interleave_size); + *dst = out; + dst++; + } + } + + return 0; +} + +template +static int repack_f32_to_f32_N_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_F32); + + const float * src = (const float *)data; + block_f32 * dst = ( block_f32 *)t->data; + + float dst_tmp[nrows_interleaved * interleave_size]; + + int nrow = ggml_nrows(t); + int row_size = t->ne[0]; + int nblocks = row_size / interleave_size; + + GGML_ASSERT(data_size == nrow * nblocks * interleave_size * sizeof(float)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % interleave_size != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int i = 0; i < nblocks; i++) { + for (int j = 0; j < nrows_interleaved; j++) { + for (int k = 0; k < interleave_size; k++) { + dst_tmp[j * interleave_size + k] = src[(j + b) * row_size + i * interleave_size + k]; + } + } + block_f32 out; + memcpy(&out.d, dst_tmp, sizeof(float) * nrows_interleaved * interleave_size); + *dst = out; + dst++; + } + } + + return 0; +} + namespace ggml::cpu::repack { // repack template @@ -2528,12 +2793,30 @@ template <> int repack(struct ggml_tensor * t, const void * return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size); } -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size); +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_f16_to_f16_N_bl<16, 1>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_f16_to_f16_N_bl<32, 1>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_f16_to_f16_N_bl<64, 1>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_f16_to_f16_N_bl<128, 1>(t, data, data_size); } -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size); +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_f32_to_f32_N_bl<16, 1>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_f32_to_f32_N_bl<32, 1>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_f32_to_f32_N_bl<64, 1>(t, data, data_size); +} +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_f32_to_f32_N_bl<128, 1>(t, data, data_size); } // gemv @@ -2587,23 +2870,47 @@ template <> void gemv(int n, float * s, size ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_f16_1x16_f16(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_f16_1x32_f16(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_f16_1x64_f16(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_f16_1x128_f16(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_f32_1x16_f32(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_f32_1x32_f32(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_f32_1x64_f32(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_f32_1x128_f32(n, s, bs, vx, vy, nr, nc); } // gemm -template +template void gemm(int, float *, size_t, const void *, const void *, int, int); -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -2626,11 +2933,11 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -2638,20 +2945,44 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_f16_7x1x16_f16(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_f16_7x1x32_f16(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_f16_7x1x64_f16(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_f16_7x1x128_f16(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_f32_7x1x16_f32(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_f32_7x1x32_f32(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_f32_7x1x64_f32(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_f32_7x1x128_f32(n, s, bs, vx, vy, nr, nc); } class tensor_traits_base : public ggml::cpu::tensor_traits { @@ -2659,7 +2990,7 @@ class tensor_traits_base : public ggml::cpu::tensor_traits { virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; }; -template class tensor_traits : public tensor_traits_base { +template class tensor_traits : public tensor_traits_base { bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { // not realy a GGML_TYPE_Q8_0 but same size. @@ -2742,12 +3073,12 @@ template wdata + params->wsize); // If there are more than three rows in src1, use gemm; otherwise, use gemv. - if (nrows > 3) { - gemm(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0, + if (nrows > (NB_ROWS - 1)) { + gemm(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0, src0_ptr + src0_start * nb01, src1_ptr, - nrows - (nrows % 4), ncols); + nrows - (nrows % NB_ROWS), ncols); } - for (int iter = nrows - (nrows % 4); iter < nrows; iter++) { + for (int iter = nrows - (nrows % NB_ROWS); iter < nrows; iter++) { gemv(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start, ne01, src0_ptr + src0_start * nb01, src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols); @@ -2801,12 +3132,12 @@ template data + i12 * nb12; char * wdata_ptr = wdata + i12 * nbw2; - for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { - ggml_quantize_mat_t((float *) (data_ptr + i11 * nb11), - (void *) (wdata_ptr + i11 * nbw1), 4, ne10); + for (int64_t i11 = ith * NB_ROWS; i11 < ne11 - ne11 % NB_ROWS; i11 += nth * NB_ROWS) { + ggml_repack_mat_t((float *) (data_ptr + i11 * nb11), + (void *) (wdata_ptr + i11 * nbw1), NB_ROWS, ne10); } - const int64_t i11_processed = ne11 - ne11 % 4; + const int64_t i11_processed = ne11 - ne11 % NB_ROWS; for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10); } @@ -2818,7 +3149,7 @@ template src[0]); - int nth_scaled = nth * 4; + int nth_scaled = nth * NB_ROWS; int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled; int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0; @@ -3031,13 +3362,13 @@ template q4_0_4x4_q8_0; - static const ggml::cpu::repack::tensor_traits q4_0_4x8_q8_0; - static const ggml::cpu::repack::tensor_traits q4_0_8x8_q8_0; + static const ggml::cpu::repack::tensor_traits q4_0_4x4_q8_0; + static const ggml::cpu::repack::tensor_traits q4_0_4x8_q8_0; + static const ggml::cpu::repack::tensor_traits q4_0_8x8_q8_0; // instance for Q4_K - static const ggml::cpu::repack::tensor_traits q4_K_8x4_q8_K; - static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; + static const ggml::cpu::repack::tensor_traits q4_K_8x4_q8_K; + static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; // instance for Q5_K static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K; @@ -3046,15 +3377,27 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q6_K_8x8_q8_K; // instance for Q2 - static const ggml::cpu::repack::tensor_traits q2_K_8x8_q8_K; + static const ggml::cpu::repack::tensor_traits q2_K_8x8_q8_K; // instance for IQ4 - static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0; - static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0; + static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0; - // instance for Q8_0 - static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0; - static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0; + // instance for F16 +#if defined __riscv_zvfh + static const ggml::cpu::repack::tensor_traits f16_7x16x1_f16; + static const ggml::cpu::repack::tensor_traits f16_7x32x1_f16; + static const ggml::cpu::repack::tensor_traits f16_7x64x1_f16; + static const ggml::cpu::repack::tensor_traits f16_7x128x1_f16; +#endif + + // instance for F32 +#if defined __riscv_zvfh + static const ggml::cpu::repack::tensor_traits f32_7x16x1_f32; + static const ggml::cpu::repack::tensor_traits f32_7x32x1_f32; + static const ggml::cpu::repack::tensor_traits f32_7x64x1_f32; + static const ggml::cpu::repack::tensor_traits f32_7x128x1_f32; +#endif if (cur->type == GGML_TYPE_Q4_0) { if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) @@ -3118,16 +3461,29 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &iq4_nl_4x4_q8_0; } } - } else if (cur->type == GGML_TYPE_Q8_0) { - if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { - if (cur->ne[1] % 4 == 0) { - return &q8_0_4x8_q8_0; + } else if (cur->type == GGML_TYPE_F16) { + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { if (cur->ne[1] % 16 == 0) { return &f16_7x16x1_f16; } break; } + case 256: { if (cur->ne[1] % 32 == 0) { return &f16_7x32x1_f16; } break; } + case 512: { if (cur->ne[1] % 64 == 0) { return &f16_7x64x1_f16; } break; } + case 1024: { if (cur->ne[1] % 128 == 0) { return &f16_7x128x1_f16; } break; } + default: return nullptr; } + #endif } - if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { - if (cur->ne[1] % 4 == 0) { - return &q8_0_4x4_q8_0; + } else if (cur->type == GGML_TYPE_F32) { + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { if (cur->ne[1] % 16 == 0) { return &f32_7x16x1_f32; } break; } + case 256: { if (cur->ne[1] % 32 == 0) { return &f32_7x32x1_f32; } break; } + case 512: { if (cur->ne[1] % 64 == 0) { return &f32_7x64x1_f32; } break; } + case 1024: { if (cur->ne[1] % 128 == 0) { return &f32_7x128x1_f32; } break; } + default: return nullptr; } + #endif } } diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 855320eeeb..b2044ff7b5 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -97,6 +97,23 @@ struct block_iq4_nlx8 { static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding"); +template +struct block_f16 { + ggml_half d[N * K]; +}; + +using block_f16_32x1 = block_f16<32, 1>; +using block_f16_7x1 = block_f16<7, 1>; +using block_f16_4x1 = block_f16<4, 1>; + +template +struct block_f32 { + float d[N * K]; +}; + +using block_f32_32x1 = block_f32<32, 1>; +using block_f32_7x1 = block_f32<7, 1>; + #if defined(__cplusplus) extern "C" { #endif @@ -160,6 +177,50 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +// FP16 +void ggml_repack_mat_f16_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_repack_mat_f16_7x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_f16_1x16_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_f16_1x32_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_f16_1x64_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_f16_1x128_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f16_4x1x32_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f16_7x1x16_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f16_7x1x32_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f16_7x1x64_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f16_7x1x128_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_repack_mat_f16_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_repack_mat_f16_7x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_f16_1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_f16_1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_f16_1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_f16_1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f16_4x1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f16_7x1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f16_7x1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f16_7x1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f16_7x1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); + +// FP32 +void ggml_repack_mat_f32_7x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_f32_1x16_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_f32_1x32_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_f32_1x64_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_f32_1x128_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f32_7x1x16_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f32_7x1x32_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f32_7x1x64_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f32_7x1x128_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_repack_mat_f32_7x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_f32_1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_f32_1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_f32_1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_f32_1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f32_7x1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f32_7x1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f32_7x1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_f32_7x1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); + #if defined(__cplusplus) } // extern "C" #endif From 0d9caadc2a3531a0a0a1d0e757c29e6d3ccf996a Mon Sep 17 00:00:00 2001 From: Taimur Ahmad Date: Tue, 23 Dec 2025 15:13:09 +0500 Subject: [PATCH 2/3] ggml-cpu: add repack GEMM and GEMV for floating-point (#4) --- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 100 +++++++++-------- ggml/src/ggml-cpu/repack.cpp | 140 +++++++++++------------- ggml/src/ggml-cpu/repack.h | 6 +- 3 files changed, 115 insertions(+), 131 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 358e43f8d2..c1541d1c03 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -343,6 +343,8 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo template static inline void ggml_gemv_f16_1xM_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + GGML_UNUSED(bs); + const int nb = n / 1; assert (nr == 1); @@ -369,39 +371,41 @@ static inline void ggml_gemv_f16_1xM_f16(int n, float * GGML_RESTRICT s, size_t } void ggml_gemv_f16_1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemv_f16_1xM_f16<16>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemv_f16_1x16_f16_generic(n, s, bs, vx, vy, nr, nc); +#endif } void ggml_gemv_f16_1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemv_f16_1xM_f16<32>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemv_f16_1x32_f16_generic(n, s, bs, vx, vy, nr, nc); +#endif } void ggml_gemv_f16_1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemv_f16_1xM_f16<64>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemv_f16_1x64_f16_generic(n, s, bs, vx, vy, nr, nc); +#endif } void ggml_gemv_f16_1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemv_f16_1xM_f16<128>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemv_f16_1x128_f16_generic(n, s, bs, vx, vy, nr, nc); +#endif } template static inline void ggml_gemv_f32_1xM_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + GGML_UNUSED(bs); + const int nb = n / 1; assert (nr == 1); @@ -428,35 +432,35 @@ static inline void ggml_gemv_f32_1xM_f32(int n, float * GGML_RESTRICT s, size_t } void ggml_gemv_f32_1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemv_f32_1xM_f32<16>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemv_f32_1x16_f32_generic(n, s, bs, vx, vy, nr, nc); +#endif } void ggml_gemv_f32_1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemv_f32_1xM_f32<32>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemv_f32_1x32_f32_generic(n, s, bs, vx, vy, nr, nc); +#endif } void ggml_gemv_f32_1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemv_f32_1xM_f32<64>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemv_f32_1x64_f32_generic(n, s, bs, vx, vy, nr, nc); +#endif } void ggml_gemv_f32_1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemv_f32_1xM_f32<128>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemv_f32_1x128_f32_generic(n, s, bs, vx, vy, nr, nc); +#endif } template @@ -506,35 +510,35 @@ static inline void ggml_gemm_f16_7x1xM_f16(int n, float * GGML_RESTRICT s, size_ } void ggml_gemm_f16_7x1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemm_f16_7x1xM_f16<16>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemm_f16_7x1x16_f16_generic(n, s, bs, vx, vy, nr, nc); +#endif } void ggml_gemm_f16_7x1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemm_f16_7x1xM_f16<32>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemm_f16_7x1x32_f16_generic(n, s, bs, vx, vy, nr, nc); +#endif } void ggml_gemm_f16_7x1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemm_f16_7x1xM_f16<64>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemm_f16_7x1x64_f16_generic(n, s, bs, vx, vy, nr, nc); +#endif } void ggml_gemm_f16_7x1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemm_f16_7x1xM_f16<128>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemm_f16_7x1x128_f16_generic(n, s, bs, vx, vy, nr, nc); +#endif } template @@ -584,33 +588,33 @@ static inline void ggml_gemm_f32_7x1xM_f32(int n, float * GGML_RESTRICT s, size_ } void ggml_gemm_f32_7x1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemm_f32_7x1xM_f32<16>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemm_f32_7x1x16_f32_generic(n, s, bs, vx, vy, nr, nc); +#endif } void ggml_gemm_f32_7x1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemm_f32_7x1xM_f32<32>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemm_f32_7x1x32_f32_generic(n, s, bs, vx, vy, nr, nc); +#endif } void ggml_gemm_f32_7x1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemm_f32_7x1xM_f32<64>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemm_f32_7x1x64_f32_generic(n, s, bs, vx, vy, nr, nc); +#endif } void ggml_gemm_f32_7x1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_zvfh ggml_gemm_f32_7x1xM_f32<128>(n, s, bs, vx, vy, nr, nc); - return; -#endif +#else ggml_gemm_f32_7x1x128_f32_generic(n, s, bs, vx, vy, nr, nc); +#endif } diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index cee953d23b..a40f2ce801 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -31,7 +31,7 @@ static inline int nearest_int(float fval) { return (i & 0x007fffff) - 0x00400000; } -// Helper template functions for `fp16` and `fp32`. +// Helper functions for `fp16` and `fp32`. template static inline void ggml_repack_mat_f16_NxK_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { @@ -262,6 +262,7 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG } } +#if defined __riscv_zvfh void ggml_repack_mat_f16_7x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { ggml_repack_mat_f16_NxK_generic<7, 1>(x, vy, k); } @@ -269,6 +270,7 @@ void ggml_repack_mat_f16_7x1_generic(const float * GGML_RESTRICT x, void * GGML_ void ggml_repack_mat_f32_7x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { ggml_repack_mat_f32_NxK_generic<7, 1>(x, vy, k); } +#endif } // extern "C" @@ -299,6 +301,7 @@ template <> void ggml_repack_mat_t<4, 8, GGML_TYPE_Q8_K>(const float * GGML_REST ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); } +#if defined __riscv_zvfh template <> void ggml_repack_mat_t<7, 1, GGML_TYPE_F16>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 7); UNUSED(nrow); @@ -310,6 +313,7 @@ template <> void ggml_repack_mat_t<7, 1, GGML_TYPE_F32>(const float * GGML_RESTR UNUSED(nrow); ggml_repack_mat_f32_7x1(x, vy, n_per_row); } +#endif template static inline void ggml_gemv_f16_KxM_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -333,7 +337,7 @@ static inline void ggml_gemv_f16_KxM_f16_generic(int n, float * GGML_RESTRICT s, for (int l = 0; l < nb; l++) { for (int j = 0; j < ncols_interleaved; j++) { for (int k = 0; k < interleave_size; k++) { - sumf[j] += GGML_FP16_TO_FP32(b_ptr[l].d[j * interleave_size + k]) * GGML_FP16_TO_FP32(a_ptr[l + k]); + sumf[j] += GGML_FP16_TO_FP32(b_ptr[l].d[j * interleave_size + k]) * GGML_FP16_TO_FP32(a_ptr[l * interleave_size + k]); } } } @@ -363,7 +367,7 @@ static inline void ggml_gemv_f32_KxM_f32_generic(int n, float * GGML_RESTRICT s, for (int l = 0; l < nb; l++) { for (int j = 0; j < ncols_interleaved; j++) { for (int k = 0; k < interleave_size; k++) { - sumf[j] += b_ptr[l].d[j * interleave_size + k] * a_ptr[l + k]; + sumf[j] += b_ptr[l].d[j * interleave_size + k] * a_ptr[l * interleave_size + k]; } } } @@ -375,7 +379,7 @@ template static inline void ggml_gemm_f16_NxKxM_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int nb = n / interleave_size; - assert (nr % nrows == 0); + assert(nr % nrows == 0); assert(n % interleave_size == 0); assert(nc % ncols_interleaved == 0); @@ -395,7 +399,7 @@ static inline void ggml_gemm_f16_NxKxM_f16_generic(int n, float * GGML_RESTRICT for (int m = 0; m < nrows; m++) { for (int j = 0; j < ncols_interleaved; j++) { for (int k = 0; k < interleave_size; k++) { - sumf[m][j] += b_ptr[l].d[j * interleave_size + k] * a_ptr[l].d[m * interleave_size + k]; + sumf[m][j] += GGML_FP16_TO_FP32(b_ptr[l].d[j * interleave_size + k]) * GGML_FP16_TO_FP32(a_ptr[l].d[m * interleave_size + k]); } } } @@ -412,7 +416,7 @@ template static inline void ggml_gemm_f32_NxKxM_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int nb = n / interleave_size; - assert (nr % nrows == 0); + assert(nr % nrows == 0); assert(n % interleave_size == 0); assert(nc % ncols_interleaved == 0); @@ -1135,7 +1139,7 @@ void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -void ggml_gemv_q8_0_4x4_q8_0_generic(int n, +#if defined __riscv_zvfhvoid ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, @@ -1182,53 +1186,23 @@ void ggml_gemv_q8_0_4x4_q8_0_generic(int n, } } -void ggml_gemv_q8_0_4x8_q8_0_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / blocklen); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; - sumi += v0 * a_ptr[l].qs[k * blocklen + i]; - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j]; - } - } +void ggml_gemv_f32_1x16_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_f32_KxM_f32_generic<1, 16>(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_f32_1x32_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_f32_KxM_f32_generic<1, 32>(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_f32_1x64_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_f32_KxM_f32_generic<1, 64>(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_f32_1x128_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_f32_KxM_f32_generic<1, 128>(n, s, bs, vx, vy, nr, nc); +} +#endif + void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1955,6 +1929,7 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +#if defined __riscv_zvfh void ggml_gemm_f16_7x1x16_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemm_f16_NxKxM_f16_generic<7, 1, 16>(n, s, bs, vx, vy, nr, nc); } @@ -1986,6 +1961,7 @@ void ggml_gemm_f32_7x1x64_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_f32_7x1x128_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemm_f32_NxKxM_f32_generic<7, 1, 128>(n, s, bs, vx, vy, nr, nc); } +#endif } // extern "C" @@ -2670,14 +2646,14 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b GGML_UNUSED(data_size); } -template -static int repack_f16_to_f16_N_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { +template +static int repack_f16_to_f16_MxK_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_F16); const ggml_half * src = (const ggml_half *)data; - block_f16 * dst = ( block_f16 *)t->data; + block_f16 * dst = ( block_f16 *)t->data; - ggml_half dst_tmp[nrows_interleaved * interleave_size]; + ggml_half dst_tmp[ncols_interleaved * interleave_size]; int nrow = ggml_nrows(t); int row_size = t->ne[0]; @@ -2685,19 +2661,19 @@ static int repack_f16_to_f16_N_bl(struct ggml_tensor * t, const void * GGML_REST GGML_ASSERT(data_size == nrow * nblocks * interleave_size * sizeof(ggml_half)); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % interleave_size != 0) { + if (t->ne[1] % ncols_interleaved != 0 || t->ne[0] % interleave_size != 0) { return -1; } - for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int b = 0; b < nrow; b += ncols_interleaved) { for (int i = 0; i < nblocks; i++) { - for (int j = 0; j < nrows_interleaved; j++) { + for (int j = 0; j < ncols_interleaved; j++) { for (int k = 0; k < interleave_size; k++) { dst_tmp[j * interleave_size + k] = src[(j + b) * row_size + i * interleave_size + k]; } } - block_f16 out; - memcpy(&out.d, dst_tmp, sizeof(ggml_half) * nrows_interleaved * interleave_size); + block_f16 out; + memcpy(&out.d, dst_tmp, sizeof(ggml_half) * ncols_interleaved * interleave_size); *dst = out; dst++; } @@ -2706,14 +2682,14 @@ static int repack_f16_to_f16_N_bl(struct ggml_tensor * t, const void * GGML_REST return 0; } -template -static int repack_f32_to_f32_N_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { +template +static int repack_f32_to_f32_MxK_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_F32); const float * src = (const float *)data; - block_f32 * dst = ( block_f32 *)t->data; + block_f32 * dst = ( block_f32 *)t->data; - float dst_tmp[nrows_interleaved * interleave_size]; + float dst_tmp[ncols_interleaved * interleave_size]; int nrow = ggml_nrows(t); int row_size = t->ne[0]; @@ -2721,19 +2697,19 @@ static int repack_f32_to_f32_N_bl(struct ggml_tensor * t, const void * GGML_REST GGML_ASSERT(data_size == nrow * nblocks * interleave_size * sizeof(float)); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % interleave_size != 0) { + if (t->ne[1] % ncols_interleaved != 0 || t->ne[0] % interleave_size != 0) { return -1; } - for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int b = 0; b < nrow; b += ncols_interleaved) { for (int i = 0; i < nblocks; i++) { - for (int j = 0; j < nrows_interleaved; j++) { + for (int j = 0; j < ncols_interleaved; j++) { for (int k = 0; k < interleave_size; k++) { dst_tmp[j * interleave_size + k] = src[(j + b) * row_size + i * interleave_size + k]; } } - block_f32 out; - memcpy(&out.d, dst_tmp, sizeof(float) * nrows_interleaved * interleave_size); + block_f32 out; + memcpy(&out.d, dst_tmp, sizeof(float) * ncols_interleaved * interleave_size); *dst = out; dst++; } @@ -2793,31 +2769,33 @@ template <> int repack(struct ggml_tensor * t, const void * return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size); } +#if defined __riscv_zvfh template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_f16_to_f16_N_bl<16, 1>(t, data, data_size); + return repack_f16_to_f16_MxK_bl<16, 1>(t, data, data_size); } template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_f16_to_f16_N_bl<32, 1>(t, data, data_size); + return repack_f16_to_f16_MxK_bl<32, 1>(t, data, data_size); } template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_f16_to_f16_N_bl<64, 1>(t, data, data_size); + return repack_f16_to_f16_MxK_bl<64, 1>(t, data, data_size); } template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_f16_to_f16_N_bl<128, 1>(t, data, data_size); + return repack_f16_to_f16_MxK_bl<128, 1>(t, data, data_size); } template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_f32_to_f32_N_bl<16, 1>(t, data, data_size); + return repack_f32_to_f32_MxK_bl<16, 1>(t, data, data_size); } template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_f32_to_f32_N_bl<32, 1>(t, data, data_size); + return repack_f32_to_f32_MxK_bl<32, 1>(t, data, data_size); } template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_f32_to_f32_N_bl<64, 1>(t, data, data_size); + return repack_f32_to_f32_MxK_bl<64, 1>(t, data, data_size); } template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_f32_to_f32_N_bl<128, 1>(t, data, data_size); + return repack_f32_to_f32_MxK_bl<128, 1>(t, data, data_size); } +#endif // gemv template @@ -2870,6 +2848,7 @@ template <> void gemv(int n, float * s, size ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_f16_1x16_f16(n, s, bs, vx, vy, nr, nc); } @@ -2901,6 +2880,7 @@ template <> void gemv(int n, float * s, size_t bs, template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_f32_1x128_f32(n, s, bs, vx, vy, nr, nc); } +#endif // gemm template @@ -2953,6 +2933,7 @@ template <> void gemm(int n, float * s, s ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_f16_7x1x16_f16(n, s, bs, vx, vy, nr, nc); } @@ -2984,6 +2965,7 @@ template <> void gemm(int n, float * s, size_t b template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_f32_7x1x128_f32(n, s, bs, vx, vy, nr, nc); } +#endif class tensor_traits_base : public ggml::cpu::tensor_traits { public: @@ -3072,7 +3054,7 @@ template wdata + params->wsize); - // If there are more than three rows in src1, use gemm; otherwise, use gemv. + // If there are more than `NB_ROWS` rows in src1, use gemm; otherwise, use gemv. if (nrows > (NB_ROWS - 1)) { gemm(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0, src0_ptr + src0_start * nb01, src1_ptr, diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index b2044ff7b5..1badee46a4 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -177,25 +177,22 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#ifdef __riscv_zvfh // FP16 -void ggml_repack_mat_f16_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_repack_mat_f16_7x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_f16_1x16_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_f16_1x32_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_f16_1x64_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_f16_1x128_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_f16_4x1x32_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_f16_7x1x16_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_f16_7x1x32_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_f16_7x1x64_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_f16_7x1x128_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_repack_mat_f16_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_repack_mat_f16_7x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_f16_1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_f16_1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_f16_1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_f16_1x128_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_f16_4x1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_f16_7x1x16_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_f16_7x1x32_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_f16_7x1x64_f16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -220,6 +217,7 @@ void ggml_gemm_f32_7x1x16_f32(int n, float * GGML_RESTRICT s, size_t bs, const v void ggml_gemm_f32_7x1x32_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_f32_7x1x64_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_f32_7x1x128_f32(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#endif #if defined(__cplusplus) } // extern "C" From 28e07aad92e646adc4a9f3f77f99d129185d07f8 Mon Sep 17 00:00:00 2001 From: taimur-10x Date: Tue, 23 Dec 2025 15:20:05 +0500 Subject: [PATCH 3/3] ggml-cpu: refactor repack, format --- ggml/src/ggml-cpu/arch-fallback.h | 148 ++--------- ggml/src/ggml-cpu/repack.cpp | 398 ++++++++++++++++++++---------- 2 files changed, 288 insertions(+), 258 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index c61fe80bdb..0d0342f6eb 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -36,8 +36,6 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1 -#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -48,14 +46,8 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 -#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 -#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 -#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 -#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 -#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 -#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 -#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 -#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 +#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 +#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -66,72 +58,36 @@ # define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 -#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 -#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 -#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 -#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 -#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 -#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 -#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 -#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 +#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 +#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K -#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 -#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 -#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 -#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 -#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 -#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 -#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 -#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K -#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 -#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 -#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 -#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 -#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 -#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 -#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 -#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 -#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1 -#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 -#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 -#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 -#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 -#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 -#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 -#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 -#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 -#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 +#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 +#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 -#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 -#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 -#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 -#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 -#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 -#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 -#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 -#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 +#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 +#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__POWERPC__) || defined(__powerpc__) // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679 // quants.c @@ -144,8 +100,6 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1 -#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -156,14 +110,8 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 -#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 -#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 -#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 -#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 -#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 -#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 -#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 -#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 +#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 +#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -174,14 +122,8 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 -#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 -#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 -#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 -#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 -#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 -#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 -#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 -#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 +#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 +#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__loongarch64) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -194,8 +136,6 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1 -#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -206,14 +146,8 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 -#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 -#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 -#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 -#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 -#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 -#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 -#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 -#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 +#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 +#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -224,14 +158,8 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 -#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 -#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 -#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 -#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 -#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 -#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 -#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 -#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 +#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 +#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__riscv) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -294,8 +222,6 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1 -#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -306,14 +232,8 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 -#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 -#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 -#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 -#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 -#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 -#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 -#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 -#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 +#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 +#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -324,14 +244,8 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 -#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 -#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 -#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 -#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 -#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 -#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 -#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 -#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 +#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 +#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__wasm__) // quants.c #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 @@ -352,8 +266,6 @@ #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 -#define ggml_repack_mat_f16_7x1_generic ggml_repack_mat_f16_7x1 -#define ggml_repack_mat_f32_7x1_generic ggml_repack_mat_f32_7x1 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 @@ -364,14 +276,8 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 -#define ggml_gemv_f16_1x16_f16_generic ggml_gemv_f16_1x16_f16 -#define ggml_gemv_f16_1x32_f16_generic ggml_gemv_f16_1x32_f16 -#define ggml_gemv_f16_1x64_f16_generic ggml_gemv_f16_1x64_f16 -#define ggml_gemv_f16_1x128_f16_generic ggml_gemv_f16_1x128_f16 -#define ggml_gemv_f32_1x16_f32_generic ggml_gemv_f32_1x16_f32 -#define ggml_gemv_f32_1x32_f32_generic ggml_gemv_f32_1x32_f32 -#define ggml_gemv_f32_1x64_f32_generic ggml_gemv_f32_1x64_f32 -#define ggml_gemv_f32_1x128_f32_generic ggml_gemv_f32_1x128_f32 +#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 +#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 @@ -382,12 +288,6 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 -#define ggml_gemm_f16_7x1x16_f16_generic ggml_gemm_f16_7x1x16_f16 -#define ggml_gemm_f16_7x1x32_f16_generic ggml_gemm_f16_7x1x32_f16 -#define ggml_gemm_f16_7x1x64_f16_generic ggml_gemm_f16_7x1x64_f16 -#define ggml_gemm_f16_7x1x128_f16_generic ggml_gemm_f16_7x1x128_f16 -#define ggml_gemm_f32_7x1x16_f32_generic ggml_gemm_f32_7x1x16_f32 -#define ggml_gemm_f32_7x1x32_f32_generic ggml_gemm_f32_7x1x32_f32 -#define ggml_gemm_f32_7x1x64_f32_generic ggml_gemm_f32_7x1x64_f32 -#define ggml_gemm_f32_7x1x128_f32_generic ggml_gemm_f32_7x1x128_f32 +#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 +#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #endif diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index a40f2ce801..8c21150ff0 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -32,7 +32,7 @@ static inline int nearest_int(float fval) { } // Helper functions for `fp16` and `fp32`. - +// template static inline void ggml_repack_mat_f16_NxK_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(k % interleave_size == 0); @@ -80,7 +80,6 @@ static inline void ggml_repack_mat_f32_NxK_generic(const float * GGML_RESTRICT x // operations durin unpacking) // - extern "C" { void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { @@ -987,143 +986,29 @@ void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -void ggml_gemv_q5_K_8x8_q8_K_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[8]; - float sum_minf[8]; - uint32_t utmp[32]; - int sumi1; - int sumi2; - int sumi; - - const block_q8_K * a_ptr = (const block_q8_K *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0; - sum_minf[j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int sb = 0; sb < 8; sb++) { - memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); - utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; - utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); - utmp[sb * 4 + 2] = uaux_0; - utmp[sb * 4 + 0] &= kmask1; - } - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; - uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; - - const int qh_shift = (k / 4) * 2; - for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; - - const int qh_idx = (k * 8 + i) % 32; - const int qh_chunk = qh_idx / 8; - const int qh_pos = qh_idx % 8; - const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; - - const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; - const uint8_t h0 = (qh_val >> qh_shift) & 1; - const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; - - const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); - const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); - - const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i; - - sumi1 = (v0 * a_ptr[l].qs[q8_offset]); - sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; - } - } - for (int sb = 0; sb < 8; sb++) { - uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; - for (int j = 0; j < ncols_interleaved; j++) { - sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * - GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; - } - } - } - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; - } - } -} - - -void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - constexpr int qk = QK_K; +void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; + const int ncols_interleaved = 4; + const int blocklen = 4; + assert(nr == 1); assert(n % qk == 0); assert(nc % ncols_interleaved == 0); UNUSED(bs); UNUSED(nr); - float sumf[8]; + float sumf[4]; + int sumi; - const block_q8_K * a_ptr = (const block_q8_K *) vy; + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0f; - } + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; for (int l = 0; l < nb; l++) { - - - for (int k = 0; k < 16; k++) { - // k = 0.. 7 weights 0-63 low, 64-127 high - // k = 8..15 weights 128-191 low, 192-255 high - const int base_l = (k / 8) * 128 + (k % 8) * 8; - const int base_h = base_l + 64; - - const int scale_idx_l = base_l / 16; - const int scale_idx_h = base_h / 16; - - // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half - const int qh_shift_l = ((base_l % 128) / 32) * 2; - const int qh_shift_h = ((base_h % 128) / 32) * 2; - - // qh_half: offset to the correct 32-byte half (0 or 32) - const int qh_half_l = (base_l / 128) * 32; - const int qh_half_h = (base_h / 128) * 32; - + for (int k = 0; k < (qk / (2 * blocklen)); k++) { for (int j = 0; j < ncols_interleaved; j++) { sumi = 0; for (int i = 0; i < blocklen; ++i) { @@ -1139,7 +1024,45 @@ void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -#if defined __riscv_zvfhvoid ggml_gemv_q8_0_4x4_q8_0_generic(int n, +void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, @@ -1186,6 +1109,70 @@ void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemv_q8_0_4x8_q8_0_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * blocklen + i]; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +#if defined __riscv_zvfh +void ggml_gemv_f16_1x16_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_f16_KxM_f16_generic<1, 16>(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_f16_1x32_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_f16_KxM_f16_generic<1, 32>(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_f16_1x64_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_f16_KxM_f16_generic<1, 64>(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_f16_1x128_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_f16_KxM_f16_generic<1, 128>(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_f32_1x16_f32_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemv_f32_KxM_f32_generic<1, 16>(n, s, bs, vx, vy, nr, nc); } @@ -1929,6 +1916,110 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemm_q8_0_4x4_q8_0_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; + } + sumf[m][j] += + sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +void ggml_gemm_q8_0_4x8_q8_0_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; + } + sumf[m][j] += + sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + #if defined __riscv_zvfh void ggml_gemm_f16_7x1x16_f16_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { ggml_gemm_f16_NxKxM_f16_generic<7, 1, 16>(n, s, bs, vx, vy, nr, nc); @@ -2769,6 +2860,14 @@ template <> int repack(struct ggml_tensor * t, const void * return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size); +} + #if defined __riscv_zvfh template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_f16_to_f16_MxK_bl<16, 1>(t, data, data_size); @@ -2848,6 +2947,14 @@ template <> void gemv(int n, float * s, size ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + #if defined __riscv_zvfh template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_f16_1x16_f16(n, s, bs, vx, vy, nr, nc); @@ -2895,7 +3002,7 @@ template <> void gemm(int n, float * s, siz } template <> -void gemm(int n, +void gemm(int n, float * s, size_t bs, const void * vx, @@ -2905,11 +3012,11 @@ void gemm(int n, ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -2921,7 +3028,7 @@ template <> void gemm(int n, float * s, siz ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -2933,6 +3040,14 @@ template <> void gemm(int n, float * s, s ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + #if defined __riscv_zvfh template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_f16_7x1x16_f16(n, s, bs, vx, vy, nr, nc); @@ -3353,10 +3468,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; // instance for Q5_K - static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K; + static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K; // instance for Q6_K - static const ggml::cpu::repack::tensor_traits q6_K_8x8_q8_K; + static const ggml::cpu::repack::tensor_traits q6_K_8x8_q8_K; // instance for Q2 static const ggml::cpu::repack::tensor_traits q2_K_8x8_q8_K; @@ -3365,6 +3480,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0; + // instance for Q8_0 + static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0; + static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0; + // instance for F16 #if defined __riscv_zvfh static const ggml::cpu::repack::tensor_traits f16_7x16x1_f16; @@ -3443,6 +3562,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &iq4_nl_4x4_q8_0; } } + } else if (cur->type == GGML_TYPE_Q8_0) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 4 == 0) { + return &q8_0_4x8_q8_0; + } + } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 4 == 0) { + return &q8_0_4x4_q8_0; + } + } } else if (cur->type == GGML_TYPE_F16) { if (ggml_cpu_has_riscv_v()) { #if defined __riscv_zvfh