hmx flash-attn: refine cost model coefficients based on profiling data
This commit is contained in:
parent
a5a4d3c370
commit
ae78a998c8
|
|
@ -90,12 +90,17 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV,
|
||||||
// overhead + g_br * per_gbr + g_br² * per_gbr2 + Bc * per_bc + g_br * Bc * per_gbr_bc
|
// overhead + g_br * per_gbr + g_br² * per_gbr2 + Bc * per_bc + g_br * Bc * per_gbr_bc
|
||||||
//
|
//
|
||||||
// Cost model (minimization objective):
|
// Cost model (minimization objective):
|
||||||
// ceil(qo_len / Br) * br_block_cost + ceil(kv_len / Bc) * bc_block_cost
|
// Q * (c_q_fixed + K * c_iter_fixed), where Q = ceil(qo/Br), K = ceil(kv/Bc)
|
||||||
//
|
//
|
||||||
// br_block_cost: each extra Q block re-iterates all KV blocks (DMA + interleave +
|
// Rationale: partial Q blocks only allocate HMX row tiles for actual rows
|
||||||
// HMX QK/PV + softmax). Scales with kv_len.
|
// (g_br_actual = align_up(n_q_rows*G, 32)), so the data-proportional HMX/HVX
|
||||||
// bc_block_cost: each extra KV block adds one pipeline iteration per Q block
|
// work is invariant in (Br, Bc) — it equals O(qo*kv) regardless of blocking.
|
||||||
// (3 barriers + DMA + interleave). Scales with qo_len.
|
// Only two terms depend on (Br, Bc):
|
||||||
|
// c_q_fixed — per-Q-block overhead (q_load + epilogue o_update + o_norm +
|
||||||
|
// o_store + HMX queue flush between Q blocks), ~1400us.
|
||||||
|
// c_iter_fixed — per-KV-iter pipeline overhead (HMX queue push/pop + DMA pop
|
||||||
|
// + HVX barriers), ~200us.
|
||||||
|
// Coefficients are absolute (independent of qo_len/kv_len).
|
||||||
//
|
//
|
||||||
// Pipeline constraint: when kv_len >= MIN_KV_BLOCKS * bc_unit and n_threads >= 2,
|
// Pipeline constraint: when kv_len >= MIN_KV_BLOCKS * bc_unit and n_threads >= 2,
|
||||||
// enforce Bc <= kv_len / MIN_KV_BLOCKS so that n_kv_blocks >= MIN_KV_BLOCKS.
|
// enforce Bc <= kv_len / MIN_KV_BLOCKS so that n_kv_blocks >= MIN_KV_BLOCKS.
|
||||||
|
|
@ -187,12 +192,9 @@ static int hmx_fa_find_chunk_size(size_t * Br_out,
|
||||||
const size_t Bc_limit = can_pipeline
|
const size_t Bc_limit = can_pipeline
|
||||||
? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit)
|
? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit)
|
||||||
: (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit);
|
: (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit);
|
||||||
|
// Cost coefficients calibrated from profiling
|
||||||
// Cost coefficients: extra Q block re-scans all KV; extra KV block adds one
|
const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store
|
||||||
// pipeline iteration per Q block. Weight br_block_cost higher because it
|
const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers
|
||||||
// involves full KV re-traversal (DMA + interleave + HMX + softmax per block).
|
|
||||||
const size_t br_block_cost = kv_len * 3; // scales with kv_len (expensive)
|
|
||||||
const size_t bc_block_cost = qo_len * 2; // scales with qo_len (cheaper with pipeline overlap)
|
|
||||||
|
|
||||||
size_t best_cost = SIZE_MAX, best_mn = 0;
|
size_t best_cost = SIZE_MAX, best_mn = 0;
|
||||||
size_t best_Br = 0, best_Bc = 0;
|
size_t best_Br = 0, best_Bc = 0;
|
||||||
|
|
@ -229,7 +231,7 @@ static int hmx_fa_find_chunk_size(size_t * Br_out,
|
||||||
|
|
||||||
const size_t q_blocks = (qo_len + Br - 1) / Br;
|
const size_t q_blocks = (qo_len + Br - 1) / Br;
|
||||||
const size_t kv_blocks = (kv_len + Bc - 1) / Bc;
|
const size_t kv_blocks = (kv_len + Bc - 1) / Bc;
|
||||||
const size_t cost = q_blocks * br_block_cost + kv_blocks * bc_block_cost;
|
const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed);
|
||||||
const size_t mn = Br * Bc;
|
const size_t mn = Br * Bc;
|
||||||
|
|
||||||
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
|
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue