diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index a21c536104..addf93205e 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -340,7 +340,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const case 128: case 112: case 256: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } + break; case 512: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } if (!gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; }