From 1c77ccf01d7201ce376a30111ffcb336f8175aed Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 20 Jan 2026 13:03:43 +0100 Subject: [PATCH] CUDA: use mmvq for mul-mat-id for small batch sizes --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- ggml/src/ggml-cuda/mmvq.cu | 108 ++++++++++++++++++-------------- 2 files changed, 63 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 08383edb40..e4d0870653 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2279,7 +2279,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - if (ne2 == 1) { + if (ne2 <= MMVQ_MAX_BATCH_SIZE) { if (ggml_is_quantized(src0->type)) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); } else { diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index d671551c17..8d5c4dcd31 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -145,7 +145,8 @@ static __global__ void mul_mat_vec_q( const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, - const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, + const uint32_t ids_stride) { constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; @@ -162,11 +163,12 @@ static __global__ void mul_mat_vec_q( const int blocks_per_row_x = ncols_x / qk; constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; - // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. + // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens) const uint32_t channel_dst = blockIdx.y; - const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); - const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; - const uint32_t sample_dst = blockIdx.z; + const uint32_t token_idx = ids ? blockIdx.z : 0; + const uint32_t channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv(channel_dst, channel_ratio); + const uint32_t channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst; + const uint32_t sample_dst = ids ? 0 : blockIdx.z; const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_y = sample_dst; @@ -221,7 +223,7 @@ static __global__ void mul_mat_vec_q( float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; - const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; + const block_q8_1 * y = ((const block_q8_1 *) vy) + token_idx*stride_col_y + sample_y*stride_sample_y + channel_y*stride_channel_y; const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { @@ -273,7 +275,7 @@ static __global__ void mul_mat_vec_q( return; } - dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; + dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0; // sum up partial sums and write back result #pragma unroll @@ -335,10 +337,10 @@ static __global__ void mul_mat_vec_q( } static std::pair calc_launch_params( - const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y, + const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, const int warp_size, const mmvq_parameter_table_id table_id) { const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); - const dim3 block_nums(nblocks, nchannels_y, nsamples_y); + const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); return {block_nums, block_dims}; } @@ -350,7 +352,8 @@ static void mul_mat_vec_q_switch_fusion( const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, - const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) { + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, + const uint32_t ids_stride, cudaStream_t stream) { const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { @@ -358,7 +361,7 @@ static void mul_mat_vec_q_switch_fusion( mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; } } @@ -368,7 +371,7 @@ static void mul_mat_vec_q_switch_fusion( mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } template @@ -379,7 +382,7 @@ static void mul_mat_vec_q_switch_ncols_dst( const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - cudaStream_t stream) { + const int ids_stride, cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); @@ -393,8 +396,19 @@ static void mul_mat_vec_q_switch_ncols_dst( const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + const bool has_ids = ids != nullptr; + + if (has_ids) { + // note: batching ncols_dst is not possible because token use different experts, so we use ncols_dst = 1 and iterate via blockIdx.z + constexpr int c_ncols_dst = 1; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + return; + } - GGML_ASSERT(!ids || ncols_dst == 1); switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; @@ -402,7 +416,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 2: { constexpr int c_ncols_dst = 2; @@ -410,7 +424,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 3: { constexpr int c_ncols_dst = 3; @@ -418,7 +432,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 4: { constexpr int c_ncols_dst = 4; @@ -426,7 +440,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 5: { constexpr int c_ncols_dst = 5; @@ -434,7 +448,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 6: { constexpr int c_ncols_dst = 6; @@ -442,7 +456,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 7: { constexpr int c_ncols_dst = 7; @@ -450,7 +464,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 8: { constexpr int c_ncols_dst = 8; @@ -458,7 +472,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; default: GGML_ABORT("fatal error"); @@ -474,127 +488,127 @@ static void mul_mat_vec_q_switch_type( const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - cudaStream_t stream) { + const int ids_stride, cudaStream_t stream) { switch (type_x) { case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q4_1: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_1: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q8_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_MXFP4: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q2_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q3_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q4_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q6_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_XXS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_XS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_S: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ3_XXS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ1_S: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ1_M: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ4_NL: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ4_XS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ3_S: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; default: GGML_ABORT("fatal error"); @@ -622,7 +636,7 @@ void ggml_cuda_mul_mat_vec_q( GGML_ASSERT( nb0 == ts_dst); GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); - GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE); const float * src1_d = (const float *) src1->data; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; @@ -693,11 +707,13 @@ void ggml_cuda_mul_mat_vec_q( const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_y = ids ? s11 : s12; + const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + mul_mat_vec_q_switch_type( src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, stream); + ne03, ne3, s03, s13, s3, ids_stride, stream); } void ggml_cuda_op_mul_mat_vec_q( @@ -726,7 +742,7 @@ void ggml_cuda_op_mul_mat_vec_q( ggml_cuda_mm_fusion_args_device fusion_local{}; mul_mat_vec_q_switch_type( src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream); GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); }