Fix launch logic when supports_cooperative_launch=false

This commit is contained in:
Oliver Simons 2025-12-09 19:03:47 +01:00
parent 3f0594ad0b
commit a25fda5290
1 changed files with 1 additions and 3 deletions

View File

@ -407,15 +407,13 @@ static void soft_max_f32_cuda(const float * x,
// Parallelize across SMs for top-p/dist-smapling // Parallelize across SMs for top-p/dist-smapling
// The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and
// Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution. // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
if (ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && params.scale == 1.0f && params.max_bias == 0.0f) { if (ggml_cuda_info().devices[id].supports_cooperative_launch && ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && params.scale == 1.0f && params.max_bias == 0.0f) {
if (ggml_cuda_info().devices[id].supports_cooperative_launch) {
ggml_cuda_pool_alloc<float> tmp_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); ggml_cuda_pool_alloc<float> tmp_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_alloc.ptr, (void *) const_cast<soft_max_params *>(& params)}; void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_alloc.ptr, (void *) const_cast<soft_max_params *>(& params)};
CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols, CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,
dim3(ggml_cuda_info().devices[id].nsm, 1, 1), dim3(ggml_cuda_info().devices[id].nsm, 1, 1),
dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream)); dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));
}
} else { } else {
const size_t nbytes_shared_low = WARP_SIZE * sizeof(float); const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
soft_max_f32<false, 0, 0> soft_max_f32<false, 0, 0>