feat: update OpenCL kernel argument types and invoker for division operations
This commit is contained in:
parent
61093a4159
commit
08dbd97356
|
|
@ -2714,7 +2714,7 @@ template <> struct cl_kernel_arg_setter<float> {
|
|||
};
|
||||
|
||||
template <> struct cl_kernel_arg_setter<ggml_tensor> {
|
||||
typedef void func_t(cl_mem, cl_ulong);
|
||||
typedef void func_t(char *, cl_ulong);
|
||||
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
|
||||
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra;
|
||||
|
|
@ -2799,22 +2799,17 @@ template <typename _TFinalArg> struct cl_param_type_extractor<_TFinalArg> {
|
|||
};
|
||||
|
||||
template <typename _TFunc> struct cl_kernel_invoker {
|
||||
template <typename... _TCalledArgs> static void invoke(cl_kernel kernel, _TCalledArgs &&... args) {
|
||||
template <typename... _TCalledArgs> static size_t invoke(cl_kernel kernel, _TCalledArgs &&... args) {
|
||||
static_assert(std::is_same_v<_TFunc, typename cl_param_type_extractor<_TCalledArgs...>::func_t>,
|
||||
"Kernel argument type mismatch between prototype and called arguments");
|
||||
|
||||
size_t index = 0;
|
||||
(
|
||||
[&] {
|
||||
index = cl_kernel_arg_setter<
|
||||
std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TCalledArgs>>>>::set_arg(kernel,
|
||||
index,
|
||||
args);
|
||||
}(),
|
||||
...);
|
||||
return cl_set_kernel_args(kernel, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename _TFunc, typename... _TArgs> static inline size_t cl_set_kernel_args_safe(cl_kernel kernel, _TArgs &&... args) {
|
||||
return cl_kernel_invoker<_TFunc>::invoke(kernel, args...);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Additional tensor extra structs for quantized tensors.
|
||||
|
|
@ -5194,7 +5189,7 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
|
|||
} else {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_div;
|
||||
cl_kernel_invoker<decltype(ocl_kernel_prototypes::kernel_div)>::invoke(
|
||||
cl_set_kernel_args_safe<decltype(ocl_kernel_prototypes::kernel_div)>(
|
||||
kernel,
|
||||
src0,
|
||||
src1,
|
||||
|
|
|
|||
|
|
@ -4,28 +4,28 @@
|
|||
|
||||
#include "ocl_defs.h"
|
||||
|
||||
OCL_KERNEL void kernel_div(ocl_global_char_ptr src0,
|
||||
ulong offset0,
|
||||
ocl_global_char_ptr src1,
|
||||
ulong offset1,
|
||||
ocl_global_char_ptr dst,
|
||||
ulong offsetd,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3);
|
||||
OCL_KERNEL void kernel_div(OCL_GLOBAL char * src0,
|
||||
ulong offset0,
|
||||
OCL_GLOBAL char * src1,
|
||||
ulong offset1,
|
||||
OCL_GLOBAL char * dst,
|
||||
ulong offsetd,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3);
|
||||
|
||||
#endif // __KERNELS_DIV_H__
|
||||
|
|
|
|||
|
|
@ -6,12 +6,10 @@
|
|||
// Device (OpenCL) Definitions
|
||||
# define OCL_KERNEL kernel
|
||||
# define OCL_GLOBAL global
|
||||
# define ocl_global_char_ptr global char *
|
||||
#else
|
||||
// Host (C++) Definitions
|
||||
# define OCL_KERNEL
|
||||
# define OCL_GLOBAL
|
||||
# define ocl_global_char_ptr cl_mem
|
||||
# define __kernel
|
||||
# define __global
|
||||
# define ulong cl_ulong
|
||||
|
|
|
|||
Loading…
Reference in New Issue