vulkan: allow MMQ bk_step tuning

This commit is contained in:
0cc4m 2025-11-16 14:25:24 +01:00
parent 416e7c7f47
commit 7e8eb9ba0a
2 changed files with 200 additions and 51 deletions

View File

@ -38,6 +38,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
#include <mutex>
#include <future>
#include <thread>
#include <bitset>
#if defined(_MSC_VER)
# define NOMINMAX 1
@ -561,12 +562,19 @@ struct vk_device_struct {
size_t idx;
bool mul_mat_l[GGML_TYPE_COUNT];
bool mul_mat_m[GGML_TYPE_COUNT];
bool mul_mat_s[GGML_TYPE_COUNT];
bool mul_mat_id_l[GGML_TYPE_COUNT];
bool mul_mat_id_m[GGML_TYPE_COUNT];
bool mul_mat_id_s[GGML_TYPE_COUNT];
std::bitset<GGML_TYPE_COUNT> mul_mat_l;
std::bitset<GGML_TYPE_COUNT> mul_mat_m;
std::bitset<GGML_TYPE_COUNT> mul_mat_s;
std::bitset<GGML_TYPE_COUNT> mul_mat_id_l;
std::bitset<GGML_TYPE_COUNT> mul_mat_id_m;
std::bitset<GGML_TYPE_COUNT> mul_mat_id_s;
std::bitset<GGML_TYPE_COUNT> mul_mat_int_l;
std::bitset<GGML_TYPE_COUNT> mul_mat_int_m;
std::bitset<GGML_TYPE_COUNT> mul_mat_int_s;
std::bitset<GGML_TYPE_COUNT> mul_mat_int_id_l;
std::bitset<GGML_TYPE_COUNT> mul_mat_int_id_m;
std::bitset<GGML_TYPE_COUNT> mul_mat_int_id_s;
vk::DescriptorSetLayout dsl;
@ -2526,7 +2534,38 @@ static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
}
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
static uint32_t mmq_shmem_struct_size(const vk_device& device, ggml_type type) {
const uint32_t float_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
constexpr uint32_t int32_size = sizeof(uint32_t);
switch (type) {
case GGML_TYPE_Q4_0:
return 4 * int32_size + float_size;
case GGML_TYPE_Q4_1:
return 4 * int32_size + 2 * float_size;
case GGML_TYPE_Q5_0:
return 5 * int32_size + float_size;
case GGML_TYPE_Q5_1:
return 5 * int32_size + 2 * float_size;
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
return 8 * int32_size + float_size;
case GGML_TYPE_Q8_1:
return 8 * int32_size + 2 * float_size;
case GGML_TYPE_Q2_K:
return 2 * int32_size + 2 + 2 * float_size;
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
return 4 * int32_size + 2 * float_size;
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
return 8 * int32_size + 2 * float_size;
default:
return 0;
}
}
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type, ggml_type src1_type, uint32_t bk_step) {
uint32_t lut_size = 0;
switch (src0_type) {
@ -2559,11 +2598,18 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
}
// Needs to be kept up to date on shader changes
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
const uint32_t warps = warptile[0] / warptile[10];
uint32_t load_bufs;
if (src1_type != GGML_TYPE_Q8_1) {
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
} else {
load_bufs = (warptile[1] + warptile[2]) * bk_step * mmq_shmem_struct_size(device, src0_type);
}
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
@ -2644,6 +2690,28 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev
return 0; // If no matching configuration is found
}
static uint32_t get_default_bk_step(const vk_device& device, ggml_type src0_type, bool mul_mat_id) {
const uint32_t bk_struct_size = mmq_shmem_struct_size(device, src0_type);
const uint32_t q5_0_struct_size = mmq_shmem_struct_size(device, GGML_TYPE_Q5_0);
// Smaller struct means we can fit more in shared memory
if (bk_struct_size < q5_0_struct_size) {
return 4;
}
// GCN likes large caches
if (device->architecture == vk_device_architecture::AMD_GCN) {
return 4;
}
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
return 1;
}
return mul_mat_id ? 1 : 2;
}
static void ggml_vk_load_shaders(vk_device& device) {
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
@ -2676,6 +2744,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_s;
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_m;
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_l;
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_id_s;
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_id_m;
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_id_l;
uint32_t l_align, m_align, s_align;
if (device->coopmat2) {
// spec constants and tile sizes for non-quant matmul/matmul_id
@ -2734,14 +2809,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
// Integer MMQ has a smaller shared memory profile, but heavier register use
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
l_warptile_mmq_int = { 128, 128, 128, 0, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 0, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 0, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
// K-quants use even more registers, mitigate by setting WMITER to 1
l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
l_warptile_mmq_int_k = { 128, 128, 128, 0, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int_k = { 128, 64, 64, 0, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 0, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
@ -2751,13 +2826,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
l_warptile_mmqid_int = { 128, 128, 128, 0, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
m_warptile_mmqid_int = { 128, 64, 64, 0, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 0, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
l_warptile_mmqid_int_k = { 128, 128, 128, 0, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
m_warptile_mmqid_int_k = { 128, 64, 64, 0, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 0, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
// chip specific tuning
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
@ -2773,31 +2848,81 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_align = 32;
for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
ggml_type t = (ggml_type)i;
const ggml_type t = (ggml_type)i;
// Disable medium and large matrix multiplication if not enough shared memory is available
// Check mmq warptiles as the largest configuration
// Throw an error if not enough for any matrix multiplication is available
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) {
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t, GGML_TYPE_F32, 1)) {
std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
throw std::runtime_error("Shared memory size too small for matrix multiplication.");
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) {
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t, GGML_TYPE_F32, 1)) {
device->mul_mat_m[i] = false;
device->mul_mat_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) {
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t, GGML_TYPE_F32, 1)) {
device->mul_mat_l[i] = false;
}
// Disable mul_mat_id if not enough shared memory is available
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) {
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t, GGML_TYPE_F32, 1)) {
device->mul_mat_id_s[i] = false;
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) {
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t, GGML_TYPE_F32, 1)) {
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t, GGML_TYPE_F32, 1)) {
device->mul_mat_id_l[i] = false;
}
// Integer dot matmul has different shared memory requirements
// bk_step is how many blocks in k-direction are stored in shared memory at once
// Each value is initialized to 4, then reduced until it fits, or 0 if it is not supported
mul_mat_int_bk_step_s[i] = device->mul_mat_int_s[i] ? get_default_bk_step(device, t, false) : 0;
mul_mat_int_bk_step_m[i] = device->mul_mat_int_m[i] ? get_default_bk_step(device, t, false) : 0;
mul_mat_int_bk_step_l[i] = device->mul_mat_int_l[i] ? get_default_bk_step(device, t, false) : 0;
mul_mat_int_bk_step_id_s[i] = device->mul_mat_int_id_s[i] ? get_default_bk_step(device, t, true) : 0;
mul_mat_int_bk_step_id_m[i] = device->mul_mat_int_id_m[i] ? get_default_bk_step(device, t, true) : 0;
mul_mat_int_bk_step_id_l[i] = device->mul_mat_int_id_l[i] ? get_default_bk_step(device, t, true) : 0;
for (uint32_t bk_step : { 4, 2, 1 }) {
if (mul_mat_int_bk_step_s[i] == bk_step && !ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t, GGML_TYPE_Q8_1, bk_step)) {
mul_mat_int_bk_step_s[i] >>= 1;
mul_mat_int_bk_step_m[i] >>= 1;
mul_mat_int_bk_step_l[i] >>= 1;
} else if (mul_mat_int_bk_step_m[i] == bk_step && !ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t, GGML_TYPE_Q8_1, bk_step)) {
mul_mat_int_bk_step_m[i] >>= 1;
mul_mat_int_bk_step_l[i] >>= 1;
} else if (mul_mat_int_bk_step_l[i] == bk_step && !ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t, GGML_TYPE_Q8_1, bk_step)) {
mul_mat_int_bk_step_l[i] >>= 1;
}
if (mul_mat_int_bk_step_id_s[i] == bk_step && !ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t, GGML_TYPE_Q8_1, bk_step)) {
mul_mat_int_bk_step_id_s[i] >>= 1;
mul_mat_int_bk_step_id_m[i] >>= 1;
mul_mat_int_bk_step_id_l[i] >>= 1;
} else if (mul_mat_int_bk_step_id_m[i] == bk_step && !ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t, GGML_TYPE_Q8_1, bk_step)) {
mul_mat_int_bk_step_id_m[i] >>= 1;
mul_mat_int_bk_step_id_l[i] >>= 1;
} else if (mul_mat_int_bk_step_id_l[i] == bk_step && !ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t, GGML_TYPE_Q8_1, bk_step)) {
mul_mat_int_bk_step_id_l[i] >>= 1;
}
}
// std::cerr << "ggml_vulkan: Info: Integer dot-product matmul support for type " << ggml_type_name(t) << ": "
// << "small bk_step=" << (int)mul_mat_int_bk_step_s[i] << ", "
// << "medium bk_step=" << (int)mul_mat_int_bk_step_m[i] << ", "
// << "large bk_step=" << (int)mul_mat_int_bk_step_l[i] << "; "
// << "matmul_id small bk_step=" << (int)mul_mat_int_bk_step_id_s[i] << ", "
// << "medium bk_step=" << (int)mul_mat_int_bk_step_id_m[i] << ", "
// << "large bk_step=" << (int)mul_mat_int_bk_step_id_l[i] << std::endl;
device->mul_mat_int_s[i] = mul_mat_int_bk_step_s[i] > 0;
device->mul_mat_int_m[i] = mul_mat_int_bk_step_m[i] > 0;
device->mul_mat_int_l[i] = mul_mat_int_bk_step_l[i] > 0;
device->mul_mat_int_id_s[i] = mul_mat_int_bk_step_id_s[i] > 0;
device->mul_mat_int_id_m[i] = mul_mat_int_bk_step_id_m[i] > 0;
device->mul_mat_int_id_l[i] = mul_mat_int_bk_step_id_l[i] > 0;
}
}
@ -2865,6 +2990,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
align, disable_robustness, require_full_subgroups, required_subgroup_size);
};
auto const &s_mmq_warptile_bk_step = [&](const std::vector<uint32_t>& warptile, ggml_type type) -> std::vector<uint32_t> {
std::vector<uint32_t> warptile_copy = warptile;
warptile_copy[3] = mul_mat_int_bk_step_s[type];
return warptile_copy;
};
auto const &m_mmq_warptile_bk_step = [&](const std::vector<uint32_t>& warptile, ggml_type type) -> std::vector<uint32_t> {
std::vector<uint32_t> warptile_copy = warptile;
warptile_copy[3] = mul_mat_int_bk_step_m[type];
return warptile_copy;
};
auto const &l_mmq_warptile_bk_step = [&](const std::vector<uint32_t>& warptile, ggml_type type) -> std::vector<uint32_t> {
std::vector<uint32_t> warptile_copy = warptile;
warptile_copy[3] = mul_mat_int_bk_step_l[type];
return warptile_copy;
};
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
};
@ -3150,14 +3291,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
if (device->mul_mat ## ID ## _l[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat_int ## ID ## _l[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_mmq_warptile_bk_step(l_ ## WARPTILE, TYPE), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
if (device->mul_mat ## ID ## _m[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat_int ## ID ## _m[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_mmq_warptile_bk_step(m_ ## WARPTILE, TYPE), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
if (device->mul_mat ## ID ## _s[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
if (device->mul_mat_int ## ID ## _s[TYPE]) { \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_mmq_warptile_bk_step(s_ ## WARPTILE, TYPE), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
} \
// Create 2 variants, {f16,f32} accumulator
@ -3321,12 +3462,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
if (device->mul_mat_int ## ID ## _l[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_mmq_warptile_bk_step(l_ ## WARPTILE, TYPE), 1); \
if (device->mul_mat_int ## ID ## _m[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_mmq_warptile_bk_step(m_ ## WARPTILE, TYPE), 1); \
if (device->mul_mat_int ## ID ## _s[TYPE]) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_mmq_warptile_bk_step(s_ ## WARPTILE, TYPE), 1); \
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
@ -4693,8 +4834,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->mul_mat_id_s[i] = true;
break;
}
}
device->mul_mat_int_l[i] = device->mul_mat_l[i];
device->mul_mat_int_m[i] = device->mul_mat_m[i];
device->mul_mat_int_s[i] = device->mul_mat_s[i];
device->mul_mat_int_id_l[i] = device->mul_mat_id_l[i];
device->mul_mat_int_id_m[i] = device->mul_mat_id_m[i];
device->mul_mat_int_id_s[i] = device->mul_mat_id_s[i];
}
std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
@ -6130,6 +6277,16 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
return aligned ? mmp->a_s : mmp->s;
}
if (src1_type == GGML_TYPE_Q8_1) {
if ((ctx->device->mul_mat_int_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_int_m[src0_type] && !ctx->device->mul_mat_int_l[src0_type])) {
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_int_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_int_l[src0_type]) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_l : mmp->l;
}
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
return aligned ? mmp->a_s : mmp->s;
}

View File

@ -67,7 +67,7 @@ layout (push_constant) uniform parameter
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
// layout (constant_id = 3) const uint BK = 32;
layout (constant_id = 3) const uint BK_STEP = 1; // Amount of quant blocks stored in shared memory
layout (constant_id = 4) const uint WM = 32;
layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2;
@ -82,14 +82,6 @@ layout (constant_id = 10) const uint WARP = 32;
#include "mul_mmq_shmem_types.glsl"
#ifdef MUL_MAT_ID
#define BK_STEP 1
#else
#ifndef BK_STEP
#define BK_STEP 4
#endif
#endif
// Shared memory cache
shared block_a_cache buf_a[BM * BK_STEP];
shared block_b_cache buf_b[BN * BK_STEP];