feat: update OpenCL kernel argument types and invoker for division operations

This commit is contained in:
chraac 2026-01-16 23:11:16 +08:00
parent 61093a4159
commit 08dbd97356
3 changed files with 31 additions and 38 deletions

View File

@ -2714,7 +2714,7 @@ template <> struct cl_kernel_arg_setter<float> {
};
template <> struct cl_kernel_arg_setter<ggml_tensor> {
typedef void func_t(cl_mem, cl_ulong);
typedef void func_t(char *, cl_ulong);
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra;
@ -2799,22 +2799,17 @@ template <typename _TFinalArg> struct cl_param_type_extractor<_TFinalArg> {
};
template <typename _TFunc> struct cl_kernel_invoker {
template <typename... _TCalledArgs> static void invoke(cl_kernel kernel, _TCalledArgs &&... args) {
template <typename... _TCalledArgs> static size_t invoke(cl_kernel kernel, _TCalledArgs &&... args) {
static_assert(std::is_same_v<_TFunc, typename cl_param_type_extractor<_TCalledArgs...>::func_t>,
"Kernel argument type mismatch between prototype and called arguments");
size_t index = 0;
(
[&] {
index = cl_kernel_arg_setter<
std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TCalledArgs>>>>::set_arg(kernel,
index,
args);
}(),
...);
return cl_set_kernel_args(kernel, args...);
}
};
template <typename _TFunc, typename... _TArgs> static inline size_t cl_set_kernel_args_safe(cl_kernel kernel, _TArgs &&... args) {
return cl_kernel_invoker<_TFunc>::invoke(kernel, args...);
}
} // namespace
// Additional tensor extra structs for quantized tensors.
@ -5194,7 +5189,7 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
} else {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_div;
cl_kernel_invoker<decltype(ocl_kernel_prototypes::kernel_div)>::invoke(
cl_set_kernel_args_safe<decltype(ocl_kernel_prototypes::kernel_div)>(
kernel,
src0,
src1,

View File

@ -4,28 +4,28 @@
#include "ocl_defs.h"
OCL_KERNEL void kernel_div(ocl_global_char_ptr src0,
ulong offset0,
ocl_global_char_ptr src1,
ulong offset1,
ocl_global_char_ptr dst,
ulong offsetd,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3);
OCL_KERNEL void kernel_div(OCL_GLOBAL char * src0,
ulong offset0,
OCL_GLOBAL char * src1,
ulong offset1,
OCL_GLOBAL char * dst,
ulong offsetd,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3);
#endif // __KERNELS_DIV_H__

View File

@ -6,12 +6,10 @@
// Device (OpenCL) Definitions
# define OCL_KERNEL kernel
# define OCL_GLOBAL global
# define ocl_global_char_ptr global char *
#else
// Host (C++) Definitions
# define OCL_KERNEL
# define OCL_GLOBAL
# define ocl_global_char_ptr cl_mem
# define __kernel
# define __global
# define ulong cl_ulong