diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 15b40f77f5..b47a203c0d 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -359,7 +359,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16 } - const HVX_Vector slope_vec = hvx_vec_splat_f32(slope); + const HVX_Vector slope_vec = hvx_vec_splat_f16(slope); for (uint32_t ib = 0; ib < n_blocks; ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); @@ -398,13 +398,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in if (mask) { const __fp16 * mp = m_base + ic; HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp; - - HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00); - HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16); - - HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair)); - - HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec); + HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); + HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores); scores = Q6_Vsf_equals_Vqf32(scores); }