diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a99375c088..cb7fa2c9cb 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -254,6 +254,7 @@ enum vk_device_architecture { AMD_RDNA3, INTEL_XE2, NVIDIA_PRE_TURING, + NVIDIA_TURING, }; static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { @@ -336,18 +337,34 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& const std::vector ext_props = device.enumerateDeviceExtensionProperties(); bool cooperative_matrix = false; + bool sm_builtins = false; // Detect "pre-turing" based on lack of coopmat support. for (const auto& properties : ext_props) { if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) { cooperative_matrix = true; - break; + } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { + sm_builtins = true; } } if (!cooperative_matrix) { return vk_device_architecture::NVIDIA_PRE_TURING; } + + if (sm_builtins) { + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; + + props2.pNext = &sm_props; + + device.getProperties2(&props2); + + // Turing has 32, following architectures have 48 + if (sm_props.shaderWarpsPerSM == 32) { + return vk_device_architecture::NVIDIA_TURING; + } + } } return vk_device_architecture::OTHER; } @@ -8460,6 +8477,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + if (path == FA_COOPMAT1 && ctx->device->architecture == vk_device_architecture::NVIDIA_TURING) { + // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 + path = FA_SCALAR; + } + if (path == FA_COOPMAT1) { const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);