vulkan: allow MMQ bk_step tuning
This commit is contained in:
parent
416e7c7f47
commit
7e8eb9ba0a
|
|
@ -38,6 +38,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
|
|||
#include <mutex>
|
||||
#include <future>
|
||||
#include <thread>
|
||||
#include <bitset>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
# define NOMINMAX 1
|
||||
|
|
@ -561,12 +562,19 @@ struct vk_device_struct {
|
|||
|
||||
size_t idx;
|
||||
|
||||
bool mul_mat_l[GGML_TYPE_COUNT];
|
||||
bool mul_mat_m[GGML_TYPE_COUNT];
|
||||
bool mul_mat_s[GGML_TYPE_COUNT];
|
||||
bool mul_mat_id_l[GGML_TYPE_COUNT];
|
||||
bool mul_mat_id_m[GGML_TYPE_COUNT];
|
||||
bool mul_mat_id_s[GGML_TYPE_COUNT];
|
||||
std::bitset<GGML_TYPE_COUNT> mul_mat_l;
|
||||
std::bitset<GGML_TYPE_COUNT> mul_mat_m;
|
||||
std::bitset<GGML_TYPE_COUNT> mul_mat_s;
|
||||
std::bitset<GGML_TYPE_COUNT> mul_mat_id_l;
|
||||
std::bitset<GGML_TYPE_COUNT> mul_mat_id_m;
|
||||
std::bitset<GGML_TYPE_COUNT> mul_mat_id_s;
|
||||
|
||||
std::bitset<GGML_TYPE_COUNT> mul_mat_int_l;
|
||||
std::bitset<GGML_TYPE_COUNT> mul_mat_int_m;
|
||||
std::bitset<GGML_TYPE_COUNT> mul_mat_int_s;
|
||||
std::bitset<GGML_TYPE_COUNT> mul_mat_int_id_l;
|
||||
std::bitset<GGML_TYPE_COUNT> mul_mat_int_id_m;
|
||||
std::bitset<GGML_TYPE_COUNT> mul_mat_int_id_s;
|
||||
|
||||
vk::DescriptorSetLayout dsl;
|
||||
|
||||
|
|
@ -2526,7 +2534,38 @@ static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type
|
|||
return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1];
|
||||
}
|
||||
|
||||
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
|
||||
static uint32_t mmq_shmem_struct_size(const vk_device& device, ggml_type type) {
|
||||
const uint32_t float_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
constexpr uint32_t int32_size = sizeof(uint32_t);
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
return 4 * int32_size + float_size;
|
||||
case GGML_TYPE_Q4_1:
|
||||
return 4 * int32_size + 2 * float_size;
|
||||
case GGML_TYPE_Q5_0:
|
||||
return 5 * int32_size + float_size;
|
||||
case GGML_TYPE_Q5_1:
|
||||
return 5 * int32_size + 2 * float_size;
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_MXFP4:
|
||||
return 8 * int32_size + float_size;
|
||||
case GGML_TYPE_Q8_1:
|
||||
return 8 * int32_size + 2 * float_size;
|
||||
case GGML_TYPE_Q2_K:
|
||||
return 2 * int32_size + 2 + 2 * float_size;
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q4_K:
|
||||
return 4 * int32_size + 2 * float_size;
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
return 8 * int32_size + 2 * float_size;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type, ggml_type src1_type, uint32_t bk_step) {
|
||||
|
||||
uint32_t lut_size = 0;
|
||||
switch (src0_type) {
|
||||
|
|
@ -2559,11 +2598,18 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|||
}
|
||||
|
||||
// Needs to be kept up to date on shader changes
|
||||
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
|
||||
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
const uint32_t warps = warptile[0] / warptile[10];
|
||||
uint32_t load_bufs;
|
||||
|
||||
if (src1_type != GGML_TYPE_Q8_1) {
|
||||
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
|
||||
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
|
||||
|
||||
load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
||||
} else {
|
||||
load_bufs = (warptile[1] + warptile[2]) * bk_step * mmq_shmem_struct_size(device, src0_type);
|
||||
}
|
||||
|
||||
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
||||
const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
|
||||
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
||||
const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
|
||||
|
|
@ -2644,6 +2690,28 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev
|
|||
return 0; // If no matching configuration is found
|
||||
}
|
||||
|
||||
static uint32_t get_default_bk_step(const vk_device& device, ggml_type src0_type, bool mul_mat_id) {
|
||||
const uint32_t bk_struct_size = mmq_shmem_struct_size(device, src0_type);
|
||||
|
||||
const uint32_t q5_0_struct_size = mmq_shmem_struct_size(device, GGML_TYPE_Q5_0);
|
||||
|
||||
// Smaller struct means we can fit more in shared memory
|
||||
if (bk_struct_size < q5_0_struct_size) {
|
||||
return 4;
|
||||
}
|
||||
|
||||
// GCN likes large caches
|
||||
if (device->architecture == vk_device_architecture::AMD_GCN) {
|
||||
return 4;
|
||||
}
|
||||
|
||||
if (device->vendor_id == VK_VENDOR_ID_INTEL) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
return mul_mat_id ? 1 : 2;
|
||||
}
|
||||
|
||||
static void ggml_vk_load_shaders(vk_device& device) {
|
||||
VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
|
||||
|
||||
|
|
@ -2676,6 +2744,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
|
||||
l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
|
||||
|
||||
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_s;
|
||||
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_m;
|
||||
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_l;
|
||||
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_id_s;
|
||||
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_id_m;
|
||||
std::array<uint8_t, GGML_TYPE_COUNT> mul_mat_int_bk_step_id_l;
|
||||
|
||||
uint32_t l_align, m_align, s_align;
|
||||
if (device->coopmat2) {
|
||||
// spec constants and tile sizes for non-quant matmul/matmul_id
|
||||
|
|
@ -2734,14 +2809,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
||||
|
||||
// Integer MMQ has a smaller shared memory profile, but heavier register use
|
||||
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
||||
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
|
||||
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
|
||||
l_warptile_mmq_int = { 128, 128, 128, 0, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
||||
m_warptile_mmq_int = { 128, 64, 64, 0, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
|
||||
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 0, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
|
||||
|
||||
// K-quants use even more registers, mitigate by setting WMITER to 1
|
||||
l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
|
||||
m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
|
||||
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
|
||||
l_warptile_mmq_int_k = { 128, 128, 128, 0, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
|
||||
m_warptile_mmq_int_k = { 128, 64, 64, 0, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
|
||||
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 0, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
|
||||
|
||||
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
|
||||
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
|
||||
|
|
@ -2751,13 +2826,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
|
||||
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
|
||||
|
||||
l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
|
||||
m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
|
||||
s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
|
||||
l_warptile_mmqid_int = { 128, 128, 128, 0, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
|
||||
m_warptile_mmqid_int = { 128, 64, 64, 0, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
|
||||
s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 0, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
|
||||
|
||||
l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
|
||||
m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
|
||||
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
|
||||
l_warptile_mmqid_int_k = { 128, 128, 128, 0, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
|
||||
m_warptile_mmqid_int_k = { 128, 64, 64, 0, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
|
||||
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 0, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
|
||||
|
||||
// chip specific tuning
|
||||
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
|
||||
|
|
@ -2773,31 +2848,81 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
s_align = 32;
|
||||
|
||||
for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
|
||||
ggml_type t = (ggml_type)i;
|
||||
const ggml_type t = (ggml_type)i;
|
||||
// Disable medium and large matrix multiplication if not enough shared memory is available
|
||||
// Check mmq warptiles as the largest configuration
|
||||
// Throw an error if not enough for any matrix multiplication is available
|
||||
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) {
|
||||
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t, GGML_TYPE_F32, 1)) {
|
||||
std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
|
||||
throw std::runtime_error("Shared memory size too small for matrix multiplication.");
|
||||
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) {
|
||||
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t, GGML_TYPE_F32, 1)) {
|
||||
device->mul_mat_m[i] = false;
|
||||
device->mul_mat_l[i] = false;
|
||||
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) {
|
||||
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t, GGML_TYPE_F32, 1)) {
|
||||
device->mul_mat_l[i] = false;
|
||||
}
|
||||
|
||||
// Disable mul_mat_id if not enough shared memory is available
|
||||
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) {
|
||||
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t, GGML_TYPE_F32, 1)) {
|
||||
device->mul_mat_id_s[i] = false;
|
||||
device->mul_mat_id_m[i] = false;
|
||||
device->mul_mat_id_l[i] = false;
|
||||
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) {
|
||||
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t, GGML_TYPE_F32, 1)) {
|
||||
device->mul_mat_id_m[i] = false;
|
||||
device->mul_mat_id_l[i] = false;
|
||||
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
|
||||
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t, GGML_TYPE_F32, 1)) {
|
||||
device->mul_mat_id_l[i] = false;
|
||||
}
|
||||
|
||||
// Integer dot matmul has different shared memory requirements
|
||||
// bk_step is how many blocks in k-direction are stored in shared memory at once
|
||||
// Each value is initialized to 4, then reduced until it fits, or 0 if it is not supported
|
||||
|
||||
mul_mat_int_bk_step_s[i] = device->mul_mat_int_s[i] ? get_default_bk_step(device, t, false) : 0;
|
||||
mul_mat_int_bk_step_m[i] = device->mul_mat_int_m[i] ? get_default_bk_step(device, t, false) : 0;
|
||||
mul_mat_int_bk_step_l[i] = device->mul_mat_int_l[i] ? get_default_bk_step(device, t, false) : 0;
|
||||
mul_mat_int_bk_step_id_s[i] = device->mul_mat_int_id_s[i] ? get_default_bk_step(device, t, true) : 0;
|
||||
mul_mat_int_bk_step_id_m[i] = device->mul_mat_int_id_m[i] ? get_default_bk_step(device, t, true) : 0;
|
||||
mul_mat_int_bk_step_id_l[i] = device->mul_mat_int_id_l[i] ? get_default_bk_step(device, t, true) : 0;
|
||||
|
||||
for (uint32_t bk_step : { 4, 2, 1 }) {
|
||||
if (mul_mat_int_bk_step_s[i] == bk_step && !ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t, GGML_TYPE_Q8_1, bk_step)) {
|
||||
mul_mat_int_bk_step_s[i] >>= 1;
|
||||
mul_mat_int_bk_step_m[i] >>= 1;
|
||||
mul_mat_int_bk_step_l[i] >>= 1;
|
||||
} else if (mul_mat_int_bk_step_m[i] == bk_step && !ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t, GGML_TYPE_Q8_1, bk_step)) {
|
||||
mul_mat_int_bk_step_m[i] >>= 1;
|
||||
mul_mat_int_bk_step_l[i] >>= 1;
|
||||
} else if (mul_mat_int_bk_step_l[i] == bk_step && !ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t, GGML_TYPE_Q8_1, bk_step)) {
|
||||
mul_mat_int_bk_step_l[i] >>= 1;
|
||||
}
|
||||
|
||||
if (mul_mat_int_bk_step_id_s[i] == bk_step && !ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t, GGML_TYPE_Q8_1, bk_step)) {
|
||||
mul_mat_int_bk_step_id_s[i] >>= 1;
|
||||
mul_mat_int_bk_step_id_m[i] >>= 1;
|
||||
mul_mat_int_bk_step_id_l[i] >>= 1;
|
||||
} else if (mul_mat_int_bk_step_id_m[i] == bk_step && !ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t, GGML_TYPE_Q8_1, bk_step)) {
|
||||
mul_mat_int_bk_step_id_m[i] >>= 1;
|
||||
mul_mat_int_bk_step_id_l[i] >>= 1;
|
||||
} else if (mul_mat_int_bk_step_id_l[i] == bk_step && !ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t, GGML_TYPE_Q8_1, bk_step)) {
|
||||
mul_mat_int_bk_step_id_l[i] >>= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// std::cerr << "ggml_vulkan: Info: Integer dot-product matmul support for type " << ggml_type_name(t) << ": "
|
||||
// << "small bk_step=" << (int)mul_mat_int_bk_step_s[i] << ", "
|
||||
// << "medium bk_step=" << (int)mul_mat_int_bk_step_m[i] << ", "
|
||||
// << "large bk_step=" << (int)mul_mat_int_bk_step_l[i] << "; "
|
||||
// << "matmul_id small bk_step=" << (int)mul_mat_int_bk_step_id_s[i] << ", "
|
||||
// << "medium bk_step=" << (int)mul_mat_int_bk_step_id_m[i] << ", "
|
||||
// << "large bk_step=" << (int)mul_mat_int_bk_step_id_l[i] << std::endl;
|
||||
|
||||
device->mul_mat_int_s[i] = mul_mat_int_bk_step_s[i] > 0;
|
||||
device->mul_mat_int_m[i] = mul_mat_int_bk_step_m[i] > 0;
|
||||
device->mul_mat_int_l[i] = mul_mat_int_bk_step_l[i] > 0;
|
||||
device->mul_mat_int_id_s[i] = mul_mat_int_bk_step_id_s[i] > 0;
|
||||
device->mul_mat_int_id_m[i] = mul_mat_int_bk_step_id_m[i] > 0;
|
||||
device->mul_mat_int_id_l[i] = mul_mat_int_bk_step_id_l[i] > 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2865,6 +2990,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
align, disable_robustness, require_full_subgroups, required_subgroup_size);
|
||||
};
|
||||
|
||||
auto const &s_mmq_warptile_bk_step = [&](const std::vector<uint32_t>& warptile, ggml_type type) -> std::vector<uint32_t> {
|
||||
std::vector<uint32_t> warptile_copy = warptile;
|
||||
warptile_copy[3] = mul_mat_int_bk_step_s[type];
|
||||
return warptile_copy;
|
||||
};
|
||||
auto const &m_mmq_warptile_bk_step = [&](const std::vector<uint32_t>& warptile, ggml_type type) -> std::vector<uint32_t> {
|
||||
std::vector<uint32_t> warptile_copy = warptile;
|
||||
warptile_copy[3] = mul_mat_int_bk_step_m[type];
|
||||
return warptile_copy;
|
||||
};
|
||||
auto const &l_mmq_warptile_bk_step = [&](const std::vector<uint32_t>& warptile, ggml_type type) -> std::vector<uint32_t> {
|
||||
std::vector<uint32_t> warptile_copy = warptile;
|
||||
warptile_copy[3] = mul_mat_int_bk_step_l[type];
|
||||
return warptile_copy;
|
||||
};
|
||||
|
||||
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1};
|
||||
};
|
||||
|
|
@ -3150,14 +3291,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) { \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat_int ## ID ## _l[TYPE]) { \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_mmq_warptile_bk_step(l_ ## WARPTILE, TYPE), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
} \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) { \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat_int ## ID ## _m[TYPE]) { \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_mmq_warptile_bk_step(m_ ## WARPTILE, TYPE), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
} \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) { \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat_int ## ID ## _s[TYPE]) { \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_mmq_warptile_bk_step(s_ ## WARPTILE, TYPE), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
} \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
|
|
@ -3321,12 +3462,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
||||
if (device->mul_mat_int ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_mmq_warptile_bk_step(l_ ## WARPTILE, TYPE), 1); \
|
||||
if (device->mul_mat_int ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_mmq_warptile_bk_step(m_ ## WARPTILE, TYPE), 1); \
|
||||
if (device->mul_mat_int ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_mmq_warptile_bk_step(s_ ## WARPTILE, TYPE), 1); \
|
||||
|
||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
||||
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
|
||||
|
|
@ -4693,8 +4834,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
device->mul_mat_id_s[i] = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
device->mul_mat_int_l[i] = device->mul_mat_l[i];
|
||||
device->mul_mat_int_m[i] = device->mul_mat_m[i];
|
||||
device->mul_mat_int_s[i] = device->mul_mat_s[i];
|
||||
device->mul_mat_int_id_l[i] = device->mul_mat_id_l[i];
|
||||
device->mul_mat_int_id_m[i] = device->mul_mat_id_m[i];
|
||||
device->mul_mat_int_id_s[i] = device->mul_mat_id_s[i];
|
||||
}
|
||||
|
||||
std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
|
||||
std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
|
||||
|
|
@ -6130,6 +6277,16 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
|
|||
return aligned ? mmp->a_s : mmp->s;
|
||||
}
|
||||
|
||||
if (src1_type == GGML_TYPE_Q8_1) {
|
||||
if ((ctx->device->mul_mat_int_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_int_m[src0_type] && !ctx->device->mul_mat_int_l[src0_type])) {
|
||||
return aligned ? mmp->a_s : mmp->s;
|
||||
}
|
||||
if ((ctx->device->mul_mat_int_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_int_l[src0_type]) {
|
||||
return aligned ? mmp->a_m : mmp->m;
|
||||
}
|
||||
return aligned ? mmp->a_l : mmp->l;
|
||||
}
|
||||
|
||||
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
|
||||
return aligned ? mmp->a_s : mmp->s;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ layout (push_constant) uniform parameter
|
|||
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
|
||||
layout (constant_id = 1) const uint BM = 64;
|
||||
layout (constant_id = 2) const uint BN = 64;
|
||||
// layout (constant_id = 3) const uint BK = 32;
|
||||
layout (constant_id = 3) const uint BK_STEP = 1; // Amount of quant blocks stored in shared memory
|
||||
layout (constant_id = 4) const uint WM = 32;
|
||||
layout (constant_id = 5) const uint WN = 32;
|
||||
layout (constant_id = 6) const uint WMITER = 2;
|
||||
|
|
@ -82,14 +82,6 @@ layout (constant_id = 10) const uint WARP = 32;
|
|||
|
||||
#include "mul_mmq_shmem_types.glsl"
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#define BK_STEP 1
|
||||
#else
|
||||
#ifndef BK_STEP
|
||||
#define BK_STEP 4
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Shared memory cache
|
||||
shared block_a_cache buf_a[BM * BK_STEP];
|
||||
shared block_b_cache buf_b[BN * BK_STEP];
|
||||
|
|
|
|||
Loading…
Reference in New Issue