metal : minor cleanup
This commit is contained in:
parent
41ea26144e
commit
3479aacf8c
|
|
@ -81,10 +81,10 @@
|
|||
#define FC_COUNT_EQUAL 1000
|
||||
|
||||
// op-specific constants
|
||||
#define OP_FLASH_ATTN_EXT_NQPTG 8
|
||||
#define OP_FLASH_ATTN_EXT_NQPSG 8
|
||||
#define OP_FLASH_ATTN_EXT_NCPSG 64
|
||||
|
||||
#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1
|
||||
#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
|
||||
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
|
||||
|
||||
// kernel argument structs
|
||||
|
|
|
|||
|
|
@ -2295,7 +2295,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
|
|||
// return res;
|
||||
//}
|
||||
|
||||
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG;
|
||||
const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
|
||||
const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
|
||||
|
||||
const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
|
||||
|
|
@ -2411,7 +2411,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||
// half8x8 kernel
|
||||
const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup
|
||||
const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
|
||||
const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
|
||||
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
|
|
@ -2578,9 +2578,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
#undef FATTN_SMEM
|
||||
} else {
|
||||
// half4x4 kernel
|
||||
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup
|
||||
const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
|
||||
const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
const int nkpsg = 1*ncpsg;
|
||||
const int nhptg = 1; // heads per threadgroup
|
||||
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
GGML_ASSERT(nqptg % 1 == 0);
|
||||
|
|
@ -2632,6 +2632,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
ggml_metal_op_concurrency_reset(ctx);
|
||||
}
|
||||
|
||||
// note: for simplicity assume the K is larger or equal than V
|
||||
GGML_ASSERT(ne10 >= ne20);
|
||||
|
||||
// ne00 + 2*ncpsg*(nsg)
|
||||
// for each query, we load it as f16 in shared memory (ne00)
|
||||
// and store the soft_max values and the mask
|
||||
|
|
@ -2639,28 +2642,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
// ne20*(nsg)
|
||||
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
while (true) {
|
||||
const size_t smem = FATTN_SMEM(nsgmax);
|
||||
// avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
|
||||
if (smem > props_dev->max_theadgroup_memory_size/2) {
|
||||
break;
|
||||
}
|
||||
nsgmax *= 2;
|
||||
}
|
||||
nsgmax /= 2;
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
//const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
||||
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32)));
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsg = 1;
|
||||
while (nsg <= nsgt) {
|
||||
nsg *= 2;
|
||||
}
|
||||
nsg /= 2;
|
||||
|
||||
// workgroups
|
||||
// each workgroup handles nsg*nkpsg cache values
|
||||
|
|
@ -2673,7 +2657,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
} else {
|
||||
nwg = 32;
|
||||
nsg = 1;
|
||||
while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) {
|
||||
while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
|
||||
nsg *= 2;
|
||||
}
|
||||
}
|
||||
|
|
@ -2739,7 +2723,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
|
||||
} else {
|
||||
// sanity checks
|
||||
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
|
||||
|
|
@ -2752,7 +2736,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
|
||||
|
||||
// sync the 2 kernels
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
|
|
|||
|
|
@ -5931,7 +5931,7 @@ template<
|
|||
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
||||
short DK, // K head size
|
||||
short DV, // V head size
|
||||
short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup
|
||||
short Q = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup
|
||||
short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext(
|
||||
constant ggml_metal_kargs_flash_attn_ext & args,
|
||||
|
|
@ -6141,11 +6141,10 @@ template<
|
|||
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
|
||||
short DK, // K head size
|
||||
short DV, // V head size
|
||||
short NE, // head elements per thread
|
||||
short Q, // queries per threadgroup
|
||||
short C, // cache items per threadgroup
|
||||
short NSG> // number of simd groups
|
||||
void kernel_flash_attn_ext_vec_impl(
|
||||
short NE = 4, // head elements per thread
|
||||
short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, // queries per threadgroup
|
||||
short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext_vec(
|
||||
constant ggml_metal_kargs_flash_attn_ext_vec & args,
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
|
|
@ -6162,6 +6161,7 @@ void kernel_flash_attn_ext_vec_impl(
|
|||
static_assert(DV % 32 == 0, "DV must be divisible by 32");
|
||||
|
||||
#define NWG (FC_flash_attn_ext_vec_nwg)
|
||||
#define NSG (FC_flash_attn_ext_vec_nsg)
|
||||
|
||||
#define NS10 (FC_flash_attn_ext_vec_ns10)
|
||||
#define NS20 (FC_flash_attn_ext_vec_ns20)
|
||||
|
|
@ -6190,12 +6190,12 @@ void kernel_flash_attn_ext_vec_impl(
|
|||
|
||||
const short T = PK + NSG*SH; // shared memory size per query in (half)
|
||||
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention
|
||||
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t
|
||||
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask
|
||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results
|
||||
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
|
||||
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
|
||||
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention
|
||||
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t
|
||||
threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask
|
||||
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results
|
||||
|
||||
// store the result for all queries in shared memory (the O matrix from the paper)
|
||||
so4 += tiisg;
|
||||
|
|
@ -6213,11 +6213,13 @@ void kernel_flash_attn_ext_vec_impl(
|
|||
// load heads from Q to shared memory
|
||||
device const float4 * q4 = (device const float4 *) ((device const char *) q);
|
||||
|
||||
for (short i = tiisg; i < PK4; i += NW) {
|
||||
if (iq1 < args.ne01 && i < DK4) {
|
||||
sq4[i] = (q4_t) q4[i];
|
||||
} else {
|
||||
sq4[i] = (q4_t) 0.0f;
|
||||
if (iq1 < args.ne01) {
|
||||
for (short i = tiisg; i < PK4; i += NW) {
|
||||
if (i < DK4) {
|
||||
sq4[i] = (q4_t) q4[i];
|
||||
} else {
|
||||
sq4[i] = (q4_t) 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -6295,7 +6297,7 @@ void kernel_flash_attn_ext_vec_impl(
|
|||
}
|
||||
|
||||
// skip -INF blocks
|
||||
if (simd_max(sm[tiisg]) == -INFINITY) {
|
||||
if (simd_max(sm[tiisg]) <= -MAXHALF) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
@ -6569,57 +6571,11 @@ void kernel_flash_attn_ext_vec_impl(
|
|||
}
|
||||
|
||||
#undef NWG
|
||||
#undef NSG
|
||||
#undef NS10
|
||||
#undef NS20
|
||||
}
|
||||
|
||||
template<
|
||||
typename q4_t, // query types in shared memory
|
||||
typename k4_t, // key types in shared memory
|
||||
typename v4_t, // value types in shared memory
|
||||
typename qk_t, // Q*K types
|
||||
typename s_t, // soft-max types
|
||||
typename s4_t,
|
||||
typename o4_t, // attention accumulation types
|
||||
typename kd4_t, // key type in device memory
|
||||
short nl_k,
|
||||
void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
|
||||
typename vd4_t, // value type in device memory
|
||||
short nl_v,
|
||||
void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
|
||||
short DK, // K head size
|
||||
short DV, // V head size
|
||||
short NE = 4, // head elements per thread
|
||||
short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup
|
||||
short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext_vec(
|
||||
constant ggml_metal_kargs_flash_attn_ext_vec & args,
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device const char * sinks,
|
||||
device const char * pad,
|
||||
device char * dst,
|
||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C
|
||||
#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg
|
||||
switch (FC_flash_attn_ext_vec_nsg) {
|
||||
// note: disabled cases to reduce library load time
|
||||
case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break;
|
||||
case 2: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 2>(FWD_ARGS); break;
|
||||
case 4: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 4>(FWD_ARGS); break;
|
||||
//case 8: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 8>(FWD_ARGS); break;
|
||||
//case 16: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 16>(FWD_ARGS); break;
|
||||
//case 32: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 32>(FWD_ARGS); break;
|
||||
}
|
||||
#undef FWD_TMPL
|
||||
#undef FWD_ARGS
|
||||
}
|
||||
|
||||
// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
|
||||
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
|
||||
//
|
||||
|
|
|
|||
Loading…
Reference in New Issue