diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 6ffcc4bef1..f8034fa175 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2775,22 +2775,22 @@ template static inline size_t cl_set_kernel_args(cl_kernel return index; } -template struct type_merger {}; +template struct cl_func_args_concatenator {}; template -struct type_merger { - using func_t = typename type_merger::func_t; +struct cl_func_args_concatenator { + using func_t = typename cl_func_args_concatenator::func_t; }; -template struct type_merger { +template struct cl_func_args_concatenator { using func_t = void(_TInnerArgs...); }; template struct cl_param_type_extractor { - using args_t = std::remove_const_t>>; - using first_func_t = - typename cl_kernel_arg_setter::func_t; - using func_t = typename type_merger::func_t>::func_t; + using args_t = std::remove_const_t>>; + using first_func_t = typename cl_kernel_arg_setter::func_t; + using func_t = typename cl_func_args_concatenator::func_t>::func_t; }; template struct cl_param_type_extractor<_TFinalArg> { @@ -2798,16 +2798,11 @@ template struct cl_param_type_extractor<_TFinalArg> { using func_t = typename cl_kernel_arg_setter::func_t; }; -template struct cl_kernel_invoker { - template static size_t invoke(cl_kernel kernel, _TCalledArgs &&... args) { - static_assert(std::is_same_v<_TFunc, typename cl_param_type_extractor<_TCalledArgs...>::func_t>, - "Kernel argument type mismatch between prototype and called arguments"); - return cl_set_kernel_args(kernel, args...); - } -}; - -template static inline size_t cl_set_kernel_args_safe(cl_kernel kernel, _TArgs &&... args) { - return cl_kernel_invoker<_TFunc>::invoke(kernel, args...); +template +static inline size_t cl_set_kernel_args_safe(cl_kernel kernel, _TArgs &&... args) { + static_assert(std::is_same_v<_TFunc, typename cl_param_type_extractor<_TArgs...>::func_t>, + "Kernel argument type mismatch between prototype and called arguments"); + return cl_set_kernel_args(kernel, args...); } } // namespace