diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 8d80d1dd9a..07b10167bc 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -235,30 +235,33 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type // Host function: returns the max batch size for the current arch+type at runtime. int get_mmvq_mmid_max_batch(ggml_type type, int cc) { // NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID. - if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) { - return MMVQ_MAX_BATCH_SIZE; - } - if (cc >= GGML_CUDA_CC_TURING) { - return get_mmvq_mmid_max_batch_turing_plus(type); - } if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return MMVQ_MAX_BATCH_SIZE; + } + if (cc >= GGML_CUDA_CC_TURING) { + return get_mmvq_mmid_max_batch_turing_plus(type); + } return get_mmvq_mmid_max_batch_pascal_older(type); } + // AMD - if (GGML_CUDA_CC_IS_RDNA4(cc)) { - return get_mmvq_mmid_max_batch_rdna4(type); - } - if (GGML_CUDA_CC_IS_RDNA3(cc)) { - return get_mmvq_mmid_max_batch_rdna3(type); - } - if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) { - return get_mmvq_mmid_max_batch_rdna1_rdna2(type); - } - if (GGML_CUDA_CC_IS_CDNA(cc)) { - return get_mmvq_mmid_max_batch_cdna(type); - } - if (GGML_CUDA_CC_IS_GCN(cc)) { - return get_mmvq_mmid_max_batch_gcn(type); + if (GGML_CUDA_CC_IS_AMD(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return get_mmvq_mmid_max_batch_rdna4(type); + } + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + return get_mmvq_mmid_max_batch_rdna3(type); + } + if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) { + return get_mmvq_mmid_max_batch_rdna1_rdna2(type); + } + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return get_mmvq_mmid_max_batch_cdna(type); + } + if (GGML_CUDA_CC_IS_GCN(cc)) { + return get_mmvq_mmid_max_batch_gcn(type); + } } return MMVQ_MAX_BATCH_SIZE; }