refactor: update kernel argument setter for ggml_tensor and simplify argument handling
This commit is contained in:
parent
ada3e50321
commit
220a33afe8
|
|
@ -2701,7 +2701,7 @@ template <> struct cl_kernel_arg_setter<float> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <> struct cl_kernel_arg_setter<const ggml_tensor *> {
|
template <> struct cl_kernel_arg_setter<ggml_tensor> {
|
||||||
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
|
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
|
||||||
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra;
|
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra;
|
||||||
static_assert(std::is_same_v<decltype(extra->data_device), cl_mem>, "data_device type mismatch");
|
static_assert(std::is_same_v<decltype(extra->data_device), cl_mem>, "data_device type mismatch");
|
||||||
|
|
@ -2713,17 +2713,43 @@ template <> struct cl_kernel_arg_setter<const ggml_tensor *> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <> struct cl_kernel_arg_setter<ggml_tensor *> {
|
template <> struct cl_kernel_arg_setter<int64_t [GGML_MAX_DIMS]> {
|
||||||
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
|
static size_t set_arg(cl_kernel kernel, size_t index, const int64_t (&ne)[GGML_MAX_DIMS]) {
|
||||||
return cl_kernel_arg_setter<const ggml_tensor *>::set_arg(kernel, index, t);
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly");
|
||||||
|
|
||||||
|
const int ne0 = (int)ne[0];
|
||||||
|
const int ne1 = (int)ne[1];
|
||||||
|
const int ne2 = (int)ne[2];
|
||||||
|
const int ne3 = (int)ne[3];
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, index, sizeof(int), &ne0));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, index + 1, sizeof(int), &ne1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, index + 2, sizeof(int), &ne2));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, index + 3, sizeof(int), &ne3));
|
||||||
|
return index + 4;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel kernel, _TArgs... args) {
|
template <> struct cl_kernel_arg_setter<size_t [GGML_MAX_DIMS]> {
|
||||||
|
static size_t set_arg(cl_kernel kernel, size_t index, const size_t (&nb)[GGML_MAX_DIMS]) {
|
||||||
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly");
|
||||||
|
|
||||||
|
const cl_ulong nb0 = nb[0];
|
||||||
|
const cl_ulong nb1 = nb[1];
|
||||||
|
const cl_ulong nb2 = nb[2];
|
||||||
|
const cl_ulong nb3 = nb[3];
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, index, sizeof(cl_ulong), &nb0));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, index + 1, sizeof(cl_ulong), &nb1));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, index + 2, sizeof(cl_ulong), &nb2));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, index + 3, sizeof(cl_ulong), &nb3));
|
||||||
|
return index + 4;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel kernel, _TArgs&&... args) {
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
(
|
(
|
||||||
[&] {
|
[&] {
|
||||||
index = cl_kernel_arg_setter<decltype(args)>::set_arg(kernel, index, args);
|
index = cl_kernel_arg_setter<std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TArgs>>>>::set_arg(kernel, index, args);
|
||||||
}(),
|
}(),
|
||||||
...);
|
...);
|
||||||
return index;
|
return index;
|
||||||
|
|
@ -5073,30 +5099,10 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||||
const int ne02 = src0->ne[2];
|
const int ne02 = src0->ne[2];
|
||||||
const int ne03 = src0->ne[3];
|
const int ne03 = src0->ne[3];
|
||||||
|
|
||||||
const cl_ulong nb00 = src0->nb[0];
|
|
||||||
const cl_ulong nb01 = src0->nb[1];
|
|
||||||
const cl_ulong nb02 = src0->nb[2];
|
|
||||||
const cl_ulong nb03 = src0->nb[3];
|
|
||||||
|
|
||||||
const int ne10 = src1->ne[0];
|
const int ne10 = src1->ne[0];
|
||||||
const int ne11 = src1->ne[1];
|
const int ne11 = src1->ne[1];
|
||||||
const int ne12 = src1->ne[2];
|
|
||||||
const int ne13 = src1->ne[3]; UNUSED(ne13);
|
|
||||||
|
|
||||||
const cl_ulong nb10 = src1->nb[0];
|
|
||||||
const cl_ulong nb11 = src1->nb[1];
|
|
||||||
const cl_ulong nb12 = src1->nb[2];
|
|
||||||
const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13);
|
|
||||||
|
|
||||||
const int ne0 = dst->ne[0];
|
const int ne0 = dst->ne[0];
|
||||||
const int ne1 = dst->ne[1];
|
|
||||||
const int ne2 = dst->ne[2];
|
|
||||||
const int ne3 = dst->ne[3];
|
|
||||||
|
|
||||||
const cl_ulong nb0 = dst->nb[0];
|
|
||||||
const cl_ulong nb1 = dst->nb[1];
|
|
||||||
const cl_ulong nb2 = dst->nb[2];
|
|
||||||
const cl_ulong nb3 = dst->nb[3];
|
|
||||||
|
|
||||||
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||||
|
|
||||||
|
|
@ -5131,30 +5137,12 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
|
||||||
src0,
|
src0,
|
||||||
src1,
|
src1,
|
||||||
dst,
|
dst,
|
||||||
ne00,
|
src0->ne,
|
||||||
ne01,
|
src0->nb,
|
||||||
ne02,
|
src1->ne,
|
||||||
ne03,
|
src1->nb,
|
||||||
nb00,
|
dst->ne,
|
||||||
nb01,
|
dst->nb
|
||||||
nb02,
|
|
||||||
nb03,
|
|
||||||
ne10,
|
|
||||||
ne11,
|
|
||||||
ne12,
|
|
||||||
ne13,
|
|
||||||
nb10,
|
|
||||||
nb11,
|
|
||||||
nb12,
|
|
||||||
nb13,
|
|
||||||
ne0,
|
|
||||||
ne1,
|
|
||||||
ne2,
|
|
||||||
ne3,
|
|
||||||
nb0,
|
|
||||||
nb1,
|
|
||||||
nb2,
|
|
||||||
nb3
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue