diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index cc80eb3ffc..d1239b1c5f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2343,7 +2343,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); if (ne2 <= MMVQ_MAX_BATCH_SIZE) { if (ggml_is_quantized(src0->type)) { - if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) { + const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc); + if (ne2 <= mmvq_mmid_max) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); return; } @@ -2946,14 +2947,18 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { } // [TAG_MUL_MAT_ID_CUDA_GRAPHS] - if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) { - // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs - // TODO: figure out a way to enable for larger batch sizes, without hurting performance - // ref: https://github.com/ggml-org/llama.cpp/pull/18958 - use_cuda_graph = false; + if (node->op == GGML_OP_MUL_MAT_ID) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc); + if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) { + // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs + // TODO: figure out a way to enable for larger batch sizes, without hurting performance + // ref: https://github.com/ggml-org/llama.cpp/pull/18958 + use_cuda_graph = false; #ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); #endif + } } if (!use_cuda_graph) { diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 66bd8beeae..8d80d1dd9a 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -97,6 +97,194 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { return MMVQ_PARAMETERS_GENERIC; } +// Per-architecture maximum batch size for which MMVQ should be used for MUL_MAT_ID. +// Returns a value <= MMVQ_MAX_BATCH_SIZE. Default is MMVQ_MAX_BATCH_SIZE. +// Check https://github.com/ggml-org/llama.cpp/pull/20905#issuecomment-4145835627 for details + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 6; + case GGML_TYPE_IQ1_M: return 6; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 5; + case GGML_TYPE_IQ2_XXS: return 5; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 5; + case GGML_TYPE_MXFP4: return 4; + case GGML_TYPE_Q2_K: return 4; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 6; + case GGML_TYPE_Q4_1: return 6; + case GGML_TYPE_Q4_K: return 5; + case GGML_TYPE_Q5_0: return 6; + case GGML_TYPE_Q5_1: return 6; + case GGML_TYPE_Q5_K: return 5; + case GGML_TYPE_Q6_K: return 4; + case GGML_TYPE_Q8_0: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 7; + case GGML_TYPE_IQ3_S: return 6; + case GGML_TYPE_IQ3_XXS: return 7; + case GGML_TYPE_MXFP4: return 7; + case GGML_TYPE_Q2_K: return 7; + case GGML_TYPE_Q3_K: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_gcn(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 5; + case GGML_TYPE_IQ1_M: return 5; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 4; + case GGML_TYPE_Q2_K: return 4; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 5; + case GGML_TYPE_Q4_1: return 5; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_K: return 4; + case GGML_TYPE_Q6_K: return 4; + case GGML_TYPE_Q8_0: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_cdna(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 5; + case GGML_TYPE_IQ2_XS: return 5; + case GGML_TYPE_IQ2_XXS: return 5; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna1_rdna2(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_Q2_K: return 7; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_K: return 5; + case GGML_TYPE_Q5_K: return 6; + case GGML_TYPE_Q6_K: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna3(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 6; + case GGML_TYPE_IQ1_M: return 6; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 6; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_K: return 4; + case GGML_TYPE_Q6_K: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 7; + case GGML_TYPE_IQ1_M: return 7; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 7; + case GGML_TYPE_IQ4_XS: return 5; + case GGML_TYPE_MXFP4: return 5; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 7; + case GGML_TYPE_Q4_1: return 7; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_0: return 7; + case GGML_TYPE_Q5_1: return 7; + case GGML_TYPE_Q5_K: return 5; + case GGML_TYPE_Q6_K: return 5; + case GGML_TYPE_Q8_0: return 7; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +// Host function: returns the max batch size for the current arch+type at runtime. +int get_mmvq_mmid_max_batch(ggml_type type, int cc) { + // NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID. + if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return MMVQ_MAX_BATCH_SIZE; + } + if (cc >= GGML_CUDA_CC_TURING) { + return get_mmvq_mmid_max_batch_turing_plus(type); + } + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + return get_mmvq_mmid_max_batch_pascal_older(type); + } + // AMD + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return get_mmvq_mmid_max_batch_rdna4(type); + } + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + return get_mmvq_mmid_max_batch_rdna3(type); + } + if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) { + return get_mmvq_mmid_max_batch_rdna1_rdna2(type); + } + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return get_mmvq_mmid_max_batch_cdna(type); + } + if (GGML_CUDA_CC_IS_GCN(cc)) { + return get_mmvq_mmid_max_batch_gcn(type); + } + return MMVQ_MAX_BATCH_SIZE; +} + +// Device constexpr: returns the max batch size for the current arch+type at compile time. +template +static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() { +#if defined(RDNA4) + return get_mmvq_mmid_max_batch_rdna4(type); +#elif defined(RDNA3) + return get_mmvq_mmid_max_batch_rdna3(type); +#elif defined(RDNA2) || defined(RDNA1) + return get_mmvq_mmid_max_batch_rdna1_rdna2(type); +#elif defined(CDNA) + return get_mmvq_mmid_max_batch_cdna(type); +#elif defined(GCN) + return get_mmvq_mmid_max_batch_gcn(type); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ >= GGML_CUDA_CC_ADA_LOVELACE) + return MMVQ_MAX_BATCH_SIZE; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + return get_mmvq_mmid_max_batch_turing_plus(type); +#else + return get_mmvq_mmid_max_batch_pascal_older(type); +#endif +} + static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) { if (table_id == MMVQ_PARAMETERS_GENERIC) { switch (ncols_dst) { @@ -195,7 +383,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -template +template __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, @@ -222,22 +410,13 @@ static __global__ void mul_mat_vec_q( const uint32_t channel_dst = blockIdx.y; - uint32_t token_idx = 0; uint32_t channel_x; uint32_t channel_y; uint32_t sample_dst; - if constexpr (is_multi_token_id) { - // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case - token_idx = blockIdx.z; - channel_x = ids[channel_dst + token_idx * ids_stride]; - channel_y = fastmodulo(channel_dst, nchannels_y); - sample_dst = 0; - } else { - channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); - channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; - sample_dst = blockIdx.z; - } + channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + sample_dst = blockIdx.z; const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_y = sample_dst; @@ -294,9 +473,6 @@ static __global__ void mul_mat_vec_q( float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; - if constexpr (is_multi_token_id) { - y += token_idx*stride_col_y; - } const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { @@ -350,10 +526,6 @@ static __global__ void mul_mat_vec_q( dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; - if constexpr (is_multi_token_id) { - dst += token_idx*stride_col_dst; - } - // sum up partial sums and write back result #pragma unroll for (int j = 0; j < ncols_dst; ++j) { @@ -413,6 +585,69 @@ static __global__ void mul_mat_vec_q( } } +// Dedicated MoE multi-token kernel. +// Grid: (ceil(nrows_x / c_rows_per_block), nchannels_dst) +// Block: (warp_size, ncols_dst) - each warp handles one token independently. +// No shared memory reduction needed since each warp works alone. +template +__launch_bounds__(get_mmvq_mmid_max_batch_for_device()*ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mul_mat_vec_q_moe( + const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, + float * __restrict__ dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x, + const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, + const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, + const uint32_t ncols_dst, const uint32_t ids_stride) { + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); + + const uint32_t token_idx = threadIdx.y; + const int row0 = c_rows_per_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + constexpr int blocks_per_iter = vdr * warp_size / qi; + + const uint32_t channel_dst = blockIdx.y; + + if (token_idx >= ncols_dst) { + return; + } + + const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride]; + const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y); + + const block_q8_1 * y = ((const block_q8_1 *) vy) + channel_y*stride_channel_y + token_idx*stride_col_y; + const int kbx_offset = channel_x*stride_channel_x + row0*stride_row_x; + + // partial sum for each thread + float tmp[c_rows_per_block] = {0.0f}; + + for (int kbx = threadIdx.x / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); + const int kqs = vdr * (threadIdx.x % (qi/vdr)); + +#pragma unroll + for (int i = 0; i < c_rows_per_block; ++i) { + tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_offset + i*stride_row_x + kbx, kqs); + } + } + + // Warp-level reduction only - no shared memory needed +#pragma unroll + for (int i = 0; i < c_rows_per_block; ++i) { + tmp[i] = warp_reduce_sum(tmp[i]); + } + + // Write results + if (threadIdx.x < c_rows_per_block && (c_rows_per_block == 1 || uint32_t(row0 + threadIdx.x) < nrows_x)) { + dst[channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0 + threadIdx.x] = tmp[threadIdx.x]; + } +} + template static std::pair calc_launch_params( const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, @@ -425,7 +660,7 @@ static std::pair calc_launch_params( return {block_nums, block_dims}; } -template +template static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, @@ -438,7 +673,7 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -448,12 +683,33 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } +template +static void mul_mat_vec_q_moe_launch( + const void * vx, const void * vy, const int32_t * ids, float * dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x, + const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, + const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, + const uint32_t ncols_dst, const uint32_t ids_stride, + const int warp_size, const int nchannels_dst, cudaStream_t stream) { + + constexpr int rows_per_block = 2; // 2 gives best perf based on tuning + const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block; + const dim3 block_nums(nblocks_rows, nchannels_dst); + const dim3 block_dims(warp_size, ncols_dst); + + mul_mat_vec_q_moe<<>>( + vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x, + stride_row_x, stride_col_y, stride_col_dst, + stride_channel_x, stride_channel_y, stride_channel_dst, + ncols_dst, ids_stride); +} + template static void mul_mat_vec_q_switch_ncols_dst( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, @@ -472,20 +728,62 @@ static void mul_mat_vec_q_switch_ncols_dst( const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[device].cc; const int warp_size = ggml_cuda_info().devices[device].warp_size; - const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); + const mmvq_parameter_table_id table_id = get_device_table_id(cc); const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const bool has_ids = ids != nullptr; + const auto should_use_small_k = [&](int c_ncols_dst) { + // When K is small, increase rows_per_block to match nwarps so each warp has more work to do + // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_iter_1warp = vdr * warp_size / qi; + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + + constexpr std::array iq_slow_turing = { + GGML_TYPE_IQ3_XXS, + GGML_TYPE_IQ3_S, + }; + constexpr std::array iq_slow_other = { + GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, + GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, + }; + constexpr std::array slow_pascal = { + GGML_TYPE_IQ3_S, + GGML_TYPE_Q2_K, + GGML_TYPE_Q3_K, + }; + + const bool is_nvidia_turing_plus = GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_TURING; + const bool is_nvidia_pascal_older = GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA; + + if (is_nvidia_turing_plus) { + if (ncols_dst == 1 && + std::find(iq_slow_turing.begin(), iq_slow_turing.end(), type) != iq_slow_turing.end()) { + use = false; + } + } else if ((ncols_dst == 1 && std::find(iq_slow_other.begin(), iq_slow_other.end(), type) != iq_slow_other.end()) || + (is_nvidia_pascal_older && std::find(slow_pascal.begin(), slow_pascal.end(), type) != slow_pascal.end()) || + GGML_CUDA_CC_IS_RDNA(cc)) { + use = false; + } + + return use; + }; + if (has_ids && ncols_dst > 1) { - // Multi-token MUL_MAT_ID path only - single-token goes through regular path below - constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + // Multi-token MUL_MAT_ID path - dedicated MoE kernel + mul_mat_vec_q_moe_launch( + vx, vy, ids, dst, ncols_x, nchannels_y_fd, nrows_x, + stride_row_x, stride_col_y, stride_col_dst, + stride_channel_x, stride_channel_y, stride_channel_dst, + ncols_dst, ids_stride, warp_size, nchannels_dst, stream); return; } @@ -493,31 +791,24 @@ static void mul_mat_vec_q_switch_ncols_dst( case 1: { constexpr int c_ncols_dst = 1; - // When K is small, increase rows_per_block to match nwarps so each warp has more work to do - // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qi = ggml_cuda_type_traits::qi; - constexpr int vdr = get_vdr_mmvq(type); - const int blocks_per_row_x = ncols_x / qk; - const int blocks_per_iter_1warp = vdr * warp_size / qi; - const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); - const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + bool use_small_k = should_use_small_k(c_ncols_dst); + if (use_small_k) { - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, - warp_size, table_id, true); - mul_mat_vec_q_switch_fusion( + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, + nsamples_dst, warp_size, table_id, true); + mul_mat_vec_q_switch_fusion( vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, + stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride, + stream); } else { - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, - warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, + nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion( vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, + stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride, + stream); } } break; case 2: { diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 8a154631f6..6bf0a8e867 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -1,7 +1,10 @@ #include "common.cuh" #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. -#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID + +// Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID, +// based on the quantization type and GPU architecture (compute capability). +int get_mmvq_mmid_max_batch(ggml_type type, int cc); void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);