diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 9b1cae1c38..1793409820 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -210,6 +210,141 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR #endif } +void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK_K == 256); + assert(k % QK_K == 0); + const int nb = k / QK_K; + + block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; + +#if defined(__ARM_NEON) + const int blck_size_interleave = 8; + float32x4_t srcv[4][64]; // 64 = QK_K/4 + float iscale[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[64]; + float32x4_t amaxv[64]; + + // d: + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j); + for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]); + for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]); + for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]); + + const float amax = vmaxvq_f32(amaxv[0]); + + // Check if exists: orig == amax + float32x4_t amax_vec = vdupq_n_f32(amax); + uint32x4_t mask_all = vdupq_n_u32(0); + for (int j = 0; j < 64; j++) { + uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]); + mask_all = vorrq_u32(mask_all, mask_curr); + } + + // Assume that none == amax, then check mask_all to reverse + iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f; + uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu)); + if (vmaxvq_u32(cmp) != 0) { + iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f; + } + + y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; + } + + // qs: 8 byte interleave over 4 rows, loop = QK_K/8 + // bsums: simply generated one by one, row_i is calculated before row_i+1 + // loops = 16 + for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) { + // Process row0 and row1 + float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0])); + float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0])); + float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0])); + float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0])); + int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3); + int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7); + int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11); + int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15); + int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1])); + float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1])); + float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1])); + float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1])); + int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3); + int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7); + int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11); + int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15); + int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + // Calculate and store qs + int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16 + int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7); + vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15); + // Calculate and store bsum + int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + y[i].bsums[j] = vaddlvq_s8(i0_0_15); + y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15); + + // Process row2 and row3 + f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2])); + f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2])); + f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2])); + f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2])); + i0_0_3 = vcvtnq_s32_f32(f0_0_3); + i0_4_7 = vcvtnq_s32_f32(f0_4_7); + i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + i0_8_11 = vcvtnq_s32_f32(f0_8_11); + i0_12_15 = vcvtnq_s32_f32(f0_12_15); + i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3])); + f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3])); + f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3])); + f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3])); + i1_0_3 = vcvtnq_s32_f32(f1_0_3); + i1_4_7 = vcvtnq_s32_f32(f1_4_7); + i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 + i1_8_11 = vcvtnq_s32_f32(f1_8_11); + i1_12_15 = vcvtnq_s32_f32(f1_12_15); + i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 + + // Calculate and store qs + i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16 + i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7); + vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15); + // Calculate and store bsum + i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 + y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15); + y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15); + } + } + +#else + // NOTE: + // Current C implementation is actually aligned with x86 AVX2 design, but differs from above ARM NEON design. + // This is because the [bsums] layout is different in block_q8_Kx4 for the 2 designs. + // As NEON is supported in almost all the modern ARM platforms, this generic path can be rarely arrived nowadays. + // (The exceptional cases aren't suitable for AI work) + // However logically we may still need a corresponding generic version for ARM, called xxx_generic_arm for example. + UNUSED(nb); + UNUSED(y); + ggml_quantize_mat_q8_K_4x8_generic(x, vy, k); +#endif +} + void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -540,7 +675,8 @@ void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo return; } #else - // todo, c implementation + // C implementation + ggml_gemv_q4_K_4x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); #endif } @@ -2656,7 +2792,8 @@ void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } return; #else - // todo, c implementation + // C implementation + ggml_gemm_q4_K_4x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); #endif } diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index d2caed8a49..7dda9eea0c 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -287,6 +287,230 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR #endif } +void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK_K == 256); + assert(k % QK_K == 0); + const int nb = k / QK_K; + + block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; + +#if defined(__AVX2__) + float iscale[4]; + __m256 srcv[4][32]; + __m256 iscale_vec[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 ); + __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 ); + __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 ); + __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 ); + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 abs0 = _mm256_andnot_ps( signBit, v0 ); + __m256 abs1 = _mm256_andnot_ps( signBit, v1 ); + __m256 abs2 = _mm256_andnot_ps( signBit, v2 ); + __m256 abs3 = _mm256_andnot_ps( signBit, v3 ); + + __m256 maxAbs = _mm256_max_ps( abs0, abs1 ); + maxAbs = _mm256_max_ps( maxAbs, abs2 ); + maxAbs = _mm256_max_ps( maxAbs, abs3 ); + + __m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ ); + __m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ ); + __m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ ); + __m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ ); + + __m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3)); + + srcv[row_iter][0] = v0; + srcv[row_iter][1] = v1; + srcv[row_iter][2] = v2; + srcv[row_iter][3] = v3; + + for (int sb = 1; sb < 8; sb++) { + // Temporarily stores absolute quant values + __m256 tempAbs = maxAbs; + + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32); + __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 ); + __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 ); + __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 ); + + // Compute max(abs(e)) for the block + __m256 abs0 = _mm256_andnot_ps( signBit, v0 ); + __m256 abs1 = _mm256_andnot_ps( signBit, v1 ); + __m256 abs2 = _mm256_andnot_ps( signBit, v2 ); + __m256 abs3 = _mm256_andnot_ps( signBit, v3 ); + + maxAbs = _mm256_max_ps( maxAbs, abs0 ); + maxAbs = _mm256_max_ps( maxAbs, abs1 ); + maxAbs = _mm256_max_ps( maxAbs, abs2 ); + maxAbs = _mm256_max_ps( maxAbs, abs3 ); + + __m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ ); + maskAbs = _mm256_and_ps( maskAbs, mask_prev ); + + mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ ); + mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ ); + mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ ); + mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ ); + + __m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3)); + maskAbs = _mm256_or_ps(maskAbs, mask_curr); + + srcv[row_iter][sb * 4] = v0; + srcv[row_iter][sb * 4 + 1] = v1; + srcv[row_iter][sb * 4 + 2] = v2; + srcv[row_iter][sb * 4 + 3] = v3; + } + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + __m256 maxScalarVec = _mm256_set1_ps(maxScalar); + + __m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ ); + __m256 finalMask = _mm256_and_ps(maskAbs, mask_next); + + const int mask = _mm256_movemask_ps(finalMask); + iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + + if(mask) { + iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f; + } + + y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0; + iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]); + } + + __m256i quants_interleaved[32]; + for (int j = 0; j < 32; j++) { + // Apply the multiplier + __m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]); + __m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]); + __m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]); + __m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); + i2 = _mm256_packs_epi32( i2, i3 ); + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); + + // Permute and store the quantized weights in the required order after the pack instruction + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0); + quants_interleaved[j] = i0; + } + + // Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation + __m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15)); + shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0); + __m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15)); + shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0); + __m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9)); + shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0); + + for (int k = 0; k < 4; k++) { + // Quants from four different sub blocks are taken + __m256i q0 = quants_interleaved[k * 8 + 0]; + __m256i q1 = quants_interleaved[k * 8 + 1]; + __m256i q2 = quants_interleaved[k * 8 + 2]; + __m256i q3 = quants_interleaved[k * 8 + 3]; + __m256i q4 = quants_interleaved[k * 8 + 4]; + __m256i q5 = quants_interleaved[k * 8 + 5]; + __m256i q6 = quants_interleaved[k * 8 + 6]; + __m256i q7 = quants_interleaved[k * 8 + 7]; + + + // The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time + __m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2); + __m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34); + __m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3); + sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68); + __m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4); + sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136); + + __m256i one = _mm256_set1_epi8(1); + __m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved); + + for (int l = 0; l < 3; l++) { + // Quants value shifted to process next two values from each sub block + q0 = _mm256_srli_epi64(q0, 16); + q2 = _mm256_srli_epi64(q2, 16); + q4 = _mm256_srli_epi64(q4, 16); + q6 = _mm256_srli_epi64(q6, 16); + + sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2); + sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34); + sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3); + sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68); + sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4); + sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136); + + bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved)); + } + + // The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time + __m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2); + __m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34); + __m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3); + sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68); + __m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4); + sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136); + + __m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved); + + for (int l = 0; l < 3; l++) { + // Quants value shifted to process next two values from each sub block + q1 = _mm256_srli_epi64(q1, 16); + q3 = _mm256_srli_epi64(q3, 16); + q5 = _mm256_srli_epi64(q5, 16); + q7 = _mm256_srli_epi64(q7, 16); + + sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2); + sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34); + sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3); + sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68); + sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4); + sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136); + + bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved)); + } + + // Overall bsums in interleaved fashion computed by adding results of both halves + __m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2); + _mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r); + } + } + +#else + UNUSED(nb); + UNUSED(y); + ggml_quantize_mat_q8_K_4x8_generic(x, vy, k); +#endif +} + // // GEMV/GEMM templates // diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 338bb0e8d2..eed54473ea 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -227,346 +227,6 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG } } -void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(QK_K == 256); - assert(k % QK_K == 0); - const int nb = k / QK_K; - - block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; - -#if defined(__ARM_NEON) - const int blck_size_interleave = 8; - float32x4_t srcv[4][64]; // 64 = QK_K/4 - float iscale[4]; - - for (int i = 0; i < nb; i++) { - float32x4_t asrcv[64]; - float32x4_t amaxv[64]; - - // d: - for (int row_iter = 0; row_iter < 4; row_iter++) { - for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j); - for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); - - for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); - for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); - for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); - for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]); - for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]); - for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]); - - const float amax = vmaxvq_f32(amaxv[0]); - - // Check if exists: orig == amax - float32x4_t amax_vec = vdupq_n_f32(amax); - uint32x4_t mask_all = vdupq_n_u32(0); - for (int j = 0; j < 64; j++) { - uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]); - mask_all = vorrq_u32(mask_all, mask_curr); - } - - // Assume that none == amax, then check mask_all to reverse - iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f; - uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu)); - if (vmaxvq_u32(cmp) != 0) { - iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f; - } - - y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; - } - - // qs: 8 byte interleave over 4 rows, loop = QK_K/8 - // bsums: simply generated one by one, row_i is calculated before row_i+1 - // loops = 16 - for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) { - // Process row0 and row1 - float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0])); - float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0])); - float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0])); - float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0])); - int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3); - int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7); - int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 - int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11); - int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15); - int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 - - float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1])); - float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1])); - float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1])); - float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1])); - int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3); - int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7); - int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 - int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11); - int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15); - int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 - - // Calculate and store qs - int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16 - int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 - vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7); - vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15); - // Calculate and store bsum - int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 - int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 - y[i].bsums[j] = vaddlvq_s8(i0_0_15); - y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15); - - // Process row2 and row3 - f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2])); - f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2])); - f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2])); - f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2])); - i0_0_3 = vcvtnq_s32_f32(f0_0_3); - i0_4_7 = vcvtnq_s32_f32(f0_4_7); - i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 - i0_8_11 = vcvtnq_s32_f32(f0_8_11); - i0_12_15 = vcvtnq_s32_f32(f0_12_15); - i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 - - f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3])); - f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3])); - f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3])); - f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3])); - i1_0_3 = vcvtnq_s32_f32(f1_0_3); - i1_4_7 = vcvtnq_s32_f32(f1_4_7); - i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8 - i1_8_11 = vcvtnq_s32_f32(f1_8_11); - i1_12_15 = vcvtnq_s32_f32(f1_12_15); - i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8 - - // Calculate and store qs - i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16 - i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 - vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7); - vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15); - // Calculate and store bsum - i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 - i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16 - y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15); - y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15); - } - } -#elif defined(__AVX2__) - float iscale[4]; - __m256 srcv[4][32]; - __m256 iscale_vec[4]; - - for (int i = 0; i < nb; i++) { - for (int row_iter = 0; row_iter < 4; row_iter++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 ); - __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 ); - __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 ); - __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 ); - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 abs0 = _mm256_andnot_ps( signBit, v0 ); - __m256 abs1 = _mm256_andnot_ps( signBit, v1 ); - __m256 abs2 = _mm256_andnot_ps( signBit, v2 ); - __m256 abs3 = _mm256_andnot_ps( signBit, v3 ); - - __m256 maxAbs = _mm256_max_ps( abs0, abs1 ); - maxAbs = _mm256_max_ps( maxAbs, abs2 ); - maxAbs = _mm256_max_ps( maxAbs, abs3 ); - - __m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ ); - __m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ ); - __m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ ); - __m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ ); - - __m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3)); - - srcv[row_iter][0] = v0; - srcv[row_iter][1] = v1; - srcv[row_iter][2] = v2; - srcv[row_iter][3] = v3; - - for (int sb = 1; sb < 8; sb++) { - // Temporarily stores absolute quant values - __m256 tempAbs = maxAbs; - - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32); - __m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 ); - __m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 ); - __m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 ); - - // Compute max(abs(e)) for the block - __m256 abs0 = _mm256_andnot_ps( signBit, v0 ); - __m256 abs1 = _mm256_andnot_ps( signBit, v1 ); - __m256 abs2 = _mm256_andnot_ps( signBit, v2 ); - __m256 abs3 = _mm256_andnot_ps( signBit, v3 ); - - maxAbs = _mm256_max_ps( maxAbs, abs0 ); - maxAbs = _mm256_max_ps( maxAbs, abs1 ); - maxAbs = _mm256_max_ps( maxAbs, abs2 ); - maxAbs = _mm256_max_ps( maxAbs, abs3 ); - - __m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ ); - maskAbs = _mm256_and_ps( maskAbs, mask_prev ); - - mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ ); - mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ ); - mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ ); - mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ ); - - __m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3)); - maskAbs = _mm256_or_ps(maskAbs, mask_curr); - - srcv[row_iter][sb * 4] = v0; - srcv[row_iter][sb * 4 + 1] = v1; - srcv[row_iter][sb * 4 + 2] = v2; - srcv[row_iter][sb * 4 + 3] = v3; - } - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - __m256 maxScalarVec = _mm256_set1_ps(maxScalar); - - __m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ ); - __m256 finalMask = _mm256_and_ps(maskAbs, mask_next); - - const int mask = _mm256_movemask_ps(finalMask); - iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - - if(mask) { - iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f; - } - - y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0; - iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]); - } - - __m256i quants_interleaved[32]; - for (int j = 0; j < 32; j++) { - // Apply the multiplier - __m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]); - __m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]); - __m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]); - __m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); - i2 = _mm256_packs_epi32( i2, i3 ); - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); - - // Permute and store the quantized weights in the required order after the pack instruction - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0); - quants_interleaved[j] = i0; - } - - // Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation - __m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15)); - shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0); - __m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15)); - shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0); - __m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9)); - shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0); - - for (int k = 0; k < 4; k++) { - // Quants from four different sub blocks are taken - __m256i q0 = quants_interleaved[k * 8 + 0]; - __m256i q1 = quants_interleaved[k * 8 + 1]; - __m256i q2 = quants_interleaved[k * 8 + 2]; - __m256i q3 = quants_interleaved[k * 8 + 3]; - __m256i q4 = quants_interleaved[k * 8 + 4]; - __m256i q5 = quants_interleaved[k * 8 + 5]; - __m256i q6 = quants_interleaved[k * 8 + 6]; - __m256i q7 = quants_interleaved[k * 8 + 7]; - - - // The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time - __m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2); - __m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34); - __m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3); - sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68); - __m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4); - sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136); - - __m256i one = _mm256_set1_epi8(1); - __m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved); - - for (int l = 0; l < 3; l++) { - // Quants value shifted to process next two values from each sub block - q0 = _mm256_srli_epi64(q0, 16); - q2 = _mm256_srli_epi64(q2, 16); - q4 = _mm256_srli_epi64(q4, 16); - q6 = _mm256_srli_epi64(q6, 16); - - sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2); - sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34); - sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3); - sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68); - sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4); - sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136); - - bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved)); - } - - // The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time - __m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2); - __m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34); - __m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3); - sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68); - __m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4); - sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136); - - __m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved); - - for (int l = 0; l < 3; l++) { - // Quants value shifted to process next two values from each sub block - q1 = _mm256_srli_epi64(q1, 16); - q3 = _mm256_srli_epi64(q3, 16); - q5 = _mm256_srli_epi64(q5, 16); - q7 = _mm256_srli_epi64(q7, 16); - - sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2); - sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34); - sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3); - sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68); - sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4); - sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136); - - bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved)); - } - - // Overall bsums in interleaved fashion computed by adding results of both halves - __m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2); - _mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r); - } - } - -#else - // NOTE: This default c implementation is aligned with AVX2 implemantation, but differs from arm implementation. - // especially in bsums arrangement. - UNUSED(nb); - UNUSED(y); - ggml_quantize_mat_q8_K_4x8_generic(x, vy, k); -#endif -} - } // extern "C" template @@ -731,6 +391,87 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemv_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + uint8_t scales[4][8]; // scales for 8 subblocks of 4 q4_k unit (4 cols) + uint8_t mins[4][8]; // mins for 8 subblocks of 4 q4_k unit (4 cols) + float sumf[4]; // 1x4 unit: final result + float sum_minf[4]; // 1x4 unit: final minus result + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + + // loop on n dimension, each iteration works on 4 columns + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx4 * b_ptr = (const block_q4_Kx4 *) vx + (x * nb); + + // initialize results for 4 cols + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + + // loop on k dimension, each iteration works on 1 q4_kx4 block + for (int n = 0; n < nb; n++) { + // prepare scales and mins for 4 cols + for (int j = 0; j < ncols_interleaved; j++) { + for (int i = 0; i < 8; i++) { + scales[j][i] = b_ptr[n].scales[i * 8 + j]; + mins[j][i] = b_ptr[n].scales[i * 8 + j + ncols_interleaved]; + } + } + + // core loop: each iteration works on an interleaved unit (four 8-byte segments from 4 cols) + for (int k = 0; k < (qk / (2 * blocklen); k++)) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; i++) { + const int v0 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xf); + const int v1 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & >> 4); + sumi1 = (v0 * a_ptr[n].qs((k / 2) * 32 + (k % 2) * blocklen + i)); + sumi2 = (v0 * a_ptr[n].qs((k / 2) * 32 + (k % 2) * blocklen + i + 16)); + uint8_t scale = scales[j][k / 2]; + sumi += sumi1 * scale + sumi2 * scale; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[n].d[j]) * a_ptr[n].d; + } + } + + // prepare partial results for tail work + for (int j = 0; j < ncols_interleaved; j++) { + for (int i = 0; i < QK_K / 32; i++) { + sum_minf[j] += mins[j][i] * (a_ptr[n].bsums[i * 2] + a_ptr[n].bsums[i * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d; + } + } + } + + // save results + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } + +} + void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -1290,6 +1031,95 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemm_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + uint8_t scales[4][8]; // scales for 8 subblocks of 4 q4_k unit (4 cols) + uint8_t mins[4][8]; // mins for 8 subblocks of 4 q4_k unit (4 cols) + float sumf[4][4]; // 4x4 unit: final result + float sum_minf[4][4]; // 4x4 unit: final minus result + int sumi1; + int sumi2; + int sumi; + + // loop on m dimension, each iteration works on 4 rows + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + // loop on n dimension, each iteration works on 4 columns + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx4 * b_ptr = (const block_q4_Kx4 *) vx + (x * nb); + + // initialize results for 4 cols + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + + // loop on k dimension, each iteration works on 1 q4_kx4 block + for (int n = 0; n < nb; n++) { + // prepare scales and mins for 4 cols + for (int j = 0; j < ncols_interleaved; j++) { + for (int i = 0; i < 8; i++) { + scales[j][i] = b_ptr[n].scales[i * 8 + j]; + mins[j][i] = b_ptr[n].scales[i * 8 + j + ncols_interleaved]; + } + } + + // core loop: each iteration works on an interleaved unit (four 8-byte segments from 4 cols) + for (int k = 0; k < (qk / (2 * blocklen); k++)) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; i++) { + const int v0 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xf); + const int v1 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & >> 4); + sumi1 = (v0 * a_ptr[n].qs((k / 2) * 128 + (k % 2) * 4 * blocklen + i)); + sumi2 = (v0 * a_ptr[n].qs((k / 2) * 128 + (k % 2) * 4 * blocklen + i + 64)); + uint8_t scale = scales[j][k / 2]; + sumi += sumi1 * scale + sumi2 * scale; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[n].d[j]) * a_ptr[n].d[m]; + } + } + } + + // prepare partial results for tail work + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + for (int i = 0; i < QK_K / 32; i++) { + const int16_t *bsums = a_ptr[n].bsums + (i * 8) - ((i % 2) * 6) + (m * 4); + sum_minf[m][j] += mins[j][i] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d[m]; + } + } + } + } + + // save results + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -3037,8 +2867,10 @@ static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buf GGML_UNUSED(buft); } -// size calculation after q4_kx4 repacking, it's different from traditional type -size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) { +// For the aarch64 q4_k repack implementation only: +// New tensor storage size calculation after q4_kx4 repacking. +// Because sizeof(block_q4_Kx4) is a bit larger than default sizeof(block_q4_K)*4. +static inline size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) { size_t nbytes; const size_t blck_size = 256; const size_t type_size = sizeof(block_q4_Kx4) / 4;