From 4630b5187e0682d444d105197b9ab6f8dba444ad Mon Sep 17 00:00:00 2001 From: Manogna-Sree Date: Mon, 11 Aug 2025 02:47:20 -0700 Subject: [PATCH] Fix for inaccuracy of GEMM Q6K --- ggml/src/ggml-cpu/arch/x86/repack.cpp | 65 +++++++++++++-------------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index d4a29058ba..fb82584550 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -6900,7 +6900,6 @@ void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m256i rhs_mat_2367_31 = _mm256_or_si256(_mm256_and_si256(rhs_raw_mat_2367_7, m4), rhs_hbit_2367_31); const __m256i rhs_mat_2367_71 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(rhs_raw_mat_2367_7, 4), m4), rhs_hbit_2367_71); - // Shuffle pattern one - right side input const __m256i rhs_mat_0145_00_sp1 = _mm256_shuffle_epi32(rhs_mat_0145_00, 136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) const __m256i rhs_mat_2367_00_sp1 = _mm256_shuffle_epi32(rhs_mat_2367_00, 136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) @@ -7094,38 +7093,38 @@ void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo __m256i lhs_mat_01_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 0); __m256i lhs_mat_23_71 = _mm256_permute2f128_si256(lhs_mat_0123_71, lhs_mat_0123_71, 17); - __m256i lhs_mat_s_01_00 = _mm256_maddubs_epi16(lhs_mat_01_00, m32s); - __m256i lhs_mat_s_23_00 = _mm256_maddubs_epi16(lhs_mat_23_00, m32s); - __m256i lhs_mat_s_01_01 = _mm256_maddubs_epi16(lhs_mat_01_01, m32s); - __m256i lhs_mat_s_23_01 = _mm256_maddubs_epi16(lhs_mat_23_01, m32s); - __m256i lhs_mat_s_01_10 = _mm256_maddubs_epi16(lhs_mat_01_10, m32s); - __m256i lhs_mat_s_23_10 = _mm256_maddubs_epi16(lhs_mat_23_10, m32s); - __m256i lhs_mat_s_01_11 = _mm256_maddubs_epi16(lhs_mat_01_11, m32s); - __m256i lhs_mat_s_23_11 = _mm256_maddubs_epi16(lhs_mat_23_11, m32s); - __m256i lhs_mat_s_01_20 = _mm256_maddubs_epi16(lhs_mat_01_20, m32s); - __m256i lhs_mat_s_23_20 = _mm256_maddubs_epi16(lhs_mat_23_20, m32s); - __m256i lhs_mat_s_01_21 = _mm256_maddubs_epi16(lhs_mat_01_21, m32s); - __m256i lhs_mat_s_23_21 = _mm256_maddubs_epi16(lhs_mat_23_21, m32s); - __m256i lhs_mat_s_01_30 = _mm256_maddubs_epi16(lhs_mat_01_30, m32s); - __m256i lhs_mat_s_23_30 = _mm256_maddubs_epi16(lhs_mat_23_30, m32s); - __m256i lhs_mat_s_01_31 = _mm256_maddubs_epi16(lhs_mat_01_31, m32s); - __m256i lhs_mat_s_23_31 = _mm256_maddubs_epi16(lhs_mat_23_31, m32s); - __m256i lhs_mat_s_01_40 = _mm256_maddubs_epi16(lhs_mat_01_40, m32s); - __m256i lhs_mat_s_23_40 = _mm256_maddubs_epi16(lhs_mat_23_40, m32s); - __m256i lhs_mat_s_01_41 = _mm256_maddubs_epi16(lhs_mat_01_41, m32s); - __m256i lhs_mat_s_23_41 = _mm256_maddubs_epi16(lhs_mat_23_41, m32s); - __m256i lhs_mat_s_01_50 = _mm256_maddubs_epi16(lhs_mat_01_50, m32s); - __m256i lhs_mat_s_23_50 = _mm256_maddubs_epi16(lhs_mat_23_50, m32s); - __m256i lhs_mat_s_01_51 = _mm256_maddubs_epi16(lhs_mat_01_51, m32s); - __m256i lhs_mat_s_23_51 = _mm256_maddubs_epi16(lhs_mat_23_51, m32s); - __m256i lhs_mat_s_01_60 = _mm256_maddubs_epi16(lhs_mat_01_60, m32s); - __m256i lhs_mat_s_23_60 = _mm256_maddubs_epi16(lhs_mat_23_60, m32s); - __m256i lhs_mat_s_01_61 = _mm256_maddubs_epi16(lhs_mat_01_61, m32s); - __m256i lhs_mat_s_23_61 = _mm256_maddubs_epi16(lhs_mat_23_61, m32s); - __m256i lhs_mat_s_01_70 = _mm256_maddubs_epi16(lhs_mat_01_70, m32s); - __m256i lhs_mat_s_23_70 = _mm256_maddubs_epi16(lhs_mat_23_70, m32s); - __m256i lhs_mat_s_01_71 = _mm256_maddubs_epi16(lhs_mat_01_71, m32s); - __m256i lhs_mat_s_23_71 = _mm256_maddubs_epi16(lhs_mat_23_71, m32s); + __m256i lhs_mat_s_01_00 = _mm256_maddubs_epi16(m32s, lhs_mat_01_00); + __m256i lhs_mat_s_23_00 = _mm256_maddubs_epi16(m32s, lhs_mat_23_00); + __m256i lhs_mat_s_01_01 = _mm256_maddubs_epi16(m32s, lhs_mat_01_01); + __m256i lhs_mat_s_23_01 = _mm256_maddubs_epi16(m32s, lhs_mat_23_01); + __m256i lhs_mat_s_01_10 = _mm256_maddubs_epi16(m32s, lhs_mat_01_10); + __m256i lhs_mat_s_23_10 = _mm256_maddubs_epi16(m32s, lhs_mat_23_10); + __m256i lhs_mat_s_01_11 = _mm256_maddubs_epi16(m32s, lhs_mat_01_11); + __m256i lhs_mat_s_23_11 = _mm256_maddubs_epi16(m32s, lhs_mat_23_11); + __m256i lhs_mat_s_01_20 = _mm256_maddubs_epi16(m32s, lhs_mat_01_20); + __m256i lhs_mat_s_23_20 = _mm256_maddubs_epi16(m32s, lhs_mat_23_20); + __m256i lhs_mat_s_01_21 = _mm256_maddubs_epi16(m32s, lhs_mat_01_21); + __m256i lhs_mat_s_23_21 = _mm256_maddubs_epi16(m32s, lhs_mat_23_21); + __m256i lhs_mat_s_01_30 = _mm256_maddubs_epi16(m32s, lhs_mat_01_30); + __m256i lhs_mat_s_23_30 = _mm256_maddubs_epi16(m32s, lhs_mat_23_30); + __m256i lhs_mat_s_01_31 = _mm256_maddubs_epi16(m32s, lhs_mat_01_31); + __m256i lhs_mat_s_23_31 = _mm256_maddubs_epi16(m32s, lhs_mat_23_31); + __m256i lhs_mat_s_01_40 = _mm256_maddubs_epi16(m32s, lhs_mat_01_40); + __m256i lhs_mat_s_23_40 = _mm256_maddubs_epi16(m32s, lhs_mat_23_40); + __m256i lhs_mat_s_01_41 = _mm256_maddubs_epi16(m32s, lhs_mat_01_41); + __m256i lhs_mat_s_23_41 = _mm256_maddubs_epi16(m32s, lhs_mat_23_41); + __m256i lhs_mat_s_01_50 = _mm256_maddubs_epi16(m32s, lhs_mat_01_50); + __m256i lhs_mat_s_23_50 = _mm256_maddubs_epi16(m32s, lhs_mat_23_50); + __m256i lhs_mat_s_01_51 = _mm256_maddubs_epi16(m32s, lhs_mat_01_51); + __m256i lhs_mat_s_23_51 = _mm256_maddubs_epi16(m32s, lhs_mat_23_51); + __m256i lhs_mat_s_01_60 = _mm256_maddubs_epi16(m32s, lhs_mat_01_60); + __m256i lhs_mat_s_23_60 = _mm256_maddubs_epi16(m32s, lhs_mat_23_60); + __m256i lhs_mat_s_01_61 = _mm256_maddubs_epi16(m32s, lhs_mat_01_61); + __m256i lhs_mat_s_23_61 = _mm256_maddubs_epi16(m32s, lhs_mat_23_61); + __m256i lhs_mat_s_01_70 = _mm256_maddubs_epi16(m32s, lhs_mat_01_70); + __m256i lhs_mat_s_23_70 = _mm256_maddubs_epi16(m32s, lhs_mat_23_70); + __m256i lhs_mat_s_01_71 = _mm256_maddubs_epi16(m32s, lhs_mat_01_71); + __m256i lhs_mat_s_23_71 = _mm256_maddubs_epi16(m32s, lhs_mat_23_71); // Shuffle pattern one - left side input const __m256i lhs_mat_01_00_sp1 = _mm256_shuffle_epi32(lhs_mat_01_00, 160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3)