From fe1f3fbc2a1e04f2afca832c5f994e370cccbccc Mon Sep 17 00:00:00 2001 From: chraac Date: Mon, 2 Feb 2026 23:59:54 +0800 Subject: [PATCH] ggml-hexagon: optimize flash attention calculations with improved variable handling --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index a3797bdc1a..441f619698 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -409,6 +409,9 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const bool is_q_fp32 = (q->type == HTP_TYPE_F32); + const HVX_Vector logit_cap = hvx_vec_splat_f32(logit_softcap); + for (uint32_t ir = ir0; ir < ir1; ++ir) { const uint32_t iq3 = fastdiv(ir, &octx->src0_div21); const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1); @@ -468,7 +471,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; - const bool is_q_fp32 = (q->type == HTP_TYPE_F32); + const HVX_Vector slope_vec = hvx_vec_splat_f32(slope); for (uint32_t ib = 0; ib < n_blocks; ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; @@ -504,7 +507,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in // 2. Softcap if (logit_softcap != 0.0f) { scores = hvx_vec_tanh_f32(scores); - scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap)); + scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap); scores = Q6_Vsf_equals_Vqf32(scores); } @@ -518,7 +521,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair)); - HVX_Vector slope_vec = hvx_vec_splat_f32(slope); HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec); scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val)); scores = Q6_Vsf_equals_Vqf32(scores);