hmx: optimize FA softmax mask phase (no-ALiBi fast path + GQA dedup)

This commit is contained in:
Yiwei Shao 2026-04-15 20:22:07 -07:00
parent 5ce4ad9db0
commit 4f42c8a939
1 changed files with 64 additions and 33 deletions

View File

@ -602,6 +602,7 @@ typedef struct {
uint32_t kv_start;
uint32_t q_start;
uint32_t ib3;
bool has_alibi; // true when max_bias != 0 (need slope * mask + add)
// ALiBi per-head slopes (indexed by GQA-merged row: slope[r] for r in [0, n_rows_g))
// slope[r] = 1.0 when max_bias == 0 (no ALiBi)
@ -690,13 +691,27 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
}
}
// Apply mask (with ALiBi slope scaling) & compute rowmax(S)
// CPU reference: S[i] += slope * mask[i]
// slope == 1.0 when max_bias == 0 (no ALiBi)
const HVX_Vector v_slope0 = hvx_vec_splat_f16((__fp16) args->slopes[r + 0]);
const HVX_Vector v_slope1 = (r + 1 < (int) n_rows_g) ?
hvx_vec_splat_f16((__fp16) args->slopes[r + 1]) :
Q6_V_vzero();
// Apply mask & compute rowmax(S)
//
// Optimizations over baseline:
// A. No-ALiBi fast path: when max_bias==0 (slope≡1.0), mask values are
// 0 (unmasked) or -inf (masked). S += 1.0*mask is a no-op for
// unmasked positions; masked ones are overwritten by vmux. So we
// skip the mul+add entirely and just do vmux — same as htp-ops-lib.
// B. GQA mask row dedup: G consecutive Q rows share one mask row
// (qi = r / G). Reuse mask vector when qi is unchanged between
// row0 and row1 (saves ~75% of VTCM loads for G=4).
// ALiBi slopes — only needed when has_alibi (scheme A)
HVX_Vector v_slope0, v_slope1;
if (args->has_alibi) {
v_slope0 = hvx_vec_splat_f16((__fp16) args->slopes[r + 0]);
v_slope1 = (r + 1 < (int) n_rows_g) ?
hvx_vec_splat_f16((__fp16) args->slopes[r + 1]) :
Q6_V_vzero();
}
const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c)
HVX_Vector v_s_rowmax0 = v_neg_inf;
HVX_Vector v_s_rowmax1 = v_neg_inf;
@ -710,18 +725,22 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
if (args->mask_vtcm) {
// Read mask from VTCM buffer (DMA'd per KV block).
// GQA mapping: G consecutive Q rows share the same mask row.
// GQA dedup (scheme B): skip load when qi unchanged.
const size_t qi0 = (r + 0) / G;
v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c);
v_mask1 = v_neg_inf;
if (r + 1 < (int) n_rows_g) {
const size_t qi1 = (r + 1) / G;
v_mask1 = *(const HVX_UVector *) (args->mask_vtcm + qi1 * args->mask_vtcm_row_stride + c);
if (qi1 == qi0) {
v_mask1 = v_mask0; // scheme B: reuse — same mask row
} else {
v_mask1 = *(const HVX_UVector *) (args->mask_vtcm + qi1 * args->mask_vtcm_row_stride + c);
}
}
} else {
// Fallback: read mask directly from DDR (when mask->ne[2] > 1).
const struct htp_tensor * mask = args->mask;
const size_t q_idx0 = args->q_start + (r + 0) / G;
const size_t q_idx0 = args->q_start + ((r + 0) / G);
const size_t h_idx0 = args->kv_head * G + (r + 0) % G;
const uint32_t im2_0 = h_idx0 % mask->ne[2];
const uint32_t im3_0 = args->ib3 % mask->ne[3];
@ -733,35 +752,45 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
v_mask1 = v_neg_inf;
if (r + 1 < (int) n_rows_g) {
const size_t q_idx1 = args->q_start + (r + 1) / G;
const size_t h_idx1 = args->kv_head * G + (r + 1) % G;
const uint32_t im2_1 = h_idx1 % mask->ne[2];
const uint32_t im3_1 = args->ib3 % mask->ne[3];
const __fp16 * m1_ptr =
(const __fp16 *) ((const uint8_t *) mask->data +
q_idx1 * mask->nb[1] + im2_1 * mask->nb[2] +
im3_1 * mask->nb[3]) +
args->kv_start + c;
v_mask1 = *(const HVX_UVector *) m1_ptr;
const size_t q_idx1 = args->q_start + ((r + 1) / G);
if (q_idx1 == q_idx0) {
// scheme B: same mask row in DDR path
v_mask1 = v_mask0;
} else {
const size_t h_idx1 = args->kv_head * G + (r + 1) % G;
const uint32_t im2_1 = h_idx1 % mask->ne[2];
const uint32_t im3_1 = args->ib3 % mask->ne[3];
const __fp16 * m1_ptr =
(const __fp16 *) ((const uint8_t *) mask->data +
q_idx1 * mask->nb[1] + im2_1 * mask->nb[2] +
im3_1 * mask->nb[3]) +
args->kv_start + c;
v_mask1 = *(const HVX_UVector *) m1_ptr;
}
}
}
// Apply slope * mask: threshold first, then add to S
// Mask values below -16.0 (0xcc00) are treated as -inf (causal mask).
// For non-inf mask values, scale by ALiBi slope and add to S.
const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00);
HVX_VectorPred q_keep0 =
// Threshold: mask values below -16.0 are treated as -inf (causal mask).
HVX_VectorPred q_keep0 =
Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask0, v_threshold), q_tail_keep);
HVX_VectorPred q_keep1 =
HVX_VectorPred q_keep1 =
Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask1, v_threshold), q_tail_keep);
// S += slope * mask (only for non-inf positions)
HVX_Vector v_sm0 = hvx_vec_mul_f16_f16(v_mask0, v_slope0);
HVX_Vector v_sm1 = hvx_vec_mul_f16_f16(v_mask1, v_slope1);
my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0,
hvx_vec_add_f16_f16(my_row_buf0[ci], v_sm0), v_neg_inf);
my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1,
hvx_vec_add_f16_f16(my_row_buf1[ci], v_sm1), v_neg_inf);
if (args->has_alibi) {
// ALiBi path: S += slope * mask (full mul + add)
HVX_Vector v_sm0 = hvx_vec_mul_f16_f16(v_mask0, v_slope0);
HVX_Vector v_sm1 = hvx_vec_mul_f16_f16(v_mask1, v_slope1);
my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0,
hvx_vec_add_f16_f16(my_row_buf0[ci], v_sm0), v_neg_inf);
my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1,
hvx_vec_add_f16_f16(my_row_buf1[ci], v_sm1), v_neg_inf);
} else {
// No-ALiBi fast path (scheme A): slope≡1.0, mask is 0 or -inf.
// S += 1.0 * mask is a no-op for unmasked (mask=0); vmux handles masked.
// Same approach as htp-ops-lib flash_attn.c — skip mul + add entirely.
my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, my_row_buf0[ci], v_neg_inf);
my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, my_row_buf1[ci], v_neg_inf);
}
} else {
if (ne < 64) {
my_row_buf0[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf0[ci], v_neg_inf);
@ -1472,6 +1501,7 @@ int op_hmx_flash_attn_ext(struct htp_ops_context * octx) {
sargs.kv_start = kv_start;
sargs.q_start = q_start;
sargs.ib3 = ib3;
sargs.has_alibi = (factx.max_bias != 0.0f);
sargs.slopes = factx.slopes;
fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g);
sargs.mask = mask;
@ -1583,6 +1613,7 @@ int op_hmx_flash_attn_ext(struct htp_ops_context * octx) {
sargs.kv_start = kv_start;
sargs.q_start = q_start;
sargs.ib3 = ib3;
sargs.has_alibi = (factx.max_bias != 0.0f);
sargs.slopes = factx.slopes;
fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g);
sargs.mask = mask;