diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ef99c3c1eb..db8d91f2fe 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -38,6 +38,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include #include #include +#include #if defined(_MSC_VER) # define NOMINMAX 1 @@ -561,12 +562,19 @@ struct vk_device_struct { size_t idx; - bool mul_mat_l[GGML_TYPE_COUNT]; - bool mul_mat_m[GGML_TYPE_COUNT]; - bool mul_mat_s[GGML_TYPE_COUNT]; - bool mul_mat_id_l[GGML_TYPE_COUNT]; - bool mul_mat_id_m[GGML_TYPE_COUNT]; - bool mul_mat_id_s[GGML_TYPE_COUNT]; + std::bitset mul_mat_l; + std::bitset mul_mat_m; + std::bitset mul_mat_s; + std::bitset mul_mat_id_l; + std::bitset mul_mat_id_m; + std::bitset mul_mat_id_s; + + std::bitset mul_mat_int_l; + std::bitset mul_mat_int_m; + std::bitset mul_mat_int_s; + std::bitset mul_mat_int_id_l; + std::bitset mul_mat_int_id_m; + std::bitset mul_mat_int_id_s; vk::DescriptorSetLayout dsl; @@ -2526,7 +2534,38 @@ static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1]; } -static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { +static uint32_t mmq_shmem_struct_size(const vk_device& device, ggml_type type) { + const uint32_t float_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + constexpr uint32_t int32_size = sizeof(uint32_t); + + switch (type) { + case GGML_TYPE_Q4_0: + return 4 * int32_size + float_size; + case GGML_TYPE_Q4_1: + return 4 * int32_size + 2 * float_size; + case GGML_TYPE_Q5_0: + return 5 * int32_size + float_size; + case GGML_TYPE_Q5_1: + return 5 * int32_size + 2 * float_size; + case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + return 8 * int32_size + float_size; + case GGML_TYPE_Q8_1: + return 8 * int32_size + 2 * float_size; + case GGML_TYPE_Q2_K: + return 2 * int32_size + 2 + 2 * float_size; + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + return 4 * int32_size + 2 * float_size; + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + return 8 * int32_size + 2 * float_size; + default: + return 0; + } +} + +static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type, ggml_type src1_type, uint32_t bk_step) { uint32_t lut_size = 0; switch (src0_type) { @@ -2559,11 +2598,18 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec } // Needs to be kept up to date on shader changes - const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; - const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); const uint32_t warps = warptile[0] / warptile[10]; + uint32_t load_bufs; + + if (src1_type != GGML_TYPE_Q8_1) { + const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; + const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + + load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; + } else { + load_bufs = (warptile[1] + warptile[2]) * bk_step * mmq_shmem_struct_size(device, src0_type); + } - const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0; const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0; @@ -2644,6 +2690,28 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev return 0; // If no matching configuration is found } +static uint32_t get_default_bk_step(const vk_device& device, ggml_type src0_type, bool mul_mat_id) { + const uint32_t bk_struct_size = mmq_shmem_struct_size(device, src0_type); + + const uint32_t q5_0_struct_size = mmq_shmem_struct_size(device, GGML_TYPE_Q5_0); + + // Smaller struct means we can fit more in shared memory + if (bk_struct_size < q5_0_struct_size) { + return 4; + } + + // GCN likes large caches + if (device->architecture == vk_device_architecture::AMD_GCN) { + return 4; + } + + if (device->vendor_id == VK_VENDOR_ID_INTEL) { + return 1; + } + + return mul_mat_id ? 1 : 2; +} + static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); @@ -2676,6 +2744,13 @@ static void ggml_vk_load_shaders(vk_device& device) { l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; + std::array mul_mat_int_bk_step_s; + std::array mul_mat_int_bk_step_m; + std::array mul_mat_int_bk_step_l; + std::array mul_mat_int_bk_step_id_s; + std::array mul_mat_int_bk_step_id_m; + std::array mul_mat_int_bk_step_id_l; + uint32_t l_align, m_align, s_align; if (device->coopmat2) { // spec constants and tile sizes for non-quant matmul/matmul_id @@ -2734,14 +2809,14 @@ static void ggml_vk_load_shaders(vk_device& device) { s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; // Integer MMQ has a smaller shared memory profile, but heavier register use - l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; - m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; - s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + l_warptile_mmq_int = { 128, 128, 128, 0, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int = { 128, 64, 64, 0, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int = { subgroup_size_32, 32, 32, 0, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; // K-quants use even more registers, mitigate by setting WMITER to 1 - l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 }; - m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 }; - s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 }; + l_warptile_mmq_int_k = { 128, 128, 128, 0, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int_k = { 128, 64, 64, 0, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 0, 32, 32, 1, 2, 1, 1, subgroup_size_8 }; l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; @@ -2751,13 +2826,13 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; - l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 }; - m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 }; - s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 }; + l_warptile_mmqid_int = { 128, 128, 128, 0, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 }; + m_warptile_mmqid_int = { 128, 64, 64, 0, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 }; + s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 0, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 }; - l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 }; - m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 }; - s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 }; + l_warptile_mmqid_int_k = { 128, 128, 128, 0, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 }; + m_warptile_mmqid_int_k = { 128, 64, 64, 0, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 }; + s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 0, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 }; // chip specific tuning if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { @@ -2773,31 +2848,81 @@ static void ggml_vk_load_shaders(vk_device& device) { s_align = 32; for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { - ggml_type t = (ggml_type)i; + const ggml_type t = (ggml_type)i; // Disable medium and large matrix multiplication if not enough shared memory is available // Check mmq warptiles as the largest configuration // Throw an error if not enough for any matrix multiplication is available - if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) { + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t, GGML_TYPE_F32, 1)) { std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; throw std::runtime_error("Shared memory size too small for matrix multiplication."); - } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) { + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t, GGML_TYPE_F32, 1)) { device->mul_mat_m[i] = false; device->mul_mat_l[i] = false; - } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) { + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t, GGML_TYPE_F32, 1)) { device->mul_mat_l[i] = false; } // Disable mul_mat_id if not enough shared memory is available - if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) { + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t, GGML_TYPE_F32, 1)) { device->mul_mat_id_s[i] = false; device->mul_mat_id_m[i] = false; device->mul_mat_id_l[i] = false; - } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) { + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t, GGML_TYPE_F32, 1)) { device->mul_mat_id_m[i] = false; device->mul_mat_id_l[i] = false; - } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) { + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t, GGML_TYPE_F32, 1)) { device->mul_mat_id_l[i] = false; } + + // Integer dot matmul has different shared memory requirements + // bk_step is how many blocks in k-direction are stored in shared memory at once + // Each value is initialized to 4, then reduced until it fits, or 0 if it is not supported + + mul_mat_int_bk_step_s[i] = device->mul_mat_int_s[i] ? get_default_bk_step(device, t, false) : 0; + mul_mat_int_bk_step_m[i] = device->mul_mat_int_m[i] ? get_default_bk_step(device, t, false) : 0; + mul_mat_int_bk_step_l[i] = device->mul_mat_int_l[i] ? get_default_bk_step(device, t, false) : 0; + mul_mat_int_bk_step_id_s[i] = device->mul_mat_int_id_s[i] ? get_default_bk_step(device, t, true) : 0; + mul_mat_int_bk_step_id_m[i] = device->mul_mat_int_id_m[i] ? get_default_bk_step(device, t, true) : 0; + mul_mat_int_bk_step_id_l[i] = device->mul_mat_int_id_l[i] ? get_default_bk_step(device, t, true) : 0; + + for (uint32_t bk_step : { 4, 2, 1 }) { + if (mul_mat_int_bk_step_s[i] == bk_step && !ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t, GGML_TYPE_Q8_1, bk_step)) { + mul_mat_int_bk_step_s[i] >>= 1; + mul_mat_int_bk_step_m[i] >>= 1; + mul_mat_int_bk_step_l[i] >>= 1; + } else if (mul_mat_int_bk_step_m[i] == bk_step && !ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t, GGML_TYPE_Q8_1, bk_step)) { + mul_mat_int_bk_step_m[i] >>= 1; + mul_mat_int_bk_step_l[i] >>= 1; + } else if (mul_mat_int_bk_step_l[i] == bk_step && !ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t, GGML_TYPE_Q8_1, bk_step)) { + mul_mat_int_bk_step_l[i] >>= 1; + } + + if (mul_mat_int_bk_step_id_s[i] == bk_step && !ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t, GGML_TYPE_Q8_1, bk_step)) { + mul_mat_int_bk_step_id_s[i] >>= 1; + mul_mat_int_bk_step_id_m[i] >>= 1; + mul_mat_int_bk_step_id_l[i] >>= 1; + } else if (mul_mat_int_bk_step_id_m[i] == bk_step && !ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t, GGML_TYPE_Q8_1, bk_step)) { + mul_mat_int_bk_step_id_m[i] >>= 1; + mul_mat_int_bk_step_id_l[i] >>= 1; + } else if (mul_mat_int_bk_step_id_l[i] == bk_step && !ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t, GGML_TYPE_Q8_1, bk_step)) { + mul_mat_int_bk_step_id_l[i] >>= 1; + } + } + + // std::cerr << "ggml_vulkan: Info: Integer dot-product matmul support for type " << ggml_type_name(t) << ": " + // << "small bk_step=" << (int)mul_mat_int_bk_step_s[i] << ", " + // << "medium bk_step=" << (int)mul_mat_int_bk_step_m[i] << ", " + // << "large bk_step=" << (int)mul_mat_int_bk_step_l[i] << "; " + // << "matmul_id small bk_step=" << (int)mul_mat_int_bk_step_id_s[i] << ", " + // << "medium bk_step=" << (int)mul_mat_int_bk_step_id_m[i] << ", " + // << "large bk_step=" << (int)mul_mat_int_bk_step_id_l[i] << std::endl; + + device->mul_mat_int_s[i] = mul_mat_int_bk_step_s[i] > 0; + device->mul_mat_int_m[i] = mul_mat_int_bk_step_m[i] > 0; + device->mul_mat_int_l[i] = mul_mat_int_bk_step_l[i] > 0; + device->mul_mat_int_id_s[i] = mul_mat_int_bk_step_id_s[i] > 0; + device->mul_mat_int_id_m[i] = mul_mat_int_bk_step_id_m[i] > 0; + device->mul_mat_int_id_l[i] = mul_mat_int_bk_step_id_l[i] > 0; } } @@ -2865,6 +2990,22 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; + auto const &s_mmq_warptile_bk_step = [&](const std::vector& warptile, ggml_type type) -> std::vector { + std::vector warptile_copy = warptile; + warptile_copy[3] = mul_mat_int_bk_step_s[type]; + return warptile_copy; + }; + auto const &m_mmq_warptile_bk_step = [&](const std::vector& warptile, ggml_type type) -> std::vector { + std::vector warptile_copy = warptile; + warptile_copy[3] = mul_mat_int_bk_step_m[type]; + return warptile_copy; + }; + auto const &l_mmq_warptile_bk_step = [&](const std::vector& warptile, ggml_type type) -> std::vector { + std::vector warptile_copy = warptile; + warptile_copy[3] = mul_mat_int_bk_step_l[type]; + return warptile_copy; + }; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; }; @@ -3150,14 +3291,14 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ - if (device->mul_mat ## ID ## _l[TYPE]) { \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat_int ## ID ## _l[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_mmq_warptile_bk_step(l_ ## WARPTILE, TYPE), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ - if (device->mul_mat ## ID ## _m[TYPE]) { \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat_int ## ID ## _m[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_mmq_warptile_bk_step(m_ ## WARPTILE, TYPE), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ - if (device->mul_mat ## ID ## _s[TYPE]) { \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat_int ## ID ## _s[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_mmq_warptile_bk_step(s_ ## WARPTILE, TYPE), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ // Create 2 variants, {f16,f32} accumulator @@ -3321,12 +3462,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + if (device->mul_mat_int ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_mmq_warptile_bk_step(l_ ## WARPTILE, TYPE), 1); \ + if (device->mul_mat_int ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_mmq_warptile_bk_step(m_ ## WARPTILE, TYPE), 1); \ + if (device->mul_mat_int ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_mmq_warptile_bk_step(s_ ## WARPTILE, TYPE), 1); \ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); @@ -4693,8 +4834,14 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_s[i] = true; break; } - } + device->mul_mat_int_l[i] = device->mul_mat_l[i]; + device->mul_mat_int_m[i] = device->mul_mat_m[i]; + device->mul_mat_int_s[i] = device->mul_mat_s[i]; + device->mul_mat_int_id_l[i] = device->mul_mat_id_l[i]; + device->mul_mat_int_id_m[i] = device->mul_mat_id_m[i]; + device->mul_mat_int_id_s[i] = device->mul_mat_id_s[i]; + } std::vector dsl_binding; std::vector dsl_binding_flags; @@ -6130,6 +6277,16 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, return aligned ? mmp->a_s : mmp->s; } + if (src1_type == GGML_TYPE_Q8_1) { + if ((ctx->device->mul_mat_int_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_int_m[src0_type] && !ctx->device->mul_mat_int_l[src0_type])) { + return aligned ? mmp->a_s : mmp->s; + } + if ((ctx->device->mul_mat_int_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_int_l[src0_type]) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_l : mmp->l; + } + if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { return aligned ? mmp->a_s : mmp->s; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 5266e523b9..11eb253f83 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -67,7 +67,7 @@ layout (push_constant) uniform parameter layout (constant_id = 0) const uint BLOCK_SIZE = 64; layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; -// layout (constant_id = 3) const uint BK = 32; +layout (constant_id = 3) const uint BK_STEP = 1; // Amount of quant blocks stored in shared memory layout (constant_id = 4) const uint WM = 32; layout (constant_id = 5) const uint WN = 32; layout (constant_id = 6) const uint WMITER = 2; @@ -82,14 +82,6 @@ layout (constant_id = 10) const uint WARP = 32; #include "mul_mmq_shmem_types.glsl" -#ifdef MUL_MAT_ID -#define BK_STEP 1 -#else -#ifndef BK_STEP -#define BK_STEP 4 -#endif -#endif - // Shared memory cache shared block_a_cache buf_a[BM * BK_STEP]; shared block_b_cache buf_b[BN * BK_STEP];