refactor: simplify kernel argument setting with variadic template function

This commit is contained in:
chraac 2026-01-01 13:32:43 +08:00
parent cef1d23c5a
commit 522ef487a1
1 changed files with 52 additions and 37 deletions

View File

@ -53,6 +53,15 @@
bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor);
template <typename... _TArgs> static inline void set_kernel_args(cl_kernel kernel, _TArgs... args) {
size_t index = 0;
(
[&] {
CL_CHECK(clSetKernelArg(kernel, index++, sizeof(args), &args));
}(),
...);
}
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
@ -5098,13 +5107,16 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
kernel = backend_ctx->kernel_mul_row_f16;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
set_kernel_args(
kernel,
extra0->data_device,
offset0,
extra1->data_device,
offset1,
extrad->data_device,
offsetd,
ne
);
} else {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_mul;
@ -5112,36 +5124,39 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
kernel = backend_ctx->kernel_mul_f16;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
set_kernel_args(
kernel,
extra0->data_device,
offset0,
extra1->data_device,
offset1,
extrad->data_device,
offsetd,
ne00,
ne01,
ne02,
ne03,
nb00,
nb01,
nb02,
nb03,
ne10,
ne11,
ne12,
ne13,
nb10,
nb11,
nb12,
nb13,
ne0,
ne1,
ne2,
ne3,
nb0,
nb1,
nb2,
nb3
);
}
if (bcast_row) {