From 2bbe4c2cf8298114e3908e285125b9d0d1c5bb42 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 12 Jan 2026 05:32:13 -0600 Subject: [PATCH] vulkan: Use VK_EXT_shader_64bit_indexing to handle large mat_mul(_id) (#18678) This fixes incoherent output in Llama-4-Maverick-17B-128E-PAB-Q8_0, which has a mul_mat_id with an A matrix that's Q8_0 8192 x 5120 x 128. This should work when the number of blocks in the A matrix is less than 2^32 (for mul_mat_vec or mul_mm_cm2), or for mul_mm I think the limit is like 2^32*LOAD_VEC_A elements. - Divide batch_stride by QUANT_K earlier, so the block index calculation works in 32b. - Each vk_pipeline_struct has a linked list of pipelines that will allow it to handle variants. So far this change just adds a single use case for this, compiling with the e64BitIndexingEXT flag. - Use the 64b indexing variant when the A matrix is larger than maxStorageBufferRange. 64-bit indexing has some cost - around 3-5% in MoE models, so it's worth the effort to avoid enabling it unconditionally. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 157 ++++++++++++++---- .../vulkan-shaders/mul_mat_vec.comp | 1 - .../vulkan-shaders/mul_mat_vec_base.glsl | 4 +- .../vulkan-shaders/mul_mat_vec_iq1_m.comp | 2 +- .../vulkan-shaders/mul_mat_vec_iq1_s.comp | 2 +- .../vulkan-shaders/mul_mat_vec_iq2_s.comp | 2 +- .../vulkan-shaders/mul_mat_vec_iq2_xs.comp | 4 +- .../vulkan-shaders/mul_mat_vec_iq2_xxs.comp | 2 +- .../vulkan-shaders/mul_mat_vec_iq3_s.comp | 2 +- .../vulkan-shaders/mul_mat_vec_iq3_xxs.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q5_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 2 +- .../vulkan-shaders/mul_mat_vecq.comp | 2 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 8 +- .../vulkan-shaders/mul_mm_cm2.comp | 4 +- .../ggml-vulkan/vulkan-shaders/mul_mmq.comp | 8 +- tests/test-backend-ops.cpp | 5 + 20 files changed, 156 insertions(+), 59 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ba5252b814..4b337cb931 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -119,6 +119,8 @@ struct ggml_backend_vk_context; // Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT. #define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3) +typedef std::shared_ptr vk_pipeline; + struct vk_pipeline_struct { std::string name; vk::ShaderModule shader_module; @@ -136,9 +138,15 @@ struct vk_pipeline_struct { std::atomic compiled {}; // number of registers used, extracted from pipeline executable properties uint32_t register_count {}; + +#if defined(VK_EXT_shader_64bit_indexing) + bool is_64b_indexing {}; +#endif + // linked list of pipelines for multiple compilation variants. + // currently only used to compile a 64-bit indexing variant. + vk_pipeline next; }; -typedef std::shared_ptr vk_pipeline; typedef std::weak_ptr vk_pipeline_ref; static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); @@ -584,6 +592,8 @@ struct vk_device_struct { bool add_rms_fusion; uint32_t partials_binding_alignment; + bool shader_64b_indexing; + bool integer_dot_product; // 0: default, 1: force mmvq, -1: disable mmvq int32_t mmvq_mode; @@ -2080,6 +2090,19 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin compute_pipeline_create_info.setPNext(&rci); } +#if defined(VK_EXT_shader_64bit_indexing) + vk::PipelineCreateFlags2CreateInfo pipelineFlags2CreateInfo; + if (pipeline->is_64b_indexing) + { + pipelineFlags2CreateInfo.flags = vk::PipelineCreateFlagBits2::e64BitIndexingEXT; + if (device->pipeline_executable_properties_support) { + pipelineFlags2CreateInfo.flags |= vk::PipelineCreateFlagBits2::eCaptureStatisticsKHR; + } + pipelineFlags2CreateInfo.setPNext(compute_pipeline_create_info.pNext); + compute_pipeline_create_info.setPNext(&pipelineFlags2CreateInfo); + } +#endif + try { pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; } catch (const vk::SystemError& e) { @@ -3066,7 +3089,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } std::vector> compiles; - auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, + auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { @@ -3074,35 +3097,49 @@ static void ggml_vk_load_shaders(vk_device& device) { required_subgroup_size = get_subgroup_size(name, device->architecture); } - if (!pipeline) { - pipeline = std::make_shared(); - } - if (!pipeline->initialized) { - pipeline->name = name; - pipeline->parameter_count = parameter_count; - pipeline->push_constant_size = push_constant_size; - pipeline->wg_denoms = wg_denoms; - pipeline->align = align; - pipeline->initialized = true; - } + vk_pipeline *ptr = &base_pipeline; - if (!pipeline->needed || pipeline->compiled) { - return; + int num_pipelines = 1; +#if defined(VK_EXT_shader_64bit_indexing) + if (device->shader_64b_indexing) { + num_pipelines = 2; } - // TODO: We're no longer benefitting from the async compiles (shaders are - // compiled individually, as needed) and this complexity can be removed. - { - // wait until fewer than N compiles are in progress - uint32_t N = std::max(1u, std::thread::hardware_concurrency()); - std::unique_lock guard(compile_count_mutex); - while (compile_count >= N) { - compile_count_cond.wait(guard); +#endif + for (int i = 0; i < num_pipelines; ++i, ptr = &(*ptr)->next) { + vk_pipeline &pipeline = *ptr; + if (!pipeline) { + pipeline = std::make_shared(); + } + if (!pipeline->initialized) { + pipeline->name = name; + pipeline->parameter_count = parameter_count; + pipeline->push_constant_size = push_constant_size; + pipeline->wg_denoms = wg_denoms; + pipeline->align = align; + pipeline->initialized = true; +#if defined(VK_EXT_shader_64bit_indexing) + pipeline->is_64b_indexing = (i == 1); +#endif } - compile_count++; - } - compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, - parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); + if (!pipeline->needed || pipeline->compiled) { + continue; + } + // TODO: We're no longer benefitting from the async compiles (shaders are + // compiled individually, as needed) and this complexity can be removed. + { + // wait until fewer than N compiles are in progress + uint32_t N = std::max(1u, std::thread::hardware_concurrency()); + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, + parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); + } }; auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint, @@ -4480,6 +4517,7 @@ static vk_device ggml_vk_get_device(size_t idx) { bool pipeline_executable_properties_support = false; device->coopmat_support = false; device->integer_dot_product = false; + device->shader_64b_indexing = false; bool bfloat16_support = false; for (const auto& properties : ext_props) { @@ -4527,6 +4565,10 @@ static vk_device ggml_vk_get_device(size_t idx) { device->memory_priority = true; } else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) { device->external_memory_host = true; +#if defined(VK_EXT_shader_64bit_indexing) + } else if (strcmp("VK_EXT_shader_64bit_indexing", properties.extensionName) == 0) { + device->shader_64b_indexing = true; +#endif } } @@ -4817,6 +4859,16 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_EXT_external_memory_host"); } +#if defined(VK_EXT_shader_64bit_indexing) + VkPhysicalDeviceShader64BitIndexingFeaturesEXT shader_64bit_indexing_features {}; + shader_64bit_indexing_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_64_BIT_INDEXING_FEATURES_EXT; + if (device->shader_64b_indexing) { + last_struct->pNext = (VkBaseOutStructure *)&shader_64bit_indexing_features; + last_struct = (VkBaseOutStructure *)&shader_64bit_indexing_features; + device_extensions.push_back("VK_EXT_shader_64bit_indexing"); + } +#endif + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->pipeline_executable_properties_support = pipeline_executable_properties_support; @@ -6902,6 +6954,20 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub ggml_vk_sync_buffers(ctx, subctx); } +static vk_pipeline ggml_vk_get_64b_indexing_pipeline(ggml_backend_vk_context * ctx, vk_pipeline &pipeline) { + GGML_UNUSED(ctx); +#if defined(VK_EXT_shader_64bit_indexing) + vk_pipeline *ptr = &pipeline; + while (*ptr) { + if ((*ptr)->is_64b_indexing) { + return *ptr; + } + ptr = &(*ptr)->next; + } +#endif + return pipeline; +} + static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k) { VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; @@ -6985,6 +7051,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { + pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); + } + // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11; const uint64_t x_ne = ggml_nelements(src0); @@ -7294,6 +7364,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); } + if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { + dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv); + } + const bool qx_needs_dequant = x_non_contig; const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); @@ -7489,9 +7563,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c gqa_ratio = 1; } + vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1]; + + if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { + pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); + } + { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); } vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true); @@ -7533,7 +7613,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c workgroups_z /= gqa_ratio; } - ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { d_Qx, d_Qy, @@ -7583,9 +7663,14 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); const uint32_t channel_stride_y = nb12 / sizeof(float); + vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_nc_f16_f32; + if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { + pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); + } + { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); } vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true); @@ -7622,7 +7707,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); - ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { d_Qx, d_Qy, @@ -7641,8 +7726,9 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases // where the M dimension is very large. // Split_k doesn't work with M splitting. + // This only supports batchsize == 1. const size_t nbytes = ggml_nbytes(src0); - const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange; + const bool needs_split = dst->ne[2] == 1 && dst->ne[3] == 1 && nbytes > ctx->device->properties.limits.maxStorageBufferRange; if (needs_split) { // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets) const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]); @@ -7784,6 +7870,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { + pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); + } // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; const uint64_t x_ne = ggml_nelements(src0); @@ -8045,6 +8134,10 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const bool qx_needs_dequant = x_non_contig; const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); + if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { + dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv); + } + // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index b3c96576de..2271be4021 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -87,7 +87,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; get_offsets(a_offset, b_offset, d_offset); - a_offset /= QUANT_K; y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index cfc8b0c7f4..dfb7865936 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -65,9 +65,9 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { a_offset = #ifdef MUL_MAT_ID - expert_id * p.batch_stride_a; + expert_id * (p.batch_stride_a / QUANT_K); #else - batch_idx_a * p.batch_stride_a; + batch_idx_a * (p.batch_stride_a / QUANT_K); #endif b_offset = #ifdef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp index e5cc7ff862..3ea24a76ce 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp @@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { // Compute starting index in matrix B for this superblock const uint y_idx = i * QUANT_K + 32 * ib32; - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; // Precompute indices for quantization lookup tables const uint qh_base = 2 * ib32; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp index c5f5e9cbb2..fd953c8fad 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp @@ -17,7 +17,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const vec4 b_val_1 = vec4(data_b_v4[base_b_idx + 2 * l + 1]); // index for data_a - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const float d = float(data_a[ibi].d); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp index e424af12c5..b4f6d1d6b6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp @@ -12,7 +12,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint nibble_shift = 4 * (itid & 1); const uint ib32 = itid / 2; // 0..7 - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const float d = float(data_a[ibi].d); const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp index 7ec2e04f58..d8dafe5f70 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp @@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint y_idx = i * QUANT_K + 16 * itid; const uint nibble_shift = 4 * (itid & 1); const uint ib32 = itid / 2; // 0..7 - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; // Precompute db multiplication factors float db_vals[NUM_ROWS]; [[unroll]] for (uint n = 0; n < num_rows; ++n) { @@ -22,7 +22,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, db_vals[n] = d * (0.125f + float(scale) * 0.25f); ibi += num_blocks_per_row; } - ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + ibi = a_offset + first_row * num_blocks_per_row + i; [[unroll]] for (uint n = 0; n < num_rows; ++n) { // Preload grid and sign data for all l values vec4 grid0_vals[2], grid1_vals[2]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp index 71bd72d17e..f75dcf8331 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp @@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint y_idx = i * QUANT_K + 16 * itid; const uint ib32 = itid / 2; // 0..7 - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const float d = float(data_a[ibi].d); const uint signscale = pack32(u16vec2( diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp index a4b9ab1f94..5cdf2a89d0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp @@ -10,7 +10,7 @@ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { const uint y_idx = i * QUANT_K + 32 * ib32; - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const float d = float(data_a[ibi].d); const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp index 40849c691f..a88898109a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp @@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint y_idx = i * QUANT_K + 16 * itid; const uint ib32 = itid / 2; // 0..7 - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const float d = float(data_a[ibi].d); const uint signscale = pack32(u16vec2( diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 14093c0de5..619de054cb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -15,7 +15,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint y_idx = i * QUANT_K + y_offset; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; csel ^= 1; if (!all_threads) { // when we don't have enough blocks to use all threads diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 528f224d86..93e48b7901 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co const uint y_idx = i * QUANT_K + y_offset; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; csel ^= 1; if (!all_threads) { // when we don't have enough blocks to use all threads diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 49d91ad591..6af5a81587 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -13,7 +13,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint y2_idx = y1_idx + 128; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 0d61b4966e..3695b47b98 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -13,7 +13,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint y2_idx = y1_idx + 128; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index d7a7f6426e..3e89d91cbb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -15,7 +15,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint y_idx = i * QUANT_K + y_offset; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; csel ^= 1; if (!all_threads) { // when we don't have enough blocks to use all threads diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index ff5f43979d..6fe3e2dc04 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -79,7 +79,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; get_offsets(a_offset, b_offset, d_offset); - a_offset /= QUANT_K_Q8_1; + a_offset *= QUANT_K / QUANT_K_Q8_1; b_offset /= QUANT_K_Q8_1; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index c0c00d28fc..775e9a70f6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -234,13 +234,13 @@ void main() { const uint end_k = min(p.K, (ik + 1) * p.k_split); #endif - uint pos_a = ( + uint pos_a = #ifdef MUL_MAT_ID - expert_idx * p.batch_stride_a + + expert_idx * (p.batch_stride_a / LOAD_VEC_A) + #else - batch_idx_a * p.batch_stride_a + + batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) + #endif - ir * BM * p.stride_a + start_k) / LOAD_VEC_A; + (ir * BM * p.stride_a + start_k) / LOAD_VEC_A; #ifdef MUL_MAT_ID uint pos_b = 0; #else diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index d0d1d8ef72..b6614d2fc5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -250,10 +250,10 @@ void main() { #endif #ifdef MUL_MAT_ID - uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; + uint pos_a = expert_idx * (p.batch_stride_a / QUANT_K); uint pos_b = 0; #else - uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; + uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K); uint pos_b = batch_idx * p.batch_stride_b; uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index cd36e270ab..335d7f6a68 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -189,13 +189,13 @@ void main() { const uint end_k = min(p.K, (ik + 1) * p.k_split); #endif - uint pos_a_ib = ( + uint pos_a_ib = #ifdef MUL_MAT_ID - expert_idx * p.batch_stride_a + + expert_idx * (p.batch_stride_a / BK) + #else - batch_idx_a * p.batch_stride_a + + batch_idx_a * (p.batch_stride_a / BK) + #endif - ir * BM * p.stride_a + start_k) / BK; + (ir * BM * p.stride_a + start_k) / BK; #ifdef MUL_MAT_ID uint pos_b_ib = 0; #else diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 56d277e167..19ef58404e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7560,6 +7560,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 96, 2592, {1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 3, 2592, {1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 1700000, 1, 2592, {1, 1}, {1, 1})); + + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q8_0, GGML_TYPE_F32, 128, 128, false, 8192, 2, 5120)); // Llama-4-Maverick-17B-128E-PAB-Q8_0 + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q8_0, GGML_TYPE_F32, 128, 128, false, 8192, 1, 5120)); // Llama-4-Maverick-17B-128E-PAB-Q8_0 + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 8192, 1, 5120, {128, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 8192, 512, 5120, {128, 1}, {1, 1})); #endif for (ggml_type type_a : all_types) {