diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index af32c1e2c7..fc7ecb9041 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -17,121 +17,6 @@ #include "htp-msg.h" #include "htp-ops.h" -static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) { - HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements - HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements - return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); -} - -// Dot product of FP32 and FP16 vectors, accumulating to float -static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 - const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 - - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements - - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum = Q6_V_vsplat_R(0); - - uint32_t i = 0; - - #pragma unroll(4) - for (i = 0; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x_hf = vx[i]; - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x_hf = vx[i]; - - // Zero-out unused elements - // Note that we need to clear both x and y because they may contain NANs - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - x_hf = Q6_V_vand_QV(bmask, x_hf); - y_hf = Q6_V_vand_QV(bmask, y_hf); - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); - } - - rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); - hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); -} - -// Dot product of FP32 and FP16 vectors, accumulating to float -static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r, - const void * restrict y, - const void * restrict x0, - const void * restrict x1, - unsigned int n, - float s) { - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 - const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 - const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 - - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements - - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); - - uint32_t i = 0; - - #pragma unroll(2) - for (i = 0; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - // Load x (fp16) - HVX_Vector x0_hf = vx0[i]; - HVX_Vector x1_hf = vx1[i]; - - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x0_hf = vx0[i]; - HVX_Vector x1_hf = vx1[i]; - - // Zero-out unused elements - // Note that we need to clear both x and y because they may contain NANs - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - x0_hf = Q6_V_vand_QV(bmask, x0_hf); - x1_hf = Q6_V_vand_QV(bmask, x1_hf); - y_hf = Q6_V_vand_QV(bmask, y_hf); - - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); - } - - HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); - hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); -} - // Dot product of two F16 vectors, accumulating to float static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16