refactor: simplify kernel argument setting with variadic template function
This commit is contained in:
parent
cef1d23c5a
commit
522ef487a1
|
|
@ -53,6 +53,15 @@
|
|||
|
||||
bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor);
|
||||
|
||||
template <typename... _TArgs> static inline void set_kernel_args(cl_kernel kernel, _TArgs... args) {
|
||||
size_t index = 0;
|
||||
(
|
||||
[&] {
|
||||
CL_CHECK(clSetKernelArg(kernel, index++, sizeof(args), &args));
|
||||
}(),
|
||||
...);
|
||||
}
|
||||
|
||||
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
||||
// Precompute mp (m' in the paper) and L such that division
|
||||
// can be computed using a multiply (high 32b of 64b result)
|
||||
|
|
@ -5098,13 +5107,16 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
|
|||
kernel = backend_ctx->kernel_mul_row_f16;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne));
|
||||
set_kernel_args(
|
||||
kernel,
|
||||
extra0->data_device,
|
||||
offset0,
|
||||
extra1->data_device,
|
||||
offset1,
|
||||
extrad->data_device,
|
||||
offsetd,
|
||||
ne
|
||||
);
|
||||
} else {
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
kernel = backend_ctx->kernel_mul;
|
||||
|
|
@ -5112,36 +5124,39 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
|
|||
kernel = backend_ctx->kernel_mul_f16;
|
||||
}
|
||||
|
||||
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
|
||||
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
|
||||
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
|
||||
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
|
||||
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
|
||||
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
|
||||
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
|
||||
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
|
||||
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
|
||||
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
|
||||
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3));
|
||||
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
|
||||
CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
|
||||
CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
|
||||
set_kernel_args(
|
||||
kernel,
|
||||
extra0->data_device,
|
||||
offset0,
|
||||
extra1->data_device,
|
||||
offset1,
|
||||
extrad->data_device,
|
||||
offsetd,
|
||||
ne00,
|
||||
ne01,
|
||||
ne02,
|
||||
ne03,
|
||||
nb00,
|
||||
nb01,
|
||||
nb02,
|
||||
nb03,
|
||||
ne10,
|
||||
ne11,
|
||||
ne12,
|
||||
ne13,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
nb13,
|
||||
ne0,
|
||||
ne1,
|
||||
ne2,
|
||||
ne3,
|
||||
nb0,
|
||||
nb1,
|
||||
nb2,
|
||||
nb3
|
||||
);
|
||||
}
|
||||
|
||||
if (bcast_row) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue