ggml-hexagon: optimize flash attention by changing slope vector type to F16

This commit is contained in:
chraac 2026-02-07 12:20:41 +08:00
parent 764423f774
commit e307f68c3b
1 changed files with 3 additions and 8 deletions

View File

@ -359,7 +359,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
}
const HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
@ -398,13 +398,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
if (mask) {
const __fp16 * mp = m_base + ic;
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
scores = Q6_Vsf_equals_Vqf32(scores);
}