feat: add function type definitions for OpenCL kernel argument setters
This commit is contained in:
parent
b2a283d5e9
commit
208f8454cd
|
|
@ -2687,6 +2687,8 @@ namespace /* anonymous */ {
|
|||
template <typename _TData> struct cl_kernel_arg_setter {};
|
||||
|
||||
template <> struct cl_kernel_arg_setter<int> {
|
||||
typedef void func_t(int);
|
||||
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, int arg) {
|
||||
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
|
||||
return index + 1;
|
||||
|
|
@ -2694,6 +2696,8 @@ template <> struct cl_kernel_arg_setter<int> {
|
|||
};
|
||||
|
||||
template <> struct cl_kernel_arg_setter<std::remove_pointer_t<std::remove_reference_t<cl_mem>>> {
|
||||
typedef void func_t(cl_mem);
|
||||
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, cl_mem arg) {
|
||||
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
|
||||
return index + 1;
|
||||
|
|
@ -2702,6 +2706,8 @@ template <> struct cl_kernel_arg_setter<std::remove_pointer_t<std::remove_refere
|
|||
|
||||
|
||||
template <> struct cl_kernel_arg_setter<cl_ulong> {
|
||||
typedef void func_t(cl_ulong);
|
||||
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, cl_ulong arg) {
|
||||
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
|
||||
return index + 1;
|
||||
|
|
@ -2709,6 +2715,8 @@ template <> struct cl_kernel_arg_setter<cl_ulong> {
|
|||
};
|
||||
|
||||
template <> struct cl_kernel_arg_setter<float> {
|
||||
typedef void func_t(float);
|
||||
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, float arg) {
|
||||
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
|
||||
return index + 1;
|
||||
|
|
@ -2716,6 +2724,8 @@ template <> struct cl_kernel_arg_setter<float> {
|
|||
};
|
||||
|
||||
template <> struct cl_kernel_arg_setter<ggml_tensor> {
|
||||
typedef void func_t(cl_mem, cl_ulong);
|
||||
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
|
||||
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra;
|
||||
static_assert(std::is_same_v<decltype(extra->data_device), cl_mem>, "data_device type mismatch");
|
||||
|
|
@ -2728,6 +2738,8 @@ template <> struct cl_kernel_arg_setter<ggml_tensor> {
|
|||
};
|
||||
|
||||
template <> struct cl_kernel_arg_setter<int64_t[GGML_MAX_DIMS]> {
|
||||
typedef void func_t(int, int, int, int);
|
||||
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, const int64_t (&ne)[GGML_MAX_DIMS]) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly");
|
||||
|
||||
|
|
@ -2744,6 +2756,8 @@ template <> struct cl_kernel_arg_setter<int64_t[GGML_MAX_DIMS]> {
|
|||
};
|
||||
|
||||
template <> struct cl_kernel_arg_setter<size_t[GGML_MAX_DIMS]> {
|
||||
typedef void func_t(cl_ulong, cl_ulong, cl_ulong, cl_ulong);
|
||||
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, const size_t (&nb)[GGML_MAX_DIMS]) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly");
|
||||
|
||||
|
|
@ -2771,18 +2785,41 @@ template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel
|
|||
return index;
|
||||
}
|
||||
|
||||
template<typename _TFunc>
|
||||
struct cl_kernel_invoker {};
|
||||
template <typename... _TArgs> struct type_merger {};
|
||||
|
||||
template <typename... _TArgs, typename... _TInnerArgs1, typename... _TInnerArgs2>
|
||||
struct type_merger<void(_TInnerArgs1...), void(_TInnerArgs2...), _TArgs...> {
|
||||
using func_t = typename type_merger<void(_TInnerArgs1..., _TInnerArgs2...), _TArgs...>::func_t;
|
||||
};
|
||||
|
||||
template <typename... _TInnerArgs> struct type_merger<void(_TInnerArgs...)> {
|
||||
using func_t = void(_TInnerArgs...);
|
||||
};
|
||||
|
||||
template <typename _TFirstArg, typename... _TRestArgs> struct cl_param_type_extractor {
|
||||
using args_t = std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TFirstArg>>>;
|
||||
using first_func_t =
|
||||
typename cl_kernel_arg_setter<args_t>::func_t;
|
||||
using func_t = typename type_merger<first_func_t, typename cl_param_type_extractor<_TRestArgs...>::func_t>::func_t;
|
||||
};
|
||||
|
||||
template <typename _TFinalArg> struct cl_param_type_extractor<_TFinalArg> {
|
||||
using args_t = std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TFinalArg>>>;
|
||||
using func_t = typename cl_kernel_arg_setter<args_t>::func_t;
|
||||
};
|
||||
|
||||
template <typename _TFunc> struct cl_kernel_invoker {
|
||||
template <typename... _TCalledArgs> static void invoke(cl_kernel kernel, _TCalledArgs &&... args) {
|
||||
static_assert(std::is_same_v<_TFunc, typename cl_param_type_extractor<_TCalledArgs...>::func_t>,
|
||||
"Kernel argument type mismatch between prototype and called arguments");
|
||||
|
||||
template<typename... _TArgs>
|
||||
struct cl_kernel_invoker<void(_TArgs...)> {
|
||||
static void invoke(cl_kernel kernel, _TArgs... args) {
|
||||
size_t index = 0;
|
||||
(
|
||||
[&] {
|
||||
index = cl_kernel_arg_setter<
|
||||
std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TArgs>>>>::set_arg(kernel, index,
|
||||
args);
|
||||
std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TCalledArgs>>>>::set_arg(kernel,
|
||||
index,
|
||||
args);
|
||||
}(),
|
||||
...);
|
||||
}
|
||||
|
|
@ -5167,57 +5204,16 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
|
|||
} else {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_div;
|
||||
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
|
||||
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||
|
||||
cl_ulong offset0 = extra0->offset + src0->view_offs;
|
||||
cl_ulong offset1 = extra1->offset + src1->view_offs;
|
||||
cl_ulong offsetd = extrad->offset + dst->view_offs;
|
||||
|
||||
const cl_ulong nb00 = src0->nb[0];
|
||||
const cl_ulong nb01 = src0->nb[1];
|
||||
const cl_ulong nb02 = src0->nb[2];
|
||||
const cl_ulong nb03 = src0->nb[3];
|
||||
|
||||
const int ne12 = src1->ne[2];
|
||||
const int ne13 = src1->ne[3];
|
||||
|
||||
const cl_ulong nb10 = src1->nb[0];
|
||||
const cl_ulong nb11 = src1->nb[1];
|
||||
const cl_ulong nb12 = src1->nb[2];
|
||||
const cl_ulong nb13 = src1->nb[3];
|
||||
|
||||
const cl_ulong nb0 = dst->nb[0];
|
||||
const cl_ulong nb1 = dst->nb[1];
|
||||
const cl_ulong nb2 = dst->nb[2];
|
||||
const cl_ulong nb3 = dst->nb[3];
|
||||
|
||||
cl_kernel_invoker<decltype(ocl_kernel_prototypes::kernel_div)>::invoke(
|
||||
kernel,
|
||||
extra0->data_device,
|
||||
offset0,
|
||||
extra1->data_device,
|
||||
offset1,
|
||||
extrad->data_device,
|
||||
offsetd,
|
||||
nb00,
|
||||
nb01,
|
||||
nb02,
|
||||
nb03,
|
||||
ne10,
|
||||
ne11,
|
||||
ne12,
|
||||
ne13,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
nb13,
|
||||
src0,
|
||||
src1,
|
||||
dst,
|
||||
src0->nb,
|
||||
src1->ne,
|
||||
src1->nb,
|
||||
ne0,
|
||||
nb0,
|
||||
nb1,
|
||||
nb2,
|
||||
nb3
|
||||
dst->nb
|
||||
);
|
||||
} else {
|
||||
kernel = backend_ctx->kernel_div_f16;
|
||||
|
|
|
|||
Loading…
Reference in New Issue