feat: update OpenCL kernel argument types and invoker for division operations
This commit is contained in:
parent
61093a4159
commit
08dbd97356
|
|
@ -2714,7 +2714,7 @@ template <> struct cl_kernel_arg_setter<float> {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <> struct cl_kernel_arg_setter<ggml_tensor> {
|
template <> struct cl_kernel_arg_setter<ggml_tensor> {
|
||||||
typedef void func_t(cl_mem, cl_ulong);
|
typedef void func_t(char *, cl_ulong);
|
||||||
|
|
||||||
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
|
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
|
||||||
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra;
|
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra;
|
||||||
|
|
@ -2799,22 +2799,17 @@ template <typename _TFinalArg> struct cl_param_type_extractor<_TFinalArg> {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename _TFunc> struct cl_kernel_invoker {
|
template <typename _TFunc> struct cl_kernel_invoker {
|
||||||
template <typename... _TCalledArgs> static void invoke(cl_kernel kernel, _TCalledArgs &&... args) {
|
template <typename... _TCalledArgs> static size_t invoke(cl_kernel kernel, _TCalledArgs &&... args) {
|
||||||
static_assert(std::is_same_v<_TFunc, typename cl_param_type_extractor<_TCalledArgs...>::func_t>,
|
static_assert(std::is_same_v<_TFunc, typename cl_param_type_extractor<_TCalledArgs...>::func_t>,
|
||||||
"Kernel argument type mismatch between prototype and called arguments");
|
"Kernel argument type mismatch between prototype and called arguments");
|
||||||
|
return cl_set_kernel_args(kernel, args...);
|
||||||
size_t index = 0;
|
|
||||||
(
|
|
||||||
[&] {
|
|
||||||
index = cl_kernel_arg_setter<
|
|
||||||
std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TCalledArgs>>>>::set_arg(kernel,
|
|
||||||
index,
|
|
||||||
args);
|
|
||||||
}(),
|
|
||||||
...);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename _TFunc, typename... _TArgs> static inline size_t cl_set_kernel_args_safe(cl_kernel kernel, _TArgs &&... args) {
|
||||||
|
return cl_kernel_invoker<_TFunc>::invoke(kernel, args...);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Additional tensor extra structs for quantized tensors.
|
// Additional tensor extra structs for quantized tensors.
|
||||||
|
|
@ -5194,7 +5189,7 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||||
} else {
|
} else {
|
||||||
if (src0->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
kernel = backend_ctx->kernel_div;
|
kernel = backend_ctx->kernel_div;
|
||||||
cl_kernel_invoker<decltype(ocl_kernel_prototypes::kernel_div)>::invoke(
|
cl_set_kernel_args_safe<decltype(ocl_kernel_prototypes::kernel_div)>(
|
||||||
kernel,
|
kernel,
|
||||||
src0,
|
src0,
|
||||||
src1,
|
src1,
|
||||||
|
|
|
||||||
|
|
@ -4,28 +4,28 @@
|
||||||
|
|
||||||
#include "ocl_defs.h"
|
#include "ocl_defs.h"
|
||||||
|
|
||||||
OCL_KERNEL void kernel_div(ocl_global_char_ptr src0,
|
OCL_KERNEL void kernel_div(OCL_GLOBAL char * src0,
|
||||||
ulong offset0,
|
ulong offset0,
|
||||||
ocl_global_char_ptr src1,
|
OCL_GLOBAL char * src1,
|
||||||
ulong offset1,
|
ulong offset1,
|
||||||
ocl_global_char_ptr dst,
|
OCL_GLOBAL char * dst,
|
||||||
ulong offsetd,
|
ulong offsetd,
|
||||||
ulong nb00,
|
ulong nb00,
|
||||||
ulong nb01,
|
ulong nb01,
|
||||||
ulong nb02,
|
ulong nb02,
|
||||||
ulong nb03,
|
ulong nb03,
|
||||||
int ne10,
|
int ne10,
|
||||||
int ne11,
|
int ne11,
|
||||||
int ne12,
|
int ne12,
|
||||||
int ne13,
|
int ne13,
|
||||||
ulong nb10,
|
ulong nb10,
|
||||||
ulong nb11,
|
ulong nb11,
|
||||||
ulong nb12,
|
ulong nb12,
|
||||||
ulong nb13,
|
ulong nb13,
|
||||||
int ne0,
|
int ne0,
|
||||||
ulong nb0,
|
ulong nb0,
|
||||||
ulong nb1,
|
ulong nb1,
|
||||||
ulong nb2,
|
ulong nb2,
|
||||||
ulong nb3);
|
ulong nb3);
|
||||||
|
|
||||||
#endif // __KERNELS_DIV_H__
|
#endif // __KERNELS_DIV_H__
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,10 @@
|
||||||
// Device (OpenCL) Definitions
|
// Device (OpenCL) Definitions
|
||||||
# define OCL_KERNEL kernel
|
# define OCL_KERNEL kernel
|
||||||
# define OCL_GLOBAL global
|
# define OCL_GLOBAL global
|
||||||
# define ocl_global_char_ptr global char *
|
|
||||||
#else
|
#else
|
||||||
// Host (C++) Definitions
|
// Host (C++) Definitions
|
||||||
# define OCL_KERNEL
|
# define OCL_KERNEL
|
||||||
# define OCL_GLOBAL
|
# define OCL_GLOBAL
|
||||||
# define ocl_global_char_ptr cl_mem
|
|
||||||
# define __kernel
|
# define __kernel
|
||||||
# define __global
|
# define __global
|
||||||
# define ulong cl_ulong
|
# define ulong cl_ulong
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue