From 271191906c3ff0a02916622f703166b6891fce0e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 20 Jan 2026 12:21:28 +0200 Subject: [PATCH] metal : enable FA for MLA heads (#18950) --- ggml/src/ggml-metal/ggml-metal-device.m | 8 ++------ ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 13 ++++++++----- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c418afe9c3..eb4e2c209c 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1078,12 +1078,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[0]->ne[0] != 112 && op->src[0]->ne[0] != 128 && op->src[0]->ne[0] != 192 && - op->src[0]->ne[0] != 256) { - return false; - } - if (op->src[0]->ne[0] == 576) { - // DeepSeek sizes - // TODO: disabled for now, until optmized + op->src[0]->ne[0] != 256 && + op->src[0]->ne[0] != 576) { return false; } if (op->src[1]->type != op->src[2]->type) { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 3d97d3dfdc..7f4cfbba22 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2520,7 +2520,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { // simdgroups per threadgroup (a.k.a. warps) //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - int32_t nsg = 4; + int32_t nsg = ne00 >= 512 ? 8 : 4; const size_t smem = FATTN_SMEM(nsg); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a4e1cafe55..17e358d1a8 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5552,9 +5552,7 @@ void kernel_flash_attn_ext_impl( constexpr short NC = (C/8)/NSG; - // note: do not unroll for large heads - #pragma unroll (DK <= 64 ? NC : 1) - for (short cc = 0; cc < NC; ++cc) { + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); if (DK % 16 != 0) { @@ -5575,7 +5573,9 @@ void kernel_flash_attn_ext_impl( k8x8_t mk[2]; q8x8_t mq[2]; - FOR_UNROLL (short i = 0; i < DK8/2; ++i) { + // note: too much unroll can tank the performance for large heads + #pragma unroll (MIN(DK8/2, 4*NSG)) + for (short i = 0; i < DK8/2; ++i) { simdgroup_barrier(mem_flags::mem_none); simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); @@ -5749,7 +5749,9 @@ void kernel_flash_attn_ext_impl( pv += 8*NS20; } } else { - FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + constexpr short NC = (C/8)/2; + + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { s8x8_t vs[2]; simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); @@ -5952,6 +5954,7 @@ kernel void kernel_flash_attn_ext( //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break; } #undef FWD_TMPL #undef FWD_ARGS