diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index ce25ccf427..d18675650f 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -5,6 +5,11 @@ #include +// RDNA4 tunable parameters for bs=1 +#ifndef MMVQ_RDNA4_NWARPS +#define MMVQ_RDNA4_NWARPS 8 +#endif + typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { @@ -60,11 +65,14 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { enum mmvq_parameter_table_id { MMVQ_PARAMETERS_GENERIC = 0, MMVQ_PARAMETERS_GCN, - MMVQ_PARAMETERS_RDNA2 + MMVQ_PARAMETERS_RDNA2, + MMVQ_PARAMETERS_RDNA4 }; static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { -#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4) +#if defined(RDNA4) + return MMVQ_PARAMETERS_RDNA4; +#elif defined(RDNA2) || defined(RDNA3) return MMVQ_PARAMETERS_RDNA2; #elif defined(GCN) || defined(CDNA) return MMVQ_PARAMETERS_GCN; @@ -74,7 +82,10 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { } static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { - if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return MMVQ_PARAMETERS_RDNA4; + } + if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) { return MMVQ_PARAMETERS_RDNA2; } if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { @@ -114,6 +125,14 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet return 1; } } + if (table_id == MMVQ_PARAMETERS_RDNA4) { + switch (ncols_dst) { + case 1: + return MMVQ_RDNA4_NWARPS; + default: + return 1; + } + } return 1; }