From 126ce2c9e4f73203c1b455ceb14927f4a02bc51c Mon Sep 17 00:00:00 2001 From: hongyang Date: Thu, 18 Dec 2025 14:00:40 +0800 Subject: [PATCH] fix compatibility with other q4_k repacking models --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 221 +++++++++++++------------- ggml/src/ggml-cpu/repack.cpp | 28 ++-- ggml/src/ggml-cpu/repack.h | 4 +- 3 files changed, 131 insertions(+), 122 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index b3df909f8c..ba86886282 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -210,129 +210,137 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR #endif } -void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) { assert(QK_K == 256); assert(k % QK_K == 0); + UNUSED(nc); const int nb = k / QK_K; block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; #if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - const int blck_size_interleave = 8; - float32x4_t srcv[4][64]; // 64 = QK_K/4 - float iscale[4]; + if (nc % 8 == 0) { + UNUSED(nb); + UNUSED(y); + ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc); + } else if (nc % 4 == 0) { + const int blck_size_interleave = 8; + float32x4_t srcv[4][64]; // 64 = QK_K/4 + float iscale[4]; - for (int i = 0; i < nb; i++) { - float32x4_t asrcv[64]; - float32x4_t amaxv[64]; + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[64]; + float32x4_t amaxv[64]; - // d: - for (int row_iter = 0; row_iter < 4; row_iter++) { - for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j); - for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + // d: + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j); + for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); - for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); - for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); - for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); - for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]); - for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]); - for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]); + for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]); + for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]); + for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]); - const float amax = vmaxvq_f32(amaxv[0]); + const float amax = vmaxvq_f32(amaxv[0]); - // Check if exists: orig == amax - float32x4_t amax_vec = vdupq_n_f32(amax); - uint32x4_t mask_all = vdupq_n_u32(0); - for (int j = 0; j < 64; j++) { - uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]); - mask_all = vorrq_u32(mask_all, mask_curr); + // Check if exists: orig == amax + float32x4_t amax_vec = vdupq_n_f32(amax); + uint32x4_t mask_all = vdupq_n_u32(0); + for (int j = 0; j < 64; j++) { + uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]); + mask_all = vorrq_u32(mask_all, mask_curr); + } + + // Assume that none == amax, then check mask_all to reverse + iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f; + uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu)); + if (vmaxvq_u32(cmp) != 0) { + iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f; + } + + y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; } - // Assume that none == amax, then check mask_all to reverse - iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f; - uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu)); - if (vmaxvq_u32(cmp) != 0) { - iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f; + // qs: 8 byte interleave over 4 rows, loop = QK_K/8 + // bsums: simply generated one by one, row_i is calculated before row_i+1 + // loops = 16 + for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) { + // Process row0 and row1 + float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0])); + float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0])); + float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0])); + float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0])); + int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3); + int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7); + int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11); + int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15); + int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1])); + float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1])); + float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1])); + float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1])); + int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3); + int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7); + int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11); + int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15); + int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + // Calculate and store qs + int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16 + int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7); + vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15); + // Calculate and store bsum + int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + y[i].bsums[j] = vaddlvq_s8(i0_0_15); + y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15); + + // Process row2 and row3 + f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2])); + f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2])); + f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2])); + f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2])); + i0_0_3 = vcvtnq_s32_f32(f0_0_3); + i0_4_7 = vcvtnq_s32_f32(f0_4_7); + i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + i0_8_11 = vcvtnq_s32_f32(f0_8_11); + i0_12_15 = vcvtnq_s32_f32(f0_12_15); + i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3])); + f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3])); + f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3])); + f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3])); + i1_0_3 = vcvtnq_s32_f32(f1_0_3); + i1_4_7 = vcvtnq_s32_f32(f1_4_7); + i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + i1_8_11 = vcvtnq_s32_f32(f1_8_11); + i1_12_15 = vcvtnq_s32_f32(f1_12_15); + i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + // Calculate and store qs + i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16 + i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7); + vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15); + // Calculate and store bsum + i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15); + y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15); } - - y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; - } - - // qs: 8 byte interleave over 4 rows, loop = QK_K/8 - // bsums: simply generated one by one, row_i is calculated before row_i+1 - // loops = 16 - for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) { - // Process row0 and row1 - float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0])); - float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0])); - float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0])); - float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0])); - int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3); - int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7); - int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 - int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11); - int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15); - int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 - - float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1])); - float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1])); - float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1])); - float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1])); - int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3); - int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7); - int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 - int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11); - int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15); - int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 - - // Calculate and store qs - int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16 - int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 - vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7); - vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15); - // Calculate and store bsum - int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 - int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 - y[i].bsums[j] = vaddlvq_s8(i0_0_15); - y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15); - - // Process row2 and row3 - f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2])); - f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2])); - f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2])); - f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2])); - i0_0_3 = vcvtnq_s32_f32(f0_0_3); - i0_4_7 = vcvtnq_s32_f32(f0_4_7); - i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 - i0_8_11 = vcvtnq_s32_f32(f0_8_11); - i0_12_15 = vcvtnq_s32_f32(f0_12_15); - i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 - - f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3])); - f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3])); - f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3])); - f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3])); - i1_0_3 = vcvtnq_s32_f32(f1_0_3); - i1_4_7 = vcvtnq_s32_f32(f1_4_7); - i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 - i1_8_11 = vcvtnq_s32_f32(f1_8_11); - i1_12_15 = vcvtnq_s32_f32(f1_12_15); - i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 - - // Calculate and store qs - i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16 - i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 - vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7); - vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15); - // Calculate and store bsum - i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 - i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 - y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15); - y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15); } } + return; +#endif -#else // NOTE: // Current C impl of Q8_K quanti is originally designed to work with block_q4_Kx8 in x86 AVX design, and differs from // above Q8_K quanti logic in AArch64 NEON design, which is designed to work with block_q4_Kx4. The main difference is in @@ -340,8 +348,7 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR // "[bsums] layout" correctly in ggml_gemm_q4_K_4x8_q8_K_generic(). UNUSED(nb); UNUSED(y); - ggml_quantize_mat_q8_K_4x8_generic(x, vy, k); -#endif + ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc); } void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 7e2d0bf4ef..7f9e9e2172 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -176,9 +176,10 @@ void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GG } } -void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) { assert(QK_K == 256); assert(k % QK_K == 0); + UNUSED(nc); const int nb = k / QK_K; block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; @@ -230,30 +231,33 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG } // extern "C" template -void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row); +void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols); -template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) { assert(nrow == 4); UNUSED(nrow); + UNUSED(ncols); ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) { assert(nrow == 4); UNUSED(nrow); + UNUSED(ncols); ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) { assert(nrow == 4); UNUSED(nrow); + UNUSED(ncols); ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) { assert(nrow == 4); UNUSED(nrow); - ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); + ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row, ncols); } extern "C" { @@ -2502,7 +2506,7 @@ template ((float *) (data_ptr + i11 * nb11), - (void *) (wdata_ptr + i11 * nbw1), 4, ne10); + (void *) (wdata_ptr + i11 * nbw1), 4, ne10, ne01); } const int64_t i11_processed = ne11 - ne11 % 4; @@ -2775,15 +2779,13 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q4_K_8x8_q8_K; } } - if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { // new for ARM N2 - if (cur->ne[1] % 4 == 0) { - return &q4_K_4x8_q8_K; - } - } if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (cur->ne[1] % 8 == 0) { return &q4_K_8x8_q8_K; } + if (cur->ne[1] % 4 == 0) { + return &q4_K_4x8_q8_K; + } } if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index dd467f479a..ea5268c451 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -88,7 +88,7 @@ extern "C" { void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); -void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -116,7 +116,7 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); -void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc); // gemv_generic ??? void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);