hmx: optimize FA softmax mask phase (no-ALiBi fast path + GQA dedup)
This commit is contained in:
parent
5ce4ad9db0
commit
4f42c8a939
|
|
@ -602,6 +602,7 @@ typedef struct {
|
||||||
uint32_t kv_start;
|
uint32_t kv_start;
|
||||||
uint32_t q_start;
|
uint32_t q_start;
|
||||||
uint32_t ib3;
|
uint32_t ib3;
|
||||||
|
bool has_alibi; // true when max_bias != 0 (need slope * mask + add)
|
||||||
|
|
||||||
// ALiBi per-head slopes (indexed by GQA-merged row: slope[r] for r in [0, n_rows_g))
|
// ALiBi per-head slopes (indexed by GQA-merged row: slope[r] for r in [0, n_rows_g))
|
||||||
// slope[r] = 1.0 when max_bias == 0 (no ALiBi)
|
// slope[r] = 1.0 when max_bias == 0 (no ALiBi)
|
||||||
|
|
@ -690,13 +691,27 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply mask (with ALiBi slope scaling) & compute rowmax(S)
|
// Apply mask & compute rowmax(S)
|
||||||
// CPU reference: S[i] += slope * mask[i]
|
//
|
||||||
// slope == 1.0 when max_bias == 0 (no ALiBi)
|
// Optimizations over baseline:
|
||||||
const HVX_Vector v_slope0 = hvx_vec_splat_f16((__fp16) args->slopes[r + 0]);
|
// A. No-ALiBi fast path: when max_bias==0 (slope≡1.0), mask values are
|
||||||
const HVX_Vector v_slope1 = (r + 1 < (int) n_rows_g) ?
|
// 0 (unmasked) or -inf (masked). S += 1.0*mask is a no-op for
|
||||||
hvx_vec_splat_f16((__fp16) args->slopes[r + 1]) :
|
// unmasked positions; masked ones are overwritten by vmux. So we
|
||||||
Q6_V_vzero();
|
// skip the mul+add entirely and just do vmux — same as htp-ops-lib.
|
||||||
|
// B. GQA mask row dedup: G consecutive Q rows share one mask row
|
||||||
|
// (qi = r / G). Reuse mask vector when qi is unchanged between
|
||||||
|
// row0 and row1 (saves ~75% of VTCM loads for G=4).
|
||||||
|
|
||||||
|
// ALiBi slopes — only needed when has_alibi (scheme A)
|
||||||
|
HVX_Vector v_slope0, v_slope1;
|
||||||
|
if (args->has_alibi) {
|
||||||
|
v_slope0 = hvx_vec_splat_f16((__fp16) args->slopes[r + 0]);
|
||||||
|
v_slope1 = (r + 1 < (int) n_rows_g) ?
|
||||||
|
hvx_vec_splat_f16((__fp16) args->slopes[r + 1]) :
|
||||||
|
Q6_V_vzero();
|
||||||
|
}
|
||||||
|
|
||||||
|
const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c)
|
||||||
|
|
||||||
HVX_Vector v_s_rowmax0 = v_neg_inf;
|
HVX_Vector v_s_rowmax0 = v_neg_inf;
|
||||||
HVX_Vector v_s_rowmax1 = v_neg_inf;
|
HVX_Vector v_s_rowmax1 = v_neg_inf;
|
||||||
|
|
@ -710,18 +725,22 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
|
||||||
|
|
||||||
if (args->mask_vtcm) {
|
if (args->mask_vtcm) {
|
||||||
// Read mask from VTCM buffer (DMA'd per KV block).
|
// Read mask from VTCM buffer (DMA'd per KV block).
|
||||||
// GQA mapping: G consecutive Q rows share the same mask row.
|
// GQA dedup (scheme B): skip load when qi unchanged.
|
||||||
const size_t qi0 = (r + 0) / G;
|
const size_t qi0 = (r + 0) / G;
|
||||||
v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c);
|
v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c);
|
||||||
v_mask1 = v_neg_inf;
|
v_mask1 = v_neg_inf;
|
||||||
if (r + 1 < (int) n_rows_g) {
|
if (r + 1 < (int) n_rows_g) {
|
||||||
const size_t qi1 = (r + 1) / G;
|
const size_t qi1 = (r + 1) / G;
|
||||||
v_mask1 = *(const HVX_UVector *) (args->mask_vtcm + qi1 * args->mask_vtcm_row_stride + c);
|
if (qi1 == qi0) {
|
||||||
|
v_mask1 = v_mask0; // scheme B: reuse — same mask row
|
||||||
|
} else {
|
||||||
|
v_mask1 = *(const HVX_UVector *) (args->mask_vtcm + qi1 * args->mask_vtcm_row_stride + c);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Fallback: read mask directly from DDR (when mask->ne[2] > 1).
|
// Fallback: read mask directly from DDR (when mask->ne[2] > 1).
|
||||||
const struct htp_tensor * mask = args->mask;
|
const struct htp_tensor * mask = args->mask;
|
||||||
const size_t q_idx0 = args->q_start + (r + 0) / G;
|
const size_t q_idx0 = args->q_start + ((r + 0) / G);
|
||||||
const size_t h_idx0 = args->kv_head * G + (r + 0) % G;
|
const size_t h_idx0 = args->kv_head * G + (r + 0) % G;
|
||||||
const uint32_t im2_0 = h_idx0 % mask->ne[2];
|
const uint32_t im2_0 = h_idx0 % mask->ne[2];
|
||||||
const uint32_t im3_0 = args->ib3 % mask->ne[3];
|
const uint32_t im3_0 = args->ib3 % mask->ne[3];
|
||||||
|
|
@ -733,35 +752,45 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) {
|
||||||
|
|
||||||
v_mask1 = v_neg_inf;
|
v_mask1 = v_neg_inf;
|
||||||
if (r + 1 < (int) n_rows_g) {
|
if (r + 1 < (int) n_rows_g) {
|
||||||
const size_t q_idx1 = args->q_start + (r + 1) / G;
|
const size_t q_idx1 = args->q_start + ((r + 1) / G);
|
||||||
const size_t h_idx1 = args->kv_head * G + (r + 1) % G;
|
if (q_idx1 == q_idx0) {
|
||||||
const uint32_t im2_1 = h_idx1 % mask->ne[2];
|
// scheme B: same mask row in DDR path
|
||||||
const uint32_t im3_1 = args->ib3 % mask->ne[3];
|
v_mask1 = v_mask0;
|
||||||
const __fp16 * m1_ptr =
|
} else {
|
||||||
(const __fp16 *) ((const uint8_t *) mask->data +
|
const size_t h_idx1 = args->kv_head * G + (r + 1) % G;
|
||||||
q_idx1 * mask->nb[1] + im2_1 * mask->nb[2] +
|
const uint32_t im2_1 = h_idx1 % mask->ne[2];
|
||||||
im3_1 * mask->nb[3]) +
|
const uint32_t im3_1 = args->ib3 % mask->ne[3];
|
||||||
args->kv_start + c;
|
const __fp16 * m1_ptr =
|
||||||
v_mask1 = *(const HVX_UVector *) m1_ptr;
|
(const __fp16 *) ((const uint8_t *) mask->data +
|
||||||
|
q_idx1 * mask->nb[1] + im2_1 * mask->nb[2] +
|
||||||
|
im3_1 * mask->nb[3]) +
|
||||||
|
args->kv_start + c;
|
||||||
|
v_mask1 = *(const HVX_UVector *) m1_ptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply slope * mask: threshold first, then add to S
|
// Threshold: mask values below -16.0 are treated as -inf (causal mask).
|
||||||
// Mask values below -16.0 (0xcc00) are treated as -inf (causal mask).
|
HVX_VectorPred q_keep0 =
|
||||||
// For non-inf mask values, scale by ALiBi slope and add to S.
|
|
||||||
const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00);
|
|
||||||
HVX_VectorPred q_keep0 =
|
|
||||||
Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask0, v_threshold), q_tail_keep);
|
Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask0, v_threshold), q_tail_keep);
|
||||||
HVX_VectorPred q_keep1 =
|
HVX_VectorPred q_keep1 =
|
||||||
Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask1, v_threshold), q_tail_keep);
|
Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask1, v_threshold), q_tail_keep);
|
||||||
|
|
||||||
// S += slope * mask (only for non-inf positions)
|
if (args->has_alibi) {
|
||||||
HVX_Vector v_sm0 = hvx_vec_mul_f16_f16(v_mask0, v_slope0);
|
// ALiBi path: S += slope * mask (full mul + add)
|
||||||
HVX_Vector v_sm1 = hvx_vec_mul_f16_f16(v_mask1, v_slope1);
|
HVX_Vector v_sm0 = hvx_vec_mul_f16_f16(v_mask0, v_slope0);
|
||||||
my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0,
|
HVX_Vector v_sm1 = hvx_vec_mul_f16_f16(v_mask1, v_slope1);
|
||||||
hvx_vec_add_f16_f16(my_row_buf0[ci], v_sm0), v_neg_inf);
|
my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0,
|
||||||
my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1,
|
hvx_vec_add_f16_f16(my_row_buf0[ci], v_sm0), v_neg_inf);
|
||||||
hvx_vec_add_f16_f16(my_row_buf1[ci], v_sm1), v_neg_inf);
|
my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1,
|
||||||
|
hvx_vec_add_f16_f16(my_row_buf1[ci], v_sm1), v_neg_inf);
|
||||||
|
} else {
|
||||||
|
// No-ALiBi fast path (scheme A): slope≡1.0, mask is 0 or -inf.
|
||||||
|
// S += 1.0 * mask is a no-op for unmasked (mask=0); vmux handles masked.
|
||||||
|
// Same approach as htp-ops-lib flash_attn.c — skip mul + add entirely.
|
||||||
|
my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, my_row_buf0[ci], v_neg_inf);
|
||||||
|
my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, my_row_buf1[ci], v_neg_inf);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (ne < 64) {
|
if (ne < 64) {
|
||||||
my_row_buf0[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf0[ci], v_neg_inf);
|
my_row_buf0[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf0[ci], v_neg_inf);
|
||||||
|
|
@ -1472,6 +1501,7 @@ int op_hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||||
sargs.kv_start = kv_start;
|
sargs.kv_start = kv_start;
|
||||||
sargs.q_start = q_start;
|
sargs.q_start = q_start;
|
||||||
sargs.ib3 = ib3;
|
sargs.ib3 = ib3;
|
||||||
|
sargs.has_alibi = (factx.max_bias != 0.0f);
|
||||||
sargs.slopes = factx.slopes;
|
sargs.slopes = factx.slopes;
|
||||||
fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g);
|
fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g);
|
||||||
sargs.mask = mask;
|
sargs.mask = mask;
|
||||||
|
|
@ -1583,6 +1613,7 @@ int op_hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||||
sargs.kv_start = kv_start;
|
sargs.kv_start = kv_start;
|
||||||
sargs.q_start = q_start;
|
sargs.q_start = q_start;
|
||||||
sargs.ib3 = ib3;
|
sargs.ib3 = ib3;
|
||||||
|
sargs.has_alibi = (factx.max_bias != 0.0f);
|
||||||
sargs.slopes = factx.slopes;
|
sargs.slopes = factx.slopes;
|
||||||
fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g);
|
fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g);
|
||||||
sargs.mask = mask;
|
sargs.mask = mask;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue