ggml-hexagon: streamline flash attention operations by removing redundant checks for FP32
This commit is contained in:
parent
fe1f3fbc2a
commit
c2fe8a12bb
|
|
@ -470,9 +470,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
|
||||||
const HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
|
if (is_q_fp32) {
|
||||||
|
hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16
|
||||||
|
}
|
||||||
|
|
||||||
|
const HVX_Vector slope_vec = hvx_vec_splat_f32(slope);
|
||||||
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
|
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
|
||||||
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
|
||||||
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
|
||||||
|
|
@ -495,11 +498,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||||
for (int j = 0; j < VLEN_FP32; j += 2) {
|
for (int j = 0; j < VLEN_FP32; j += 2) {
|
||||||
const uint32_t cur_ic = ic + j;
|
const uint32_t cur_ic = ic + j;
|
||||||
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
||||||
if (is_q_fp32) {
|
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
|
||||||
hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
|
|
||||||
} else {
|
|
||||||
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
HVX_Vector scores = *(HVX_Vector *) scores_arr;
|
HVX_Vector scores = *(HVX_Vector *) scores_arr;
|
||||||
|
|
@ -569,13 +568,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
||||||
for (; ic < current_block_size; ++ic) {
|
for (; ic < current_block_size; ++ic) {
|
||||||
float s_val;
|
float s_val;
|
||||||
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
|
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
|
||||||
|
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
||||||
if (is_q_fp32) {
|
|
||||||
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
|
||||||
} else {
|
|
||||||
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (logit_softcap != 0.0f) {
|
if (logit_softcap != 0.0f) {
|
||||||
s_val = logit_softcap * tanhf(s_val);
|
s_val = logit_softcap * tanhf(s_val);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue