avx2: inline unpack_ksigns, fix up 8 sign bytes at once

This commit is contained in:
David Friehs 2026-02-16 06:46:45 +01:00
parent bd7b45e165
commit 7fe317f662
1 changed files with 36 additions and 32 deletions

View File

@ -2370,34 +2370,18 @@ static const int8_t keven_signs_q2xs[1024] = {
#endif
#if defined(__AVX2__)
static inline __m256i unpack_ksigns(const uint32_t packed) {
__m128i x = _mm_set1_epi32(packed);
// shift each 32 bit value to their offset. bits 0-6 are ok, 7-31 are garbage
const __m128i shifts = _mm_setr_epi32(0, 7, 14, 21);
x = _mm_srlv_epi32(x, shifts);
// popc_odd has 0x80 at locations that have odd bitcount, 0x00 at even bitcount
const __m128i mask_nib = _mm_set1_epi32(0x0F);
const __m128i popc_odd = _mm_setr_epi32(0x00808000, 0x80000080, 0x80000080, 0x00808000);
// xor bit 4-7 into the lower bit 0-3. this does not change if the set bit count is odd
__m128i p = _mm_srli_epi32(x, 4);
p = _mm_xor_si128(p, x);
p = _mm_and_si128(p, mask_nib);
p = _mm_shuffle_epi8(popc_odd, p);
// correct bit 7 via xor. bits 0-7 now ok, 8-31 still garbage
x = _mm_xor_si128(x, p);
// expand to __m256i, broadcast bytes 0, 4, 8, 12
const __m256i shf = _mm256_setr_epi64x(0x0000000000000000LL, 0x0404040404040404LL,
0x0808080808080808LL, 0x0C0C0C0C0C0C0C0CLL);
const __m256i y = _mm256_shuffle_epi8(_mm256_broadcastsi128_si256(x), shf);
const __m256i sel = _mm256_set1_epi64x(0x8040201008040201LL);
return _mm256_cmpeq_epi8(_mm256_and_si256(y, sel), sel);
}
// shifts to 7 bit signs in xxs quantizations
static const uint32_t ksigns_shift_xxs[8] = {0, 7, 14, 21, 0, 7, 14, 21};
// for _mm256_shuffle_epi8, has 0x80 at indices that are encoded with odd bit counts
static const uint32_t ksigns_popc_odd[8] = {
0x00808000, 0x80000080, 0x80000080, 0x00808000,
0x00808000, 0x80000080, 0x80000080, 0x00808000,
};
// for _mm256_shuffle_epi8, broadcasts bytes 0, 4, 8, 12
static const uint64_t ksigns_bcast_xxs[4] = {
0x0000000000000000ULL, 0x0404040404040404ULL,
0x0808080808080808ULL, 0x0C0C0C0C0C0C0C0CULL,
};
#endif
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
@ -2418,6 +2402,12 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
const __m256i ks_shift = _mm256_loadu_si256((const __m256i *)ksigns_shift_xxs);
const __m256i ks_bcast = _mm256_loadu_si256((const __m256i *)ksigns_bcast_xxs);
const __m256i popc_odd = _mm256_loadu_si256((const __m256i *)ksigns_popc_odd);
const __m256i mask_nib = _mm256_set1_epi32(0x0F);
const __m256i ks_bsel = _mm256_set1_epi64x(0x8040201008040201LL);
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
@ -2431,10 +2421,24 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
const __m256i s2_1 = unpack_ksigns(aux32[1]);
const __m256i s2_2 = unpack_ksigns(aux32[3]);
const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(q8_1, s2_1), s2_1);
const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(q8_2, s2_2), s2_2);
const __m256i s_raw = MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1]));
// shift each value to their offset. bits 0-6 now ok, 7-31 are garbage
const __m256i s_7 = _mm256_srlv_epi32(s_raw, ks_shift);
// count the bits via xor+lut, correct bit 8
const __m256i nib = _mm256_xor_si256(_mm256_srli_epi32(s_7, 4), s_7);
const __m256i popc = _mm256_shuffle_epi8(popc_odd, _mm256_and_si256(nib, mask_nib));
const __m256i s_8 = _mm256_xor_si256(s_7, popc);
// extract into two __m256i, broadcast bytes 0, 4, 8, 12
const __m256i s1_e = _mm256_permute2x128_si256(s_8, s_8, 0x00);
const __m256i s2_e = _mm256_permute2x128_si256(s_8, s_8, 0x11);
const __m256i s1_b = _mm256_shuffle_epi8(s1_e, ks_bcast);
const __m256i s2_b = _mm256_shuffle_epi8(s2_e, ks_bcast);
// set 0xFF in bytes that contain bit, then invert via xor+sub
const __m256i s1 = _mm256_cmpeq_epi8(_mm256_and_si256(s1_b, ks_bsel), ks_bsel);
const __m256i s2 = _mm256_cmpeq_epi8(_mm256_and_si256(s2_b, ks_bsel), ks_bsel);
const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(q8_1, s1), s1);
const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(q8_2, s2), s2);
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);