refactor: update kernel argument setter for ggml_tensor and simplify argument handling

This commit is contained in:
chraac 2026-01-03 00:28:21 +08:00
parent ada3e50321
commit 220a33afe8
1 changed files with 38 additions and 50 deletions

View File

@ -2701,7 +2701,7 @@ template <> struct cl_kernel_arg_setter<float> {
}
};
template <> struct cl_kernel_arg_setter<const ggml_tensor *> {
template <> struct cl_kernel_arg_setter<ggml_tensor> {
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra;
static_assert(std::is_same_v<decltype(extra->data_device), cl_mem>, "data_device type mismatch");
@ -2713,17 +2713,43 @@ template <> struct cl_kernel_arg_setter<const ggml_tensor *> {
}
};
template <> struct cl_kernel_arg_setter<ggml_tensor *> {
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
return cl_kernel_arg_setter<const ggml_tensor *>::set_arg(kernel, index, t);
template <> struct cl_kernel_arg_setter<int64_t [GGML_MAX_DIMS]> {
static size_t set_arg(cl_kernel kernel, size_t index, const int64_t (&ne)[GGML_MAX_DIMS]) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly");
const int ne0 = (int)ne[0];
const int ne1 = (int)ne[1];
const int ne2 = (int)ne[2];
const int ne3 = (int)ne[3];
CL_CHECK(clSetKernelArg(kernel, index, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, index + 1, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, index + 2, sizeof(int), &ne2));
CL_CHECK(clSetKernelArg(kernel, index + 3, sizeof(int), &ne3));
return index + 4;
}
};
template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel kernel, _TArgs... args) {
template <> struct cl_kernel_arg_setter<size_t [GGML_MAX_DIMS]> {
static size_t set_arg(cl_kernel kernel, size_t index, const size_t (&nb)[GGML_MAX_DIMS]) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly");
const cl_ulong nb0 = nb[0];
const cl_ulong nb1 = nb[1];
const cl_ulong nb2 = nb[2];
const cl_ulong nb3 = nb[3];
CL_CHECK(clSetKernelArg(kernel, index, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, index + 1, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, index + 2, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, index + 3, sizeof(cl_ulong), &nb3));
return index + 4;
}
};
template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel kernel, _TArgs&&... args) {
size_t index = 0;
(
[&] {
index = cl_kernel_arg_setter<decltype(args)>::set_arg(kernel, index, args);
index = cl_kernel_arg_setter<std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TArgs>>>>::set_arg(kernel, index, args);
}(),
...);
return index;
@ -5073,30 +5099,10 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const cl_ulong nb00 = src0->nb[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int ne10 = src1->ne[0];
const int ne11 = src1->ne[1];
const int ne12 = src1->ne[2];
const int ne13 = src1->ne[3]; UNUSED(ne13);
const cl_ulong nb10 = src1->nb[0];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int ne3 = dst->ne[3];
const cl_ulong nb0 = dst->nb[0];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@ -5131,30 +5137,12 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
src0,
src1,
dst,
ne00,
ne01,
ne02,
ne03,
nb00,
nb01,
nb02,
nb03,
ne10,
ne11,
ne12,
ne13,
nb10,
nb11,
nb12,
nb13,
ne0,
ne1,
ne2,
ne3,
nb0,
nb1,
nb2,
nb3
src0->ne,
src0->nb,
src1->ne,
src1->nb,
dst->ne,
dst->nb
);
}