Merge 41780fe07e into 18ddaea2ae
This commit is contained in:
commit
779fc296bc
|
|
@ -38,6 +38,7 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
|
|
@ -48,6 +49,7 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
|
|
@ -58,7 +60,6 @@
|
|||
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
||||
|
|
@ -69,12 +70,14 @@
|
|||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
|
||||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
||||
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
|
||||
|
|
@ -94,6 +97,7 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
|
|
@ -104,6 +108,7 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
|
|
@ -126,6 +131,7 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
|
|
@ -136,6 +142,7 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
|
|
@ -165,6 +172,7 @@
|
|||
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
||||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
|
|
@ -174,6 +182,7 @@
|
|||
#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
|
|
@ -202,6 +211,7 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
|
|
@ -212,6 +222,7 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
|
|
@ -242,6 +253,7 @@
|
|||
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
|
||||
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
|
||||
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
|
||||
#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K
|
||||
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
|
||||
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
||||
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
||||
|
|
@ -252,6 +264,7 @@
|
|||
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
||||
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
||||
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
||||
#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K
|
||||
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
|
||||
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
|
|
|
|||
|
|
@ -210,6 +210,147 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) {
|
||||
assert(QK_K == 256);
|
||||
assert(k % QK_K == 0);
|
||||
UNUSED(nc);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||
|
||||
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
if (nc % 8 == 0) {
|
||||
UNUSED(nb);
|
||||
UNUSED(y);
|
||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc);
|
||||
} else if (nc % 4 == 0) {
|
||||
const int blck_size_interleave = 8;
|
||||
float32x4_t srcv[4][64]; // 64 = QK_K/4
|
||||
float iscale[4];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float32x4_t asrcv[64];
|
||||
float32x4_t amaxv[64];
|
||||
|
||||
// d:
|
||||
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
||||
for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j);
|
||||
for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
|
||||
|
||||
for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
|
||||
for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
|
||||
for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
|
||||
for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]);
|
||||
for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]);
|
||||
for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]);
|
||||
|
||||
const float amax = vmaxvq_f32(amaxv[0]);
|
||||
|
||||
// Check if exists: orig == amax
|
||||
float32x4_t amax_vec = vdupq_n_f32(amax);
|
||||
uint32x4_t mask_all = vdupq_n_u32(0);
|
||||
for (int j = 0; j < 64; j++) {
|
||||
uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]);
|
||||
mask_all = vorrq_u32(mask_all, mask_curr);
|
||||
}
|
||||
|
||||
// Assume that none == amax, then check mask_all to reverse
|
||||
iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f;
|
||||
uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu));
|
||||
if (vmaxvq_u32(cmp) != 0) {
|
||||
iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f;
|
||||
}
|
||||
|
||||
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
|
||||
}
|
||||
|
||||
// qs: 8 byte interleave over 4 rows, loop = QK_K/8
|
||||
// bsums: simply generated one by one, row_i is calculated before row_i+1
|
||||
// loops = 16
|
||||
for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) {
|
||||
// Process row0 and row1
|
||||
float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0]));
|
||||
float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0]));
|
||||
float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0]));
|
||||
float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0]));
|
||||
int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||
int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||
int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||
int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||
int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1]));
|
||||
float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1]));
|
||||
float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1]));
|
||||
float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1]));
|
||||
int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||
int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||
int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||
int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||
int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
// Calculate and store qs
|
||||
int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7);
|
||||
vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15);
|
||||
// Calculate and store bsum
|
||||
int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
y[i].bsums[j] = vaddlvq_s8(i0_0_15);
|
||||
y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15);
|
||||
|
||||
// Process row2 and row3
|
||||
f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2]));
|
||||
f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2]));
|
||||
f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2]));
|
||||
f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2]));
|
||||
i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||
i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||
i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||
i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||
i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3]));
|
||||
f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3]));
|
||||
f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3]));
|
||||
f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3]));
|
||||
i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||
i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||
i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||
i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||
i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
// Calculate and store qs
|
||||
i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7);
|
||||
vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15);
|
||||
// Calculate and store bsum
|
||||
i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15);
|
||||
y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
|
||||
// NOTE:
|
||||
// Current C impl of Q8_K quanti is originally designed to work with block_q4_Kx8 in x86 AVX design, and differs from
|
||||
// above Q8_K quanti logic in AArch64 NEON design, which is designed to work with block_q4_Kx4. The main difference is in
|
||||
// the process of their "[bsums] layout". Hoever, we can still reuse the x86 C impl for AArch64, as long as we access the
|
||||
// "[bsums] layout" correctly in ggml_gemm_q4_K_4x8_q8_K_generic().
|
||||
UNUSED(nb);
|
||||
UNUSED(y);
|
||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -499,6 +640,125 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
const block_q4_Kx4 *GGML_RESTRICT q4 = (const block_q4_Kx4*) vx;
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||
for (int c = 0; c < nc; c += ncols_interleaved) {
|
||||
const block_q8_K *GGML_RESTRICT q8 = (const block_q8_K *) vy;
|
||||
float32x4_t res = vdupq_n_f32(0);
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) q4->d)); // d0 d1 d2 d3
|
||||
float32x4_t q8_d = vdupq_n_f32(q8->d);
|
||||
float32x4_t g_d = vmulq_f32 (q4_d, q8_d);
|
||||
float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) q4->dmin)); // dmin0 dmin1 dmin2 dmin3
|
||||
float32x4_t g_dmin = vmulq_f32(q4_dmin, q8_d);
|
||||
const uint8_t * GGML_RESTRICT q4_ptr = q4->qs;
|
||||
const int8_t * GGML_RESTRICT q8_ptr = q8->qs;
|
||||
int32x4_t prod = vdupq_n_s32(0);
|
||||
const int16x8_t q8_sums = vpaddq_s16(vld1q_s16(q8->bsums), vld1q_s16(q8->bsums + 8));
|
||||
// When using vgetq_lane_s16, its index must be a constant, which cannot be used in a loop, so use vst1q_s16 instead.
|
||||
int16_t tmp_arry[8];
|
||||
vst1q_s16(tmp_arry, q8_sums);
|
||||
for (int j = 0; j < QK_K / 32; ++j) {
|
||||
int32x4_t sum0 = vdupq_n_s32(0);
|
||||
int32x4_t sum1 = vdupq_n_s32(0);
|
||||
// Each block: scales0 scales1 scales2 scales3 mins0 mins1 mins2 mins3
|
||||
int16x8_t scales_mins = vmovl_s8(vld1_s8((const int8_t *)q4->scales + 8 * j));
|
||||
prod = vmlal_s16(prod, vdup_n_s16(tmp_arry[j]), vget_high_s16(scales_mins));
|
||||
uint8x16_t q4_0 = vld1q_u8((const uint8_t *) q4_ptr);
|
||||
uint8x16_t q4_1 = vld1q_u8((const uint8_t *) q4_ptr + 16);
|
||||
uint8x16_t q4_2 = vld1q_u8((const uint8_t *) q4_ptr + 32);
|
||||
uint8x16_t q4_3 = vld1q_u8((const uint8_t *) q4_ptr + 48);
|
||||
q4_ptr += 64;
|
||||
int8x16_t q8_0 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr);
|
||||
int8x16_t q8_1 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr + 1);
|
||||
int8x16_t q8_2 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr + 2);
|
||||
int8x16_t q8_3 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr + 3);
|
||||
q8_ptr += 32;
|
||||
|
||||
/* low bits
|
||||
(1) sum0
|
||||
b0_000 b0_001 b0_002 b0_003 b0_004 b0_005 b0_006 b0_007 | b1_000 b1_001 b1_002 b1_003 b1_004 b1_005 b1_006 b1_007
|
||||
* a0 a1 a2 a3 a4 a5 a6 a7 | a0 a1 a2 a3 a4 a5 a6 a7
|
||||
|------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
(2) sum1
|
||||
b2_000 b2_001 b2_002 b2_003 b2_004 b2_005 b2_006 b2_007 | b3_000 b3_001 b3_002 b3_003 b3_004 b3_005 b3_006 b3_007
|
||||
* a0 a1 a2 a3 a4 a5 a6 a7 | a0 a1 a2 a3 a4 a5 a6 a7
|
||||
|------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
(3) sum0
|
||||
b0_008 b0_009 b0_010 b0_011 b0_012 b0_013 b0_014 b0_015 | b1_008 b1_009 b1_010 b1_011 b1_012 b1_013 b1_014 b1_015
|
||||
* a8 a9 a10 a11 a12 a13 a14 a15 | a8 a9 a10 a11 a12 a13 a14 a15
|
||||
|------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
(4) sum1
|
||||
b2_008 b2_009 b2_010 b2_011 b2_012 b2_013 b2_014 b2_015 | b3_008 b3_009 b3_010 b3_011 b3_012 b3_013 b3_014 b3_015
|
||||
* a8 a9 a10 a11 a12 a13 a14 a15 | a8 a9 a10 a11 a12 a13 a14 a15
|
||||
|------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
*/
|
||||
sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vandq_u8(q4_0, m4b)), q8_0);
|
||||
sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vandq_u8(q4_1, m4b)), q8_0);
|
||||
sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vandq_u8(q4_2, m4b)), q8_1);
|
||||
sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vandq_u8(q4_3, m4b)), q8_1);
|
||||
|
||||
/* high bits
|
||||
(1) sum0
|
||||
b0_016 b0_017 b0_018 b0_019 b0_020 b0_021 b0_022 b0_023 | b1_016 b1_017 b1_018 b1_019 b1_020 b1_021 b1_022 b1_023
|
||||
* a16 a17 a18 a19 a20 a21 a22 a23 | a16 a17 a18 a19 a20 a21 a22 a23
|
||||
|------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
(2) sum1
|
||||
b2_016 b2_017 b2_018 b2_019 b2_020 b2_021 b2_022 b2_023 | b3_016 b3_017 b3_018 b3_019 b3_020 b3_021 b3_022 b3_023
|
||||
* a16 a17 a18 a19 a20 a21 a22 a23 | a16 a17 a18 a19 a20 a21 a22 a23
|
||||
|------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
(3) sum0
|
||||
b_024 b0_025 b0_026 b0_027 b0_028 b0_029 b0_030 b0_031 | b1_024 b1_025 b1_026 b1_027 b1_028 b1_029 b1_030 b1_031
|
||||
* a24 a25 a26 a27 a28 a29 a30 a31 | a24 a25 a26 a27 a28 a29 a30 a31
|
||||
|------------dot------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
(4) sum1
|
||||
b2_024 b2_025 b2_026 b2_027 b2_028 b2_029 b2_030 b2_031 | b3_024 b3_025 b3_026 b3_027 b3_028 b3_029 b3_030 b3_031
|
||||
* a24 a25 a26 a27 a28 a29 a30 a31 | a24 a25 a26 a27 a28 a29 a30 a31
|
||||
|------------dot------------ | |------------dot-------------| | |------------dot-------------| |------------dot-------------|
|
||||
*/
|
||||
sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vshrq_n_u8(q4_0, 4)), q8_2);
|
||||
sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vshrq_n_u8(q4_1, 4)), q8_2);
|
||||
sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vshrq_n_u8(q4_2, 4)), q8_3);
|
||||
sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vshrq_n_u8(q4_3, 4)), q8_3);
|
||||
float32x4_t sumf = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), vpaddq_s32(sum0, sum1)));
|
||||
res = vfmaq_f32(res, g_d, sumf);
|
||||
}
|
||||
res -= vmulq_f32(g_dmin, vcvtq_f32_s32(prod));
|
||||
q4++;
|
||||
q8++;
|
||||
}
|
||||
vst1q_f32(s, res);
|
||||
s += ncols_interleaved;
|
||||
}
|
||||
return;
|
||||
}
|
||||
#else
|
||||
// C implementation
|
||||
ggml_gemv_q4_K_4x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -2329,6 +2589,299 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert(n % qk == 0);
|
||||
assert(nr % 4 == 0);
|
||||
assert(nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr); // row
|
||||
UNUSED(nc); // column
|
||||
UNUSED(nb); // block
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr_start = (const block_q8_Kx4 *) vy;
|
||||
const block_q4_Kx4 * GGML_RESTRICT q4_ptr_start = (const block_q4_Kx4 *) vx;
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
||||
float32x4_t zeros = vdupq_n_f32(0.0f);
|
||||
int anr = nr - nr % 16;
|
||||
int row = 0;
|
||||
// Row loop
|
||||
for (; row < anr / 4; row += 4) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptrs[4];
|
||||
q8_ptrs[0] = q8_ptr_start + (row * nb);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
q8_ptrs[i + 1] = q8_ptrs[i] + nb;
|
||||
}
|
||||
// Column loop
|
||||
for (int col = 0; col < nc / ncols_interleaved; col++) {
|
||||
const block_q4_Kx4 * GGML_RESTRICT q4_ptr = q4_ptr_start + (col * nb);
|
||||
// init output
|
||||
float32x4_t res[16]; // final result
|
||||
for (int i = 0; i < 16; i++) {
|
||||
res[i] = zeros;
|
||||
}
|
||||
// Block loop
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
float32x4_t g_d[16];
|
||||
float32x4_t g_dmin[16];
|
||||
int16x8_t q8_bsums[16];
|
||||
int32x4_t prod[16]; // store bsums*mins
|
||||
for (int i = 0; i < 16; i++) {
|
||||
g_d[i] = zeros;
|
||||
g_dmin[i] = zeros;
|
||||
q8_bsums[i] = vdupq_n_s16(0);
|
||||
prod[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// Get global d and dmin
|
||||
float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // col0 col1 col2 col3
|
||||
float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin0 dmin1 dmin2 dmin3
|
||||
int16_t tmp_q8_bsums_array[16][8];
|
||||
for (int iter = 0; iter < 4; iter++) {
|
||||
// Calculation when four lines are grouped together
|
||||
for (int in = 0; in < 4; in++) {
|
||||
float32x4_t scalar_q8_d = vdupq_n_f32(q8_ptrs[iter][b].d[in]);
|
||||
g_d[in + 4 * iter] = vmulq_f32(q4_d, scalar_q8_d);
|
||||
g_dmin[in + 4 * iter] = vmulq_f32(q4_dmin, scalar_q8_d);
|
||||
// The 16 elements in each row are merged into 8 elements. No loop expansion is performed here
|
||||
q8_bsums[in + 4 * iter] = vpaddq_s16(vld1q_s16(q8_ptrs[iter][b].bsums + 16 * in), vld1q_s16(q8_ptrs[iter][b].bsums + 16 * in + 8));
|
||||
vst1q_s16(tmp_q8_bsums_array[in + 4 * iter], q8_bsums[in + 4 * iter]);
|
||||
}
|
||||
}
|
||||
// The 256 elements in the superblock are processed in 8 steps
|
||||
for (int sb = 0; sb < QK_K / 32; sb++) {
|
||||
int32x4_t acc_rows[16]; // the calculated value of qs
|
||||
int32x4_t sum[16]; // the value of qs after rearranging
|
||||
for (int i = 0; i < 16; i++) {
|
||||
acc_rows[i] = vdupq_n_s32(0);
|
||||
sum[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// each block: scales0 scales1 scales2 scales3 mins0 mins1 mins2 mins3
|
||||
int16x8_t scales_mins = vmovl_s8(vld1_s8((const int8_t *) q4_ptr[b].scales + 8 * sb));
|
||||
uint8x16_t q4_qs_raw_01_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + sb * 64);
|
||||
uint8x16_t q4_qs_raw_23_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 16 + sb * 64);
|
||||
uint8x16_t q4_qs_raw_01_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 32 + sb * 64);
|
||||
uint8x16_t q4_qs_raw_23_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 48 + sb * 64);
|
||||
|
||||
int8x16_t q4_qs_01_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_0, m4b)); // B0(0-7) B1(0-7)
|
||||
int8x16_t q4_qs_23_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_0, m4b)); // B2(0-7) B3(0-7)
|
||||
int8x16_t q4_qs_01_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_1, m4b)); // B0(8-15) B1(8-15)
|
||||
int8x16_t q4_qs_23_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_1, m4b)); // B2(8-15) B3(8-15)
|
||||
|
||||
int8x16_t q4_qs_01_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_0, 4)); // B0(16-23) B1(16-23)
|
||||
int8x16_t q4_qs_23_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_0, 4)); // B2(16-23) B3(16-23)
|
||||
int8x16_t q4_qs_01_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_1, 4)); // B0(24-31) B1(24-31)
|
||||
int8x16_t q4_qs_23_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_1, 4)); // B2(24-31) B3(24-31)
|
||||
|
||||
// The 16 rows of the left matrix are expanded four times
|
||||
for (int iter = 0; iter < 4; iter++) {
|
||||
// Direct loop unrolling
|
||||
prod[0 + 4 * iter] = vmlal_s16(prod[0 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[0 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter): bsums*mins(0-3)
|
||||
prod[1 + 4 * iter] = vmlal_s16(prod[1 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[1 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter+1): bsums*mins(0-3)
|
||||
prod[2 + 4 * iter] = vmlal_s16(prod[2 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[2 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter+2): bsums*mins(0-3)
|
||||
prod[3 + 4 * iter] = vmlal_s16(prod[3 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[3 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter+3): bsums*mins(0-3)
|
||||
|
||||
int8x16_t q8_qs_01_00 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 128 * sb); // A0(0-7) A1(0-7)
|
||||
int8x16_t q8_qs_23_00 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 16 + 128 * sb); // A2(0-7) A3(0-7)
|
||||
|
||||
acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_00, q4_qs_01_l0); // A0*B0(0-7) A0*B1(0-7) A1*B0(0-7) A1*B1(0-7)
|
||||
acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_00, q4_qs_23_l0); // A0*B2(0-7) A0*B3(0-7) A1*B2(0-7) A1*B3(0-7)
|
||||
acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_00, q4_qs_01_l0); // A2*B0(0-7) A2*B1(0-7) A3*B0(0-7) A3*B1(0-7)
|
||||
acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_00, q4_qs_23_l0); // A2*B2(0-7) A2*B3(0-7) A3*B2(0-7) A3*B3(0-7)
|
||||
|
||||
int8x16_t q8_qs_01_01 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 32 + 128 * sb); // A0(8-15) A1(8-15)
|
||||
int8x16_t q8_qs_23_01 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 48 + 128 * sb); // A2(8-15) A3(8-15)
|
||||
|
||||
acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_01, q4_qs_01_l1); // A0*B0(8-15) A0*B1(8-15) A1*B0(8-15) A1*B1(8-15)
|
||||
acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_01, q4_qs_23_l1); // A0*B2(8-15) A0*B3(8-15) A1*B2(8-15) A1*B3(8-15)
|
||||
acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_01, q4_qs_01_l1); // A2*B0(8-15) A2*B1(8-15) A3*B0(8-15) A3*B1(8-15)
|
||||
acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_01, q4_qs_23_l1); // A2*B2(8-15) A2*B3(8-15) A3*B2(8-15) A3*B3(8-15)
|
||||
|
||||
int8x16_t q8_qs_01_02 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 64 + 128 * sb); // A0(16-23) A1(16-23)
|
||||
int8x16_t q8_qs_23_02 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 80 + 128 * sb); // A2(16-23) A3(16-23)
|
||||
|
||||
acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_02, q4_qs_01_h0); // A0*B0(16-23) A0*B1(16-23) A1*B0(16-23) A1*B1(16-23)
|
||||
acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_02, q4_qs_23_h0); // A0*B2(16-23) A0*B3(16-23) A1*B2(16-23) A1*B3(16-23)
|
||||
acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_02, q4_qs_01_h0); // A2*B0(16-23) A2*B1(16-23) A3*B0(16-23) A3*B1(16-23)
|
||||
acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_02, q4_qs_23_h0); // A2*B2(16-23) A2*B3(16-23) A3*B2(16-23) A3*B3(16-23)
|
||||
|
||||
int8x16_t q8_qs_01_03 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 96 + 128 * sb); // A0(24-31) A1(24-31)
|
||||
int8x16_t q8_qs_23_03 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 112 + 128 * sb); // A2(24-31) A3(24-31)
|
||||
|
||||
acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_03, q4_qs_01_h1); // A0*B0(24-31) A0*B1(24-31) A1*B0(24-31) A1*B1(24-31)
|
||||
acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_03, q4_qs_23_h1); // A0*B2(24-31) A0*B3(24-31) A1*B2(24-31) A1*B3(24-31)
|
||||
acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_03, q4_qs_01_h1); // A2*B0(24-31) A2*B1(24-31) A3*B0(24-31) A3*B1(24-31)
|
||||
acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_03, q4_qs_23_h1); // A2*B2(24-31) A2*B3(24-31) A3*B2(24-31) A3*B3(24-31)
|
||||
|
||||
// rearranging vectors
|
||||
sum[0 + 4 * iter] = vcombine_s32(vget_low_s32(acc_rows[0 + 4 * iter]), vget_low_s32(acc_rows[1 + 4 * iter]));
|
||||
sum[1 + 4 * iter] = vcombine_s32(vget_high_s32(acc_rows[0 + 4 * iter]), vget_high_s32(acc_rows[1 + 4 * iter]));
|
||||
sum[2 + 4 * iter] = vcombine_s32(vget_low_s32(acc_rows[2 + 4 * iter]), vget_low_s32(acc_rows[3 + 4 * iter]));
|
||||
sum[3 + 4 * iter] = vcombine_s32(vget_high_s32(acc_rows[2 + 4 * iter]), vget_high_s32(acc_rows[3 + 4 * iter]));
|
||||
|
||||
float32x4_t sumf_0 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[0 + 4 * iter])); // scales *qs
|
||||
float32x4_t sumf_1 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[1 + 4 * iter]));
|
||||
float32x4_t sumf_2 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[2 + 4 * iter]));
|
||||
float32x4_t sumf_3 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[3 + 4 * iter]));
|
||||
|
||||
res[0 + 4 * iter] = vfmaq_f32(res[0 + 4 * iter], g_d[0 + 4 * iter], sumf_0);
|
||||
res[1 + 4 * iter] = vfmaq_f32(res[1 + 4 * iter], g_d[1 + 4 * iter], sumf_1);
|
||||
res[2 + 4 * iter] = vfmaq_f32(res[2 + 4 * iter], g_d[2 + 4 * iter], sumf_2);
|
||||
res[3 + 4 * iter] = vfmaq_f32(res[3 + 4 * iter], g_d[3 + 4 * iter], sumf_3);
|
||||
}
|
||||
}
|
||||
for (int iter = 0; iter < 4; iter++) {
|
||||
res[0 + 4 * iter] -= vmulq_f32(g_dmin[0 + 4 * iter], vcvtq_f32_s32(prod[0 + 4 * iter]));
|
||||
res[1 + 4 * iter] -= vmulq_f32(g_dmin[1 + 4 * iter], vcvtq_f32_s32(prod[1 + 4 * iter]));
|
||||
res[2 + 4 * iter] -= vmulq_f32(g_dmin[2 + 4 * iter], vcvtq_f32_s32(prod[2 + 4 * iter]));
|
||||
res[3 + 4 * iter] -= vmulq_f32(g_dmin[3 + 4 * iter], vcvtq_f32_s32(prod[3 + 4 * iter]));
|
||||
}
|
||||
}
|
||||
// store result
|
||||
for (int i = 0; i < 16; i++) {
|
||||
vst1q_f32((float *) (s + ((row * 4 + i) * bs + col * 4)), res[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handling tail parts that are less than 16 lines
|
||||
for (; row < nr / 4; row++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = q8_ptr_start + (row * nb);
|
||||
// Column loop
|
||||
for (int col = 0; col < nc / ncols_interleaved; col++) {
|
||||
const block_q4_Kx4 * GGML_RESTRICT q4_ptr = q4_ptr_start + (col * nb);
|
||||
// init output
|
||||
float32x4_t res[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res[i] = zeros;
|
||||
}
|
||||
// Block loop
|
||||
for (int64_t b = 0; b < nb; b++) {
|
||||
float32x4_t g_d[4];
|
||||
float32x4_t g_dmin[4];
|
||||
int16x8_t q8_bsums[4];
|
||||
int32x4_t prod[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
g_d[i] = zeros;
|
||||
g_dmin[i] = zeros;
|
||||
q8_bsums[i] = vdupq_n_s16(0);
|
||||
prod[i] = vdupq_n_s32(0);
|
||||
}
|
||||
float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // col0 col1 col2 col3
|
||||
float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin0 dmin1 dmin2 dmin3
|
||||
int16_t tmp_q8_bsums_array[4][8];
|
||||
for (int in = 0; in < 4; in++) {
|
||||
float32x4_t scalar_q8_d = vdupq_n_f32(q8_ptr[b].d[in]);
|
||||
g_d[in] = vmulq_f32(q4_d, scalar_q8_d);
|
||||
g_dmin[in] = vmulq_f32(q4_dmin, scalar_q8_d);
|
||||
q8_bsums[in] = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * in), vld1q_s16(q8_ptr[b].bsums + 16 * in + 8));
|
||||
vst1q_s16(tmp_q8_bsums_array[in], q8_bsums[in]);
|
||||
}
|
||||
for (int sb = 0; sb < QK_K / 32; sb++) {
|
||||
int32x4_t acc_rows[4];
|
||||
int32x4_t sum[4];
|
||||
for (int i = 0; i < 4; i++) {
|
||||
acc_rows[i] = vdupq_n_s32(0);
|
||||
sum[i] = vdupq_n_s32(0);
|
||||
}
|
||||
// each block: scales0 scales1 scales2 scales3 mins0 mins1 mins2 mins3
|
||||
int16x8_t scales_mins = vmovl_s8(vld1_s8((const int8_t *) q4_ptr[b].scales + 8 * sb));
|
||||
uint8x16_t q4_qs_raw_01_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + sb * 64);
|
||||
uint8x16_t q4_qs_raw_23_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 16 + sb * 64);
|
||||
uint8x16_t q4_qs_raw_01_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 32 + sb * 64);
|
||||
uint8x16_t q4_qs_raw_23_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 48 + sb * 64);
|
||||
|
||||
int8x16_t q4_qs_01_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_0, m4b)); // B0(0-7) B1(0-7)
|
||||
int8x16_t q4_qs_23_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_0, m4b)); // B2(0-7) B3(0-7)
|
||||
int8x16_t q4_qs_01_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_1, m4b)); // B0(8-15) B1(8-15)
|
||||
int8x16_t q4_qs_23_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_1, m4b)); // B2(8-15) B3(8-15)
|
||||
|
||||
prod[0] = vmlal_s16(prod[0], vdup_n_s16(tmp_q8_bsums_array[0][sb]), vget_high_s16(scales_mins)); // row(iter): bsums*mins(0-3)
|
||||
prod[1] = vmlal_s16(prod[1], vdup_n_s16(tmp_q8_bsums_array[1][sb]), vget_high_s16(scales_mins)); // row(iter+1): bsums*mins(0-3)
|
||||
prod[2] = vmlal_s16(prod[2], vdup_n_s16(tmp_q8_bsums_array[2][sb]), vget_high_s16(scales_mins)); // row(iter+2): bsums*mins(0-3)
|
||||
prod[3] = vmlal_s16(prod[3], vdup_n_s16(tmp_q8_bsums_array[3][sb]), vget_high_s16(scales_mins)); // row(iter+3): bsums*mins(0-3)
|
||||
|
||||
int8x16_t q8_qs_01_00 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 128 * sb); // A0(0-7) A1(0-7)
|
||||
int8x16_t q8_qs_23_00 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 16 + 128 * sb); // A2(0-7) A3(0-7)
|
||||
|
||||
acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_00, q4_qs_01_l0); // A0*B0(0-7) A0*B1(0-7) A1*B0(0-7) A1*B1(0-7)
|
||||
acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_00, q4_qs_23_l0); // A0*B2(0-7) A0*B3(0-7) A1*B2(0-7) A1*B3(0-7)
|
||||
acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_00, q4_qs_01_l0); // A2*B0(0-7) A2*B1(0-7) A3*B0(0-7) A3*B1(0-7)
|
||||
acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_00, q4_qs_23_l0); // A2*B2(0-7) A2*B3(0-7) A3*B2(0-7) A3*B3(0-7)
|
||||
|
||||
int8x16_t q8_qs_01_01 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 32 + 128 * sb); // A0(8-15) A1(8-15)
|
||||
int8x16_t q8_qs_23_01 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 48 + 128 * sb); // A2(8-15) A3(8-15)
|
||||
|
||||
acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_01, q4_qs_01_l1); // A0*B0(8-15) A0*B1(8-15) A1*B0(8-15) A1*B1(8-15)
|
||||
acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_01, q4_qs_23_l1); // A0*B2(8-15) A0*B3(8-15) A1*B2(8-15) A1*B3(8-15)
|
||||
acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_01, q4_qs_01_l1); // A2*B0(8-15) A2*B1(8-15) A3*B0(8-15) A3*B1(8-15)
|
||||
acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_01, q4_qs_23_l1); // A2*B2(8-15) A2*B3(8-15) A3*B2(8-15) A3*B3(8-15)
|
||||
|
||||
int8x16_t q4_qs_01_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_0, 4)); // B0(16-23) B1(16-23)
|
||||
int8x16_t q4_qs_23_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_0, 4)); // B2(16-23) B3(16-23)
|
||||
int8x16_t q4_qs_01_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_1, 4)); // B0(24-31) B1(24-31)
|
||||
int8x16_t q4_qs_23_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_1, 4)); // B2(24-31) B3(24-31)
|
||||
|
||||
int8x16_t q8_qs_01_02 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 64 + 128 * sb); // A0(16-23) A1(16-23)
|
||||
int8x16_t q8_qs_23_02 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 80 + 128 * sb); // A2(16-23) A3(16-23)
|
||||
|
||||
acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_02, q4_qs_01_h0); // A0*B0(16-23) A0*B1(16-23) A1*B0(16-23) A1*B1(16-23)
|
||||
acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_02, q4_qs_23_h0); // A0*B2(16-23) A0*B3(16-23) A1*B2(16-23) A1*B3(16-23)
|
||||
acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_02, q4_qs_01_h0); // A2*B0(16-23) A2*B1(16-23) A3*B0(16-23) A3*B1(16-23)
|
||||
acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_02, q4_qs_23_h0); // A2*B2(16-23) A2*B3(16-23) A3*B2(16-23) A3*B3(16-23)
|
||||
|
||||
int8x16_t q8_qs_01_03 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 96 + 128 * sb); // A0(24-31) A1(24-31)
|
||||
int8x16_t q8_qs_23_03 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 112 + 128 * sb); // A2(24-31) A3(24-31)
|
||||
|
||||
acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_03, q4_qs_01_h1); // A0*B0(24-31) A0*B1(24-31) A1*B0(24-31) A1*B1(24-31)
|
||||
acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_03, q4_qs_23_h1); // A0*B2(24-31) A0*B3(24-31) A1*B2(24-31) A1*B3(24-31)
|
||||
acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_03, q4_qs_01_h1); // A2*B0(24-31) A2*B1(24-31) A3*B0(24-31) A3*B1(24-31)
|
||||
acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_03, q4_qs_23_h1); // A2*B2(24-31) A2*B3(24-31) A3*B2(24-31) A3*B3(24-31)
|
||||
|
||||
// rearranging vectors
|
||||
sum[0] = vcombine_s32(vget_low_s32(acc_rows[0]), vget_low_s32(acc_rows[1]));
|
||||
sum[1] = vcombine_s32(vget_high_s32(acc_rows[0]), vget_high_s32(acc_rows[1]));
|
||||
sum[2] = vcombine_s32(vget_low_s32(acc_rows[2]), vget_low_s32(acc_rows[3]));
|
||||
sum[3] = vcombine_s32(vget_high_s32(acc_rows[2]), vget_high_s32(acc_rows[3]));
|
||||
|
||||
float32x4_t sumf_0 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[0])); // scales *qs
|
||||
float32x4_t sumf_1 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[1]));
|
||||
float32x4_t sumf_2 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[2]));
|
||||
float32x4_t sumf_3 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[3]));
|
||||
|
||||
res[0] = vfmaq_f32(res[0], g_d[0], sumf_0);
|
||||
res[1] = vfmaq_f32(res[1], g_d[1], sumf_1);
|
||||
res[2] = vfmaq_f32(res[2], g_d[2], sumf_2);
|
||||
res[3] = vfmaq_f32(res[3], g_d[3], sumf_3);
|
||||
}
|
||||
res[0] -= vmulq_f32(g_dmin[0], vcvtq_f32_s32(prod[0]));
|
||||
res[1] -= vmulq_f32(g_dmin[1], vcvtq_f32_s32(prod[1]));
|
||||
res[2] -= vmulq_f32(g_dmin[2], vcvtq_f32_s32(prod[2]));
|
||||
res[3] -= vmulq_f32(g_dmin[3], vcvtq_f32_s32(prod[3]));
|
||||
}
|
||||
// store result
|
||||
for (int i = 0; i < 4; i++) {
|
||||
vst1q_f32((float *) (s + ((row * 4 + i) * bs + col * 4)), res[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
#else
|
||||
// C implementation
|
||||
ggml_gemm_q4_K_4x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
constexpr int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
|
|
|||
|
|
@ -287,9 +287,10 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) {
|
||||
assert(QK_K == 256);
|
||||
assert(k % QK_K == 0);
|
||||
UNUSED(nc);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||
|
|
@ -507,7 +508,7 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
#else
|
||||
UNUSED(nb);
|
||||
UNUSED(y);
|
||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
|
||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -176,9 +176,10 @@ void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GG
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) {
|
||||
assert(QK_K == 256);
|
||||
assert(k % QK_K == 0);
|
||||
UNUSED(nc);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||
|
|
@ -230,30 +231,33 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
|
|||
} // extern "C"
|
||||
|
||||
template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
|
||||
void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
|
||||
void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols);
|
||||
|
||||
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
||||
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
|
||||
assert(nrow == 4);
|
||||
UNUSED(nrow);
|
||||
UNUSED(ncols);
|
||||
ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
|
||||
}
|
||||
|
||||
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
||||
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
|
||||
assert(nrow == 4);
|
||||
UNUSED(nrow);
|
||||
UNUSED(ncols);
|
||||
ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
|
||||
}
|
||||
|
||||
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
||||
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
|
||||
assert(nrow == 4);
|
||||
UNUSED(nrow);
|
||||
UNUSED(ncols);
|
||||
ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
|
||||
}
|
||||
|
||||
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
||||
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
|
||||
assert(nrow == 4);
|
||||
UNUSED(nrow);
|
||||
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
|
||||
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row, ncols);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
|
@ -391,6 +395,86 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 8;
|
||||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
int8_t scales[4][8]; // scales for 8 subblocks of 4 q4_k unit (4 cols)
|
||||
int8_t mins[4][8]; // mins for 8 subblocks of 4 q4_k unit (4 cols)
|
||||
float sumf[4]; // 1x4 unit: final result
|
||||
float sum_minf[4]; // 1x4 unit: final minus result
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
|
||||
// loop on n dimension, each iteration works on 4 columns
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx4 * b_ptr = (const block_q4_Kx4 *) vx + (x * nb);
|
||||
|
||||
// initialize results for 4 cols
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0;
|
||||
sum_minf[j] = 0.0;
|
||||
}
|
||||
|
||||
// loop on k dimension, each iteration works on 1 q4_kx4 block
|
||||
for (int n = 0; n < nb; n++) {
|
||||
// prepare scales and mins for 4 cols
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
scales[j][i] = b_ptr[n].scales[i * 8 + j];
|
||||
mins[j][i] = b_ptr[n].scales[i * 8 + j + ncols_interleaved];
|
||||
}
|
||||
}
|
||||
// core loop: each iteration works on an interleaved unit (four 8-byte segments from 4 cols)
|
||||
for (int k = 0; k < qk / (2 * blocklen); k++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
int8_t scale = scales[j][k / 2];
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int v0 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xf);
|
||||
const int v1 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
|
||||
sumi1 = v0 * a_ptr[n].qs[(k / 2) * 32 + (k % 2) * blocklen + i];
|
||||
sumi2 = v1 * a_ptr[n].qs[(k / 2) * 32 + (k % 2) * blocklen + i + 16];
|
||||
sumi += scale * (sumi1 + sumi2);
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[n].d[j]) * a_ptr[n].d;
|
||||
}
|
||||
}
|
||||
|
||||
// prepare partial results for tail work
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
for (int i = 0; i < QK_K / 32; i++) {
|
||||
sum_minf[j] += mins[j][i] * (a_ptr[n].bsums[i * 2] + a_ptr[n].bsums[i * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// save results
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -950,6 +1034,98 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nr % 4 == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
int8_t scales[4][8]; // scales for 8 subblocks of 4 q4_k unit (4 cols)
|
||||
int8_t mins[4][8]; // mins for 8 subblocks of 4 q4_k unit (4 cols)
|
||||
float sumf[4][4]; // 4x4 unit: final result
|
||||
float sum_minf[4][4]; // 4x4 unit: final minus result
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
// loop on m dimension, each iteration works on 4 rows
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
// loop on n dimension, each iteration works on 4 columns
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx4 * b_ptr = (const block_q4_Kx4 *) vx + (x * nb);
|
||||
|
||||
// initialize results for 4 cols
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0;
|
||||
sum_minf[m][j] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
// loop on k dimension, each iteration works on 1 q4_kx4 block
|
||||
for (int n = 0; n < nb; n++) {
|
||||
// prepare scales and mins for 4 cols
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
scales[j][i] = b_ptr[n].scales[i * 8 + j];
|
||||
mins[j][i] = b_ptr[n].scales[i * 8 + j + ncols_interleaved];
|
||||
}
|
||||
}
|
||||
|
||||
// core loop: each iteration works on an interleaved unit (four 8-byte segments from 4 cols)
|
||||
for (int k = 0; k < qk / (2 * blocklen); k++) {
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
int8_t scale = scales[j][k / 2];
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int v0 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xf);
|
||||
const int v1 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
|
||||
sumi1 = v0 * a_ptr[n].qs[(k / 2) * 128 + (k % 2) * 4 * blocklen + m * blocklen + i];
|
||||
sumi2 = v1 * a_ptr[n].qs[(k / 2) * 128 + (k % 2) * 4 * blocklen + m * blocklen + i + 64];
|
||||
sumi += scale * (sumi1 + sumi2);
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[n].d[j]) * a_ptr[n].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// prepare partial results for tail work
|
||||
//
|
||||
// NOTE:
|
||||
// the "[bsums] layout" here is from ggml_quantize_mat_q8_K_4x8_generic().
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
for (int i = 0; i < QK_K / 32; i++) {
|
||||
const int16_t *bsums = a_ptr[n].bsums + (i * 8) - ((i % 2) * 6) + (m * 4);
|
||||
sum_minf[m][j] += mins[j][i] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// save results
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -1505,6 +1681,90 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in
|
|||
return out;
|
||||
}
|
||||
|
||||
static void make_block_q4_Kx4(block_q4_K * in, unsigned int blck_size_interleave, block_q4_Kx4 * out) {
|
||||
int nrow = 4;
|
||||
int nloop = 4;
|
||||
|
||||
// d and dmin values of the 4 Q4_K are copied directly.
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
out->d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
|
||||
}
|
||||
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
out->dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
|
||||
}
|
||||
|
||||
// For qs, 2 things need to be done:
|
||||
// 1. Recover from Q4_K storage tyle to Q4_0 style
|
||||
// 2. Interleave quants by taking 8 bytes at a time
|
||||
|
||||
// 1.
|
||||
const uint64_t lo_mask = 0x0f0f0f0f0f0f0f0fULL;
|
||||
const uint64_t hi_mask = 0xf0f0f0f0f0f0f0f0ULL;
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
uint64_t *q = (uint64_t *)(in[i].qs);
|
||||
for (int j = 0; j < nloop; j++) {
|
||||
uint64_t q0, q1, q2, q3;
|
||||
q0 = q[0];
|
||||
q1 = q[1];
|
||||
q2 = q[2];
|
||||
q3 = q[3];
|
||||
|
||||
uint64_t hi1, hi2, lo3, lo4;
|
||||
hi1 = q0 & hi_mask;
|
||||
hi2 = q1 & hi_mask;
|
||||
lo3 = q2 & lo_mask;
|
||||
lo4 = q3 & lo_mask;
|
||||
q[0] = (q0 & lo_mask) | (lo3 << 4);
|
||||
q[1] = (q1 & lo_mask) | (lo4 << 4);
|
||||
q[2] = (q2 & hi_mask) | (hi1 >> 4);
|
||||
q[3] = (q3 & hi_mask) | (hi2 >> 4);
|
||||
|
||||
q += 4;
|
||||
}
|
||||
}
|
||||
|
||||
// 2.
|
||||
// Calculate total number of interleaved subblocks
|
||||
const int end = QK_K * 2 / blck_size_interleave;
|
||||
uint64_t *src, *dst;
|
||||
for (int i = 0; i < end; ++i) {
|
||||
int src_id = i % 4;
|
||||
int src_offset = (i / 4) * blck_size_interleave;
|
||||
int dst_offset = i * blck_size_interleave;
|
||||
|
||||
src = (uint64_t *)(&in[src_id].qs[src_offset]);
|
||||
dst = (uint64_t *)(&out->qs[dst_offset]);
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
// For scales & mins of each subblock. (8 subblocks in one Q4_K, 32 in total)
|
||||
// A special requirement to meet: expand to 8-bit from 6-bit.
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
uint32_t utmp[4];
|
||||
for (int i = 0; i < nrow; i++) {
|
||||
// rearrange as d|d|...|d|min|min|...|min
|
||||
// expand to 8-bit from 6-bit
|
||||
memset(utmp, 0, 16);
|
||||
memcpy(utmp, in[i].scales, 12);
|
||||
utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux = utmp[1] & kmask1;
|
||||
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
|
||||
utmp[2] = uaux;
|
||||
utmp[0] &= kmask1;
|
||||
|
||||
// move to Q4_K
|
||||
const uint8_t * d_ptr = (const uint8_t*)&utmp[0];
|
||||
const uint8_t * m_ptr = (const uint8_t*)&utmp[2];
|
||||
for (int j = 0; j < 8; j++) {
|
||||
out->scales[j * 8 + i] = *d_ptr++;
|
||||
out->scales[j * 8 + i + nrow] = *m_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) {
|
||||
block_q4_Kx8 out;
|
||||
//Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure
|
||||
|
|
@ -1656,6 +1916,46 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
|
|||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
static int repack_q4_K_to_q4_K_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
|
||||
GGML_ASSERT(interleave_block == 8);
|
||||
constexpr int nrows_interleaved = 4;
|
||||
|
||||
block_q4_Kx4 * dst = (block_q4_Kx4 *)t->data;
|
||||
const block_q4_K * src = (const block_q4_K *) data;
|
||||
block_q4_K dst_tmp[4];
|
||||
int nrow = ggml_nrows(t);
|
||||
int nblocks = t->ne[0] / QK_K;
|
||||
|
||||
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));
|
||||
|
||||
if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (int b = 0; b < nrow; b += nrows_interleaved) {
|
||||
for (int64_t x = 0; x < nblocks; x++) {
|
||||
for (int i = 0; i < nrows_interleaved; i++ ) {
|
||||
dst_tmp[i] = src[x + i * nblocks];
|
||||
}
|
||||
make_block_q4_Kx4(dst_tmp, interleave_block, dst++);
|
||||
}
|
||||
src += nrows_interleaved * nblocks;
|
||||
}
|
||||
|
||||
// change tensor shape as block_q4_kx4 brings space size change
|
||||
//t->nb[0] = ggml_type_size(type);
|
||||
t->nb[0] = sizeof(block_q4_Kx4) / 4;
|
||||
t->nb[1] = t->nb[0] * (t->ne[0] / ggml_blck_size(t->type));
|
||||
for (int i = 2; i < GGML_MAX_DIMS; i++) {
|
||||
t->nb[i] = t->nb[i - 1] * t->ne[i - 1];
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
|
||||
static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
|
||||
GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
|
||||
GGML_ASSERT(interleave_block == 8 || interleave_block == 4);
|
||||
|
|
@ -1924,6 +2224,10 @@ template <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * da
|
|||
return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q4_K, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q4_K_to_q4_K_4_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
||||
template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
|
||||
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
|
||||
}
|
||||
|
|
@ -1973,6 +2277,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
|||
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q4_K, 8, 4, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q4_K_4x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
|
@ -2013,14 +2321,18 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
|
|||
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 8, 4, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_4x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
||||
template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
|
||||
ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
|
||||
}
|
||||
|
|
@ -2194,7 +2506,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||
|
||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
|
||||
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
|
||||
(void *) (wdata_ptr + i11 * nbw1), 4, ne10, ne01);
|
||||
}
|
||||
|
||||
const int64_t i11_processed = ne11 - ne11 % 4;
|
||||
|
|
@ -2292,7 +2604,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||
const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
||||
//GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
|
|
@ -2429,6 +2741,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
|
||||
|
||||
// instance for Q4_K
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 4, GGML_TYPE_Q8_K> q4_K_4x8_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
|
||||
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
|
||||
|
||||
|
|
@ -2470,6 +2783,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q4_K_8x8_q8_K;
|
||||
}
|
||||
if (cur->ne[1] % 4 == 0) {
|
||||
return &q4_K_4x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
|
|
@ -2555,6 +2871,30 @@ static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buf
|
|||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
// Below func is for the aarch64 q4_K_4x8_q8_K repack case only:
|
||||
// Tensor storage after repacking is a bit larger than before -- sizeof(block_q4_Kx4) > sizeof(block_q4_K)*4
|
||||
// This is due to member "scales" are pre-decoded in repacking stage, not in execution stage.
|
||||
static inline size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) {
|
||||
size_t nbytes;
|
||||
const size_t blck_size = 256;
|
||||
const size_t type_size = sizeof(block_q4_Kx4) / 4;
|
||||
nbytes = ((tensor->ne[0] * type_size) / blck_size) * tensor->ne[1] * tensor->ne[2] * tensor->ne[3];
|
||||
return nbytes;
|
||||
}
|
||||
|
||||
static size_t ggml_backend_cpu_aarch64_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
||||
if (tensor->type == GGML_TYPE_Q4_K) {
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (tensor->ne[1] % 4 == 0) {
|
||||
return ggml_nbytes_q4_kx4(tensor); // for q4_K_4x8_q8_K only
|
||||
}
|
||||
}
|
||||
}
|
||||
return ggml_nbytes(tensor);
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
namespace ggml::cpu::repack {
|
||||
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
||||
|
|
@ -2611,7 +2951,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void) {
|
|||
/* .alloc_buffer = */ ggml_backend_cpu_repack_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_cpu_repack_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
||||
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
|
||||
/* .get_alloc_size = */ ggml_backend_cpu_aarch64_buffer_type_get_alloc_size, // defaults to ggml_nbytes except for ARM N2
|
||||
/* .is_host = */ nullptr,
|
||||
},
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
||||
|
|
|
|||
|
|
@ -36,6 +36,13 @@ using block_q4_0x8 = block<4, 8>;
|
|||
using block_q8_0x4 = block<8, 4>;
|
||||
using block_q8_0x8 = block<8, 8>;
|
||||
|
||||
struct block_q4_Kx4 {
|
||||
ggml_half d[4]; // super-block scale for quantized scales
|
||||
ggml_half dmin[4]; // super-block scale for quantized mins
|
||||
int8_t scales[64]; // scales and mins, quantized with 8 bits (recover from 6-bit during repack)
|
||||
uint8_t qs[512]; // 4--bit quants
|
||||
};
|
||||
|
||||
struct block_q4_Kx8 {
|
||||
ggml_half d[8]; // super-block scale for quantized scales
|
||||
ggml_half dmin[8]; // super-block scale for quantized mins
|
||||
|
|
@ -81,10 +88,11 @@ extern "C" {
|
|||
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc);
|
||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
@ -93,6 +101,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
|||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
@ -107,18 +116,22 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc);
|
||||
// gemv_generic ???
|
||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
// gemm_generic ???
|
||||
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
|
|||
Loading…
Reference in New Issue