From 220a33afe87204bf419543ea8d3deb89dd72a33f Mon Sep 17 00:00:00 2001 From: chraac Date: Sat, 3 Jan 2026 00:28:21 +0800 Subject: [PATCH] refactor: update kernel argument setter for ggml_tensor and simplify argument handling --- ggml/src/ggml-opencl/ggml-opencl.cpp | 88 ++++++++++++---------------- 1 file changed, 38 insertions(+), 50 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 7289fcc37c..e5fe616dba 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2701,7 +2701,7 @@ template <> struct cl_kernel_arg_setter { } }; -template <> struct cl_kernel_arg_setter { +template <> struct cl_kernel_arg_setter { static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) { ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra; static_assert(std::is_same_vdata_device), cl_mem>, "data_device type mismatch"); @@ -2713,17 +2713,43 @@ template <> struct cl_kernel_arg_setter { } }; -template <> struct cl_kernel_arg_setter { - static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) { - return cl_kernel_arg_setter::set_arg(kernel, index, t); +template <> struct cl_kernel_arg_setter { + static size_t set_arg(cl_kernel kernel, size_t index, const int64_t (&ne)[GGML_MAX_DIMS]) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly"); + + const int ne0 = (int)ne[0]; + const int ne1 = (int)ne[1]; + const int ne2 = (int)ne[2]; + const int ne3 = (int)ne[3]; + CL_CHECK(clSetKernelArg(kernel, index, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, index + 1, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, index + 2, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, index + 3, sizeof(int), &ne3)); + return index + 4; } }; -template static inline size_t cl_set_kernel_args(cl_kernel kernel, _TArgs... args) { +template <> struct cl_kernel_arg_setter { + static size_t set_arg(cl_kernel kernel, size_t index, const size_t (&nb)[GGML_MAX_DIMS]) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly"); + + const cl_ulong nb0 = nb[0]; + const cl_ulong nb1 = nb[1]; + const cl_ulong nb2 = nb[2]; + const cl_ulong nb3 = nb[3]; + CL_CHECK(clSetKernelArg(kernel, index, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, index + 1, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, index + 2, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, index + 3, sizeof(cl_ulong), &nb3)); + return index + 4; + } +}; + +template static inline size_t cl_set_kernel_args(cl_kernel kernel, _TArgs&&... args) { size_t index = 0; ( [&] { - index = cl_kernel_arg_setter::set_arg(kernel, index, args); + index = cl_kernel_arg_setter>>>::set_arg(kernel, index, args); }(), ...); return index; @@ -5073,30 +5099,10 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3]; - const cl_ulong nb00 = src0->nb[0]; - const cl_ulong nb01 = src0->nb[1]; - const cl_ulong nb02 = src0->nb[2]; - const cl_ulong nb03 = src0->nb[3]; - const int ne10 = src1->ne[0]; const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - const int ne13 = src1->ne[3]; UNUSED(ne13); - - const cl_ulong nb10 = src1->nb[0]; - const cl_ulong nb11 = src1->nb[1]; - const cl_ulong nb12 = src1->nb[2]; - const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13); const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - const int ne3 = dst->ne[3]; - - const cl_ulong nb0 = dst->nb[0]; - const cl_ulong nb1 = dst->nb[1]; - const cl_ulong nb2 = dst->nb[2]; - const cl_ulong nb3 = dst->nb[3]; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -5131,30 +5137,12 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const src0, src1, dst, - ne00, - ne01, - ne02, - ne03, - nb00, - nb01, - nb02, - nb03, - ne10, - ne11, - ne12, - ne13, - nb10, - nb11, - nb12, - nb13, - ne0, - ne1, - ne2, - ne3, - nb0, - nb1, - nb2, - nb3 + src0->ne, + src0->nb, + src1->ne, + src1->nb, + dst->ne, + dst->nb ); }