From bbac6a26b2bd7f7c1f0831cb1e7b52734c66673b Mon Sep 17 00:00:00 2001 From: leejet Date: Mon, 27 Oct 2025 02:13:31 +0800 Subject: [PATCH] ggml: fix cuda kernel launch configuration for k_compute_batched_ptrs to support large batch (#16744) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix k_compute_batched_ptrs * add backend ops test * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Johannes Gäßler * reduce the batch size --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/ggml-cuda.cu | 11 +++++++++-- tests/test-backend-ops.cpp | 3 +++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 19f72975c0..6b688bfecd 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1957,8 +1957,15 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct size_t src1_stride_size = sizeof(cuda_t); - dim3 block_dims(ne13, ne12); - k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( + const int threads_x = 16; + const int threads_y = 16; + dim3 block_dims(threads_x, threads_y); + + dim3 grid_dims( + (ne13 + threads_x - 1) / threads_x, + (ne12 + threads_y - 1) / threads_y + ); + k_compute_batched_ptrs<<>>( src0_ptr, src1_ptr, dst_t, ptrs_src.get(), ptrs_dst.get(), ne12, ne13, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 33ac27ff5c..0ad73e944e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6697,6 +6697,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 1024, {3, 2}, {1, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1})); + + // test cases with large batch size + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {1536, 1}, {1, 1})); } } for (ggml_type type_a : other_types) {