Use the new kernel only for nblocks_stream_k_raw > 4 * ntiles_dst to make sure we have enough concurrency on GPUs
This commit is contained in:
parent
99c3df8219
commit
d1fd632ab8
|
|
@ -748,7 +748,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||
*dst = dst_val / rowsum;
|
||||
}
|
||||
|
||||
// Fallback fixup kernel for the cases where nblocks_stream_k < 2 * ntiles_dst
|
||||
// Fallback fixup kernel for the cases where nblocks_stream_k < 4 * ntiles_dst
|
||||
// (blocks_num.x not a multiple of ntiles_dst)
|
||||
template <int D, int ncols1, int ncols2>
|
||||
__launch_bounds__(D, 1)
|
||||
|
|
@ -1054,8 +1054,8 @@ void launch_fattn(
|
|||
|
||||
const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst);
|
||||
// Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks.
|
||||
// do this only if nblocks_stream_k_raw is at least 2x ntiles_dst to avoid excessive loss of occupancy
|
||||
const int nblocks_stream_k = nblocks_stream_k_raw > 2 * ntiles_dst
|
||||
// do this only if nblocks_stream_k_raw is at least 4x ntiles_dst to avoid excessive loss of occupancy
|
||||
const int nblocks_stream_k = nblocks_stream_k_raw > 4 * ntiles_dst
|
||||
? (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst
|
||||
: nblocks_stream_k_raw;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue