Optimize MOE GEMV kernel for BS > 1. (#20905)

* Optimize MOE GEMV kernel for BS > 1.

The previous MOE kernel for BS > 1 had too many thread blocks (nrows_x, nchannels_dst, ncols_dst), with very little work per block. block of (32, 4) was doing inner dot product for a single row.

New mul_mat_vec_q_moe kernel is dedicated for MoE multi-token kernel with grid (ceil(nrows_x/rpb), nchannels_dst), block (warp_size, ncols_dst). Each warp handles two rows independently with warp-level reduction only (no shared memory sync).

This change doesn't increase any compilation time as a single template instance is needed per type. This also simplifies the original GEMV kernel and gets rid of `is_multi_token_id` specialization.

* Remove em-dashes

* Cherry-pick changes from @am17an PR https://github.com/ggml-org/llama.cpp/pull/20885 to enable small_k optimization only for cases where it benefits

Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8

* Make the max batch size for MOE GEMV kernel configurable based on GPU arch and datatype

---------

Co-authored-by: Aman Gupta <amangupta052@gmail.com>
This commit is contained in:
Gaurav Garg 2026-03-29 22:05:18 +05:30 committed by GitHub
parent f5d1c4179f
commit ec16a072f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 358 additions and 59 deletions

View File

@ -2343,7 +2343,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
if (ggml_is_quantized(src0->type)) {
if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc);
if (ne2 <= mmvq_mmid_max) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
return;
}
@ -2946,14 +2947,18 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
}
// [TAG_MUL_MAT_ID_CUDA_GRAPHS]
if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
use_cuda_graph = false;
if (node->op == GGML_OP_MUL_MAT_ID) {
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc);
if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) {
// under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
// TODO: figure out a way to enable for larger batch sizes, without hurting performance
// ref: https://github.com/ggml-org/llama.cpp/pull/18958
use_cuda_graph = false;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
#endif
}
}
if (!use_cuda_graph) {

View File

@ -97,6 +97,194 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
return MMVQ_PARAMETERS_GENERIC;
}
// Per-architecture maximum batch size for which MMVQ should be used for MUL_MAT_ID.
// Returns a value <= MMVQ_MAX_BATCH_SIZE. Default is MMVQ_MAX_BATCH_SIZE.
// Check https://github.com/ggml-org/llama.cpp/pull/20905#issuecomment-4145835627 for details
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ1_S: return 6;
case GGML_TYPE_IQ1_M: return 6;
case GGML_TYPE_IQ2_S: return 4;
case GGML_TYPE_IQ2_XS: return 5;
case GGML_TYPE_IQ2_XXS: return 5;
case GGML_TYPE_IQ3_S: return 4;
case GGML_TYPE_IQ3_XXS: return 4;
case GGML_TYPE_IQ4_NL: return 6;
case GGML_TYPE_IQ4_XS: return 5;
case GGML_TYPE_MXFP4: return 4;
case GGML_TYPE_Q2_K: return 4;
case GGML_TYPE_Q3_K: return 4;
case GGML_TYPE_Q4_0: return 6;
case GGML_TYPE_Q4_1: return 6;
case GGML_TYPE_Q4_K: return 5;
case GGML_TYPE_Q5_0: return 6;
case GGML_TYPE_Q5_1: return 6;
case GGML_TYPE_Q5_K: return 5;
case GGML_TYPE_Q6_K: return 4;
case GGML_TYPE_Q8_0: return 4;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ2_S: return 7;
case GGML_TYPE_IQ3_S: return 6;
case GGML_TYPE_IQ3_XXS: return 7;
case GGML_TYPE_MXFP4: return 7;
case GGML_TYPE_Q2_K: return 7;
case GGML_TYPE_Q3_K: return 5;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_gcn(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ1_S: return 5;
case GGML_TYPE_IQ1_M: return 5;
case GGML_TYPE_IQ2_S: return 4;
case GGML_TYPE_IQ2_XS: return 4;
case GGML_TYPE_IQ2_XXS: return 4;
case GGML_TYPE_IQ3_S: return 4;
case GGML_TYPE_IQ3_XXS: return 4;
case GGML_TYPE_IQ4_NL: return 6;
case GGML_TYPE_IQ4_XS: return 4;
case GGML_TYPE_Q2_K: return 4;
case GGML_TYPE_Q3_K: return 4;
case GGML_TYPE_Q4_0: return 5;
case GGML_TYPE_Q4_1: return 5;
case GGML_TYPE_Q4_K: return 4;
case GGML_TYPE_Q5_K: return 4;
case GGML_TYPE_Q6_K: return 4;
case GGML_TYPE_Q8_0: return 4;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_cdna(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ2_S: return 5;
case GGML_TYPE_IQ2_XS: return 5;
case GGML_TYPE_IQ2_XXS: return 5;
case GGML_TYPE_IQ3_S: return 4;
case GGML_TYPE_IQ3_XXS: return 5;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna1_rdna2(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ2_S: return 4;
case GGML_TYPE_IQ2_XS: return 4;
case GGML_TYPE_IQ2_XXS: return 4;
case GGML_TYPE_IQ3_S: return 4;
case GGML_TYPE_IQ3_XXS: return 4;
case GGML_TYPE_Q2_K: return 7;
case GGML_TYPE_Q3_K: return 4;
case GGML_TYPE_Q4_K: return 5;
case GGML_TYPE_Q5_K: return 6;
case GGML_TYPE_Q6_K: return 5;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna3(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ1_S: return 6;
case GGML_TYPE_IQ1_M: return 6;
case GGML_TYPE_IQ2_S: return 4;
case GGML_TYPE_IQ2_XS: return 4;
case GGML_TYPE_IQ2_XXS: return 4;
case GGML_TYPE_IQ3_S: return 4;
case GGML_TYPE_IQ3_XXS: return 4;
case GGML_TYPE_IQ4_NL: return 6;
case GGML_TYPE_IQ4_XS: return 6;
case GGML_TYPE_Q4_K: return 4;
case GGML_TYPE_Q5_K: return 4;
case GGML_TYPE_Q6_K: return 4;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type type) {
switch (type) {
case GGML_TYPE_IQ1_S: return 7;
case GGML_TYPE_IQ1_M: return 7;
case GGML_TYPE_IQ2_S: return 4;
case GGML_TYPE_IQ2_XS: return 4;
case GGML_TYPE_IQ2_XXS: return 4;
case GGML_TYPE_IQ3_S: return 4;
case GGML_TYPE_IQ3_XXS: return 4;
case GGML_TYPE_IQ4_NL: return 7;
case GGML_TYPE_IQ4_XS: return 5;
case GGML_TYPE_MXFP4: return 5;
case GGML_TYPE_Q3_K: return 4;
case GGML_TYPE_Q4_0: return 7;
case GGML_TYPE_Q4_1: return 7;
case GGML_TYPE_Q4_K: return 4;
case GGML_TYPE_Q5_0: return 7;
case GGML_TYPE_Q5_1: return 7;
case GGML_TYPE_Q5_K: return 5;
case GGML_TYPE_Q6_K: return 5;
case GGML_TYPE_Q8_0: return 7;
default: return MMVQ_MAX_BATCH_SIZE;
}
}
// Host function: returns the max batch size for the current arch+type at runtime.
int get_mmvq_mmid_max_batch(ggml_type type, int cc) {
// NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID.
if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
return MMVQ_MAX_BATCH_SIZE;
}
if (cc >= GGML_CUDA_CC_TURING) {
return get_mmvq_mmid_max_batch_turing_plus(type);
}
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
return get_mmvq_mmid_max_batch_pascal_older(type);
}
// AMD
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return get_mmvq_mmid_max_batch_rdna4(type);
}
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
return get_mmvq_mmid_max_batch_rdna3(type);
}
if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
}
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return get_mmvq_mmid_max_batch_cdna(type);
}
if (GGML_CUDA_CC_IS_GCN(cc)) {
return get_mmvq_mmid_max_batch_gcn(type);
}
return MMVQ_MAX_BATCH_SIZE;
}
// Device constexpr: returns the max batch size for the current arch+type at compile time.
template <ggml_type type>
static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() {
#if defined(RDNA4)
return get_mmvq_mmid_max_batch_rdna4(type);
#elif defined(RDNA3)
return get_mmvq_mmid_max_batch_rdna3(type);
#elif defined(RDNA2) || defined(RDNA1)
return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
#elif defined(CDNA)
return get_mmvq_mmid_max_batch_cdna(type);
#elif defined(GCN)
return get_mmvq_mmid_max_batch_gcn(type);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ >= GGML_CUDA_CC_ADA_LOVELACE)
return MMVQ_MAX_BATCH_SIZE;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
return get_mmvq_mmid_max_batch_turing_plus(type);
#else
return get_mmvq_mmid_max_batch_pascal_older(type);
#endif
}
static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC) {
switch (ncols_dst) {
@ -195,7 +383,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
return 1;
}
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false, bool small_k = false>
template <ggml_type type, int ncols_dst, bool has_fusion, bool small_k = false>
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
@ -222,22 +410,13 @@ static __global__ void mul_mat_vec_q(
const uint32_t channel_dst = blockIdx.y;
uint32_t token_idx = 0;
uint32_t channel_x;
uint32_t channel_y;
uint32_t sample_dst;
if constexpr (is_multi_token_id) {
// Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
token_idx = blockIdx.z;
channel_x = ids[channel_dst + token_idx * ids_stride];
channel_y = fastmodulo(channel_dst, nchannels_y);
sample_dst = 0;
} else {
channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
sample_dst = blockIdx.z;
}
channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
sample_dst = blockIdx.z;
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const uint32_t sample_y = sample_dst;
@ -294,9 +473,6 @@ static __global__ void mul_mat_vec_q(
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
if constexpr (is_multi_token_id) {
y += token_idx*stride_col_y;
}
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@ -350,10 +526,6 @@ static __global__ void mul_mat_vec_q(
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
if constexpr (is_multi_token_id) {
dst += token_idx*stride_col_dst;
}
// sum up partial sums and write back result
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
@ -413,6 +585,69 @@ static __global__ void mul_mat_vec_q(
}
}
// Dedicated MoE multi-token kernel.
// Grid: (ceil(nrows_x / c_rows_per_block), nchannels_dst)
// Block: (warp_size, ncols_dst) - each warp handles one token independently.
// No shared memory reduction needed since each warp works alone.
template <ggml_type type, int c_rows_per_block>
__launch_bounds__(get_mmvq_mmid_max_batch_for_device<type>()*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q_moe(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids,
float * __restrict__ dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
const uint32_t ncols_dst, const uint32_t ids_stride) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
const uint32_t token_idx = threadIdx.y;
const int row0 = c_rows_per_block*blockIdx.x;
const int blocks_per_row_x = ncols_x / qk;
constexpr int blocks_per_iter = vdr * warp_size / qi;
const uint32_t channel_dst = blockIdx.y;
if (token_idx >= ncols_dst) {
return;
}
const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride];
const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y);
const block_q8_1 * y = ((const block_q8_1 *) vy) + channel_y*stride_channel_y + token_idx*stride_col_y;
const int kbx_offset = channel_x*stride_channel_x + row0*stride_row_x;
// partial sum for each thread
float tmp[c_rows_per_block] = {0.0f};
for (int kbx = threadIdx.x / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
const int kby = kbx * (qk/QK8_1);
const int kqs = vdr * (threadIdx.x % (qi/vdr));
#pragma unroll
for (int i = 0; i < c_rows_per_block; ++i) {
tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_offset + i*stride_row_x + kbx, kqs);
}
}
// Warp-level reduction only - no shared memory needed
#pragma unroll
for (int i = 0; i < c_rows_per_block; ++i) {
tmp[i] = warp_reduce_sum<warp_size>(tmp[i]);
}
// Write results
if (threadIdx.x < c_rows_per_block && (c_rows_per_block == 1 || uint32_t(row0 + threadIdx.x) < nrows_x)) {
dst[channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0 + threadIdx.x] = tmp[threadIdx.x];
}
}
template<ggml_type type>
static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
@ -425,7 +660,7 @@ static std::pair<dim3, dim3> calc_launch_params(
return {block_nums, block_dims};
}
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false, bool small_k = false>
template<ggml_type type, int c_ncols_dst, bool small_k = false>
static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@ -438,7 +673,7 @@ static void mul_mat_vec_q_switch_fusion(
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, true, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@ -448,12 +683,33 @@ static void mul_mat_vec_q_switch_fusion(
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, false, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
}
template <ggml_type type>
static void mul_mat_vec_q_moe_launch(
const void * vx, const void * vy, const int32_t * ids, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
const uint32_t ncols_dst, const uint32_t ids_stride,
const int warp_size, const int nchannels_dst, cudaStream_t stream) {
constexpr int rows_per_block = 2; // 2 gives best perf based on tuning
const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block;
const dim3 block_nums(nblocks_rows, nchannels_dst);
const dim3 block_dims(warp_size, ncols_dst);
mul_mat_vec_q_moe<type, rows_per_block><<<block_nums, block_dims, 0, stream>>>(
vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x,
stride_row_x, stride_col_y, stride_col_dst,
stride_channel_x, stride_channel_y, stride_channel_dst,
ncols_dst, ids_stride);
}
template <ggml_type type>
static void mul_mat_vec_q_switch_ncols_dst(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
@ -472,20 +728,62 @@ static void mul_mat_vec_q_switch_ncols_dst(
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
const int device = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[device].cc;
const int warp_size = ggml_cuda_info().devices[device].warp_size;
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
const mmvq_parameter_table_id table_id = get_device_table_id(cc);
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
const bool has_ids = ids != nullptr;
const auto should_use_small_k = [&](int c_ncols_dst) {
// When K is small, increase rows_per_block to match nwarps so each warp has more work to do
// Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_iter_1warp = vdr * warp_size / qi;
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
constexpr std::array<ggml_type, 2> iq_slow_turing = {
GGML_TYPE_IQ3_XXS,
GGML_TYPE_IQ3_S,
};
constexpr std::array<ggml_type, 8> iq_slow_other = {
GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS,
GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
};
constexpr std::array<ggml_type, 3> slow_pascal = {
GGML_TYPE_IQ3_S,
GGML_TYPE_Q2_K,
GGML_TYPE_Q3_K,
};
const bool is_nvidia_turing_plus = GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_TURING;
const bool is_nvidia_pascal_older = GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA;
if (is_nvidia_turing_plus) {
if (ncols_dst == 1 &&
std::find(iq_slow_turing.begin(), iq_slow_turing.end(), type) != iq_slow_turing.end()) {
use = false;
}
} else if ((ncols_dst == 1 && std::find(iq_slow_other.begin(), iq_slow_other.end(), type) != iq_slow_other.end()) ||
(is_nvidia_pascal_older && std::find(slow_pascal.begin(), slow_pascal.end(), type) != slow_pascal.end()) ||
GGML_CUDA_CC_IS_RDNA(cc)) {
use = false;
}
return use;
};
if (has_ids && ncols_dst > 1) {
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
// Multi-token MUL_MAT_ID path - dedicated MoE kernel
mul_mat_vec_q_moe_launch<type>(
vx, vy, ids, dst, ncols_x, nchannels_y_fd, nrows_x,
stride_row_x, stride_col_y, stride_col_dst,
stride_channel_x, stride_channel_y, stride_channel_dst,
ncols_dst, ids_stride, warp_size, nchannels_dst, stream);
return;
}
@ -493,31 +791,24 @@ static void mul_mat_vec_q_switch_ncols_dst(
case 1: {
constexpr int c_ncols_dst = 1;
// When K is small, increase rows_per_block to match nwarps so each warp has more work to do
// Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_iter_1warp = vdr * warp_size / qi;
const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
bool use_small_k = should_use_small_k(c_ncols_dst);
if (use_small_k) {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id, true);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false, true>(
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
nsamples_dst, warp_size, table_id, true);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
stream);
} else {
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
stream);
}
} break;
case 2: {

View File

@ -1,7 +1,10 @@
#include "common.cuh"
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
// Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID,
// based on the quantization type and GPU architecture (compute capability).
int get_mmvq_mmid_max_batch(ggml_type type, int cc);
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);