vulkan: Use VK_EXT_shader_64bit_indexing to handle large mat_mul(_id) (#18678)

This fixes incoherent output in Llama-4-Maverick-17B-128E-PAB-Q8_0, which
has a mul_mat_id with an A matrix that's Q8_0 8192 x 5120 x 128.

This should work when the number of blocks in the A matrix is less than 2^32
(for mul_mat_vec or mul_mm_cm2), or for mul_mm I think the limit is like
2^32*LOAD_VEC_A elements.

- Divide batch_stride by QUANT_K earlier, so the block index calculation works in 32b.
- Each vk_pipeline_struct has a linked list of pipelines that will allow it to handle
variants. So far this change just adds a single use case for this, compiling with the
e64BitIndexingEXT flag.
- Use the 64b indexing variant when the A matrix is larger than maxStorageBufferRange.

64-bit indexing has some cost - around 3-5% in MoE models, so it's worth the effort
to avoid enabling it unconditionally.
This commit is contained in:
Jeff Bolz 2026-01-12 05:32:13 -06:00 committed by GitHub
parent 1051ecd289
commit 2bbe4c2cf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 156 additions and 59 deletions

View File

@ -119,6 +119,8 @@ struct ggml_backend_vk_context;
// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)
typedef std::shared_ptr<struct vk_pipeline_struct> vk_pipeline;
struct vk_pipeline_struct {
std::string name;
vk::ShaderModule shader_module;
@ -136,9 +138,15 @@ struct vk_pipeline_struct {
std::atomic<bool> compiled {};
// number of registers used, extracted from pipeline executable properties
uint32_t register_count {};
#if defined(VK_EXT_shader_64bit_indexing)
bool is_64b_indexing {};
#endif
// linked list of pipelines for multiple compilation variants.
// currently only used to compile a 64-bit indexing variant.
vk_pipeline next;
};
typedef std::shared_ptr<vk_pipeline_struct> vk_pipeline;
typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
@ -584,6 +592,8 @@ struct vk_device_struct {
bool add_rms_fusion;
uint32_t partials_binding_alignment;
bool shader_64b_indexing;
bool integer_dot_product;
// 0: default, 1: force mmvq, -1: disable mmvq
int32_t mmvq_mode;
@ -2080,6 +2090,19 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
compute_pipeline_create_info.setPNext(&rci);
}
#if defined(VK_EXT_shader_64bit_indexing)
vk::PipelineCreateFlags2CreateInfo pipelineFlags2CreateInfo;
if (pipeline->is_64b_indexing)
{
pipelineFlags2CreateInfo.flags = vk::PipelineCreateFlagBits2::e64BitIndexingEXT;
if (device->pipeline_executable_properties_support) {
pipelineFlags2CreateInfo.flags |= vk::PipelineCreateFlagBits2::eCaptureStatisticsKHR;
}
pipelineFlags2CreateInfo.setPNext(compute_pipeline_create_info.pNext);
compute_pipeline_create_info.setPNext(&pipelineFlags2CreateInfo);
}
#endif
try {
pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
} catch (const vk::SystemError& e) {
@ -3066,7 +3089,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
}
std::vector<std::future<void>> compiles;
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
@ -3074,35 +3097,49 @@ static void ggml_vk_load_shaders(vk_device& device) {
required_subgroup_size = get_subgroup_size(name, device->architecture);
}
if (!pipeline) {
pipeline = std::make_shared<vk_pipeline_struct>();
}
if (!pipeline->initialized) {
pipeline->name = name;
pipeline->parameter_count = parameter_count;
pipeline->push_constant_size = push_constant_size;
pipeline->wg_denoms = wg_denoms;
pipeline->align = align;
pipeline->initialized = true;
}
vk_pipeline *ptr = &base_pipeline;
if (!pipeline->needed || pipeline->compiled) {
return;
int num_pipelines = 1;
#if defined(VK_EXT_shader_64bit_indexing)
if (device->shader_64b_indexing) {
num_pipelines = 2;
}
// TODO: We're no longer benefitting from the async compiles (shaders are
// compiled individually, as needed) and this complexity can be removed.
{
// wait until fewer than N compiles are in progress
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
std::unique_lock<std::mutex> guard(compile_count_mutex);
while (compile_count >= N) {
compile_count_cond.wait(guard);
#endif
for (int i = 0; i < num_pipelines; ++i, ptr = &(*ptr)->next) {
vk_pipeline &pipeline = *ptr;
if (!pipeline) {
pipeline = std::make_shared<vk_pipeline_struct>();
}
if (!pipeline->initialized) {
pipeline->name = name;
pipeline->parameter_count = parameter_count;
pipeline->push_constant_size = push_constant_size;
pipeline->wg_denoms = wg_denoms;
pipeline->align = align;
pipeline->initialized = true;
#if defined(VK_EXT_shader_64bit_indexing)
pipeline->is_64b_indexing = (i == 1);
#endif
}
compile_count++;
}
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
if (!pipeline->needed || pipeline->compiled) {
continue;
}
// TODO: We're no longer benefitting from the async compiles (shaders are
// compiled individually, as needed) and this complexity can be removed.
{
// wait until fewer than N compiles are in progress
uint32_t N = std::max(1u, std::thread::hardware_concurrency());
std::unique_lock<std::mutex> guard(compile_count_mutex);
while (compile_count >= N) {
compile_count_cond.wait(guard);
}
compile_count++;
}
compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
}
};
auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint,
@ -4480,6 +4517,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
bool pipeline_executable_properties_support = false;
device->coopmat_support = false;
device->integer_dot_product = false;
device->shader_64b_indexing = false;
bool bfloat16_support = false;
for (const auto& properties : ext_props) {
@ -4527,6 +4565,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->memory_priority = true;
} else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
device->external_memory_host = true;
#if defined(VK_EXT_shader_64bit_indexing)
} else if (strcmp("VK_EXT_shader_64bit_indexing", properties.extensionName) == 0) {
device->shader_64b_indexing = true;
#endif
}
}
@ -4817,6 +4859,16 @@ static vk_device ggml_vk_get_device(size_t idx) {
device_extensions.push_back("VK_EXT_external_memory_host");
}
#if defined(VK_EXT_shader_64bit_indexing)
VkPhysicalDeviceShader64BitIndexingFeaturesEXT shader_64bit_indexing_features {};
shader_64bit_indexing_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_64_BIT_INDEXING_FEATURES_EXT;
if (device->shader_64b_indexing) {
last_struct->pNext = (VkBaseOutStructure *)&shader_64bit_indexing_features;
last_struct = (VkBaseOutStructure *)&shader_64bit_indexing_features;
device_extensions.push_back("VK_EXT_shader_64bit_indexing");
}
#endif
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
device->pipeline_executable_properties_support = pipeline_executable_properties_support;
@ -6902,6 +6954,20 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_sync_buffers(ctx, subctx);
}
static vk_pipeline ggml_vk_get_64b_indexing_pipeline(ggml_backend_vk_context * ctx, vk_pipeline &pipeline) {
GGML_UNUSED(ctx);
#if defined(VK_EXT_shader_64bit_indexing)
vk_pipeline *ptr = &pipeline;
while (*ptr) {
if ((*ptr)->is_64b_indexing) {
return *ptr;
}
ptr = &(*ptr)->next;
}
#endif
return pipeline;
}
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k) {
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@ -6985,6 +7051,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
}
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
const uint64_t x_ne = ggml_nelements(src0);
@ -7294,6 +7364,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
}
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
}
const bool qx_needs_dequant = x_non_contig;
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
@ -7489,9 +7563,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
gqa_ratio = 1;
}
vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1];
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
}
{
// Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
}
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
@ -7533,7 +7613,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
workgroups_z /= gqa_ratio;
}
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1],
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
d_Qx,
d_Qy,
@ -7583,9 +7663,14 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
const uint32_t channel_stride_y = nb12 / sizeof(float);
vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_nc_f16_f32;
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
}
{
// Request descriptor sets
ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
}
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
@ -7622,7 +7707,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
d_Qx,
d_Qy,
@ -7641,8 +7726,9 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
// Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases
// where the M dimension is very large.
// Split_k doesn't work with M splitting.
// This only supports batchsize == 1.
const size_t nbytes = ggml_nbytes(src0);
const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange;
const bool needs_split = dst->ne[2] == 1 && dst->ne[3] == 1 && nbytes > ctx->device->properties.limits.maxStorageBufferRange;
if (needs_split) {
// Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)
const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);
@ -7784,6 +7870,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
}
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
const uint64_t x_ne = ggml_nelements(src0);
@ -8045,6 +8134,10 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const bool qx_needs_dequant = x_non_contig;
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
}
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT

View File

@ -87,7 +87,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K;
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;

View File

@ -65,9 +65,9 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
a_offset =
#ifdef MUL_MAT_ID
expert_id * p.batch_stride_a;
expert_id * (p.batch_stride_a / QUANT_K);
#else
batch_idx_a * p.batch_stride_a;
batch_idx_a * (p.batch_stride_a / QUANT_K);
#endif
b_offset =
#ifdef MUL_MAT_ID

View File

@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32,
const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
// Compute starting index in matrix B for this superblock
const uint y_idx = i * QUANT_K + 32 * ib32;
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
uint ibi = a_offset + first_row * num_blocks_per_row + i;
// Precompute indices for quantization lookup tables
const uint qh_base = 2 * ib32;

View File

@ -17,7 +17,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32,
const vec4 b_val_1 = vec4(data_b_v4[base_b_idx + 2 * l + 1]);
// index for data_a
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
uint ibi = a_offset + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);

View File

@ -12,7 +12,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint nibble_shift = 4 * (itid & 1);
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
uint ibi = a_offset + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;

View File

@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint y_idx = i * QUANT_K + 16 * itid;
const uint nibble_shift = 4 * (itid & 1);
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
uint ibi = a_offset + first_row * num_blocks_per_row + i;
// Precompute db multiplication factors
float db_vals[NUM_ROWS];
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
@ -22,7 +22,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
db_vals[n] = d * (0.125f + float(scale) * 0.25f);
ibi += num_blocks_per_row;
}
ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
ibi = a_offset + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
// Preload grid and sign data for all l values
vec4 grid0_vals[2], grid1_vals[2];

View File

@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint y_idx = i * QUANT_K + 16 * itid;
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
uint ibi = a_offset + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint signscale = pack32(u16vec2(

View File

@ -10,7 +10,7 @@ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 32 * ib32;
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
uint ibi = a_offset + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF;

View File

@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint y_idx = i * QUANT_K + 16 * itid;
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
uint ibi = a_offset + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint signscale = pack32(u16vec2(

View File

@ -15,7 +15,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
csel ^= 1;
if (!all_threads) { // when we don't have enough blocks to use all threads

View File

@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
csel ^= 1;
if (!all_threads) { // when we don't have enough blocks to use all threads

View File

@ -13,7 +13,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
const uint y2_idx = y1_idx + 128;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];

View File

@ -13,7 +13,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im,
const uint y2_idx = y1_idx + 128;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];

View File

@ -15,7 +15,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row;
csel ^= 1;
if (!all_threads) { // when we don't have enough blocks to use all threads

View File

@ -79,7 +79,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K_Q8_1;
a_offset *= QUANT_K / QUANT_K_Q8_1;
b_offset /= QUANT_K_Q8_1;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];

View File

@ -234,13 +234,13 @@ void main() {
const uint end_k = min(p.K, (ik + 1) * p.k_split);
#endif
uint pos_a = (
uint pos_a =
#ifdef MUL_MAT_ID
expert_idx * p.batch_stride_a +
expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
#else
batch_idx_a * p.batch_stride_a +
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
#endif
ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
#ifdef MUL_MAT_ID
uint pos_b = 0;
#else

View File

@ -250,10 +250,10 @@ void main() {
#endif
#ifdef MUL_MAT_ID
uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
uint pos_a = expert_idx * (p.batch_stride_a / QUANT_K);
uint pos_b = 0;
#else
uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K);
uint pos_b = batch_idx * p.batch_stride_b;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif

View File

@ -189,13 +189,13 @@ void main() {
const uint end_k = min(p.K, (ik + 1) * p.k_split);
#endif
uint pos_a_ib = (
uint pos_a_ib =
#ifdef MUL_MAT_ID
expert_idx * p.batch_stride_a +
expert_idx * (p.batch_stride_a / BK) +
#else
batch_idx_a * p.batch_stride_a +
batch_idx_a * (p.batch_stride_a / BK) +
#endif
ir * BM * p.stride_a + start_k) / BK;
(ir * BM * p.stride_a + start_k) / BK;
#ifdef MUL_MAT_ID
uint pos_b_ib = 0;
#else

View File

@ -7560,6 +7560,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 96, 2592, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 3, 2592, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 1, 2592, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q8_0, GGML_TYPE_F32, 128, 128, false, 8192, 2, 5120)); // Llama-4-Maverick-17B-128E-PAB-Q8_0
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q8_0, GGML_TYPE_F32, 128, 128, false, 8192, 1, 5120)); // Llama-4-Maverick-17B-128E-PAB-Q8_0
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 8192, 1, 5120, {128, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 8192, 512, 5120, {128, 1}, {1, 1}));
#endif
for (ggml_type type_a : all_types) {