Merge 0f2d806829 into 18ddaea2ae
This commit is contained in:
commit
9dc07a8b72
|
|
@ -1396,7 +1396,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
const int blocklen = 8;
|
||||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
static const uint32_t kmask3 = 0x30303030;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
|
@ -1412,17 +1412,14 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__AVX2__)
|
||||
// Lookup table to convert signed nibbles to signed bytes
|
||||
__m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0));
|
||||
signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
|
||||
// Shuffle masks to rearrange delta and scale values to multiply with appropriate scales
|
||||
__m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
|
||||
__m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
|
||||
static const __m128i deltamask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0);
|
||||
static const __m128i scalemask = _mm_set_epi8(7, 7, 3, 3, 6, 6, 2, 2, 5, 5, 1, 1, 4, 4, 0, 0);
|
||||
// Permute mask used for easier vector processing at later stages
|
||||
__m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
|
||||
static const __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
|
||||
|
||||
// Mask to extract nibbles from bytes
|
||||
const __m256i m4b = _mm256_set1_epi8(0x0F);
|
||||
static const __m256i m4b = _mm256_set1_epi8(0x0F);
|
||||
|
||||
int64_t b_nb = n / QK_K;
|
||||
|
||||
|
|
@ -1459,133 +1456,120 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
__m256i iacc_min_b = _mm256_setzero_si256();
|
||||
|
||||
const __m256i q8sums = _mm256_loadu_si256((const __m256i * )(a_ptr[b].bsums));
|
||||
__m256i q8s = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1)));
|
||||
q8s = _mm256_permute2f128_si256(q8s, q8s, 0);
|
||||
__m128i q8s = _mm_hadd_epi16(_mm256_castsi256_si128(q8sums), _mm256_extracti128_si256(q8sums, 1));
|
||||
const uint32_t* q8s_ptr = (uint32_t*) &q8s;
|
||||
|
||||
const int sbCount = QK_K / 64;
|
||||
|
||||
const uint32_t *utmp = (const uint32_t*) (b_ptr[b].scales);
|
||||
|
||||
// Processes two sub blocks from each Q4_K in each iteration
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
|
||||
// Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
|
||||
const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + sb * 256));
|
||||
const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 32 + sb * 256));
|
||||
const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 64 + sb * 256));
|
||||
const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 96 + sb * 256));
|
||||
const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 128 + sb * 256));
|
||||
const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 160 + sb * 256));
|
||||
const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 192 + sb * 256));
|
||||
const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr[b].qs + 224 + sb * 256));
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
// Values of the first sub block of eight block_q4_K structures for the sb loop
|
||||
const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
|
||||
const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
|
||||
const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
|
||||
const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
|
||||
const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
|
||||
const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
|
||||
const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
|
||||
const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
|
||||
|
||||
// Values of the second sub block of eight block_q4_K structures when sb = 1
|
||||
const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);
|
||||
const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
|
||||
const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
|
||||
const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
|
||||
const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
|
||||
const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
|
||||
const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
|
||||
const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
|
||||
|
||||
uint32_t utmp_0[4], utmp_1[4];
|
||||
for (int sb = 0; sb < sbCount; sb++) {
|
||||
|
||||
// Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together
|
||||
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
|
||||
memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
|
||||
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp_0[1] & kmask1;
|
||||
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
|
||||
utmp_0[2] = uaux_0;
|
||||
utmp_0[0] &= kmask1;
|
||||
|
||||
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
|
||||
memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
|
||||
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_1 = utmp_1[1] & kmask1;
|
||||
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
|
||||
utmp_1[2] = uaux_1;
|
||||
utmp_1[0] &= kmask1;
|
||||
const uint32_t utmp_03 = ((utmp[2] >> 4) & kmask2) | ((utmp[1] >> 2) & kmask3);
|
||||
const uint32_t utmp_02 = utmp[1] & kmask1;
|
||||
const uint32_t utmp_01 = (utmp[2] & kmask2) | ((utmp[0] >> 2) & kmask3);
|
||||
const uint32_t utmp_00 = utmp[0] & kmask1;
|
||||
|
||||
// Scales of first sub block in the sb loop
|
||||
const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
|
||||
__m128i mins_and_scales_0 = _mm_set_epi32(utmp_03, utmp_02, utmp_01, utmp_00);
|
||||
__m128i scales_rearrange_0 = _mm_shuffle_epi8(mins_and_scales_0, scalemask);
|
||||
__m256i scales_0 = _mm256_cvtepu8_epi16(scales_rearrange_0);
|
||||
|
||||
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
|
||||
const uint32_t utmp_13 = ((utmp[5] >> 4) & kmask2) | ((utmp[4] >> 2) & kmask3);
|
||||
const uint32_t utmp_12 = utmp[4] & kmask1;
|
||||
const uint32_t utmp_11 = (utmp[5] & kmask2) | ((utmp[3] >> 2) & kmask3);
|
||||
const uint32_t utmp_10 = utmp[3] & kmask1;
|
||||
|
||||
// Scales of second sub block in the sb loop
|
||||
__m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
|
||||
__m128i mins_and_scales_1 = _mm_set_epi32(utmp_13, utmp_12, utmp_11, utmp_10);
|
||||
__m128i scales_rearrange_1 = _mm_shuffle_epi8(mins_and_scales_1, scalemask);
|
||||
__m256i scales_1 = _mm256_cvtepu8_epi16(scales_rearrange_1);
|
||||
|
||||
// Mins of first and second sub block of Q4_K block are arranged side by side
|
||||
__m256i mins_01 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(_mm_shuffle_epi32(mins_and_scales_0, 78), _mm_shuffle_epi32(mins_and_scales_1, 78)));
|
||||
|
||||
// Load the two sub block values corresponding to sb in block_q8_K in batches of 16 bytes and replicate the same across 256 bit vector
|
||||
__m256i lhs_vec_00 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + sb * 64)));
|
||||
__m256i lhs_vec_01 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 16 + sb * 64)));
|
||||
__m256i lhs_vec_10 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 32 + sb * 64)));
|
||||
__m256i lhs_vec_11 = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(a_ptr[b].qs + 48 + sb * 64)));
|
||||
utmp += 6;
|
||||
|
||||
lhs_vec_00 = _mm256_permute2f128_si256(lhs_vec_00, lhs_vec_00, 0);
|
||||
lhs_vec_01 = _mm256_permute2f128_si256(lhs_vec_01, lhs_vec_01, 0);
|
||||
lhs_vec_10 = _mm256_permute2f128_si256(lhs_vec_10, lhs_vec_10, 0);
|
||||
lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0);
|
||||
const uint32_t* a_values = (const uint32_t*) (a_ptr[b].qs + sb * 64);
|
||||
const uint32_t* b_values = (const uint32_t*) (b_ptr[b].qs + sb * 256);
|
||||
|
||||
// Dot product done within 32 bit lanes and accumulated in the same vector
|
||||
// First done for first sub block and thenn for second sub block in each sb
|
||||
// B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3)
|
||||
// B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7)
|
||||
// ...........................................................................
|
||||
// B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31)
|
||||
// Load the eight block_q4_K for two sub blocks quantized values interleaved with each other in chunks of eight - B0,B1 ....B6,B7
|
||||
const __m256i rhs_raw_vec_0123_0 = _mm256_loadu_si256((const __m256i * )(b_values));
|
||||
const __m256i rhs_raw_vec_4567_0 = _mm256_loadu_si256((const __m256i * )(b_values + 8));
|
||||
|
||||
const __m256i rhs_vec_0123_00 = _mm256_and_si256(rhs_raw_vec_0123_0, m4b);
|
||||
const __m256i rhs_vec_0123_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_0, 4), m4b);
|
||||
|
||||
const __m256i rhs_vec_4567_00 = _mm256_and_si256(rhs_raw_vec_4567_0, m4b);
|
||||
const __m256i rhs_vec_4567_10 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_0, 4), m4b);
|
||||
|
||||
const __m256i iacc_01 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_set1_epi32(a_values[0])),
|
||||
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_set1_epi32(a_values[1])));
|
||||
|
||||
const __m256i iacc_11 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_set1_epi32(a_values[8])),
|
||||
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_set1_epi32(a_values[9])));
|
||||
|
||||
|
||||
__m256i iacc_0 = _mm256_setzero_si256();
|
||||
__m256i iacc_1 = _mm256_setzero_si256();
|
||||
const __m256i rhs_raw_vec_0123_1 = _mm256_loadu_si256((const __m256i * )(b_values + 16));
|
||||
const __m256i rhs_raw_vec_4567_1 = _mm256_loadu_si256((const __m256i * )(b_values + 24));
|
||||
|
||||
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_00 ,_mm256_shuffle_epi32(rhs_vec_4567_00, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 0)));
|
||||
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_00, 177) ,rhs_vec_4567_00, 170), _mm256_shuffle_epi32(lhs_vec_00, 85)));
|
||||
const __m256i rhs_vec_0123_01 = _mm256_and_si256(rhs_raw_vec_0123_1, m4b);
|
||||
const __m256i rhs_vec_0123_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_1, 4), m4b);
|
||||
|
||||
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_shuffle_epi32(lhs_vec_00, 170)));
|
||||
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_shuffle_epi32(lhs_vec_00, 255)));
|
||||
const __m256i rhs_vec_4567_01 = _mm256_and_si256(rhs_raw_vec_4567_1, m4b);
|
||||
const __m256i rhs_vec_4567_11 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_1, 4), m4b);
|
||||
|
||||
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 0)));
|
||||
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_shuffle_epi32(lhs_vec_01, 85)));
|
||||
const __m256i iacc_02 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_01 ,_mm256_shuffle_epi32(rhs_vec_4567_01, 177), 170), _mm256_set1_epi32(a_values[2])),
|
||||
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_01, 177) ,rhs_vec_4567_01, 170), _mm256_set1_epi32(a_values[3])));
|
||||
|
||||
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_shuffle_epi32(lhs_vec_01, 170)));
|
||||
iacc_0 = _mm256_add_epi16(iacc_0, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_shuffle_epi32(lhs_vec_01, 255)));
|
||||
const __m256i iacc_12 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_set1_epi32(a_values[10])),
|
||||
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_set1_epi32(a_values[11])));
|
||||
|
||||
iacc_0 = _mm256_madd_epi16(iacc_0, scales_0);
|
||||
|
||||
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_10 ,_mm256_shuffle_epi32(rhs_vec_4567_10, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 0)));
|
||||
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_10, 177) ,rhs_vec_4567_10, 170), _mm256_shuffle_epi32(lhs_vec_10, 85)));
|
||||
const __m256i rhs_raw_vec_0123_2 = _mm256_loadu_si256((const __m256i * )(b_values + 32));
|
||||
const __m256i rhs_raw_vec_4567_2 = _mm256_loadu_si256((const __m256i * )(b_values + 40));
|
||||
|
||||
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_11 ,_mm256_shuffle_epi32(rhs_vec_4567_11, 177), 170), _mm256_shuffle_epi32(lhs_vec_10, 170)));
|
||||
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_11, 177) ,rhs_vec_4567_11, 170), _mm256_shuffle_epi32(lhs_vec_10, 255)));
|
||||
const __m256i rhs_vec_0123_02 = _mm256_and_si256(rhs_raw_vec_0123_2, m4b);
|
||||
const __m256i rhs_vec_0123_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_2, 4), m4b);
|
||||
|
||||
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 0)));
|
||||
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_shuffle_epi32(lhs_vec_11, 85)));
|
||||
const __m256i rhs_vec_4567_02 = _mm256_and_si256(rhs_raw_vec_4567_2, m4b);
|
||||
const __m256i rhs_vec_4567_12 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_2, 4), m4b);
|
||||
|
||||
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_shuffle_epi32(lhs_vec_11, 170)));
|
||||
iacc_1 = _mm256_add_epi16(iacc_1, _mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_shuffle_epi32(lhs_vec_11, 255)));
|
||||
const __m256i iacc_03 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_02 ,_mm256_shuffle_epi32(rhs_vec_4567_02, 177), 170), _mm256_set1_epi32(a_values[4])),
|
||||
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_02, 177) ,rhs_vec_4567_02, 170), _mm256_set1_epi32(a_values[5])));
|
||||
|
||||
const __m256i iacc_13 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_12 ,_mm256_shuffle_epi32(rhs_vec_4567_12, 177), 170), _mm256_set1_epi32(a_values[12])),
|
||||
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_12, 177) ,rhs_vec_4567_12, 170), _mm256_set1_epi32(a_values[13])));
|
||||
|
||||
|
||||
const __m256i rhs_raw_vec_0123_3 = _mm256_loadu_si256((const __m256i * )(b_values + 48));
|
||||
const __m256i rhs_raw_vec_4567_3 = _mm256_loadu_si256((const __m256i * )(b_values + 56));
|
||||
|
||||
const __m256i rhs_vec_0123_03 = _mm256_and_si256(rhs_raw_vec_0123_3, m4b);
|
||||
const __m256i rhs_vec_0123_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_0123_3, 4), m4b);
|
||||
|
||||
const __m256i rhs_vec_4567_03 = _mm256_and_si256(rhs_raw_vec_4567_3, m4b);
|
||||
const __m256i rhs_vec_4567_13 = _mm256_and_si256(_mm256_srli_epi16(rhs_raw_vec_4567_3, 4), m4b);
|
||||
|
||||
__m256i iacc_04 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_03 ,_mm256_shuffle_epi32(rhs_vec_4567_03, 177), 170), _mm256_set1_epi32(a_values[6])),
|
||||
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_03, 177) ,rhs_vec_4567_03, 170), _mm256_set1_epi32(a_values[7])));
|
||||
|
||||
__m256i iacc_14 = _mm256_add_epi16( _mm256_maddubs_epi16(_mm256_blend_epi32(rhs_vec_0123_13 ,_mm256_shuffle_epi32(rhs_vec_4567_13, 177), 170), _mm256_set1_epi32(a_values[14])),
|
||||
_mm256_maddubs_epi16(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_13, 177) ,rhs_vec_4567_13, 170), _mm256_set1_epi32(a_values[15])));
|
||||
|
||||
iacc_1 = _mm256_madd_epi16(iacc_1, scales_1);
|
||||
|
||||
// Accumulate the iacc value for one sb
|
||||
__m256i iacc_sb = _mm256_add_epi32(iacc_0, iacc_1);
|
||||
__m256i iacc_sb = _mm256_add_epi32( _mm256_madd_epi16( _mm256_add_epi16(_mm256_add_epi16(iacc_01, iacc_02), _mm256_add_epi16(iacc_03, iacc_04)), scales_0),
|
||||
_mm256_madd_epi16( _mm256_add_epi16(_mm256_add_epi16(iacc_11, iacc_12), _mm256_add_epi16(iacc_13, iacc_14)), scales_1));
|
||||
|
||||
// Broadcast the bsums of the two sub blocks of the iteration of Q8_K across the vector
|
||||
// Multiply-Add with corresponding mins of Q4_Kx8 with bsums
|
||||
__m256i q8s_sb = _mm256_shuffle_epi32(q8s, 0);
|
||||
__m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
|
||||
q8s = _mm256_bsrli_epi128(q8s, 4);
|
||||
const __m256i q8s_sb = _mm256_set1_epi32(q8s_ptr[sb]);
|
||||
const __m256i iacc_min_sb = _mm256_madd_epi16(q8s_sb, mins_01);
|
||||
|
||||
// Accumulate for the complete block
|
||||
iacc_b = _mm256_add_epi32(iacc_b, iacc_sb);
|
||||
|
|
@ -1962,6 +1946,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
static const uint32_t kmask1 = 0x3f3f3f3f;
|
||||
static const uint32_t kmask2 = 0x0f0f0f0f;
|
||||
static const uint32_t kmask3 = 0x03030303;
|
||||
static const uint32_t kmask_3 = 0x30303030;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nr % 4 == 0);
|
||||
|
|
@ -2764,6 +2749,8 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
// dmin values - Load the eight dmin values of block_q4_kx8
|
||||
const __m256 col_dmin_f32 = GGML_F32Cx8_LOAD(b_ptr[b].dmin);
|
||||
|
||||
const uint32_t *utmp = (const uint32_t*) (b_ptr[b].scales);
|
||||
|
||||
// Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
|
||||
|
|
@ -2865,31 +2852,25 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
const __m256i rhs_mat_0145_13_sp2 = _mm256_shuffle_epi32(rhs_mat_0145_13, 221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31)
|
||||
const __m256i rhs_mat_2367_13_sp2 = _mm256_shuffle_epi32(rhs_mat_2367_13, 221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31)
|
||||
|
||||
uint32_t utmp_0[4], utmp_1[4];
|
||||
|
||||
// Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together
|
||||
// Scales and Mins of corresponding sub blocks from different Q8_K structures are stored together
|
||||
// The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop
|
||||
memcpy(utmp_0, b_ptr[b].scales + 24 * sb, 12);
|
||||
utmp_0[3] = ((utmp_0[2] >> 4) & kmask2) | (((utmp_0[1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_0 = utmp_0[1] & kmask1;
|
||||
utmp_0[1] = (utmp_0[2] & kmask2) | (((utmp_0[0] >> 6) & kmask3) << 4);
|
||||
utmp_0[2] = uaux_0;
|
||||
utmp_0[0] &= kmask1;
|
||||
|
||||
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
|
||||
memcpy(utmp_1, b_ptr[b].scales + 12 + sb * 24, 12);
|
||||
utmp_1[3] = ((utmp_1[2] >> 4) & kmask2) | (((utmp_1[1] >> 6) & kmask3) << 4);
|
||||
const uint32_t uaux_1 = utmp_1[1] & kmask1;
|
||||
utmp_1[1] = (utmp_1[2] & kmask2) | (((utmp_1[0] >> 6) & kmask3) << 4);
|
||||
utmp_1[2] = uaux_1;
|
||||
utmp_1[0] &= kmask1;
|
||||
const uint32_t utmp_03 = ((utmp[2] >> 4) & kmask2) | ((utmp[1] >> 2) & kmask_3);
|
||||
const uint32_t utmp_02 = utmp[1] & kmask1;
|
||||
const uint32_t utmp_01 = (utmp[2] & kmask2) | ((utmp[0] >> 2) & kmask_3);
|
||||
const uint32_t utmp_00 = utmp[0] & kmask1;
|
||||
|
||||
// Scales of first sub block in the sb loop
|
||||
const __m128i mins_and_scales_0 = _mm_set_epi32(utmp_0[3], utmp_0[2], utmp_0[1], utmp_0[0]);
|
||||
__m128i mins_and_scales_0 = _mm_set_epi32(utmp_03, utmp_02, utmp_01, utmp_00);
|
||||
const __m256i scales_0 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0));
|
||||
|
||||
// The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop
|
||||
const uint32_t utmp_13 = ((utmp[5] >> 4) & kmask2) | ((utmp[4] >> 2) & kmask_3);
|
||||
const uint32_t utmp_12 = utmp[4] & kmask1;
|
||||
const uint32_t utmp_11 = (utmp[5] & kmask2) | ((utmp[3] >> 2) & kmask_3);
|
||||
const uint32_t utmp_10 = utmp[3] & kmask1;
|
||||
|
||||
// Scales of second sub block in the sb loop
|
||||
const __m128i mins_and_scales_1 = _mm_set_epi32(utmp_1[3], utmp_1[2], utmp_1[1], utmp_1[0]);
|
||||
__m128i mins_and_scales_1 = _mm_set_epi32(utmp_13, utmp_12, utmp_11, utmp_10);
|
||||
const __m256i scales_1 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1));
|
||||
|
||||
// Mins of first and second sub block of Q4_K block are arranged side by side
|
||||
|
|
@ -2901,39 +2882,45 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
const __m256i scale_0145_1 = _mm256_shuffle_epi32(scales_1, 68);
|
||||
const __m256i scale_2367_1 = _mm256_shuffle_epi32(scales_1, 238);
|
||||
|
||||
utmp += 6;
|
||||
|
||||
for (int rp = 0; rp < 4; rp++) {
|
||||
|
||||
// Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3
|
||||
// Loaded as set of 128 bit vectors and repeated into a 256 bit vector
|
||||
__m256i lhs_mat_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb)));
|
||||
__m256i lhs_mat_01_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 0);
|
||||
__m256i lhs_mat_23_00 = _mm256_permute2f128_si256(lhs_mat_0123_00, lhs_mat_0123_00, 17);
|
||||
__m256i lhs_mat_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb)));
|
||||
__m256i lhs_mat_01_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 0);
|
||||
__m256i lhs_mat_23_01 = _mm256_permute2f128_si256(lhs_mat_0123_01, lhs_mat_0123_01, 17);
|
||||
__m256i lhs_mat_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb)));
|
||||
__m256i lhs_mat_01_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 0);
|
||||
__m256i lhs_mat_23_02 = _mm256_permute2f128_si256(lhs_mat_0123_02, lhs_mat_0123_02, 17);
|
||||
__m256i lhs_mat_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb)));
|
||||
__m256i lhs_mat_01_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 0);
|
||||
__m256i lhs_mat_23_03 = _mm256_permute2f128_si256(lhs_mat_0123_03, lhs_mat_0123_03, 17);
|
||||
__m256i lhs_mat_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb)));
|
||||
__m256i lhs_mat_01_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 0);
|
||||
__m256i lhs_mat_23_10 = _mm256_permute2f128_si256(lhs_mat_0123_10, lhs_mat_0123_10, 17);
|
||||
__m256i lhs_mat_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb)));
|
||||
__m256i lhs_mat_01_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 0);
|
||||
__m256i lhs_mat_23_11 = _mm256_permute2f128_si256(lhs_mat_0123_11, lhs_mat_0123_11, 17);
|
||||
__m256i lhs_mat_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb)));
|
||||
__m256i lhs_mat_01_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 0);
|
||||
__m256i lhs_mat_23_12 = _mm256_permute2f128_si256(lhs_mat_0123_12, lhs_mat_0123_12, 17);
|
||||
__m256i lhs_mat_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb)));
|
||||
__m256i lhs_mat_01_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 0);
|
||||
__m256i lhs_mat_23_13 = _mm256_permute2f128_si256(lhs_mat_0123_13, lhs_mat_0123_13, 17);
|
||||
|
||||
// Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
|
||||
__m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
|
||||
__m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
|
||||
lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
|
||||
__m256i lhs_mat_01_00 = _mm256_blend_epi32(lhs_mat_0123_00, _mm256_permutevar8x32_epi32(lhs_mat_0123_00, requiredOrder), 0xF0);
|
||||
__m256i lhs_mat_23_00 = _mm256_blend_epi32(lhs_mat_0123_00, _mm256_permutevar8x32_epi32(lhs_mat_0123_00, requiredOrder), 0x0F);
|
||||
|
||||
__m256i lhs_mat_01_01 = _mm256_blend_epi32(lhs_mat_0123_01, _mm256_permutevar8x32_epi32(lhs_mat_0123_01, requiredOrder), 0xF0);
|
||||
__m256i lhs_mat_23_01 = _mm256_blend_epi32(lhs_mat_0123_01, _mm256_permutevar8x32_epi32(lhs_mat_0123_01, requiredOrder), 0x0F);
|
||||
|
||||
__m256i lhs_mat_01_02 = _mm256_blend_epi32(lhs_mat_0123_02, _mm256_permutevar8x32_epi32(lhs_mat_0123_02, requiredOrder), 0xF0);
|
||||
__m256i lhs_mat_23_02 = _mm256_blend_epi32(lhs_mat_0123_02, _mm256_permutevar8x32_epi32(lhs_mat_0123_02, requiredOrder), 0x0F);
|
||||
|
||||
__m256i lhs_mat_01_03 = _mm256_blend_epi32(lhs_mat_0123_03, _mm256_permutevar8x32_epi32(lhs_mat_0123_03, requiredOrder), 0xF0);
|
||||
__m256i lhs_mat_23_03 = _mm256_blend_epi32(lhs_mat_0123_03, _mm256_permutevar8x32_epi32(lhs_mat_0123_03, requiredOrder), 0x0F);
|
||||
|
||||
__m256i lhs_mat_01_10 = _mm256_blend_epi32(lhs_mat_0123_10, _mm256_permutevar8x32_epi32(lhs_mat_0123_10, requiredOrder), 0xF0);
|
||||
__m256i lhs_mat_23_10 = _mm256_blend_epi32(lhs_mat_0123_10, _mm256_permutevar8x32_epi32(lhs_mat_0123_10, requiredOrder), 0x0F);
|
||||
|
||||
__m256i lhs_mat_01_11 = _mm256_blend_epi32(lhs_mat_0123_11, _mm256_permutevar8x32_epi32(lhs_mat_0123_11, requiredOrder), 0xF0);
|
||||
__m256i lhs_mat_23_11 = _mm256_blend_epi32(lhs_mat_0123_11, _mm256_permutevar8x32_epi32(lhs_mat_0123_11, requiredOrder), 0x0F);
|
||||
|
||||
__m256i lhs_mat_01_12 = _mm256_blend_epi32(lhs_mat_0123_12, _mm256_permutevar8x32_epi32(lhs_mat_0123_12, requiredOrder), 0xF0);
|
||||
__m256i lhs_mat_23_12 = _mm256_blend_epi32(lhs_mat_0123_12, _mm256_permutevar8x32_epi32(lhs_mat_0123_12, requiredOrder), 0x0F);
|
||||
|
||||
__m256i lhs_mat_01_13 = _mm256_blend_epi32(lhs_mat_0123_13, _mm256_permutevar8x32_epi32(lhs_mat_0123_13, requiredOrder), 0xF0);
|
||||
__m256i lhs_mat_23_13 = _mm256_blend_epi32(lhs_mat_0123_13, _mm256_permutevar8x32_epi32(lhs_mat_0123_13, requiredOrder), 0x0F);
|
||||
|
||||
|
||||
// Shuffle pattern one - left side input
|
||||
const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)
|
||||
|
|
@ -3051,6 +3038,11 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]);
|
||||
acc_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_3), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]);
|
||||
|
||||
// Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks
|
||||
__m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb)));
|
||||
__m256i lhs_bsums_hsum_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1)));
|
||||
lhs_bsums_hsum_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_0123_01, lhs_bsums_hsum_0123_01, 0);
|
||||
|
||||
__m256i iacc_row_min_0 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 0), mins_01);
|
||||
__m256i iacc_row_min_1 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 85), mins_01);
|
||||
__m256i iacc_row_min_2 = _mm256_madd_epi16(_mm256_shuffle_epi32(lhs_bsums_hsum_0123_01, 170), mins_01);
|
||||
|
|
@ -3060,7 +3052,6 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
acc_min_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_1), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]);
|
||||
acc_min_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_2), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]);
|
||||
acc_min_rows[rp * 4 + 3] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_min_3), _mm256_mul_ps(col_dmin_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -3070,6 +3061,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (; y < nr / 4; y++) {
|
||||
|
||||
const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb);
|
||||
|
|
|
|||
Loading…
Reference in New Issue