From 49a5ff40e2410787fe4586d5449f16ebcad4a240 Mon Sep 17 00:00:00 2001 From: kangletian Date: Tue, 10 Feb 2026 05:20:10 +0000 Subject: [PATCH] mmvq: add RDNA4-specific parameter table (nwarps=8, rows=1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a dedicated MMVQ_PARAMETERS_RDNA4 entry separate from RDNA2/RDNA3. For bs=1 decode on RDNA4 (gfx1201), optimal config is nwarps=8 rows=1: - 8 warps × 32 threads = 256 threads per block - blocks_per_iter = vdr*nwarps*warp_size/qi = 2*8*32/4 = 128 - For K=4096: blocks_per_row=128, entire K dimension in single iteration - Maximizes memory-level parallelism on RDNA4 Benchmark (Llama 2 7B Q4_0, AMD Radeon AI PRO R9700): Master: 95.05 tok/s (tg128) nwarps=8: 104.82 tok/s (tg128) → +10.3% pp512: no regression (1448 vs 1449 tok/s) --- ggml/src/ggml-cuda/mmvq.cu | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index ce25ccf427..d18675650f 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -5,6 +5,11 @@ #include +// RDNA4 tunable parameters for bs=1 +#ifndef MMVQ_RDNA4_NWARPS +#define MMVQ_RDNA4_NWARPS 8 +#endif + typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { @@ -60,11 +65,14 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { enum mmvq_parameter_table_id { MMVQ_PARAMETERS_GENERIC = 0, MMVQ_PARAMETERS_GCN, - MMVQ_PARAMETERS_RDNA2 + MMVQ_PARAMETERS_RDNA2, + MMVQ_PARAMETERS_RDNA4 }; static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { -#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4) +#if defined(RDNA4) + return MMVQ_PARAMETERS_RDNA4; +#elif defined(RDNA2) || defined(RDNA3) return MMVQ_PARAMETERS_RDNA2; #elif defined(GCN) || defined(CDNA) return MMVQ_PARAMETERS_GCN; @@ -74,7 +82,10 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { } static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { - if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return MMVQ_PARAMETERS_RDNA4; + } + if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) { return MMVQ_PARAMETERS_RDNA2; } if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { @@ -114,6 +125,14 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet return 1; } } + if (table_id == MMVQ_PARAMETERS_RDNA4) { + switch (ncols_dst) { + case 1: + return MMVQ_RDNA4_NWARPS; + default: + return 1; + } + } return 1; }