From b2a283d5e9c3e92619134e277f30e5a941aef58f Mon Sep 17 00:00:00 2001 From: chraac Date: Thu, 15 Jan 2026 22:29:08 +0800 Subject: [PATCH] feat: enhance OpenCL kernel division operations with new argument setters and invoker --- ggml/src/ggml-opencl/ggml-opencl.cpp | 100 +++++++++++++++++++++--- ggml/src/ggml-opencl/kernels/div.h | 46 +++++------ ggml/src/ggml-opencl/kernels/ocl_defs.h | 6 +- 3 files changed, 115 insertions(+), 37 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 38d72cba6b..7284bcdfe2 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2693,6 +2693,14 @@ template <> struct cl_kernel_arg_setter { } }; +template <> struct cl_kernel_arg_setter>> { + static size_t set_arg(cl_kernel kernel, size_t index, cl_mem arg) { + CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg)); + return index + 1; + } +}; + + template <> struct cl_kernel_arg_setter { static size_t set_arg(cl_kernel kernel, size_t index, cl_ulong arg) { CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg)); @@ -2763,6 +2771,23 @@ template static inline size_t cl_set_kernel_args(cl_kernel return index; } +template +struct cl_kernel_invoker {}; + +template +struct cl_kernel_invoker { + static void invoke(cl_kernel kernel, _TArgs... args) { + size_t index = 0; + ( + [&] { + index = cl_kernel_arg_setter< + std::remove_const_t>>>::set_arg(kernel, index, + args); + }(), + ...); + } +}; + } // namespace // Additional tensor extra structs for quantized tensors. @@ -5142,21 +5167,72 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const } else { if (src0->type == GGML_TYPE_F32) { kernel = backend_ctx->kernel_div; + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + cl_kernel_invoker::invoke( + kernel, + extra0->data_device, + offset0, + extra1->data_device, + offset1, + extrad->data_device, + offsetd, + nb00, + nb01, + nb02, + nb03, + ne10, + ne11, + ne12, + ne13, + nb10, + nb11, + nb12, + nb13, + ne0, + nb0, + nb1, + nb2, + nb3 + ); } else { kernel = backend_ctx->kernel_div_f16; + cl_set_kernel_args( + kernel, + src0, + src1, + dst, + src0->nb, + src1->ne, + src1->nb, + ne0, + dst->nb + ); } - - cl_set_kernel_args( - kernel, - src0, - src1, - dst, - src0->nb, - src1->ne, - src1->nb, - ne0, - dst->nb - ); } if (bcast_row) { diff --git a/ggml/src/ggml-opencl/kernels/div.h b/ggml/src/ggml-opencl/kernels/div.h index fd65b5706f..f09d88cc19 100644 --- a/ggml/src/ggml-opencl/kernels/div.h +++ b/ggml/src/ggml-opencl/kernels/div.h @@ -4,28 +4,28 @@ #include "ocl_defs.h" -OCL_KERNEL void kernel_div(OCL_GLOBAL char * src0, - ulong offset0, - OCL_GLOBAL char * src1, - ulong offset1, - OCL_GLOBAL char * dst, - ulong offsetd, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3); +OCL_KERNEL void kernel_div(ocl_global_char_ptr src0, + ulong offset0, + ocl_global_char_ptr src1, + ulong offset1, + ocl_global_char_ptr dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3); #endif // __KERNELS_DIV_H__ diff --git a/ggml/src/ggml-opencl/kernels/ocl_defs.h b/ggml/src/ggml-opencl/kernels/ocl_defs.h index 3faf1ecf56..b6743890b3 100644 --- a/ggml/src/ggml-opencl/kernels/ocl_defs.h +++ b/ggml/src/ggml-opencl/kernels/ocl_defs.h @@ -4,12 +4,14 @@ #ifdef __OPENCL_C_VERSION__ // Device (OpenCL) Definitions -# define OCL_KERNEL kernel -# define OCL_GLOBAL global +# define OCL_KERNEL kernel +# define OCL_GLOBAL global +# define ocl_global_char_ptr global char * #else // Host (C++) Definitions # define OCL_KERNEL # define OCL_GLOBAL +# define ocl_global_char_ptr cl_mem # define __kernel # define __global # define ulong cl_ulong