diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 6238ce7ebd..b9ef371e0d 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -765,7 +765,7 @@ bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0 return ne11 <= 8; } else if (GGML_CUDA_CC_IS_AMD(cc)) { if (fp16_mma_hardware_available(cc)) { - if (GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + if (GGML_CUDA_CC_IS_RDNA3(cc)) { return ne11 <= 5; } return ne11 <= 2; @@ -788,6 +788,9 @@ bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0 return ne11 <= 8; } else if (GGML_CUDA_CC_IS_AMD(cc)) { if (bf16_mma_hardware_available(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return ne11 <= 2; + } return ne11 <= 3; } return ne11 <= 8;