diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2c5d395798..cea92ce9d8 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -386,7 +386,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { }, [GGML_TYPE_Q2_0C] = { .from_float = quantize_row_q2_0c, - .vec_dot = NULL, // TODO: We should have the fallback kernel when KleidiAI is not used + .vec_dot = ggml_vec_dot_q2_0c_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index 659e205776..5fd585fc3a 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -580,6 +580,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, + /* .run_kernel_lut_ex = */ nullptr, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, @@ -600,6 +601,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, + /* .run_kernel_lut_ex = */ nullptr, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, @@ -613,6 +615,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .packed_size_ex = */ &rhs_ps_fn5, /* .packed_stride_ex = */ &rhs_stride_fn4, /* .pack_func_ex = */ &rhs_pack_fn12, + /* .pack_func_lut_ex = */ nullptr, }, /* .required_cpu = */ CPU_FEATURE_SVE | CPU_FEATURE_I8MM | CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, @@ -635,6 +638,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, + /* .run_kernel_lut_ex = */ nullptr, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, @@ -655,6 +659,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, + /* .run_kernel_lut_ex = */ nullptr, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, @@ -668,6 +673,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .packed_size_ex = */ &rhs_ps_fn5, /* .packed_stride_ex = */ &rhs_stride_fn4, /* .pack_func_ex = */ &rhs_pack_fn12, + /* .pack_func_lut_ex = */ nullptr, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, @@ -690,6 +696,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, + /* .run_kernel_lut_ex = */ nullptr, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, @@ -710,6 +717,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, + /* .run_kernel_lut_ex = */ nullptr, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, @@ -723,6 +731,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .packed_size_ex = */ &rhs_ps_fn5, /* .packed_stride_ex = */ &rhs_stride_fn4, /* .pack_func_ex = */ &rhs_pack_fn12, + /* .pack_func_lut_ex = */ nullptr, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 091d1f698d..b3d4c0adc8 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -422,6 +422,68 @@ void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, *s = sumf; } +void ggml_vec_dot_q2_0c_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_0c * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + GGML_ASSERT(n % QKQ2_0C == 0); + const int nb = n / QKQ2_0C; + + float sumf = 0.0f; + + static const int8_t q2_0c_vals[4] = { -3, -1, 1, 3 }; + const int bytes_per_block = QKQ2_0C / 4; + const int bytes_per_half = QK_K / 4; + + for (int i = 0; i < nb; ++i) { + const block_q2_0c * xb = x + i; + const block_q8_K * y0 = y + (i * 2 + 0); + const block_q8_K * y1 = y + (i * 2 + 1); + + int32_t sum0 = 0; + int32_t sum1 = 0; + + for (int j = 0; j < bytes_per_half; ++j) { + const uint8_t byte = xb->qs[j]; + const int8_t q0 = q2_0c_vals[(byte >> 0) & 0x03]; + const int8_t q1 = q2_0c_vals[(byte >> 2) & 0x03]; + const int8_t q2 = q2_0c_vals[(byte >> 4) & 0x03]; + const int8_t q3 = q2_0c_vals[(byte >> 6) & 0x03]; + + const int base = j * 4; + sum0 += q0 * y0->qs[base + 0]; + sum0 += q1 * y0->qs[base + 1]; + sum0 += q2 * y0->qs[base + 2]; + sum0 += q3 * y0->qs[base + 3]; + } + + for (int j = bytes_per_half; j < bytes_per_block; ++j) { + const uint8_t byte = xb->qs[j]; + const int8_t q0 = q2_0c_vals[(byte >> 0) & 0x03]; + const int8_t q1 = q2_0c_vals[(byte >> 2) & 0x03]; + const int8_t q2 = q2_0c_vals[(byte >> 4) & 0x03]; + const int8_t q3 = q2_0c_vals[(byte >> 6) & 0x03]; + + const int base = (j - bytes_per_half) * 4; + sum1 += q0 * y1->qs[base + 0]; + sum1 += q1 * y1->qs[base + 1]; + sum1 += q2 * y1->qs[base + 2]; + sum1 += q3 * y1->qs[base + 3]; + } + + const float d = GGML_CPU_FP16_TO_FP32(xb->d); + sumf += d * ((float) sum0 * y0->d + (float) sum1 * y1->d); + } + + *s = sumf; +} + void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 5bc022f1d4..c565b0fd39 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -53,6 +53,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q2_0c_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 0533471935..c9cf2a1ac8 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -2203,10 +2203,16 @@ static inline uint8_t map_int8_to_uint2_idx(int32_t v0) { switch(v0) { case -3: return 0; + case -2: + return 1; case -1: return 1; + case 0: + return 1; case 1: return 2; + case 2: + return 2; case 3: return 3; default: @@ -2300,7 +2306,6 @@ void quantize_row_q2_0c_ref(const float * GGML_RESTRICT x, block_q2_0c * GGML_RE if (qi3 < qmin) qi3 = qmin; if (qi3 > qmax) qi3 = qmax; - // TODO: What if we have -2 or +2? const uint8_t v0_u8 = map_int8_to_uint2_idx(qi0); const uint8_t v1_u8 = map_int8_to_uint2_idx(qi1); const uint8_t v2_u8 = map_int8_to_uint2_idx(qi2); diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index 037c0582bb..3c06319528 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -146,6 +146,7 @@ int main(int argc, char * argv[]) { type == GGML_TYPE_TQ1_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY : type == GGML_TYPE_TQ2_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY : type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : + type == GGML_TYPE_Q2_0C ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : type == GGML_TYPE_IQ2_S ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : @@ -167,7 +168,7 @@ int main(int argc, char * argv[]) { const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S ? MAX_DOT_PRODUCT_ERROR_LOWBIT - : type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0 + : type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0 || type == GGML_TYPE_Q2_0C ? MAX_DOT_PRODUCT_ERROR_TERNARY : MAX_DOT_PRODUCT_ERROR; failed = !(vec_dot_error < max_allowed_error);