This commit is contained in:
chraac 2026-02-01 22:16:50 +08:00
parent 080db98920
commit b022069260
1 changed files with 2 additions and 3 deletions

View File

@ -468,7 +468,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
}
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
const HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
@ -482,8 +482,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
// Inner loop processing the block from VTCM
uint32_t ic = 0;
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
// Process in blocks of 32 (VLEN_FP32)
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
HVX_Vector_x4 scores_x4;
@ -515,6 +513,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
const __fp16 * mp = m_base + ic;
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00);
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16);
HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair));