Revert "refactored matrix dimension"

This reverts commit edccd26d0f.
This commit is contained in:
Nakasaka, Masato 2026-01-18 23:24:28 -08:00
parent 8783ed4e3c
commit f23e4b9f15
1 changed files with 54 additions and 59 deletions

View File

@ -551,9 +551,6 @@ static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_vie
{ 4, 0, 3 }, // set_rows->src[0] == view
};
struct vk_matrix_dimension {
uint32_t m, n, k;
};
struct vk_device_struct {
std::recursive_mutex mutex;
@ -619,10 +616,14 @@ struct vk_device_struct {
bool coopmat_support_16x16x16_f16acc {};
bool coopmat_support_16x16x16_f32acc {};
bool coopmat1_fa_support {};
vk_matrix_dimension coopmat;
uint32_t coopmat_m;
uint32_t coopmat_n;
uint32_t coopmat_k;
bool coopmat_int_support;
vk_matrix_dimension coopmat_int;
uint32_t coopmat_int_m;
uint32_t coopmat_int_n;
uint32_t coopmat_int_k;
bool coopmat2;
@ -3045,31 +3046,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_align = 32;
} else {
// Matrix cores require different warp group sizes
const vk_matrix_dimension l_t = {
device->coopmat_support ? device->coopmat.m : 4,
device->coopmat_support ? device->coopmat.n : 4,
device->coopmat_support ? device->coopmat.k : 1,
};
const vk_matrix_dimension m_t = {
device->coopmat_support ? device->coopmat.m : 4,
device->coopmat_support ? device->coopmat.n : 2,
device->coopmat_support ? device->coopmat.k : 1,
};
const vk_matrix_dimension s_t = {
device->coopmat_support ? device->coopmat.m : 2,
device->coopmat_support ? device->coopmat.n : 2,
device->coopmat_support ? device->coopmat.k : 1,
};
const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4;
const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4;
const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2;
const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4;
const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2;
const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2;
const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1;
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, l_t.m, l_t.n, l_t.k, subgroup_size_8 };
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, m_t.m, m_t.n, m_t.k, subgroup_size_8 };
s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, s_t.m, s_t.n, s_t.k, subgroup_size_8 };
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, l_t.m, l_t.n, l_t.k, subgroup_size_8 };
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, m_t.m, m_t.n, m_t.k, subgroup_size_8 };
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, s_t.m, s_t.n, s_t.k, subgroup_size_8 };
l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
// Integer MMQ has a smaller shared memory profile, but heavier register use
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
@ -3081,13 +3076,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, subgroup_size_8 };
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, l_t.m, l_t.n, l_t.k, mul_mat_subgroup_size_16 };
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, m_t.m, m_t.n, m_t.k, mul_mat_subgroup_size_16 };
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, s_warptile_wm, 32, 2, s_t.m, s_t.n, s_t.k, mul_mat_subgroup_size_16 };
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, l_t.m, l_t.n, l_t.k, mul_mat_subgroup_size_8 };
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, m_t.m, m_t.n, m_t.k, mul_mat_subgroup_size_8 };
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, s_t.m, s_t.n, s_t.k, mul_mat_subgroup_size_8 };
l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
@ -3103,13 +3098,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
} else if (device->vendor_id == VK_VENDOR_ID_AMD && device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary) {
// This is intentionally using tx_m values, slight performance increase
l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, m_t.m, m_t.n, m_t.k, subgroup_size_8 };
l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, m_t.m, m_t.n, m_t.k, subgroup_size_8 };
l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
l_warptile_mmq_int_k = { 256, 128, 128, 32, subgroup_size_16, 64, 1, 4, 2, 1, subgroup_size_16 };
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2_ONWARD) {
// Xe2/Xe3 with coopmat enabled - warptile performance tuning
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, m_t.m, m_t.n, m_t.k, subgroup_size_8 };
l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, m_t.m, m_t.n, m_t.k, subgroup_size_8 };
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
}
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
@ -4643,9 +4638,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
} else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
!getenv("GGML_VK_DISABLE_COOPMAT")) {
device->coopmat_support = true;
device->coopmat.m = 0;
device->coopmat.n = 0;
device->coopmat.k = 0;
device->coopmat_m = 0;
device->coopmat_n = 0;
device->coopmat_k = 0;
#endif
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
@ -5139,12 +5134,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) {
// coopmat sizes not set yet
if (device->coopmat.m == 0) {
if (device->coopmat_m == 0) {
device->coopmat_acc_f32_support = true;
device->coopmat.m = prop.MSize;
device->coopmat.n = prop.NSize;
device->coopmat.k = prop.KSize;
} else if (device->coopmat.m == prop.MSize && device->coopmat.n == prop.NSize && device->coopmat.k == prop.KSize) {
device->coopmat_m = prop.MSize;
device->coopmat_n = prop.NSize;
device->coopmat_k = prop.KSize;
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
// Only enable if shape is identical
device->coopmat_acc_f32_support = true;
}
@ -5154,12 +5149,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
} else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
// coopmat sizes not set yet
if (device->coopmat.m == 0) {
if (device->coopmat_m == 0) {
device->coopmat_acc_f16_support = true;
device->coopmat.m = prop.MSize;
device->coopmat.n = prop.NSize;
device->coopmat.k = prop.KSize;
} else if (device->coopmat.m == prop.MSize && device->coopmat.n == prop.NSize && device->coopmat.k == prop.KSize) {
device->coopmat_m = prop.MSize;
device->coopmat_n = prop.NSize;
device->coopmat_k = prop.KSize;
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
// Only enable if shape is identical
device->coopmat_acc_f16_support = true;
}
@ -5172,12 +5167,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
(vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 &&
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
device->coopmat_int.m == 0
device->coopmat_int_m == 0
) {
device->coopmat_int_support = true;
device->coopmat_int.m = prop.MSize;
device->coopmat_int.n = prop.NSize;
device->coopmat_int.k = prop.KSize;
device->coopmat_int_m = prop.MSize;
device->coopmat_int_n = prop.NSize;
device->coopmat_int_k = prop.KSize;
}
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
@ -5187,12 +5182,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
) {
// coopmat sizes not set yet
if (device->coopmat.m == 0) {
if (device->coopmat_m == 0) {
device->coopmat_bf16_support = true;
device->coopmat.m = prop.MSize;
device->coopmat.n = prop.NSize;
device->coopmat.k = prop.KSize;
} else if (device->coopmat.m == prop.MSize && device->coopmat.n == prop.NSize && device->coopmat.k == prop.KSize) {
device->coopmat_m = prop.MSize;
device->coopmat_n = prop.NSize;
device->coopmat_k = prop.KSize;
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
// Only enable if shape is identical
device->coopmat_bf16_support = true;
}
@ -5200,7 +5195,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
#endif
}
if (device->coopmat.m == 0 || !device->coopmat_acc_f32_support) {
if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
// No suitable matmul mode found
GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
device->coopmat_support = false;