From ea450209d05acc72f1adc61b4478803f234bfed2 Mon Sep 17 00:00:00 2001 From: chraac Date: Sat, 20 Dec 2025 23:06:02 +0800 Subject: [PATCH] opt: use qf32 internal precision for vec_dot_f16_f32 --- ggml/src/ggml-hexagon/htp/matmul-ops.c | 38 ++++++++++++++------------ 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 6e2d46b112..bf0edd7fa9 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -915,17 +915,20 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri uint32_t nv0 = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nv1 = n % VLEN_FP16; // leftover elements - const HVX_Vector zero = Q6_Vh_vsplat_R(0x3C00); // 1.0 in fp16 + const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in fp16 + const HVX_Vector zero = Q6_V_vsplat_R(0); // 0.0 in fp16 HVX_Vector rsum = Q6_V_vsplat_R(0); uint32_t i = 0; for (i = 0; i < nv0; i++) { - HVX_VectorPair yp = vy[i]; - HVX_Vector x = vx[i]; - HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), zero); // mul by 1.0 + HVX_VectorPair yp = vy[i]; + HVX_Vector x = vx[i]; + HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), one); // mul by 1.0 + HVX_Vector y_hi = Q6_Vqf32_vadd_VsfVsf(Q6_V_hi_W(yp), zero); // convert to qf32 + HVX_Vector y_lo = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(yp), zero); // convert to qf32 - HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); - HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); + HVX_Vector hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(xp), y_hi); + HVX_Vector lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(xp), y_lo); HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo); rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); @@ -934,24 +937,25 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri if (nv1) { HVX_VectorPair yp = vy[i]; HVX_Vector x = vx[i]; - HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), zero); // mul by 1.0 + HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), one); // mul by 1.0 - HVX_Vector l_x; - HVX_Vector l_y; + HVX_Vector leftover_x; + HVX_Vector leftover_y; if (nv1 >= VLEN_FP32) { - HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi); + HVX_Vector y_lo = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(yp), zero); // convert to qf32 + HVX_Vector lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(xp), y_lo); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, lo); nv1 -= VLEN_FP32; - l_x = Q6_V_hi_W(xp); - l_y = Q6_V_hi_W(yp); + leftover_x = Q6_V_hi_W(xp); + leftover_y = Q6_Vqf32_vadd_VsfVsf(Q6_V_hi_W(yp), zero); // convert to qf32 } else { - l_x = Q6_V_lo_W(xp); - l_y = Q6_V_lo_W(yp); + leftover_x = Q6_V_lo_W(xp); + leftover_y = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(yp), zero); // convert to qf32 } if (nv1) { - HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(l_x), l_y); - HVX_Vector sum = Q6_V_valign_VVR(lo, Q6_V_vzero(), nv1 * sizeof(float)); + HVX_Vector lo = Q6_Vqf32_vmpy_Vqf32Vqf32(leftover_x, leftover_y); + HVX_Vector sum = Q6_V_valign_VVR(lo, zero, nv1 * sizeof(float)); rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); }