diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 10dea421bb..f48d354f10 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3037,12 +3037,25 @@ static void ggml_vk_load_shaders(vk_device& device) { // Xe2/Xe3 with coopmat enabled - warptile performance tuning l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + } else if (device->vendor_id == VK_VENDOR_ID_QUALCOMM) { + uint32_t s = device->subgroup_size; + m_warptile = m_warptile_mmq = m_warptile_mmq_int = m_warptile_mmq_k = m_warptile_mmq_int_k = + m_warptile_id = m_warptile_mmqid = m_warptile_mmqid_int = m_warptile_mmqid_int_k = + { 64, 64, 64, s, s, 16, 8, 8, 1, 1, s }; + s_warptile = s_warptile_mmq = s_warptile_mmq_int = s_warptile_mmq_k = s_warptile_mmq_int_k = + s_warptile_id = s_warptile_mmqid = s_warptile_mmqid_int = s_warptile_mmqid_int_k = + { 32, 32, 64, s, s, 16, 8, 8, 1, 1, s }; } else if (device->vendor_id == VK_VENDOR_ID_ARM && device->subgroup_size >= 16) { - uint32_t wm_iter = 32 / device->subgroup_size; - uint32_t wm_tile = device->subgroup_size * 2; - m_warptile_mmq = m_warptile_mmq_int = { 64, 64, 64, 16, wm_tile, 32, wm_iter, 2, 2, 1, device->subgroup_size }; - m_warptile = { 64, 64, 64, 16, wm_tile, 32, wm_iter, 2, 2, 1, device->subgroup_size }; - m_warptile_id = m_warptile_mmqid = { 64, 64, 64, 16, wm_tile, 32, wm_iter, 2, 2, 1, device->subgroup_size }; + uint32_t s = device->subgroup_size; + s_warptile = s_warptile_mmq = s_warptile_mmq_int = s_warptile_mmq_k = s_warptile_mmq_int_k = + s_warptile_id = s_warptile_mmqid = s_warptile_mmqid_int = s_warptile_mmqid_int_k = + { 32, 32, 64, s, s, 16, 4, 4, 1, 1, s }; + m_warptile = m_warptile_mmq = m_warptile_mmq_int = m_warptile_mmq_k = m_warptile_mmq_int_k = + m_warptile_id = m_warptile_mmqid = m_warptile_mmqid_int = m_warptile_mmqid_int_k = + { 64, 64, 64, s, s, 16, 4, 4, 1, 1, s }; + l_warptile = l_warptile_mmq = l_warptile_mmq_int = l_warptile_mmq_k = l_warptile_mmq_int_k = + l_warptile_id = l_warptile_mmqid = l_warptile_mmqid_int = l_warptile_mmqid_int_k = + { 64, 64, 64, s, s, 16, 4, 4, 1, 1, s }; } l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; @@ -3052,6 +3065,10 @@ static void ggml_vk_load_shaders(vk_device& device) { m_align = 64; s_align = 32; + if (device->vendor_id == VK_VENDOR_ID_QUALCOMM) { + m_align = 128; + } + for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { ggml_type t = (ggml_type)i; // Disable medium and large matrix multiplication if not enough shared memory is available @@ -3067,10 +3084,7 @@ static void ggml_vk_load_shaders(vk_device& device) { device->mul_mat_l[i] = false; } - if (device->vendor_id == VK_VENDOR_ID_ARM || device->vendor_id == VK_VENDOR_ID_QUALCOMM) { - device->mul_mat_l[i] = false; - device->mul_mat_id_l[i] = false; - } + // Disable mul_mat_id if not enough shared memory is available if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) { @@ -5182,6 +5196,15 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; + case VK_VENDOR_ID_ARM: + case VK_VENDOR_ID_QUALCOMM: + device->mul_mat_l[i] = false; + device->mul_mat_id_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; case VK_VENDOR_ID_APPLE: device->mul_mat_l[i] = false; device->mul_mat_m[i] = true; @@ -13756,6 +13779,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // (and scaled down based on model size, so smaller models submit earlier). // Also submit at least every 100 nodes, in case there are workloads without as much matmul. int nodes_per_submit = 100; + if (ctx->device->vendor_id == VK_VENDOR_ID_QUALCOMM) { + nodes_per_submit = 1000; + } int submitted_nodes = 0; int submit_count = 0; uint64_t mul_mat_bytes = 0;