wip
This commit is contained in:
parent
e307f68c3b
commit
274ade5ca2
|
|
@ -379,7 +379,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
|
||||
// 1. Compute scores
|
||||
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
|
||||
for (int j = 0; j < VLEN_FP32; j += 2) {
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
|
||||
const uint32_t cur_ic = ic + j;
|
||||
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
|
||||
hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
|
||||
|
|
@ -432,7 +432,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
|
|||
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
|
||||
*(HVX_Vector *) p_arr = P;
|
||||
|
||||
for (int j = 0; j < VLEN_FP32; j += 2) {
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
|
||||
const uint32_t cur_ic = ic2 + j;
|
||||
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
|
||||
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + size_v_row_padded, p_arr[j], p_arr[j + 1], DV);
|
||||
|
|
|
|||
Loading…
Reference in New Issue