From 8af1f5f430baaab1719db8f0a259bcc2a1cfdaa0 Mon Sep 17 00:00:00 2001 From: nullname Date: Sat, 24 Jan 2026 14:02:07 +0800 Subject: [PATCH] ggml-hexagon: flash-attn opt (#19025) * optimize flash attention kernel by improving score computation and online softmax update * wip * Refactor online softmax update in flash attention kernel for improved performance * Optimize flash attention kernel by replacing float array with HVX_Vector for score computation * wip --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 54 +++++++++++++--------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 1de47d0f3d..c7cb2a4e0b 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -2,9 +2,9 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" +#include #include #include - #include #include @@ -111,7 +111,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict hvx_vec_store_u(r, 4, rsum); } -// MAD: y (F32) += x (F16) * v (float) +// MAD: y (F32) += x (F16) * s (float) static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; HVX_Vector * restrict ptr_y = (HVX_Vector *) y; @@ -318,9 +318,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in uint32_t ic = 0; // Process in blocks of 32 (VLEN_FP32) - for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) { + static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); + HVX_Vector_x4 scores_x4; + HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); + for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores - float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; + float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE]; for (int j = 0; j < VLEN_FP32; ++j) { const uint32_t cur_ic = ic + j; const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; @@ -356,36 +359,43 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in scores = Q6_Vsf_equals_Vqf32(scores); } - // 4. Online Softmax Update - HVX_Vector v_max = hvx_vec_reduce_max_f32(scores); - float m_block = hvx_vec_get_f32(v_max); + scores_x4.v[iv] = scores; + v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max); + } + { + // 4. Online Softmax Update + v_max = hvx_vec_reduce_max_f32(v_max); + float m_block = hvx_vec_get_f32(v_max); float M_old = M; float M_new = (m_block > M) ? m_block : M; M = M_new; - float ms = expf(M_old - M_new); - + const float ms = expf(M_old - M_new); hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); - S = S * ms; HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new); - HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); - HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); + HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); + for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { + HVX_Vector scores = scores_x4.v[iv]; + HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); + HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); - HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P); - float p_sum = hvx_vec_get_f32(p_sum_vec); - S += p_sum; + p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); - // 5. Accumulate V - float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; - *(HVX_Vector*)p_arr = P; + // 5. Accumulate V + float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; + *(HVX_Vector*)p_arr = P; - for (int j = 0; j < VLEN_FP32; ++j) { - const uint32_t cur_ic = ic + j; - const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; - hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + for (int j = 0; j < VLEN_FP32; ++j) { + const uint32_t cur_ic = ic2 + j; + const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + } } + + p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec); + S = S * ms + hvx_vec_get_f32(p_sum_vec); } // Leftover