fix: update supported function
Co-authored-by: safranowith <bsh155762@gmail.com> Co-authored-by: ye-NX <y8703470@gmail.com>
This commit is contained in:
parent
7bade66d0d
commit
15c48eb069
|
|
@ -15,39 +15,55 @@ bool ggml_sycl_flash_attn_ext_supported(const ggml_tensor * dst) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
const ggml_tensor * K = dst->src[1];
|
const ggml_tensor * K = dst->src[1];
|
||||||
const ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
|
||||||
|
float scale, max_bias, logit_softcap;
|
||||||
|
|
||||||
|
std::memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
||||||
|
std::memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
||||||
|
std::memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
|
if( max_bias != 0.0f || logit_softcap != 0.0f){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (Q == nullptr || K == nullptr || V == nullptr) {
|
if (Q == nullptr || K == nullptr || V == nullptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (Q->type == GGML_TYPE_F32 && K->type == GGML_TYPE_F32 && V->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// if (Q->type == GGML_TYPE_F16 && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
|
||||||
// return true;
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
if (mask != 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (Q->type != GGML_TYPE_F32 || K->type != GGML_TYPE_F32 || V->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t DQK = Q->ne[0];
|
||||||
|
int64_t DV = V->ne[0];
|
||||||
|
|
||||||
|
if (DQK != DV){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (DV != 32 && DV != 64 && DV != 80 && DV != 96 && DV != 112 && DV != 128 && DV != 256 && DV != 512){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
//not support multi-head yet
|
||||||
|
if (Q->ne[2] != 1 || K->ne[2] != 1 || V->ne[2] != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
template<int64_t DQK, int64_t DV>
|
template<int64_t DQK, int64_t DV>
|
||||||
void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
void ggml_sycl_op_flash_attn_2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
const ggml_tensor * K = dst->src[1];
|
const ggml_tensor * K = dst->src[1];
|
||||||
const ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
GGML_ASSERT(Q != nullptr);
|
|
||||||
GGML_ASSERT(K != nullptr);
|
|
||||||
GGML_ASSERT(V != nullptr);
|
|
||||||
GGML_ASSERT(dst != nullptr);
|
|
||||||
|
|
||||||
//not support KV_Cache yet
|
|
||||||
GGML_ASSERT(K->ne[1] == V->ne[1]);
|
|
||||||
|
|
||||||
//not support multi head and gqa yet
|
|
||||||
GGML_ASSERT(Q->ne[2] == 1);
|
|
||||||
GGML_ASSERT(K->ne[2] == 1);
|
|
||||||
GGML_ASSERT(V->ne[2] == 1);
|
|
||||||
|
|
||||||
const float * Q_d = (const float *) Q->data;
|
const float * Q_d = (const float *) Q->data;
|
||||||
const float * K_d = (const float *) K->data;
|
const float * K_d = (const float *) K->data;
|
||||||
const float * V_d = (const float *) V->data;
|
const float * V_d = (const float *) V->data;
|
||||||
|
|
@ -180,6 +196,10 @@ void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
const ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
|
case 32:
|
||||||
|
GGML_ASSERT(V->ne[0] == 32);
|
||||||
|
ggml_sycl_op_flash_attn_2< 32, 32>(ctx, dst);
|
||||||
|
break;
|
||||||
case 64:
|
case 64:
|
||||||
GGML_ASSERT(V->ne[0] == 64);
|
GGML_ASSERT(V->ne[0] == 64);
|
||||||
ggml_sycl_op_flash_attn_2< 64, 64>(ctx, dst);
|
ggml_sycl_op_flash_attn_2< 64, 64>(ctx, dst);
|
||||||
|
|
@ -206,10 +226,10 @@ void ggml_sycl_op_flash_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst)
|
||||||
break;
|
break;
|
||||||
case 576:
|
case 576:
|
||||||
GGML_ASSERT(V->ne[0] == 512);
|
GGML_ASSERT(V->ne[0] == 512);
|
||||||
ggml_sycl_op_flash_attn_2<576, 512>(ctx, dst);
|
ggml_sycl_op_flash_attn_2<512, 512>(ctx, dst);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("Unsupported head size");
|
fprintf(stderr, "Warning: Unsupported head size %ld — skipping op\n", Q->ne[0]);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue