From afaf17d76736409f01f1e2a7e99b5e5d535b86c2 Mon Sep 17 00:00:00 2001 From: shaoqi Date: Fri, 9 Jan 2026 16:44:11 -0800 Subject: [PATCH 1/3] OpenCL: add CUMSUM op support --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 83 ++++++++++++++++++++++++++ ggml/src/ggml-opencl/kernels/cumsum.cl | 56 +++++++++++++++++ 3 files changed, 140 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/cumsum.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index fa5fadd112..bb2e2ecd5c 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -121,6 +121,7 @@ set(GGML_OPENCL_KERNELS ssm_conv sub sum_rows + cumsum transpose concat tsembd diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 508b2b8f03..b3da7e3f30 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -540,6 +540,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; cl_kernel kernel_sum_rows_f32; + cl_kernel kernel_cumsum_f32; + cl_kernel kernel_repeat; cl_kernel kernel_repeat_f32; cl_kernel kernel_pad; cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc; @@ -1768,6 +1770,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // cumsum + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cumsum.cl.h" + }; +#else + const std::string kernel_src = read_file("cumsum.cl"); +#endif + cl_program prog; + prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_cumsum_f32 = clCreateKernel(prog, "kernel_cumsum_f32", &err), err)); + GGML_LOG_CONT("."); + CL_CHECK(clReleaseProgram(prog)); + } + // sigmoid { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3422,6 +3441,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32; } case GGML_OP_SUM_ROWS: + case GGML_OP_CUMSUM: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_MEAN: return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_FLASH_ATTN_EXT: @@ -10619,6 +10640,62 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_cumsum(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne0 = src0->ne[0]; + const int ne1 = src0->ne[1]; + const int ne2 = src0->ne[2]; + const int ne3 = src0->ne[3]; + + const int axis = ggml_get_op_params_i32(dst, 0); + const int exclusive = ggml_get_op_params_i32(dst, 1); + const int reverse = ggml_get_op_params_i32(dst, 2); + + size_t lines = 1; + if (axis != 0) lines *= ne0; + if (axis != 1) lines *= ne1; + if (axis != 2) lines *= ne2; + if (axis != 3) lines *= ne3; + + cl_kernel kernel = backend_ctx->kernel_cumsum_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &axis)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &exclusive)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &reverse)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &lines)); + + size_t global_work_size[1] = { (size_t)lines }; + size_t local_work_val = 256; + if ((size_t)lines < local_work_val) local_work_val = (size_t)lines; + size_t local_work_size[1] = { local_work_val }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst); +} + static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -11031,6 +11108,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_sum_rows; break; + case GGML_OP_CUMSUM: + if (!any_on_device) { + return false; + } + func = ggml_cl_cumsum; + break; case GGML_OP_FLASH_ATTN_EXT: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/cumsum.cl b/ggml/src/ggml-opencl/kernels/cumsum.cl new file mode 100644 index 0000000000..ecfeabd393 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/cumsum.cl @@ -0,0 +1,56 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// cumsum +//------------------------------------------------------------------------------ +kernel void kernel_cumsum_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne0, + int ne1, + int ne2, + int ne3, + int axis, + int exclusive, + int reverse, + int lines +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + const int gid = get_global_id(0); + + int i0 = 0, i1 = 0, i2 = 0, i3 = 0; + int t = gid; + + if (axis != 3) { i3 = t % ne3; t /= ne3; } + if (axis != 2) { i2 = t % ne2; t /= ne2; } + if (axis != 1) { i1 = t % ne1; t /= ne1; } + if (axis != 0) { i0 = t % ne0; t /= ne0; } + + const int axis_len = (axis == 0 ? ne0 : axis == 1 ? ne1 : axis == 2 ? ne2 : ne3); + + float acc = 0.0f; + + for (int pos = 0; pos < axis_len; pos++) { + const int a = reverse ? (axis_len - 1 - pos) : pos; + + int j0 = i0, j1 = i1, j2 = i2, j3 = i3; + if (axis == 0) j0 = a; + else if (axis == 1) j1 = a; + else if (axis == 2) j2 = a; + else j3 = a; + + int idx = j0 + ne0 * (j1 + ne1 * (j2 + ne2 * j3)); + + if (exclusive) { + dst[idx] = acc; + acc += src0[idx]; + } else { + acc += src0[idx]; + dst[idx] = acc; + } + } +} From c4b57de54c46192b52c92b012f6b1491443466d1 Mon Sep 17 00:00:00 2001 From: shaoqi Date: Wed, 21 Jan 2026 11:08:34 -0800 Subject: [PATCH 2/3] remove unused argument --- ggml/src/ggml-opencl/ggml-opencl.cpp | 1 - ggml/src/ggml-opencl/kernels/cumsum.cl | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index b3da7e3f30..f9e3f46195 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -10686,7 +10686,6 @@ static void ggml_cl_cumsum(ggml_backend_t backend, const ggml_tensor * src0, con CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &axis)); CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &exclusive)); CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &reverse)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &lines)); size_t global_work_size[1] = { (size_t)lines }; size_t local_work_val = 256; diff --git a/ggml/src/ggml-opencl/kernels/cumsum.cl b/ggml/src/ggml-opencl/kernels/cumsum.cl index ecfeabd393..7f6d539850 100644 --- a/ggml/src/ggml-opencl/kernels/cumsum.cl +++ b/ggml/src/ggml-opencl/kernels/cumsum.cl @@ -14,8 +14,7 @@ kernel void kernel_cumsum_f32( int ne3, int axis, int exclusive, - int reverse, - int lines + int reverse ) { src0 = (global float*)((global char*)src0 + offset0); dst = (global float*)((global char*)dst + offsetd); From dd52e3fd0c28fa6e08a6cdbb75b8062b007bfa86 Mon Sep 17 00:00:00 2001 From: shaoqi Date: Fri, 6 Feb 2026 14:31:22 -0800 Subject: [PATCH 3/3] opencl: refactor cumsum --- ggml/src/ggml-opencl/ggml-opencl.cpp | 118 +++++++++++++++------ ggml/src/ggml-opencl/kernels/cumsum.cl | 135 ++++++++++++++++++------- 2 files changed, 187 insertions(+), 66 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index f9e3f46195..5ddfc4a95c 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -540,7 +540,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; cl_kernel kernel_sum_rows_f32; - cl_kernel kernel_cumsum_f32; + cl_kernel kernel_cumsum_blk; + cl_kernel kernel_cumsum_add; cl_kernel kernel_repeat; cl_kernel kernel_repeat_f32; cl_kernel kernel_pad; @@ -1782,7 +1783,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve cl_program prog; prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_cumsum_f32 = clCreateKernel(prog, "kernel_cumsum_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cumsum_blk = clCreateKernel(prog, "kernel_cumsum_blk", &err), err)); + CL_CHECK((backend_ctx->kernel_cumsum_add = clCreateKernel(prog, "kernel_cumsum_add", &err), err)); GGML_LOG_CONT("."); CL_CHECK(clReleaseProgram(prog)); } @@ -10658,41 +10660,99 @@ static void ggml_cl_cumsum(ggml_backend_t backend, const ggml_tensor * src0, con cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - const int ne0 = src0->ne[0]; - const int ne1 = src0->ne[1]; - const int ne2 = src0->ne[2]; - const int ne3 = src0->ne[3]; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; - const int axis = ggml_get_op_params_i32(dst, 0); - const int exclusive = ggml_get_op_params_i32(dst, 1); - const int reverse = ggml_get_op_params_i32(dst, 2); + cl_kernel kernel = backend_ctx->kernel_cumsum_blk; - size_t lines = 1; - if (axis != 0) lines *= ne0; - if (axis != 1) lines *= ne1; - if (axis != 2) lines *= ne2; - if (axis != 3) lines *= ne3; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + int nth = 1; + while (nth < ne00 && 2*nth <= max_workgroup_size) { + nth *= 2; + } - cl_kernel kernel = backend_ctx->kernel_cumsum_f32; + GGML_ASSERT(ne00 <= nth*nth); + + const int net0 = (ne00 + nth - 1) / nth; + const int net1 = ne01; + const int net2 = ne02; + + const cl_ulong nbt0 = sizeof(float); + const cl_ulong nbt1 = net0*nbt0; + const cl_ulong nbt2 = net1*nbt1; + const cl_ulong nbt3 = net2*nbt2; + + cl_int status; + cl_mem tmp = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, net0 * ne01 * ne02 * ne03 * sizeof(float), NULL, &status); + CL_CHECK(status); CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne2)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne3)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &axis)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &exclusive)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &reverse)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &tmp)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &net0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &net1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &net2)); - size_t global_work_size[1] = { (size_t)lines }; - size_t local_work_val = 256; - if ((size_t)lines < local_work_val) local_work_val = (size_t)lines; - size_t local_work_size[1] = { local_work_val }; + size_t global_work_size[] = { (size_t)(nth * net0 * ne01), (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = { (size_t)nth, 1, 1}; - backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + if(ne00 > nth){ + cl_ulong offsett = 0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &tmp)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offsett)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &tmp)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &tmp)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsett)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &net0)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nbt0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nbt1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nbt2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nbt3)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &net0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &net1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &net2)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + kernel = backend_ctx->kernel_cumsum_add; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &tmp)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &nbt0)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &nbt1)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &nbt2)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &nbt3)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } + CL_CHECK(clReleaseMemObject(tmp)); } static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { diff --git a/ggml/src/ggml-opencl/kernels/cumsum.cl b/ggml/src/ggml-opencl/kernels/cumsum.cl index 7f6d539850..3416b96d8b 100644 --- a/ggml/src/ggml-opencl/kernels/cumsum.cl +++ b/ggml/src/ggml-opencl/kernels/cumsum.cl @@ -1,55 +1,116 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable - //------------------------------------------------------------------------------ // cumsum //------------------------------------------------------------------------------ -kernel void kernel_cumsum_f32( - global float * src0, +#define MAX_SUBGROUPS 16 +kernel void kernel_cumsum_blk( + global char * src0, ulong offset0, - global float * dst, + global char * tmp, + global char * dst, ulong offsetd, - int ne0, - int ne1, - int ne2, - int ne3, - int axis, - int exclusive, - int reverse + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + uint net0, + uint net1, + uint net2 ) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); + src0 = src0 + offset0; + dst = dst + offsetd; - const int gid = get_global_id(0); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - int i0 = 0, i1 = 0, i2 = 0, i3 = 0; - int t = gid; + const int nth = get_local_size(0); + const int tid = get_local_id(0); - if (axis != 3) { i3 = t % ne3; t /= ne3; } - if (axis != 2) { i2 = t % ne2; t /= ne2; } - if (axis != 1) { i1 = t % ne1; t /= ne1; } - if (axis != 0) { i0 = t % ne0; t /= ne0; } + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); - const int axis_len = (axis == 0 ? ne0 : axis == 1 ? ne1 : axis == 2 ? ne2 : ne3); + const int ib = i1 / ne01; + const int i00 = ib * nth; + const int i01 = i1 % ne01; + const int i02 = i2; + const int i03 = i3; - float acc = 0.0f; + global const float * src0_row = (global const float *)(src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float * tmp_row = (global float *)tmp + net0 * i01 + net0 * net1 * i02 + net0 * net1 * net2 * i03; + global float * dst_row = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - for (int pos = 0; pos < axis_len; pos++) { - const int a = reverse ? (axis_len - 1 - pos) : pos; + __local float partial[MAX_SUBGROUPS]; - int j0 = i0, j1 = i1, j2 = i2, j3 = i3; - if (axis == 0) j0 = a; - else if (axis == 1) j1 = a; - else if (axis == 2) j2 = a; - else j3 = a; + float v = 0.0f; + if(i00 + tid < ne00){ + v = src0_row[i00 + tid]; + } - int idx = j0 + ne0 * (j1 + ne1 * (j2 + ne2 * j3)); + float s = sub_group_scan_inclusive_add(v); + if(sg_lid == sg_size - 1){ + partial[sg_id] = s; + } + barrier(CLK_LOCAL_MEM_FENCE); - if (exclusive) { - dst[idx] = acc; - acc += src0[idx]; - } else { - acc += src0[idx]; - dst[idx] = acc; - } + if(sg_id == 0){ + float x = 0.0f; + if(sg_lid < get_num_sub_groups()) x = partial[sg_lid]; + float ex = sub_group_scan_exclusive_add(x); + if(sg_lid < get_num_sub_groups()) partial[sg_lid] = ex; + } + barrier(CLK_LOCAL_MEM_FENCE); + + s += partial[sg_id]; + + if(i00 + tid < ne00){ + dst_row[i00 + tid] = s; + } + if(ne00 > nth && tid == nth - 1){ + tmp_row[ib] = s; + } +} + +kernel void kernel_cumsum_add( + global char * tmp, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + uint nbt0, + uint nbt1, + uint nbt2, + uint nbt3 +) { + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int nth = get_local_size(0); + const int tid = get_local_id(0); + + const int ib = i1 / ne01; + if(ib == 0){ + return; + } + const int i00 = ib * nth; + const int i01 = i1 % ne01; + const int i02 = i2; + const int i03 = i3; + + global float * tmp_row = (global float *)(tmp + nbt1 * i01 + nbt2 * i02 + nbt3 * i03); + global float * dst_row = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + if(i00 + tid < ne00){ + dst_row[i00 + tid] += tmp_row[ib - 1]; } }