ggml-hexagon: optimize flash attention calculations with improved variable handling

This commit is contained in:
chraac 2026-02-02 23:59:54 +08:00
parent b022069260
commit fe1f3fbc2a
1 changed files with 5 additions and 3 deletions

View File

@ -409,6 +409,9 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
const HVX_Vector logit_cap = hvx_vec_splat_f32(logit_softcap);
for (uint32_t ir = ir0; ir < ir1; ++ir) {
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
@ -468,7 +471,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
const HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
@ -504,7 +507,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
// 2. Softcap
if (logit_softcap != 0.0f) {
scores = hvx_vec_tanh_f32(scores);
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
scores = Q6_Vsf_equals_Vqf32(scores);
}
@ -518,7 +521,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
scores = Q6_Vsf_equals_Vqf32(scores);