GEMM implementations

This commit is contained in:
Alberto Cabrera 2025-12-12 12:15:33 +00:00
parent fbe5fd4025
commit 0d14b67763
1 changed files with 105 additions and 1 deletions

View File

@ -2757,7 +2757,49 @@ void ggml_gemm_q8_0_4x4_q8_0(int n,
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
// TODO: Implement ARM NEON DOTPROD kernel for q8_0 × q8_0 GEMM
for (int y = 0; y < nr / 4; y++) {
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
float32x4_t sumf[4];
for (int m = 0; m < 4; m++) {
sumf[m] = vdupq_n_f32(0);
}
for (int l = 0; l < nb; l++) {
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *) a_ptr[l].d));
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *) b_ptr[l].d));
int32x4_t sumi_0 = vdupq_n_s32(0);
int32x4_t sumi_1 = vdupq_n_s32(0);
int32x4_t sumi_2 = vdupq_n_s32(0);
int32x4_t sumi_3 = vdupq_n_s32(0);
for (int k_group = 0; k_group < 8; k_group += 4) {
int8x16x4_t a = vld1q_s8_x4(a_ptr[l].qs + 16 * k_group);
int8x16x4_t b = vld1q_s8_x4(b_ptr[l].qs + 16 * k_group);
for (int k = 0; k < 4; k++) {
sumi_0 = vdotq_laneq_s32(sumi_0, b.val[k], a.val[k], 0);
sumi_1 = vdotq_laneq_s32(sumi_1, b.val[k], a.val[k], 1);
sumi_2 = vdotq_laneq_s32(sumi_2, b.val[k], a.val[k], 2);
sumi_3 = vdotq_laneq_s32(sumi_3, b.val[k], a.val[k], 3);
}
}
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
}
for (int m = 0; m < 4; m++) {
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
}
}
}
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemm_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
}
@ -2783,6 +2825,68 @@ void ggml_gemm_q8_0_4x8_q8_0(int n,
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
for (int y = 0; y < nr; y += 4) {
const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
for (int x = 0; x < nc; x += ncols_interleaved) {
const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
const block_q8_0x4 * a_ptr = a_ptr_base;
float32x4_t acc_f32[4];
for (int i = 0; i < 4; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
int32x4_t acc[4];
for (int i = 0; i < 4; i++) {
acc[i] = vdupq_n_s32(0);
}
// Process 4 chunks of 8 positions each
for (int chunk = 0; chunk < 4; chunk++) {
int8x16_t a01 = vld1q_s8(a_ptr->qs + chunk * 32);
int8x16_t a23 = vld1q_s8(a_ptr->qs + chunk * 32 + 16);
int8x16_t b01 = vld1q_s8(b_ptr->qs + chunk * 32);
int8x16_t b23 = vld1q_s8(b_ptr->qs + chunk * 32 + 16);
acc[0] = vmmlaq_s32(acc[0], a01, b01);
acc[1] = vmmlaq_s32(acc[1], a01, b23);
acc[2] = vmmlaq_s32(acc[2], a23, b01);
acc[3] = vmmlaq_s32(acc[3], a23, b23);
}
// Reorder outputs from 2×2 tiles to row-major
// acc[0] = [r0c0, r0c1, r1c0, r1c1]
// acc[1] = [r0c2, r0c3, r1c2, r1c3]
// acc[2] = [r2c0, r2c1, r3c0, r3c1]
// acc[3] = [r2c2, r2c3, r3c2, r3c3]
int32x4_t row0 = vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1]));
int32x4_t row1 = vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1]));
int32x4_t row2 = vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3]));
int32x4_t row3 = vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3]));
// Scales
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const __fp16 *) a_ptr->d));
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const __fp16 *) b_ptr->d));
acc_f32[0] = vfmaq_f32(acc_f32[0], vcvtq_f32_s32(row0), vmulq_laneq_f32(b_d, a_d, 0));
acc_f32[1] = vfmaq_f32(acc_f32[1], vcvtq_f32_s32(row1), vmulq_laneq_f32(b_d, a_d, 1));
acc_f32[2] = vfmaq_f32(acc_f32[2], vcvtq_f32_s32(row2), vmulq_laneq_f32(b_d, a_d, 2));
acc_f32[3] = vfmaq_f32(acc_f32[3], vcvtq_f32_s32(row3), vmulq_laneq_f32(b_d, a_d, 3));
a_ptr++;
b_ptr++;
}
for (int row = 0; row < 4; row++) {
vst1q_f32(s + (y + row) * bs + x, acc_f32[row]);
}
}
}
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
ggml_gemm_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
}