mmvq: add RDNA4-specific parameter table (nwarps=8, rows=1)

Add a dedicated MMVQ_PARAMETERS_RDNA4 entry separate from RDNA2/RDNA3.
For bs=1 decode on RDNA4 (gfx1201), optimal config is nwarps=8 rows=1:
- 8 warps × 32 threads = 256 threads per block
- blocks_per_iter = vdr*nwarps*warp_size/qi = 2*8*32/4 = 128
- For K=4096: blocks_per_row=128, entire K dimension in single iteration
- Maximizes memory-level parallelism on RDNA4

Benchmark (Llama 2 7B Q4_0, AMD Radeon AI PRO R9700):
  Master:   95.05 tok/s (tg128)
  nwarps=8: 104.82 tok/s (tg128) → +10.3%
  pp512: no regression (1448 vs 1449 tok/s)
This commit is contained in:
kangletian 2026-02-10 05:20:10 +00:00
parent 98e57ca422
commit 49a5ff40e2
1 changed files with 22 additions and 3 deletions

View File

@ -5,6 +5,11 @@
#include <cstdint>
// RDNA4 tunable parameters for bs=1
#ifndef MMVQ_RDNA4_NWARPS
#define MMVQ_RDNA4_NWARPS 8
#endif
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
@ -60,11 +65,14 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
enum mmvq_parameter_table_id {
MMVQ_PARAMETERS_GENERIC = 0,
MMVQ_PARAMETERS_GCN,
MMVQ_PARAMETERS_RDNA2
MMVQ_PARAMETERS_RDNA2,
MMVQ_PARAMETERS_RDNA4
};
static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
#if defined(RDNA4)
return MMVQ_PARAMETERS_RDNA4;
#elif defined(RDNA2) || defined(RDNA3)
return MMVQ_PARAMETERS_RDNA2;
#elif defined(GCN) || defined(CDNA)
return MMVQ_PARAMETERS_GCN;
@ -74,7 +82,10 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
}
static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return MMVQ_PARAMETERS_RDNA4;
}
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
return MMVQ_PARAMETERS_RDNA2;
}
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
@ -114,6 +125,14 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet
return 1;
}
}
if (table_id == MMVQ_PARAMETERS_RDNA4) {
switch (ncols_dst) {
case 1:
return MMVQ_RDNA4_NWARPS;
default:
return 1;
}
}
return 1;
}