From 89f10baad5a1809055d71110dff60e55561b9c62 Mon Sep 17 00:00:00 2001 From: nullname Date: Sat, 31 Jan 2026 13:14:20 +0800 Subject: [PATCH] ggml-hexagon: flash-attention and reduce-sum optimizations (#19141) * wip * ggml-hexagon: add vectorized dot product function for FP32 and FP16 accumulation * ggml-hexagon: optimize dot product functions for FP16 and FP32 with new vectorized implementations * wip * ggml-hexagon: optimize hvx_vec_dump_f32_n and hvx_vec_reduce_sum_qf32x2 functions for improved performance * ggml-hexagon: refactor dot product functions to use a common loading function for improved readability * optimize vector dot product functions to use unified reduction for improved performance * wip * ggml-hexagon: add vectorized dot product function for FP32 and FP16 accumulation * ggml-hexagon: optimize dot product functions for FP16 and FP32 with new vectorized implementations * wip * ggml-hexagon: optimize hvx_vec_dump_f32_n and hvx_vec_reduce_sum_qf32x2 functions for improved performance * ggml-hexagon: refactor dot product functions to use a common loading function for improved readability * optimize vector dot product functions to use unified reduction for improved performance * hexagon: optimize reduce-sum for v75+ * hexagon: always keep row_sums in sf/fp32 * ggml-hexagon: enhance directory checks for HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT * fix compiling error after rebase --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/CMakeLists.txt | 16 ++- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 159 ++++++++++++++++++--- ggml/src/ggml-hexagon/htp/hvx-dump.h | 9 +- ggml/src/ggml-hexagon/htp/hvx-reduce.h | 41 ++++++ ggml/src/ggml-hexagon/htp/matmul-ops.c | 107 +++++++------- ggml/src/ggml-hexagon/htp/softmax-ops.c | 4 +- ggml/src/ggml-hexagon/htp/unary-ops.c | 4 +- 7 files changed, 248 insertions(+), 92 deletions(-) diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index 2b69197017..f3a583543c 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -1,8 +1,20 @@ file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT) file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) -if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}" OR NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}") - message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.") +if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}") + message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT point to the correct Hexagon SDK installation.") +endif() + +if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}") + message("Try to read HEXAGON_TOOLS_ROOT from hexagon_sdk.json") + file(READ "${HEXAGON_SDK_ROOT}/hexagon_sdk.json" HEXAGON_SDK_CONFIG_PATH) + string(JSON HEXAGON_TOOLS_PATH GET ${HEXAGON_SDK_CONFIG_PATH} "root" "tools" "info" 0 "path") + message("Found HEXAGON_TOOLS_PATH: ${HEXAGON_TOOLS_PATH}") + set(HEXAGON_TOOLS_ROOT "${HEXAGON_SDK_ROOT}/${HEXAGON_TOOLS_PATH}") + file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) + if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}") + message(FATAL_ERROR "Make sure HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.") + endif() endif() message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels") diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index c7cb2a4e0b..c184637443 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -17,6 +17,12 @@ #include "htp-msg.h" #include "htp-ops.h" +static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) { + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements + return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); +} + // Dot product of FP32 and FP16 vectors, accumulating to float static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 @@ -33,23 +39,19 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict #pragma unroll(4) for (i = 0; i < nvec; i++) { // Load y (fp32) and convert into fp16 - HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements - HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements - HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); // Load x (fp16) HVX_Vector x_hf = vx[i]; HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); } if (nloe) { // Load y (fp32) and convert into fp16 - HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements - HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements - HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); // Load x (fp16) HVX_Vector x_hf = vx[i]; @@ -62,13 +64,72 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); } - rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s)); - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); + rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); + hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); +} - hvx_vec_store_u(r, 4, rsum); +// Dot product of FP32 and FP16 vectors, accumulating to float +static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r, + const void * restrict y, + const void * restrict x0, + const void * restrict x1, + unsigned int n, + float s) { + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 + const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 + const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum0 = Q6_V_vsplat_R(0); + HVX_Vector rsum1 = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); + // Load x (fp16) + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); + + // Load x (fp16) + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; + + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x0_hf = Q6_V_vand_QV(bmask, x0_hf); + x1_hf = Q6_V_vand_QV(bmask, x1_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + } + + HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); + hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); } // Dot product of two F16 vectors, accumulating to float @@ -91,7 +152,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); } if (nloe) { @@ -103,12 +164,62 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); } - rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s)); - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); - hvx_vec_store_u(r, 4, rsum); + rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); + hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); +} + +static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, + const void * restrict y, + const void * restrict x0, + const void * restrict x1, + unsigned int n, + float s) { + const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 + const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum0 = Q6_V_vsplat_R(0); + HVX_Vector rsum1 = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = vy[i]; + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + } + + if (nloe) { + HVX_Vector y_hf = vy[i]; + + // Load x (fp16) and zero-out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + } + + HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); + hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); } // MAD: y (F32) += x (F16) * s (float) @@ -317,20 +428,22 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in // Inner loop processing the block from VTCM uint32_t ic = 0; + const bool is_q_fp32 = (q->type == HTP_TYPE_F32); + // Process in blocks of 32 (VLEN_FP32) - static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); + static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); HVX_Vector_x4 scores_x4; HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores - float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE]; - for (int j = 0; j < VLEN_FP32; ++j) { + float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; + for (int j = 0; j < VLEN_FP32; j += 2) { const uint32_t cur_ic = ic + j; const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; - if (q->type == HTP_TYPE_F32) { - hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + if (is_q_fp32) { + hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); } else { - hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); } } @@ -403,7 +516,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in float s_val; const uint8_t * k_ptr = k_base + ic * size_k_row_padded; - if (q->type == HTP_TYPE_F32) { + if (is_q_fp32) { hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); } else { hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); diff --git a/ggml/src/ggml-hexagon/htp/hvx-dump.h b/ggml/src/ggml-hexagon/htp/hvx-dump.h index e882227893..85201fc345 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-dump.h +++ b/ggml/src/ggml-hexagon/htp/hvx-dump.h @@ -28,19 +28,16 @@ static void hvx_vec_dump_f16(char * pref, HVX_Vector v) { } static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) { - union { - HVX_Vector v; - float d[32]; - } u = { .v = v }; + HVX_VectorAlias u = { .v = v }; const uint32_t n0 = n / 16; const uint32_t n1 = n % 16; int i = 0; for (; i < n0; i++) { - hex_dump_f32_line(pref, u.d + (16 * i), 16); + hex_dump_f32_line(pref, u.fp32 + (16 * i), 16); } if (n1) { - hex_dump_f32_line(pref, u.d + (16 * i), n1); + hex_dump_f32_line(pref, u.fp32 + (16 * i), n1); } } diff --git a/ggml/src/ggml-hexagon/htp/hvx-reduce.h b/ggml/src/ggml-hexagon/htp/hvx-reduce.h index 8845fe73ea..1ca7c05d98 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-reduce.h +++ b/ggml/src/ggml-hexagon/htp/hvx-reduce.h @@ -44,6 +44,45 @@ static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) { return hvx_vec_reduce_sum_n_qf32(in, 32); } +#if __HVX_ARCH__ > 75 + +static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { + HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); + HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); + + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16)); + return sum_sf; +} + +static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // fp32 nbytes + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum + width = width << 1; + } + return sum; +} + +#else + +static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { + HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); + HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); + + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16)); + return Q6_Vsf_equals_Vqf32(sum_qf); +} + static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) { unsigned int total = n * 4; // total vec nbytes unsigned int width = 4; // fp32 nbytes @@ -57,6 +96,8 @@ static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) return sum; } +#endif + static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) { return hvx_vec_reduce_sum_n_f32(in, 32); } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 1603ff2b3b..d251eeed33 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -11,6 +11,7 @@ #include "hex-dma.h" #include "hvx-utils.h" +#include "hvx-dump.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -320,7 +321,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); // Multiply and accumulate into int32. @@ -344,7 +345,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks @@ -362,14 +363,14 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * // Zero out unused scales HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Reduce and convert into fp32 - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); hvx_vec_store_u(&s[0], 4, r0_sum); } @@ -402,7 +403,7 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); HVX_Vector r1_sum = Q6_V_vsplat_R(0); @@ -432,8 +433,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks @@ -456,20 +457,18 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - // Convert into fp32 and reduce - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); - - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(&s[0], 8, rsum); } static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -493,7 +492,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); // Multiply and accumulate into int32. @@ -517,7 +516,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks @@ -535,14 +534,14 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * // Zero out unused scales HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Reduce and convert into fp32 - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); hvx_vec_store_u(&s[0], 4, r0_sum); } @@ -605,8 +604,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks @@ -629,20 +628,18 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - // Convert into fp32 and reduce - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); - - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(&s[0], 8, rsum); } static void vec_dot_mxfp4x4x2_q8x4x2(const int n, @@ -669,7 +666,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); // Multiply and accumulate into int32. @@ -708,7 +705,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } // Process leftovers @@ -741,14 +738,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, // Zero-out unused scales HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Reduce and convert into fp32 - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); hvx_vec_store_u(&s[0], 4, r0_sum); } @@ -781,13 +778,13 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); HVX_Vector r1_sum = Q6_V_vsplat_R(0); // Multiply and accumulate into int32. // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). + // Apply scale to acc and accumulate into the row sum (f32). const uint32_t nb = n / qk; // num full blocks int32_t nloe = n % qk; // num leftover elemements (must be signed) @@ -829,8 +826,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } // Process leftovers @@ -867,24 +864,22 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - // Zero-out unused scales + // Zero-out unused values HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - // Convert into fp32 and reduce - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); - - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(&s[0], 8, rsum); } static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -913,7 +908,7 @@ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } @@ -957,11 +952,8 @@ static void vec_dot_f16_f16_aa_rx2(const int n, rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); } - rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum1)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4); - - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1)); + hvx_vec_store_u(&s[0], 8, rsum); } static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -990,7 +982,7 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } @@ -1042,7 +1034,8 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); + // Convert into fp32 and reduce + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 1b6b2eba4a..e91a16d947 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -154,8 +154,8 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, v_pad[i] = v3; } - v = hvx_vec_reduce_sum_qf32(sum_vec); - sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v)); + v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); + sum_vec = hvx_vec_repl4(v); HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v); HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index be8be8c4e6..1a27cb6e63 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -57,8 +57,8 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); } - HVX_Vector reduced_sum = hvx_vec_reduce_sum_qf32(sum_v); - sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum)); + HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); + sum_v = hvx_vec_repl4(reduced_sum); HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);