diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b3212ff139..025ac04ca9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2762,12 +2762,16 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; -static uint32_t get_fa_scalar_num_rows(uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) { +static uint32_t get_fa_scalar_num_rows(const vk_device& device, uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) { if (rows == FA_ROWS_1) { return 1; } - if (rows == FA_ROWS_SMALL || hsv >= 192 || (hsv | hsk) & 8 || small_cache) { + if ( + rows == FA_ROWS_SMALL || hsv >= 192 || (hsv | hsk) & 8 || small_cache || + (device->architecture == AMD_GCN && hsk <= 64) || + (device->vendor_id == VK_VENDOR_ID_INTEL) + ) { return 8; } @@ -2785,12 +2789,12 @@ static bool fa_disable_subgroups(const vk_device& device, FaCodePath path) { return device->vendor_id == VK_VENDOR_ID_INTEL && path == FA_SCALAR; } -static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) { +static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path, FaRows rows) { if (fa_disable_subgroups(device, path)) { return 0xFFFFFFFF; } - if (path == FA_SCALAR && device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) { + if (path == FA_SCALAR && device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && rows == FA_ROWS_1) { return 32; } @@ -2799,14 +2803,14 @@ static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) { static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, FaRows rows, uint32_t Br, uint32_t Bc) { const uint32_t D = std::max(hsk, hsv); - const uint32_t subgroup_size = fa_disable_subgroups(device, path) ? 32 : fa_subgroup_size(device, path); + const uint32_t subgroup_size = fa_disable_subgroups(device, path) ? 32 : fa_subgroup_size(device, path, rows); switch (path) { case FA_COOPMAT2: return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128); case FA_COOPMAT1: return (Bc / 16) * subgroup_size; // enough subgroups for Bc/MatBc default: - if (subgroup_size > 32 && Br < 4) { + if (subgroup_size > 32 && (Br < 4 || hsk < 64)) { return subgroup_size * 2; } else { return subgroup_size * 4; @@ -2814,7 +2818,7 @@ static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint } } -static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) { +static std::array fa_rows_cols(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) { GGML_UNUSED(clamp); if (path == FA_SCALAR) { @@ -2822,9 +2826,9 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. // But this only applies to row_split=1, meaning FA_ROWS_1 - return {get_fa_scalar_num_rows(hsk, hsv, rows, small_cache), 64}; + return {get_fa_scalar_num_rows(device, hsk, hsv, rows, small_cache), 64}; } else { - return {get_fa_scalar_num_rows(hsk, hsv, rows, small_cache), 32}; + return {get_fa_scalar_num_rows(device, hsk, hsv, rows, small_cache), 32}; } } @@ -2848,8 +2852,8 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 return {64, 64}; } -static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, FaRows rows, bool small_cache) { - return fa_rows_cols(path, hsk, hsv, 0, type, rows, small_cache)[1]; +static uint32_t fa_align(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, FaRows rows, bool small_cache) { + return fa_rows_cols(device, path, hsk, hsv, 0, type, rows, small_cache)[1]; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -3217,7 +3221,7 @@ static void ggml_vk_load_shaders(vk_device& device) { }; auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) -> std::array { - return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1}; + return {fa_rows_cols(device, path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1}; }; auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector { @@ -3227,10 +3231,10 @@ static void ggml_vk_load_shaders(vk_device& device) { // For scalar, use 128 (arbitrary) // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. const uint32_t D = (hsk|hsv); - auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache); + auto rows_cols = fa_rows_cols(device, path, hsk, hsv, clamp, type, rows, small_cache); const uint32_t wg_size = fa_workgroup_size(device, path, hsk, hsv, rows, rows_cols[0], rows_cols[1]); - const uint32_t subgroup_size = fa_subgroup_size(device, path); + const uint32_t subgroup_size = fa_subgroup_size(device, path, rows); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. @@ -3256,13 +3260,13 @@ static void ggml_vk_load_shaders(vk_device& device) { bool f32acc = fa.first.f32acc; \ uint32_t flags = fa.first.flags; \ bool fa_ds = fa_disable_subgroups(device, path); \ - uint32_t fa_sgs = fa_subgroup_size(device, path); \ + uint32_t fa_sgs = fa_subgroup_size(device, path, rows); \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(device, FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(device, FAPATH,HSK,HSV,TYPE,rows,small_cache), true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } \ } else { \ if (f32acc) { \ @@ -8420,7 +8424,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, bool fp32acc) { // Needs to be kept up to date on shader changes - const std::array rows_cols = fa_rows_cols(FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache); + const std::array rows_cols = fa_rows_cols(device, FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache); const uint32_t Br = rows_cols[0]; const uint32_t Bc = rows_cols[1]; const uint32_t wg_size = fa_workgroup_size(device, FA_SCALAR, hsk, hsv, rows, Br, Bc); @@ -8451,7 +8455,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); - const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, FA_ROWS_LARGE, false); + const auto rows_cols = fa_rows_cols(device, FA_COOPMAT1, hsk, hsv, 0, kv_type, FA_ROWS_LARGE, false); const uint32_t Br = rows_cols[0]; const uint32_t Bc = rows_cols[1]; @@ -8578,7 +8582,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_rows(HSK, HSV, FA_ROWS_LARGE, small_cache); + max_gqa = get_fa_scalar_num_rows(ctx->device, HSK, HSV, FA_ROWS_LARGE, small_cache); break; case FA_COOPMAT2: max_gqa = flash_attention_num_small_rows; @@ -8606,14 +8610,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx rows = FA_ROWS_LARGE; } - // coopmat1 does not actually support "small rows" (it needs 16 rows). - // So use scalar instead. - if (rows != FA_ROWS_LARGE && path == FA_COOPMAT1) { - path = FA_SCALAR; - } - // scalar is faster than coopmat2 when N==1 - if (rows == FA_ROWS_1 && path == FA_COOPMAT2) { + if (rows == FA_ROWS_1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) { path = FA_SCALAR; } @@ -8634,7 +8632,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride /= 4; } - uint32_t alignment = fa_align(path, HSK, HSV, k->type, rows, small_cache); + uint32_t alignment = fa_align(ctx->device, path, HSK, HSV, k->type, rows, small_cache); bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; @@ -8691,7 +8689,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Use a placeholder core count if one isn't available. split_k is a big help for perf. const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16; - auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache); + auto rows_cols = fa_rows_cols(ctx->device, path, HSK, HSV, !aligned, k->type, rows, small_cache); const uint32_t Br = rows_cols[0]; const uint32_t Bc = rows_cols[1];