From 7a99dc85e2d26a9b9c540cad887322eae8924f03 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 4 Mar 2026 21:55:29 -0800 Subject: [PATCH] hexagon: Flash Attention optimizations (dma, mpyacc, multi-row) and MatMul updates (#20118) * ggml-hexagon: enhance hvx_dot_f16_f16_aa_rx4 for improved performance by expanding vector handling and optimizing accumulation # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * ggml-hexagon: optimize hvx_dot_f16_f16_aa_rx4 and enhance hvx_vec_reduce_sum_f32x4 for improved performance and reduced complexity * ggml-hexagon: add hvx_dot_f16_f16_aa_rx32 for enhanced vector processing in flash attention # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * optimize hvx_dot_f16_f16_aa_rx4 and hvx_dot_f16_f16_aa_rx32 by removing unused scale parameter and improving vector accumulation # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * ggml-hexagon: refactor hvx_dot_f16_f16_aa_rx4 for improved readability and return HVX_Vector for better integration # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * ggml-hexagon: initialize sums variable in hvx_dot_f16_f16_aa_rx32 for clarity * ggml-hexagon: fix compiling error * fix hvx_dot_f16_f16_aa_rx4 to handle leftover elements correctly using masking * refactor hvx_dot_f16_f16_aa_rx4 to accept vector and leftover element counts as parameters for improved clarity and flexibility * wip * fa: instrumentation and dma reordering * hex-fa: use block-size 64 to improve DMA pipelining * hex-fa: optimize vec-dot for v79 and above * hex-fa: use block size 64 * hex-fa: avoid scalar fp32->fp16 conversions * hex-fa: simplify dot_f16 functions using optimized vec_mpyacc * hex-fa: rewrite mad_f32_f16 using hvx_vec_mpyacc * hex-mm: use mpyacc in matmul dot functions --------- Co-authored-by: chraac --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 371 ++++++++++++--------- ggml/src/ggml-hexagon/htp/hvx-base.h | 21 +- ggml/src/ggml-hexagon/htp/hvx-copy.h | 4 +- ggml/src/ggml-hexagon/htp/hvx-reduce.h | 30 ++ ggml/src/ggml-hexagon/htp/matmul-ops.c | 88 ++--- 5 files changed, 293 insertions(+), 221 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 74c777d4c3..6dc978dd68 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -10,6 +10,7 @@ #include "hex-dma.h" #include "hvx-utils.h" +#include "hvx-dump.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -17,6 +18,16 @@ #include "htp-msg.h" #include "htp-ops.h" +// Must be multiple of 32 +#define FLASH_ATTN_BLOCK_SIZE (32 * 2) + +// This is a bit of a hack because the compiler is strugling to properly inline +// the default hvx_vec_f32_to_f16 with output into the local array. +static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) +{ + *(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1); +} + // Dot product of two F16 vectors, accumulating to float static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 @@ -25,175 +36,184 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; #pragma unroll(4) for (i = 0; i < nvec; i++) { - HVX_Vector y_hf = vy[i]; - HVX_Vector x_hf = vx[i]; - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]); } if (nloe) { - // Load x (fp16) and zero-out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf); } - rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); - hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); + HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum))); + hvx_vec_store_u(r, 4, rsum); } -static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, - const void * restrict y, - const void * restrict x0, - const void * restrict x1, - unsigned int n, - float s) { - const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 - const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 +static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y, + const uint8_t * restrict x, + const size_t stride_x, + const size_t nvec, + const size_t nloe) { + const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16 + const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16 + const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16 + const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements - - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); + HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; - #pragma unroll(4) for (i = 0; i < nvec; i++) { HVX_Vector y_hf = vy[i]; HVX_Vector x0_hf = vx0[i]; HVX_Vector x1_hf = vx1[i]; + HVX_Vector x2_hf = vx2[i]; + HVX_Vector x3_hf = vx3[i]; - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf); + rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf); + rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf); } if (nloe) { // Load x (fp16) and zero-out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); - HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); - HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]); + HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]); - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf); + rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf); + rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf); } - HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); - hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); + HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p))); + HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p))); + HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p))); + HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p))); + + HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } }; + return hvx_vec_reduce_sum_f32x4(rsum0123); } -// MAD: y (F32) += x (F16) * s (F32) -static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { - const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; - HVX_Vector * restrict ptr_y = (HVX_Vector *) y; +static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y, + const uint8_t * restrict x, + const size_t stride_x, + const size_t n, + float s) { + + const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + const size_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector sums; // initialize at j = 0 + const size_t stride_x_4 = stride_x * 4; + for (uint32_t j = 0; j < VLEN_FP32; j += 4) { + HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe); + HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32); + sums = Q6_V_vmux_QVV(pred, sums, sums_x4); + x += stride_x_4; + } + + sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums); + return Q6_Vsf_equals_Vqf32(sums); +} + +// MAD: y (F32) += x (F16) * s (F16) +static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, int n) { + const HVX_Vector * restrict vx0 = (const HVX_Vector *) x; + + HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y; + HVX_Vector * restrict vy = (HVX_Vector *) y; uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector S = hvx_vec_splat_f16(s); + HVX_Vector S0 = hvx_vec_splat_f16(*s); uint32_t i = 0; - #pragma unroll(4) + + #pragma unroll(2) for (i = 0; i < nvec; ++i) { - // Multiply x * s -> pair of F32 vectors - HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); - ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2])); - ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1])); + vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0); } if (nloe) { - HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + HVX_VectorPair xy_p = vy_p[i]; + xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0); - HVX_Vector xs = Q6_V_lo_W(xs_p); - i = 2 * i; // index for ptr_y + HVX_Vector xy = Q6_V_lo_W(xy_p); + i = 2 * i; // index for vy - if (nloe >= 32) { - ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p); + if (nloe >= VLEN_FP32) { + vy[i] = xy; + nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p); } if (nloe) { - HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - hvx_vec_store_a(&ptr_y[i], nloe * 4, xy); + hvx_vec_store_a(&vy[i], nloe * 4, xy); } } } -// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32) -static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, - const void * restrict x0, - const void * restrict x1, - float s0, - float s1, - int n) { - const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0; - const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1; - HVX_Vector * restrict ptr_y = (HVX_Vector *) y; +// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16) +static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1, + const __fp16 * restrict s0, const __fp16 * restrict s1, int n) { + const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0; + const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1; + + HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y; + HVX_Vector * restrict vy = (HVX_Vector *) y; uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector S0 = hvx_vec_splat_f16(s0); - HVX_Vector S1 = hvx_vec_splat_f16(s1); + HVX_Vector S0 = hvx_vec_splat_f16(*s0); + HVX_Vector S1 = hvx_vec_splat_f16(*s1); uint32_t i = 0; + #pragma unroll(2) for (i = 0; i < nvec; ++i) { - // Multiply x * s -> pair of F32 vectors - HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); - HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); - - HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); - HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); - - ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2])); - ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1])); + vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0); + vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1); } if (nloe) { - HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); - HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); + HVX_VectorPair xy_p = vy_p[i]; + xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0); + xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1); - HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); - HVX_Vector xs = xs_p_lo; - i = 2 * i; // index for ptr_y + HVX_Vector xy = Q6_V_lo_W(xy_p); + i = 2 * i; // index for vy - if (nloe >= 32) { - ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - nloe -= 32; ++i; - xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); + if (nloe >= VLEN_FP32) { + vy[i] = xy; + nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p); } if (nloe) { - HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - hvx_vec_store_a(&ptr_y[i], nloe * 4, xy); + hvx_vec_store_a(&vy[i], nloe * 4, xy); } } } -#define FLASH_ATTN_BLOCK_SIZE 128 - struct htp_fa_context { const struct htp_ops_context * octx; @@ -226,7 +246,12 @@ struct htp_fa_context { size_t size_v_block; size_t size_m_block; + uint32_t qrows; + uint32_t qrows_per_thread; + bool is_q_fp32; + + uint64_t t_start; }; static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) { @@ -296,9 +321,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const uint32_t nb3 = dst->nb[3]; // total rows in q - const uint32_t nr = neq1*neq2*neq3; - - const uint32_t dr = (nr + nth - 1) / nth; + const uint32_t nr = factx->qrows; + const uint32_t dr = factx->qrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = MIN(ir0 + dr, nr); @@ -337,15 +361,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3); dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1); - const uint32_t h = iq2; // head index - const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; - - HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); - HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); - - // Clear accumulator - hvx_splat_f32_a(spad_a, 0, DV); - float * VKQ32 = (float *) spad_a; + // FARF(HIGH, "fa %u: prefetch Q: ir %u iq1 %u iq2 %u iq3 %u q_row_ptr %p size %u : usec %u", ith, ir, iq1, iq2, iq3, q_row_ptr, size_q_row, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); const __fp16 * mp_base = NULL; if (mask) { @@ -376,8 +393,23 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Mask is 1D contiguous for this row dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); } + + // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", + // ith, ir, ib, iq1, iq2, iq3, + // size_k_row, size_v_row, current_block_size, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); } + const uint32_t h = iq2; // head index + const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; + + HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); + HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); + + // Clear accumulator + hvx_splat_f32_a(spad_a, 0, DV); + float * VKQ32 = (float *) (spad_a + 0); + uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; if (factx->is_q_fp32) { hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16 @@ -393,23 +425,19 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * uint8_t * v_base = dma_queue_pop(dma).dst; // V __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M + // FARF(HIGH, "fa %u: process: ir %u ib %u : iq1 %u iq2 %u iq3 %u q_ptr_vtcm %p : usec %u", + // ith, ir, ib, iq1, iq2, iq3, q_ptr_vtcm, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); + // Inner loop processing the block from VTCM uint32_t ic = 0; - // Process in blocks of 32 (VLEN_FP32) - static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); - HVX_Vector_x4 scores_x4; + // Process in sub-blocks of 32 (VLEN_FP32) + HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32]; HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores - float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; - for (uint32_t j = 0; j < VLEN_FP32; j += 2) { - const uint32_t cur_ic = ic + j; - const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded; - hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale); - } - - HVX_Vector scores = *(HVX_Vector *) scores_arr; + HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale); // 2. Softcap if (factx->logit_softcap != 0.0f) { @@ -428,35 +456,35 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * scores = Q6_Vsf_equals_Vqf32(scores); } - scores_x4.v[iv] = scores; + sb_scores[iv] = scores; v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max } { // 4. Online Softmax Update HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec); - HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec); - HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec)); + HVX_Vector diff_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec)); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); M_vec = M_new_vec; hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { - HVX_Vector scores = scores_x4.v[iv]; + HVX_Vector scores = sb_scores[iv]; HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec); HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); // 5. Accumulate V - float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; - *(HVX_Vector *) p_arr = P; + __fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16]; + hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0)); for (uint32_t j = 0; j < VLEN_FP32; j += 2) { const uint32_t cur_ic = ic2 + j; const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; - hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV); + hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV); } } @@ -464,47 +492,50 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec)); } - // Sync scalars for leftover/next block if needed - float M = hvx_vec_get_f32(M_vec); - float S = hvx_vec_get_f32(S_vec); + if (ic < current_block_size) { + // Sync scalars for leftover/next block if needed + float M = hvx_vec_get_f32(M_vec); + float S = hvx_vec_get_f32(S_vec); - // Leftover - for (; ic < current_block_size; ++ic) { - float s_val; - const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded; - hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale); - if (factx->logit_softcap != 0.0f) { - s_val = factx->logit_softcap * tanhf(s_val); + // Leftover + for (; ic < current_block_size; ++ic) { + float s_val; + const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded; + hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale); + if (factx->logit_softcap != 0.0f) { + s_val = factx->logit_softcap * tanhf(s_val); + } + + if (mask) { + const float m_val = m_base[ic]; + s_val += slope * m_val; + } + + const float Mold = M; + __fp16 vs = 1.0f; + + if (s_val > M) { + M = s_val; + HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); + + float ms = hvx_vec_get_f32(ms_vec); + S = S * ms + vs; + } else { + HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M); + vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); + S += vs; + } + + const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded; + + hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV); } - if (mask) { - const float m_val = m_base[ic]; - s_val += slope * m_val; - } - - const float Mold = M; - float vs = 1.0f; - - if (s_val > M) { - M = s_val; - HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M); - HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); - hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); - - float ms = hvx_vec_get_f32(ms_vec); - S = S * ms + vs; - } else { - HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M); - vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); - S += vs; - } - - const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded; - - hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs); + M_vec = hvx_vec_splat_f32(M); + S_vec = hvx_vec_splat_f32(S); } - M_vec = hvx_vec_splat_f32(M); - S_vec = hvx_vec_splat_f32(S); // Issue DMA for next+1 block (if exists) if (ib + 2 < factx->n_blocks) { @@ -525,6 +556,11 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start); dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); } + + // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", + // ith, ir, next_ib, iq1, iq2, iq3, + // size_k_row, size_v_row, next_block_size, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); } } @@ -586,6 +622,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { struct htp_fa_context factx; factx.octx = octx; + factx.t_start = HAP_perf_get_qtimer_count(); + factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); factx.src0_div1 = init_fastdiv_values(q->ne[1]); @@ -632,6 +670,15 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2); factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + // total rows in q + const uint32_t neq0 = q->ne[0]; + const uint32_t neq1 = q->ne[1]; + const uint32_t neq2 = q->ne[2]; + const uint32_t neq3 = q->ne[3]; + + factx.qrows = neq1*neq2*neq3; + factx.qrows_per_thread = (factx.qrows + octx->n_threads - 1) / octx->n_threads; + size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 octx->src0_spad.size_per_thread = size_q_block * 1; diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index 12a1b7f128..701637f22b 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -38,7 +38,7 @@ static inline HVX_Vector hvx_vec_splat_f32(float v) { return Q6_V_vsplat_R(u.i); } -static inline HVX_Vector hvx_vec_splat_f16(float v) { +static inline HVX_Vector hvx_vec_splat_f16(_Float16 v) { union { __fp16 f; uint16_t i; } u = { .f = v }; return Q6_Vh_vsplat_R(u.i); } @@ -170,4 +170,23 @@ static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) { return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0); } +#if __HVX_ARCH__ < 79 + +static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vector x, HVX_Vector y) +{ + HVX_VectorPair m = Q6_Wqf32_vmpy_VhfVhf(x, y); + HVX_Vector a0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(m), Q6_V_lo_W(acc))); + HVX_Vector a1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(m), Q6_V_hi_W(acc))); + return Q6_W_vcombine_VV(a1, a0); +} + +#else + +static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vector x, HVX_Vector y) +{ + return Q6_Wsf_vmpyacc_WsfVhfVhf(acc, x, y); +} + +#endif + #endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h index ae0dbed030..851482e01b 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-copy.h +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -42,11 +42,11 @@ static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float)); } -static inline void hvx_splat_f16_a(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } -static inline void hvx_splat_f16_u(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } diff --git a/ggml/src/ggml-hexagon/htp/hvx-reduce.h b/ggml/src/ggml-hexagon/htp/hvx-reduce.h index 1ca7c05d98..3c0073ef6d 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-reduce.h +++ b/ggml/src/ggml-hexagon/htp/hvx-reduce.h @@ -46,6 +46,21 @@ static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) { #if __HVX_ARCH__ > 75 +static inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) { + HVX_VectorPair sum_p01 = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4); + HVX_VectorPair sum_p23 = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4); + HVX_Vector sum_sf01 = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01)); + HVX_Vector sum_sf23 = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23)); + + HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(sum_sf23, sum_sf01, 8); + HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123)); + + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8)); + return sum_sf; +} + static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); @@ -72,6 +87,21 @@ static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) #else +static inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) { + HVX_VectorPair sum_p01 = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4); + HVX_VectorPair sum_p23 = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4); + HVX_Vector sum_qf01 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01)); + HVX_Vector sum_qf23 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23)); + + HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(sum_qf23), Q6_Vsf_equals_Vqf32(sum_qf01), 8); + HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123)); + + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8)); + return Q6_Vsf_equals_Vqf32(sum_qf); +} + static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 6f6f51f01f..9ca74aedfe 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -1234,27 +1234,24 @@ static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; #pragma unroll(4) for (i = 0; i < nvec; i++) { - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]); } if (nloe) { HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf); } - rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); - hvx_vec_store_u(&s[0], 4, rsum); + HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p))); + hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum)); } static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, @@ -1267,35 +1264,30 @@ static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, uint32_t nvec = n / VLEN_FP16; uint32_t nloe = n % VLEN_FP16; - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); + HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; #pragma unroll(2) for (i = 0; i < nvec; i++) { HVX_Vector y_hf = y[i]; - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf); - - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); - rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf); } if (nloe) { HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]); HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]); - HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); - - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); - rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf); } - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1)); + HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p))); + HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p))); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1); hvx_vec_store_u(s0, 8, rsum); } @@ -1311,10 +1303,10 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res uint32_t nloe = n % VLEN_FP16; // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + HVX_VectorPair r0_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair r0_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair r1_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair r1_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; @@ -1326,20 +1318,10 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res HVX_Vector c1_hf = y1[i]; // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 - HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); - HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); - HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); - HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); - - HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); - HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); - HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); - HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); + r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf); + r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf); + r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf); + r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf); } if (nloe) { @@ -1350,23 +1332,17 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]); HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]); - HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); - HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); - HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); - HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); - - HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); - HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); - HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); - HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); - + r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf); + r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf); + r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf); + r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf); } + HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p))); + HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p))); + HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p))); + HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p))); + // Reduce and store results HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);