From 8a5e84cb5bceaddbdb660353c44560e5118a476a Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Wed, 4 Feb 2026 19:12:10 +0000 Subject: [PATCH] gemm finished --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 32 ++++++++++++++++++--------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 48f6f3e4bb..4801abb13c 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -806,7 +806,7 @@ void ggml_gemv_q5_K_8x4_q8_K(int n, UNUSED(blocklen); #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567 + constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567 const uint8x16_t m4b = vdupq_n_u8(0x0f); const uint8x16_t mone = vdupq_n_u8(1); const uint8x16_t mtwo = vdupq_n_u8(2); @@ -844,6 +844,7 @@ void ggml_gemv_q5_K_8x4_q8_K(int n, int16_t bsums_arr[8]; vst1q_s16(bsums_arr, bsums); + // Preload to maximize qh reuse uint8x16_t qh[col_groups][8]; for (int c = 0; c < col_groups; c++) { for (int i = 0; i < 8; i++) { @@ -878,6 +879,8 @@ void ggml_gemv_q5_K_8x4_q8_K(int n, uint8x16_t hbit_hi[8]; int8x16_t q5_lo[8]; int8x16_t q5_hi[8]; + // Already tried unrolling this loop, no perf difference + // Compiler seems to be able to unroll and schedule well enough for (int i = 0; i < 8; i++) { q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c); hbit_lo[i] = vandq_u8(qh[c][i], mone); @@ -3175,7 +3178,13 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemm_q5_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { constexpr int qk = QK_K; const int nb = n / qk; @@ -3192,13 +3201,13 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) constexpr int q8_k_blocklen = 4; - constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs + constexpr int acc_size = 2 * 4; // 2 row pairs, 4 col pairs constexpr int col_groups = ncols_interleaved / 4; const uint8x16_t m4b = vdupq_n_u8(0x0f); const uint8x16_t mone = vdupq_n_u8(1); const uint8x16_t mtwo = vdupq_n_u8(2); - // 8 accumulators: 2 row pairs × 4 col pairs + // 8 accumulators: 2 row pairs, 4 col pairs float32x4_t acc_f32[acc_size]; for (int y = 0; y < nr / q8_k_blocklen; y++) { @@ -3212,7 +3221,7 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } for (int b = 0; b < nb; b++) { - // d4 0 1 2 3, 4 5 6 7 + // d5 0 1 2 3, 4 5 6 7 float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d8 0 1 2 3 @@ -3273,7 +3282,7 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } for (int sb = 0; sb < QK_K / 64; sb++) { - // Int accumulators for qs vecdot (4 row x 2 col quartets) + // Int accumulators for qs vecdot (4 row * 2 col quartets) int32x4_t acc_lo[acc_size]; int32x4_t acc_hi[acc_size]; for (int i = 0; i < acc_size; i++) { @@ -3300,21 +3309,19 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k); const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16); + // NOTE: This is the only difference with q4_K const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone); const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3); qh[0][k] = vshrq_n_u8(qh[0][k], 2); const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone); const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3); qh[1][k] = vshrq_n_u8(qh[1][k], 2); + // From here, same as q4_K const int8x16_t q5_0123_lo = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4)); const int8x16_t q5_0123_hi = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123)); - const int8x16_t q5_4567_lo = - vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4)); - const int8x16_t q5_4567_hi = - vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567)); acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123 acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123 @@ -3326,6 +3333,11 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123 acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123 + const int8x16_t q5_4567_lo = + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4)); + const int8x16_t q5_4567_hi = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567)); + acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567 acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567 acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567