gemm finished

This commit is contained in:
Alberto Cabrera 2026-02-04 19:12:10 +00:00
parent 28fb08937a
commit 8a5e84cb5b
1 changed files with 22 additions and 10 deletions

View File

@ -806,7 +806,7 @@ void ggml_gemv_q5_K_8x4_q8_K(int n,
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
const uint8x16_t m4b = vdupq_n_u8(0x0f);
const uint8x16_t mone = vdupq_n_u8(1);
const uint8x16_t mtwo = vdupq_n_u8(2);
@ -844,6 +844,7 @@ void ggml_gemv_q5_K_8x4_q8_K(int n,
int16_t bsums_arr[8];
vst1q_s16(bsums_arr, bsums);
// Preload to maximize qh reuse
uint8x16_t qh[col_groups][8];
for (int c = 0; c < col_groups; c++) {
for (int i = 0; i < 8; i++) {
@ -878,6 +879,8 @@ void ggml_gemv_q5_K_8x4_q8_K(int n,
uint8x16_t hbit_hi[8];
int8x16_t q5_lo[8];
int8x16_t q5_hi[8];
// Already tried unrolling this loop, no perf difference
// Compiler seems to be able to unroll and schedule well enough
for (int i = 0; i < 8; i++) {
q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
hbit_lo[i] = vandq_u8(qh[c][i], mone);
@ -3175,7 +3178,13 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
void ggml_gemm_q5_K_8x4_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
@ -3192,13 +3201,13 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int q8_k_blocklen = 4;
constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
constexpr int acc_size = 2 * 4; // 2 row pairs, 4 col pairs
constexpr int col_groups = ncols_interleaved / 4;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
const uint8x16_t mone = vdupq_n_u8(1);
const uint8x16_t mtwo = vdupq_n_u8(2);
// 8 accumulators: 2 row pairs × 4 col pairs
// 8 accumulators: 2 row pairs, 4 col pairs
float32x4_t acc_f32[acc_size];
for (int y = 0; y < nr / q8_k_blocklen; y++) {
@ -3212,7 +3221,7 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
}
for (int b = 0; b < nb; b++) {
// d4 0 1 2 3, 4 5 6 7
// d5 0 1 2 3, 4 5 6 7
float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));
float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));
// d8 0 1 2 3
@ -3273,7 +3282,7 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
}
for (int sb = 0; sb < QK_K / 64; sb++) {
// Int accumulators for qs vecdot (4 row x 2 col quartets)
// Int accumulators for qs vecdot (4 row * 2 col quartets)
int32x4_t acc_lo[acc_size];
int32x4_t acc_hi[acc_size];
for (int i = 0; i < acc_size; i++) {
@ -3300,21 +3309,19 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k);
const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16);
// NOTE: This is the only difference with q4_K
const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone);
const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3);
qh[0][k] = vshrq_n_u8(qh[0][k], 2);
const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone);
const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3);
qh[1][k] = vshrq_n_u8(qh[1][k], 2);
// From here, same as q4_K
const int8x16_t q5_0123_lo =
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4));
const int8x16_t q5_0123_hi =
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123));
const int8x16_t q5_4567_lo =
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));
const int8x16_t q5_4567_hi =
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
@ -3326,6 +3333,11 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
const int8x16_t q5_4567_lo =
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));
const int8x16_t q5_4567_hi =
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567