sycl : handle other FA case (#21377)
This commit is contained in:
parent
25eec6f327
commit
f51fd36d79
|
|
@ -1252,6 +1252,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm
|
|||
return;
|
||||
}
|
||||
|
||||
{
|
||||
constexpr int cols_per_block = ncols2*2;
|
||||
const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
|
||||
const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
|
||||
launch_fattn<DV, cols_per_block/ncols2, ncols2,
|
||||
flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
|
||||
(ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue