diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f19271ca87..801da7c7f8 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4559,6 +4559,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch); +static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev); static vk_device ggml_vk_get_device(size_t idx) { VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); @@ -4775,6 +4776,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->shader_core_count = sm_props.shaderSMCount; } else if (amd_shader_core_properties2) { device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) { + device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device); } else { device->shader_core_count = 0; } @@ -8686,8 +8689,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t split_kv = KV; uint32_t split_k = 1; + // Intel Alchemist prefers more workgroups + const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1; + // Use a placeholder core count if one isn't available. split_k is a big help for perf. - const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16; auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, rows, small_cache); const uint32_t Br = rows_cols[0]; @@ -15446,6 +15452,46 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope } } +static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) { + VkPhysicalDeviceProperties2 props = vkdev.getProperties2(); + + if (props.properties.vendorID != VK_VENDOR_ID_INTEL) { + return 0; + } + + const uint32_t device_id = props.properties.deviceID; + + switch (device_id) { + case 0x56A6: // A310 + return 6; + case 0x5693: // A370M + case 0x56A5: // A380 + case 0x56B1: // Pro A40/A50 + return 8; + case 0x5697: // A530M + return 12; + case 0x5692: // A550M + case 0x56B3: // Pro A60 + return 16; + case 0x56A2: // A580 + return 24; + case 0x5691: // A730M + case 0x56A1: // A750 + return 28; + case 0x56A0: // A770 + case 0x5690: // A770M + return 32; + case 0xE212: // Pro B50 + return 16; + case 0xE20C: // B570 + return 18; + case 0xE20B: // B580 + return 20; + default: + return 0; + } +} + // checks #ifdef GGML_VULKAN_CHECK_RESULTS