diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 441f619698..af32c1e2c7 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -470,9 +470,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } } - const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; - const HVX_Vector slope_vec = hvx_vec_splat_f32(slope); + uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + if (is_q_fp32) { + hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16 + } + const HVX_Vector slope_vec = hvx_vec_splat_f32(slope); for (uint32_t ib = 0; ib < n_blocks; ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); @@ -495,11 +498,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in for (int j = 0; j < VLEN_FP32; j += 2) { const uint32_t cur_ic = ic + j; const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; - if (is_q_fp32) { - hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); - } else { - hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); - } + hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); } HVX_Vector scores = *(HVX_Vector *) scores_arr; @@ -569,13 +568,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in for (; ic < current_block_size; ++ic) { float s_val; const uint8_t * k_ptr = k_base + ic * size_k_row_padded; - - if (is_q_fp32) { - hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); - } else { - hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); - } - + hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); if (logit_softcap != 0.0f) { s_val = logit_softcap * tanhf(s_val); }