DOTPROD GEMV
This commit is contained in:
parent
aaec188927
commit
643c4d9f0a
|
|
@ -68,7 +68,6 @@
|
||||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||||
#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K
|
|
||||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||||
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K
|
||||||
|
|
|
||||||
|
|
@ -785,6 +785,164 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
||||||
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_gemv_q5_K_8x4_q8_K(int n,
|
||||||
|
float * GGML_RESTRICT s,
|
||||||
|
size_t bs,
|
||||||
|
const void * GGML_RESTRICT vx,
|
||||||
|
const void * GGML_RESTRICT vy,
|
||||||
|
int nr,
|
||||||
|
int nc) {
|
||||||
|
constexpr int qk = QK_K;
|
||||||
|
const int nb = n / qk;
|
||||||
|
|
||||||
|
constexpr int ncols_interleaved = 8;
|
||||||
|
constexpr int blocklen = 4;
|
||||||
|
|
||||||
|
assert(n % qk == 0);
|
||||||
|
assert(nc % ncols_interleaved == 0);
|
||||||
|
|
||||||
|
UNUSED(nb);
|
||||||
|
UNUSED(ncols_interleaved);
|
||||||
|
UNUSED(blocklen);
|
||||||
|
|
||||||
|
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
|
||||||
|
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||||
|
const uint8x16_t mone = vdupq_n_u8(1);
|
||||||
|
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||||
|
|
||||||
|
// 1x8 tile = 2 x 4
|
||||||
|
float32x4_t acc_f32[col_groups];
|
||||||
|
|
||||||
|
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
||||||
|
|
||||||
|
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||||
|
const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb);
|
||||||
|
|
||||||
|
for (int i = 0; i < col_groups; i++) {
|
||||||
|
acc_f32[i] = vdupq_n_f32(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int b = 0; b < nb; b++) {
|
||||||
|
float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3
|
||||||
|
float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7
|
||||||
|
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
||||||
|
float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d);
|
||||||
|
float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d);
|
||||||
|
float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3
|
||||||
|
float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7
|
||||||
|
float32x4_t sb_min_0123 = vmulq_f32(q5_dmin_0, q8_d);
|
||||||
|
float32x4_t sb_min_4567 = vmulq_f32(q5_dmin_1, q8_d);
|
||||||
|
|
||||||
|
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
|
||||||
|
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
||||||
|
int32x4_t acc_lo[col_groups];
|
||||||
|
int32x4_t acc_hi[col_groups];
|
||||||
|
|
||||||
|
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
||||||
|
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
||||||
|
int16_t bsums_arr[8];
|
||||||
|
vst1q_s16(bsums_arr, bsums);
|
||||||
|
|
||||||
|
uint8x16_t qh[col_groups][8];
|
||||||
|
for (int c = 0; c < col_groups; c++) {
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||||
|
for (int i = 0; i < col_groups; i++) {
|
||||||
|
acc_lo[i] = vdupq_n_s32(0);
|
||||||
|
acc_hi[i] = vdupq_n_s32(0);
|
||||||
|
}
|
||||||
|
// Need scales for the low and high nibbles
|
||||||
|
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||||
|
int16x8_t q5sb_mins[2];
|
||||||
|
int16x8_t q5sb_scales[2];
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
int8_t aux_q5sb[8];
|
||||||
|
const int offset = sb * 24 + i * 12;
|
||||||
|
decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb);
|
||||||
|
q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb));
|
||||||
|
}
|
||||||
|
|
||||||
|
int8x16_t q8_qs[4];
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int c = 0; c < col_groups; c++) {
|
||||||
|
uint8x16_t q5_cols[8];
|
||||||
|
uint8x16_t hbit_lo[8];
|
||||||
|
uint8x16_t hbit_hi[8];
|
||||||
|
int8x16_t q5_lo[8];
|
||||||
|
int8x16_t q5_hi[8];
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
|
||||||
|
hbit_lo[i] = vandq_u8(qh[c][i], mone);
|
||||||
|
hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3);
|
||||||
|
qh[c][i] = vshrq_n_u8(qh[c][i], 2);
|
||||||
|
q5_lo[i] = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4));
|
||||||
|
q5_hi[i] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2);
|
||||||
|
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3);
|
||||||
|
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2);
|
||||||
|
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scales
|
||||||
|
// row c0123 blk0 and blk1
|
||||||
|
const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]);
|
||||||
|
const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]);
|
||||||
|
const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
|
||||||
|
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
|
||||||
|
acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
|
||||||
|
// row c4567 blk0 and blk1
|
||||||
|
const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]);
|
||||||
|
const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]);
|
||||||
|
const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
|
||||||
|
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
|
||||||
|
acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
|
||||||
|
|
||||||
|
// Bias Correction
|
||||||
|
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
||||||
|
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
||||||
|
|
||||||
|
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0]));
|
||||||
|
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1]));
|
||||||
|
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0]));
|
||||||
|
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1]));
|
||||||
|
} // for sb
|
||||||
|
|
||||||
|
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
|
||||||
|
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
|
||||||
|
} // for b
|
||||||
|
|
||||||
|
int base = x * ncols_interleaved;
|
||||||
|
vst1q_f32(s + base, acc_f32[0]);
|
||||||
|
vst1q_f32(s + base + 4, acc_f32[1]);
|
||||||
|
} // for x
|
||||||
|
return;
|
||||||
|
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_gemv_q5_K_8x8_q8_K(int n,
|
void ggml_gemv_q5_K_8x8_q8_K(int n,
|
||||||
float * GGML_RESTRICT s,
|
float * GGML_RESTRICT s,
|
||||||
size_t bs,
|
size_t bs,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue