ggml-hexagon: optimize flash attention calculations with improved variable handling
This commit is contained in:
parent
b022069260
commit
fe1f3fbc2a
|
|
@ -409,6 +409,9 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
|
||||
const HVX_Vector logit_cap = hvx_vec_splat_f32(logit_softcap);
|
||||
|
||||
for (uint32_t ir = ir0; ir < ir1; ++ir) {
|
||||
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
|
||||
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
|
||||
|
|
@ -468,7 +471,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
}
|
||||
|
||||
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
||||
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
|
||||
const HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
|
||||
|
||||
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
|
||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||
|
|
@ -504,7 +507,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
// 2. Softcap
|
||||
if (logit_softcap != 0.0f) {
|
||||
scores = hvx_vec_tanh_f32(scores);
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap));
|
||||
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
}
|
||||
|
||||
|
|
@ -518,7 +521,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
|
||||
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));
|
||||
|
||||
HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
|
||||
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec);
|
||||
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
|
||||
scores = Q6_Vsf_equals_Vqf32(scores);
|
||||
|
|
|
|||
Loading…
Reference in New Issue