feat: enhance OpenCL kernel division operations with new argument setters and invoker

This commit is contained in:
chraac 2026-01-15 22:29:08 +08:00
parent 5fbff1aa3a
commit b2a283d5e9
3 changed files with 115 additions and 37 deletions

View File

@ -2693,6 +2693,14 @@ template <> struct cl_kernel_arg_setter<int> {
}
};
template <> struct cl_kernel_arg_setter<std::remove_pointer_t<std::remove_reference_t<cl_mem>>> {
static size_t set_arg(cl_kernel kernel, size_t index, cl_mem arg) {
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
return index + 1;
}
};
template <> struct cl_kernel_arg_setter<cl_ulong> {
static size_t set_arg(cl_kernel kernel, size_t index, cl_ulong arg) {
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
@ -2763,6 +2771,23 @@ template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel
return index;
}
template<typename _TFunc>
struct cl_kernel_invoker {};
template<typename... _TArgs>
struct cl_kernel_invoker<void(_TArgs...)> {
static void invoke(cl_kernel kernel, _TArgs... args) {
size_t index = 0;
(
[&] {
index = cl_kernel_arg_setter<
std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TArgs>>>>::set_arg(kernel, index,
args);
}(),
...);
}
};
} // namespace
// Additional tensor extra structs for quantized tensors.
@ -5142,21 +5167,72 @@ static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const
} else {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_div;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
const cl_ulong nb00 = src0->nb[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int ne12 = src1->ne[2];
const int ne13 = src1->ne[3];
const cl_ulong nb10 = src1->nb[0];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb13 = src1->nb[3];
const cl_ulong nb0 = dst->nb[0];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
cl_kernel_invoker<decltype(ocl_kernel_prototypes::kernel_div)>::invoke(
kernel,
extra0->data_device,
offset0,
extra1->data_device,
offset1,
extrad->data_device,
offsetd,
nb00,
nb01,
nb02,
nb03,
ne10,
ne11,
ne12,
ne13,
nb10,
nb11,
nb12,
nb13,
ne0,
nb0,
nb1,
nb2,
nb3
);
} else {
kernel = backend_ctx->kernel_div_f16;
cl_set_kernel_args(
kernel,
src0,
src1,
dst,
src0->nb,
src1->ne,
src1->nb,
ne0,
dst->nb
);
}
cl_set_kernel_args(
kernel,
src0,
src1,
dst,
src0->nb,
src1->ne,
src1->nb,
ne0,
dst->nb
);
}
if (bcast_row) {

View File

@ -4,28 +4,28 @@
#include "ocl_defs.h"
OCL_KERNEL void kernel_div(OCL_GLOBAL char * src0,
ulong offset0,
OCL_GLOBAL char * src1,
ulong offset1,
OCL_GLOBAL char * dst,
ulong offsetd,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3);
OCL_KERNEL void kernel_div(ocl_global_char_ptr src0,
ulong offset0,
ocl_global_char_ptr src1,
ulong offset1,
ocl_global_char_ptr dst,
ulong offsetd,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3);
#endif // __KERNELS_DIV_H__

View File

@ -4,12 +4,14 @@
#ifdef __OPENCL_C_VERSION__
// Device (OpenCL) Definitions
# define OCL_KERNEL kernel
# define OCL_GLOBAL global
# define OCL_KERNEL kernel
# define OCL_GLOBAL global
# define ocl_global_char_ptr global char *
#else
// Host (C++) Definitions
# define OCL_KERNEL
# define OCL_GLOBAL
# define ocl_global_char_ptr cl_mem
# define __kernel
# define __global
# define ulong cl_ulong