improve code quality
This commit is contained in:
parent
49aa628d6b
commit
da606bd736
|
|
@ -217,7 +217,7 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
|
||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||
|
||||
#if defined(__ARM_NEON)
|
||||
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
const int blck_size_interleave = 8;
|
||||
float32x4_t srcv[4][64]; // 64 = QK_K/4
|
||||
float iscale[4];
|
||||
|
|
@ -334,11 +334,10 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
|
||||
#else
|
||||
// NOTE:
|
||||
// Current C implementation is actually aligned with x86 AVX2 design, but differs from above ARM NEON design.
|
||||
// This is because the [bsums] layout is different in block_q8_Kx4 for the 2 designs.
|
||||
// As NEON is supported in almost all the modern ARM platforms, this generic path can be rarely arrived nowadays.
|
||||
// (The exceptional cases aren't suitable for AI work)
|
||||
// However logically we may still need a corresponding generic version for ARM, called xxx_generic_arm for example.
|
||||
// Current C impl of Q8_K quanti is originally designed to work with block_q4_Kx8 in x86 AVX design, and differs from
|
||||
// above Q8_K quanti logic in AArch64 NEON design, which is designed to work with block_q4_Kx4. The main difference is in
|
||||
// the process of their "[bsums] layout". Hoever, we can still reuse the x86 C impl for AArch64, as long as we access the
|
||||
// "[bsums] layout" correctly in ggml_gemm_q4_K_4x8_q8_K_generic().
|
||||
UNUSED(nb);
|
||||
UNUSED(y);
|
||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
|
||||
|
|
@ -564,6 +563,76 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 4;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
float * res_ptr = s;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
||||
|
||||
float32x4_t sumf = vdupq_n_f32(0);
|
||||
for (int l = 0; l < nb; l++) {
|
||||
uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
|
||||
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
|
||||
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
|
||||
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
|
||||
|
||||
int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
|
||||
int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
|
||||
int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
|
||||
int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
|
||||
int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
|
||||
int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
|
||||
int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
|
||||
int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
|
||||
|
||||
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
|
||||
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
|
||||
|
||||
int32x4_t sumi = vdupq_n_s32(0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
|
||||
sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
|
||||
sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
|
||||
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
|
||||
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
|
||||
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
|
||||
|
||||
float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
|
||||
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
|
||||
float32x4_t d = a_d * b_d;
|
||||
|
||||
sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
|
||||
}
|
||||
|
||||
vst1q_f32(res_ptr + x * 4, sumf);
|
||||
}
|
||||
return;
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -583,7 +652,7 @@ void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
const block_q4_Kx4 *GGML_RESTRICT q4 = (const block_q4_Kx4*) vx;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||
|
|
@ -683,76 +752,6 @@ void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 4;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
|
||||
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
||||
float * res_ptr = s;
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
||||
|
||||
float32x4_t sumf = vdupq_n_f32(0);
|
||||
for (int l = 0; l < nb; l++) {
|
||||
uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
|
||||
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
|
||||
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
|
||||
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
|
||||
|
||||
int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
|
||||
int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
|
||||
int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
|
||||
int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
|
||||
int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
|
||||
int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
|
||||
int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
|
||||
int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
|
||||
|
||||
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
|
||||
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
|
||||
|
||||
int32x4_t sumi = vdupq_n_s32(0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
|
||||
sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
|
||||
sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
|
||||
sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
|
||||
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
|
||||
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
|
||||
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
|
||||
|
||||
float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
|
||||
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
|
||||
float32x4_t d = a_d * b_d;
|
||||
|
||||
sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
|
||||
}
|
||||
|
||||
vst1q_f32(res_ptr + x * 4, sumf);
|
||||
}
|
||||
return;
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -2507,6 +2506,82 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 4;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nr % 4 == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
||||
|
||||
float32x4_t sumf[4];
|
||||
for (int m = 0; m < 4; m++) {
|
||||
sumf[m] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
|
||||
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
|
||||
|
||||
int32x4_t sumi_0 = vdupq_n_s32(0);
|
||||
int32x4_t sumi_1 = vdupq_n_s32(0);
|
||||
int32x4_t sumi_2 = vdupq_n_s32(0);
|
||||
int32x4_t sumi_3 = vdupq_n_s32(0);
|
||||
|
||||
for (int k = 0; k < 4; k++) {
|
||||
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
|
||||
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
|
||||
|
||||
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
|
||||
int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
|
||||
int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
|
||||
|
||||
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
|
||||
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
|
||||
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
|
||||
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
|
||||
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
|
||||
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
|
||||
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
|
||||
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
|
||||
}
|
||||
|
||||
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
|
||||
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
|
||||
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
|
||||
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
|
||||
}
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -2527,7 +2602,7 @@ void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr_start = (const block_q8_Kx4 *) vy;
|
||||
const block_q4_Kx4 * GGML_RESTRICT q4_ptr_start = (const block_q4_Kx4 *) vx;
|
||||
|
||||
|
|
@ -2800,82 +2875,6 @@ void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 4;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nr % 4 == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
|
||||
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
|
||||
|
||||
float32x4_t sumf[4];
|
||||
for (int m = 0; m < 4; m++) {
|
||||
sumf[m] = vdupq_n_f32(0);
|
||||
}
|
||||
|
||||
for (int l = 0; l < nb; l++) {
|
||||
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
|
||||
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
|
||||
|
||||
int32x4_t sumi_0 = vdupq_n_s32(0);
|
||||
int32x4_t sumi_1 = vdupq_n_s32(0);
|
||||
int32x4_t sumi_2 = vdupq_n_s32(0);
|
||||
int32x4_t sumi_3 = vdupq_n_s32(0);
|
||||
|
||||
for (int k = 0; k < 4; k++) {
|
||||
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
|
||||
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
|
||||
|
||||
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
|
||||
int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
|
||||
int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
|
||||
|
||||
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
|
||||
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
|
||||
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
|
||||
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
|
||||
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
|
||||
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
|
||||
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
|
||||
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
|
||||
}
|
||||
|
||||
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
|
||||
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
|
||||
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
|
||||
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
|
||||
}
|
||||
|
||||
for (int m = 0; m < 4; m++) {
|
||||
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
||||
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
|
|
|||
|
|
@ -1099,11 +1099,14 @@ void ggml_gemm_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
}
|
||||
|
||||
// prepare partial results for tail work
|
||||
//
|
||||
// NOTE:
|
||||
// the "[bsums] layout" here is from ggml_quantize_mat_q8_K_4x8_generic().
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
for (int i = 0; i < QK_K / 32; i++) {
|
||||
const int16_t bsums = a_ptr[n].bsums[i * 2 + m * 16] + a_ptr[n].bsums[i * 2 + 1 + m * 16];
|
||||
sum_minf[m][j] += mins[j][i] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d[m];
|
||||
const int16_t *bsums = a_ptr[n].bsums + (i * 8) - ((i % 2) * 6) + (m * 4);
|
||||
sum_minf[m][j] += mins[j][i] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2314,10 +2317,6 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
|||
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
|
@ -2326,6 +2325,10 @@ template <> void gemm<block_q4_K, 8, 4, GGML_TYPE_Q8_K>(int n, float * s, size_t
|
|||
ggml_gemm_q4_K_4x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
|
@ -2866,9 +2869,9 @@ static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buf
|
|||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
// For the aarch64 q4_k repack implementation only:
|
||||
// New tensor storage size calculation after q4_kx4 repacking.
|
||||
// Because sizeof(block_q4_Kx4) is a bit larger than default sizeof(block_q4_K)*4.
|
||||
// Below func is for the aarch64 q4_K_4x8_q8_K repack case only:
|
||||
// Tensor storage after repacking is a bit larger than before -- sizeof(block_q4_Kx4) > sizeof(block_q4_K)*4
|
||||
// This is due to member "scales" are pre-decoded in repacking stage, not in execution stage.
|
||||
static inline size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) {
|
||||
size_t nbytes;
|
||||
const size_t blck_size = 256;
|
||||
|
|
@ -2880,7 +2883,9 @@ static inline size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) {
|
|||
static size_t ggml_backend_cpu_aarch64_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
||||
if (tensor->type == GGML_TYPE_Q4_K) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
return ggml_nbytes_q4_kx4(tensor);
|
||||
if (tensor->ne[1] % 4 == 0) {
|
||||
return ggml_nbytes_q4_kx4(tensor); // for q4_K_4x8_q8_K only
|
||||
}
|
||||
}
|
||||
}
|
||||
return ggml_nbytes(tensor);
|
||||
|
|
|
|||
|
|
@ -39,35 +39,8 @@ using block_q8_0x8 = block<8, 8>;
|
|||
struct block_q4_Kx4 {
|
||||
ggml_half d[4]; // super-block scale for quantized scales
|
||||
ggml_half dmin[4]; // super-block scale for quantized mins
|
||||
int8_t scales[64]; // scales and mins, quantized with 8 bits (recover from 6-bit during repack) TODO: consider if uint8_t?
|
||||
int8_t scales[64]; // scales and mins, quantized with 8 bits (recover from 6-bit during repack)
|
||||
uint8_t qs[512]; // 4--bit quants
|
||||
|
||||
/********************************************************layout***************************************************************/
|
||||
// low <-------------------------------------------------------------------------------------> high
|
||||
//
|
||||
// d: |s0|s1|s2|s3|
|
||||
//
|
||||
// dmin: |s0|s1|s2|s3|
|
||||
//
|
||||
// scales: |-------- d --------|-------- m --------|
|
||||
// |s0b0|s1b0|s2b0|s3b0|s0b0|s1b0|s2b0|s3b0|
|
||||
// |s0b1|s1b1|s2b1|s3b1|s0b1|s1b1|s2b1|s3b1|
|
||||
// ......
|
||||
// |s0b7|s1b7|s2b7|s3b7|s0b7|s1b7|s2b7|s3b7|
|
||||
//
|
||||
// qs: <from block0 of all>
|
||||
// |s0w0 |s0w16|s0w1 |s0w17|s0w2 |s0w18|s0w3 |s0w19|s0w4 |s0w20|s0w5 |s0w21|s0w6 |s0w22|s0w7 |s0w23| --- 8B from s0
|
||||
// |s1w0 |s1w16|s1w1 |s1w17|s1w2 |s1w18|s1w3 |s1w19|s1w4 |s1w20|s1w5 |s1w21|s1w6 |s1w22|s1w7 |s1w23| --- 8B from s1
|
||||
// |s2w0 |s2w16|s2w1 |s2w17|s2w2 |s2w18|s2w3 |s2w19|s2w4 |s2w20|s2w5 |s2w21|s2w6 |s2w22|s2w7 |s2w23| --- 8B from s2
|
||||
// |s3w0 |s3w16|s3w1 |s3w17|s3w2 |s3w18|s3w3 |s3w19|s3w4 |s3w20|s3w5 |s3w21|s3w6 |s3w22|s3w7 |s3w23| --- 8B from s3
|
||||
// |s0w8 |s0w24|s0w9 |s0w25|s0w10|s0w26|s0w11|s0w27|s0w12|s0w28|s0w13|s0w29|s0w14|s0w30|s0w15|s0w31| --- 8B from s0
|
||||
// |s1w8 |s1w24|s1w9 |s1w25|s1w10|s1w26|s1w11|s1w27|s1w12|s1w28|s1w13|s1w29|s1w14|s1w30|s1w15|s1w31| --- 8B from s1
|
||||
// |s2w8 |s2w24|s2w9 |s2w25|s2w10|s2w26|s2w11|s2w27|s2w12|s2w28|s2w13|s2w29|s2w14|s2w30|s2w15|s2w31| --- 8B from s2
|
||||
// |s3w8 |s3w24|s3w9 |s3w25|s3w10|s3w26|s3w11|s3w27|s3w12|s3w28|s3w13|s3w29|s3w14|s3w30|s3w15|s3w31| --- 8B from s3
|
||||
// <from block1 of all>
|
||||
// ......
|
||||
// <from block7 off all>
|
||||
/*****************************************************************************************************************************/
|
||||
};
|
||||
|
||||
struct block_q4_Kx8 {
|
||||
|
|
|
|||
Loading…
Reference in New Issue