diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 9b13614cfd..54acf41d83 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -209,6 +209,7 @@ static __device__ float two_stage_warp_reduce_sum(float val) { } } +// TODO: Template to allow keeping ncols in registers if they fit static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x, float * __restrict__ dst, float * __restrict__ tmp_vals, @@ -404,7 +405,9 @@ static void soft_max_f32_cuda(const float * x, launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared); } else { // Parallelize across SMs for top-p/dist-smapling - if (ncols_x > 10000 && mask == nullptr && sinks == nullptr && params.scale == 1.0f && params.max_bias == 0.0f) { + // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and + // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution. + if (ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && params.scale == 1.0f && params.max_bias == 0.0f) { if (ggml_cuda_info().devices[id].supports_cooperative_launch) { ggml_cuda_pool_alloc tmp_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 2777cd2b82..7a02979b3a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8113,6 +8113,12 @@ static std::vector> make_test_cases_perf() { } } + for (int col: {8192, 16384, 32768, 65536, 131072, 262144, 524288}) { + for (int rows: {1, 4, 16}){ + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {col, rows, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + } + } + test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false)); test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));