diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 0b8cac6f93..8ee4167adb 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -602,6 +602,7 @@ typedef struct { uint32_t kv_start; uint32_t q_start; uint32_t ib3; + bool has_alibi; // true when max_bias != 0 (need slope * mask + add) // ALiBi per-head slopes (indexed by GQA-merged row: slope[r] for r in [0, n_rows_g)) // slope[r] = 1.0 when max_bias == 0 (no ALiBi) @@ -690,13 +691,27 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { } } - // Apply mask (with ALiBi slope scaling) & compute rowmax(S) - // CPU reference: S[i] += slope * mask[i] - // slope == 1.0 when max_bias == 0 (no ALiBi) - const HVX_Vector v_slope0 = hvx_vec_splat_f16((__fp16) args->slopes[r + 0]); - const HVX_Vector v_slope1 = (r + 1 < (int) n_rows_g) ? - hvx_vec_splat_f16((__fp16) args->slopes[r + 1]) : - Q6_V_vzero(); + // Apply mask & compute rowmax(S) + // + // Optimizations over baseline: + // A. No-ALiBi fast path: when max_bias==0 (slope≡1.0), mask values are + // 0 (unmasked) or -inf (masked). S += 1.0*mask is a no-op for + // unmasked positions; masked ones are overwritten by vmux. So we + // skip the mul+add entirely and just do vmux — same as htp-ops-lib. + // B. GQA mask row dedup: G consecutive Q rows share one mask row + // (qi = r / G). Reuse mask vector when qi is unchanged between + // row0 and row1 (saves ~75% of VTCM loads for G=4). + + // ALiBi slopes — only needed when has_alibi (scheme A) + HVX_Vector v_slope0, v_slope1; + if (args->has_alibi) { + v_slope0 = hvx_vec_splat_f16((__fp16) args->slopes[r + 0]); + v_slope1 = (r + 1 < (int) n_rows_g) ? + hvx_vec_splat_f16((__fp16) args->slopes[r + 1]) : + Q6_V_vzero(); + } + + const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c) HVX_Vector v_s_rowmax0 = v_neg_inf; HVX_Vector v_s_rowmax1 = v_neg_inf; @@ -710,18 +725,22 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { if (args->mask_vtcm) { // Read mask from VTCM buffer (DMA'd per KV block). - // GQA mapping: G consecutive Q rows share the same mask row. + // GQA dedup (scheme B): skip load when qi unchanged. const size_t qi0 = (r + 0) / G; v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c); v_mask1 = v_neg_inf; if (r + 1 < (int) n_rows_g) { const size_t qi1 = (r + 1) / G; - v_mask1 = *(const HVX_UVector *) (args->mask_vtcm + qi1 * args->mask_vtcm_row_stride + c); + if (qi1 == qi0) { + v_mask1 = v_mask0; // scheme B: reuse — same mask row + } else { + v_mask1 = *(const HVX_UVector *) (args->mask_vtcm + qi1 * args->mask_vtcm_row_stride + c); + } } } else { // Fallback: read mask directly from DDR (when mask->ne[2] > 1). const struct htp_tensor * mask = args->mask; - const size_t q_idx0 = args->q_start + (r + 0) / G; + const size_t q_idx0 = args->q_start + ((r + 0) / G); const size_t h_idx0 = args->kv_head * G + (r + 0) % G; const uint32_t im2_0 = h_idx0 % mask->ne[2]; const uint32_t im3_0 = args->ib3 % mask->ne[3]; @@ -733,35 +752,45 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { v_mask1 = v_neg_inf; if (r + 1 < (int) n_rows_g) { - const size_t q_idx1 = args->q_start + (r + 1) / G; - const size_t h_idx1 = args->kv_head * G + (r + 1) % G; - const uint32_t im2_1 = h_idx1 % mask->ne[2]; - const uint32_t im3_1 = args->ib3 % mask->ne[3]; - const __fp16 * m1_ptr = - (const __fp16 *) ((const uint8_t *) mask->data + - q_idx1 * mask->nb[1] + im2_1 * mask->nb[2] + - im3_1 * mask->nb[3]) + - args->kv_start + c; - v_mask1 = *(const HVX_UVector *) m1_ptr; + const size_t q_idx1 = args->q_start + ((r + 1) / G); + if (q_idx1 == q_idx0) { + // scheme B: same mask row in DDR path + v_mask1 = v_mask0; + } else { + const size_t h_idx1 = args->kv_head * G + (r + 1) % G; + const uint32_t im2_1 = h_idx1 % mask->ne[2]; + const uint32_t im3_1 = args->ib3 % mask->ne[3]; + const __fp16 * m1_ptr = + (const __fp16 *) ((const uint8_t *) mask->data + + q_idx1 * mask->nb[1] + im2_1 * mask->nb[2] + + im3_1 * mask->nb[3]) + + args->kv_start + c; + v_mask1 = *(const HVX_UVector *) m1_ptr; + } } } - // Apply slope * mask: threshold first, then add to S - // Mask values below -16.0 (0xcc00) are treated as -inf (causal mask). - // For non-inf mask values, scale by ALiBi slope and add to S. - const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); - HVX_VectorPred q_keep0 = + // Threshold: mask values below -16.0 are treated as -inf (causal mask). + HVX_VectorPred q_keep0 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask0, v_threshold), q_tail_keep); - HVX_VectorPred q_keep1 = + HVX_VectorPred q_keep1 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask1, v_threshold), q_tail_keep); - // S += slope * mask (only for non-inf positions) - HVX_Vector v_sm0 = hvx_vec_mul_f16_f16(v_mask0, v_slope0); - HVX_Vector v_sm1 = hvx_vec_mul_f16_f16(v_mask1, v_slope1); - my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, - hvx_vec_add_f16_f16(my_row_buf0[ci], v_sm0), v_neg_inf); - my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, - hvx_vec_add_f16_f16(my_row_buf1[ci], v_sm1), v_neg_inf); + if (args->has_alibi) { + // ALiBi path: S += slope * mask (full mul + add) + HVX_Vector v_sm0 = hvx_vec_mul_f16_f16(v_mask0, v_slope0); + HVX_Vector v_sm1 = hvx_vec_mul_f16_f16(v_mask1, v_slope1); + my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, + hvx_vec_add_f16_f16(my_row_buf0[ci], v_sm0), v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, + hvx_vec_add_f16_f16(my_row_buf1[ci], v_sm1), v_neg_inf); + } else { + // No-ALiBi fast path (scheme A): slope≡1.0, mask is 0 or -inf. + // S += 1.0 * mask is a no-op for unmasked (mask=0); vmux handles masked. + // Same approach as htp-ops-lib flash_attn.c — skip mul + add entirely. + my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, my_row_buf0[ci], v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, my_row_buf1[ci], v_neg_inf); + } } else { if (ne < 64) { my_row_buf0[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf0[ci], v_neg_inf); @@ -1472,6 +1501,7 @@ int op_hmx_flash_attn_ext(struct htp_ops_context * octx) { sargs.kv_start = kv_start; sargs.q_start = q_start; sargs.ib3 = ib3; + sargs.has_alibi = (factx.max_bias != 0.0f); sargs.slopes = factx.slopes; fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); sargs.mask = mask; @@ -1583,6 +1613,7 @@ int op_hmx_flash_attn_ext(struct htp_ops_context * octx) { sargs.kv_start = kv_start; sargs.q_start = q_start; sargs.ib3 = ib3; + sargs.has_alibi = (factx.max_bias != 0.0f); sargs.slopes = factx.slopes; fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); sargs.mask = mask;