From 4806d6a8fef66af2b62eb90c7845f6f7743051cb Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Tue, 12 Aug 2025 18:53:34 +0530 Subject: [PATCH] Add further fixes and updates to scalar code --- ggml/src/ggml-cpu/arch/x86/repack.cpp | 130 +++++++++++++------------- ggml/src/ggml-cpu/repack.cpp | 56 ++++------- 2 files changed, 82 insertions(+), 104 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index 1b196d14df..7e6a375e89 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -6702,7 +6702,7 @@ void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__AVX2__) +#if defined(__AVX2__) || defined(__AVX512F__) const block_q6_Kx8 * b_ptr_start = (const block_q6_Kx8 * ) vx; const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy; int64_t b_nb = n / QK_K; @@ -8797,60 +8797,60 @@ void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo // Index : 0 -7, 64 - 71 // Comments indicate the indices of elements from individual super block in non interleaved fashion // Index : 0 -7, 64 - 71 - const __m256i rhs_mat_0145_00 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_0, m4), rhs_hbit_0145_00); - const __m256i rhs_mat_0145_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_0, 4), m4), rhs_hbit_0145_40); + const __m256i rhs_mat_0145_00 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_0, m4), rhs_hbit_0145_00); + const __m256i rhs_mat_0145_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_0, 4), m4), rhs_hbit_0145_40); - const __m256i rhs_mat_2367_00 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_0, m4), rhs_hbit_2367_00); - const __m256i rhs_mat_2367_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_0, 4), m4), rhs_hbit_2367_40); + const __m256i rhs_mat_2367_00 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_0, m4), rhs_hbit_2367_00); + const __m256i rhs_mat_2367_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_0, 4), m4), rhs_hbit_2367_40); // Index : 8 - 15, 72 - 79 - const __m256i rhs_mat_0145_01 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_1, m4), rhs_hbit_0145_01); - const __m256i rhs_mat_0145_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_1, 4), m4), rhs_hbit_0145_41); + const __m256i rhs_mat_0145_01 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_1, m4), rhs_hbit_0145_01); + const __m256i rhs_mat_0145_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_1, 4), m4), rhs_hbit_0145_41); - const __m256i rhs_mat_2367_01 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_1, m4), rhs_hbit_2367_01); - const __m256i rhs_mat_2367_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_1, 4), m4), rhs_hbit_2367_41); + const __m256i rhs_mat_2367_01 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_1, m4), rhs_hbit_2367_01); + const __m256i rhs_mat_2367_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_1, 4), m4), rhs_hbit_2367_41); // Index : 16 - 23, 80 - 87 - const __m256i rhs_mat_0145_10 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_2, m4), rhs_hbit_0145_10); - const __m256i rhs_mat_0145_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_2, 4), m4), rhs_hbit_0145_50); + const __m256i rhs_mat_0145_10 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_2, m4), rhs_hbit_0145_10); + const __m256i rhs_mat_0145_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_2, 4), m4), rhs_hbit_0145_50); - const __m256i rhs_mat_2367_10 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_2, m4), rhs_hbit_2367_10); - const __m256i rhs_mat_2367_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_2, 4), m4), rhs_hbit_2367_50); + const __m256i rhs_mat_2367_10 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_2, m4), rhs_hbit_2367_10); + const __m256i rhs_mat_2367_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_2, 4), m4), rhs_hbit_2367_50); // Index : 24 - 31, 88 - 95 - const __m256i rhs_mat_0145_11 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_3, m4), rhs_hbit_0145_11); - const __m256i rhs_mat_0145_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_3, 4), m4), rhs_hbit_0145_51); + const __m256i rhs_mat_0145_11 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_3, m4), rhs_hbit_0145_11); + const __m256i rhs_mat_0145_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_3, 4), m4), rhs_hbit_0145_51); - const __m256i rhs_mat_2367_11 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_3, m4), rhs_hbit_2367_11); - const __m256i rhs_mat_2367_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_3, 4), m4), rhs_hbit_2367_51); + const __m256i rhs_mat_2367_11 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_3, m4), rhs_hbit_2367_11); + const __m256i rhs_mat_2367_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_3, 4), m4), rhs_hbit_2367_51); // Index : 32 - 39, 96 - 103 - const __m256i rhs_mat_0145_20 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_4, m4), rhs_hbit_0145_20); - const __m256i rhs_mat_0145_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_4, 4), m4), rhs_hbit_0145_60); + const __m256i rhs_mat_0145_20 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_4, m4), rhs_hbit_0145_20); + const __m256i rhs_mat_0145_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_4, 4), m4), rhs_hbit_0145_60); - const __m256i rhs_mat_2367_20 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_4, m4), rhs_hbit_2367_20); - const __m256i rhs_mat_2367_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_4, 4), m4), rhs_hbit_2367_60); + const __m256i rhs_mat_2367_20 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_4, m4), rhs_hbit_2367_20); + const __m256i rhs_mat_2367_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_4, 4), m4), rhs_hbit_2367_60); // Index : 40 - 47, 104 - 111 - const __m256i rhs_mat_0145_21 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_5, m4), rhs_hbit_0145_21); - const __m256i rhs_mat_0145_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_5, 4), m4), rhs_hbit_0145_61); + const __m256i rhs_mat_0145_21 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_5, m4), rhs_hbit_0145_21); + const __m256i rhs_mat_0145_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_5, 4), m4), rhs_hbit_0145_61); - const __m256i rhs_mat_2367_21 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_5, m4), rhs_hbit_2367_21); - const __m256i rhs_mat_2367_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_5, 4), m4), rhs_hbit_2367_61); + const __m256i rhs_mat_2367_21 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_5, m4), rhs_hbit_2367_21); + const __m256i rhs_mat_2367_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_5, 4), m4), rhs_hbit_2367_61); // Index : 48 - 55, 112 - 119 - const __m256i rhs_mat_0145_30 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_6, m4), rhs_hbit_0145_30); - const __m256i rhs_mat_0145_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_6, 4), m4), rhs_hbit_0145_70); + const __m256i rhs_mat_0145_30 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_6, m4), rhs_hbit_0145_30); + const __m256i rhs_mat_0145_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_6, 4), m4), rhs_hbit_0145_70); - const __m256i rhs_mat_2367_30 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_6, m4), rhs_hbit_2367_30); - const __m256i rhs_mat_2367_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_6, 4), m4), rhs_hbit_2367_70); + const __m256i rhs_mat_2367_30 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_6, m4), rhs_hbit_2367_30); + const __m256i rhs_mat_2367_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_6, 4), m4), rhs_hbit_2367_70); // Index : 56 - 63, 120 - 127 - const __m256i rhs_mat_0145_31 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_7, m4), rhs_hbit_0145_31); - const __m256i rhs_mat_0145_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_7, 4), m4), rhs_hbit_0145_71); + const __m256i rhs_mat_0145_31 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_7, m4), rhs_hbit_0145_31); + const __m256i rhs_mat_0145_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_7, 4), m4), rhs_hbit_0145_71); - const __m256i rhs_mat_2367_31 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_7, m4), rhs_hbit_2367_31); - const __m256i rhs_mat_2367_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_7, 4), m4), rhs_hbit_2367_71); + const __m256i rhs_mat_2367_31 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_7, m4), rhs_hbit_2367_31); + const __m256i rhs_mat_2367_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_7, 4), m4), rhs_hbit_2367_71); // Shuffle pattern one - right side input const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) @@ -9609,60 +9609,60 @@ void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo // Index : 0 -7, 64 - 71 // Comments indicate the indices of elements from individual super block in non interleaved fashion // Index : 0 -7, 64 - 71 - const __m256i rhs_mat_0145_00 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_0, m4), rhs_hbit_0145_00); - const __m256i rhs_mat_0145_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_0, 4), m4), rhs_hbit_0145_40); + const __m256i rhs_mat_0145_00 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_0, m4), rhs_hbit_0145_00); + const __m256i rhs_mat_0145_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_0, 4), m4), rhs_hbit_0145_40); - const __m256i rhs_mat_2367_00 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_0, m4), rhs_hbit_2367_00); - const __m256i rhs_mat_2367_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_0, 4), m4), rhs_hbit_2367_40); + const __m256i rhs_mat_2367_00 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_0, m4), rhs_hbit_2367_00); + const __m256i rhs_mat_2367_40 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_0, 4), m4), rhs_hbit_2367_40); // Index : 8 - 15, 72 - 79 - const __m256i rhs_mat_0145_01 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_1, m4), rhs_hbit_0145_01); - const __m256i rhs_mat_0145_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_1, 4), m4), rhs_hbit_0145_41); + const __m256i rhs_mat_0145_01 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_1, m4), rhs_hbit_0145_01); + const __m256i rhs_mat_0145_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_1, 4), m4), rhs_hbit_0145_41); - const __m256i rhs_mat_2367_01 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_1, m4), rhs_hbit_2367_01); - const __m256i rhs_mat_2367_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_1, 4), m4), rhs_hbit_2367_41); + const __m256i rhs_mat_2367_01 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_1, m4), rhs_hbit_2367_01); + const __m256i rhs_mat_2367_41 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_1, 4), m4), rhs_hbit_2367_41); // Index : 16 - 23, 80 - 87 - const __m256i rhs_mat_0145_10 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_2, m4), rhs_hbit_0145_10); - const __m256i rhs_mat_0145_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_2, 4), m4), rhs_hbit_0145_50); + const __m256i rhs_mat_0145_10 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_2, m4), rhs_hbit_0145_10); + const __m256i rhs_mat_0145_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_2, 4), m4), rhs_hbit_0145_50); - const __m256i rhs_mat_2367_10 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_2, m4), rhs_hbit_2367_10); - const __m256i rhs_mat_2367_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_2, 4), m4), rhs_hbit_2367_50); + const __m256i rhs_mat_2367_10 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_2, m4), rhs_hbit_2367_10); + const __m256i rhs_mat_2367_50 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_2, 4), m4), rhs_hbit_2367_50); // Index : 24 - 31, 88 - 95 - const __m256i rhs_mat_0145_11 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_3, m4), rhs_hbit_0145_11); - const __m256i rhs_mat_0145_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_3, 4), m4), rhs_hbit_0145_51); + const __m256i rhs_mat_0145_11 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_3, m4), rhs_hbit_0145_11); + const __m256i rhs_mat_0145_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_3, 4), m4), rhs_hbit_0145_51); - const __m256i rhs_mat_2367_11 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_3, m4), rhs_hbit_2367_11); - const __m256i rhs_mat_2367_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_3, 4), m4), rhs_hbit_2367_51); + const __m256i rhs_mat_2367_11 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_3, m4), rhs_hbit_2367_11); + const __m256i rhs_mat_2367_51 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_3, 4), m4), rhs_hbit_2367_51); // Index : 32 - 39, 96 - 103 - const __m256i rhs_mat_0145_20 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_4, m4), rhs_hbit_0145_20); - const __m256i rhs_mat_0145_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_4, 4), m4), rhs_hbit_0145_60); + const __m256i rhs_mat_0145_20 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_4, m4), rhs_hbit_0145_20); + const __m256i rhs_mat_0145_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_4, 4), m4), rhs_hbit_0145_60); - const __m256i rhs_mat_2367_20 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_4, m4), rhs_hbit_2367_20); - const __m256i rhs_mat_2367_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_4, 4), m4), rhs_hbit_2367_60); + const __m256i rhs_mat_2367_20 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_4, m4), rhs_hbit_2367_20); + const __m256i rhs_mat_2367_60 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_4, 4), m4), rhs_hbit_2367_60); // Index : 40 - 47, 104 - 111 - const __m256i rhs_mat_0145_21 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_5, m4), rhs_hbit_0145_21); - const __m256i rhs_mat_0145_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_5, 4), m4), rhs_hbit_0145_61); + const __m256i rhs_mat_0145_21 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_5, m4), rhs_hbit_0145_21); + const __m256i rhs_mat_0145_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_5, 4), m4), rhs_hbit_0145_61); - const __m256i rhs_mat_2367_21 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_5, m4), rhs_hbit_2367_21); - const __m256i rhs_mat_2367_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_5, 4), m4), rhs_hbit_2367_61); + const __m256i rhs_mat_2367_21 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_5, m4), rhs_hbit_2367_21); + const __m256i rhs_mat_2367_61 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_5, 4), m4), rhs_hbit_2367_61); // Index : 48 - 55, 112 - 119 - const __m256i rhs_mat_0145_30 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_6, m4), rhs_hbit_0145_30); - const __m256i rhs_mat_0145_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_6, 4), m4), rhs_hbit_0145_70); + const __m256i rhs_mat_0145_30 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_6, m4), rhs_hbit_0145_30); + const __m256i rhs_mat_0145_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_6, 4), m4), rhs_hbit_0145_70); - const __m256i rhs_mat_2367_30 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_6, m4), rhs_hbit_2367_30); - const __m256i rhs_mat_2367_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_6, 4), m4), rhs_hbit_2367_70); + const __m256i rhs_mat_2367_30 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_6, m4), rhs_hbit_2367_30); + const __m256i rhs_mat_2367_70 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_6, 4), m4), rhs_hbit_2367_70); // Index : 56 - 63, 120 - 127 - const __m256i rhs_mat_0145_31 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_0145_7, m4), rhs_hbit_0145_31); - const __m256i rhs_mat_0145_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_0145_7, 4), m4), rhs_hbit_0145_71); + const __m256i rhs_mat_0145_31 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_0145_7, m4), rhs_hbit_0145_31); + const __m256i rhs_mat_0145_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_0145_7, 4), m4), rhs_hbit_0145_71); - const __m256i rhs_mat_2367_31 = _mm256_or_si256(_mm256_and_si256(rhs_lbit_mat_2367_7, m4), rhs_hbit_2367_31); - const __m256i rhs_mat_2367_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_lbit_mat_2367_7, 4), m4), rhs_hbit_2367_71); + const __m256i rhs_mat_2367_31 = _mm256_or_si256(_mm256_and_si256(rhs_raw_lbit_2367_7, m4), rhs_hbit_2367_31); + const __m256i rhs_mat_2367_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_lbit_2367_7, 4), m4), rhs_hbit_2367_71); // Shuffle pattern one - right side input const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 7182479f61..2a3fe71150 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -647,10 +647,10 @@ void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } for (int l = 0; l < nb; l++) { for (int k = 0; k < (qk / (4 * blocklen)); k++) { - const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ; - const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16; - const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32; - const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48; + const int8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64; + const int8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16; + const int8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32; + const int8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48; for (int j = 0; j < ncols_interleaved; j++) { sumi1 = 0; sumi2 = 0; @@ -659,22 +659,10 @@ void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, sumi = 0; int offset = ((k / 2) % 2) + j * 2; for (int i = 0; i < blocklen; ++i) { - const int hbits_index = k * ncols_interleaved * blocklen + j * blocklen + i; - const int lbits_index = (hbits_index / 32) * 64 + (hbits_index % 32); - const int v0_hbits = (int8_t) ((b_ptr[l].qh[hbits_index] & 3) << 4); - const int v1_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 2 ) & 3) << 4); - const int v2_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 4 ) & 3) << 4); - const int v3_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 6 ) & 3) << 4); - - const int v0_lbits = (int8_t) (b_ptr[l].qh[lbits_index] & 0xF); - const int v1_lbits = (int8_t) (b_ptr[l].qh[lbits_index + 32] & 0xF); - const int v2_lbits = (int8_t) ((b_ptr[l].qh[lbits_index] >> 4) & 0xF); - const int v3_lbits = (int8_t) ((b_ptr[l].qh[lbits_index + 32] >> 4) & 0xF); - - const int v0 = v0_hbits | v0_lbits; - const int v1 = v1_hbits | v1_lbits; - const int v2 = v2_hbits | v2_lbits; - const int v3 = v3_hbits | v3_lbits; + int8_t v0 = (int8_t)((b_ptr[l].qh[hbits_index] & 3) << 4) | (b_ptr[l].ql[lbits_index] & 0xF) - 32; + int8_t v1 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 2 ) & 3) << 4) | (b_ptr[l].ql[lbits_index + 32] & 0xF) - 32; + int8_t v2 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 4 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index] >> 4) & 0xF) - 32; + int8_t v3 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 6 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index + 32] >> 4) & 0xF) - 32; sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]); sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]); @@ -1226,20 +1214,19 @@ void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, for (int y = 0; y < nr / 4; y++) { const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb); + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { sumf[m][j] = 0.0; - sum_minf[m][j] = 0.0; } } for (int l = 0; l < nb; l++) { for (int k = 0; k < (qk / (4 * blocklen)); k++) { - const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ; - const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16; - const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32; - const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48; + const int8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64; + const int8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16; + const int8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32; + const int8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48; for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { sumi1 = 0; @@ -1251,20 +1238,11 @@ void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, for (int i = 0; i < blocklen; ++i){ const int hbits_index = k * ncols_interleaved * blocklen + j * blocklen + i; const int lbits_index = (hbits_index / 32) * 64 + (hbits_index % 32); - const int v0_hbits = (int8_t) ((b_ptr[l].qh[hbits_index] & 3) << 4); - const int v1_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 2 ) & 3) << 4); - const int v2_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 4 ) & 3) << 4); - const int v3_hbits = (int8_t) (((b_ptr[l].qh[hbits_index] >> 6 ) & 3) << 4); - const int v0_lbits = (int8_t) (b_ptr[l].qh[lbits_index] & 0xF); - const int v1_lbits = (int8_t) (b_ptr[l].qh[lbits_index + 32] & 0xF); - const int v2_lbits = (int8_t) ((b_ptr[l].qh[lbits_index] >> 4) & 0xF); - const int v3_lbits = (int8_t) ((b_ptr[l].qh[lbits_index + 32] >> 4) & 0xF); - - const int v0 = v0_hbits | v0_lbits; - const int v1 = v1_hbits | v1_lbits; - const int v2 = v2_hbits | v2_lbits; - const int v3 = v3_hbits | v3_lbits; + int8_t v0 = (int8_t)((b_ptr[l].qh[hbits_index] & 3) << 4) | (b_ptr[l].ql[lbits_index] & 0xF) - 32; + int8_t v1 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 2 ) & 3) << 4) | (b_ptr[l].ql[lbits_index + 32] & 0xF) - 32; + int8_t v2 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 4 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index] >> 4) & 0xF) - 32; + int8_t v3 = (int8_t)(((b_ptr[l].qh[hbits_index] >> 6 ) & 3) << 4) | ((b_ptr[l].ql[lbits_index + 32] >> 4) & 0xF) - 32; sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]); sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);