device tuning

This commit is contained in:
Ruben Ortlam 2026-02-14 09:05:15 +01:00
parent dd92b1f8d5
commit 0b4b0d2e57
1 changed files with 27 additions and 29 deletions

View File

@ -2762,12 +2762,16 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
// number of rows/cols for flash attention shader
static constexpr uint32_t flash_attention_num_small_rows = 32;
static uint32_t get_fa_scalar_num_rows(uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) {
static uint32_t get_fa_scalar_num_rows(const vk_device& device, uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) {
if (rows == FA_ROWS_1) {
return 1;
}
if (rows == FA_ROWS_SMALL || hsv >= 192 || (hsv | hsk) & 8 || small_cache) {
if (
rows == FA_ROWS_SMALL || hsv >= 192 || (hsv | hsk) & 8 || small_cache ||
(device->architecture == AMD_GCN && hsk <= 64) ||
(device->vendor_id == VK_VENDOR_ID_INTEL)
) {
return 8;
}
@ -2785,12 +2789,12 @@ static bool fa_disable_subgroups(const vk_device& device, FaCodePath path) {
return device->vendor_id == VK_VENDOR_ID_INTEL && path == FA_SCALAR;
}
static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) {
static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path, FaRows rows) {
if (fa_disable_subgroups(device, path)) {
return 0xFFFFFFFF;
}
if (path == FA_SCALAR && device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) {
if (path == FA_SCALAR && device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && rows == FA_ROWS_1) {
return 32;
}
@ -2799,14 +2803,14 @@ static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) {
static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, FaRows rows, uint32_t Br, uint32_t Bc) {
const uint32_t D = std::max(hsk, hsv);
const uint32_t subgroup_size = fa_disable_subgroups(device, path) ? 32 : fa_subgroup_size(device, path);
const uint32_t subgroup_size = fa_disable_subgroups(device, path) ? 32 : fa_subgroup_size(device, path, rows);
switch (path) {
case FA_COOPMAT2:
return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128);
case FA_COOPMAT1:
return (Bc / 16) * subgroup_size; // enough subgroups for Bc/MatBc
default:
if (subgroup_size > 32 && Br < 4) {
if (subgroup_size > 32 && (Br < 4 || hsk < 64)) {
return subgroup_size * 2;
} else {
return subgroup_size * 4;
@ -2814,7 +2818,7 @@ static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint
}
}
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) {
static std::array<uint32_t, 2> fa_rows_cols(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) {
GGML_UNUSED(clamp);
if (path == FA_SCALAR) {
@ -2822,9 +2826,9 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
// But this only applies to row_split=1, meaning FA_ROWS_1
return {get_fa_scalar_num_rows(hsk, hsv, rows, small_cache), 64};
return {get_fa_scalar_num_rows(device, hsk, hsv, rows, small_cache), 64};
} else {
return {get_fa_scalar_num_rows(hsk, hsv, rows, small_cache), 32};
return {get_fa_scalar_num_rows(device, hsk, hsv, rows, small_cache), 32};
}
}
@ -2848,8 +2852,8 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
return {64, 64};
}
static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, FaRows rows, bool small_cache) {
return fa_rows_cols(path, hsk, hsv, 0, type, rows, small_cache)[1];
static uint32_t fa_align(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, FaRows rows, bool small_cache) {
return fa_rows_cols(device, path, hsk, hsv, 0, type, rows, small_cache)[1];
}
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
@ -3217,7 +3221,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
};
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) -> std::array<uint32_t, 3> {
return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1};
return {fa_rows_cols(device, path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1};
};
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
@ -3227,10 +3231,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
// For scalar, use 128 (arbitrary)
// The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
const uint32_t D = (hsk|hsv);
auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache);
auto rows_cols = fa_rows_cols(device, path, hsk, hsv, clamp, type, rows, small_cache);
const uint32_t wg_size = fa_workgroup_size(device, path, hsk, hsv, rows, rows_cols[0], rows_cols[1]);
const uint32_t subgroup_size = fa_subgroup_size(device, path);
const uint32_t subgroup_size = fa_subgroup_size(device, path, rows);
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
@ -3256,13 +3260,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
bool f32acc = fa.first.f32acc; \
uint32_t flags = fa.first.flags; \
bool fa_ds = fa_disable_subgroups(device, path); \
uint32_t fa_sgs = fa_subgroup_size(device, path); \
uint32_t fa_sgs = fa_subgroup_size(device, path, rows); \
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(device, FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(device, FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} \
} else { \
if (f32acc) { \
@ -8420,7 +8424,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, bool fp32acc) {
// Needs to be kept up to date on shader changes
const std::array<uint32_t, 2> rows_cols = fa_rows_cols(FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache);
const std::array<uint32_t, 2> rows_cols = fa_rows_cols(device, FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache);
const uint32_t Br = rows_cols[0];
const uint32_t Bc = rows_cols[1];
const uint32_t wg_size = fa_workgroup_size(device, FA_SCALAR, hsk, hsv, rows, Br, Bc);
@ -8451,7 +8455,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
// Needs to be kept up to date on shader changes
GGML_UNUSED(hsv);
const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, FA_ROWS_LARGE, false);
const auto rows_cols = fa_rows_cols(device, FA_COOPMAT1, hsk, hsv, 0, kv_type, FA_ROWS_LARGE, false);
const uint32_t Br = rows_cols[0];
const uint32_t Bc = rows_cols[1];
@ -8578,7 +8582,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
case FA_SCALAR:
case FA_COOPMAT1:
// We may switch from coopmat1 to scalar, so use the scalar limit for both
max_gqa = get_fa_scalar_num_rows(HSK, HSV, FA_ROWS_LARGE, small_cache);
max_gqa = get_fa_scalar_num_rows(ctx->device, HSK, HSV, FA_ROWS_LARGE, small_cache);
break;
case FA_COOPMAT2:
max_gqa = flash_attention_num_small_rows;
@ -8606,14 +8610,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
rows = FA_ROWS_LARGE;
}
// coopmat1 does not actually support "small rows" (it needs 16 rows).
// So use scalar instead.
if (rows != FA_ROWS_LARGE && path == FA_COOPMAT1) {
path = FA_SCALAR;
}
// scalar is faster than coopmat2 when N==1
if (rows == FA_ROWS_1 && path == FA_COOPMAT2) {
if (rows == FA_ROWS_1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) {
path = FA_SCALAR;
}
@ -8634,7 +8632,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
v_stride /= 4;
}
uint32_t alignment = fa_align(path, HSK, HSV, k->type, rows, small_cache);
uint32_t alignment = fa_align(ctx->device, path, HSK, HSV, k->type, rows, small_cache);
bool aligned = (KV % alignment) == 0 &&
// the "aligned" shader variant will forcibly align strides, for performance
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
@ -8691,7 +8689,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16;
auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache);
auto rows_cols = fa_rows_cols(ctx->device, path, HSK, HSV, !aligned, k->type, rows, small_cache);
const uint32_t Br = rows_cols[0];
const uint32_t Bc = rows_cols[1];