Use wave32 on AMD RDNA for scalar FA

This commit is contained in:
Ruben Ortlam 2026-02-12 18:44:26 +01:00
parent 16cb912442
commit 3ae5466aaf
1 changed files with 15 additions and 9 deletions

View File

@ -2790,23 +2790,28 @@ static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) {
return 0xFFFFFFFF;
}
if (path == FA_SCALAR && device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) {
return 32;
}
return device->subgroup_size;
}
static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, FaRows rows, uint32_t Br, uint32_t Bc) {
const uint32_t D = std::max(hsk, hsv);
const uint32_t subgroup_size = fa_disable_subgroups(device, path) ? 32 : fa_subgroup_size(device, path);
switch (path) {
case FA_COOPMAT2:
return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128);
case FA_COOPMAT1:
return (Bc / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
return (Bc / 16) * subgroup_size; // enough subgroups for Bc/MatBc
default:
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
return 128;
} else if (device->subgroup_size > 32 && Br < 4) {
return device->subgroup_size * 2;
} else if (subgroup_size > 32 && Br < 4) {
return subgroup_size * 2;
} else {
return device->subgroup_size * 4;
return subgroup_size * 4;
}
}
}
@ -2815,7 +2820,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
GGML_UNUSED(clamp);
if (path == FA_SCALAR) {
if (rows == FA_ROWS_1 && ((hsk|hsv) & 8)) {
if (rows == FA_ROWS_1 || ((hsk|hsv) & 8)) {
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
// But this only applies to row_split=1, meaning FA_ROWS_1
@ -3255,18 +3260,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
bool f32acc = fa.first.f32acc; \
uint32_t flags = fa.first.flags; \
bool fa_ds = path == FA_SCALAR && disable_subgroups; \
uint32_t fa_sgs = fa_subgroup_size(device, path); \
if (path == FAPATH) { \
if (aligned) { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} \
} else { \
if (f32acc) { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} else { \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? device->subgroup_size : 0)); \
ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
} \
} \
} \