vulkan: Use VK_EXT_shader_64bit_indexing to handle large mat_mul(_id) (#18678)

This fixes incoherent output in Llama-4-Maverick-17B-128E-PAB-Q8_0, which
has a mul_mat_id with an A matrix that's Q8_0 8192 x 5120 x 128.

This should work when the number of blocks in the A matrix is less than 2^32
(for mul_mat_vec or mul_mm_cm2), or for mul_mm I think the limit is like
2^32*LOAD_VEC_A elements.

- Divide batch_stride by QUANT_K earlier, so the block index calculation works in 32b.
- Each vk_pipeline_struct has a linked list of pipelines that will allow it to handle
variants. So far this change just adds a single use case for this, compiling with the
e64BitIndexingEXT flag.
- Use the 64b indexing variant when the A matrix is larger than maxStorageBufferRange.

64-bit indexing has some cost - around 3-5% in MoE models, so it's worth the effort
to avoid enabling it unconditionally.
This commit is contained in:
Jeff Bolz 2026-01-12 05:32:13 -06:00 committed by GitHub
parent 1051ecd289
commit 2bbe4c2cf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 156 additions and 59 deletions

View File

@ -119,6 +119,8 @@ struct ggml_backend_vk_context;
// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT. // Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3) #define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)
typedef std::shared_ptr<struct vk_pipeline_struct> vk_pipeline;
struct vk_pipeline_struct { struct vk_pipeline_struct {
std::string name; std::string name;
vk::ShaderModule shader_module; vk::ShaderModule shader_module;
@ -136,9 +138,15 @@ struct vk_pipeline_struct {
std::atomic<bool> compiled {}; std::atomic<bool> compiled {};
// number of registers used, extracted from pipeline executable properties // number of registers used, extracted from pipeline executable properties
uint32_t register_count {}; uint32_t register_count {};
#if defined(VK_EXT_shader_64bit_indexing)
bool is_64b_indexing {};
#endif
// linked list of pipelines for multiple compilation variants.
// currently only used to compile a 64-bit indexing variant.
vk_pipeline next;
}; };
typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref; typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
@ -584,6 +592,8 @@ struct vk_device_struct {
bool add_rms_fusion; bool add_rms_fusion;
uint32_t partials_binding_alignment; uint32_t partials_binding_alignment;
bool shader_64b_indexing;
bool integer_dot_product; bool integer_dot_product;
// 0: default, 1: force mmvq, -1: disable mmvq // 0: default, 1: force mmvq, -1: disable mmvq
int32_t mmvq_mode; int32_t mmvq_mode;
@ -2080,6 +2090,19 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
compute_pipeline_create_info.setPNext(&rci); compute_pipeline_create_info.setPNext(&rci);
} }
#if defined(VK_EXT_shader_64bit_indexing)
vk::PipelineCreateFlags2CreateInfo pipelineFlags2CreateInfo;
if (pipeline->is_64b_indexing)
{
pipelineFlags2CreateInfo.flags = vk::PipelineCreateFlagBits2::e64BitIndexingEXT;
if (device->pipeline_executable_properties_support) {
pipelineFlags2CreateInfo.flags |= vk::PipelineCreateFlagBits2::eCaptureStatisticsKHR;
}
pipelineFlags2CreateInfo.setPNext(compute_pipeline_create_info.pNext);
compute_pipeline_create_info.setPNext(&pipelineFlags2CreateInfo);
}
#endif
try { try {
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
} catch (const vk::SystemError& e) { } catch (const vk::SystemError& e) {
@ -3066,7 +3089,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
} }
std::vector<std::future<void>> compiles; std::vector<std::future<void>> compiles;
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
@ -3074,35 +3097,49 @@ static void ggml_vk_load_shaders(vk_device& device) {
required_subgroup_size = get_subgroup_size(name, device->architecture); required_subgroup_size = get_subgroup_size(name, device->architecture);
} }
if (!pipeline) { vk_pipeline *ptr = &base_pipeline;
pipeline = std::make_shared<vk_pipeline_struct>();
}
if (!pipeline->initialized) {
pipeline->name = name;
pipeline->parameter_count = parameter_count;
pipeline->push_constant_size = push_constant_size;
pipeline->wg_denoms = wg_denoms;
pipeline->align = align;
pipeline->initialized = true;
}
if (!pipeline->needed || pipeline->compiled) { int num_pipelines = 1;
return; #if defined(VK_EXT_shader_64bit_indexing)
if (device->shader_64b_indexing) {
num_pipelines = 2;
} }
// TODO: We're no longer benefitting from the async compiles (shaders are #endif
// compiled individually, as needed) and this complexity can be removed. for (int i = 0; i < num_pipelines; ++i, ptr = &(*ptr)->next) {
{ vk_pipeline &pipeline = *ptr;
// wait until fewer than N compiles are in progress if (!pipeline) {
uint32_t N = std::max(1u, std::thread::hardware_concurrency()); pipeline = std::make_shared<vk_pipeline_struct>();
std::unique_lock<std::mutex> guard(compile_count_mutex); }
while (compile_count >= N) { if (!pipeline->initialized) {
compile_count_cond.wait(guard); pipeline->name = name;
pipeline->parameter_count = parameter_count;
pipeline->push_constant_size = push_constant_size;
pipeline->wg_denoms = wg_denoms;
pipeline->align = align;
pipeline->initialized = true;
#if defined(VK_EXT_shader_64bit_indexing)
pipeline->is_64b_indexing = (i == 1);
#endif
} }
compile_count++;
}
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, if (!pipeline->needed || pipeline->compiled) {
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); continue;
}
// TODO: We're no longer benefitting from the async compiles (shaders are
// compiled individually, as needed) and this complexity can be removed.
{
// wait until fewer than N compiles are in progress
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
std::unique_lock<std::mutex> guard(compile_count_mutex);
while (compile_count >= N) {
compile_count_cond.wait(guard);
}
compile_count++;
}
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
}
}; };
auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint, auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint,
@ -4480,6 +4517,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
bool pipeline_executable_properties_support = false; bool pipeline_executable_properties_support = false;
device->coopmat_support = false; device->coopmat_support = false;
device->integer_dot_product = false; device->integer_dot_product = false;
device->shader_64b_indexing = false;
bool bfloat16_support = false; bool bfloat16_support = false;
for (const auto& properties : ext_props) { for (const auto& properties : ext_props) {
@ -4527,6 +4565,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->memory_priority = true; device->memory_priority = true;
} else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) { } else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
device->external_memory_host = true; device->external_memory_host = true;
#if defined(VK_EXT_shader_64bit_indexing)
} else if (strcmp("VK_EXT_shader_64bit_indexing", properties.extensionName) == 0) {
device->shader_64b_indexing = true;
#endif
} }
} }
@ -4817,6 +4859,16 @@ static vk_device ggml_vk_get_device(size_t idx) {
device_extensions.push_back("VK_EXT_external_memory_host"); device_extensions.push_back("VK_EXT_external_memory_host");
} }
#if defined(VK_EXT_shader_64bit_indexing)
VkPhysicalDeviceShader64BitIndexingFeaturesEXT shader_64bit_indexing_features {};
shader_64bit_indexing_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_64_BIT_INDEXING_FEATURES_EXT;
if (device->shader_64b_indexing) {
last_struct->pNext = (VkBaseOutStructure *)&shader_64bit_indexing_features;
last_struct = (VkBaseOutStructure *)&shader_64bit_indexing_features;
device_extensions.push_back("VK_EXT_shader_64bit_indexing");
}
#endif
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
device->pipeline_executable_properties_support = pipeline_executable_properties_support; device->pipeline_executable_properties_support = pipeline_executable_properties_support;
@ -6902,6 +6954,20 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_sync_buffers(ctx, subctx); ggml_vk_sync_buffers(ctx, subctx);
} }
static vk_pipeline ggml_vk_get_64b_indexing_pipeline(ggml_backend_vk_context * ctx, vk_pipeline &pipeline) {
GGML_UNUSED(ctx);
#if defined(VK_EXT_shader_64bit_indexing)
vk_pipeline *ptr = &pipeline;
while (*ptr) {
if ((*ptr)->is_64b_indexing) {
return *ptr;
}
ptr = &(*ptr)->next;
}
#endif
return pipeline;
}
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k) { static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k) {
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@ -6985,6 +7051,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
}
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11; uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
const uint64_t x_ne = ggml_nelements(src0); const uint64_t x_ne = ggml_nelements(src0);
@ -7294,6 +7364,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
} }
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
}
const bool qx_needs_dequant = x_non_contig; const bool qx_needs_dequant = x_non_contig;
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
@ -7489,9 +7563,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
gqa_ratio = 1; gqa_ratio = 1;
} }
vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1];
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
}
{ {
// Request descriptor sets // Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
} }
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true); vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
@ -7533,7 +7613,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
workgroups_z /= gqa_ratio; workgroups_z /= gqa_ratio;
} }
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{ {
d_Qx, d_Qx,
d_Qy, d_Qy,
@ -7583,9 +7663,14 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
const uint32_t channel_stride_y = nb12 / sizeof(float); const uint32_t channel_stride_y = nb12 / sizeof(float);
vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_nc_f16_f32;
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
}
{ {
// Request descriptor sets // Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
} }
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true); vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
@ -7622,7 +7707,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{ {
d_Qx, d_Qx,
d_Qy, d_Qy,
@ -7641,8 +7726,9 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
// Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases
// where the M dimension is very large. // where the M dimension is very large.
// Split_k doesn't work with M splitting. // Split_k doesn't work with M splitting.
// This only supports batchsize == 1.
const size_t nbytes = ggml_nbytes(src0); const size_t nbytes = ggml_nbytes(src0);
const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange; const bool needs_split = dst->ne[2] == 1 && dst->ne[3] == 1 && nbytes > ctx->device->properties.limits.maxStorageBufferRange;
if (needs_split) { if (needs_split) {
// Choose the number of rows that can fit (and divide by two, to allow for any additional offsets) // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)
const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]); const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);
@ -7784,6 +7870,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
}
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
const uint64_t x_ne = ggml_nelements(src0); const uint64_t x_ne = ggml_nelements(src0);
@ -8045,6 +8134,10 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const bool qx_needs_dequant = x_non_contig; const bool qx_needs_dequant = x_non_contig;
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
}
// Not implemented // Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT

View File

@ -87,7 +87,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K;
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;

View File

@ -65,9 +65,9 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
a_offset = a_offset =
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
expert_id * p.batch_stride_a; expert_id * (p.batch_stride_a / QUANT_K);
#else #else
batch_idx_a * p.batch_stride_a; batch_idx_a * (p.batch_stride_a / QUANT_K);
#endif #endif
b_offset = b_offset =
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID

View File

@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32,
const uint num_blocks_per_row, const uint first_row, const uint num_rows) { const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
// Compute starting index in matrix B for this superblock // Compute starting index in matrix B for this superblock
const uint y_idx = i * QUANT_K + 32 * ib32; const uint y_idx = i * QUANT_K + 32 * ib32;
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; uint ibi = a_offset + first_row * num_blocks_per_row + i;
// Precompute indices for quantization lookup tables // Precompute indices for quantization lookup tables
const uint qh_base = 2 * ib32; const uint qh_base = 2 * ib32;

View File

@ -17,7 +17,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32,
const vec4 b_val_1 = vec4(data_b_v4[base_b_idx + 2 * l + 1]); const vec4 b_val_1 = vec4(data_b_v4[base_b_idx + 2 * l + 1]);
// index for data_a // index for data_a
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; uint ibi = a_offset + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d); const float d = float(data_a[ibi].d);

View File

@ -12,7 +12,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint nibble_shift = 4 * (itid & 1); const uint nibble_shift = 4 * (itid & 1);
const uint ib32 = itid / 2; // 0..7 const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; uint ibi = a_offset + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d); const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;

View File

@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint y_idx = i * QUANT_K + 16 * itid; const uint y_idx = i * QUANT_K + 16 * itid;
const uint nibble_shift = 4 * (itid & 1); const uint nibble_shift = 4 * (itid & 1);
const uint ib32 = itid / 2; // 0..7 const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; uint ibi = a_offset + first_row * num_blocks_per_row + i;
// Precompute db multiplication factors // Precompute db multiplication factors
float db_vals[NUM_ROWS]; float db_vals[NUM_ROWS];
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
@ -22,7 +22,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
db_vals[n] = d * (0.125f + float(scale) * 0.25f); db_vals[n] = d * (0.125f + float(scale) * 0.25f);
ibi += num_blocks_per_row; ibi += num_blocks_per_row;
} }
ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; ibi = a_offset + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
// Preload grid and sign data for all l values // Preload grid and sign data for all l values
vec4 grid0_vals[2], grid1_vals[2]; vec4 grid0_vals[2], grid1_vals[2];

View File

@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint y_idx = i * QUANT_K + 16 * itid; const uint y_idx = i * QUANT_K + 16 * itid;
const uint ib32 = itid / 2; // 0..7 const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; uint ibi = a_offset + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d); const float d = float(data_a[ibi].d);
const uint signscale = pack32(u16vec2( const uint signscale = pack32(u16vec2(

View File

@ -10,7 +10,7 @@ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 32 * ib32; const uint y_idx = i * QUANT_K + 32 * ib32;
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; uint ibi = a_offset + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d); const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF; const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF;

View File

@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint y_idx = i * QUANT_K + 16 * itid; const uint y_idx = i * QUANT_K + 16 * itid;
const uint ib32 = itid / 2; // 0..7 const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; uint ibi = a_offset + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d); const float d = float(data_a[ibi].d);
const uint signscale = pack32(u16vec2( const uint signscale = pack32(u16vec2(

View File

@ -15,7 +15,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint y_idx = i * QUANT_K + y_offset; const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
csel ^= 1; csel ^= 1;
if (!all_threads) { // when we don't have enough blocks to use all threads if (!all_threads) { // when we don't have enough blocks to use all threads

View File

@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
const uint y_idx = i * QUANT_K + y_offset; const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
csel ^= 1; csel ^= 1;
if (!all_threads) { // when we don't have enough blocks to use all threads if (!all_threads) { // when we don't have enough blocks to use all threads

View File

@ -13,7 +13,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
const uint y2_idx = y1_idx + 128; const uint y2_idx = y1_idx + 128;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];

View File

@ -13,7 +13,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
const uint y2_idx = y1_idx + 128; const uint y2_idx = y1_idx + 128;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];

View File

@ -15,7 +15,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint y_idx = i * QUANT_K + y_offset; const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
csel ^= 1; csel ^= 1;
if (!all_threads) { // when we don't have enough blocks to use all threads if (!all_threads) { // when we don't have enough blocks to use all threads

View File

@ -79,7 +79,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K_Q8_1; a_offset *= QUANT_K / QUANT_K_Q8_1;
b_offset /= QUANT_K_Q8_1; b_offset /= QUANT_K_Q8_1;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];

View File

@ -234,13 +234,13 @@ void main() {
const uint end_k = min(p.K, (ik + 1) * p.k_split); const uint end_k = min(p.K, (ik + 1) * p.k_split);
#endif #endif
uint pos_a = ( uint pos_a =
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
expert_idx * p.batch_stride_a + expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
#else #else
batch_idx_a * p.batch_stride_a + batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
#endif #endif
ir * BM * p.stride_a + start_k) / LOAD_VEC_A; (ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
uint pos_b = 0; uint pos_b = 0;
#else #else

View File

@ -250,10 +250,10 @@ void main() {
#endif #endif
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; uint pos_a = expert_idx * (p.batch_stride_a / QUANT_K);
uint pos_b = 0; uint pos_b = 0;
#else #else
uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
uint pos_b = batch_idx * p.batch_stride_b; uint pos_b = batch_idx * p.batch_stride_b;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif #endif

View File

@ -189,13 +189,13 @@ void main() {
const uint end_k = min(p.K, (ik + 1) * p.k_split); const uint end_k = min(p.K, (ik + 1) * p.k_split);
#endif #endif
uint pos_a_ib = ( uint pos_a_ib =
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
expert_idx * p.batch_stride_a + expert_idx * (p.batch_stride_a / BK) +
#else #else
batch_idx_a * p.batch_stride_a + batch_idx_a * (p.batch_stride_a / BK) +
#endif #endif
ir * BM * p.stride_a + start_k) / BK; (ir * BM * p.stride_a + start_k) / BK;
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
uint pos_b_ib = 0; uint pos_b_ib = 0;
#else #else

View File

@ -7560,6 +7560,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 96, 2592, {1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 96, 2592, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 3, 2592, {1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 3, 2592, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 1, 2592, {1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 1, 2592, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q8_0, GGML_TYPE_F32, 128, 128, false, 8192, 2, 5120)); // Llama-4-Maverick-17B-128E-PAB-Q8_0
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q8_0, GGML_TYPE_F32, 128, 128, false, 8192, 1, 5120)); // Llama-4-Maverick-17B-128E-PAB-Q8_0
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 8192, 1, 5120, {128, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 8192, 512, 5120, {128, 1}, {1, 1}));
#endif #endif
for (ggml_type type_a : all_types) { for (ggml_type type_a : all_types) {