refactor: update kernel argument setter for ggml_tensor and simplify argument handling
This commit is contained in:
parent
ada3e50321
commit
220a33afe8
|
|
@ -2701,7 +2701,7 @@ template <> struct cl_kernel_arg_setter<float> {
|
|||
}
|
||||
};
|
||||
|
||||
template <> struct cl_kernel_arg_setter<const ggml_tensor *> {
|
||||
template <> struct cl_kernel_arg_setter<ggml_tensor> {
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
|
||||
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra;
|
||||
static_assert(std::is_same_v<decltype(extra->data_device), cl_mem>, "data_device type mismatch");
|
||||
|
|
@ -2713,17 +2713,43 @@ template <> struct cl_kernel_arg_setter<const ggml_tensor *> {
|
|||
}
|
||||
};
|
||||
|
||||
template <> struct cl_kernel_arg_setter<ggml_tensor *> {
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
|
||||
return cl_kernel_arg_setter<const ggml_tensor *>::set_arg(kernel, index, t);
|
||||
template <> struct cl_kernel_arg_setter<int64_t [GGML_MAX_DIMS]> {
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, const int64_t (&ne)[GGML_MAX_DIMS]) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly");
|
||||
|
||||
const int ne0 = (int)ne[0];
|
||||
const int ne1 = (int)ne[1];
|
||||
const int ne2 = (int)ne[2];
|
||||
const int ne3 = (int)ne[3];
|
||||
CL_CHECK(clSetKernelArg(kernel, index, sizeof(int), &ne0));
|
||||
CL_CHECK(clSetKernelArg(kernel, index + 1, sizeof(int), &ne1));
|
||||
CL_CHECK(clSetKernelArg(kernel, index + 2, sizeof(int), &ne2));
|
||||
CL_CHECK(clSetKernelArg(kernel, index + 3, sizeof(int), &ne3));
|
||||
return index + 4;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel kernel, _TArgs... args) {
|
||||
template <> struct cl_kernel_arg_setter<size_t [GGML_MAX_DIMS]> {
|
||||
static size_t set_arg(cl_kernel kernel, size_t index, const size_t (&nb)[GGML_MAX_DIMS]) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly");
|
||||
|
||||
const cl_ulong nb0 = nb[0];
|
||||
const cl_ulong nb1 = nb[1];
|
||||
const cl_ulong nb2 = nb[2];
|
||||
const cl_ulong nb3 = nb[3];
|
||||
CL_CHECK(clSetKernelArg(kernel, index, sizeof(cl_ulong), &nb0));
|
||||
CL_CHECK(clSetKernelArg(kernel, index + 1, sizeof(cl_ulong), &nb1));
|
||||
CL_CHECK(clSetKernelArg(kernel, index + 2, sizeof(cl_ulong), &nb2));
|
||||
CL_CHECK(clSetKernelArg(kernel, index + 3, sizeof(cl_ulong), &nb3));
|
||||
return index + 4;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel kernel, _TArgs&&... args) {
|
||||
size_t index = 0;
|
||||
(
|
||||
[&] {
|
||||
index = cl_kernel_arg_setter<decltype(args)>::set_arg(kernel, index, args);
|
||||
index = cl_kernel_arg_setter<std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TArgs>>>>::set_arg(kernel, index, args);
|
||||
}(),
|
||||
...);
|
||||
return index;
|
||||
|
|
@ -5073,30 +5099,10 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
|
|||
const int ne02 = src0->ne[2];
|
||||
const int ne03 = src0->ne[3];
|
||||
|
||||
const cl_ulong nb00 = src0->nb[0];
|
||||
const cl_ulong nb01 = src0->nb[1];
|
||||
const cl_ulong nb02 = src0->nb[2];
|
||||
const cl_ulong nb03 = src0->nb[3];
|
||||
|
||||
const int ne10 = src1->ne[0];
|
||||
const int ne11 = src1->ne[1];
|
||||
const int ne12 = src1->ne[2];
|
||||
const int ne13 = src1->ne[3]; UNUSED(ne13);
|
||||
|
||||
const cl_ulong nb10 = src1->nb[0];
|
||||
const cl_ulong nb11 = src1->nb[1];
|
||||
const cl_ulong nb12 = src1->nb[2];
|
||||
const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
|
||||
|
||||
const int ne0 = dst->ne[0];
|
||||
const int ne1 = dst->ne[1];
|
||||
const int ne2 = dst->ne[2];
|
||||
const int ne3 = dst->ne[3];
|
||||
|
||||
const cl_ulong nb0 = dst->nb[0];
|
||||
const cl_ulong nb1 = dst->nb[1];
|
||||
const cl_ulong nb2 = dst->nb[2];
|
||||
const cl_ulong nb3 = dst->nb[3];
|
||||
|
||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||
|
||||
|
|
@ -5131,30 +5137,12 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
|
|||
src0,
|
||||
src1,
|
||||
dst,
|
||||
ne00,
|
||||
ne01,
|
||||
ne02,
|
||||
ne03,
|
||||
nb00,
|
||||
nb01,
|
||||
nb02,
|
||||
nb03,
|
||||
ne10,
|
||||
ne11,
|
||||
ne12,
|
||||
ne13,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
nb13,
|
||||
ne0,
|
||||
ne1,
|
||||
ne2,
|
||||
ne3,
|
||||
nb0,
|
||||
nb1,
|
||||
nb2,
|
||||
nb3
|
||||
src0->ne,
|
||||
src0->nb,
|
||||
src1->ne,
|
||||
src1->nb,
|
||||
dst->ne,
|
||||
dst->nb
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue