parent
8783ed4e3c
commit
f23e4b9f15
|
|
@ -551,9 +551,6 @@ static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_vie
|
|||
{ 4, 0, 3 }, // set_rows->src[0] == view
|
||||
};
|
||||
|
||||
struct vk_matrix_dimension {
|
||||
uint32_t m, n, k;
|
||||
};
|
||||
|
||||
struct vk_device_struct {
|
||||
std::recursive_mutex mutex;
|
||||
|
|
@ -619,10 +616,14 @@ struct vk_device_struct {
|
|||
bool coopmat_support_16x16x16_f16acc {};
|
||||
bool coopmat_support_16x16x16_f32acc {};
|
||||
bool coopmat1_fa_support {};
|
||||
vk_matrix_dimension coopmat;
|
||||
uint32_t coopmat_m;
|
||||
uint32_t coopmat_n;
|
||||
uint32_t coopmat_k;
|
||||
|
||||
bool coopmat_int_support;
|
||||
vk_matrix_dimension coopmat_int;
|
||||
uint32_t coopmat_int_m;
|
||||
uint32_t coopmat_int_n;
|
||||
uint32_t coopmat_int_k;
|
||||
|
||||
bool coopmat2;
|
||||
|
||||
|
|
@ -3045,31 +3046,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
s_align = 32;
|
||||
} else {
|
||||
// Matrix cores require different warp group sizes
|
||||
const vk_matrix_dimension l_t = {
|
||||
device->coopmat_support ? device->coopmat.m : 4,
|
||||
device->coopmat_support ? device->coopmat.n : 4,
|
||||
device->coopmat_support ? device->coopmat.k : 1,
|
||||
};
|
||||
const vk_matrix_dimension m_t = {
|
||||
device->coopmat_support ? device->coopmat.m : 4,
|
||||
device->coopmat_support ? device->coopmat.n : 2,
|
||||
device->coopmat_support ? device->coopmat.k : 1,
|
||||
};
|
||||
const vk_matrix_dimension s_t = {
|
||||
device->coopmat_support ? device->coopmat.m : 2,
|
||||
device->coopmat_support ? device->coopmat.n : 2,
|
||||
device->coopmat_support ? device->coopmat.k : 1,
|
||||
};
|
||||
const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4;
|
||||
const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4;
|
||||
const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2;
|
||||
const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4;
|
||||
const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2;
|
||||
const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2;
|
||||
const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1;
|
||||
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
|
||||
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
|
||||
|
||||
const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
|
||||
|
||||
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, l_t.m, l_t.n, l_t.k, subgroup_size_8 };
|
||||
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, m_t.m, m_t.n, m_t.k, subgroup_size_8 };
|
||||
s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, s_t.m, s_t.n, s_t.k, subgroup_size_8 };
|
||||
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
|
||||
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
||||
|
||||
l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, l_t.m, l_t.n, l_t.k, subgroup_size_8 };
|
||||
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, m_t.m, m_t.n, m_t.k, subgroup_size_8 };
|
||||
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, s_t.m, s_t.n, s_t.k, subgroup_size_8 };
|
||||
l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
|
||||
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
||||
|
||||
// Integer MMQ has a smaller shared memory profile, but heavier register use
|
||||
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
||||
|
|
@ -3081,13 +3076,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
|
||||
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, subgroup_size_8 };
|
||||
|
||||
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, l_t.m, l_t.n, l_t.k, mul_mat_subgroup_size_16 };
|
||||
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, m_t.m, m_t.n, m_t.k, mul_mat_subgroup_size_16 };
|
||||
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, s_warptile_wm, 32, 2, s_t.m, s_t.n, s_t.k, mul_mat_subgroup_size_16 };
|
||||
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
|
||||
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
|
||||
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
|
||||
|
||||
l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, l_t.m, l_t.n, l_t.k, mul_mat_subgroup_size_8 };
|
||||
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, m_t.m, m_t.n, m_t.k, mul_mat_subgroup_size_8 };
|
||||
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, s_t.m, s_t.n, s_t.k, mul_mat_subgroup_size_8 };
|
||||
l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
|
||||
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
|
||||
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
|
||||
|
||||
l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
|
||||
m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
|
||||
|
|
@ -3103,13 +3098,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_AMD && device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary) {
|
||||
// This is intentionally using tx_m values, slight performance increase
|
||||
l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, m_t.m, m_t.n, m_t.k, subgroup_size_8 };
|
||||
l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, m_t.m, m_t.n, m_t.k, subgroup_size_8 };
|
||||
l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
l_warptile_mmq_int_k = { 256, 128, 128, 32, subgroup_size_16, 64, 1, 4, 2, 1, subgroup_size_16 };
|
||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2_ONWARD) {
|
||||
// Xe2/Xe3 with coopmat enabled - warptile performance tuning
|
||||
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, m_t.m, m_t.n, m_t.k, subgroup_size_8 };
|
||||
l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, m_t.m, m_t.n, m_t.k, subgroup_size_8 };
|
||||
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
||||
}
|
||||
|
||||
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
||||
|
|
@ -4643,9 +4638,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
|
||||
!getenv("GGML_VK_DISABLE_COOPMAT")) {
|
||||
device->coopmat_support = true;
|
||||
device->coopmat.m = 0;
|
||||
device->coopmat.n = 0;
|
||||
device->coopmat.k = 0;
|
||||
device->coopmat_m = 0;
|
||||
device->coopmat_n = 0;
|
||||
device->coopmat_k = 0;
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
||||
|
|
@ -5139,12 +5134,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
|
||||
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) {
|
||||
// coopmat sizes not set yet
|
||||
if (device->coopmat.m == 0) {
|
||||
if (device->coopmat_m == 0) {
|
||||
device->coopmat_acc_f32_support = true;
|
||||
device->coopmat.m = prop.MSize;
|
||||
device->coopmat.n = prop.NSize;
|
||||
device->coopmat.k = prop.KSize;
|
||||
} else if (device->coopmat.m == prop.MSize && device->coopmat.n == prop.NSize && device->coopmat.k == prop.KSize) {
|
||||
device->coopmat_m = prop.MSize;
|
||||
device->coopmat_n = prop.NSize;
|
||||
device->coopmat_k = prop.KSize;
|
||||
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
|
||||
// Only enable if shape is identical
|
||||
device->coopmat_acc_f32_support = true;
|
||||
}
|
||||
|
|
@ -5154,12 +5149,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
} else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
|
||||
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
|
||||
// coopmat sizes not set yet
|
||||
if (device->coopmat.m == 0) {
|
||||
if (device->coopmat_m == 0) {
|
||||
device->coopmat_acc_f16_support = true;
|
||||
device->coopmat.m = prop.MSize;
|
||||
device->coopmat.n = prop.NSize;
|
||||
device->coopmat.k = prop.KSize;
|
||||
} else if (device->coopmat.m == prop.MSize && device->coopmat.n == prop.NSize && device->coopmat.k == prop.KSize) {
|
||||
device->coopmat_m = prop.MSize;
|
||||
device->coopmat_n = prop.NSize;
|
||||
device->coopmat_k = prop.KSize;
|
||||
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
|
||||
// Only enable if shape is identical
|
||||
device->coopmat_acc_f16_support = true;
|
||||
}
|
||||
|
|
@ -5172,12 +5167,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
(vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 &&
|
||||
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
|
||||
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
|
||||
device->coopmat_int.m == 0
|
||||
device->coopmat_int_m == 0
|
||||
) {
|
||||
device->coopmat_int_support = true;
|
||||
device->coopmat_int.m = prop.MSize;
|
||||
device->coopmat_int.n = prop.NSize;
|
||||
device->coopmat_int.k = prop.KSize;
|
||||
device->coopmat_int_m = prop.MSize;
|
||||
device->coopmat_int_n = prop.NSize;
|
||||
device->coopmat_int_k = prop.KSize;
|
||||
}
|
||||
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
||||
|
|
@ -5187,12 +5182,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
|
||||
) {
|
||||
// coopmat sizes not set yet
|
||||
if (device->coopmat.m == 0) {
|
||||
if (device->coopmat_m == 0) {
|
||||
device->coopmat_bf16_support = true;
|
||||
device->coopmat.m = prop.MSize;
|
||||
device->coopmat.n = prop.NSize;
|
||||
device->coopmat.k = prop.KSize;
|
||||
} else if (device->coopmat.m == prop.MSize && device->coopmat.n == prop.NSize && device->coopmat.k == prop.KSize) {
|
||||
device->coopmat_m = prop.MSize;
|
||||
device->coopmat_n = prop.NSize;
|
||||
device->coopmat_k = prop.KSize;
|
||||
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
|
||||
// Only enable if shape is identical
|
||||
device->coopmat_bf16_support = true;
|
||||
}
|
||||
|
|
@ -5200,7 +5195,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
#endif
|
||||
}
|
||||
|
||||
if (device->coopmat.m == 0 || !device->coopmat_acc_f32_support) {
|
||||
if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
|
||||
// No suitable matmul mode found
|
||||
GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
|
||||
device->coopmat_support = false;
|
||||
|
|
|
|||
Loading…
Reference in New Issue