ggml-hexagon: streamline flash attention operations by removing redundant checks for FP32

This commit is contained in:
chraac 2026-02-03 00:14:51 +08:00
parent fe1f3fbc2a
commit c2fe8a12bb
1 changed files with 7 additions and 14 deletions

View File

@ -470,9 +470,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
}
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
const HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
if (is_q_fp32) {
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
}
const HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
@ -495,11 +498,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
for (int j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic + j;
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
if (is_q_fp32) {
hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
} else {
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
}
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
}
HVX_Vector scores = *(HVX_Vector *) scores_arr;
@ -569,13 +568,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
for (; ic < current_block_size; ++ic) {
float s_val;
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
if (is_q_fp32) {
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
} else {
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
}
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
if (logit_softcap != 0.0f) {
s_val = logit_softcap * tanhf(s_val);
}