mmvq: add RDNA4-specific parameter table (nwarps=8, rows=1)
Add a dedicated MMVQ_PARAMETERS_RDNA4 entry separate from RDNA2/RDNA3. For bs=1 decode on RDNA4 (gfx1201), optimal config is nwarps=8 rows=1: - 8 warps × 32 threads = 256 threads per block - blocks_per_iter = vdr*nwarps*warp_size/qi = 2*8*32/4 = 128 - For K=4096: blocks_per_row=128, entire K dimension in single iteration - Maximizes memory-level parallelism on RDNA4 Benchmark (Llama 2 7B Q4_0, AMD Radeon AI PRO R9700): Master: 95.05 tok/s (tg128) nwarps=8: 104.82 tok/s (tg128) → +10.3% pp512: no regression (1448 vs 1449 tok/s)
This commit is contained in:
parent
98e57ca422
commit
49a5ff40e2
|
|
@ -5,6 +5,11 @@
|
|||
|
||||
#include <cstdint>
|
||||
|
||||
// RDNA4 tunable parameters for bs=1
|
||||
#ifndef MMVQ_RDNA4_NWARPS
|
||||
#define MMVQ_RDNA4_NWARPS 8
|
||||
#endif
|
||||
|
||||
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
|
||||
|
||||
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
|
||||
|
|
@ -60,11 +65,14 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
|||
enum mmvq_parameter_table_id {
|
||||
MMVQ_PARAMETERS_GENERIC = 0,
|
||||
MMVQ_PARAMETERS_GCN,
|
||||
MMVQ_PARAMETERS_RDNA2
|
||||
MMVQ_PARAMETERS_RDNA2,
|
||||
MMVQ_PARAMETERS_RDNA4
|
||||
};
|
||||
|
||||
static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
|
||||
#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
|
||||
#if defined(RDNA4)
|
||||
return MMVQ_PARAMETERS_RDNA4;
|
||||
#elif defined(RDNA2) || defined(RDNA3)
|
||||
return MMVQ_PARAMETERS_RDNA2;
|
||||
#elif defined(GCN) || defined(CDNA)
|
||||
return MMVQ_PARAMETERS_GCN;
|
||||
|
|
@ -74,7 +82,10 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
|
|||
}
|
||||
|
||||
static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
|
||||
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||
return MMVQ_PARAMETERS_RDNA4;
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
|
||||
return MMVQ_PARAMETERS_RDNA2;
|
||||
}
|
||||
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
|
|
@ -114,6 +125,14 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet
|
|||
return 1;
|
||||
}
|
||||
}
|
||||
if (table_id == MMVQ_PARAMETERS_RDNA4) {
|
||||
switch (ncols_dst) {
|
||||
case 1:
|
||||
return MMVQ_RDNA4_NWARPS;
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue