diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index fa5fadd112..bb2e2ecd5c 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -121,6 +121,7 @@ set(GGML_OPENCL_KERNELS ssm_conv sub sum_rows + cumsum transpose concat tsembd diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 508b2b8f03..5ddfc4a95c 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -540,6 +540,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; cl_kernel kernel_sum_rows_f32; + cl_kernel kernel_cumsum_blk; + cl_kernel kernel_cumsum_add; + cl_kernel kernel_repeat; cl_kernel kernel_repeat_f32; cl_kernel kernel_pad; cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc; @@ -1768,6 +1771,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // cumsum + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cumsum.cl.h" + }; +#else + const std::string kernel_src = read_file("cumsum.cl"); +#endif + cl_program prog; + prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_cumsum_blk = clCreateKernel(prog, "kernel_cumsum_blk", &err), err)); + CL_CHECK((backend_ctx->kernel_cumsum_add = clCreateKernel(prog, "kernel_cumsum_add", &err), err)); + GGML_LOG_CONT("."); + CL_CHECK(clReleaseProgram(prog)); + } + // sigmoid { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3422,6 +3443,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32; } case GGML_OP_SUM_ROWS: + case GGML_OP_CUMSUM: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_MEAN: return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_FLASH_ATTN_EXT: @@ -10619,6 +10642,119 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_cumsum(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + cl_kernel kernel = backend_ctx->kernel_cumsum_blk; + + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + int nth = 1; + while (nth < ne00 && 2*nth <= max_workgroup_size) { + nth *= 2; + } + + GGML_ASSERT(ne00 <= nth*nth); + + const int net0 = (ne00 + nth - 1) / nth; + const int net1 = ne01; + const int net2 = ne02; + + const cl_ulong nbt0 = sizeof(float); + const cl_ulong nbt1 = net0*nbt0; + const cl_ulong nbt2 = net1*nbt1; + const cl_ulong nbt3 = net2*nbt2; + + cl_int status; + cl_mem tmp = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, net0 * ne01 * ne02 * ne03 * sizeof(float), NULL, &status); + CL_CHECK(status); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &tmp)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &net0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &net1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &net2)); + + size_t global_work_size[] = { (size_t)(nth * net0 * ne01), (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = { (size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + if(ne00 > nth){ + cl_ulong offsett = 0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &tmp)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offsett)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &tmp)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &tmp)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsett)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &net0)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nbt0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nbt1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nbt2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nbt3)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &net0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &net1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &net2)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + kernel = backend_ctx->kernel_cumsum_add; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &tmp)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &nbt0)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &nbt1)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &nbt2)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &nbt3)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } + CL_CHECK(clReleaseMemObject(tmp)); +} + static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -11031,6 +11167,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_sum_rows; break; + case GGML_OP_CUMSUM: + if (!any_on_device) { + return false; + } + func = ggml_cl_cumsum; + break; case GGML_OP_FLASH_ATTN_EXT: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/cumsum.cl b/ggml/src/ggml-opencl/kernels/cumsum.cl new file mode 100644 index 0000000000..3416b96d8b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/cumsum.cl @@ -0,0 +1,116 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +//------------------------------------------------------------------------------ +// cumsum +//------------------------------------------------------------------------------ +#define MAX_SUBGROUPS 16 +kernel void kernel_cumsum_blk( + global char * src0, + ulong offset0, + global char * tmp, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + uint net0, + uint net1, + uint net2 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int nth = get_local_size(0); + const int tid = get_local_id(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + const int ib = i1 / ne01; + const int i00 = ib * nth; + const int i01 = i1 % ne01; + const int i02 = i2; + const int i03 = i3; + + global const float * src0_row = (global const float *)(src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float * tmp_row = (global float *)tmp + net0 * i01 + net0 * net1 * i02 + net0 * net1 * net2 * i03; + global float * dst_row = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + __local float partial[MAX_SUBGROUPS]; + + float v = 0.0f; + if(i00 + tid < ne00){ + v = src0_row[i00 + tid]; + } + + float s = sub_group_scan_inclusive_add(v); + if(sg_lid == sg_size - 1){ + partial[sg_id] = s; + } + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_id == 0){ + float x = 0.0f; + if(sg_lid < get_num_sub_groups()) x = partial[sg_lid]; + float ex = sub_group_scan_exclusive_add(x); + if(sg_lid < get_num_sub_groups()) partial[sg_lid] = ex; + } + barrier(CLK_LOCAL_MEM_FENCE); + + s += partial[sg_id]; + + if(i00 + tid < ne00){ + dst_row[i00 + tid] = s; + } + if(ne00 > nth && tid == nth - 1){ + tmp_row[ib] = s; + } +} + +kernel void kernel_cumsum_add( + global char * tmp, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + uint nbt0, + uint nbt1, + uint nbt2, + uint nbt3 +) { + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int nth = get_local_size(0); + const int tid = get_local_id(0); + + const int ib = i1 / ne01; + if(ib == 0){ + return; + } + const int i00 = ib * nth; + const int i01 = i1 % ne01; + const int i02 = i2; + const int i03 = i3; + + global float * tmp_row = (global float *)(tmp + nbt1 * i01 + nbt2 * i02 + nbt3 * i03); + global float * dst_row = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + if(i00 + tid < ne00){ + dst_row[i00 + tid] += tmp_row[ib - 1]; + } +}