metal : skip loading all-zero mask
This commit is contained in:
parent
423bee462b
commit
4815a66990
|
|
@ -5285,6 +5285,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E
|
|||
// scan the blocks of the mask that are not masked
|
||||
// 0 - masked (i.e. full of -INF, skip)
|
||||
// 1 - not masked (i.e. at least one element of the mask is not -INF)
|
||||
// 2 - all zero
|
||||
kernel void kernel_flash_attn_ext_blk(
|
||||
constant ggml_metal_kargs_flash_attn_ext_blk & args,
|
||||
device const char * mask,
|
||||
|
|
@ -5306,27 +5307,29 @@ kernel void kernel_flash_attn_ext_blk(
|
|||
|
||||
device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
|
||||
|
||||
// fast route
|
||||
if (res == 0) {
|
||||
if (simd_max(*mask_src) > -MAXHALF/2) {
|
||||
res = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// detailed check of the elements of the block
|
||||
if ((C > NW || Q > 1) && res == 0) {
|
||||
half m = -MAXHALF;
|
||||
half mmin = MAXHALF;
|
||||
half mmax = -MAXHALF;
|
||||
|
||||
FOR_UNROLL (short j = 0; j < Q; ++j) {
|
||||
FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
|
||||
m = max(m, mask_src[ii*NW]);
|
||||
mmin = min(mmin, mask_src[ii*NW]);
|
||||
mmax = max(mmax, mask_src[ii*NW]);
|
||||
}
|
||||
|
||||
mask_src += args.nb31/2;
|
||||
}
|
||||
|
||||
if (simd_max(m) > -MAXHALF/2) {
|
||||
res = 1;
|
||||
mmin = simd_min(mmin);
|
||||
mmax = simd_max(mmax);
|
||||
|
||||
if (mmax > -MAXHALF) {
|
||||
if (mmin == 0.0 && mmax == 0.0) {
|
||||
res = 2;
|
||||
} else {
|
||||
res = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -5568,9 +5571,13 @@ void kernel_flash_attn_ext_impl(
|
|||
ic = 0;
|
||||
}
|
||||
|
||||
char blk_cur = 1;
|
||||
|
||||
// read the mask into shared mem
|
||||
if (FC_flash_attn_ext_has_mask) {
|
||||
if (blk[ic0] == 0) {
|
||||
blk_cur = blk[ic0];
|
||||
|
||||
if (blk_cur == 0) {
|
||||
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||
pm2[jj] += NW;
|
||||
}
|
||||
|
|
@ -5578,16 +5585,24 @@ void kernel_flash_attn_ext_impl(
|
|||
continue;
|
||||
}
|
||||
|
||||
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||
const short j = jj*NSG + sgitg;
|
||||
if (blk_cur == 1) {
|
||||
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||
const short j = jj*NSG + sgitg;
|
||||
|
||||
if (FC_flash_attn_ext_bc_mask) {
|
||||
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
|
||||
} else {
|
||||
sm2[j*SH + tiisg] = pm2[jj][tiisg];
|
||||
if (FC_flash_attn_ext_bc_mask) {
|
||||
sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
|
||||
} else {
|
||||
sm2[j*SH + tiisg] = pm2[jj][tiisg];
|
||||
}
|
||||
|
||||
pm2[jj] += NW;
|
||||
}
|
||||
} else if (blk_cur == 2) {
|
||||
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||
const short j = jj*NSG + sgitg;
|
||||
|
||||
pm2[jj] += NW;
|
||||
pm2[jj] += NW;
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
|
|
@ -5752,10 +5767,12 @@ void kernel_flash_attn_ext_impl(
|
|||
}
|
||||
|
||||
// mqk = mqk + slope*mask
|
||||
if (FC_flash_attn_ext_has_bias) {
|
||||
s2 += s2_t(sm2[j*SH + tiisg])*slope;
|
||||
} else {
|
||||
s2 += s2_t(sm2[j*SH + tiisg]);
|
||||
if (blk_cur != 2) {
|
||||
if (FC_flash_attn_ext_has_bias) {
|
||||
s2 += s2_t(sm2[j*SH + tiisg])*slope;
|
||||
} else {
|
||||
s2 += s2_t(sm2[j*SH + tiisg]);
|
||||
}
|
||||
}
|
||||
|
||||
M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
|
||||
|
|
|
|||
Loading…
Reference in New Issue