From 89bbd505993d78be8b5ba31f625fb989e3d886d2 Mon Sep 17 00:00:00 2001 From: vithulep Date: Thu, 31 Jul 2025 09:44:18 +0530 Subject: [PATCH] Added sve for dequantized_q8_0 --- ggml/src/ggml-cpu/arch/arm/quants.c | 94 ++++++++++------------------- 1 file changed, 32 insertions(+), 62 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 61fde4f162..6457c649e4 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -47,73 +47,43 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i block_q8_0 * GGML_RESTRICT y = vy; -// #if defined(__ARM_FEATURE_SVE) -// // const int sve_register_length = svcntb() * 8; //get the vector length -// // const int ggml_f32_epr = sve_register_length / 32; -// const svfloat32_t inactive1 = svdup_n_f32(0.0f); -// const svbool_t pg = svptrue_b32(); -// svfloat32_t zero = svdup_f32(0.0f); -// svfloat32_t half = svdup_f32(0.5f); -// const svint32_t inactive2 = svdup_n_s32(0); +#if defined(__ARM_FEATURE_SVE) + const int sve_register_length = svcntb() * 8; + const int ggml_f32_epr = sve_register_length / 32; + const svfloat32_t inactive1 = svdup_n_f32(0.0f); + const svbool_t pg = svptrue_b32(); + svfloat32_t zero = svdup_f32(0.0f); + svfloat32_t half = svdup_f32(0.5f); -// for (int i = 0; i < nb; i+=1) { -// svfloat32_t srcv1, srcv2, srcv3, srcv4; -// svfloat32_t asrcv1, asrcv2, asrcv3, asrcv4; -// float32_t amax1 = 0.0; + for (int i = 0; i < nb; i+=1) { + svfloat32_t srcv1, asrcv1; + svfloat32_t sv_max = svdup_n_f32(0.0f); + float32_t amax = 0.0; -// srcv1 = svld1_f32(pg, x + i*32); -// asrcv1 = svabs_f32_m(inactive1, pg, srcv1); + for (int j = 0; j < QK8_0; j+=ggml_f32_epr) { + srcv1 = svld1_f32(pg, x + i*32 + j); + asrcv1 = svabs_f32_m(inactive1, pg, srcv1); + sv_max = svmax_f32_m(pg, sv_max, asrcv1); + } + amax = svmaxv_f32(pg, sv_max); + float32_t d = amax / ((1 << 7) - 1); + float32_t id = d ? 1.0f/d : 0.0f; + y[i].d = GGML_FP32_TO_FP16(d); + for (int j = 0; j < QK8_0; j+=ggml_f32_epr) { + srcv1 = svld1_f32(pg, x + i*32 + j); + const svfloat32_t v1 = svmul_n_f32_m(pg, srcv1, id); -// srcv2 = svld1_f32(pg, x + i*32 + 8); -// asrcv2 = svabs_f32_m(inactive1, pg, srcv2); + svbool_t ge_zero = svcmpge_f32(pg, v1, zero); + svfloat32_t v_pos = svadd_f32_m(pg, v1, half); + svfloat32_t v_neg = svsub_f32_m(pg, v1, half); -// srcv3 = svld1_f32(pg, x + i*32 + 16); -// asrcv3 = svabs_f32_m(inactive1, pg, srcv3); + svfloat32_t v_rounded = svsel_f32(ge_zero, v_pos, v_neg); + svint32_t result = svcvt_s32_f32_x(pg, v_rounded); + svst1b_s32(pg, &y[i].qs[j], result); + } + } -// srcv4 = svld1_f32(pg, x + i*32 + 24); -// asrcv4 = svabs_f32_m(inactive1, pg, srcv4); - -// svfloat32_t max1 = svmax_f32_m(pg, asrcv2, asrcv1); -// svfloat32_t max2 = svmax_f32_m(pg, asrcv4, asrcv3); -// svfloat32_t max3 = svmax_f32_m(pg, max2, max1); -// amax1 = svmaxv_f32(pg, max3); - -// float32_t d1 = amax1 / ((1 << 7) - 1); -// float32_t id1 = d1 ? 1.0f/d1 : 0.0f; -// y[i].d = GGML_FP32_TO_FP16(d1); - -// const svfloat32_t v1 = svmul_n_f32_m(pg, srcv1, id1); -// const svfloat32_t v2 = svmul_n_f32_m(pg, srcv2, id1); -// const svfloat32_t v3 = svmul_n_f32_m(pg, srcv3, id1); -// const svfloat32_t v4 = svmul_n_f32_m(pg, srcv4, id1); - -// svbool_t ge_zero = svcmpge_f32(pg, v1, zero); -// svfloat32_t v_rounded = svsel_f32(ge_zero, svadd_f32_m(pg, v1, half), svsub_f32_m(pg, v1, half)); -// // svint32_t v_rounded = svcvt_s32_f32_m(inactive2, pg, v1); - -// svbool_t ge_zero_2 = svcmpge_f32(pg, v2, zero); -// svfloat32_t v_rounded_2 = svsel_f32(ge_zero_2, svadd_f32_m(pg, v2, half), svsub_f32_m(pg, v2, half)); - -// svbool_t ge_zero_3 = svcmpge_f32(pg, v3, zero); -// svfloat32_t v_rounded_3 = svsel_f32(ge_zero_3, svadd_f32_m(pg, v3, half), svsub_f32_m(pg, v3, half)); - -// svbool_t ge_zero_4 = svcmpge_f32(pg, v4, zero); -// svfloat32_t v_rounded_4 = svsel_f32(ge_zero_4, svadd_f32_m(pg, v4, half), svsub_f32_m(pg, v4, half)); - -// svint32_t result = svcvt_s32_f32_x(pg, v_rounded); -// svst1b_s32(pg, &y[i].qs[0], result); - -// svint32_t result_2 = svcvt_s32_f32_x(pg, v_rounded_2); -// svst1b_s32(pg, &y[i].qs[8], result_2); - -// svint32_t result_3 = svcvt_s32_f32_x(pg, v_rounded_3); -// svst1b_s32(pg, &y[i].qs[16], result_3); - -// svint32_t result_4 = svcvt_s32_f32_x(pg, v_rounded_4); -// svst1b_s32(pg, &y[i].qs[24], result_4); -// } - -#if defined(__ARM_NEON) +#elif defined(__ARM_NEON) for (int i = 0; i < nb; i++) { float32x4_t srcv [8]; float32x4_t asrcv[8];