diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 3f8946ac70..dbf654b340 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -38,6 +38,7 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K @@ -48,6 +49,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K @@ -58,7 +60,6 @@ #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 -#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 @@ -69,12 +70,14 @@ #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -94,6 +97,7 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K @@ -104,6 +108,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K @@ -126,6 +131,7 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K @@ -136,6 +142,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K @@ -165,6 +172,7 @@ #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K @@ -174,6 +182,7 @@ #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K @@ -202,6 +211,7 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K @@ -212,6 +222,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K @@ -242,6 +253,7 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_4x8_q8_K_generic ggml_gemv_q4_K_4x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K @@ -252,6 +264,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_4x8_q8_K_generic ggml_gemm_q4_K_4x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index b61220a189..ba86886282 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -210,6 +210,147 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR #endif } +void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) { + assert(QK_K == 256); + assert(k % QK_K == 0); + UNUSED(nc); + const int nb = k / QK_K; + + block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; + +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + if (nc % 8 == 0) { + UNUSED(nb); + UNUSED(y); + ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc); + } else if (nc % 4 == 0) { + const int blck_size_interleave = 8; + float32x4_t srcv[4][64]; // 64 = QK_K/4 + float iscale[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[64]; + float32x4_t amaxv[64]; + + // d: + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j); + for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]); + for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]); + for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]); + + const float amax = vmaxvq_f32(amaxv[0]); + + // Check if exists: orig == amax + float32x4_t amax_vec = vdupq_n_f32(amax); + uint32x4_t mask_all = vdupq_n_u32(0); + for (int j = 0; j < 64; j++) { + uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]); + mask_all = vorrq_u32(mask_all, mask_curr); + } + + // Assume that none == amax, then check mask_all to reverse + iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f; + uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu)); + if (vmaxvq_u32(cmp) != 0) { + iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f; + } + + y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; + } + + // qs: 8 byte interleave over 4 rows, loop = QK_K/8 + // bsums: simply generated one by one, row_i is calculated before row_i+1 + // loops = 16 + for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) { + // Process row0 and row1 + float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0])); + float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0])); + float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0])); + float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0])); + int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3); + int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7); + int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11); + int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15); + int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1])); + float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1])); + float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1])); + float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1])); + int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3); + int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7); + int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11); + int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15); + int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + // Calculate and store qs + int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16 + int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7); + vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15); + // Calculate and store bsum + int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + y[i].bsums[j] = vaddlvq_s8(i0_0_15); + y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15); + + // Process row2 and row3 + f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2])); + f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2])); + f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2])); + f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2])); + i0_0_3 = vcvtnq_s32_f32(f0_0_3); + i0_4_7 = vcvtnq_s32_f32(f0_4_7); + i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + i0_8_11 = vcvtnq_s32_f32(f0_8_11); + i0_12_15 = vcvtnq_s32_f32(f0_12_15); + i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3])); + f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3])); + f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3])); + f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3])); + i1_0_3 = vcvtnq_s32_f32(f1_0_3); + i1_4_7 = vcvtnq_s32_f32(f1_4_7); + i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + i1_8_11 = vcvtnq_s32_f32(f1_8_11); + i1_12_15 = vcvtnq_s32_f32(f1_12_15); + i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + // Calculate and store qs + i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16 + i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7); + vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15); + // Calculate and store bsum + i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15); + y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15); + } + } + } + return; +#endif + + // NOTE: + // Current C impl of Q8_K quanti is originally designed to work with block_q4_Kx8 in x86 AVX design, and differs from + // above Q8_K quanti logic in AArch64 NEON design, which is designed to work with block_q4_Kx4. The main difference is in + // the process of their "[bsums] layout". Hoever, we can still reuse the x86 C impl for AArch64, as long as we access the + // "[bsums] layout" correctly in ggml_gemm_q4_K_4x8_q8_K_generic(). + UNUSED(nb); + UNUSED(y); + ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc); +} + void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -499,6 +640,125 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + const block_q4_Kx4 *GGML_RESTRICT q4 = (const block_q4_Kx4*) vx; + const uint8x16_t m4b = vdupq_n_u8(0xf); + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_K *GGML_RESTRICT q8 = (const block_q8_K *) vy; + float32x4_t res = vdupq_n_f32(0); + for (int i = 0; i < nb; i++) { + float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) q4->d)); // d0 d1 d2 d3 + float32x4_t q8_d = vdupq_n_f32(q8->d); + float32x4_t g_d = vmulq_f32 (q4_d, q8_d); + float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) q4->dmin)); // dmin0 dmin1 dmin2 dmin3 + float32x4_t g_dmin = vmulq_f32(q4_dmin, q8_d); + const uint8_t * GGML_RESTRICT q4_ptr = q4->qs; + const int8_t * GGML_RESTRICT q8_ptr = q8->qs; + int32x4_t prod = vdupq_n_s32(0); + const int16x8_t q8_sums = vpaddq_s16(vld1q_s16(q8->bsums), vld1q_s16(q8->bsums + 8)); + // When using vgetq_lane_s16, its index must be a constant, which cannot be used in a loop, so use vst1q_s16 instead. + int16_t tmp_arry[8]; + vst1q_s16(tmp_arry, q8_sums); + for (int j = 0; j < QK_K / 32; ++j) { + int32x4_t sum0 = vdupq_n_s32(0); + int32x4_t sum1 = vdupq_n_s32(0); + // Each block: scales0 scales1 scales2 scales3 mins0 mins1 mins2 mins3 + int16x8_t scales_mins = vmovl_s8(vld1_s8((const int8_t *)q4->scales + 8 * j)); + prod = vmlal_s16(prod, vdup_n_s16(tmp_arry[j]), vget_high_s16(scales_mins)); + uint8x16_t q4_0 = vld1q_u8((const uint8_t *) q4_ptr); + uint8x16_t q4_1 = vld1q_u8((const uint8_t *) q4_ptr + 16); + uint8x16_t q4_2 = vld1q_u8((const uint8_t *) q4_ptr + 32); + uint8x16_t q4_3 = vld1q_u8((const uint8_t *) q4_ptr + 48); + q4_ptr += 64; + int8x16_t q8_0 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr); + int8x16_t q8_1 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr + 1); + int8x16_t q8_2 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr + 2); + int8x16_t q8_3 = (int8x16_t) vld1q_dup_s64((const int64_t *) q8_ptr + 3); + q8_ptr += 32; + + /* low bits + (1) sum0 + b0_000 b0_001 b0_002 b0_003 b0_004 b0_005 b0_006 b0_007 | b1_000 b1_001 b1_002 b1_003 b1_004 b1_005 b1_006 b1_007 + * a0 a1 a2 a3 a4 a5 a6 a7 | a0 a1 a2 a3 a4 a5 a6 a7 + |------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + (2) sum1 + b2_000 b2_001 b2_002 b2_003 b2_004 b2_005 b2_006 b2_007 | b3_000 b3_001 b3_002 b3_003 b3_004 b3_005 b3_006 b3_007 + * a0 a1 a2 a3 a4 a5 a6 a7 | a0 a1 a2 a3 a4 a5 a6 a7 + |------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + (3) sum0 + b0_008 b0_009 b0_010 b0_011 b0_012 b0_013 b0_014 b0_015 | b1_008 b1_009 b1_010 b1_011 b1_012 b1_013 b1_014 b1_015 + * a8 a9 a10 a11 a12 a13 a14 a15 | a8 a9 a10 a11 a12 a13 a14 a15 + |------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + (4) sum1 + b2_008 b2_009 b2_010 b2_011 b2_012 b2_013 b2_014 b2_015 | b3_008 b3_009 b3_010 b3_011 b3_012 b3_013 b3_014 b3_015 + * a8 a9 a10 a11 a12 a13 a14 a15 | a8 a9 a10 a11 a12 a13 a14 a15 + |------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + */ + sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vandq_u8(q4_0, m4b)), q8_0); + sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vandq_u8(q4_1, m4b)), q8_0); + sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vandq_u8(q4_2, m4b)), q8_1); + sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vandq_u8(q4_3, m4b)), q8_1); + + /* high bits + (1) sum0 + b0_016 b0_017 b0_018 b0_019 b0_020 b0_021 b0_022 b0_023 | b1_016 b1_017 b1_018 b1_019 b1_020 b1_021 b1_022 b1_023 + * a16 a17 a18 a19 a20 a21 a22 a23 | a16 a17 a18 a19 a20 a21 a22 a23 + |------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + (2) sum1 + b2_016 b2_017 b2_018 b2_019 b2_020 b2_021 b2_022 b2_023 | b3_016 b3_017 b3_018 b3_019 b3_020 b3_021 b3_022 b3_023 + * a16 a17 a18 a19 a20 a21 a22 a23 | a16 a17 a18 a19 a20 a21 a22 a23 + |------------dot-------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + (3) sum0 + b_024 b0_025 b0_026 b0_027 b0_028 b0_029 b0_030 b0_031 | b1_024 b1_025 b1_026 b1_027 b1_028 b1_029 b1_030 b1_031 + * a24 a25 a26 a27 a28 a29 a30 a31 | a24 a25 a26 a27 a28 a29 a30 a31 + |------------dot------------| |------------dot-------------| | |------------dot-------------| |------------dot-------------| + (4) sum1 + b2_024 b2_025 b2_026 b2_027 b2_028 b2_029 b2_030 b2_031 | b3_024 b3_025 b3_026 b3_027 b3_028 b3_029 b3_030 b3_031 + * a24 a25 a26 a27 a28 a29 a30 a31 | a24 a25 a26 a27 a28 a29 a30 a31 + |------------dot------------ | |------------dot-------------| | |------------dot-------------| |------------dot-------------| + */ + sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vshrq_n_u8(q4_0, 4)), q8_2); + sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vshrq_n_u8(q4_1, 4)), q8_2); + sum0 = vdotq_s32(sum0, vreinterpretq_s8_u8(vshrq_n_u8(q4_2, 4)), q8_3); + sum1 = vdotq_s32(sum1, vreinterpretq_s8_u8(vshrq_n_u8(q4_3, 4)), q8_3); + float32x4_t sumf = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), vpaddq_s32(sum0, sum1))); + res = vfmaq_f32(res, g_d, sumf); + } + res -= vmulq_f32(g_dmin, vcvtq_f32_s32(prod)); + q4++; + q8++; + } + vst1q_f32(s, res); + s += ncols_interleaved; + } + return; + } +#else + // C implementation + ggml_gemv_q4_K_4x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { constexpr int qk = QK_K; const int nb = n / qk; @@ -2329,6 +2589,299 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); // row + UNUSED(nc); // column + UNUSED(nb); // block + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + const block_q8_Kx4 * GGML_RESTRICT q8_ptr_start = (const block_q8_Kx4 *) vy; + const block_q4_Kx4 * GGML_RESTRICT q4_ptr_start = (const block_q4_Kx4 *) vx; + + const uint8x16_t m4b = vdupq_n_u8(0x0f); + float32x4_t zeros = vdupq_n_f32(0.0f); + int anr = nr - nr % 16; + int row = 0; + // Row loop + for (; row < anr / 4; row += 4) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptrs[4]; + q8_ptrs[0] = q8_ptr_start + (row * nb); + for (int i = 0; i < 3; ++i) { + q8_ptrs[i + 1] = q8_ptrs[i] + nb; + } + // Column loop + for (int col = 0; col < nc / ncols_interleaved; col++) { + const block_q4_Kx4 * GGML_RESTRICT q4_ptr = q4_ptr_start + (col * nb); + // init output + float32x4_t res[16]; // final result + for (int i = 0; i < 16; i++) { + res[i] = zeros; + } + // Block loop + for (int64_t b = 0; b < nb; b++) { + float32x4_t g_d[16]; + float32x4_t g_dmin[16]; + int16x8_t q8_bsums[16]; + int32x4_t prod[16]; // store bsums*mins + for (int i = 0; i < 16; i++) { + g_d[i] = zeros; + g_dmin[i] = zeros; + q8_bsums[i] = vdupq_n_s16(0); + prod[i] = vdupq_n_s32(0); + } + // Get global d and dmin + float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // col0 col1 col2 col3 + float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin0 dmin1 dmin2 dmin3 + int16_t tmp_q8_bsums_array[16][8]; + for (int iter = 0; iter < 4; iter++) { + // Calculation when four lines are grouped together + for (int in = 0; in < 4; in++) { + float32x4_t scalar_q8_d = vdupq_n_f32(q8_ptrs[iter][b].d[in]); + g_d[in + 4 * iter] = vmulq_f32(q4_d, scalar_q8_d); + g_dmin[in + 4 * iter] = vmulq_f32(q4_dmin, scalar_q8_d); + // The 16 elements in each row are merged into 8 elements. No loop expansion is performed here + q8_bsums[in + 4 * iter] = vpaddq_s16(vld1q_s16(q8_ptrs[iter][b].bsums + 16 * in), vld1q_s16(q8_ptrs[iter][b].bsums + 16 * in + 8)); + vst1q_s16(tmp_q8_bsums_array[in + 4 * iter], q8_bsums[in + 4 * iter]); + } + } + // The 256 elements in the superblock are processed in 8 steps + for (int sb = 0; sb < QK_K / 32; sb++) { + int32x4_t acc_rows[16]; // the calculated value of qs + int32x4_t sum[16]; // the value of qs after rearranging + for (int i = 0; i < 16; i++) { + acc_rows[i] = vdupq_n_s32(0); + sum[i] = vdupq_n_s32(0); + } + // each block: scales0 scales1 scales2 scales3 mins0 mins1 mins2 mins3 + int16x8_t scales_mins = vmovl_s8(vld1_s8((const int8_t *) q4_ptr[b].scales + 8 * sb)); + uint8x16_t q4_qs_raw_01_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + sb * 64); + uint8x16_t q4_qs_raw_23_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 16 + sb * 64); + uint8x16_t q4_qs_raw_01_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 32 + sb * 64); + uint8x16_t q4_qs_raw_23_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 48 + sb * 64); + + int8x16_t q4_qs_01_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_0, m4b)); // B0(0-7) B1(0-7) + int8x16_t q4_qs_23_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_0, m4b)); // B2(0-7) B3(0-7) + int8x16_t q4_qs_01_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_1, m4b)); // B0(8-15) B1(8-15) + int8x16_t q4_qs_23_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_1, m4b)); // B2(8-15) B3(8-15) + + int8x16_t q4_qs_01_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_0, 4)); // B0(16-23) B1(16-23) + int8x16_t q4_qs_23_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_0, 4)); // B2(16-23) B3(16-23) + int8x16_t q4_qs_01_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_1, 4)); // B0(24-31) B1(24-31) + int8x16_t q4_qs_23_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_1, 4)); // B2(24-31) B3(24-31) + + // The 16 rows of the left matrix are expanded four times + for (int iter = 0; iter < 4; iter++) { + // Direct loop unrolling + prod[0 + 4 * iter] = vmlal_s16(prod[0 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[0 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter): bsums*mins(0-3) + prod[1 + 4 * iter] = vmlal_s16(prod[1 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[1 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter+1): bsums*mins(0-3) + prod[2 + 4 * iter] = vmlal_s16(prod[2 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[2 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter+2): bsums*mins(0-3) + prod[3 + 4 * iter] = vmlal_s16(prod[3 + 4 * iter], vdup_n_s16(tmp_q8_bsums_array[3 + 4 * iter][sb]), vget_high_s16(scales_mins)); // row(iter+3): bsums*mins(0-3) + + int8x16_t q8_qs_01_00 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 128 * sb); // A0(0-7) A1(0-7) + int8x16_t q8_qs_23_00 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 16 + 128 * sb); // A2(0-7) A3(0-7) + + acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_00, q4_qs_01_l0); // A0*B0(0-7) A0*B1(0-7) A1*B0(0-7) A1*B1(0-7) + acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_00, q4_qs_23_l0); // A0*B2(0-7) A0*B3(0-7) A1*B2(0-7) A1*B3(0-7) + acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_00, q4_qs_01_l0); // A2*B0(0-7) A2*B1(0-7) A3*B0(0-7) A3*B1(0-7) + acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_00, q4_qs_23_l0); // A2*B2(0-7) A2*B3(0-7) A3*B2(0-7) A3*B3(0-7) + + int8x16_t q8_qs_01_01 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 32 + 128 * sb); // A0(8-15) A1(8-15) + int8x16_t q8_qs_23_01 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 48 + 128 * sb); // A2(8-15) A3(8-15) + + acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_01, q4_qs_01_l1); // A0*B0(8-15) A0*B1(8-15) A1*B0(8-15) A1*B1(8-15) + acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_01, q4_qs_23_l1); // A0*B2(8-15) A0*B3(8-15) A1*B2(8-15) A1*B3(8-15) + acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_01, q4_qs_01_l1); // A2*B0(8-15) A2*B1(8-15) A3*B0(8-15) A3*B1(8-15) + acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_01, q4_qs_23_l1); // A2*B2(8-15) A2*B3(8-15) A3*B2(8-15) A3*B3(8-15) + + int8x16_t q8_qs_01_02 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 64 + 128 * sb); // A0(16-23) A1(16-23) + int8x16_t q8_qs_23_02 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 80 + 128 * sb); // A2(16-23) A3(16-23) + + acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_02, q4_qs_01_h0); // A0*B0(16-23) A0*B1(16-23) A1*B0(16-23) A1*B1(16-23) + acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_02, q4_qs_23_h0); // A0*B2(16-23) A0*B3(16-23) A1*B2(16-23) A1*B3(16-23) + acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_02, q4_qs_01_h0); // A2*B0(16-23) A2*B1(16-23) A3*B0(16-23) A3*B1(16-23) + acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_02, q4_qs_23_h0); // A2*B2(16-23) A2*B3(16-23) A3*B2(16-23) A3*B3(16-23) + + int8x16_t q8_qs_01_03 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 96 + 128 * sb); // A0(24-31) A1(24-31) + int8x16_t q8_qs_23_03 = vld1q_s8((const int8_t *) q8_ptrs[iter][b].qs + 112 + 128 * sb); // A2(24-31) A3(24-31) + + acc_rows[0 + 4 * iter] = vmmlaq_s32(acc_rows[0 + 4 * iter], q8_qs_01_03, q4_qs_01_h1); // A0*B0(24-31) A0*B1(24-31) A1*B0(24-31) A1*B1(24-31) + acc_rows[1 + 4 * iter] = vmmlaq_s32(acc_rows[1 + 4 * iter], q8_qs_01_03, q4_qs_23_h1); // A0*B2(24-31) A0*B3(24-31) A1*B2(24-31) A1*B3(24-31) + acc_rows[2 + 4 * iter] = vmmlaq_s32(acc_rows[2 + 4 * iter], q8_qs_23_03, q4_qs_01_h1); // A2*B0(24-31) A2*B1(24-31) A3*B0(24-31) A3*B1(24-31) + acc_rows[3 + 4 * iter] = vmmlaq_s32(acc_rows[3 + 4 * iter], q8_qs_23_03, q4_qs_23_h1); // A2*B2(24-31) A2*B3(24-31) A3*B2(24-31) A3*B3(24-31) + + // rearranging vectors + sum[0 + 4 * iter] = vcombine_s32(vget_low_s32(acc_rows[0 + 4 * iter]), vget_low_s32(acc_rows[1 + 4 * iter])); + sum[1 + 4 * iter] = vcombine_s32(vget_high_s32(acc_rows[0 + 4 * iter]), vget_high_s32(acc_rows[1 + 4 * iter])); + sum[2 + 4 * iter] = vcombine_s32(vget_low_s32(acc_rows[2 + 4 * iter]), vget_low_s32(acc_rows[3 + 4 * iter])); + sum[3 + 4 * iter] = vcombine_s32(vget_high_s32(acc_rows[2 + 4 * iter]), vget_high_s32(acc_rows[3 + 4 * iter])); + + float32x4_t sumf_0 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[0 + 4 * iter])); // scales *qs + float32x4_t sumf_1 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[1 + 4 * iter])); + float32x4_t sumf_2 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[2 + 4 * iter])); + float32x4_t sumf_3 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[3 + 4 * iter])); + + res[0 + 4 * iter] = vfmaq_f32(res[0 + 4 * iter], g_d[0 + 4 * iter], sumf_0); + res[1 + 4 * iter] = vfmaq_f32(res[1 + 4 * iter], g_d[1 + 4 * iter], sumf_1); + res[2 + 4 * iter] = vfmaq_f32(res[2 + 4 * iter], g_d[2 + 4 * iter], sumf_2); + res[3 + 4 * iter] = vfmaq_f32(res[3 + 4 * iter], g_d[3 + 4 * iter], sumf_3); + } + } + for (int iter = 0; iter < 4; iter++) { + res[0 + 4 * iter] -= vmulq_f32(g_dmin[0 + 4 * iter], vcvtq_f32_s32(prod[0 + 4 * iter])); + res[1 + 4 * iter] -= vmulq_f32(g_dmin[1 + 4 * iter], vcvtq_f32_s32(prod[1 + 4 * iter])); + res[2 + 4 * iter] -= vmulq_f32(g_dmin[2 + 4 * iter], vcvtq_f32_s32(prod[2 + 4 * iter])); + res[3 + 4 * iter] -= vmulq_f32(g_dmin[3 + 4 * iter], vcvtq_f32_s32(prod[3 + 4 * iter])); + } + } + // store result + for (int i = 0; i < 16; i++) { + vst1q_f32((float *) (s + ((row * 4 + i) * bs + col * 4)), res[i]); + } + } + } + // Handling tail parts that are less than 16 lines + for (; row < nr / 4; row++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = q8_ptr_start + (row * nb); + // Column loop + for (int col = 0; col < nc / ncols_interleaved; col++) { + const block_q4_Kx4 * GGML_RESTRICT q4_ptr = q4_ptr_start + (col * nb); + // init output + float32x4_t res[4]; + for (int i = 0; i < 4; i++) { + res[i] = zeros; + } + // Block loop + for (int64_t b = 0; b < nb; b++) { + float32x4_t g_d[4]; + float32x4_t g_dmin[4]; + int16x8_t q8_bsums[4]; + int32x4_t prod[4]; + for (int i = 0; i < 4; i++) { + g_d[i] = zeros; + g_dmin[i] = zeros; + q8_bsums[i] = vdupq_n_s16(0); + prod[i] = vdupq_n_s32(0); + } + float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // col0 col1 col2 col3 + float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin0 dmin1 dmin2 dmin3 + int16_t tmp_q8_bsums_array[4][8]; + for (int in = 0; in < 4; in++) { + float32x4_t scalar_q8_d = vdupq_n_f32(q8_ptr[b].d[in]); + g_d[in] = vmulq_f32(q4_d, scalar_q8_d); + g_dmin[in] = vmulq_f32(q4_dmin, scalar_q8_d); + q8_bsums[in] = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * in), vld1q_s16(q8_ptr[b].bsums + 16 * in + 8)); + vst1q_s16(tmp_q8_bsums_array[in], q8_bsums[in]); + } + for (int sb = 0; sb < QK_K / 32; sb++) { + int32x4_t acc_rows[4]; + int32x4_t sum[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = vdupq_n_s32(0); + sum[i] = vdupq_n_s32(0); + } + // each block: scales0 scales1 scales2 scales3 mins0 mins1 mins2 mins3 + int16x8_t scales_mins = vmovl_s8(vld1_s8((const int8_t *) q4_ptr[b].scales + 8 * sb)); + uint8x16_t q4_qs_raw_01_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + sb * 64); + uint8x16_t q4_qs_raw_23_0 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 16 + sb * 64); + uint8x16_t q4_qs_raw_01_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 32 + sb * 64); + uint8x16_t q4_qs_raw_23_1 = vld1q_u8((const uint8_t *) q4_ptr[b].qs + 48 + sb * 64); + + int8x16_t q4_qs_01_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_0, m4b)); // B0(0-7) B1(0-7) + int8x16_t q4_qs_23_l0 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_0, m4b)); // B2(0-7) B3(0-7) + int8x16_t q4_qs_01_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_01_1, m4b)); // B0(8-15) B1(8-15) + int8x16_t q4_qs_23_l1 = vreinterpretq_s8_u8(vandq_u8(q4_qs_raw_23_1, m4b)); // B2(8-15) B3(8-15) + + prod[0] = vmlal_s16(prod[0], vdup_n_s16(tmp_q8_bsums_array[0][sb]), vget_high_s16(scales_mins)); // row(iter): bsums*mins(0-3) + prod[1] = vmlal_s16(prod[1], vdup_n_s16(tmp_q8_bsums_array[1][sb]), vget_high_s16(scales_mins)); // row(iter+1): bsums*mins(0-3) + prod[2] = vmlal_s16(prod[2], vdup_n_s16(tmp_q8_bsums_array[2][sb]), vget_high_s16(scales_mins)); // row(iter+2): bsums*mins(0-3) + prod[3] = vmlal_s16(prod[3], vdup_n_s16(tmp_q8_bsums_array[3][sb]), vget_high_s16(scales_mins)); // row(iter+3): bsums*mins(0-3) + + int8x16_t q8_qs_01_00 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 128 * sb); // A0(0-7) A1(0-7) + int8x16_t q8_qs_23_00 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 16 + 128 * sb); // A2(0-7) A3(0-7) + + acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_00, q4_qs_01_l0); // A0*B0(0-7) A0*B1(0-7) A1*B0(0-7) A1*B1(0-7) + acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_00, q4_qs_23_l0); // A0*B2(0-7) A0*B3(0-7) A1*B2(0-7) A1*B3(0-7) + acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_00, q4_qs_01_l0); // A2*B0(0-7) A2*B1(0-7) A3*B0(0-7) A3*B1(0-7) + acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_00, q4_qs_23_l0); // A2*B2(0-7) A2*B3(0-7) A3*B2(0-7) A3*B3(0-7) + + int8x16_t q8_qs_01_01 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 32 + 128 * sb); // A0(8-15) A1(8-15) + int8x16_t q8_qs_23_01 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 48 + 128 * sb); // A2(8-15) A3(8-15) + + acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_01, q4_qs_01_l1); // A0*B0(8-15) A0*B1(8-15) A1*B0(8-15) A1*B1(8-15) + acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_01, q4_qs_23_l1); // A0*B2(8-15) A0*B3(8-15) A1*B2(8-15) A1*B3(8-15) + acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_01, q4_qs_01_l1); // A2*B0(8-15) A2*B1(8-15) A3*B0(8-15) A3*B1(8-15) + acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_01, q4_qs_23_l1); // A2*B2(8-15) A2*B3(8-15) A3*B2(8-15) A3*B3(8-15) + + int8x16_t q4_qs_01_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_0, 4)); // B0(16-23) B1(16-23) + int8x16_t q4_qs_23_h0 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_0, 4)); // B2(16-23) B3(16-23) + int8x16_t q4_qs_01_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_01_1, 4)); // B0(24-31) B1(24-31) + int8x16_t q4_qs_23_h1 = vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_raw_23_1, 4)); // B2(24-31) B3(24-31) + + int8x16_t q8_qs_01_02 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 64 + 128 * sb); // A0(16-23) A1(16-23) + int8x16_t q8_qs_23_02 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 80 + 128 * sb); // A2(16-23) A3(16-23) + + acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_02, q4_qs_01_h0); // A0*B0(16-23) A0*B1(16-23) A1*B0(16-23) A1*B1(16-23) + acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_02, q4_qs_23_h0); // A0*B2(16-23) A0*B3(16-23) A1*B2(16-23) A1*B3(16-23) + acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_02, q4_qs_01_h0); // A2*B0(16-23) A2*B1(16-23) A3*B0(16-23) A3*B1(16-23) + acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_02, q4_qs_23_h0); // A2*B2(16-23) A2*B3(16-23) A3*B2(16-23) A3*B3(16-23) + + int8x16_t q8_qs_01_03 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 96 + 128 * sb); // A0(24-31) A1(24-31) + int8x16_t q8_qs_23_03 = vld1q_s8((const int8_t *) q8_ptr[b].qs + 112 + 128 * sb); // A2(24-31) A3(24-31) + + acc_rows[0] = vmmlaq_s32(acc_rows[0], q8_qs_01_03, q4_qs_01_h1); // A0*B0(24-31) A0*B1(24-31) A1*B0(24-31) A1*B1(24-31) + acc_rows[1] = vmmlaq_s32(acc_rows[1], q8_qs_01_03, q4_qs_23_h1); // A0*B2(24-31) A0*B3(24-31) A1*B2(24-31) A1*B3(24-31) + acc_rows[2] = vmmlaq_s32(acc_rows[2], q8_qs_23_03, q4_qs_01_h1); // A2*B0(24-31) A2*B1(24-31) A3*B0(24-31) A3*B1(24-31) + acc_rows[3] = vmmlaq_s32(acc_rows[3], q8_qs_23_03, q4_qs_23_h1); // A2*B2(24-31) A2*B3(24-31) A3*B2(24-31) A3*B3(24-31) + + // rearranging vectors + sum[0] = vcombine_s32(vget_low_s32(acc_rows[0]), vget_low_s32(acc_rows[1])); + sum[1] = vcombine_s32(vget_high_s32(acc_rows[0]), vget_high_s32(acc_rows[1])); + sum[2] = vcombine_s32(vget_low_s32(acc_rows[2]), vget_low_s32(acc_rows[3])); + sum[3] = vcombine_s32(vget_high_s32(acc_rows[2]), vget_high_s32(acc_rows[3])); + + float32x4_t sumf_0 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[0])); // scales *qs + float32x4_t sumf_1 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[1])); + float32x4_t sumf_2 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[2])); + float32x4_t sumf_3 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(vget_low_s16(scales_mins)), sum[3])); + + res[0] = vfmaq_f32(res[0], g_d[0], sumf_0); + res[1] = vfmaq_f32(res[1], g_d[1], sumf_1); + res[2] = vfmaq_f32(res[2], g_d[2], sumf_2); + res[3] = vfmaq_f32(res[3], g_d[3], sumf_3); + } + res[0] -= vmulq_f32(g_dmin[0], vcvtq_f32_s32(prod[0])); + res[1] -= vmulq_f32(g_dmin[1], vcvtq_f32_s32(prod[1])); + res[2] -= vmulq_f32(g_dmin[2], vcvtq_f32_s32(prod[2])); + res[3] -= vmulq_f32(g_dmin[3], vcvtq_f32_s32(prod[3])); + } + // store result + for (int i = 0; i < 4; i++) { + vst1q_f32((float *) (s + ((row * 4 + i) * bs + col * 4)), res[i]); + } + } + } + return; +#else + // C implementation + ggml_gemm_q4_K_4x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +#endif +} + void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { constexpr int qk = QK_K; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index 7dda9eea0c..22a9b03989 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -287,9 +287,10 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR #endif } -void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) { assert(QK_K == 256); assert(k % QK_K == 0); + UNUSED(nc); const int nb = k / QK_K; block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; @@ -507,7 +508,7 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR #else UNUSED(nb); UNUSED(y); - ggml_quantize_mat_q8_K_4x8_generic(x, vy, k); + ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc); #endif } diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index fbf7ed9432..7f9e9e2172 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -176,9 +176,10 @@ void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GG } } -void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) { assert(QK_K == 256); assert(k % QK_K == 0); + UNUSED(nc); const int nb = k / QK_K; block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; @@ -230,30 +231,33 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG } // extern "C" template -void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row); +void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols); -template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) { assert(nrow == 4); UNUSED(nrow); + UNUSED(ncols); ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) { assert(nrow == 4); UNUSED(nrow); + UNUSED(ncols); ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) { assert(nrow == 4); UNUSED(nrow); + UNUSED(ncols); ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row); } -template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { +template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) { assert(nrow == 4); UNUSED(nrow); - ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); + ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row, ncols); } extern "C" { @@ -391,6 +395,86 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemv_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + int8_t scales[4][8]; // scales for 8 subblocks of 4 q4_k unit (4 cols) + int8_t mins[4][8]; // mins for 8 subblocks of 4 q4_k unit (4 cols) + float sumf[4]; // 1x4 unit: final result + float sum_minf[4]; // 1x4 unit: final minus result + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + + // loop on n dimension, each iteration works on 4 columns + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx4 * b_ptr = (const block_q4_Kx4 *) vx + (x * nb); + + // initialize results for 4 cols + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + + // loop on k dimension, each iteration works on 1 q4_kx4 block + for (int n = 0; n < nb; n++) { + // prepare scales and mins for 4 cols + for (int j = 0; j < ncols_interleaved; j++) { + for (int i = 0; i < 8; i++) { + scales[j][i] = b_ptr[n].scales[i * 8 + j]; + mins[j][i] = b_ptr[n].scales[i * 8 + j + ncols_interleaved]; + } + } + // core loop: each iteration works on an interleaved unit (four 8-byte segments from 4 cols) + for (int k = 0; k < qk / (2 * blocklen); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + int8_t scale = scales[j][k / 2]; + for (int i = 0; i < blocklen; i++) { + const int v0 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xf); + const int v1 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = v0 * a_ptr[n].qs[(k / 2) * 32 + (k % 2) * blocklen + i]; + sumi2 = v1 * a_ptr[n].qs[(k / 2) * 32 + (k % 2) * blocklen + i + 16]; + sumi += scale * (sumi1 + sumi2); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[n].d[j]) * a_ptr[n].d; + } + } + + // prepare partial results for tail work + for (int j = 0; j < ncols_interleaved; j++) { + for (int i = 0; i < QK_K / 32; i++) { + sum_minf[j] += mins[j][i] * (a_ptr[n].bsums[i * 2] + a_ptr[n].bsums[i * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d; + } + } + } + + // save results + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } + +} + void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -950,6 +1034,98 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemm_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + int8_t scales[4][8]; // scales for 8 subblocks of 4 q4_k unit (4 cols) + int8_t mins[4][8]; // mins for 8 subblocks of 4 q4_k unit (4 cols) + float sumf[4][4]; // 4x4 unit: final result + float sum_minf[4][4]; // 4x4 unit: final minus result + int sumi1; + int sumi2; + int sumi; + + // loop on m dimension, each iteration works on 4 rows + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + // loop on n dimension, each iteration works on 4 columns + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx4 * b_ptr = (const block_q4_Kx4 *) vx + (x * nb); + + // initialize results for 4 cols + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + + // loop on k dimension, each iteration works on 1 q4_kx4 block + for (int n = 0; n < nb; n++) { + // prepare scales and mins for 4 cols + for (int j = 0; j < ncols_interleaved; j++) { + for (int i = 0; i < 8; i++) { + scales[j][i] = b_ptr[n].scales[i * 8 + j]; + mins[j][i] = b_ptr[n].scales[i * 8 + j + ncols_interleaved]; + } + } + + // core loop: each iteration works on an interleaved unit (four 8-byte segments from 4 cols) + for (int k = 0; k < qk / (2 * blocklen); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + int8_t scale = scales[j][k / 2]; + for (int i = 0; i < blocklen; i++) { + const int v0 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xf); + const int v1 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = v0 * a_ptr[n].qs[(k / 2) * 128 + (k % 2) * 4 * blocklen + m * blocklen + i]; + sumi2 = v1 * a_ptr[n].qs[(k / 2) * 128 + (k % 2) * 4 * blocklen + m * blocklen + i + 64]; + sumi += scale * (sumi1 + sumi2); + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[n].d[j]) * a_ptr[n].d[m]; + } + } + } + + // prepare partial results for tail work + // + // NOTE: + // the "[bsums] layout" here is from ggml_quantize_mat_q8_K_4x8_generic(). + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + for (int i = 0; i < QK_K / 32; i++) { + const int16_t *bsums = a_ptr[n].bsums + (i * 8) - ((i % 2) * 6) + (m * 4); + sum_minf[m][j] += mins[j][i] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d[m]; + } + } + } + } + + // save results + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -1505,6 +1681,90 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in return out; } +static void make_block_q4_Kx4(block_q4_K * in, unsigned int blck_size_interleave, block_q4_Kx4 * out) { + int nrow = 4; + int nloop = 4; + + // d and dmin values of the 4 Q4_K are copied directly. + for (int i = 0; i < nrow; i++) { + out->d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < nrow; i++) { + out->dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + // For qs, 2 things need to be done: + // 1. Recover from Q4_K storage tyle to Q4_0 style + // 2. Interleave quants by taking 8 bytes at a time + + // 1. + const uint64_t lo_mask = 0x0f0f0f0f0f0f0f0fULL; + const uint64_t hi_mask = 0xf0f0f0f0f0f0f0f0ULL; + for (int i = 0; i < nrow; i++) { + uint64_t *q = (uint64_t *)(in[i].qs); + for (int j = 0; j < nloop; j++) { + uint64_t q0, q1, q2, q3; + q0 = q[0]; + q1 = q[1]; + q2 = q[2]; + q3 = q[3]; + + uint64_t hi1, hi2, lo3, lo4; + hi1 = q0 & hi_mask; + hi2 = q1 & hi_mask; + lo3 = q2 & lo_mask; + lo4 = q3 & lo_mask; + q[0] = (q0 & lo_mask) | (lo3 << 4); + q[1] = (q1 & lo_mask) | (lo4 << 4); + q[2] = (q2 & hi_mask) | (hi1 >> 4); + q[3] = (q3 & hi_mask) | (hi2 >> 4); + + q += 4; + } + } + + // 2. + // Calculate total number of interleaved subblocks + const int end = QK_K * 2 / blck_size_interleave; + uint64_t *src, *dst; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + src = (uint64_t *)(&in[src_id].qs[src_offset]); + dst = (uint64_t *)(&out->qs[dst_offset]); + *dst = *src; + } + + // For scales & mins of each subblock. (8 subblocks in one Q4_K, 32 in total) + // A special requirement to meet: expand to 8-bit from 6-bit. + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + uint32_t utmp[4]; + for (int i = 0; i < nrow; i++) { + // rearrange as d|d|...|d|min|min|...|min + // expand to 8-bit from 6-bit + memset(utmp, 0, 16); + memcpy(utmp, in[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + // move to Q4_K + const uint8_t * d_ptr = (const uint8_t*)&utmp[0]; + const uint8_t * m_ptr = (const uint8_t*)&utmp[2]; + for (int j = 0; j < 8; j++) { + out->scales[j * 8 + i] = *d_ptr++; + out->scales[j * 8 + i + nrow] = *m_ptr++; + } + } +} + static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) { block_q4_Kx8 out; //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure @@ -1656,6 +1916,46 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q4_K_to_q4_K_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 4; + + block_q4_Kx4 * dst = (block_q4_Kx4 *)t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_K dst_tmp[4]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + make_block_q4_Kx4(dst_tmp, interleave_block, dst++); + } + src += nrows_interleaved * nblocks; + } + + // change tensor shape as block_q4_kx4 brings space size change + //t->nb[0] = ggml_type_size(type); + t->nb[0] = sizeof(block_q4_Kx4) / 4; + t->nb[1] = t->nb[0] * (t->ne[0] / ggml_blck_size(t->type)); + for (int i = 2; i < GGML_MAX_DIMS; i++) { + t->nb[i] = t->nb[i - 1] * t->ne[i - 1]; + } + + return 0; + + GGML_UNUSED(data_size); +} + static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_K); GGML_ASSERT(interleave_block == 8 || interleave_block == 4); @@ -1924,6 +2224,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_4_bl(t, 8, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size); } @@ -1973,6 +2277,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_4x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -2013,14 +2321,18 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); -} - template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_4x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -2194,7 +2506,7 @@ template ((float *) (data_ptr + i11 * nb11), - (void *) (wdata_ptr + i11 * nbw1), 4, ne10); + (void *) (wdata_ptr + i11 * nbw1), 4, ne10, ne01); } const int64_t i11_processed = ne11 - ne11 % 4; @@ -2292,7 +2604,7 @@ template from_float; // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + //GGML_ASSERT(nb00 == ggml_type_size(src0->type)); GGML_ASSERT(nb10 == ggml_type_size(src1->type)); // dst cannot be transposed or permuted @@ -2429,6 +2741,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_0_8x8_q8_0; // instance for Q4_K + static const ggml::cpu::repack::tensor_traits q4_K_4x8_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; @@ -2470,6 +2783,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons if (cur->ne[1] % 8 == 0) { return &q4_K_8x8_q8_K; } + if (cur->ne[1] % 4 == 0) { + return &q4_K_4x8_q8_K; + } } if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { if (cur->ne[1] % 8 == 0) { @@ -2555,6 +2871,30 @@ static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buf GGML_UNUSED(buft); } +// Below func is for the aarch64 q4_K_4x8_q8_K repack case only: +// Tensor storage after repacking is a bit larger than before -- sizeof(block_q4_Kx4) > sizeof(block_q4_K)*4 +// This is due to member "scales" are pre-decoded in repacking stage, not in execution stage. +static inline size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) { + size_t nbytes; + const size_t blck_size = 256; + const size_t type_size = sizeof(block_q4_Kx4) / 4; + nbytes = ((tensor->ne[0] * type_size) / blck_size) * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; + return nbytes; +} + +static size_t ggml_backend_cpu_aarch64_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { + if (tensor->type == GGML_TYPE_Q4_K) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (tensor->ne[1] % 4 == 0) { + return ggml_nbytes_q4_kx4(tensor); // for q4_K_4x8_q8_K only + } + } + } + return ggml_nbytes(tensor); + + GGML_UNUSED(buft); +} + namespace ggml::cpu::repack { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { @@ -2611,7 +2951,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void) { /* .alloc_buffer = */ ggml_backend_cpu_repack_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cpu_repack_buffer_type_get_alignment, /* .get_max_size = */ nullptr, // defaults to SIZE_MAX - /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes + /* .get_alloc_size = */ ggml_backend_cpu_aarch64_buffer_type_get_alloc_size, // defaults to ggml_nbytes except for ARM N2 /* .is_host = */ nullptr, }, /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index af98e70344..ea5268c451 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -36,6 +36,13 @@ using block_q4_0x8 = block<4, 8>; using block_q8_0x4 = block<8, 4>; using block_q8_0x8 = block<8, 8>; +struct block_q4_Kx4 { + ggml_half d[4]; // super-block scale for quantized scales + ggml_half dmin[4]; // super-block scale for quantized mins + int8_t scales[64]; // scales and mins, quantized with 8 bits (recover from 6-bit during repack) + uint8_t qs[512]; // 4--bit quants +}; + struct block_q4_Kx8 { ggml_half d[8]; // super-block scale for quantized scales ggml_half dmin[8]; // super-block scale for quantized mins @@ -81,10 +88,11 @@ extern "C" { void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); -void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -93,6 +101,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -107,18 +116,22 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); -void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc); +// gemv_generic ??? void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +// gemm_generic ??? void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);