gemm finished
This commit is contained in:
parent
28fb08937a
commit
8a5e84cb5b
|
|
@ -806,7 +806,7 @@ void ggml_gemv_q5_K_8x4_q8_K(int n,
|
||||||
UNUSED(blocklen);
|
UNUSED(blocklen);
|
||||||
|
|
||||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
|
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||||
const uint8x16_t mone = vdupq_n_u8(1);
|
const uint8x16_t mone = vdupq_n_u8(1);
|
||||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||||
|
|
@ -844,6 +844,7 @@ void ggml_gemv_q5_K_8x4_q8_K(int n,
|
||||||
int16_t bsums_arr[8];
|
int16_t bsums_arr[8];
|
||||||
vst1q_s16(bsums_arr, bsums);
|
vst1q_s16(bsums_arr, bsums);
|
||||||
|
|
||||||
|
// Preload to maximize qh reuse
|
||||||
uint8x16_t qh[col_groups][8];
|
uint8x16_t qh[col_groups][8];
|
||||||
for (int c = 0; c < col_groups; c++) {
|
for (int c = 0; c < col_groups; c++) {
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
|
|
@ -878,6 +879,8 @@ void ggml_gemv_q5_K_8x4_q8_K(int n,
|
||||||
uint8x16_t hbit_hi[8];
|
uint8x16_t hbit_hi[8];
|
||||||
int8x16_t q5_lo[8];
|
int8x16_t q5_lo[8];
|
||||||
int8x16_t q5_hi[8];
|
int8x16_t q5_hi[8];
|
||||||
|
// Already tried unrolling this loop, no perf difference
|
||||||
|
// Compiler seems to be able to unroll and schedule well enough
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
|
q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
|
||||||
hbit_lo[i] = vandq_u8(qh[c][i], mone);
|
hbit_lo[i] = vandq_u8(qh[c][i], mone);
|
||||||
|
|
@ -3175,7 +3178,13 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
void ggml_gemm_q5_K_8x4_q8_K(int n,
|
||||||
|
float * GGML_RESTRICT s,
|
||||||
|
size_t bs,
|
||||||
|
const void * GGML_RESTRICT vx,
|
||||||
|
const void * GGML_RESTRICT vy,
|
||||||
|
int nr,
|
||||||
|
int nc) {
|
||||||
constexpr int qk = QK_K;
|
constexpr int qk = QK_K;
|
||||||
const int nb = n / qk;
|
const int nb = n / qk;
|
||||||
|
|
||||||
|
|
@ -3192,13 +3201,13 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
|
|
||||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||||
constexpr int q8_k_blocklen = 4;
|
constexpr int q8_k_blocklen = 4;
|
||||||
constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
|
constexpr int acc_size = 2 * 4; // 2 row pairs, 4 col pairs
|
||||||
constexpr int col_groups = ncols_interleaved / 4;
|
constexpr int col_groups = ncols_interleaved / 4;
|
||||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||||
const uint8x16_t mone = vdupq_n_u8(1);
|
const uint8x16_t mone = vdupq_n_u8(1);
|
||||||
const uint8x16_t mtwo = vdupq_n_u8(2);
|
const uint8x16_t mtwo = vdupq_n_u8(2);
|
||||||
|
|
||||||
// 8 accumulators: 2 row pairs × 4 col pairs
|
// 8 accumulators: 2 row pairs, 4 col pairs
|
||||||
float32x4_t acc_f32[acc_size];
|
float32x4_t acc_f32[acc_size];
|
||||||
|
|
||||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||||
|
|
@ -3212,7 +3221,7 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int b = 0; b < nb; b++) {
|
for (int b = 0; b < nb; b++) {
|
||||||
// d4 0 1 2 3, 4 5 6 7
|
// d5 0 1 2 3, 4 5 6 7
|
||||||
float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));
|
float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d));
|
||||||
float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));
|
float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4));
|
||||||
// d8 0 1 2 3
|
// d8 0 1 2 3
|
||||||
|
|
@ -3273,7 +3282,7 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||||
// Int accumulators for qs vecdot (4 row x 2 col quartets)
|
// Int accumulators for qs vecdot (4 row * 2 col quartets)
|
||||||
int32x4_t acc_lo[acc_size];
|
int32x4_t acc_lo[acc_size];
|
||||||
int32x4_t acc_hi[acc_size];
|
int32x4_t acc_hi[acc_size];
|
||||||
for (int i = 0; i < acc_size; i++) {
|
for (int i = 0; i < acc_size; i++) {
|
||||||
|
|
@ -3300,21 +3309,19 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k);
|
const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k);
|
||||||
const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16);
|
const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16);
|
||||||
|
|
||||||
|
// NOTE: This is the only difference with q4_K
|
||||||
const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone);
|
const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone);
|
||||||
const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3);
|
const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3);
|
||||||
qh[0][k] = vshrq_n_u8(qh[0][k], 2);
|
qh[0][k] = vshrq_n_u8(qh[0][k], 2);
|
||||||
const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone);
|
const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone);
|
||||||
const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3);
|
const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3);
|
||||||
qh[1][k] = vshrq_n_u8(qh[1][k], 2);
|
qh[1][k] = vshrq_n_u8(qh[1][k], 2);
|
||||||
|
// From here, same as q4_K
|
||||||
|
|
||||||
const int8x16_t q5_0123_lo =
|
const int8x16_t q5_0123_lo =
|
||||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4));
|
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4));
|
||||||
const int8x16_t q5_0123_hi =
|
const int8x16_t q5_0123_hi =
|
||||||
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123));
|
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123));
|
||||||
const int8x16_t q5_4567_lo =
|
|
||||||
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));
|
|
||||||
const int8x16_t q5_4567_hi =
|
|
||||||
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));
|
|
||||||
|
|
||||||
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
|
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
|
||||||
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
|
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
|
||||||
|
|
@ -3326,6 +3333,11 @@ void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
|
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
|
||||||
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
|
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
|
||||||
|
|
||||||
|
const int8x16_t q5_4567_lo =
|
||||||
|
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4));
|
||||||
|
const int8x16_t q5_4567_hi =
|
||||||
|
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567));
|
||||||
|
|
||||||
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
|
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
|
||||||
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
|
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
|
||||||
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
|
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue