feat: update OpenCL kernel argument types and invoker for division operations

This commit is contained in:
chraac 2026-01-16 23:11:16 +08:00
parent 61093a4159
commit 08dbd97356
3 changed files with 31 additions and 38 deletions

View File

@ -2714,7 +2714,7 @@ template <> struct cl_kernel_arg_setter<float> {
}; };
template <> struct cl_kernel_arg_setter<ggml_tensor> { template <> struct cl_kernel_arg_setter<ggml_tensor> {
typedef void func_t(cl_mem, cl_ulong); typedef void func_t(char *, cl_ulong);
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) { static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra; ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra;
@ -2799,22 +2799,17 @@ template <typename _TFinalArg> struct cl_param_type_extractor<_TFinalArg> {
}; };
template <typename _TFunc> struct cl_kernel_invoker { template <typename _TFunc> struct cl_kernel_invoker {
template <typename... _TCalledArgs> static void invoke(cl_kernel kernel, _TCalledArgs &&... args) { template <typename... _TCalledArgs> static size_t invoke(cl_kernel kernel, _TCalledArgs &&... args) {
static_assert(std::is_same_v<_TFunc, typename cl_param_type_extractor<_TCalledArgs...>::func_t>, static_assert(std::is_same_v<_TFunc, typename cl_param_type_extractor<_TCalledArgs...>::func_t>,
"Kernel argument type mismatch between prototype and called arguments"); "Kernel argument type mismatch between prototype and called arguments");
return cl_set_kernel_args(kernel, args...);
size_t index = 0;
(
[&] {
index = cl_kernel_arg_setter<
std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TCalledArgs>>>>::set_arg(kernel,
index,
args);
}(),
...);
} }
}; };
template <typename _TFunc, typename... _TArgs> static inline size_t cl_set_kernel_args_safe(cl_kernel kernel, _TArgs &&... args) {
return cl_kernel_invoker<_TFunc>::invoke(kernel, args...);
}
} // namespace } // namespace
// Additional tensor extra structs for quantized tensors. // Additional tensor extra structs for quantized tensors.
@ -5194,7 +5189,7 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
} else { } else {
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_div; kernel = backend_ctx->kernel_div;
cl_kernel_invoker<decltype(ocl_kernel_prototypes::kernel_div)>::invoke( cl_set_kernel_args_safe<decltype(ocl_kernel_prototypes::kernel_div)>(
kernel, kernel,
src0, src0,
src1, src1,

View File

@ -4,11 +4,11 @@
#include "ocl_defs.h" #include "ocl_defs.h"
OCL_KERNEL void kernel_div(ocl_global_char_ptr src0, OCL_KERNEL void kernel_div(OCL_GLOBAL char * src0,
ulong offset0, ulong offset0,
ocl_global_char_ptr src1, OCL_GLOBAL char * src1,
ulong offset1, ulong offset1,
ocl_global_char_ptr dst, OCL_GLOBAL char * dst,
ulong offsetd, ulong offsetd,
ulong nb00, ulong nb00,
ulong nb01, ulong nb01,

View File

@ -6,12 +6,10 @@
// Device (OpenCL) Definitions // Device (OpenCL) Definitions
# define OCL_KERNEL kernel # define OCL_KERNEL kernel
# define OCL_GLOBAL global # define OCL_GLOBAL global
# define ocl_global_char_ptr global char *
#else #else
// Host (C++) Definitions // Host (C++) Definitions
# define OCL_KERNEL # define OCL_KERNEL
# define OCL_GLOBAL # define OCL_GLOBAL
# define ocl_global_char_ptr cl_mem
# define __kernel # define __kernel
# define __global # define __global
# define ulong cl_ulong # define ulong cl_ulong