cuda : add RDNA4-specific MMVQ parameter table for bs=1 decode (#19478)

* mmvq: add RDNA3/RDNA4-specific parameter table (nwarps=8, rows=1)

* mmvq: add dedicated RDNA3 parameter table

* mmvq: exclude RDNA3.5 (gfx1150/1151) from RDNA3 table
This commit is contained in:
PikaPikachu 2026-03-15 15:33:39 +08:00 committed by GitHub
parent 1a3d8edbba
commit 617db241aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 81 additions and 16 deletions

View File

@ -60,11 +60,17 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
enum mmvq_parameter_table_id {
MMVQ_PARAMETERS_GENERIC = 0,
MMVQ_PARAMETERS_GCN,
MMVQ_PARAMETERS_RDNA2
MMVQ_PARAMETERS_RDNA2,
MMVQ_PARAMETERS_RDNA3_0,
MMVQ_PARAMETERS_RDNA4
};
static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
#if defined(RDNA4)
return MMVQ_PARAMETERS_RDNA4;
#elif defined(RDNA3_0)
return MMVQ_PARAMETERS_RDNA3_0;
#elif defined(RDNA2) || defined(RDNA3_5)
return MMVQ_PARAMETERS_RDNA2;
#elif defined(GCN) || defined(CDNA)
return MMVQ_PARAMETERS_GCN;
@ -74,7 +80,13 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
}
static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return MMVQ_PARAMETERS_RDNA4;
}
if (GGML_CUDA_CC_IS_RDNA3_0(cc)) {
return MMVQ_PARAMETERS_RDNA3_0;
}
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc)) {
return MMVQ_PARAMETERS_RDNA2;
}
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
@ -83,7 +95,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
return MMVQ_PARAMETERS_GENERIC;
}
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC) {
switch (ncols_dst) {
case 1:
@ -114,6 +126,50 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet
return 1;
}
}
if (table_id == MMVQ_PARAMETERS_RDNA4) {
// nwarps=8 benefits types with simple vec_dot on RDNA4 (ncols_dst=1).
// Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register
// pressure and lookup table contention at higher thread counts.
if (ncols_dst == 1) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
return 8;
default:
return 1;
}
}
return 1;
}
if (table_id == MMVQ_PARAMETERS_RDNA3_0) {
// RDNA3 (W7900): stricter whitelist than RDNA4.
// Q2_K / Q5_K / IQ4_XS regress in full quant sweeps.
if (ncols_dst == 1) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ4_NL:
return 8;
default:
return 1;
}
}
return 1;
}
return 1;
}
@ -138,7 +194,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
}
template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@ -151,7 +207,7 @@ static __global__ void mul_mat_vec_q(
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
@ -355,12 +411,13 @@ static __global__ void mul_mat_vec_q(
}
}
template<ggml_type type>
static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
const int warp_size, const mmvq_parameter_table_id table_id) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1);
return {block_nums, block_dims};
}
@ -420,7 +477,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
if (has_ids && ncols_dst > 1) {
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -431,7 +488,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
switch (ncols_dst) {
case 1: {
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -439,7 +496,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 2: {
constexpr int c_ncols_dst = 2;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -447,7 +504,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 3: {
constexpr int c_ncols_dst = 3;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -455,7 +512,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 4: {
constexpr int c_ncols_dst = 4;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -463,7 +520,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 5: {
constexpr int c_ncols_dst = 5;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -471,7 +528,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 6: {
constexpr int c_ncols_dst = 6;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -479,7 +536,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 7: {
constexpr int c_ncols_dst = 7;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
@ -487,7 +544,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
} break;
case 8: {
constexpr int c_ncols_dst = 8;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,

View File

@ -207,6 +207,14 @@
#define RDNA3
#endif // defined(__GFX11__)
#if defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3_5
#endif // defined(__gfx1150__) || defined(__gfx1151__)
#if defined(RDNA3) && !defined(RDNA3_5)
#define RDNA3_0
#endif // defined(RDNA3) && !defined(RDNA3_5)
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
#define RDNA2