feat: enhance OpenCL kernel division operations with new argument setters and invoker
This commit is contained in:
parent
5fbff1aa3a
commit
b2a283d5e9
|
|
@ -2693,6 +2693,14 @@ template <> struct cl_kernel_arg_setter<int> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <> struct cl_kernel_arg_setter<std::remove_pointer_t<std::remove_reference_t<cl_mem>>> {
|
||||||
|
static size_t set_arg(cl_kernel kernel, size_t index, cl_mem arg) {
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
|
||||||
|
return index + 1;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
template <> struct cl_kernel_arg_setter<cl_ulong> {
|
template <> struct cl_kernel_arg_setter<cl_ulong> {
|
||||||
static size_t set_arg(cl_kernel kernel, size_t index, cl_ulong arg) {
|
static size_t set_arg(cl_kernel kernel, size_t index, cl_ulong arg) {
|
||||||
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
|
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
|
||||||
|
|
@ -2763,6 +2771,23 @@ template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel
|
||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename _TFunc>
|
||||||
|
struct cl_kernel_invoker {};
|
||||||
|
|
||||||
|
template<typename... _TArgs>
|
||||||
|
struct cl_kernel_invoker<void(_TArgs...)> {
|
||||||
|
static void invoke(cl_kernel kernel, _TArgs... args) {
|
||||||
|
size_t index = 0;
|
||||||
|
(
|
||||||
|
[&] {
|
||||||
|
index = cl_kernel_arg_setter<
|
||||||
|
std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TArgs>>>>::set_arg(kernel, index,
|
||||||
|
args);
|
||||||
|
}(),
|
||||||
|
...);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Additional tensor extra structs for quantized tensors.
|
// Additional tensor extra structs for quantized tensors.
|
||||||
|
|
@ -5142,21 +5167,72 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||||
} else {
|
} else {
|
||||||
if (src0->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32) {
|
||||||
kernel = backend_ctx->kernel_div;
|
kernel = backend_ctx->kernel_div;
|
||||||
|
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||||
|
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||||
|
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||||
|
|
||||||
|
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||||
|
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||||
|
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||||
|
|
||||||
|
const cl_ulong nb00 = src0->nb[0];
|
||||||
|
const cl_ulong nb01 = src0->nb[1];
|
||||||
|
const cl_ulong nb02 = src0->nb[2];
|
||||||
|
const cl_ulong nb03 = src0->nb[3];
|
||||||
|
|
||||||
|
const int ne12 = src1->ne[2];
|
||||||
|
const int ne13 = src1->ne[3];
|
||||||
|
|
||||||
|
const cl_ulong nb10 = src1->nb[0];
|
||||||
|
const cl_ulong nb11 = src1->nb[1];
|
||||||
|
const cl_ulong nb12 = src1->nb[2];
|
||||||
|
const cl_ulong nb13 = src1->nb[3];
|
||||||
|
|
||||||
|
const cl_ulong nb0 = dst->nb[0];
|
||||||
|
const cl_ulong nb1 = dst->nb[1];
|
||||||
|
const cl_ulong nb2 = dst->nb[2];
|
||||||
|
const cl_ulong nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
cl_kernel_invoker<decltype(ocl_kernel_prototypes::kernel_div)>::invoke(
|
||||||
|
kernel,
|
||||||
|
extra0->data_device,
|
||||||
|
offset0,
|
||||||
|
extra1->data_device,
|
||||||
|
offset1,
|
||||||
|
extrad->data_device,
|
||||||
|
offsetd,
|
||||||
|
nb00,
|
||||||
|
nb01,
|
||||||
|
nb02,
|
||||||
|
nb03,
|
||||||
|
ne10,
|
||||||
|
ne11,
|
||||||
|
ne12,
|
||||||
|
ne13,
|
||||||
|
nb10,
|
||||||
|
nb11,
|
||||||
|
nb12,
|
||||||
|
nb13,
|
||||||
|
ne0,
|
||||||
|
nb0,
|
||||||
|
nb1,
|
||||||
|
nb2,
|
||||||
|
nb3
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
kernel = backend_ctx->kernel_div_f16;
|
kernel = backend_ctx->kernel_div_f16;
|
||||||
|
cl_set_kernel_args(
|
||||||
|
kernel,
|
||||||
|
src0,
|
||||||
|
src1,
|
||||||
|
dst,
|
||||||
|
src0->nb,
|
||||||
|
src1->ne,
|
||||||
|
src1->nb,
|
||||||
|
ne0,
|
||||||
|
dst->nb
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
cl_set_kernel_args(
|
|
||||||
kernel,
|
|
||||||
src0,
|
|
||||||
src1,
|
|
||||||
dst,
|
|
||||||
src0->nb,
|
|
||||||
src1->ne,
|
|
||||||
src1->nb,
|
|
||||||
ne0,
|
|
||||||
dst->nb
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (bcast_row) {
|
if (bcast_row) {
|
||||||
|
|
|
||||||
|
|
@ -4,28 +4,28 @@
|
||||||
|
|
||||||
#include "ocl_defs.h"
|
#include "ocl_defs.h"
|
||||||
|
|
||||||
OCL_KERNEL void kernel_div(OCL_GLOBAL char * src0,
|
OCL_KERNEL void kernel_div(ocl_global_char_ptr src0,
|
||||||
ulong offset0,
|
ulong offset0,
|
||||||
OCL_GLOBAL char * src1,
|
ocl_global_char_ptr src1,
|
||||||
ulong offset1,
|
ulong offset1,
|
||||||
OCL_GLOBAL char * dst,
|
ocl_global_char_ptr dst,
|
||||||
ulong offsetd,
|
ulong offsetd,
|
||||||
ulong nb00,
|
ulong nb00,
|
||||||
ulong nb01,
|
ulong nb01,
|
||||||
ulong nb02,
|
ulong nb02,
|
||||||
ulong nb03,
|
ulong nb03,
|
||||||
int ne10,
|
int ne10,
|
||||||
int ne11,
|
int ne11,
|
||||||
int ne12,
|
int ne12,
|
||||||
int ne13,
|
int ne13,
|
||||||
ulong nb10,
|
ulong nb10,
|
||||||
ulong nb11,
|
ulong nb11,
|
||||||
ulong nb12,
|
ulong nb12,
|
||||||
ulong nb13,
|
ulong nb13,
|
||||||
int ne0,
|
int ne0,
|
||||||
ulong nb0,
|
ulong nb0,
|
||||||
ulong nb1,
|
ulong nb1,
|
||||||
ulong nb2,
|
ulong nb2,
|
||||||
ulong nb3);
|
ulong nb3);
|
||||||
|
|
||||||
#endif // __KERNELS_DIV_H__
|
#endif // __KERNELS_DIV_H__
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,14 @@
|
||||||
|
|
||||||
#ifdef __OPENCL_C_VERSION__
|
#ifdef __OPENCL_C_VERSION__
|
||||||
// Device (OpenCL) Definitions
|
// Device (OpenCL) Definitions
|
||||||
# define OCL_KERNEL kernel
|
# define OCL_KERNEL kernel
|
||||||
# define OCL_GLOBAL global
|
# define OCL_GLOBAL global
|
||||||
|
# define ocl_global_char_ptr global char *
|
||||||
#else
|
#else
|
||||||
// Host (C++) Definitions
|
// Host (C++) Definitions
|
||||||
# define OCL_KERNEL
|
# define OCL_KERNEL
|
||||||
# define OCL_GLOBAL
|
# define OCL_GLOBAL
|
||||||
|
# define ocl_global_char_ptr cl_mem
|
||||||
# define __kernel
|
# define __kernel
|
||||||
# define __global
|
# define __global
|
||||||
# define ulong cl_ulong
|
# define ulong cl_ulong
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue