Resolve PR comments

This commit is contained in:
Swetha B S 2025-10-23 06:15:06 -07:00
parent 8ffdaea39a
commit d611fb43e8
1 changed files with 27 additions and 20 deletions

View File

@ -1809,8 +1809,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
#undef MMID_MATRIX_ROW
}
void forward_get_rows(const ggml_compute_params * params,
ggml_tensor * dst) {
void forward_get_rows(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
@ -1823,9 +1822,8 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
}
}
static void ggml_compute_forward_get_rows_q6_Kx8(
const ggml_compute_params * params,
ggml_tensor * dst) {
static void ggml_compute_forward_get_rows_q6_Kx8(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@ -1877,13 +1875,14 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
}
}
/**
* Dequantizes a single logical row from the repacked q6_Kx8 data format.
/**
* Dequantizes a single logical row from data repacked with quant interleaving for repacked block_q6_Kx8
*
* @param p_repacked_blocks Pointer to the start of the 'block_q6_Kx8' structures for the entire row.
* @param y Output buffer for the dequantized float values.
* @param k Total number of elements (columns) in the logical row.
* @param row_idx_in_group The index (0-7) of the logical row to extract from the interleaved data.
* @param p_repacked_group_column_blocks Pointer to the start of 'block_q6_Kx8' for the row group.
* @param y Output buffer for the dequantized float values.
* @param k Total number of elements (columns) in the logical row.
* @param row_idx_in_group Index (0-7) of the logical row to dequantize.
*/
static void dequantize_row_q6_Kx8(
@ -1891,6 +1890,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
float * GGML_RESTRICT y,
int64_t k,
int row_idx_in_group) {
assert(k % QK_K == 0);
assert(row_idx_in_group >= 0 && row_idx_in_group < 8);
@ -1904,7 +1904,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
const uint8_t * ptr_ql_base = current_block->ql;
const uint8_t * ptr_qh_base = current_block->qh;
uint8_t * ptr_repacked_scales = (uint8_t *)current_block->scales; // 16*8 scales repacked - 2bytes of each super block stored together
uint8_t * ptr_repacked_scales = (uint8_t *) current_block->scales; // 16 * 8 scales repacked
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
@ -1920,22 +1920,26 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
const uint8_t ql_l32 = read_ql_qh_from_repacked(ptr_ql_base, row_idx_in_group, n/2 + l + 32);
const uint8_t qh_l = read_ql_qh_from_repacked(ptr_qh_base, row_idx_in_group, n/4 + l);
const int8_t q1 = (int8_t)((ql_l0 & 0xF) | (((qh_l >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql_l32 & 0xF) | (((qh_l >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql_l0 >> 4) | (((qh_l >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql_l32 >> 4) | (((qh_l >> 6) & 3) << 4)) - 32;
const int8_t q1 = (int8_t) ((ql_l0 & 0xF) | (((qh_l >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t) ((ql_l32 & 0xF) | (((qh_l >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t) ((ql_l0 >> 4) | (((qh_l >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t) ((ql_l32 >> 4) | (((qh_l >> 6) & 3) << 4)) - 32;
y[l + 0] = d_super_block * sc0 * q1;
y[l] = d_super_block * sc0 * q1;
y[l + 32] = d_super_block * sc1 * q2;
y[l + 64] = d_super_block * sc2 * q3;
y[l + 96] = d_super_block * sc3 * q4;
}
y += 128;
ptr_repacked_scales = (uint8_t *)current_block->scales + 64;
ptr_repacked_scales = (uint8_t *) current_block->scales + 64;
}
}
}
/**
* Read the scales from the repacked ptr_repacked_scales
*/
static inline int8_t read_scale_from_repacked(const uint8_t* ptr_repacked_scales, int row_idx_in_group, int scale_idx) {
const int pair_group_idx = scale_idx / 2;
const int sub_idx_in_pair = scale_idx % 2;
@ -1943,12 +1947,15 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
return ptr_repacked_scales[offset];
}
static inline uint8_t read_ql_qh_from_repacked(const uint8_t* ptr_ql_base, int row_idx_in_group, int ql_0_idx) {
/**
* Read the qh / ql from the repacked ptr_qh_ql_base
*/
static inline uint8_t read_ql_qh_from_repacked(const uint8_t* ptr_qh_ql_base, int row_idx_in_group, int ql_0_idx) {
const int block_size_interleave = 8;
const int chunk_idx = ql_0_idx / block_size_interleave;
const int offset_in_chunk = ql_0_idx % block_size_interleave;
const int offset = chunk_idx * (8 * block_size_interleave) + row_idx_in_group * block_size_interleave + offset_in_chunk;
return ptr_ql_base[offset];
return ptr_qh_ql_base[offset];
}
int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {