diff --git a/ggml/src/ggml-sycl/fattn.cpp b/ggml/src/ggml-sycl/fattn.cpp index ea6730145f..3fd1c16f3a 100644 --- a/ggml/src/ggml-sycl/fattn.cpp +++ b/ggml/src/ggml-sycl/fattn.cpp @@ -15,18 +15,47 @@ bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + float scale, max_bias, logit_softcap; + + std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + if( max_bias != 0.0f || logit_softcap != 0.0f){ + return false; + } if (Q == nullptr || K == nullptr || V == nullptr) { return false; } - if (Q->type == GGML_TYPE_F32 && K->type == GGML_TYPE_F32 && V->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return true; - } - // if (Q->type == GGML_TYPE_F16 && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - // return true; - // } - return false; + if (mask != 0) { + return false; + } + + if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + int64_t DQK = Q->ne[0]; + int64_t DV = V->ne[0]; + + if (DQK != DV){ + return false; + } + + if (DV != 32 && DV != 64 && DV != 80 && DV != 96 && DV != 112 && DV != 128 && DV != 256 && DV != 512){ + return false; + } + + //not support multi-head yet + if (Q->ne[2] != 1 || K->ne[2] != 1 || V->ne[2] != 1) { + return false; + } + + return true; } template @@ -35,19 +64,6 @@ void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * ds const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; - GGML_ASSERT(Q != nullptr); - GGML_ASSERT(K != nullptr); - GGML_ASSERT(V != nullptr); - GGML_ASSERT(dst != nullptr); - - //not support KV_Cache yet - GGML_ASSERT(K->ne[1] == V->ne[1]); - - //not support multi head and gqa yet - GGML_ASSERT(Q->ne[2] == 1); - GGML_ASSERT(K->ne[2] == 1); - GGML_ASSERT(V->ne[2] == 1); - const float * Q_d = (const float *) Q->data; const float * K_d = (const float *) K->data; const float * V_d = (const float *) V->data; @@ -180,6 +196,10 @@ void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) const ggml_tensor * V = dst->src[2]; switch (Q->ne[0]) { + case 32: + GGML_ASSERT(V->ne[0] == 32); + ggml_sycl_op_flash_attn_2< 32, 32>(ctx, dst); + break; case 64: GGML_ASSERT(V->ne[0] == 64); ggml_sycl_op_flash_attn_2< 64, 64>(ctx, dst); @@ -206,10 +226,10 @@ void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) break; case 576: GGML_ASSERT(V->ne[0] == 512); - ggml_sycl_op_flash_attn_2<576, 512>(ctx, dst); + ggml_sycl_op_flash_attn_2<512, 512>(ctx, dst); break; default: - GGML_ABORT("Unsupported head size"); + fprintf(stderr, "Warning: Unsupported head size %ld — skipping op\n", Q->ne[0]); break; } }