From 8710e5f9b9bd7246608808ccd3626bde8abf6ff9 Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Wed, 1 Apr 2026 21:13:08 +0530 Subject: [PATCH] hexagon: improve RMS_NORM and DIV accuracy (#21251) * hexagon-rms_norm: fix RMS_NORM for non-aligned tensor sizes Co-authored-by: Krishna Sridhar * hexagon-div: perform DIV in fp16 domain for lower dsp archs --------- Co-authored-by: Krishna Sridhar --- ggml/src/ggml-hexagon/htp/hvx-div.h | 86 ++++++++++++++++++++------- ggml/src/ggml-hexagon/htp/unary-ops.c | 41 ++++++++++--- 2 files changed, 97 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hvx-div.h b/ggml/src/ggml-hexagon/htp/hvx-div.h index 05cefea039..53ee304e74 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-div.h +++ b/ggml/src/ggml-hexagon/htp/hvx-div.h @@ -16,8 +16,10 @@ #if __HVX_ARCH__ < 79 #define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#define HVX_OP_MUL_F16(a, b) Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)) #else #define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#define HVX_OP_MUL_F16(a, b) Q6_Vhf_vmpy_VhfVhf(a, b) #endif // Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32. @@ -43,46 +45,67 @@ static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX return res; } -#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \ - do { \ - dst_type * restrict vdst = (dst_type *) dst; \ - src_type * restrict vsrc = (src_type *) src; \ - HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ - \ - const uint32_t nvec = n / VLEN_FP16; \ - const uint32_t nloe = n % VLEN_FP16; \ - \ - uint32_t i = 0; \ - \ - _Pragma("unroll(4)") \ - for (; i < nvec; i++) { \ - HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ - vdst[i] = res; \ - } \ - if (nloe) { \ - HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ - vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ - } \ +// Variant for =v79 +static inline HVX_Vector hvx_vec_hybrid_div_f16(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector f16_nan_inf_mask, HVX_Vector vec_hf_one_1_0) { +#if __HVX_ARCH__ < 79 + // For older architectures, use f16 reciprocal to avoid NaN/-inf issues + HVX_Vector vec2_inv = hvx_vec_inverse_f16_guard(vec2, f16_nan_inf_mask); + return HVX_OP_MUL_F16(vec1, vec2_inv); +#else + return hvx_vec_div_f16_using_f32(vec1, vec2, f32_nan_inf_mask, vec_hf_one_1_0); +#endif +} + #define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ src0_type * restrict vsrc0 = (src0_type *) src0; \ src1_type * restrict vsrc1 = (src1_type *) src1; \ \ - const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + const HVX_Vector f32_nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + const HVX_Vector f16_nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \ const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ \ const uint32_t nvec = n / VLEN_FP16; \ @@ -144,11 +179,15 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v \ _Pragma("unroll(4)") \ for (; i < nvec; i++) { \ - HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \ + f32_nan_inf_mask, f16_nan_inf_mask, \ + hf_one); \ vdst[i] = res; \ } \ if (nloe) { \ - HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \ + f32_nan_inf_mask, f16_nan_inf_mask, \ + hf_one); \ vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ } \ } while(0) @@ -247,5 +286,6 @@ HVX_DIV_DISPATCHER(hvx_div_f32) HVX_DIV_DISPATCHER(hvx_div_f16) #undef HVX_OP_MUL_F32 +#undef HVX_OP_MUL_F16 #endif // HVX_DIV_H diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 3d0928d4dc..13d28317d5 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -67,34 +67,61 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, uint8_t * restrict pad, const int num_elems, float epsilon) { + (void)pad; + const HVX_Vector * restrict v_src = (HVX_Vector *) src; HVX_Vector * restrict v_dst = (HVX_Vector *) dst; - HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares for full vectors + HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); - int step_of_1 = num_elems >> 5; #pragma unroll(4) - for (int i = 0; i < step_of_1; i++) { + for (int i = 0; i < nvec; i++) { HVX_Vector v1 = v_src[i]; HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); - sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); } - sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + // Reduce HVX sum + sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v); HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v); + // Scale full vectors HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); #pragma unroll(4) - for (int i = 0; i < step_of_1; i++) { + for (int i = 0; i < nvec; i++) { HVX_Vector v1 = v_src[i]; HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); - v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + HVX_Vector result = Q6_Vsf_equals_Vqf32(v2); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); } }