From d611fb43e8bc92c1672212bab2307ff712a64c11 Mon Sep 17 00:00:00 2001 From: Swetha B S Date: Thu, 23 Oct 2025 06:15:06 -0700 Subject: [PATCH] Resolve PR comments --- ggml/src/ggml-cpu/repack.cpp | 47 +++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 336c3df354..17502af627 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1809,8 +1809,7 @@ template src[0]; switch (src0->type) { @@ -1823,9 +1822,8 @@ template src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -1877,13 +1875,14 @@ template = 0 && row_idx_in_group < 8); @@ -1904,7 +1904,7 @@ template ql; const uint8_t * ptr_qh_base = current_block->qh; - uint8_t * ptr_repacked_scales = (uint8_t *)current_block->scales; // 16*8 scales repacked - 2bytes of each super block stored together + uint8_t * ptr_repacked_scales = (uint8_t *) current_block->scales; // 16 * 8 scales repacked for (int n = 0; n < QK_K; n += 128) { for (int l = 0; l < 32; ++l) { int is = l/16; @@ -1920,22 +1920,26 @@ template > 0) & 3) << 4)) - 32; - const int8_t q2 = (int8_t)((ql_l32 & 0xF) | (((qh_l >> 2) & 3) << 4)) - 32; - const int8_t q3 = (int8_t)((ql_l0 >> 4) | (((qh_l >> 4) & 3) << 4)) - 32; - const int8_t q4 = (int8_t)((ql_l32 >> 4) | (((qh_l >> 6) & 3) << 4)) - 32; + const int8_t q1 = (int8_t) ((ql_l0 & 0xF) | (((qh_l >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t) ((ql_l32 & 0xF) | (((qh_l >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t) ((ql_l0 >> 4) | (((qh_l >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t) ((ql_l32 >> 4) | (((qh_l >> 6) & 3) << 4)) - 32; - y[l + 0] = d_super_block * sc0 * q1; + y[l] = d_super_block * sc0 * q1; y[l + 32] = d_super_block * sc1 * q2; y[l + 64] = d_super_block * sc2 * q3; y[l + 96] = d_super_block * sc3 * q4; } y += 128; - ptr_repacked_scales = (uint8_t *)current_block->scales + 64; + ptr_repacked_scales = (uint8_t *) current_block->scales + 64; } } } + + /** + * Read the scales from the repacked ptr_repacked_scales + */ static inline int8_t read_scale_from_repacked(const uint8_t* ptr_repacked_scales, int row_idx_in_group, int scale_idx) { const int pair_group_idx = scale_idx / 2; const int sub_idx_in_pair = scale_idx % 2; @@ -1943,12 +1947,15 @@ template