ggml-hexagon: optimize flash attention by changing slope vector type to F16
This commit is contained in:
parent
764423f774
commit
e307f68c3b
|
|
@ -359,7 +359,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
|
||||
}
|
||||
|
||||
const HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
|
||||
const HVX_Vector slope_vec = hvx_vec_splat_f16(slope);
|
||||
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||
|
|
@ -398,13 +398,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
if (mask) {
|
||||
const __fp16 * mp = m_base + ic;
|
||||
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
|
||||
|
||||
HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
|
||||
|
||||
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
|
||||
|
||||
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
|
||||
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
|
||||
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
|
||||
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue