vulkan: use fewer FA rows for small cache runs (#18280)

This commit is contained in:
Ruben Ortlam 2025-12-24 08:59:14 +01:00 committed by GitHub
parent cf2ffc02bc
commit 7f459c98e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 29 additions and 27 deletions

View File

@ -379,18 +379,18 @@ enum FaCodePath {
}; };
struct vk_fa_pipeline_state { struct vk_fa_pipeline_state {
vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc) vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
: HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {} : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
uint32_t HSK, HSV; uint32_t HSK, HSV;
bool small_rows; bool small_rows, small_cache;
FaCodePath path; FaCodePath path;
bool aligned; bool aligned;
bool f32acc; bool f32acc;
bool operator<(const vk_fa_pipeline_state &b) const { bool operator<(const vk_fa_pipeline_state &b) const {
return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) < return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) <
std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc); std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc);
} }
}; };
@ -2582,10 +2582,10 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t flash_attention_num_small_rows = 32;
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) { static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
if (hsv >= 192) { if (hsv >= 192) {
return 2; return 2;
} else if ((hsv | hsk) & 8) { } else if ((hsv | hsk) & 8 || small_cache) {
return 4; return 4;
} else { } else {
return 8; return 8;
@ -2607,9 +2607,8 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) {
} }
} }
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) { static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
GGML_UNUSED(clamp); GGML_UNUSED(clamp);
GGML_UNUSED(hsv);
if (path == FA_SCALAR) { if (path == FA_SCALAR) {
if (small_rows) { if (small_rows) {
@ -2618,9 +2617,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
if ((hsv | hsk) & 8) { if ((hsv | hsk) & 8) {
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
return {get_fa_scalar_num_large_rows(hsk, hsv), 64}; return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
} else { } else {
return {get_fa_scalar_num_large_rows(hsk, hsv), 32}; return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
} }
} }
} }
@ -2649,8 +2648,8 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
return {64, 64}; return {64, 64};
} }
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) { static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1]; return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
} }
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) { static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
@ -2992,11 +2991,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
align, disable_robustness, require_full_subgroups, required_subgroup_size); align, disable_robustness, require_full_subgroups, required_subgroup_size);
}; };
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> { auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> {
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
}; };
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> { auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector<uint32_t> {
// For large number of rows, 128 invocations seems to work best. // For large number of rows, 128 invocations seems to work best.
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
// can't use 256 for D==80. // can't use 256 for D==80.
@ -3006,7 +3005,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
? scalar_flash_attention_workgroup_size ? scalar_flash_attention_workgroup_size
: ((small_rows && (D % 32) == 0) ? 256 : 128); : ((small_rows && (D % 32) == 0) ? 256 : 128);
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows); auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@ -3021,21 +3020,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t HSK = fa.first.HSK; \ uint32_t HSK = fa.first.HSK; \
uint32_t HSV = fa.first.HSV; \ uint32_t HSV = fa.first.HSV; \
bool small_rows = fa.first.small_rows; \ bool small_rows = fa.first.small_rows; \
bool small_cache = fa.first.small_cache; \
FaCodePath path = fa.first.path; \ FaCodePath path = fa.first.path; \
bool aligned = fa.first.aligned; \ bool aligned = fa.first.aligned; \
bool f32acc = fa.first.f32acc; \ bool f32acc = fa.first.f32acc; \
if (path == FAPATH) { \ if (path == FAPATH) { \
if (aligned) { \ if (aligned) { \
if (f32acc) { \ if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \ } else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \ } \
} else { \ } else { \
if (f32acc) { \ if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} else { \ } else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
} \ } \
} \ } \
} \ } \
@ -8008,11 +8008,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
} }
} }
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) { static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
// Needs to be kept up to date on shader changes // Needs to be kept up to date on shader changes
GGML_UNUSED(hsv); GGML_UNUSED(hsv);
const uint32_t wg_size = scalar_flash_attention_workgroup_size; const uint32_t wg_size = scalar_flash_attention_workgroup_size;
const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv); const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t Bc = scalar_flash_attention_Bc;
const uint32_t tmpsh = wg_size * sizeof(float); const uint32_t tmpsh = wg_size * sizeof(float);
@ -8136,6 +8136,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_y = (uint32_t)neq2;
uint32_t workgroups_z = (uint32_t)neq3; uint32_t workgroups_z = (uint32_t)neq3;
const bool small_cache = nek1 < 1024;
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa). // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
uint32_t max_gqa; uint32_t max_gqa;
@ -8143,7 +8145,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
case FA_SCALAR: case FA_SCALAR:
case FA_COOPMAT1: case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both // We may switch from coopmat1 to scalar, so use the scalar limit for both
max_gqa = get_fa_scalar_num_large_rows(HSK, HSV); max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
break; break;
case FA_COOPMAT2: case FA_COOPMAT2:
max_gqa = get_fa_num_small_rows(FA_COOPMAT2); max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
@ -8177,7 +8179,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
if (path == FA_SCALAR && if (path == FA_SCALAR &&
!ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) { !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
small_rows = true; small_rows = true;
} }
@ -8193,7 +8195,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
v_stride /= 4; v_stride /= 4;
} }
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
bool aligned = (KV % alignment) == 0 && bool aligned = (KV % alignment) == 0 &&
// the "aligned" shader variant will forcibly align strides, for performance // the "aligned" shader variant will forcibly align strides, for performance
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
@ -8205,7 +8207,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc); vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc);
vk_pipeline pipeline = nullptr; vk_pipeline pipeline = nullptr;