metal : pad K, V and Mask when needed
This commit is contained in:
parent
d8359f5fde
commit
5d0d2d2289
|
|
@ -930,6 +930,50 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||||
|
ggml_metal_library_t lib,
|
||||||
|
const struct ggml_tensor * op,
|
||||||
|
bool has_mask,
|
||||||
|
int32_t ncpsg) {
|
||||||
|
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||||
|
GGML_UNUSED(op);
|
||||||
|
|
||||||
|
char base[256];
|
||||||
|
char name[256];
|
||||||
|
|
||||||
|
snprintf(base, 256, "kernel_%s",
|
||||||
|
"flash_attn_ext_pad");
|
||||||
|
|
||||||
|
snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
|
||||||
|
base,
|
||||||
|
has_mask,
|
||||||
|
ncpsg);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||||
|
if (res) {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||||
|
|
||||||
|
ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
|
||||||
|
//ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
|
||||||
|
//ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
|
||||||
|
//ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
|
||||||
|
|
||||||
|
//ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
|
||||||
|
//ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
|
||||||
|
//ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
|
||||||
|
//ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
|
||||||
|
ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 24);
|
||||||
|
|
||||||
|
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||||
|
|
||||||
|
ggml_metal_cv_free(cv);
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||||
ggml_metal_library_t lib,
|
ggml_metal_library_t lib,
|
||||||
const ggml_tensor * op,
|
const ggml_tensor * op,
|
||||||
|
|
@ -937,6 +981,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||||
bool has_sinks,
|
bool has_sinks,
|
||||||
bool has_bias,
|
bool has_bias,
|
||||||
bool has_scap,
|
bool has_scap,
|
||||||
|
bool has_kvpad,
|
||||||
int32_t nsg) {
|
int32_t nsg) {
|
||||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||||
|
|
||||||
|
|
@ -955,12 +1000,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||||
dk,
|
dk,
|
||||||
dv);
|
dv);
|
||||||
|
|
||||||
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
|
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||||
base,
|
base,
|
||||||
has_mask,
|
has_mask,
|
||||||
has_sinks,
|
has_sinks,
|
||||||
has_bias,
|
has_bias,
|
||||||
has_scap,
|
has_scap,
|
||||||
|
has_kvpad,
|
||||||
ns10,
|
ns10,
|
||||||
ns20,
|
ns20,
|
||||||
nsg);
|
nsg);
|
||||||
|
|
@ -976,6 +1022,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||||
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
|
||||||
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
|
||||||
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
|
||||||
|
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
|
||||||
|
|
||||||
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
|
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
|
||||||
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
|
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
|
||||||
|
|
@ -995,6 +1042,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||||
bool has_sinks,
|
bool has_sinks,
|
||||||
bool has_bias,
|
bool has_bias,
|
||||||
bool has_scap,
|
bool has_scap,
|
||||||
|
bool has_kvpad,
|
||||||
int32_t nsg,
|
int32_t nsg,
|
||||||
int32_t nwg) {
|
int32_t nwg) {
|
||||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||||
|
|
@ -1014,12 +1062,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||||
dk,
|
dk,
|
||||||
dv);
|
dv);
|
||||||
|
|
||||||
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
||||||
base,
|
base,
|
||||||
has_mask,
|
has_mask,
|
||||||
has_sinks,
|
has_sinks,
|
||||||
has_bias,
|
has_bias,
|
||||||
has_scap,
|
has_scap,
|
||||||
|
has_kvpad,
|
||||||
ns10,
|
ns10,
|
||||||
ns20,
|
ns20,
|
||||||
nsg, nwg);
|
nsg, nwg);
|
||||||
|
|
@ -1035,6 +1084,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||||
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
|
ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
|
||||||
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
|
ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
|
||||||
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
|
ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
|
||||||
|
ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
|
||||||
|
|
||||||
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
|
ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
|
||||||
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
|
ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
|
||||||
|
|
|
||||||
|
|
@ -135,6 +135,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
|
||||||
|
ggml_metal_library_t lib,
|
||||||
|
const struct ggml_tensor * op,
|
||||||
|
bool has_mask,
|
||||||
|
int32_t ncpsg);
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||||
ggml_metal_library_t lib,
|
ggml_metal_library_t lib,
|
||||||
const struct ggml_tensor * op,
|
const struct ggml_tensor * op,
|
||||||
|
|
@ -142,6 +148,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||||
bool has_sinks,
|
bool has_sinks,
|
||||||
bool has_bias,
|
bool has_bias,
|
||||||
bool has_scap,
|
bool has_scap,
|
||||||
|
bool has_kvpad,
|
||||||
int32_t nsg);
|
int32_t nsg);
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||||
|
|
@ -151,6 +158,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||||
bool has_sinks,
|
bool has_sinks,
|
||||||
bool has_bias,
|
bool has_bias,
|
||||||
bool has_scap,
|
bool has_scap,
|
||||||
|
bool has_kvpad,
|
||||||
int32_t nsg,
|
int32_t nsg,
|
||||||
int32_t nwg);
|
int32_t nwg);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -72,11 +72,12 @@
|
||||||
#define N_SG_IQ4_XS 2
|
#define N_SG_IQ4_XS 2
|
||||||
|
|
||||||
// function constants offsets
|
// function constants offsets
|
||||||
#define FC_FLASH_ATTN_EXT 100
|
#define FC_FLASH_ATTN_EXT_PAD 100
|
||||||
#define FC_FLASH_ATTN_EXT_VEC 200
|
#define FC_FLASH_ATTN_EXT 200
|
||||||
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
|
#define FC_FLASH_ATTN_EXT_VEC 300
|
||||||
#define FC_MUL_MV 400
|
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 400
|
||||||
#define FC_MUL_MM 500
|
#define FC_MUL_MV 500
|
||||||
|
#define FC_MUL_MM 600
|
||||||
|
|
||||||
// kernel argument structs
|
// kernel argument structs
|
||||||
//
|
//
|
||||||
|
|
@ -246,6 +247,24 @@ typedef struct {
|
||||||
int32_t sect_3;
|
int32_t sect_3;
|
||||||
} ggml_metal_kargs_rope;
|
} ggml_metal_kargs_rope;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t ne11;
|
||||||
|
int32_t ne_12_2; // assume K and V are same shape
|
||||||
|
int32_t ne_12_3;
|
||||||
|
uint64_t nb11;
|
||||||
|
uint64_t nb12;
|
||||||
|
uint64_t nb13;
|
||||||
|
uint64_t nb21;
|
||||||
|
uint64_t nb22;
|
||||||
|
uint64_t nb23;
|
||||||
|
int32_t ne31;
|
||||||
|
int32_t ne32;
|
||||||
|
int32_t ne33;
|
||||||
|
uint64_t nb31;
|
||||||
|
uint64_t nb32;
|
||||||
|
uint64_t nb33;
|
||||||
|
} ggml_metal_kargs_flash_attn_ext_pad;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t ne01;
|
int32_t ne01;
|
||||||
int32_t ne02;
|
int32_t ne02;
|
||||||
|
|
@ -264,6 +283,7 @@ typedef struct {
|
||||||
uint64_t nb21;
|
uint64_t nb21;
|
||||||
uint64_t nb22;
|
uint64_t nb22;
|
||||||
uint64_t nb23;
|
uint64_t nb23;
|
||||||
|
int32_t ne31;
|
||||||
int32_t ne32;
|
int32_t ne32;
|
||||||
int32_t ne33;
|
int32_t ne33;
|
||||||
uint64_t nb31;
|
uint64_t nb31;
|
||||||
|
|
@ -298,6 +318,7 @@ typedef struct {
|
||||||
uint64_t nb21;
|
uint64_t nb21;
|
||||||
uint64_t nb22;
|
uint64_t nb22;
|
||||||
uint64_t nb23;
|
uint64_t nb23;
|
||||||
|
int32_t ne31;
|
||||||
int32_t ne32;
|
int32_t ne32;
|
||||||
int32_t ne33;
|
int32_t ne33;
|
||||||
uint64_t nb31;
|
uint64_t nb31;
|
||||||
|
|
|
||||||
|
|
@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
|
||||||
GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
|
GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
|
||||||
|
GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
|
||||||
|
GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
|
||||||
GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
|
GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
|
||||||
GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
|
GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
|
||||||
|
|
||||||
|
|
@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
||||||
ggml_is_contiguous(node->src[1]), node->src[1]->name);
|
ggml_is_contiguous(node->src[1]), node->src[1]->name);
|
||||||
}
|
}
|
||||||
|
if (node->src[2]) {
|
||||||
|
GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
|
||||||
|
ggml_is_contiguous(node->src[2]), node->src[2]->name);
|
||||||
|
}
|
||||||
|
if (node->src[3]) {
|
||||||
|
GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
|
||||||
|
ggml_is_contiguous(node->src[3]), node->src[3]->name);
|
||||||
|
}
|
||||||
if (node) {
|
if (node) {
|
||||||
GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
||||||
node->name);
|
node->name);
|
||||||
|
|
@ -1873,20 +1885,69 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
|
||||||
return (ne01 < 20) && (ne00 % 32 == 0);
|
return (ne01 < 20) && (ne00 % 32 == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
|
||||||
|
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||||
|
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
||||||
|
|
||||||
|
size_t res = 0;
|
||||||
|
|
||||||
|
const bool has_mask = op->src[3] != nullptr;
|
||||||
|
|
||||||
|
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||||
|
const bool has_kvpad = ne11 % 32 != 0;
|
||||||
|
|
||||||
|
if (has_kvpad) {
|
||||||
|
res += 32*(
|
||||||
|
nb11*ne12*ne13 +
|
||||||
|
nb21*ne22*ne23 +
|
||||||
|
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const bool has_kvpad = ne11 % 64 != 0;
|
||||||
|
|
||||||
|
if (has_kvpad) {
|
||||||
|
res += 64*(
|
||||||
|
nb11*ne12*ne13 +
|
||||||
|
nb21*ne22*ne23 +
|
||||||
|
(has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
||||||
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
|
||||||
|
|
||||||
const int64_t nwg = 32;
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||||
|
//GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
|
||||||
|
//GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
|
||||||
|
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
|
||||||
|
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
|
||||||
|
//GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
|
||||||
|
//GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
|
||||||
|
|
||||||
const int64_t ne01 = op->src[0]->ne[1];
|
size_t res = 0;
|
||||||
const int64_t ne02 = op->src[0]->ne[2];
|
|
||||||
const int64_t ne03 = op->src[0]->ne[3];
|
if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||||
const int64_t ne20 = op->src[2]->ne[0];
|
const int64_t nwg = 32;
|
||||||
|
|
||||||
// temp buffer for writing the results from each workgroup
|
// temp buffer for writing the results from each workgroup
|
||||||
// - ne20: the size of the Value head
|
// - ne20: the size of the Value head
|
||||||
// - + 2: the S and M values for each intermediate result
|
// - + 2: the S and M values for each intermediate result
|
||||||
return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||||
|
|
@ -1909,7 +1970,6 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
|
GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
|
||||||
|
|
||||||
GGML_ASSERT(ne00 % 4 == 0);
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
GGML_ASSERT(ne11 % 32 == 0);
|
|
||||||
|
|
||||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
|
GGML_ASSERT(op->src[1]->type == op->src[2]->type);
|
||||||
|
|
@ -1947,6 +2007,11 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||||
|
|
||||||
GGML_ASSERT(ne01 < 65536);
|
GGML_ASSERT(ne01 < 65536);
|
||||||
|
|
||||||
|
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||||
|
|
||||||
|
ggml_metal_buffer_id bid_pad = bid_dst;
|
||||||
|
bid_pad.offs += ggml_nbytes(op);
|
||||||
|
|
||||||
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||||
// half8x8 kernel
|
// half8x8 kernel
|
||||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
|
|
@ -1956,6 +2021,52 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_ASSERT(nqptg % 8 == 0);
|
GGML_ASSERT(nqptg % 8 == 0);
|
||||||
GGML_ASSERT(ncpsg % 32 == 0);
|
GGML_ASSERT(ncpsg % 32 == 0);
|
||||||
|
|
||||||
|
const bool has_kvpad = ne11 % ncpsg != 0;
|
||||||
|
|
||||||
|
if (has_kvpad) {
|
||||||
|
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
||||||
|
|
||||||
|
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
||||||
|
/*.ne11 =*/ne11,
|
||||||
|
/*.ne_12_2 =*/ne12,
|
||||||
|
/*.ne_12_3 =*/ne13,
|
||||||
|
/*.nb11 =*/nb11,
|
||||||
|
/*.nb12 =*/nb12,
|
||||||
|
/*.nb13 =*/nb13,
|
||||||
|
/*.nb21 =*/nb21,
|
||||||
|
/*.nb22 =*/nb22,
|
||||||
|
/*.nb23 =*/nb23,
|
||||||
|
/*.ne31 =*/ne31,
|
||||||
|
/*.ne32 =*/ne32,
|
||||||
|
/*.ne33 =*/ne33,
|
||||||
|
/*.nb31 =*/nb31,
|
||||||
|
/*.nb32 =*/nb32,
|
||||||
|
/*.nb33 =*/nb33,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||||
|
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 2);
|
||||||
|
if (op->src[3]) {
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 3);
|
||||||
|
} else {
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
|
||||||
|
}
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
||||||
|
|
||||||
|
assert(ne12 == ne22);
|
||||||
|
assert(ne13 == ne23);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||||
|
|
||||||
|
ggml_metal_op_concurrency_reset(ctx);
|
||||||
|
} else {
|
||||||
|
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||||
|
}
|
||||||
|
|
||||||
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
|
const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
|
||||||
|
|
||||||
// 2*(2*ncpsg)
|
// 2*(2*ncpsg)
|
||||||
|
|
@ -2005,6 +2116,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||||
/*.nb21 =*/ nb21,
|
/*.nb21 =*/ nb21,
|
||||||
/*.nb22 =*/ nb22,
|
/*.nb22 =*/ nb22,
|
||||||
/*.nb23 =*/ nb23,
|
/*.nb23 =*/ nb23,
|
||||||
|
/*.ne31 =*/ ne31,
|
||||||
/*.ne32 =*/ ne32,
|
/*.ne32 =*/ ne32,
|
||||||
/*.ne33 =*/ ne33,
|
/*.ne33 =*/ ne33,
|
||||||
/*.nb31 =*/ nb31,
|
/*.nb31 =*/ nb31,
|
||||||
|
|
@ -2021,7 +2133,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||||
/*.logit_softcap =*/ logit_softcap,
|
/*.logit_softcap =*/ logit_softcap,
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
|
|
@ -2038,7 +2150,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||||
} else {
|
} else {
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
||||||
}
|
}
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6);
|
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 7);
|
||||||
|
|
||||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||||
|
|
||||||
|
|
@ -2054,6 +2167,52 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_ASSERT(nqptg % 1 == 0);
|
GGML_ASSERT(nqptg % 1 == 0);
|
||||||
GGML_ASSERT(ncpsg % 32 == 0);
|
GGML_ASSERT(ncpsg % 32 == 0);
|
||||||
|
|
||||||
|
const bool has_kvpad = ne11 % ncpsg != 0;
|
||||||
|
|
||||||
|
if (has_kvpad) {
|
||||||
|
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
|
||||||
|
|
||||||
|
ggml_metal_kargs_flash_attn_ext_pad args0 = {
|
||||||
|
/*.ne11 =*/ne11,
|
||||||
|
/*.ne_12_2 =*/ne12,
|
||||||
|
/*.ne_12_3 =*/ne13,
|
||||||
|
/*.nb11 =*/nb11,
|
||||||
|
/*.nb12 =*/nb12,
|
||||||
|
/*.nb13 =*/nb13,
|
||||||
|
/*.nb21 =*/nb21,
|
||||||
|
/*.nb22 =*/nb22,
|
||||||
|
/*.nb23 =*/nb23,
|
||||||
|
/*.ne31 =*/ne31,
|
||||||
|
/*.ne32 =*/ne32,
|
||||||
|
/*.ne33 =*/ne33,
|
||||||
|
/*.nb31 =*/nb31,
|
||||||
|
/*.nb32 =*/nb32,
|
||||||
|
/*.nb33 =*/nb33,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
|
||||||
|
|
||||||
|
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||||
|
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
|
||||||
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 2);
|
||||||
|
if (op->src[3]) {
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 3);
|
||||||
|
} else {
|
||||||
|
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
|
||||||
|
}
|
||||||
|
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
||||||
|
|
||||||
|
assert(ne12 == ne22);
|
||||||
|
assert(ne13 == ne23);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
|
||||||
|
|
||||||
|
ggml_metal_op_concurrency_reset(ctx);
|
||||||
|
} else {
|
||||||
|
assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0);
|
||||||
|
}
|
||||||
|
|
||||||
// ne00 + 2*ncpsg*(nsg)
|
// ne00 + 2*ncpsg*(nsg)
|
||||||
// for each query, we load it as f16 in shared memory (ne00)
|
// for each query, we load it as f16 in shared memory (ne00)
|
||||||
// and store the soft_max values and the mask
|
// and store the soft_max values and the mask
|
||||||
|
|
@ -2118,6 +2277,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||||
/*.nb21 =*/ nb21,
|
/*.nb21 =*/ nb21,
|
||||||
/*.nb22 =*/ nb22,
|
/*.nb22 =*/ nb22,
|
||||||
/*.nb23 =*/ nb23,
|
/*.nb23 =*/ nb23,
|
||||||
|
/*.ne31 =*/ ne31,
|
||||||
/*.ne32 =*/ ne32,
|
/*.ne32 =*/ ne32,
|
||||||
/*.ne33 =*/ ne33,
|
/*.ne33 =*/ ne33,
|
||||||
/*.nb31 =*/ nb31,
|
/*.nb31 =*/ nb31,
|
||||||
|
|
@ -2134,7 +2294,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||||
/*.logit_softcap =*/ logit_softcap,
|
/*.logit_softcap =*/ logit_softcap,
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
|
||||||
|
|
||||||
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||||
|
|
||||||
|
|
@ -2161,7 +2321,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||||
|
|
||||||
if (nwg == 1) {
|
if (nwg == 1) {
|
||||||
// using 1 workgroup -> write the result directly into dst
|
// using 1 workgroup -> write the result directly into dst
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6);
|
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
||||||
|
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
|
||||||
|
|
||||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||||
|
|
||||||
|
|
@ -2171,12 +2332,12 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||||
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
||||||
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
|
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
|
||||||
|
|
||||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
||||||
|
|
||||||
// write the results from each workgroup into a temp buffer
|
// write the results from each workgroup into a temp buffer
|
||||||
ggml_metal_buffer_id bid_tmp = bid_dst;
|
ggml_metal_buffer_id bid_tmp = bid_dst;
|
||||||
bid_tmp.offs += ggml_nbytes(op);
|
bid_tmp.offs += ggml_nbytes(op) + ggml_metal_op_flash_attn_ext_extra_pad(op);
|
||||||
ggml_metal_encoder_set_buffer(enc, bid_tmp, 6);
|
|
||||||
|
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
||||||
|
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
||||||
|
|
||||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
|
||||||
// return true if we should use the FA vector kernel for this op
|
// return true if we should use the FA vector kernel for this op
|
||||||
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
|
bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
|
||||||
|
|
||||||
|
size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
|
||||||
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
|
size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
|
||||||
|
|
||||||
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
|
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
|
||||||
|
|
|
||||||
|
|
@ -193,9 +193,8 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) {
|
res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
|
||||||
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
|
||||||
}
|
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -4416,10 +4416,79 @@ kernel void kernel_leaky_relu_f32_4(
|
||||||
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
|
dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
|
||||||
|
|
||||||
|
constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 24)]];
|
||||||
|
|
||||||
|
kernel void kernel_flash_attn_ext_pad(
|
||||||
|
constant ggml_metal_kargs_flash_attn_ext_pad & args,
|
||||||
|
device const char * k,
|
||||||
|
device const char * v,
|
||||||
|
device const char * mask,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort tiitg[[thread_index_in_threadgroup]],
|
||||||
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
const int32_t C = FC_flash_attn_ext_pad_ncpsg;
|
||||||
|
|
||||||
|
device char * k_pad = dst;
|
||||||
|
device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;
|
||||||
|
device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;
|
||||||
|
|
||||||
|
const int32_t icp = args.ne11 % C;
|
||||||
|
const int32_t ic0 = args.ne11 - icp;
|
||||||
|
|
||||||
|
const int32_t i1 = tgpig[0];
|
||||||
|
const int32_t i2 = tgpig[1];
|
||||||
|
const int32_t i3 = tgpig[2];
|
||||||
|
|
||||||
|
if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {
|
||||||
|
device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;
|
||||||
|
device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;
|
||||||
|
|
||||||
|
device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;
|
||||||
|
device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
|
||||||
|
|
||||||
|
if (i1 >= icp) {
|
||||||
|
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
|
||||||
|
k_dst[i] = 0;
|
||||||
|
}
|
||||||
|
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
|
||||||
|
v_dst[i] = 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
|
||||||
|
k_dst[i] = k_src[i];
|
||||||
|
}
|
||||||
|
for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
|
||||||
|
v_dst[i] = v_src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (FC_flash_attn_ext_pad_has_mask) {
|
||||||
|
if (i2 < args.ne32 && i3 < args.ne33) {
|
||||||
|
for (int ib = i1; ib < args.ne31; ib += C) {
|
||||||
|
device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;
|
||||||
|
device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;
|
||||||
|
|
||||||
|
for (int i = tiitg; i < C; i += ntg.x) {
|
||||||
|
if (i >= icp) {
|
||||||
|
mask_dst[i] = -MAXHALF;
|
||||||
|
} else {
|
||||||
|
mask_dst[i] = mask_src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
|
constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
|
||||||
constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
|
constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
|
||||||
constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
|
constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
|
||||||
constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
|
constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
|
||||||
|
constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
|
||||||
|
|
||||||
//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
|
//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
|
||||||
//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
|
//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
|
||||||
|
|
@ -4466,6 +4535,7 @@ void kernel_flash_attn_ext_impl(
|
||||||
device const char * v,
|
device const char * v,
|
||||||
device const char * mask,
|
device const char * mask,
|
||||||
device const char * sinks,
|
device const char * sinks,
|
||||||
|
device const char * pad,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
threadgroup half * shmem_f16,
|
threadgroup half * shmem_f16,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
|
|
@ -4521,6 +4591,7 @@ void kernel_flash_attn_ext_impl(
|
||||||
|
|
||||||
// mask storage in shared mem
|
// mask storage in shared mem
|
||||||
threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
|
threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
|
||||||
|
threadgroup half * sm = (threadgroup half *) (sm2);
|
||||||
|
|
||||||
// per-query mask pointers
|
// per-query mask pointers
|
||||||
device const half2 * pm2[NQ];
|
device const half2 * pm2[NQ];
|
||||||
|
|
@ -4590,7 +4661,44 @@ void kernel_flash_attn_ext_impl(
|
||||||
|
|
||||||
// loop over the KV cache
|
// loop over the KV cache
|
||||||
// each simdgroup handles blocks of Q rows and C columns
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
for (int ic = 0; ic < args.ne11; ic += C) {
|
for (int ic0 = 0; ic0 < args.ne11; ic0 += C) {
|
||||||
|
int ic = ic0;
|
||||||
|
|
||||||
|
if (FC_flash_attn_ext_has_kvpad && ic0 + C > args.ne11) {
|
||||||
|
k = pad;
|
||||||
|
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
|
||||||
|
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
|
||||||
|
|
||||||
|
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
||||||
|
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
||||||
|
|
||||||
|
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
|
||||||
|
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
|
||||||
|
|
||||||
|
if (!FC_flash_attn_ext_has_mask) {
|
||||||
|
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||||
|
const short j = jj*NSG + sgitg;
|
||||||
|
|
||||||
|
for (short i = tiisg; i < C; i += NW) {
|
||||||
|
if (ic + i >= args.ne11) {
|
||||||
|
sm[2*j*SH + i] = -MAXHALF;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||||
|
const short j = jj*NSG + sgitg;
|
||||||
|
|
||||||
|
pm2[jj] = (device const half2 *) ((device const half *) mask +
|
||||||
|
(iq1 + j)*C +
|
||||||
|
(iq2%args.ne32)*(C*args.ne31) +
|
||||||
|
(iq3%args.ne33)*(C*args.ne31*args.ne32));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ic = 0;
|
||||||
|
}
|
||||||
|
|
||||||
// read the mask into shared mem
|
// read the mask into shared mem
|
||||||
if (FC_flash_attn_ext_has_mask) {
|
if (FC_flash_attn_ext_has_mask) {
|
||||||
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||||
|
|
@ -4624,7 +4732,7 @@ void kernel_flash_attn_ext_impl(
|
||||||
// this is compile-time check, so it does not have runtime overhead
|
// this is compile-time check, so it does not have runtime overhead
|
||||||
if (is_same<kd4x4_t, k4x4_t>::value) {
|
if (is_same<kd4x4_t, k4x4_t>::value) {
|
||||||
// we can read directly from global memory
|
// we can read directly from global memory
|
||||||
device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11);
|
device const k_t * pk = (device const k_t *) (k + ic*args.nb11);
|
||||||
threadgroup const q_t * pq = sq;
|
threadgroup const q_t * pq = sq;
|
||||||
threadgroup s_t * ps = ss;
|
threadgroup s_t * ps = ss;
|
||||||
|
|
||||||
|
|
@ -4696,7 +4804,7 @@ void kernel_flash_attn_ext_impl(
|
||||||
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
||||||
|
|
||||||
for (short ii = 0; ii < DK16; ii += 4) {
|
for (short ii = 0; ii < DK16; ii += 4) {
|
||||||
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11));
|
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));
|
||||||
|
|
||||||
if (DK16%4 == 0) {
|
if (DK16%4 == 0) {
|
||||||
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
||||||
|
|
@ -4818,7 +4926,7 @@ void kernel_flash_attn_ext_impl(
|
||||||
{
|
{
|
||||||
auto sst = ss;
|
auto sst = ss;
|
||||||
|
|
||||||
device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21);
|
device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
|
||||||
|
|
||||||
pv += 8*sgitg;
|
pv += 8*sgitg;
|
||||||
|
|
||||||
|
|
@ -4860,7 +4968,7 @@ void kernel_flash_attn_ext_impl(
|
||||||
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
|
simdgroup_load(vs, ss + 8*cc, SH, 0, false);
|
||||||
|
|
||||||
for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
|
for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
|
||||||
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21));
|
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));
|
||||||
|
|
||||||
if (DV16%4 == 0) {
|
if (DV16%4 == 0) {
|
||||||
// no need for bound checks
|
// no need for bound checks
|
||||||
|
|
@ -5004,13 +5112,14 @@ kernel void kernel_flash_attn_ext(
|
||||||
device const char * v,
|
device const char * v,
|
||||||
device const char * mask,
|
device const char * mask,
|
||||||
device const char * sinks,
|
device const char * sinks,
|
||||||
|
device const char * pad,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
|
#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
|
||||||
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
|
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
|
||||||
switch (FC_flash_attn_ext_nsg) {
|
switch (FC_flash_attn_ext_nsg) {
|
||||||
// note: disabled cases to reduce library load time
|
// note: disabled cases to reduce library load time
|
||||||
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
//case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||||
|
|
@ -5130,6 +5239,7 @@ constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_
|
||||||
constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
|
constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
|
||||||
constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
|
constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
|
||||||
constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
|
constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
|
||||||
|
constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];
|
||||||
|
|
||||||
//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
|
//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
|
||||||
//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
|
//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
|
||||||
|
|
@ -5167,6 +5277,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||||
device const char * v,
|
device const char * v,
|
||||||
device const char * mask,
|
device const char * mask,
|
||||||
device const char * sinks,
|
device const char * sinks,
|
||||||
|
device const char * pad,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
|
@ -5273,11 +5384,36 @@ void kernel_flash_attn_ext_vec_impl(
|
||||||
// loop over the KV cache
|
// loop over the KV cache
|
||||||
// each simdgroup handles blocks of Q rows and C columns
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) {
|
for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) {
|
||||||
const int ic = ic0 + C*sgitg;
|
int ic = ic0 + C*sgitg;
|
||||||
if (ic >= args.ne11) {
|
if (ic >= args.ne11) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
|
||||||
|
k = pad;
|
||||||
|
v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
|
||||||
|
mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
|
||||||
|
|
||||||
|
const short ikv2 = iq2/(args.ne02/args.ne_12_2);
|
||||||
|
const short ikv3 = iq3/(args.ne03/args.ne_12_3);
|
||||||
|
|
||||||
|
k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
|
||||||
|
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
|
||||||
|
|
||||||
|
if (!FC_flash_attn_ext_vec_has_mask) {
|
||||||
|
if (ic + tiisg >= args.ne11) {
|
||||||
|
sm[tiisg] = -MAXHALF;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
pm = (device const half *) (mask) +
|
||||||
|
iq1*C +
|
||||||
|
(iq2%args.ne32)*(C*args.ne31) +
|
||||||
|
(iq3%args.ne33)*(C*args.ne31*args.ne32);
|
||||||
|
}
|
||||||
|
|
||||||
|
ic = 0;
|
||||||
|
}
|
||||||
|
|
||||||
if (FC_flash_attn_ext_vec_has_mask) {
|
if (FC_flash_attn_ext_vec_has_mask) {
|
||||||
sm[tiisg] = pm[ic + tiisg];
|
sm[tiisg] = pm[ic + tiisg];
|
||||||
}
|
}
|
||||||
|
|
@ -5289,7 +5425,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||||
|
|
||||||
// Q*K^T
|
// Q*K^T
|
||||||
{
|
{
|
||||||
device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11);
|
device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);
|
||||||
threadgroup const q4_t * pq4 = sq4;
|
threadgroup const q4_t * pq4 = sq4;
|
||||||
|
|
||||||
pk4 += ty*NS10/4 + tx;
|
pk4 += ty*NS10/4 + tx;
|
||||||
|
|
@ -5304,7 +5440,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||||
mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]);
|
mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11));
|
device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));
|
||||||
|
|
||||||
k4_t mk;
|
k4_t mk;
|
||||||
|
|
||||||
|
|
@ -5402,7 +5538,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_same<vd4_t, v4_t>::value) {
|
if (is_same<vd4_t, v4_t>::value) {
|
||||||
device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21);
|
device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);
|
||||||
|
|
||||||
pv4 += ty*NS20/4 + tx;
|
pv4 += ty*NS20/4 + tx;
|
||||||
|
|
||||||
|
|
@ -5415,7 +5551,7 @@ void kernel_flash_attn_ext_vec_impl(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
|
FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
|
||||||
device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21));
|
device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));
|
||||||
|
|
||||||
FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
|
FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
|
||||||
const short i = ii*NL + tx;
|
const short i = ii*NL + tx;
|
||||||
|
|
@ -5587,13 +5723,14 @@ kernel void kernel_flash_attn_ext_vec(
|
||||||
device const char * v,
|
device const char * v,
|
||||||
device const char * mask,
|
device const char * mask,
|
||||||
device const char * sinks,
|
device const char * sinks,
|
||||||
|
device const char * pad,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
|
#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
|
||||||
#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg
|
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
|
||||||
switch (FC_flash_attn_ext_vec_nsg) {
|
switch (FC_flash_attn_ext_vec_nsg) {
|
||||||
// note: disabled cases to reduce library load time
|
// note: disabled cases to reduce library load time
|
||||||
case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||||
|
|
|
||||||
|
|
@ -6627,7 +6627,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
|
if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
|
||||||
for (int nr2 : { 1, 4, 16 }) {
|
for (int nr2 : { 1, 4, 16 }) {
|
||||||
if (nr2 == 16 && hsk != 128) continue;
|
if (nr2 == 16 && hsk != 128) continue;
|
||||||
for (int kv : { 512, 1024, }) {
|
//for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) {
|
||||||
|
for (int kv : { 113, 512, 1024, }) {
|
||||||
if (nr2 != 1 && kv != 512) continue;
|
if (nr2 != 1 && kv != 512) continue;
|
||||||
for (int nb : { 1, 3, 32, 35, }) {
|
for (int nb : { 1, 3, 32, 35, }) {
|
||||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue