diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index ec00ebafe3..48406ef5c9 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -90,12 +90,17 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, // overhead + g_br * per_gbr + g_br² * per_gbr2 + Bc * per_bc + g_br * Bc * per_gbr_bc // // Cost model (minimization objective): -// ceil(qo_len / Br) * br_block_cost + ceil(kv_len / Bc) * bc_block_cost +// Q * (c_q_fixed + K * c_iter_fixed), where Q = ceil(qo/Br), K = ceil(kv/Bc) // -// br_block_cost: each extra Q block re-iterates all KV blocks (DMA + interleave + -// HMX QK/PV + softmax). Scales with kv_len. -// bc_block_cost: each extra KV block adds one pipeline iteration per Q block -// (3 barriers + DMA + interleave). Scales with qo_len. +// Rationale: partial Q blocks only allocate HMX row tiles for actual rows +// (g_br_actual = align_up(n_q_rows*G, 32)), so the data-proportional HMX/HVX +// work is invariant in (Br, Bc) — it equals O(qo*kv) regardless of blocking. +// Only two terms depend on (Br, Bc): +// c_q_fixed — per-Q-block overhead (q_load + epilogue o_update + o_norm + +// o_store + HMX queue flush between Q blocks), ~1400us. +// c_iter_fixed — per-KV-iter pipeline overhead (HMX queue push/pop + DMA pop +// + HVX barriers), ~200us. +// Coefficients are absolute (independent of qo_len/kv_len). // // Pipeline constraint: when kv_len >= MIN_KV_BLOCKS * bc_unit and n_threads >= 2, // enforce Bc <= kv_len / MIN_KV_BLOCKS so that n_kv_blocks >= MIN_KV_BLOCKS. @@ -187,12 +192,9 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) : (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit); - - // Cost coefficients: extra Q block re-scans all KV; extra KV block adds one - // pipeline iteration per Q block. Weight br_block_cost higher because it - // involves full KV re-traversal (DMA + interleave + HMX + softmax per block). - const size_t br_block_cost = kv_len * 3; // scales with kv_len (expensive) - const size_t bc_block_cost = qo_len * 2; // scales with qo_len (cheaper with pipeline overlap) + // Cost coefficients calibrated from profiling + const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store + const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers size_t best_cost = SIZE_MAX, best_mn = 0; size_t best_Br = 0, best_Bc = 0; @@ -229,7 +231,7 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, const size_t q_blocks = (qo_len + Br - 1) / Br; const size_t kv_blocks = (kv_len + Bc - 1) / Bc; - const size_t cost = q_blocks * br_block_cost + kv_blocks * bc_block_cost; + const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed); const size_t mn = Br * Bc; if (cost < best_cost || (cost == best_cost && mn > best_mn)) {