hmx flash-attn: refine cost model coefficients based on profiling data

This commit is contained in:
Yiwei Shao 2026-04-23 16:01:05 -07:00
parent a5a4d3c370
commit ae78a998c8
1 changed files with 14 additions and 12 deletions

View File

@ -90,12 +90,17 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV,
// overhead + g_br * per_gbr + g_br² * per_gbr2 + Bc * per_bc + g_br * Bc * per_gbr_bc
//
// Cost model (minimization objective):
// ceil(qo_len / Br) * br_block_cost + ceil(kv_len / Bc) * bc_block_cost
// Q * (c_q_fixed + K * c_iter_fixed), where Q = ceil(qo/Br), K = ceil(kv/Bc)
//
// br_block_cost: each extra Q block re-iterates all KV blocks (DMA + interleave +
// HMX QK/PV + softmax). Scales with kv_len.
// bc_block_cost: each extra KV block adds one pipeline iteration per Q block
// (3 barriers + DMA + interleave). Scales with qo_len.
// Rationale: partial Q blocks only allocate HMX row tiles for actual rows
// (g_br_actual = align_up(n_q_rows*G, 32)), so the data-proportional HMX/HVX
// work is invariant in (Br, Bc) — it equals O(qo*kv) regardless of blocking.
// Only two terms depend on (Br, Bc):
// c_q_fixed — per-Q-block overhead (q_load + epilogue o_update + o_norm +
// o_store + HMX queue flush between Q blocks), ~1400us.
// c_iter_fixed — per-KV-iter pipeline overhead (HMX queue push/pop + DMA pop
// + HVX barriers), ~200us.
// Coefficients are absolute (independent of qo_len/kv_len).
//
// Pipeline constraint: when kv_len >= MIN_KV_BLOCKS * bc_unit and n_threads >= 2,
// enforce Bc <= kv_len / MIN_KV_BLOCKS so that n_kv_blocks >= MIN_KV_BLOCKS.
@ -187,12 +192,9 @@ static int hmx_fa_find_chunk_size(size_t * Br_out,
const size_t Bc_limit = can_pipeline
? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit)
: (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit);
// Cost coefficients: extra Q block re-scans all KV; extra KV block adds one
// pipeline iteration per Q block. Weight br_block_cost higher because it
// involves full KV re-traversal (DMA + interleave + HMX + softmax per block).
const size_t br_block_cost = kv_len * 3; // scales with kv_len (expensive)
const size_t bc_block_cost = qo_len * 2; // scales with qo_len (cheaper with pipeline overlap)
// Cost coefficients calibrated from profiling
const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store
const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers
size_t best_cost = SIZE_MAX, best_mn = 0;
size_t best_Br = 0, best_Bc = 0;
@ -229,7 +231,7 @@ static int hmx_fa_find_chunk_size(size_t * Br_out,
const size_t q_blocks = (qo_len + Br - 1) / Br;
const size_t kv_blocks = (kv_len + Bc - 1) / Bc;
const size_t cost = q_blocks * br_block_cost + kv_blocks * bc_block_cost;
const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed);
const size_t mn = Br * Bc;
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {