Use the new kernel only for nblocks_stream_k_raw > 4 * ntiles_dst to make sure we have enough concurrency on GPUs
This commit is contained in:
parent
99c3df8219
commit
d1fd632ab8
|
|
@ -748,7 +748,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||||
*dst = dst_val / rowsum;
|
*dst = dst_val / rowsum;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback fixup kernel for the cases where nblocks_stream_k < 2 * ntiles_dst
|
// Fallback fixup kernel for the cases where nblocks_stream_k < 4 * ntiles_dst
|
||||||
// (blocks_num.x not a multiple of ntiles_dst)
|
// (blocks_num.x not a multiple of ntiles_dst)
|
||||||
template <int D, int ncols1, int ncols2>
|
template <int D, int ncols1, int ncols2>
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
|
|
@ -1054,8 +1054,8 @@ void launch_fattn(
|
||||||
|
|
||||||
const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst);
|
const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst);
|
||||||
// Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks.
|
// Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks.
|
||||||
// do this only if nblocks_stream_k_raw is at least 2x ntiles_dst to avoid excessive loss of occupancy
|
// do this only if nblocks_stream_k_raw is at least 4x ntiles_dst to avoid excessive loss of occupancy
|
||||||
const int nblocks_stream_k = nblocks_stream_k_raw > 2 * ntiles_dst
|
const int nblocks_stream_k = nblocks_stream_k_raw > 4 * ntiles_dst
|
||||||
? (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst
|
? (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst
|
||||||
: nblocks_stream_k_raw;
|
: nblocks_stream_k_raw;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue