ggml-cpu: aarm64: q5_K repack gemm and gemv (and generic) implementations (i8mm) (#18860)

* Boilerplate for q5_Kx8 REPACK on ARM and fallback

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Implements make_block_q5_Kx8 by extending make_block_q4_Kx8

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* q5_K repack gemm and gemv generics

* Gemm and Gemv ARM implementations (i8mm)

* Improved qh manipulation looking at non-repack vec_dot implementation

* Full unroll

* Apply Q5_K Gemv vand and vshl optimizations to gemm. Improve comments.

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Fix wrong fallback definitions of Q5_K

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Fixed comments. Reverted unnecessary formatting

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>

* Fixed typo in generic definitions

* Switching AND + Shift with Shift Insert. Better op interleaving.

* Vectorize + unroll the block scales

* Apply gemm optimizations to gemv

* Improve bias calculation

---------

Signed-off-by: Alberto Cabrera <alberto.cabrera@liquid.ai>
This commit is contained in:
Alberto Cabrera Pérez 2026-01-23 07:55:08 +00:00 committed by GitHub
parent a3e812811d
commit 091a46cb8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 931 additions and 38 deletions

View File

@ -38,9 +38,10 @@
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@ -48,9 +49,10 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@ -70,12 +72,14 @@
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
@ -94,9 +98,10 @@
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@ -104,9 +109,10 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@ -126,9 +132,10 @@
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@ -136,9 +143,10 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@ -165,18 +173,20 @@
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@ -202,9 +212,10 @@
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@ -212,9 +223,10 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@ -242,9 +254,10 @@
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@ -252,9 +265,10 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0

View File

@ -25,9 +25,8 @@
#define UNUSED GGML_UNUSED
#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
int16x8_t * out_mins,
int8_t * out_scales) {
// Helper for decoding scales and mins of Q4_K and Q5_K block formats
static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) {
constexpr uint32_t kmask1 = 0x3f3f3f3f;
constexpr uint32_t kmask2 = 0x0f0f0f0f;
constexpr uint32_t kmask3 = 0x03030303;
@ -561,7 +560,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
@ -701,7 +700,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
@ -786,6 +785,293 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q5_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 8;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int col_pairs = ncols_interleaved / 2;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
const uint8x16_t mone = vdupq_n_u8(1);
const uint8x16_t mtwo = vdupq_n_u8(2);
// 1x8 tile = 2 x 4
float32x4_t acc_f32[ncols_interleaved / 4];
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
for (int i = 0; i < ncols_interleaved / 4; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d);
float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d);
float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d);
float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d);
// 2 sb each iteration
int32x4_t acc_lo[col_pairs];
int32x4_t acc_hi[col_pairs];
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
int16_t bsums_arr[8];
vst1q_s16(bsums_arr, bsums);
// Load qh once per block and shift after each subblock
const uint8_t * qh_base = q5_ptr[b].qh;
uint8x16_t qh[col_pairs][4];
for (int cp = 0; cp < col_pairs; cp++) {
qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
}
for (int sb = 0; sb < QK_K / 64; sb++) {
for (int i = 0; i < col_pairs; i++) {
acc_lo[i] = vdupq_n_s32(0);
acc_hi[i] = vdupq_n_s32(0);
}
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
int16x8_t q5sb_scales[2];
for (int i = 0; i < 2; i++) {
int8_t aux_q5sb[8];
const int offset = sb * 24 + i * 12;
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
}
const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K;
// Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns
const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
int8x16_t q8_qs[8];
for (int i = 0; i < 8; i++) {
q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
}
// Q5s column pair loop unrolled
{
// Cols 01
uint8x16_t qs_0 = vld1q_u8(qs_base);
uint8x16_t qs_1 = vld1q_u8(qs_base + 64);
uint8x16_t qs_2 = vld1q_u8(qs_base + 128);
uint8x16_t qs_3 = vld1q_u8(qs_base + 192);
uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone);
uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone);
uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone);
uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone);
uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3);
uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3);
uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3);
uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3);
qh[0][0] = vshrq_n_u8(qh[0][0], 2);
qh[0][1] = vshrq_n_u8(qh[0][1], 2);
qh[0][2] = vshrq_n_u8(qh[0][2], 2);
qh[0][3] = vshrq_n_u8(qh[0][3], 2);
acc_lo[0] = ggml_vdotq_s32(
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
acc_lo[0] = ggml_vdotq_s32(
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
acc_lo[0] = ggml_vdotq_s32(
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
acc_lo[0] = ggml_vdotq_s32(
acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
q8_qs[4]);
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
q8_qs[5]);
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
q8_qs[6]);
acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
q8_qs[7]);
// Cols 23
qs_0 = vld1q_u8(qs_base + 16);
qs_1 = vld1q_u8(qs_base + 80);
qs_2 = vld1q_u8(qs_base + 144);
qs_3 = vld1q_u8(qs_base + 208);
hbit_lo_0 = vandq_u8(qh[1][0], mone);
hbit_lo_1 = vandq_u8(qh[1][1], mone);
hbit_lo_2 = vandq_u8(qh[1][2], mone);
hbit_lo_3 = vandq_u8(qh[1][3], mone);
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3);
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3);
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3);
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3);
qh[1][0] = vshrq_n_u8(qh[1][0], 2);
qh[1][1] = vshrq_n_u8(qh[1][1], 2);
qh[1][2] = vshrq_n_u8(qh[1][2], 2);
qh[1][3] = vshrq_n_u8(qh[1][3], 2);
acc_lo[1] = ggml_vdotq_s32(
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
acc_lo[1] = ggml_vdotq_s32(
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
acc_lo[1] = ggml_vdotq_s32(
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
acc_lo[1] = ggml_vdotq_s32(
acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
q8_qs[4]);
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
q8_qs[5]);
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
q8_qs[6]);
acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
q8_qs[7]);
// Cols 45
qs_0 = vld1q_u8(qs_base + 32);
qs_1 = vld1q_u8(qs_base + 96);
qs_2 = vld1q_u8(qs_base + 160);
qs_3 = vld1q_u8(qs_base + 224);
hbit_lo_0 = vandq_u8(qh[2][0], mone);
hbit_lo_1 = vandq_u8(qh[2][1], mone);
hbit_lo_2 = vandq_u8(qh[2][2], mone);
hbit_lo_3 = vandq_u8(qh[2][3], mone);
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3);
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3);
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3);
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3);
qh[2][0] = vshrq_n_u8(qh[2][0], 2);
qh[2][1] = vshrq_n_u8(qh[2][1], 2);
qh[2][2] = vshrq_n_u8(qh[2][2], 2);
qh[2][3] = vshrq_n_u8(qh[2][3], 2);
acc_lo[2] = ggml_vdotq_s32(
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
acc_lo[2] = ggml_vdotq_s32(
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
acc_lo[2] = ggml_vdotq_s32(
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
acc_lo[2] = ggml_vdotq_s32(
acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
q8_qs[4]);
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
q8_qs[5]);
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
q8_qs[6]);
acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
q8_qs[7]);
// Cols 45
qs_0 = vld1q_u8(qs_base + 48);
qs_1 = vld1q_u8(qs_base + 112);
qs_2 = vld1q_u8(qs_base + 176);
qs_3 = vld1q_u8(qs_base + 240);
hbit_lo_0 = vandq_u8(qh[3][0], mone);
hbit_lo_1 = vandq_u8(qh[3][1], mone);
hbit_lo_2 = vandq_u8(qh[3][2], mone);
hbit_lo_3 = vandq_u8(qh[3][3], mone);
hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3);
hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3);
hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3);
hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3);
qh[3][0] = vshrq_n_u8(qh[3][0], 2);
qh[3][1] = vshrq_n_u8(qh[3][1], 2);
qh[3][2] = vshrq_n_u8(qh[3][2], 2);
qh[3][3] = vshrq_n_u8(qh[3][3], 2);
acc_lo[3] = ggml_vdotq_s32(
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]);
acc_lo[3] = ggml_vdotq_s32(
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]);
acc_lo[3] = ggml_vdotq_s32(
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]);
acc_lo[3] = ggml_vdotq_s32(
acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]);
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)),
q8_qs[4]);
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)),
q8_qs[5]);
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)),
q8_qs[6]);
acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)),
q8_qs[7]);
}
// Prepare bsum vectors for bias computation
// Each pair of subblocks share the same bsums
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
// Iterates over a pair of column pairs (4 columns) to use a single 128 register
// p = 0 -> 0123 p2 -> 4567
for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]);
int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]);
int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]);
int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]);
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1;
// 0123 or 4567
float32x4_t sumf_0 =
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
float32x4_t sumf_1 =
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
// FUSED BIAS: Compute and subtract bias immediately
// bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min
int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo);
bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi);
float32x4_t bias_f32 = vcvtq_f32_s32(bias);
acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32);
}
} // for sb
} // for b
int base = x * ncols_interleaved;
vst1q_f32(s + base, acc_f32[0]);
vst1q_f32(s + base + 4, acc_f32[1]);
} // for x
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q8_0_4x4_q8_0(int n,
float * GGML_RESTRICT s,
size_t bs,
@ -2431,7 +2717,7 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
@ -2595,7 +2881,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
for (int i = 0; i < 2; i++) {
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
}
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
@ -2738,6 +3024,252 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q5_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
constexpr int q8_k_blocklen = 4;
constexpr int col_pairs = ncols_interleaved / 2;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
const uint8x16_t mone = vdupq_n_u8(1);
const uint8x16_t mtwo = vdupq_n_u8(2);
// 8 accumulators: 2 row pairs × 4 col pairs
float32x4_t acc_f32[blocklen];
for (int y = 0; y < nr / q8_k_blocklen; y++) {
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
for (int i = 0; i < blocklen; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
// bsums pairs belongs to the same q8_k subblock
const int16x8_t bsums[4]{
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
};
int16_t bsums_arr[4][8];
for (int q8_row = 0; q8_row < 4; q8_row++) {
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
}
int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
for (int i = 0; i < 8; i++) {
acc[i] = vdupq_n_s32(0);
bias_acc[i] = vdupq_n_s32(0);
}
// Load qh once per block and shift after each subblock
const uint8_t * qh_base = q5_ptr[b].qh;
uint8x16_t qh[col_pairs][4];
for (int cp = 0; cp < col_pairs; cp++) {
qh[cp][0] = vld1q_u8(qh_base + 16 * cp);
qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64);
qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128);
qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192);
}
for (int sb = 0; sb < QK_K / 64; sb++) {
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int8_t q5sb_scales[2][8];
int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later
for (int i = 0; i < 2; i++) {
const int offset = sb * 24 + i * 12;
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]);
}
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
int8x16_t q8_qs_01[8];
int8x16_t q8_qs_23[8];
// Load 32-byte per row pair, 1 subblock each time
for (int i = 0; i < 8; i++) {
const int offset = i * 32; // 16 for row 01, 16 for row 23
q8_qs_01[i] = vld1q_s8(q8_base + offset);
q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
}
const int8x16_t q8s[2][8] = {
{ q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6],
q8_qs_01[7] },
{ q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6],
q8_qs_23[7] },
};
// Q5s columns iterated in pairs (01, 23, 45, 67)
for (int cp = 0; cp < col_pairs; cp++) {
for (int i = 0; i < 4; i++) {
sb_acc[i] = vdupq_n_s32(0);
}
uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
// This is the only part of the algorithm that differs with Q4_K
// Extract High bits and pack into 5 bit weights
uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone);
uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3);
qh[cp][0] = vshrq_n_u8(qh[cp][0], 2);
// Same as Q4_K, i8mm to dequantize the weights.
const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4));
int32x4_t acc_0 = sb_acc[0];
acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]);
int32x4_t acc_2 = sb_acc[2];
acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]);
const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0));
int32x4_t acc_1 = sb_acc[1];
acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]);
int32x4_t acc_3 = sb_acc[3];
acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]);
// Repeat for the other 3 columns (8..15, 16..23, 24..31)
uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3);
uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone);
qh[cp][1] = vshrq_n_u8(qh[cp][1], 2);
const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4));
acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]);
acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]);
const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1));
acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]);
acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]);
uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3);
uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone);
qh[cp][2] = vshrq_n_u8(qh[cp][2], 2);
const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4));
acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]);
acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]);
const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2));
acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]);
acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]);
uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone);
uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3);
qh[cp][3] = vshrq_n_u8(qh[cp][3], 2);
const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4));
acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]);
sb_acc[0] = acc_0;
acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]);
sb_acc[2] = acc_2;
// Scales[i] corresponds to column i
const int scale_offset = cp * 2;
const int32_t s0 = q5sb_scales[0][scale_offset];
const int32_t s1 = q5sb_scales[0][scale_offset + 1];
const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1));
acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale);
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale);
const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3));
acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]);
sb_acc[1] = acc_1;
acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]);
sb_acc[3] = acc_3;
const int32_t s2 = q5sb_scales[1][scale_offset];
const int32_t s3 = q5sb_scales[1][scale_offset + 1];
const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3));
acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2);
acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2);
}
// Multiply Acc bsum + mins
for (int q8_row = 0; q8_row < 4; q8_row++) {
// Each pair of subblocks share the same bsums
// Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
bias_acc[2 * q8_row] =
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
bias_acc[2 * q8_row] =
vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
bias_acc[2 * q8_row + 1] =
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
bias_acc[2 * q8_row + 1] =
vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
}
} // for sb
// Reorder of i8mm output with bias and output layout
for (int i = 0; i < 8; i++) {
int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
}
int32x4_t reorder_acc[8] = {
vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
};
for (int i = 0; i < q8_k_blocklen; i++) {
for (int j = 0; j < 2; j++) {
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4)));
const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d);
float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4)));
const float32x4_t scale = vmulq_f32(q5_d, q8_d);
acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
acc_f32[2 * i + j] =
vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
}
}
} // for b
// With the previous reorder, the tile is already in the correct memory layout.
for (int i = 0; i < q8_k_blocklen; i++) {
int row = y * q8_k_blocklen + i;
for (int j = 0; j < 2; j++) {
int col = x * ncols_interleaved + j * 4;
int offset = row * bs + col;
vst1q_f32(s + offset, acc_f32[2 * i + j]);
}
}
} // for x
} // for y
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q8_0_4x4_q8_0(int n,
float * GGML_RESTRICT s,

View File

@ -474,15 +474,8 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
assert (n % qk == 0);
assert (nc % ncols_interleaved == 0);
UNUSED(s);
UNUSED(bs);
UNUSED(vx);
UNUSED(vy);
UNUSED(nr);
UNUSED(nc);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
float sumf[8];
float sum_minf[8];
@ -616,6 +609,100 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
}
}
void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 8;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[8];
float sum_minf[8];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0;
sum_minf[j] = 0.0;
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
const int qh_shift = (k / 4) * 2;
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
const int qh_idx = (k * 8 + i) % 32;
const int qh_chunk = qh_idx / 8;
const int qh_pos = qh_idx % 8;
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
const uint8_t h0 = (qh_val >> qh_shift) & 1;
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i;
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for (int j = 0; j < ncols_interleaved; j++) {
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) *
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
}
}
}
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
@ -1212,6 +1299,108 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
}
}
void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 8;
constexpr uint32_t kmask1 = 0x3f3f3f3f;
constexpr uint32_t kmask2 = 0x0f0f0f0f;
constexpr uint32_t kmask3 = 0x03030303;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
float sumf[4][8];
float sum_minf[4][8];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0;
sum_minf[m][j] = 0.0;
}
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32;
uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16;
const int qh_shift = (k / 4) * 2;
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i;
const int qh_idx = (k * 8 + i) % 32;
const int qh_chunk = qh_idx / 8;
const int qh_pos = qh_idx % 8;
const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos;
const uint8_t qh_val = b_ptr[l].qh[b_qh_offset];
const uint8_t h0 = (qh_val >> qh_shift) & 1;
const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1;
const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4));
const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4));
const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i;
sumi1 = (v0 * a_ptr[l].qs[q8_offset]);
sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
}
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for (int m = 0; m < 4; m++) {
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
for (int j = 0; j < ncols_interleaved; j++) {
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) *
GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
}
}
}
}
}
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
@ -1622,7 +1811,95 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in
out.scales[i] = in[src1].scales[src2];
}
return out;
}
static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) {
block_q5_Kx8 out;
//Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure
for (int i = 0; i < 8; i++) {
out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
}
for (int i = 0; i < 8; i++) {
out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
}
const int end = QK_K * 4 / blck_size_interleave;
// Interleave Q5_K quants by taking 8 bytes at a time
for (int i = 0; i < end; ++i) {
int src_id = i % 8;
int src_offset = (i / 8) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint64_t elems;
memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
}
// Repeat for low bits 8 bytes at a time as well, since
// the high bits are interleaved in Q5_K and the index is
// qh_idx = (qs_idx % 32);
// qh_val = qh[qh_idx] >> (qs_idx / 32);
for (int i = 0; i < end / 4; ++i) {
int src_id = i % 8;
int src_offset = (i / 8) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint64_t elems;
memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t));
memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t));
}
// The below logic is copied over from Q4_K
// The point is to unpack all the scales and mins for each sub block every time we load 12 bytes.
// Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
// The output Q5_Kx8 structure has 96 bytes
// Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure
// For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures
uint8_t s[8], m[8];
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 8; j++) {
s[j] = in[j].scales[i] & 63;
m[j] = in[j].scales[i + 4] & 63;
}
out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2);
out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2);
out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2);
out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2);
out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2);
out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2);
out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2);
out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2);
out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4);
out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4);
out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
}
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 8; j++) {
s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15);
m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4);
}
out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
}
return out;
}
static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
@ -1718,6 +1995,38 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block
GGML_UNUSED(data_size);
}
static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
int interleave_block,
const void * GGML_RESTRICT data,
size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q5_K);
GGML_ASSERT(interleave_block == 8);
constexpr int nrows_interleaved = 8;
block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data;
const block_q5_K * src = (const block_q5_K *) data;
block_q5_K dst_tmp[8];
int nrow = ggml_nrows(t);
int nblocks = t->ne[0] / QK_K;
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K));
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
return -1;
}
for (int b = 0; b < nrow; b += nrows_interleaved) {
for (int64_t x = 0; x < nblocks; x++) {
for (int i = 0; i < nrows_interleaved; i++) {
dst_tmp[i] = src[x + i * nblocks];
}
*dst++ = make_block_q5_Kx8(dst_tmp, interleave_block);
}
src += nrows_interleaved * nblocks;
}
return 0;
}
static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(interleave_block == 8);
@ -1936,6 +2245,10 @@ template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * da
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
}
template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
}
template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
}
@ -1973,6 +2286,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
@ -1981,8 +2298,8 @@ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
@ -2013,20 +2330,24 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
@ -2432,6 +2753,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
// instance for Q5_K
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
// instance for Q2
static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;
@ -2482,6 +2806,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
return &q2_K_8x8_q8_K;
}
}
} else if (cur->type == GGML_TYPE_Q5_K) {
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
if (cur->ne[1] % 8 == 0) {
return &q5_K_8x8_q8_K;
}
}
} else if (cur->type == GGML_TYPE_IQ4_NL) {
if (ggml_cpu_has_avx2()) {
if (cur->ne[1] % 8 == 0) {

View File

@ -44,6 +44,7 @@ struct block_q4_Kx8 {
};
static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding");
struct block_q2_Kx8 {
ggml_half d[8]; // super-block scale for quantized scales
ggml_half dmin[8]; // super-block scale for quantized mins
@ -52,6 +53,18 @@ struct block_q2_Kx8 {
};
static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding");
struct block_q5_Kx8 {
ggml_half d[8]; // super-block scale for quantized scales
ggml_half dmin[8]; // super-block scale for quantized mins
uint8_t scales[96]; // scales and mins, quantized with 6 bits
uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants
uint8_t qs[QK_K * 8 / 2]; // low bits of 5-bit quants (in groups of 4)
};
static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5,
"wrong q5_K block size/padding");
struct block_q8_Kx4 {
float d[4]; // delta
int8_t qs[QK_K * 4]; // quants
@ -82,20 +95,22 @@ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTR
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -111,17 +126,19 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);