diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d8a2dc098e..ba2eb211a3 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -401,21 +401,27 @@ enum FaCodePath { FA_COOPMAT1, FA_COOPMAT2, }; +enum FaRows { + FA_ROWS_1, + FA_ROWS_SMALL, + FA_ROWS_LARGE, +}; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags) - : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {} + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, FaRows rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags) + : HSK(HSK), HSV(HSV), rows(rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {} uint32_t HSK, HSV; - bool small_rows, small_cache; + FaRows rows; + bool small_cache; FaCodePath path; bool aligned; bool f32acc; uint32_t flags; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) < - std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags); + return std::tie(HSK, HSV, rows, small_cache, path, aligned, f32acc, flags) < + std::tie(b.HSK, b.HSV, b.rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags); } }; @@ -2755,16 +2761,21 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; -static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) { +static uint32_t get_fa_scalar_num_rows(uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) { + if (rows == FA_ROWS_1) { + return 1; + } else if (rows == FA_ROWS_SMALL) { + return 4; + } + if (hsv >= 192) { return 8; } else if ((hsv | hsk) & 8 || small_cache) { return 8; - } else { - return 16; } + + return 16; } // The FA coopmat1 shader assumes 16x16x16 matrix multiply support. @@ -2774,36 +2785,20 @@ static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; static constexpr uint32_t scalar_flash_attention_Bc = 64; static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; -static uint32_t get_fa_num_small_rows(FaCodePath path) { - if (path == FA_COOPMAT2) { - return flash_attention_num_small_rows; - } else { - return scalar_flash_attention_num_small_rows; - } -} - -static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) { +static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) { GGML_UNUSED(clamp); if (path == FA_SCALAR) { - if (small_rows) { - return {scalar_flash_attention_num_small_rows, 64}; - } else { - return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64}; - } + return {get_fa_scalar_num_rows(hsk, hsv, rows, small_cache), 64}; } if (path == FA_COOPMAT1) { - if (small_rows) { - return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; - } else { - return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; - } + return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; } // small rows, large cols - if (small_rows) { - return {get_fa_num_small_rows(FA_COOPMAT2), 32}; + if (rows != FA_ROWS_LARGE) { + return {flash_attention_num_small_rows, 32}; } // small cols to reduce register count @@ -2817,8 +2812,8 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 return {64, 64}; } -static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) { - return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1]; +static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, FaRows rows, bool small_cache) { + return fa_rows_cols(path, hsk, hsv, 0, type, rows, small_cache)[1]; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -3185,23 +3180,23 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; - auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array { - return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) -> std::array { + return {fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, uint32_t flags) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. // For scalar, use 128 (arbitrary) // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. const uint32_t D = (hsk|hsv); - auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache); + auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache); uint32_t wg_size; switch (path) { case FA_COOPMAT2: - wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128); + wg_size = ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128); break; case FA_COOPMAT1: wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc @@ -3232,7 +3227,7 @@ static void ggml_vk_load_shaders(vk_device& device) { for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ uint32_t HSK = fa.first.HSK; \ uint32_t HSV = fa.first.HSV; \ - bool small_rows = fa.first.small_rows; \ + FaRows rows = fa.first.rows; \ bool small_cache = fa.first.small_cache; \ FaCodePath path = fa.first.path; \ bool aligned = fa.first.aligned; \ @@ -3241,15 +3236,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,rows,small_cache), true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,rows,small_cache,flags), 1, true, FAPATH!=FA_COOPMAT2, (FAPATH!=FA_COOPMAT2 ? device->subgroup_size : 0)); \ } \ } \ } \ @@ -8396,11 +8391,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache); + const uint32_t Br = get_fa_scalar_num_rows(hsk, hsv, rows, small_cache); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -8421,7 +8416,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); - const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false); + const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, FA_ROWS_LARGE, false); const uint32_t Br = rows_cols[0]; const uint32_t Bc = rows_cols[1]; @@ -8547,10 +8542,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache); + max_gqa = get_fa_scalar_num_rows(HSK, HSV, FA_ROWS_LARGE, small_cache); break; case FA_COOPMAT2: - max_gqa = get_fa_num_small_rows(FA_COOPMAT2); + max_gqa = flash_attention_num_small_rows; break; default: GGML_ASSERT(0); @@ -8566,23 +8561,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_y /= gqa_ratio; } - bool small_rows = N <= get_fa_num_small_rows(path); + FaRows rows; + if (N == 1) { + rows = FA_ROWS_1; + } else if (N <= 8) { + rows = FA_ROWS_SMALL; + } else { + rows = FA_ROWS_LARGE; + } // coopmat1 does not actually support "small rows" (it needs 16 rows). // So use scalar instead. - if (small_rows && path == FA_COOPMAT1) { + if (rows != FA_ROWS_LARGE && path == FA_COOPMAT1) { path = FA_SCALAR; } // scalar is faster than coopmat2 when N==1 - if (N == 1 && path == FA_COOPMAT2) { + if (rows == FA_ROWS_1 && path == FA_COOPMAT2) { path = FA_SCALAR; } - // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory - if (path == FA_SCALAR && - !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) { - small_rows = true; + // with large hsk/hsv, scalar path may need to use small rows to fit in shared memory + if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, FA_ROWS_LARGE, small_cache)) { + rows = FA_ROWS_SMALL; } const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); @@ -8597,7 +8598,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride /= 4; } - uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache); + uint32_t alignment = fa_align(path, HSK, HSV, k->type, rows, small_cache); bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; @@ -8628,7 +8629,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx (mask != nullptr ? 2 : 0) | (logit_softcap != 0 ? 4 : 0); - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags); + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, rows, small_cache, path, aligned, f32acc, flags); vk_pipeline pipeline = nullptr; @@ -8679,7 +8680,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_preallocate_buffers(ctx, subctx); } - auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache); + auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache); const uint32_t Br = rows_cols[0]; const uint32_t Bc = rows_cols[1];