diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index 7dda9eea0c..84b8b83b00 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -1396,7 +1396,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo const int blocklen = 8; static const uint32_t kmask1 = 0x3f3f3f3f; static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; + static const uint32_t kmask3 = 0x30303030; assert (n % qk == 0); assert (nc % ncols_interleaved == 0); @@ -1412,17 +1412,14 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo UNUSED(blocklen); #if defined(__AVX2__) - // Lookup table to convert signed nibbles to signed bytes - __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); - signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); // Shuffle masks to rearrange delta and scale values to multiply with appropriate scales - __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); - __m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); + static const __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); + static const __m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0); // Permute mask used for easier vector processing at later stages - __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + static const __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); // Mask to extract nibbles from bytes - const __m256i m4b = _mm256_set1_epi8(0x0F); + static const __m256i m4b = _mm256_set1_epi8(0x0F); int64_t b_nb = n / QK_K; @@ -1459,133 +1456,120 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo __m256i iacc_min_b = _mm256_setzero_si256(); const __m256i q8sums = _mm256_loadu_si256((const __m256i * )(a_ptr[b].bsums)); - __m256i q8s = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1))); - q8s = _mm256_permute2f128_si256(q8s, q8s, 0); + __m128i q8s = _mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)); + const uint32_t* q8s_ptr = (uint32_t*) &q8s; + + const int sbCount = QK_K / 64; + + const uint32_t *utmp = (const uint32_t*) (b_ptr[b].scales); // Processes two sub blocks from each Q4_K in each iteration - for (int sb = 0; sb < QK_K / 64; sb++) { - - // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 - const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256)); - const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256)); - const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256)); - const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256)); - const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256)); - const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256)); - const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256)); - const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256)); - - // 4-bit -> 8-bit - // Values of the first sub block of eight block_q4_K structures for the sb loop - const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b); - const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b); - const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b); - const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b); - const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b); - const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b); - const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b); - const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b); - - // Values of the second sub block of eight block_q4_K structures when sb = 1 - const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b); - const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b); - const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b); - const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b); - const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b); - const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b); - const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b); - const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b); - - uint32_t utmp_0[4], utmp_1[4]; + for (int sb = 0; sb < sbCount; sb++) { // Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop - memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12); - utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp_0[1] & kmask1; - utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4); - utmp_0[2] = uaux_0; - utmp_0[0] &= kmask1; - - // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop - memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12); - utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4); - const uint32_t uaux_1 = utmp_1[1] & kmask1; - utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4); - utmp_1[2] = uaux_1; - utmp_1[0] &= kmask1; + const uint32_t utmp_03 = ((utmp[2] >> 4) & kmask2) | ((utmp[1] >> 2) & kmask3); + const uint32_t utmp_02 = utmp[1] & kmask1; + const uint32_t utmp_01 = (utmp[2] & kmask2) | ((utmp[0] >> 2) & kmask3); + const uint32_t utmp_00 = utmp[0] & kmask1; // Scales of first sub block in the sb loop - const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]); + __m128i mins_and_scales_0 = _mm_set_epi32(utmp_03, utmp_02, utmp_01, utmp_00); __m128i scales_rearrange_0 = _mm_shuffle_epi8(mins_and_scales_0, scalemask); __m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0); + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + const uint32_t utmp_13 = ((utmp[5] >> 4) & kmask2) | ((utmp[4] >> 2) & kmask3); + const uint32_t utmp_12 = utmp[4] & kmask1; + const uint32_t utmp_11 = (utmp[5] & kmask2) | ((utmp[3] >> 2) & kmask3); + const uint32_t utmp_10 = utmp[3] & kmask1; + // Scales of second sub block in the sb loop - __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]); + __m128i mins_and_scales_1 = _mm_set_epi32(utmp_13, utmp_12, utmp_11, utmp_10); __m128i scales_rearrange_1 = _mm_shuffle_epi8(mins_and_scales_1, scalemask); __m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1); // Mins of first and second sub block of Q4_K block are arranged side by side __m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78))); - // Load the two sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector - __m256i lhs_vec_00 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 64))); - __m256i lhs_vec_01 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 64))); - __m256i lhs_vec_10 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 64))); - __m256i lhs_vec_11 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 64))); + utmp += 6; - lhs_vec_00 = _mm256_permute2f128_si256(lhs_vec_00, lhs_vec_00, 0); - lhs_vec_01 = _mm256_permute2f128_si256(lhs_vec_01, lhs_vec_01, 0); - lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0); - lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0); + const uint32_t* a_values = (const uint32_t*) (a_ptr[b].qs + sb * 64); + const uint32_t* b_values = (const uint32_t*) (b_ptr[b].qs + sb * 256); - // Dot product done within 32 bit lanes and accumulated in the same vector - // First done for first sub block and thenn for second sub block in each sb - // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3) - // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7) - // ........................................................................... - // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31) + // Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7 + const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_values)); + const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_values + 8)); + + const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b); + const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b); + + const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b); + const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b); + + const __m256i iacc_01 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_set1_epi32(a_values[0])), + _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_set1_epi32(a_values[1]))); + + const __m256i iacc_11 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_set1_epi32(a_values[8])), + _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_set1_epi32(a_values[9]))); - __m256i iacc_0 = _mm256_setzero_si256(); - __m256i iacc_1 = _mm256_setzero_si256(); + const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_values + 16)); + const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_values + 24)); - iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 0))); - iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_00, 85))); + const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b); + const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b); - iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 170))); - iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_00, 255))); + const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b); + const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b); - iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 0))); - iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_shuffle_epi32(lhs_vec_01, 85))); + const __m256i iacc_02 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_set1_epi32(a_values[2])), + _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_set1_epi32(a_values[3]))); - iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 170))); - iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_shuffle_epi32(lhs_vec_01, 255))); + const __m256i iacc_12 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_set1_epi32(a_values[10])), + _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_set1_epi32(a_values[11]))); - iacc_0 = _mm256_madd_epi16(iacc_0, scales_0); - iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 0))); - iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_10, 85))); + const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_values + 32)); + const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_values + 40)); - iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 170))); - iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_10, 255))); + const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b); + const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b); - iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 0))); - iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_shuffle_epi32(lhs_vec_11, 85))); + const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b); + const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b); - iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 170))); - iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_shuffle_epi32(lhs_vec_11, 255))); + const __m256i iacc_03 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_set1_epi32(a_values[4])), + _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_set1_epi32(a_values[5]))); + + const __m256i iacc_13 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_set1_epi32(a_values[12])), + _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_set1_epi32(a_values[13]))); + + + const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_values + 48)); + const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_values + 56)); + + const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b); + const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b); + + const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b); + const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b); + + __m256i iacc_04 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_set1_epi32(a_values[6])), + _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_set1_epi32(a_values[7]))); + + __m256i iacc_14 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_set1_epi32(a_values[14])), + _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_set1_epi32(a_values[15]))); - iacc_1 = _mm256_madd_epi16(iacc_1, scales_1); // Accumulate the iacc value for one sb - __m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1); + __m256i iacc_sb = _mm256_add_epi32( _mm256_madd_epi16( _mm256_add_epi16(_mm256_add_epi16(iacc_01, iacc_02), _mm256_add_epi16(iacc_03, iacc_04)), scales_0), + _mm256_madd_epi16( _mm256_add_epi16(_mm256_add_epi16(iacc_11, iacc_12), _mm256_add_epi16(iacc_13, iacc_14)), scales_1)); // Broadcast the bsums of the two sub blocks of the iteration of Q8_K across the vector // Multiply-Add with corresponding mins of Q4_Kx8 with bsums - __m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0); - __m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01); - q8s = _mm256_bsrli_epi128(q8s, 4); + const __m256i q8s_sb = _mm256_set1_epi32(q8s_ptr[sb]); + const __m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01); // Accumulate for the complete block iacc_b = _mm256_add_epi32(iacc_b, iacc_sb); @@ -1962,6 +1946,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo static const uint32_t kmask1 = 0x3f3f3f3f; static const uint32_t kmask2 = 0x0f0f0f0f; static const uint32_t kmask3 = 0x03030303; + static const uint32_t kmask_3 = 0x30303030; assert (n % qk == 0); assert (nr % 4 == 0); @@ -2764,6 +2749,8 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo // dmin values - Load the eight dmin values of block_q4_kx8 const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin); + const uint32_t *utmp = (const uint32_t*) (b_ptr[b].scales); + // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration for (int sb = 0; sb < QK_K / 64; sb++) { @@ -2865,31 +2852,25 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) - uint32_t utmp_0[4], utmp_1[4]; - - // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together + // Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop - memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12); - utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp_0[1] & kmask1; - utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4); - utmp_0[2] = uaux_0; - utmp_0[0] &= kmask1; - - // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop - memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12); - utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4); - const uint32_t uaux_1 = utmp_1[1] & kmask1; - utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4); - utmp_1[2] = uaux_1; - utmp_1[0] &= kmask1; + const uint32_t utmp_03 = ((utmp[2] >> 4) & kmask2) | ((utmp[1] >> 2) & kmask_3); + const uint32_t utmp_02 = utmp[1] & kmask1; + const uint32_t utmp_01 = (utmp[2] & kmask2) | ((utmp[0] >> 2) & kmask_3); + const uint32_t utmp_00 = utmp[0] & kmask1; // Scales of first sub block in the sb loop - const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]); + __m128i mins_and_scales_0 = _mm_set_epi32(utmp_03, utmp_02, utmp_01, utmp_00); const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0)); + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + const uint32_t utmp_13 = ((utmp[5] >> 4) & kmask2) | ((utmp[4] >> 2) & kmask_3); + const uint32_t utmp_12 = utmp[4] & kmask1; + const uint32_t utmp_11 = (utmp[5] & kmask2) | ((utmp[3] >> 2) & kmask_3); + const uint32_t utmp_10 = utmp[3] & kmask1; + // Scales of second sub block in the sb loop - const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]); + __m128i mins_and_scales_1 = _mm_set_epi32(utmp_13, utmp_12, utmp_11, utmp_10); const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1)); // Mins of first and second sub block of Q4_K block are arranged side by side @@ -2901,39 +2882,45 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68); const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238); + utmp += 6; + for (int rp = 0; rp < 4; rp++) { // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3 // Loaded as set of 128 bit vectors and repeated into a 256 bit vector __m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb))); - __m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0); - __m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17); __m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb))); - __m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0); - __m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17); __m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb))); - __m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0); - __m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17); __m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb))); - __m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0); - __m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17); __m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb))); - __m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0); - __m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17); __m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb))); - __m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0); - __m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17); __m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb))); - __m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0); - __m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17); __m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb))); - __m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0); - __m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17); - // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks - __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb))); - __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1))); - lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0); + __m256i lhs_mat_01_00 = _mm256_blend_epi32(lhs_mat_0123_00, _mm256_permutevar8x32_epi32(lhs_mat_0123_00, requiredOrder), 0xF0); + __m256i lhs_mat_23_00 = _mm256_blend_epi32(lhs_mat_0123_00, _mm256_permutevar8x32_epi32(lhs_mat_0123_00, requiredOrder), 0x0F); + + __m256i lhs_mat_01_01 = _mm256_blend_epi32(lhs_mat_0123_01, _mm256_permutevar8x32_epi32(lhs_mat_0123_01, requiredOrder), 0xF0); + __m256i lhs_mat_23_01 = _mm256_blend_epi32(lhs_mat_0123_01, _mm256_permutevar8x32_epi32(lhs_mat_0123_01, requiredOrder), 0x0F); + + __m256i lhs_mat_01_02 = _mm256_blend_epi32(lhs_mat_0123_02, _mm256_permutevar8x32_epi32(lhs_mat_0123_02, requiredOrder), 0xF0); + __m256i lhs_mat_23_02 = _mm256_blend_epi32(lhs_mat_0123_02, _mm256_permutevar8x32_epi32(lhs_mat_0123_02, requiredOrder), 0x0F); + + __m256i lhs_mat_01_03 = _mm256_blend_epi32(lhs_mat_0123_03, _mm256_permutevar8x32_epi32(lhs_mat_0123_03, requiredOrder), 0xF0); + __m256i lhs_mat_23_03 = _mm256_blend_epi32(lhs_mat_0123_03, _mm256_permutevar8x32_epi32(lhs_mat_0123_03, requiredOrder), 0x0F); + + __m256i lhs_mat_01_10 = _mm256_blend_epi32(lhs_mat_0123_10, _mm256_permutevar8x32_epi32(lhs_mat_0123_10, requiredOrder), 0xF0); + __m256i lhs_mat_23_10 = _mm256_blend_epi32(lhs_mat_0123_10, _mm256_permutevar8x32_epi32(lhs_mat_0123_10, requiredOrder), 0x0F); + + __m256i lhs_mat_01_11 = _mm256_blend_epi32(lhs_mat_0123_11, _mm256_permutevar8x32_epi32(lhs_mat_0123_11, requiredOrder), 0xF0); + __m256i lhs_mat_23_11 = _mm256_blend_epi32(lhs_mat_0123_11, _mm256_permutevar8x32_epi32(lhs_mat_0123_11, requiredOrder), 0x0F); + + __m256i lhs_mat_01_12 = _mm256_blend_epi32(lhs_mat_0123_12, _mm256_permutevar8x32_epi32(lhs_mat_0123_12, requiredOrder), 0xF0); + __m256i lhs_mat_23_12 = _mm256_blend_epi32(lhs_mat_0123_12, _mm256_permutevar8x32_epi32(lhs_mat_0123_12, requiredOrder), 0x0F); + + __m256i lhs_mat_01_13 = _mm256_blend_epi32(lhs_mat_0123_13, _mm256_permutevar8x32_epi32(lhs_mat_0123_13, requiredOrder), 0xF0); + __m256i lhs_mat_23_13 = _mm256_blend_epi32(lhs_mat_0123_13, _mm256_permutevar8x32_epi32(lhs_mat_0123_13, requiredOrder), 0x0F); + // Shuffle pattern one - left side input const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) @@ -3051,6 +3038,11 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks + __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb))); + __m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1))); + lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0); + __m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01); __m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01); __m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01); @@ -3060,7 +3052,6 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo acc_min_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]); acc_min_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]); acc_min_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]); - } } } @@ -3070,6 +3061,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } } } + for (; y < nr / 4; y++) { const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);