diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index b6a7460da8..e9abdf288c 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -892,7 +892,7 @@ void launch_fattn( const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int gqa_ratio = Q->ne[2] / K->ne[2]; const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2); - const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3]; + const int ntiles_dst = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3]; // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or @@ -919,37 +919,37 @@ void launch_fattn( GGML_ASSERT(max_blocks_per_sm > 0); int parallel_blocks = max_blocks_per_sm; + const int ntiles_KV = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by KV cache length. + dim3 blocks_num; if (stream_k) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. const int max_blocks = max_blocks_per_sm*nsm; - const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; - const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); + const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks; + const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves); - const int nblocks_stream_k = max_blocks; + const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst); const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75; - blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst; blocks_num.y = 1; blocks_num.z = 1; - if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2))); } } else { - const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. - // parallel_blocks must not be larger than what the tensor size allows: - parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + parallel_blocks = std::min(parallel_blocks, ntiles_KV); // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. // Test whether parallel_blocks can be set to a higher value for better efficiency. const int blocks_per_wave = nsm * max_blocks_per_sm; int nwaves_best = 0; int efficiency_percent_best = 0; - for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { - const int nblocks_total = ntiles_total * parallel_blocks_test; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KV; ++parallel_blocks_test) { + const int nblocks_total = ntiles_dst * parallel_blocks_test; const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); @@ -1015,7 +1015,7 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); if (stream_k) { - if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};