templatize multi_token_path

This commit is contained in:
Aman Gupta 2026-01-31 05:56:47 +01:00
parent 3183b72b16
commit 459b75b3cd
2 changed files with 82 additions and 38 deletions

View File

@ -4,7 +4,7 @@
#include "mmvf.cuh"
#include "convert.cuh"
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false>
static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
@ -14,19 +14,38 @@ static __global__ void mul_mat_vec_f(
const int row = blockIdx.x;
// for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
const int channel_dst = blockIdx.y;
const int token_idx = ids ? blockIdx.z : 0;
const int channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio);
const int channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst;
const int sample_dst = ids ? 0 : blockIdx.z;
const int tid = threadIdx.x;
int token_idx;
int channel_x;
int channel_y;
int sample_dst;
if constexpr (is_multi_token_id) {
// Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
token_idx = blockIdx.z;
channel_x = ids[channel_dst + token_idx * ids_stride];
channel_y = fastmodulo(channel_dst, nchannels_y);
sample_dst = 0;
} else {
token_idx = ids ? blockIdx.z : 0;
channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio);
channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst;
sample_dst = ids ? 0 : blockIdx.z;
}
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
const int sample_y = sample_dst;
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y + token_idx*stride_col_y2*2;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst + token_idx*stride_col_dst;
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
if constexpr (is_multi_token_id) {
y += token_idx*stride_col_y2*2;
dst += token_idx*stride_col_dst;
}
bool use_gate = false;
bool use_bias = false;
@ -354,7 +373,7 @@ static __global__ void mul_mat_vec_f(
}
}
template<typename T, typename type_acc, int ncols_dst, int block_size>
template<typename T, typename type_acc, int ncols_dst, int block_size, bool is_multi_token_id = false>
static void mul_mat_vec_f_switch_fusion(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const uint3 nchannels_y,
@ -366,7 +385,7 @@ static void mul_mat_vec_f_switch_fusion(
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@ -376,14 +395,14 @@ static void mul_mat_vec_f_switch_fusion(
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
}
template <typename T, typename type_acc, int ncols_dst>
template <typename T, typename type_acc, int ncols_dst, bool is_multi_token_id = false>
void launch_mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows,
@ -425,49 +444,49 @@ void launch_mul_mat_vec_f_cuda(
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 64: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 96: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 128: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 160: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 192: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 224: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 256: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
@ -490,8 +509,19 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
const bool has_ids = ids != nullptr;
if (has_ids && ncols_dst > 1) {
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
constexpr int c_ncols_dst = 1;
launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst, true>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
ncols_dst, ids_stride, stream);
return;
}
if (has_ids) {
// note: batching ncols_dst is not possible because tokens use different experts, so we use ncols_dst = 1 and iterate via blockIdx.z
// Single-token MUL_MAT_ID path
constexpr int c_ncols_dst = 1;
launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,

View File

@ -137,8 +137,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
return 1;
}
// tell the compiler to use as many registers as it wants, see nwarps definition below
template <ggml_type type, int ncols_dst, bool has_fusion>
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
@ -163,15 +162,23 @@ static __global__ void mul_mat_vec_q(
const int blocks_per_row_x = ncols_x / qk;
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
// for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
const uint32_t channel_dst = blockIdx.y;
const uint32_t token_idx = blockIdx.z;
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst + token_idx * ids_stride] : fastdiv(channel_dst, channel_ratio);
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
uint32_t sample_dst = blockIdx.z;
if constexpr (ncols_dst == 1) {
sample_dst *= !ids_stride; // sample_dst for ids is 0
uint32_t token_idx = 0;
uint32_t channel_x;
uint32_t channel_y;
uint32_t sample_dst;
if constexpr (is_multi_token_id) {
// Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
token_idx = blockIdx.z;
channel_x = ids[channel_dst + token_idx * ids_stride];
channel_y = fastmodulo(channel_dst, nchannels_y);
sample_dst = 0;
} else {
channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
sample_dst = blockIdx.z;
}
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
@ -228,7 +235,10 @@ static __global__ void mul_mat_vec_q(
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = ((const block_q8_1 *) vy) + token_idx*stride_col_y + sample_y*stride_sample_y + channel_y*stride_channel_y;
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
if constexpr (is_multi_token_id) {
y += token_idx*stride_col_y;
}
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@ -280,7 +290,11 @@ static __global__ void mul_mat_vec_q(
return;
}
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0;
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
if constexpr (is_multi_token_id) {
dst += token_idx*stride_col_dst;
}
// sum up partial sums and write back result
#pragma unroll
@ -350,7 +364,7 @@ static std::pair<dim3, dim3> calc_launch_params(
return {block_nums, block_dims};
}
template<ggml_type type, int c_ncols_dst>
template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@ -363,7 +377,7 @@ static void mul_mat_vec_q_switch_fusion(
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@ -373,7 +387,7 @@ static void mul_mat_vec_q_switch_fusion(
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@ -403,11 +417,11 @@ static void mul_mat_vec_q_switch_ncols_dst(
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
const bool has_ids = ids != nullptr;
if (has_ids) {
// note: batching ncols_dst is not possible because token use different experts, so we use ncols_dst = 1 and iterate via blockIdx.z
if (has_ids && ncols_dst > 1) {
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);