CUDA: limit number of FA stream-k CUDA blocks (#20586)
This commit is contained in:
parent
ceef6b5233
commit
ae40cd27c8
|
|
@ -892,7 +892,7 @@ void launch_fattn(
|
|||
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
|
||||
const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
|
||||
const int ntiles_dst = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
|
||||
|
||||
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
|
||||
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
||||
|
|
@ -919,37 +919,37 @@ void launch_fattn(
|
|||
GGML_ASSERT(max_blocks_per_sm > 0);
|
||||
int parallel_blocks = max_blocks_per_sm;
|
||||
|
||||
const int ntiles_KV = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by KV cache length.
|
||||
|
||||
dim3 blocks_num;
|
||||
if (stream_k) {
|
||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||
const int max_blocks = max_blocks_per_sm*nsm;
|
||||
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
||||
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
|
||||
const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks;
|
||||
const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
|
||||
|
||||
const int nblocks_stream_k = max_blocks;
|
||||
const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst);
|
||||
|
||||
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
|
||||
|
||||
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
|
||||
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst;
|
||||
blocks_num.y = 1;
|
||||
blocks_num.z = 1;
|
||||
|
||||
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
|
||||
}
|
||||
} else {
|
||||
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
|
||||
|
||||
// parallel_blocks must not be larger than what the tensor size allows:
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
|
||||
parallel_blocks = std::min(parallel_blocks, ntiles_KV);
|
||||
|
||||
// If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
|
||||
// Test whether parallel_blocks can be set to a higher value for better efficiency.
|
||||
const int blocks_per_wave = nsm * max_blocks_per_sm;
|
||||
int nwaves_best = 0;
|
||||
int efficiency_percent_best = 0;
|
||||
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
|
||||
const int nblocks_total = ntiles_total * parallel_blocks_test;
|
||||
for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KV; ++parallel_blocks_test) {
|
||||
const int nblocks_total = ntiles_dst * parallel_blocks_test;
|
||||
const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
|
||||
const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
|
||||
|
||||
|
|
@ -1015,7 +1015,7 @@ void launch_fattn(
|
|||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if (stream_k) {
|
||||
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue