From 208f8454cd332c5f4001798cbf61f3c2e209e5a9 Mon Sep 17 00:00:00 2001 From: chraac Date: Fri, 16 Jan 2026 22:44:32 +0800 Subject: [PATCH] feat: add function type definitions for OpenCL kernel argument setters --- ggml/src/ggml-opencl/ggml-opencl.cpp | 106 +++++++++++++-------------- 1 file changed, 51 insertions(+), 55 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 7284bcdfe2..61386c3918 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2687,6 +2687,8 @@ namespace /* anonymous */ { template struct cl_kernel_arg_setter {}; template <> struct cl_kernel_arg_setter { + typedef void func_t(int); + static size_t set_arg(cl_kernel kernel, size_t index, int arg) { CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg)); return index + 1; @@ -2694,6 +2696,8 @@ template <> struct cl_kernel_arg_setter { }; template <> struct cl_kernel_arg_setter>> { + typedef void func_t(cl_mem); + static size_t set_arg(cl_kernel kernel, size_t index, cl_mem arg) { CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg)); return index + 1; @@ -2702,6 +2706,8 @@ template <> struct cl_kernel_arg_setter struct cl_kernel_arg_setter { + typedef void func_t(cl_ulong); + static size_t set_arg(cl_kernel kernel, size_t index, cl_ulong arg) { CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg)); return index + 1; @@ -2709,6 +2715,8 @@ template <> struct cl_kernel_arg_setter { }; template <> struct cl_kernel_arg_setter { + typedef void func_t(float); + static size_t set_arg(cl_kernel kernel, size_t index, float arg) { CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg)); return index + 1; @@ -2716,6 +2724,8 @@ template <> struct cl_kernel_arg_setter { }; template <> struct cl_kernel_arg_setter { + typedef void func_t(cl_mem, cl_ulong); + static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) { ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra; static_assert(std::is_same_vdata_device), cl_mem>, "data_device type mismatch"); @@ -2728,6 +2738,8 @@ template <> struct cl_kernel_arg_setter { }; template <> struct cl_kernel_arg_setter { + typedef void func_t(int, int, int, int); + static size_t set_arg(cl_kernel kernel, size_t index, const int64_t (&ne)[GGML_MAX_DIMS]) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly"); @@ -2744,6 +2756,8 @@ template <> struct cl_kernel_arg_setter { }; template <> struct cl_kernel_arg_setter { + typedef void func_t(cl_ulong, cl_ulong, cl_ulong, cl_ulong); + static size_t set_arg(cl_kernel kernel, size_t index, const size_t (&nb)[GGML_MAX_DIMS]) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly"); @@ -2771,18 +2785,41 @@ template static inline size_t cl_set_kernel_args(cl_kernel return index; } -template -struct cl_kernel_invoker {}; +template struct type_merger {}; + +template +struct type_merger { + using func_t = typename type_merger::func_t; +}; + +template struct type_merger { + using func_t = void(_TInnerArgs...); +}; + +template struct cl_param_type_extractor { + using args_t = std::remove_const_t>>; + using first_func_t = + typename cl_kernel_arg_setter::func_t; + using func_t = typename type_merger::func_t>::func_t; +}; + +template struct cl_param_type_extractor<_TFinalArg> { + using args_t = std::remove_const_t>>; + using func_t = typename cl_kernel_arg_setter::func_t; +}; + +template struct cl_kernel_invoker { + template static void invoke(cl_kernel kernel, _TCalledArgs &&... args) { + static_assert(std::is_same_v<_TFunc, typename cl_param_type_extractor<_TCalledArgs...>::func_t>, + "Kernel argument type mismatch between prototype and called arguments"); -template -struct cl_kernel_invoker { - static void invoke(cl_kernel kernel, _TArgs... args) { size_t index = 0; ( [&] { index = cl_kernel_arg_setter< - std::remove_const_t>>>::set_arg(kernel, index, - args); + std::remove_const_t>>>::set_arg(kernel, + index, + args); }(), ...); } @@ -5167,57 +5204,16 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const } else { if (src0->type == GGML_TYPE_F32) { kernel = backend_ctx->kernel_div; - ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; - ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - - cl_ulong offset0 = extra0->offset + src0->view_offs; - cl_ulong offset1 = extra1->offset + src1->view_offs; - cl_ulong offsetd = extrad->offset + dst->view_offs; - - const cl_ulong nb00 = src0->nb[0]; - const cl_ulong nb01 = src0->nb[1]; - const cl_ulong nb02 = src0->nb[2]; - const cl_ulong nb03 = src0->nb[3]; - - const int ne12 = src1->ne[2]; - const int ne13 = src1->ne[3]; - - const cl_ulong nb10 = src1->nb[0]; - const cl_ulong nb11 = src1->nb[1]; - const cl_ulong nb12 = src1->nb[2]; - const cl_ulong nb13 = src1->nb[3]; - - const cl_ulong nb0 = dst->nb[0]; - const cl_ulong nb1 = dst->nb[1]; - const cl_ulong nb2 = dst->nb[2]; - const cl_ulong nb3 = dst->nb[3]; - cl_kernel_invoker::invoke( kernel, - extra0->data_device, - offset0, - extra1->data_device, - offset1, - extrad->data_device, - offsetd, - nb00, - nb01, - nb02, - nb03, - ne10, - ne11, - ne12, - ne13, - nb10, - nb11, - nb12, - nb13, + src0, + src1, + dst, + src0->nb, + src1->ne, + src1->nb, ne0, - nb0, - nb1, - nb2, - nb3 + dst->nb ); } else { kernel = backend_ctx->kernel_div_f16;