feat: add function type definitions for OpenCL kernel argument setters

This commit is contained in:
chraac 2026-01-16 22:44:32 +08:00
parent b2a283d5e9
commit 208f8454cd
1 changed files with 51 additions and 55 deletions

View File

@ -2687,6 +2687,8 @@ namespace /* anonymous */ {
template <typename _TData> struct cl_kernel_arg_setter {};
template <> struct cl_kernel_arg_setter<int> {
typedef void func_t(int);
static size_t set_arg(cl_kernel kernel, size_t index, int arg) {
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
return index + 1;
@ -2694,6 +2696,8 @@ template <> struct cl_kernel_arg_setter<int> {
};
template <> struct cl_kernel_arg_setter<std::remove_pointer_t<std::remove_reference_t<cl_mem>>> {
typedef void func_t(cl_mem);
static size_t set_arg(cl_kernel kernel, size_t index, cl_mem arg) {
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
return index + 1;
@ -2702,6 +2706,8 @@ template <> struct cl_kernel_arg_setter<std::remove_pointer_t<std::remove_refere
template <> struct cl_kernel_arg_setter<cl_ulong> {
typedef void func_t(cl_ulong);
static size_t set_arg(cl_kernel kernel, size_t index, cl_ulong arg) {
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
return index + 1;
@ -2709,6 +2715,8 @@ template <> struct cl_kernel_arg_setter<cl_ulong> {
};
template <> struct cl_kernel_arg_setter<float> {
typedef void func_t(float);
static size_t set_arg(cl_kernel kernel, size_t index, float arg) {
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
return index + 1;
@ -2716,6 +2724,8 @@ template <> struct cl_kernel_arg_setter<float> {
};
template <> struct cl_kernel_arg_setter<ggml_tensor> {
typedef void func_t(cl_mem, cl_ulong);
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra;
static_assert(std::is_same_v<decltype(extra->data_device), cl_mem>, "data_device type mismatch");
@ -2728,6 +2738,8 @@ template <> struct cl_kernel_arg_setter<ggml_tensor> {
};
template <> struct cl_kernel_arg_setter<int64_t[GGML_MAX_DIMS]> {
typedef void func_t(int, int, int, int);
static size_t set_arg(cl_kernel kernel, size_t index, const int64_t (&ne)[GGML_MAX_DIMS]) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly");
@ -2744,6 +2756,8 @@ template <> struct cl_kernel_arg_setter<int64_t[GGML_MAX_DIMS]> {
};
template <> struct cl_kernel_arg_setter<size_t[GGML_MAX_DIMS]> {
typedef void func_t(cl_ulong, cl_ulong, cl_ulong, cl_ulong);
static size_t set_arg(cl_kernel kernel, size_t index, const size_t (&nb)[GGML_MAX_DIMS]) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly");
@ -2771,18 +2785,41 @@ template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel
return index;
}
template<typename _TFunc>
struct cl_kernel_invoker {};
template <typename... _TArgs> struct type_merger {};
template <typename... _TArgs, typename... _TInnerArgs1, typename... _TInnerArgs2>
struct type_merger<void(_TInnerArgs1...), void(_TInnerArgs2...), _TArgs...> {
using func_t = typename type_merger<void(_TInnerArgs1..., _TInnerArgs2...), _TArgs...>::func_t;
};
template <typename... _TInnerArgs> struct type_merger<void(_TInnerArgs...)> {
using func_t = void(_TInnerArgs...);
};
template <typename _TFirstArg, typename... _TRestArgs> struct cl_param_type_extractor {
using args_t = std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TFirstArg>>>;
using first_func_t =
typename cl_kernel_arg_setter<args_t>::func_t;
using func_t = typename type_merger<first_func_t, typename cl_param_type_extractor<_TRestArgs...>::func_t>::func_t;
};
template <typename _TFinalArg> struct cl_param_type_extractor<_TFinalArg> {
using args_t = std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TFinalArg>>>;
using func_t = typename cl_kernel_arg_setter<args_t>::func_t;
};
template <typename _TFunc> struct cl_kernel_invoker {
template <typename... _TCalledArgs> static void invoke(cl_kernel kernel, _TCalledArgs &&... args) {
static_assert(std::is_same_v<_TFunc, typename cl_param_type_extractor<_TCalledArgs...>::func_t>,
"Kernel argument type mismatch between prototype and called arguments");
template<typename... _TArgs>
struct cl_kernel_invoker<void(_TArgs...)> {
static void invoke(cl_kernel kernel, _TArgs... args) {
size_t index = 0;
(
[&] {
index = cl_kernel_arg_setter<
std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TArgs>>>>::set_arg(kernel, index,
args);
std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TCalledArgs>>>>::set_arg(kernel,
index,
args);
}(),
...);
}
@ -5167,57 +5204,16 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
} else {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_div;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
const cl_ulong nb00 = src0->nb[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int ne12 = src1->ne[2];
const int ne13 = src1->ne[3];
const cl_ulong nb10 = src1->nb[0];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb13 = src1->nb[3];
const cl_ulong nb0 = dst->nb[0];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
cl_kernel_invoker<decltype(ocl_kernel_prototypes::kernel_div)>::invoke(
kernel,
extra0->data_device,
offset0,
extra1->data_device,
offset1,
extrad->data_device,
offsetd,
nb00,
nb01,
nb02,
nb03,
ne10,
ne11,
ne12,
ne13,
nb10,
nb11,
nb12,
nb13,
src0,
src1,
dst,
src0->nb,
src1->ne,
src1->nb,
ne0,
nb0,
nb1,
nb2,
nb3
dst->nb
);
} else {
kernel = backend_ctx->kernel_div_f16;