This commit is contained in:
Peter Engler 2026-01-02 23:47:03 +02:00 committed by GitHub
commit 9dc07a8b72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 131 additions and 139 deletions

View File

@ -1396,7 +1396,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
const int blocklen = 8;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
static const uint32_t kmask3 = 0x30303030;
assert (n % qk == 0);
assert (nc % ncols_interleaved == 0);
@ -1412,17 +1412,14 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
UNUSED(blocklen);
#if defined(__AVX2__)
// Lookup table to convert signed nibbles to signed bytes
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
// Shuffle masks to rearrange delta and scale values to multiply with appropriate scales
__m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
__m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
static const __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
static const __m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
// Permute mask used for easier vector processing at later stages
__m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
static const __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
// Mask to extract nibbles from bytes
const __m256i m4b = _mm256_set1_epi8(0x0F);
static const __m256i m4b = _mm256_set1_epi8(0x0F);
int64_t b_nb = n / QK_K;
@ -1459,133 +1456,120 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
__m256i iacc_min_b = _mm256_setzero_si256();
const __m256i q8sums = _mm256_loadu_si256((const __m256i * )(a_ptr[b].bsums));
__m256i q8s = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));
q8s = _mm256_permute2f128_si256(q8s, q8s, 0);
__m128i q8s = _mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1));
const uint32_t* q8s_ptr = (uint32_t*) &q8s;
const int sbCount = QK_K / 64;
const uint32_t *utmp = (const uint32_t*) (b_ptr[b].scales);
// Processes two sub blocks from each Q4_K in each iteration
for (int sb = 0; sb < QK_K / 64; sb++) {
// Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
// 4-bit -> 8-bit
// Values of the first sub block of eight block_q4_K structures for the sb loop
const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
// Values of the second sub block of eight block_q4_K structures when sb = 1
const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);
const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
uint32_t utmp_0[4], utmp_1[4];
for (int sb = 0; sb < sbCount; sb++) {
// Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp_0[1] & kmask1;
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
utmp_0[2] = uaux_0;
utmp_0[0] &= kmask1;
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
const uint32_t uaux_1 = utmp_1[1] & kmask1;
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
utmp_1[2] = uaux_1;
utmp_1[0] &= kmask1;
const uint32_t utmp_03 = ((utmp[2] >> 4) & kmask2) | ((utmp[1] >> 2) & kmask3);
const uint32_t utmp_02 = utmp[1] & kmask1;
const uint32_t utmp_01 = (utmp[2] & kmask2) | ((utmp[0] >> 2) & kmask3);
const uint32_t utmp_00 = utmp[0] & kmask1;
// Scales of first sub block in the sb loop
const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
__m128i mins_and_scales_0 = _mm_set_epi32(utmp_03, utmp_02, utmp_01, utmp_00);
__m128i scales_rearrange_0 = _mm_shuffle_epi8(mins_and_scales_0, scalemask);
__m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
const uint32_t utmp_13 = ((utmp[5] >> 4) & kmask2) | ((utmp[4] >> 2) & kmask3);
const uint32_t utmp_12 = utmp[4] & kmask1;
const uint32_t utmp_11 = (utmp[5] & kmask2) | ((utmp[3] >> 2) & kmask3);
const uint32_t utmp_10 = utmp[3] & kmask1;
// Scales of second sub block in the sb loop
__m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
__m128i mins_and_scales_1 = _mm_set_epi32(utmp_13, utmp_12, utmp_11, utmp_10);
__m128i scales_rearrange_1 = _mm_shuffle_epi8(mins_and_scales_1, scalemask);
__m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
// Mins of first and second sub block of Q4_K block are arranged side by side
__m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
// Load the two sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector
__m256i lhs_vec_00 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 64)));
__m256i lhs_vec_01 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 64)));
__m256i lhs_vec_10 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 64)));
__m256i lhs_vec_11 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 64)));
utmp += 6;
lhs_vec_00 = _mm256_permute2f128_si256(lhs_vec_00, lhs_vec_00, 0);
lhs_vec_01 = _mm256_permute2f128_si256(lhs_vec_01, lhs_vec_01, 0);
lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0);
lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);
const uint32_t* a_values = (const uint32_t*) (a_ptr[b].qs + sb * 64);
const uint32_t* b_values = (const uint32_t*) (b_ptr[b].qs + sb * 256);
// Dot product done within 32 bit lanes and accumulated in the same vector
// First done for first sub block and thenn for second sub block in each sb
// B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
// B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
// ...........................................................................
// B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
// Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_values));
const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_values + 8));
const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);
const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
const __m256i iacc_01 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_set1_epi32(a_values[0])),
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_set1_epi32(a_values[1])));
const __m256i iacc_11 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_set1_epi32(a_values[8])),
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_set1_epi32(a_values[9])));
__m256i iacc_0 = _mm256_setzero_si256();
__m256i iacc_1 = _mm256_setzero_si256();
const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_values + 16));
const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_values + 24));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 0)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_00, 85)));
const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 170)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_00, 255)));
const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 0)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_shuffle_epi32(lhs_vec_01, 85)));
const __m256i iacc_02 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_set1_epi32(a_values[2])),
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_set1_epi32(a_values[3])));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 170)));
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_shuffle_epi32(lhs_vec_01, 255)));
const __m256i iacc_12 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_set1_epi32(a_values[10])),
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_set1_epi32(a_values[11])));
iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 0)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_10, 85)));
const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_values + 32));
const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_values + 40));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 170)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_10, 255)));
const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 0)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_shuffle_epi32(lhs_vec_11, 85)));
const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 170)));
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_shuffle_epi32(lhs_vec_11, 255)));
const __m256i iacc_03 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_set1_epi32(a_values[4])),
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_set1_epi32(a_values[5])));
const __m256i iacc_13 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_set1_epi32(a_values[12])),
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_set1_epi32(a_values[13])));
const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_values + 48));
const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_values + 56));
const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
__m256i iacc_04 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_set1_epi32(a_values[6])),
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_set1_epi32(a_values[7])));
__m256i iacc_14 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_set1_epi32(a_values[14])),
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_set1_epi32(a_values[15])));
iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
// Accumulate the iacc value for one sb
__m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);
__m256i iacc_sb = _mm256_add_epi32( _mm256_madd_epi16( _mm256_add_epi16(_mm256_add_epi16(iacc_01, iacc_02), _mm256_add_epi16(iacc_03, iacc_04)), scales_0),
_mm256_madd_epi16( _mm256_add_epi16(_mm256_add_epi16(iacc_11, iacc_12), _mm256_add_epi16(iacc_13, iacc_14)), scales_1));
// Broadcast the bsums of the two sub blocks of the iteration of Q8_K across the vector
// Multiply-Add with corresponding mins of Q4_Kx8 with bsums
__m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
__m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
q8s = _mm256_bsrli_epi128(q8s, 4);
const __m256i q8s_sb = _mm256_set1_epi32(q8s_ptr[sb]);
const __m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
// Accumulate for the complete block
iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
@ -1962,6 +1946,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
static const uint32_t kmask_3 = 0x30303030;
assert (n % qk == 0);
assert (nr % 4 == 0);
@ -2764,6 +2749,8 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
// dmin values - Load the eight dmin values of block_q4_kx8
const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
const uint32_t *utmp = (const uint32_t*) (b_ptr[b].scales);
// Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
for (int sb = 0; sb < QK_K / 64; sb++) {
@ -2865,31 +2852,25 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
uint32_t utmp_0[4], utmp_1[4];
// Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
// Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp_0[1] & kmask1;
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
utmp_0[2] = uaux_0;
utmp_0[0] &= kmask1;
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
const uint32_t uaux_1 = utmp_1[1] & kmask1;
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
utmp_1[2] = uaux_1;
utmp_1[0] &= kmask1;
const uint32_t utmp_03 = ((utmp[2] >> 4) & kmask2) | ((utmp[1] >> 2) & kmask_3);
const uint32_t utmp_02 = utmp[1] & kmask1;
const uint32_t utmp_01 = (utmp[2] & kmask2) | ((utmp[0] >> 2) & kmask_3);
const uint32_t utmp_00 = utmp[0] & kmask1;
// Scales of first sub block in the sb loop
const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
__m128i mins_and_scales_0 = _mm_set_epi32(utmp_03, utmp_02, utmp_01, utmp_00);
const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
const uint32_t utmp_13 = ((utmp[5] >> 4) & kmask2) | ((utmp[4] >> 2) & kmask_3);
const uint32_t utmp_12 = utmp[4] & kmask1;
const uint32_t utmp_11 = (utmp[5] & kmask2) | ((utmp[3] >> 2) & kmask_3);
const uint32_t utmp_10 = utmp[3] & kmask1;
// Scales of second sub block in the sb loop
const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
__m128i mins_and_scales_1 = _mm_set_epi32(utmp_13, utmp_12, utmp_11, utmp_10);
const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
// Mins of first and second sub block of Q4_K block are arranged side by side
@ -2901,39 +2882,45 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
utmp += 6;
for (int rp = 0; rp < 4; rp++) {
// Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
__m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));
__m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
__m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
__m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));
__m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
__m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
__m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));
__m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
__m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
__m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));
__m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
__m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
__m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));
__m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
__m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
__m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));
__m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
__m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
__m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));
__m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
__m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
__m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));
__m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
__m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
// Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
__m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
__m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
__m256i lhs_mat_01_00 = _mm256_blend_epi32(lhs_mat_0123_00, _mm256_permutevar8x32_epi32(lhs_mat_0123_00, requiredOrder), 0xF0);
__m256i lhs_mat_23_00 = _mm256_blend_epi32(lhs_mat_0123_00, _mm256_permutevar8x32_epi32(lhs_mat_0123_00, requiredOrder), 0x0F);
__m256i lhs_mat_01_01 = _mm256_blend_epi32(lhs_mat_0123_01, _mm256_permutevar8x32_epi32(lhs_mat_0123_01, requiredOrder), 0xF0);
__m256i lhs_mat_23_01 = _mm256_blend_epi32(lhs_mat_0123_01, _mm256_permutevar8x32_epi32(lhs_mat_0123_01, requiredOrder), 0x0F);
__m256i lhs_mat_01_02 = _mm256_blend_epi32(lhs_mat_0123_02, _mm256_permutevar8x32_epi32(lhs_mat_0123_02, requiredOrder), 0xF0);
__m256i lhs_mat_23_02 = _mm256_blend_epi32(lhs_mat_0123_02, _mm256_permutevar8x32_epi32(lhs_mat_0123_02, requiredOrder), 0x0F);
__m256i lhs_mat_01_03 = _mm256_blend_epi32(lhs_mat_0123_03, _mm256_permutevar8x32_epi32(lhs_mat_0123_03, requiredOrder), 0xF0);
__m256i lhs_mat_23_03 = _mm256_blend_epi32(lhs_mat_0123_03, _mm256_permutevar8x32_epi32(lhs_mat_0123_03, requiredOrder), 0x0F);
__m256i lhs_mat_01_10 = _mm256_blend_epi32(lhs_mat_0123_10, _mm256_permutevar8x32_epi32(lhs_mat_0123_10, requiredOrder), 0xF0);
__m256i lhs_mat_23_10 = _mm256_blend_epi32(lhs_mat_0123_10, _mm256_permutevar8x32_epi32(lhs_mat_0123_10, requiredOrder), 0x0F);
__m256i lhs_mat_01_11 = _mm256_blend_epi32(lhs_mat_0123_11, _mm256_permutevar8x32_epi32(lhs_mat_0123_11, requiredOrder), 0xF0);
__m256i lhs_mat_23_11 = _mm256_blend_epi32(lhs_mat_0123_11, _mm256_permutevar8x32_epi32(lhs_mat_0123_11, requiredOrder), 0x0F);
__m256i lhs_mat_01_12 = _mm256_blend_epi32(lhs_mat_0123_12, _mm256_permutevar8x32_epi32(lhs_mat_0123_12, requiredOrder), 0xF0);
__m256i lhs_mat_23_12 = _mm256_blend_epi32(lhs_mat_0123_12, _mm256_permutevar8x32_epi32(lhs_mat_0123_12, requiredOrder), 0x0F);
__m256i lhs_mat_01_13 = _mm256_blend_epi32(lhs_mat_0123_13, _mm256_permutevar8x32_epi32(lhs_mat_0123_13, requiredOrder), 0xF0);
__m256i lhs_mat_23_13 = _mm256_blend_epi32(lhs_mat_0123_13, _mm256_permutevar8x32_epi32(lhs_mat_0123_13, requiredOrder), 0x0F);
// Shuffle pattern one - left side input
const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
@ -3051,6 +3038,11 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
// Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
__m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
__m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
__m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
__m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
__m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
@ -3060,7 +3052,6 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
acc_min_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);
acc_min_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);
acc_min_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);
}
}
}
@ -3070,6 +3061,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
}
}
}
for (; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);