diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0ccdb4c328..6ffcc4bef1 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2714,7 +2714,7 @@ template <> struct cl_kernel_arg_setter { }; template <> struct cl_kernel_arg_setter { - typedef void func_t(cl_mem, cl_ulong); + typedef void func_t(char *, cl_ulong); static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) { ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra; @@ -2799,22 +2799,17 @@ template struct cl_param_type_extractor<_TFinalArg> { }; template struct cl_kernel_invoker { - template static void invoke(cl_kernel kernel, _TCalledArgs &&... args) { + template static size_t invoke(cl_kernel kernel, _TCalledArgs &&... args) { static_assert(std::is_same_v<_TFunc, typename cl_param_type_extractor<_TCalledArgs...>::func_t>, "Kernel argument type mismatch between prototype and called arguments"); - - size_t index = 0; - ( - [&] { - index = cl_kernel_arg_setter< - std::remove_const_t>>>::set_arg(kernel, - index, - args); - }(), - ...); + return cl_set_kernel_args(kernel, args...); } }; +template static inline size_t cl_set_kernel_args_safe(cl_kernel kernel, _TArgs &&... args) { + return cl_kernel_invoker<_TFunc>::invoke(kernel, args...); +} + } // namespace // Additional tensor extra structs for quantized tensors. @@ -5194,7 +5189,7 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const } else { if (src0->type == GGML_TYPE_F32) { kernel = backend_ctx->kernel_div; - cl_kernel_invoker::invoke( + cl_set_kernel_args_safe( kernel, src0, src1, diff --git a/ggml/src/ggml-opencl/kernels/div.h b/ggml/src/ggml-opencl/kernels/div.h index f09d88cc19..fd65b5706f 100644 --- a/ggml/src/ggml-opencl/kernels/div.h +++ b/ggml/src/ggml-opencl/kernels/div.h @@ -4,28 +4,28 @@ #include "ocl_defs.h" -OCL_KERNEL void kernel_div(ocl_global_char_ptr src0, - ulong offset0, - ocl_global_char_ptr src1, - ulong offset1, - ocl_global_char_ptr dst, - ulong offsetd, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3); +OCL_KERNEL void kernel_div(OCL_GLOBAL char * src0, + ulong offset0, + OCL_GLOBAL char * src1, + ulong offset1, + OCL_GLOBAL char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3); #endif // __KERNELS_DIV_H__ diff --git a/ggml/src/ggml-opencl/kernels/ocl_defs.h b/ggml/src/ggml-opencl/kernels/ocl_defs.h index b6743890b3..edf32ad0a7 100644 --- a/ggml/src/ggml-opencl/kernels/ocl_defs.h +++ b/ggml/src/ggml-opencl/kernels/ocl_defs.h @@ -6,12 +6,10 @@ // Device (OpenCL) Definitions # define OCL_KERNEL kernel # define OCL_GLOBAL global -# define ocl_global_char_ptr global char * #else // Host (C++) Definitions # define OCL_KERNEL # define OCL_GLOBAL -# define ocl_global_char_ptr cl_mem # define __kernel # define __global # define ulong cl_ulong