feat: enhance OpenCL kernel division operations with new argument setters and invoker
This commit is contained in:
parent
5fbff1aa3a
commit
b2a283d5e9
|
|
@ -2693,6 +2693,14 @@ template <> struct cl_kernel_arg_setter<int> {
|
|||
}
|
||||
};
|
||||
|
||||
template <> struct cl_kernel_arg_setter<std::remove_pointer_t<std::remove_reference_t<cl_mem>>> {
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, cl_mem arg) {
|
||||
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
|
||||
return index + 1;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <> struct cl_kernel_arg_setter<cl_ulong> {
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, cl_ulong arg) {
|
||||
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
|
||||
|
|
@ -2763,6 +2771,23 @@ template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel
|
|||
return index;
|
||||
}
|
||||
|
||||
template<typename _TFunc>
|
||||
struct cl_kernel_invoker {};
|
||||
|
||||
template<typename... _TArgs>
|
||||
struct cl_kernel_invoker<void(_TArgs...)> {
|
||||
static void invoke(cl_kernel kernel, _TArgs... args) {
|
||||
size_t index = 0;
|
||||
(
|
||||
[&] {
|
||||
index = cl_kernel_arg_setter<
|
||||
std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TArgs>>>>::set_arg(kernel, index,
|
||||
args);
|
||||
}(),
|
||||
...);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// Additional tensor extra structs for quantized tensors.
|
||||
|
|
@ -5142,21 +5167,72 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
|
|||
} else {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_div;
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
const cl_ulong nb00 = src0->nb[0];
|
||||
const cl_ulong nb01 = src0->nb[1];
|
||||
const cl_ulong nb02 = src0->nb[2];
|
||||
const cl_ulong nb03 = src0->nb[3];
|
||||
|
||||
const int ne12 = src1->ne[2];
|
||||
const int ne13 = src1->ne[3];
|
||||
|
||||
const cl_ulong nb10 = src1->nb[0];
|
||||
const cl_ulong nb11 = src1->nb[1];
|
||||
const cl_ulong nb12 = src1->nb[2];
|
||||
const cl_ulong nb13 = src1->nb[3];
|
||||
|
||||
const cl_ulong nb0 = dst->nb[0];
|
||||
const cl_ulong nb1 = dst->nb[1];
|
||||
const cl_ulong nb2 = dst->nb[2];
|
||||
const cl_ulong nb3 = dst->nb[3];
|
||||
|
||||
cl_kernel_invoker<decltype(ocl_kernel_prototypes::kernel_div)>::invoke(
|
||||
kernel,
|
||||
extra0->data_device,
|
||||
offset0,
|
||||
extra1->data_device,
|
||||
offset1,
|
||||
extrad->data_device,
|
||||
offsetd,
|
||||
nb00,
|
||||
nb01,
|
||||
nb02,
|
||||
nb03,
|
||||
ne10,
|
||||
ne11,
|
||||
ne12,
|
||||
ne13,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
nb13,
|
||||
ne0,
|
||||
nb0,
|
||||
nb1,
|
||||
nb2,
|
||||
nb3
|
||||
);
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_div_f16;
|
||||
cl_set_kernel_args(
|
||||
kernel,
|
||||
src0,
|
||||
src1,
|
||||
dst,
|
||||
src0->nb,
|
||||
src1->ne,
|
||||
src1->nb,
|
||||
ne0,
|
||||
dst->nb
|
||||
);
|
||||
}
|
||||
|
||||
cl_set_kernel_args(
|
||||
kernel,
|
||||
src0,
|
||||
src1,
|
||||
dst,
|
||||
src0->nb,
|
||||
src1->ne,
|
||||
src1->nb,
|
||||
ne0,
|
||||
dst->nb
|
||||
);
|
||||
}
|
||||
|
||||
if (bcast_row) {
|
||||
|
|
|
|||
|
|
@ -4,28 +4,28 @@
|
|||
|
||||
#include "ocl_defs.h"
|
||||
|
||||
OCL_KERNEL void kernel_div(OCL_GLOBAL char * src0,
|
||||
ulong offset0,
|
||||
OCL_GLOBAL char * src1,
|
||||
ulong offset1,
|
||||
OCL_GLOBAL char * dst,
|
||||
ulong offsetd,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3);
|
||||
OCL_KERNEL void kernel_div(ocl_global_char_ptr src0,
|
||||
ulong offset0,
|
||||
ocl_global_char_ptr src1,
|
||||
ulong offset1,
|
||||
ocl_global_char_ptr dst,
|
||||
ulong offsetd,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne10,
|
||||
int ne11,
|
||||
int ne12,
|
||||
int ne13,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
ulong nb0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3);
|
||||
|
||||
#endif // __KERNELS_DIV_H__
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@
|
|||
|
||||
#ifdef __OPENCL_C_VERSION__
|
||||
// Device (OpenCL) Definitions
|
||||
# define OCL_KERNEL kernel
|
||||
# define OCL_GLOBAL global
|
||||
# define OCL_KERNEL kernel
|
||||
# define OCL_GLOBAL global
|
||||
# define ocl_global_char_ptr global char *
|
||||
#else
|
||||
// Host (C++) Definitions
|
||||
# define OCL_KERNEL
|
||||
# define OCL_GLOBAL
|
||||
# define ocl_global_char_ptr cl_mem
|
||||
# define __kernel
|
||||
# define __global
|
||||
# define ulong cl_ulong
|
||||
|
|
|
|||
Loading…
Reference in New Issue