This commit is contained in:
chraac 2026-01-08 10:16:50 +08:00
parent 0699465aec
commit 06d0f91e42
1 changed files with 14 additions and 8 deletions

View File

@ -2678,6 +2678,8 @@ struct ggml_tensor_extra_cl {
}
};
namespace /* anonymous */ {
template <typename _TData> struct cl_kernel_arg_setter {};
template <> struct cl_kernel_arg_setter<int> {
@ -2713,14 +2715,14 @@ template <> struct cl_kernel_arg_setter<ggml_tensor> {
}
};
template <> struct cl_kernel_arg_setter<int64_t [GGML_MAX_DIMS]> {
template <> struct cl_kernel_arg_setter<int64_t[GGML_MAX_DIMS]> {
static size_t set_arg(cl_kernel kernel, size_t index, const int64_t (&ne)[GGML_MAX_DIMS]) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly");
const int ne0 = (int)ne[0];
const int ne1 = (int)ne[1];
const int ne2 = (int)ne[2];
const int ne3 = (int)ne[3];
const int ne0 = (int) ne[0];
const int ne1 = (int) ne[1];
const int ne2 = (int) ne[2];
const int ne3 = (int) ne[3];
CL_CHECK(clSetKernelArg(kernel, index, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, index + 1, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, index + 2, sizeof(int), &ne2));
@ -2729,7 +2731,7 @@ template <> struct cl_kernel_arg_setter<int64_t [GGML_MAX_DIMS]> {
}
};
template <> struct cl_kernel_arg_setter<size_t [GGML_MAX_DIMS]> {
template <> struct cl_kernel_arg_setter<size_t[GGML_MAX_DIMS]> {
static size_t set_arg(cl_kernel kernel, size_t index, const size_t (&nb)[GGML_MAX_DIMS]) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS changed, update cl_kernel_arg_setter accordingly");
@ -2745,16 +2747,20 @@ template <> struct cl_kernel_arg_setter<size_t [GGML_MAX_DIMS]> {
}
};
template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel kernel, _TArgs&&... args) {
template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel kernel, _TArgs &&... args) {
size_t index = 0;
(
[&] {
index = cl_kernel_arg_setter<std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TArgs>>>>::set_arg(kernel, index, args);
index = cl_kernel_arg_setter<
std::remove_const_t<std::remove_pointer_t<std::remove_reference_t<_TArgs>>>>::set_arg(kernel, index,
args);
}(),
...);
return index;
}
} // namespace
// Additional tensor extra structs for quantized tensors.
// These tensors are loaded from files and should not be allocated in scratch --
// they should always be allocated from the pool. Hence, they do not have an