diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index d2f2def8bd..7aeee449ef 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -149,9 +149,32 @@ static __global__ void cumsum_kernel( } } +template +static void cumsum_cub(ggml_cuda_pool & pool, + const T * src, + T * dst, + int64_t ne, + cudaStream_t stream) { + size_t tmp_size = 0; + + // Query how much temp storage CUDA UnBound (CUB) needs + cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size) + tmp_size, // reference to size (will be set by CUB) + src, // input pointer + dst, // output pointer + ne, // number of elements + stream // CUDA stream to use + ); + + ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size); + + // Perform the inclusive scan + cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream); +} + template static void cumsum_cuda( - const T * src, T * dst, + [[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, @@ -165,6 +188,15 @@ static void cumsum_cuda( if (is_contiguous) { use_cub = true; + int64_t nrows = ne01 * ne02 * ne03; + // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released + // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004 + if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) { + for (int i=0; idata, (float *)dst->data, + ctx, (const float *)src0->data, (float *)dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6a1b063e5a..ec1fcd8549 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7736,6 +7736,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 5, 4, 3 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 201*1204, 1, 1, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 312*1205, 1, 1, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20481, 4, 1, 1 })); test_cases.emplace_back(new test_xielu());