This commit is contained in:
David Friehs 2026-02-16 22:50:02 +01:00 committed by GitHub
commit 3adb7e4e8f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 46 additions and 8 deletions

View File

@ -2369,6 +2369,22 @@ static const int8_t keven_signs_q2xs[1024] = {
};
#endif
#if defined(__AVX2__)
// shifts to 7 bit signs in xxs quantizations
static const uint32_t ksigns_shift_xxs[4] = {0, 7, 14, 21};
// for _mm256_shuffle_epi8, has 0x80 at indices that are encoded with odd bit counts
static const uint32_t ksigns_popc_odd[4] = {0x00808000, 0x80000080, 0x80000080, 0x00808000,};
// for _mm256_shuffle_epi8, broadcasts bytes 0, 2, 4, 6 / 8, 10, 12, 14
static const uint64_t ksigns_bcast_1[4] = {
0x0000000000000000ULL, 0x0202020202020202ULL,
0x0404040404040404ULL, 0x0606060606060606ULL,
};
static const uint64_t ksigns_bcast_2[4] = {
0x0808080808080808ULL, 0x0A0A0A0A0A0A0A0AULL,
0x0C0C0C0C0C0C0C0CULL, 0x0E0E0E0E0E0E0E0EULL,
};
#endif
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
@ -2384,11 +2400,16 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
#if defined(__AVX2__)
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
const __m128i ks_shift = _mm_loadu_si128((const __m128i *)ksigns_shift_xxs);
const __m128i ks_mask = _mm_set1_epi32(0x7F);
const __m128i popc_odd = _mm_loadu_si128((const __m128i *)ksigns_popc_odd);
const __m256i ks_bc_1 = _mm256_loadu_si256((const __m256i *)ksigns_bcast_1);
const __m256i ks_bc_2 = _mm256_loadu_si256((const __m256i *)ksigns_bcast_2);
const __m256i ks_bsel = _mm256_set1_epi64x(0x8040201008040201LL);
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
@ -2402,12 +2423,29 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
__m128i s_l = _mm_set1_epi32(aux32[1]);
__m128i s_h = _mm_set1_epi32(aux32[3]);
// shift each value to their offset, then zero out garbage
s_l = _mm_srlv_epi32(s_l, ks_shift);
s_h = _mm_srlv_epi32(s_h, ks_shift);
s_l = _mm_and_si128(s_l, ks_mask);
s_h = _mm_and_si128(s_h, ks_mask);
// pack, count bits via xor+lut, correct bit 8
__m128i signs_128 = _mm_packus_epi32(s_l, s_h);
const __m128i cnt4 = _mm_xor_si128(_mm_srli_epi16(signs_128, 4), signs_128);
const __m128i popc = _mm_shuffle_epi8(popc_odd, cnt4);
signs_128 = _mm_or_si128(signs_128, popc);
// expand to 256 bits, then broadcast to 8 bytes each
__m256i signs_256 = _mm256_broadcastsi128_si256(signs_128);
const __m256i s1_b = _mm256_shuffle_epi8(signs_256, ks_bc_1);
const __m256i s2_b = _mm256_shuffle_epi8(signs_256, ks_bc_2);
// set 0xFF in bytes that contain bit, then invert via xor+sub
const __m256i s1 = _mm256_cmpeq_epi8(_mm256_and_si256(s1_b, ks_bsel), ks_bsel);
const __m256i s2 = _mm256_cmpeq_epi8(_mm256_and_si256(s2_b, ks_bsel), ks_bsel);
const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(q8_1, s1), s1);
const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(q8_2, s2), s2);
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
const uint16_t ls1 = aux32[1] >> 28;