Add further fixes and updates to scalar code

This commit is contained in:
Srihari-mcw 2025-08-12 18:53:34 +05:30 committed by Manogna-Sree
parent c29ac56955
commit 4806d6a8fe
2 changed files with 82 additions and 104 deletions

View File

@ -6702,7 +6702,7 @@ void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__AVX2__)
#if defined(__AVX2__) || defined(__AVX512F__)
const block_q6_Kx8 * b_ptr_start = (const block_q6_Kx8 * ) vx;
const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy;
int64_t b_nb = n / QK_K;
@ -8797,60 +8797,60 @@ void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
// Index : 0 -7, 64 - 71
// Comments indicate the indices of elements from individual super block in non interleaved fashion
// Index : 0 -7, 64 - 71
const __m256i rhs_mat_0145_00 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_0, m4), rhs_hbit_0145_00);
const __m256i rhs_mat_0145_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_0, 4), m4), rhs_hbit_0145_40);
const __m256i rhs_mat_0145_00 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_0, m4), rhs_hbit_0145_00);
const __m256i rhs_mat_0145_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_0, 4), m4), rhs_hbit_0145_40);
const __m256i rhs_mat_2367_00 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_0, m4), rhs_hbit_2367_00);
const __m256i rhs_mat_2367_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_0, 4), m4), rhs_hbit_2367_40);
const __m256i rhs_mat_2367_00 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_0, m4), rhs_hbit_2367_00);
const __m256i rhs_mat_2367_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_0, 4), m4), rhs_hbit_2367_40);
// Index : 8 - 15, 72 - 79
const __m256i rhs_mat_0145_01 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_1, m4), rhs_hbit_0145_01);
const __m256i rhs_mat_0145_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_1, 4), m4), rhs_hbit_0145_41);
const __m256i rhs_mat_0145_01 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_1, m4), rhs_hbit_0145_01);
const __m256i rhs_mat_0145_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_1, 4), m4), rhs_hbit_0145_41);
const __m256i rhs_mat_2367_01 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_1, m4), rhs_hbit_2367_01);
const __m256i rhs_mat_2367_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_1, 4), m4), rhs_hbit_2367_41);
const __m256i rhs_mat_2367_01 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_1, m4), rhs_hbit_2367_01);
const __m256i rhs_mat_2367_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_1, 4), m4), rhs_hbit_2367_41);
// Index : 16 - 23, 80 - 87
const __m256i rhs_mat_0145_10 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_2, m4), rhs_hbit_0145_10);
const __m256i rhs_mat_0145_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_2, 4), m4), rhs_hbit_0145_50);
const __m256i rhs_mat_0145_10 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_2, m4), rhs_hbit_0145_10);
const __m256i rhs_mat_0145_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_2, 4), m4), rhs_hbit_0145_50);
const __m256i rhs_mat_2367_10 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_2, m4), rhs_hbit_2367_10);
const __m256i rhs_mat_2367_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_2, 4), m4), rhs_hbit_2367_50);
const __m256i rhs_mat_2367_10 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_2, m4), rhs_hbit_2367_10);
const __m256i rhs_mat_2367_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_2, 4), m4), rhs_hbit_2367_50);
// Index : 24 - 31, 88 - 95
const __m256i rhs_mat_0145_11 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_3, m4), rhs_hbit_0145_11);
const __m256i rhs_mat_0145_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_3, 4), m4), rhs_hbit_0145_51);
const __m256i rhs_mat_0145_11 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_3, m4), rhs_hbit_0145_11);
const __m256i rhs_mat_0145_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_3, 4), m4), rhs_hbit_0145_51);
const __m256i rhs_mat_2367_11 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_3, m4), rhs_hbit_2367_11);
const __m256i rhs_mat_2367_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_3, 4), m4), rhs_hbit_2367_51);
const __m256i rhs_mat_2367_11 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_3, m4), rhs_hbit_2367_11);
const __m256i rhs_mat_2367_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_3, 4), m4), rhs_hbit_2367_51);
// Index : 32 - 39, 96 - 103
const __m256i rhs_mat_0145_20 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_4, m4), rhs_hbit_0145_20);
const __m256i rhs_mat_0145_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_4, 4), m4), rhs_hbit_0145_60);
const __m256i rhs_mat_0145_20 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_4, m4), rhs_hbit_0145_20);
const __m256i rhs_mat_0145_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_4, 4), m4), rhs_hbit_0145_60);
const __m256i rhs_mat_2367_20 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_4, m4), rhs_hbit_2367_20);
const __m256i rhs_mat_2367_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_4, 4), m4), rhs_hbit_2367_60);
const __m256i rhs_mat_2367_20 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_4, m4), rhs_hbit_2367_20);
const __m256i rhs_mat_2367_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_4, 4), m4), rhs_hbit_2367_60);
// Index : 40 - 47, 104 - 111
const __m256i rhs_mat_0145_21 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_5, m4), rhs_hbit_0145_21);
const __m256i rhs_mat_0145_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_5, 4), m4), rhs_hbit_0145_61);
const __m256i rhs_mat_0145_21 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_5, m4), rhs_hbit_0145_21);
const __m256i rhs_mat_0145_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_5, 4), m4), rhs_hbit_0145_61);
const __m256i rhs_mat_2367_21 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_5, m4), rhs_hbit_2367_21);
const __m256i rhs_mat_2367_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_5, 4), m4), rhs_hbit_2367_61);
const __m256i rhs_mat_2367_21 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_5, m4), rhs_hbit_2367_21);
const __m256i rhs_mat_2367_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_5, 4), m4), rhs_hbit_2367_61);
// Index : 48 - 55, 112 - 119
const __m256i rhs_mat_0145_30 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_6, m4), rhs_hbit_0145_30);
const __m256i rhs_mat_0145_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_6, 4), m4), rhs_hbit_0145_70);
const __m256i rhs_mat_0145_30 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_6, m4), rhs_hbit_0145_30);
const __m256i rhs_mat_0145_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_6, 4), m4), rhs_hbit_0145_70);
const __m256i rhs_mat_2367_30 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_6, m4), rhs_hbit_2367_30);
const __m256i rhs_mat_2367_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_6, 4), m4), rhs_hbit_2367_70);
const __m256i rhs_mat_2367_30 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_6, m4), rhs_hbit_2367_30);
const __m256i rhs_mat_2367_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_6, 4), m4), rhs_hbit_2367_70);
// Index : 56 - 63, 120 - 127
const __m256i rhs_mat_0145_31 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_7, m4), rhs_hbit_0145_31);
const __m256i rhs_mat_0145_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_7, 4), m4), rhs_hbit_0145_71);
const __m256i rhs_mat_0145_31 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_7, m4), rhs_hbit_0145_31);
const __m256i rhs_mat_0145_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_7, 4), m4), rhs_hbit_0145_71);
const __m256i rhs_mat_2367_31 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_7, m4), rhs_hbit_2367_31);
const __m256i rhs_mat_2367_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_7, 4), m4), rhs_hbit_2367_71);
const __m256i rhs_mat_2367_31 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_7, m4), rhs_hbit_2367_31);
const __m256i rhs_mat_2367_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_7, 4), m4), rhs_hbit_2367_71);
// Shuffle pattern one - right side input
const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)
@ -9609,60 +9609,60 @@ void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
// Index : 0 -7, 64 - 71
// Comments indicate the indices of elements from individual super block in non interleaved fashion
// Index : 0 -7, 64 - 71
const __m256i rhs_mat_0145_00 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_0, m4), rhs_hbit_0145_00);
const __m256i rhs_mat_0145_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_0, 4), m4), rhs_hbit_0145_40);
const __m256i rhs_mat_0145_00 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_0, m4), rhs_hbit_0145_00);
const __m256i rhs_mat_0145_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_0, 4), m4), rhs_hbit_0145_40);
const __m256i rhs_mat_2367_00 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_0, m4), rhs_hbit_2367_00);
const __m256i rhs_mat_2367_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_0, 4), m4), rhs_hbit_2367_40);
const __m256i rhs_mat_2367_00 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_0, m4), rhs_hbit_2367_00);
const __m256i rhs_mat_2367_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_0, 4), m4), rhs_hbit_2367_40);
// Index : 8 - 15, 72 - 79
const __m256i rhs_mat_0145_01 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_1, m4), rhs_hbit_0145_01);
const __m256i rhs_mat_0145_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_1, 4), m4), rhs_hbit_0145_41);
const __m256i rhs_mat_0145_01 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_1, m4), rhs_hbit_0145_01);
const __m256i rhs_mat_0145_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_1, 4), m4), rhs_hbit_0145_41);
const __m256i rhs_mat_2367_01 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_1, m4), rhs_hbit_2367_01);
const __m256i rhs_mat_2367_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_1, 4), m4), rhs_hbit_2367_41);
const __m256i rhs_mat_2367_01 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_1, m4), rhs_hbit_2367_01);
const __m256i rhs_mat_2367_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_1, 4), m4), rhs_hbit_2367_41);
// Index : 16 - 23, 80 - 87
const __m256i rhs_mat_0145_10 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_2, m4), rhs_hbit_0145_10);
const __m256i rhs_mat_0145_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_2, 4), m4), rhs_hbit_0145_50);
const __m256i rhs_mat_0145_10 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_2, m4), rhs_hbit_0145_10);
const __m256i rhs_mat_0145_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_2, 4), m4), rhs_hbit_0145_50);
const __m256i rhs_mat_2367_10 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_2, m4), rhs_hbit_2367_10);
const __m256i rhs_mat_2367_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_2, 4), m4), rhs_hbit_2367_50);
const __m256i rhs_mat_2367_10 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_2, m4), rhs_hbit_2367_10);
const __m256i rhs_mat_2367_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_2, 4), m4), rhs_hbit_2367_50);
// Index : 24 - 31, 88 - 95
const __m256i rhs_mat_0145_11 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_3, m4), rhs_hbit_0145_11);
const __m256i rhs_mat_0145_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_3, 4), m4), rhs_hbit_0145_51);
const __m256i rhs_mat_0145_11 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_3, m4), rhs_hbit_0145_11);
const __m256i rhs_mat_0145_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_3, 4), m4), rhs_hbit_0145_51);
const __m256i rhs_mat_2367_11 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_3, m4), rhs_hbit_2367_11);
const __m256i rhs_mat_2367_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_3, 4), m4), rhs_hbit_2367_51);
const __m256i rhs_mat_2367_11 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_3, m4), rhs_hbit_2367_11);
const __m256i rhs_mat_2367_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_3, 4), m4), rhs_hbit_2367_51);
// Index : 32 - 39, 96 - 103
const __m256i rhs_mat_0145_20 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_4, m4), rhs_hbit_0145_20);
const __m256i rhs_mat_0145_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_4, 4), m4), rhs_hbit_0145_60);
const __m256i rhs_mat_0145_20 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_4, m4), rhs_hbit_0145_20);
const __m256i rhs_mat_0145_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_4, 4), m4), rhs_hbit_0145_60);
const __m256i rhs_mat_2367_20 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_4, m4), rhs_hbit_2367_20);
const __m256i rhs_mat_2367_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_4, 4), m4), rhs_hbit_2367_60);
const __m256i rhs_mat_2367_20 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_4, m4), rhs_hbit_2367_20);
const __m256i rhs_mat_2367_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_4, 4), m4), rhs_hbit_2367_60);
// Index : 40 - 47, 104 - 111
const __m256i rhs_mat_0145_21 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_5, m4), rhs_hbit_0145_21);
const __m256i rhs_mat_0145_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_5, 4), m4), rhs_hbit_0145_61);
const __m256i rhs_mat_0145_21 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_5, m4), rhs_hbit_0145_21);
const __m256i rhs_mat_0145_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_5, 4), m4), rhs_hbit_0145_61);
const __m256i rhs_mat_2367_21 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_5, m4), rhs_hbit_2367_21);
const __m256i rhs_mat_2367_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_5, 4), m4), rhs_hbit_2367_61);
const __m256i rhs_mat_2367_21 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_5, m4), rhs_hbit_2367_21);
const __m256i rhs_mat_2367_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_5, 4), m4), rhs_hbit_2367_61);
// Index : 48 - 55, 112 - 119
const __m256i rhs_mat_0145_30 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_6, m4), rhs_hbit_0145_30);
const __m256i rhs_mat_0145_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_6, 4), m4), rhs_hbit_0145_70);
const __m256i rhs_mat_0145_30 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_6, m4), rhs_hbit_0145_30);
const __m256i rhs_mat_0145_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_6, 4), m4), rhs_hbit_0145_70);
const __m256i rhs_mat_2367_30 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_6, m4), rhs_hbit_2367_30);
const __m256i rhs_mat_2367_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_6, 4), m4), rhs_hbit_2367_70);
const __m256i rhs_mat_2367_30 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_6, m4), rhs_hbit_2367_30);
const __m256i rhs_mat_2367_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_6, 4), m4), rhs_hbit_2367_70);
// Index : 56 - 63, 120 - 127
const __m256i rhs_mat_0145_31 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_7, m4), rhs_hbit_0145_31);
const __m256i rhs_mat_0145_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_7, 4), m4), rhs_hbit_0145_71);
const __m256i rhs_mat_0145_31 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_7, m4), rhs_hbit_0145_31);
const __m256i rhs_mat_0145_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_7, 4), m4), rhs_hbit_0145_71);
const __m256i rhs_mat_2367_31 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_7, m4), rhs_hbit_2367_31);
const __m256i rhs_mat_2367_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_7, 4), m4), rhs_hbit_2367_71);
const __m256i rhs_mat_2367_31 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_7, m4), rhs_hbit_2367_31);
const __m256i rhs_mat_2367_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_7, 4), m4), rhs_hbit_2367_71);
// Shuffle pattern one - right side input
const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3)

View File

@ -647,10 +647,10 @@ void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (4 * blocklen)); k++) {
const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
const int8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64;
const int8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
const int8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
const int8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
@ -659,22 +659,10 @@ void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
sumi = 0;
int offset = ((k / 2) % 2) + j * 2;
for (int i = 0; i < blocklen; ++i) {
const int hbits_index = k * ncols_interleaved * blocklen + j * blocklen + i;
const int lbits_index = (hbits_index / 32) * 64 + (hbits_index % 32);
const int v0_hbits = (int8_t) ((b_ptr[l].qh[hbits_index] & 3) << 4);
const int v1_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 2 ) & 3) << 4);
const int v2_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 4 ) & 3) << 4);
const int v3_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 6 ) & 3) << 4);
const int v0_lbits = (int8_t) (b_ptr[l].qh[lbits_index] & 0xF);
const int v1_lbits = (int8_t) (b_ptr[l].qh[lbits_index + 32] & 0xF);
const int v2_lbits = (int8_t) ((b_ptr[l].qh[lbits_index] >> 4) & 0xF);
const int v3_lbits = (int8_t) ((b_ptr[l].qh[lbits_index + 32] >> 4) & 0xF);
const int v0 = v0_hbits | v0_lbits;
const int v1 = v1_hbits | v1_lbits;
const int v2 = v2_hbits | v2_lbits;
const int v3 = v3_hbits | v3_lbits;
int8_t v0 = (int8_t)((b_ptr[l].qh[hbits_index] & 3) << 4) | (b_ptr[l].ql[lbits_index] & 0xF) - 32;
int8_t v1 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 2 ) & 3) << 4) | (b_ptr[l].ql[lbits_index + 32] & 0xF) - 32;
int8_t v2 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 4 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index] >> 4) & 0xF) - 32;
int8_t v3 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 6 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index + 32] >> 4) & 0xF) - 32;
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]);
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]);
@ -1226,20 +1214,19 @@ void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0;
sum_minf[m][j] = 0.0;
}
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (4 * blocklen)); k++) {
const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
const int8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64;
const int8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
const int8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
const int8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
@ -1251,20 +1238,11 @@ void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
for (int i = 0; i < blocklen; ++i){
const int hbits_index = k * ncols_interleaved * blocklen + j * blocklen + i;
const int lbits_index = (hbits_index / 32) * 64 + (hbits_index % 32);
const int v0_hbits = (int8_t) ((b_ptr[l].qh[hbits_index] & 3) << 4);
const int v1_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 2 ) & 3) << 4);
const int v2_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 4 ) & 3) << 4);
const int v3_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 6 ) & 3) << 4);
const int v0_lbits = (int8_t) (b_ptr[l].qh[lbits_index] & 0xF);
const int v1_lbits = (int8_t) (b_ptr[l].qh[lbits_index + 32] & 0xF);
const int v2_lbits = (int8_t) ((b_ptr[l].qh[lbits_index] >> 4) & 0xF);
const int v3_lbits = (int8_t) ((b_ptr[l].qh[lbits_index + 32] >> 4) & 0xF);
const int v0 = v0_hbits | v0_lbits;
const int v1 = v1_hbits | v1_lbits;
const int v2 = v2_hbits | v2_lbits;
const int v3 = v3_hbits | v3_lbits;
int8_t v0 = (int8_t)((b_ptr[l].qh[hbits_index] & 3) << 4) | (b_ptr[l].ql[lbits_index] & 0xF) - 32;
int8_t v1 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 2 ) & 3) << 4) | (b_ptr[l].ql[lbits_index + 32] & 0xF) - 32;
int8_t v2 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 4 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index] >> 4) & 0xF) - 32;
int8_t v3 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 6 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index + 32] >> 4) & 0xF) - 32;
sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]);
sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);