diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index db8d91f2fe..4adcf559f4 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2690,26 +2690,57 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev return 0; // If no matching configuration is found } +static bool is_k_quant(ggml_type type) { + switch (type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + return true; + default: + return false; + } +} + static uint32_t get_default_bk_step(const vk_device& device, ggml_type src0_type, bool mul_mat_id) { const uint32_t bk_struct_size = mmq_shmem_struct_size(device, src0_type); const uint32_t q5_0_struct_size = mmq_shmem_struct_size(device, GGML_TYPE_Q5_0); - // Smaller struct means we can fit more in shared memory - if (bk_struct_size < q5_0_struct_size) { - return 4; - } + const bool kq = is_k_quant(src0_type); - // GCN likes large caches if (device->architecture == vk_device_architecture::AMD_GCN) { + if (mul_mat_id) { + return kq ? 1 : 4; + } return 4; - } - - if (device->vendor_id == VK_VENDOR_ID_INTEL) { + } else if (device->vendor_id == VK_VENDOR_ID_AMD) { return 1; } - return mul_mat_id ? 1 : 2; + if (device->vendor_id == VK_VENDOR_ID_INTEL) { + if (mul_mat_id) { + if (kq) { + return 1; + } + + return src0_type != GGML_TYPE_Q8_0 ? 4 : 1; + } + + if (kq) { + return src0_type == GGML_TYPE_Q4_K ? 4 : 1; + } + + return src0_type != GGML_TYPE_Q8_0 ? 4 : 1; + } + + // Nvidia/Generic case + if (!mul_mat_id && !kq) { + return 1; + } + + return 4; } static void ggml_vk_load_shaders(vk_device& device) {