q6K get_rows & dequantize function

This commit is contained in:
Swetha B S 2025-10-23 02:17:44 -07:00
parent 9de9672adb
commit 8ffdaea39a
1 changed files with 162 additions and 1 deletions

View File

@ -1576,6 +1576,11 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);
return true;
}
case GGML_OP_GET_ROWS:
{
size = 0;
return true;
}
default:
@ -1593,6 +1598,9 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
case GGML_OP_MUL_MAT_ID:
forward_mul_mat_id(params, op);
return true;
case GGML_OP_GET_ROWS:
forward_get_rows(params, op);
return true;
default:
// GGML_ABORT("fatal error");
break;
@ -1801,6 +1809,148 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
#undef MMID_MATRIX_ROW
}
void forward_get_rows(const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_Q6_K:
ggml_compute_forward_get_rows_q6_Kx8(params, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
}
static void ggml_compute_forward_get_rows_q6_Kx8(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);
assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == ggml_type_size(src0->type));
assert(ggml_nrows(dst) == nr);
const int ith = params->ith;
const int nth = params->nth;
// rows per thread
const int dr = (nr + nth - 1) / nth;
// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);
constexpr int nrows_interleaved = 8;
const size_t sizeof_one_repacked_block = sizeof(block_q6_Kx8);
const int num_repacked_blocks_per_row_width = nc / QK_K;
const size_t stride_between_actual_row_groups = num_repacked_blocks_per_row_width * sizeof_one_repacked_block;
for (int64_t i = ir0; i < ir1; ++i) {
const int64_t i12 = i / (ne11 * ne10);
const int64_t i11 = (i - i12 * ne11 * ne10) / ne10;
const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10);
const int64_t i01 = *(int32_t *)((char *)src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); // original logical row
GGML_ASSERT(i01 >= 0 && i01 < ne01);
const int row_group_idx = i01 / nrows_interleaved;
const int row_idx_in_group = i01 % nrows_interleaved;
const char * base_ptr_for_higher_dims_in_src0 = (const char *)src0->data + i11 * nb02 + i12 * nb03;
// Pointer to the first block_q6_Kx8 of the identified row_group_idx
const block_q6_Kx8 * p_first_repacked_block_of_group_x8 = (const block_q6_Kx8 *)(base_ptr_for_higher_dims_in_src0 + row_group_idx * stride_between_actual_row_groups);
dequantize_row_q6_Kx8(
p_first_repacked_block_of_group_x8,
(float *)((char *)dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3), nc, row_idx_in_group);
}
}
/**
* Dequantizes a single logical row from the repacked q6_Kx8 data format.
*
* @param p_repacked_blocks Pointer to the start of the 'block_q6_Kx8' structures for the entire row.
* @param y Output buffer for the dequantized float values.
* @param k Total number of elements (columns) in the logical row.
* @param row_idx_in_group The index (0-7) of the logical row to extract from the interleaved data.
*/
static void dequantize_row_q6_Kx8(
const void * GGML_RESTRICT p_repacked_blocks,
float * GGML_RESTRICT y,
int64_t k,
int row_idx_in_group) {
assert(k % QK_K == 0);
assert(row_idx_in_group >= 0 && row_idx_in_group < 8);
const int nb = k / QK_K;
const block_q6_Kx8 * blocks = (const block_q6_Kx8 *)p_repacked_blocks;
for (int i = 0; i < nb; i++) {
const block_q6_Kx8 * current_block = &blocks[i];
const float d_super_block = GGML_FP16_TO_FP32(current_block->d[row_idx_in_group]);
const uint8_t * ptr_ql_base = current_block->ql;
const uint8_t * ptr_qh_base = current_block->qh;
uint8_t * ptr_repacked_scales = (uint8_t *)current_block->scales; // 16*8 scales repacked - 2bytes of each super block stored together
for (int n = 0; n < QK_K; n += 128) {
for (int l = 0; l < 32; ++l) {
int is = l/16;
// get the 4 scales needed for q1, q2, q3 and q4
const int8_t sc0 = read_scale_from_repacked(ptr_repacked_scales, row_idx_in_group, is + 0);
const int8_t sc1 = read_scale_from_repacked(ptr_repacked_scales, row_idx_in_group, is + 2);
const int8_t sc2 = read_scale_from_repacked(ptr_repacked_scales, row_idx_in_group, is + 4);
const int8_t sc3 = read_scale_from_repacked(ptr_repacked_scales, row_idx_in_group, is + 6);
// get the right ql & qh values from the interleaved data
const uint8_t ql_l0 = read_ql_qh_from_repacked(ptr_ql_base, row_idx_in_group, n/2 + l + 0);
const uint8_t ql_l32 = read_ql_qh_from_repacked(ptr_ql_base, row_idx_in_group, n/2 + l + 32);
const uint8_t qh_l = read_ql_qh_from_repacked(ptr_qh_base, row_idx_in_group, n/4 + l);
const int8_t q1 = (int8_t)((ql_l0 & 0xF) | (((qh_l >> 0) & 3) << 4)) - 32;
const int8_t q2 = (int8_t)((ql_l32 & 0xF) | (((qh_l >> 2) & 3) << 4)) - 32;
const int8_t q3 = (int8_t)((ql_l0 >> 4) | (((qh_l >> 4) & 3) << 4)) - 32;
const int8_t q4 = (int8_t)((ql_l32 >> 4) | (((qh_l >> 6) & 3) << 4)) - 32;
y[l + 0] = d_super_block * sc0 * q1;
y[l + 32] = d_super_block * sc1 * q2;
y[l + 64] = d_super_block * sc2 * q3;
y[l + 96] = d_super_block * sc3 * q4;
}
y += 128;
ptr_repacked_scales = (uint8_t *)current_block->scales + 64;
}
}
}
static inline int8_t read_scale_from_repacked(const uint8_t* ptr_repacked_scales, int row_idx_in_group, int scale_idx) {
const int pair_group_idx = scale_idx / 2;
const int sub_idx_in_pair = scale_idx % 2;
const int offset = pair_group_idx * 16 + row_idx_in_group * 2 + sub_idx_in_pair;
return ptr_repacked_scales[offset];
}
static inline uint8_t read_ql_qh_from_repacked(const uint8_t* ptr_ql_base, int row_idx_in_group, int ql_0_idx) {
const int block_size_interleave = 8;
const int chunk_idx = ql_0_idx / block_size_interleave;
const int offset_in_chunk = ql_0_idx % block_size_interleave;
const int offset = chunk_idx * (8 * block_size_interleave) + row_idx_in_group * block_size_interleave + offset_in_chunk;
return ptr_ql_base[offset];
}
int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
(int) NB_COLS, (int) INTER_SIZE);
@ -1949,12 +2099,23 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
//if (op->src[1]->type == GGML_TYPE_Q8_0) {
// return true;
//}
} else if (op->op == GGML_OP_GET_ROWS
&& op->src[0]->buffer
&& (ggml_n_dims(op->src[0]) == 2)
&& op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()
&& ggml_repack_get_optimal_repack_type(op->src[0])) {
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}
if (op->src[0]->type == GGML_TYPE_Q6_K) {
return true;
}
}
return false;
}
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_GET_ROWS) {
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}