226 lines
8.1 KiB
C++
226 lines
8.1 KiB
C++
//
|
|
// MIT license
|
|
// Copyright (C) 2025 Intel Corporation
|
|
// SPDX-License-Identifier: MIT
|
|
//
|
|
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
|
|
|
|
#include <sycl/sycl.hpp>
|
|
#include "dpct/helper.hpp"
|
|
#include "common.hpp"
|
|
#include "fattn-common.hpp"
|
|
#include "fattn-tile.hpp"
|
|
#include "fattn-vec.hpp"
|
|
#include "fattn.hpp"
|
|
|
|
|
|
#define FATTN_VEC_CASE(D, type_K, type_V) \
|
|
{ \
|
|
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
|
|
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
|
|
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
|
|
ggml_sycl_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
|
|
return; \
|
|
} \
|
|
} \
|
|
|
|
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
|
|
FATTN_VEC_CASE( 64, type_K, type_V) \
|
|
FATTN_VEC_CASE(128, type_K, type_V) \
|
|
FATTN_VEC_CASE(256, type_K, type_V) \
|
|
|
|
static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
ggml_tensor * Q = dst->src[0];
|
|
ggml_tensor * K = dst->src[1];
|
|
ggml_tensor * V = dst->src[2];
|
|
|
|
#ifdef GGML_SYCL_FA_ALL_QUANTS
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
|
|
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
|
|
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
|
|
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
|
|
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
|
|
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
|
#else
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
|
|
FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
|
|
#endif // GGML_SYCL_FA_ALL_QUANTS
|
|
|
|
GGML_ABORT("Not match KV type in vec");
|
|
}
|
|
|
|
// Best FlashAttention kernel for a specific GPU:
|
|
enum best_fattn_kernel {
|
|
BEST_FATTN_KERNEL_NONE = 0,
|
|
BEST_FATTN_KERNEL_VEC = 100,
|
|
BEST_FATTN_KERNEL_TILE = 200,
|
|
};
|
|
|
|
static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
|
|
GGML_UNUSED(device);
|
|
#ifndef SYCL_FLASH_ATTN
|
|
GGML_UNUSED(dst);
|
|
return BEST_FATTN_KERNEL_NONE;
|
|
#endif// SYCL_FLASH_ATTN
|
|
|
|
if(!g_ggml_sycl_enable_flash_attention) return BEST_FATTN_KERNEL_NONE;
|
|
|
|
const ggml_tensor * KQV = dst;
|
|
const ggml_tensor * Q = dst->src[0];
|
|
const ggml_tensor * K = dst->src[1];
|
|
const ggml_tensor * V = dst->src[2];
|
|
const ggml_tensor * mask = dst->src[3];
|
|
|
|
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
|
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
|
|
|
float max_bias = 0.0f;
|
|
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
|
|
|
bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
|
for (const ggml_tensor * t : {Q, K, V, mask}) {
|
|
if (t == nullptr || ggml_is_quantized(t->type)) {
|
|
continue;
|
|
}
|
|
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
|
if (t->nb[i] % 16 != 0) {
|
|
gqa_opt_applies = false;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
switch (K->ne[0]) {
|
|
case 40:
|
|
case 64:
|
|
case 72:
|
|
case 80:
|
|
case 96:
|
|
case 128:
|
|
case 112:
|
|
case 256:
|
|
if (V->ne[0] != K->ne[0]) {
|
|
return BEST_FATTN_KERNEL_NONE;
|
|
}
|
|
break;
|
|
case 576:
|
|
if (V->ne[0] != 512) {
|
|
return BEST_FATTN_KERNEL_NONE;
|
|
}
|
|
if (!gqa_opt_applies) {
|
|
return BEST_FATTN_KERNEL_NONE;
|
|
}
|
|
break;
|
|
default:
|
|
return BEST_FATTN_KERNEL_NONE;
|
|
}
|
|
|
|
#ifndef GGML_SYCL_FA_ALL_QUANTS
|
|
if (K->type != V->type) {
|
|
return BEST_FATTN_KERNEL_NONE;
|
|
}
|
|
#endif // GGML_SYCL_FA_ALL_QUANTS
|
|
|
|
switch (K->type) {
|
|
case GGML_TYPE_F32:
|
|
case GGML_TYPE_F16:
|
|
break;
|
|
case GGML_TYPE_Q4_1:
|
|
case GGML_TYPE_Q5_0:
|
|
case GGML_TYPE_Q5_1:
|
|
#ifndef GGML_SYCL_FA_ALL_QUANTS
|
|
return BEST_FATTN_KERNEL_NONE;
|
|
#endif // GGML_SYCL_FA_ALL_QUANTS
|
|
case GGML_TYPE_Q4_0:
|
|
case GGML_TYPE_Q8_0:
|
|
break;
|
|
default:
|
|
return BEST_FATTN_KERNEL_NONE;
|
|
}
|
|
|
|
if (mask && mask->ne[2] != 1) {
|
|
return BEST_FATTN_KERNEL_NONE;
|
|
}
|
|
|
|
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
|
|
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
|
|
|
// Todo: Use the XMX kernel if possible:
|
|
|
|
// If there are no tensor cores available, use the generic tile kernel:
|
|
if (can_use_vector_kernel) {
|
|
if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
|
|
if (Q->ne[1] == 1) {
|
|
if (!gqa_opt_applies) {
|
|
return BEST_FATTN_KERNEL_VEC;
|
|
}
|
|
}
|
|
} else {
|
|
if (Q->ne[1] <= 2) {
|
|
return BEST_FATTN_KERNEL_VEC;
|
|
}
|
|
}
|
|
}
|
|
return BEST_FATTN_KERNEL_TILE;
|
|
}
|
|
|
|
void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
ggml_sycl_set_device(ctx.device);
|
|
switch (ggml_sycl_get_best_fattn_kernel(ggml_sycl_get_device(), dst)) {
|
|
case BEST_FATTN_KERNEL_NONE:
|
|
GGML_ABORT("Not support Flash-Attention");
|
|
case BEST_FATTN_KERNEL_TILE:
|
|
ggml_sycl_flash_attn_ext_tile(ctx, dst);
|
|
break;
|
|
case BEST_FATTN_KERNEL_VEC:
|
|
ggml_sycl_flash_attn_ext_vec(ctx, dst);
|
|
break;
|
|
}
|
|
}
|
|
|
|
bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
|
|
return ggml_sycl_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
|
|
}
|