diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 74abeb4b53..e590dd537d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2803,6 +2803,26 @@ static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) { return device->subgroup_size; } +static uint32_t fa_workgroup_size(const vk_device& device, FaCodePath path, uint32_t hsk, uint32_t hsv, FaRows rows, uint32_t Br, uint32_t Bc) { + const uint32_t D = std::max(hsk, hsv); + switch (path) { + case FA_COOPMAT2: + return ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128); + case FA_COOPMAT1: + return (Bc / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc + case FA_VECTOR: + return device->vendor_id == VK_VENDOR_ID_AMD ? 256 : 128; + default: + if (device->vendor_id == VK_VENDOR_ID_INTEL) { + return 128; + } else if (device->subgroup_size > 32 && Br < 4) { + return device->subgroup_size * 2; + } else { + return device->subgroup_size * 4; + } + } +} + static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) { GGML_UNUSED(clamp); @@ -3220,29 +3240,7 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t D = (hsk|hsv); auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, rows, small_cache); - uint32_t wg_size; - switch (path) { - case FA_COOPMAT2: - wg_size = ((rows != FA_ROWS_LARGE && (D % 32) == 0) ? 256 : 128); - break; - case FA_COOPMAT1: - if (disable_subgroups) { - wg_size = 128; - } else { - wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc - } - break; - default: - if (disable_subgroups) { - wg_size = 128; - } else if (device->subgroup_size > 32 && rows_cols[0] < 4) { - wg_size = device->subgroup_size * 2; - } else { - wg_size = device->subgroup_size * 4; - } - break; - } - + const uint32_t wg_size = fa_workgroup_size(device, path, hsk, hsv, rows, rows_cols[0], rows_cols[1]); const uint32_t subgroup_size = fa_subgroup_size(device, path); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. @@ -8426,21 +8424,29 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, FaRows rows, bool small_cache) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache, bool fp32acc) { // Needs to be kept up to date on shader changes - GGML_UNUSED(hsv); - const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_rows(hsk, hsv, rows, small_cache); - const uint32_t Bc = scalar_flash_attention_Bc; + const std::array rows_cols = fa_rows_cols(device, FA_SCALAR, hsk, hsv, clamp, type, rows, small_cache); + const uint32_t Br = rows_cols[0]; + const uint32_t Bc = rows_cols[1]; + const uint32_t wg_size = fa_workgroup_size(device, FA_SCALAR, hsk, hsv, rows, Br, Bc); + const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const uint32_t acc_type_size = !fp32acc ? sizeof(ggml_fp16_t) : sizeof(float); + + // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); - const uint32_t tmpshv4 = wg_size * 4 * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * acc_type_size; - const uint32_t masksh = Bc * Br * sizeof(float); + const uint32_t masksh = Bc * (Br + 1) * float_type_size; - const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float); + const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; - const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; + const uint32_t D = std::max(hsk, hsv); + const bool shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256; + const uint32_t kvsh = shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); @@ -8569,6 +8575,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_z = (uint32_t)neq3; const bool small_cache = nek1 < 1024; + const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32; // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). @@ -8617,8 +8624,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } // with large hsk/hsv, scalar path may need to use small rows to fit in shared memory - if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, FA_ROWS_LARGE, small_cache)) { - rows = FA_ROWS_SMALL; + if (path == FA_SCALAR && rows == FA_ROWS_LARGE && !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, 0, k->type, FA_ROWS_LARGE, small_cache, f32acc)) { + rows = FA_ROWS_8; } const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); @@ -8643,8 +8650,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx aligned = false; } - bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; - float scale = 1.0f; float max_bias = 0.0f; float logit_softcap = 0.0f;