56 lines
1.9 KiB
C++
56 lines
1.9 KiB
C++
#include <sycl/sycl.hpp>
|
|
#include <sycl/ext/oneapi/work_group_static.hpp>
|
|
#include "dpct/helper.hpp"
|
|
#include "common.hpp"
|
|
#include "fattn-common.hpp"
|
|
#include "fattn-tile.hpp"
|
|
#include <cmath>
|
|
#include <float.h>
|
|
namespace syclex = sycl::ext::oneapi::experimental;
|
|
|
|
void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
const ggml_tensor * K = dst->src[1];
|
|
const ggml_tensor * V = dst->src[2];
|
|
switch (K->ne[0]) {
|
|
case 40: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_sycl_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
|
|
} break;
|
|
case 64: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_sycl_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
|
|
} break;
|
|
case 72: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_sycl_flash_attn_ext_tile_case< 72, 72>(ctx, dst);
|
|
} break;
|
|
case 80: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_sycl_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
|
|
} break;
|
|
case 96: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_sycl_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
|
|
} break;
|
|
case 112: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_sycl_flash_attn_ext_tile_case<112, 112>(ctx, dst);
|
|
} break;
|
|
case 128: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_sycl_flash_attn_ext_tile_case<128, 128>(ctx, dst);
|
|
} break;
|
|
case 256: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
|
} break;
|
|
case 576: {
|
|
GGML_ASSERT(V->ne[0] == 512);
|
|
ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst);
|
|
} break;
|
|
default: {
|
|
GGML_ABORT("Unsupported head size");
|
|
} break;
|
|
}
|
|
}
|