From ccb87fa3ee1961ec915f77cb447706f471dca6a5 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sun, 22 Mar 2026 14:19:35 +0530 Subject: [PATCH] [CUDA] Increase number of output elements per-thread block if the K-dimension is small (#20635) * Increase per-thread work if the K-dimension is small With tensor parallelism, the K-dimension of the FFN-down matrices is split, which makes it quite small, especially for MOEs. For example, Qwen3-30b-A3B has a K-dimension of 768, and Qwen3235B-A22B has k-dimension of 1536. The current heuristic uses a group of 4 warps irrespective of K-dimension size, resulting in some of the threads being idle. This results in poor performance for these matrices. This change increases the number of output elements per block for such cases. * Limit this change to ncols_dst = 1 * tab to space --- ggml/src/ggml-cuda/mmvq.cu | 56 +++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 632246e43f..024b3d8cf2 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) } } -static constexpr __device__ int get_vdr_mmvq(ggml_type type) { +static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; @@ -173,11 +173,11 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d return 1; } -static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) { if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { switch (ncols_dst) { case 1: - return 1; + return small_k ? nwarps : 1; case 2: case 3: case 4: @@ -193,7 +193,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -template +template __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, @@ -208,7 +208,7 @@ static __global__ void mul_mat_vec_q( constexpr int vdr = get_vdr_mmvq(type); constexpr mmvq_parameter_table_id table_id = get_device_table_id(); constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id); - constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); @@ -414,14 +414,16 @@ static __global__ void mul_mat_vec_q( template static std::pair calc_launch_params( const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, - const int warp_size, const mmvq_parameter_table_id table_id) { - const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); + const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) { + const int nwarps = calc_nwarps(type, ncols_dst, table_id); + const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); + const int64_t nblocks = (nrows_x + rpb - 1) / rpb; const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); - const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1); + const dim3 block_dims(warp_size, nwarps, 1); return {block_nums, block_dims}; } -template +template static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, @@ -434,7 +436,7 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -444,7 +446,7 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -488,11 +490,33 @@ static void mul_mat_vec_q_switch_ncols_dst( switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + + // When K is small, increase rows_per_block to match nwarps so each warp has more work to do + // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_iter_1warp = vdr * warp_size / qi; + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + if (use_small_k) { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, + warp_size, table_id, true); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } else { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, + warp_size, table_id); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } } break; case 2: { constexpr int c_ncols_dst = 2;