diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 74d699f633..cf3f2133ae 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -2369,6 +2369,35 @@ static const int8_t keven_signs_q2xs[1024] = { }; #endif +#if defined(__AVX2__) +static inline __m256i unpack_ksigns(const uint32_t packed) { + __m128i x = _mm_set1_epi32(packed); + + // shift each 32 bit value to their offset. bits 0-6 are ok, 7-31 are garbage + const __m128i shifts = _mm_setr_epi32(0, 7, 14, 21); + x = _mm_srlv_epi32(x, shifts); + + // plut has 0x80 at locations that have odd bitcount, 0x00 at even bitcount + const __m128i mask = _mm_set1_epi32(0x0F); + const __m128i plut = _mm_setr_epi32(0x00808000, 0x80000080, 0x80000080, 0x00808000); + + const __m128i p_l = _mm_shuffle_epi8(plut, _mm_and_si128(x, mask)); + const __m128i p_h = _mm_shuffle_epi8(plut, _mm_and_si128(_mm_srli_epi32(x, 4), mask)); + + // correct bit 7 via xor. bits 0-7 now ok, 8-31 still garbage + x = _mm_xor_si128(x, p_l); + x = _mm_xor_si128(x, p_h); + + // expand to __m256i, broadcast bytes 0, 4, 8, 12 + const __m256i shf = _mm256_setr_epi64x(0x0000000000000000LL, 0x0404040404040404LL, + 0x0808080808080808LL, 0x0C0C0C0C0C0C0C0CLL); + const __m256i y = _mm256_shuffle_epi8(_mm256_broadcastsi128_si256(x), shf); + + const __m256i sel = _mm256_set1_epi64x(0x8040201008040201LL); + return _mm256_cmpeq_epi8(_mm256_and_si256(y, sel), sel); +} +#endif + void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2384,8 +2413,6 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const #if defined(__AVX2__) - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - uint32_t aux32[4]; const uint8_t * aux8 = (const uint8_t *)aux32; @@ -2402,12 +2429,11 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); - const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], - signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], - signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); - const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); - const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i s2_1 = unpack_ksigns(aux32[1]); + const __m256i s2_2 = unpack_ksigns(aux32[3]); + const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(q8_1, s2_1), s2_1); + const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(q8_2, s2_2), s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); const uint16_t ls1 = aux32[1] >> 28;