This commit is contained in:
chraac 2026-01-02 23:53:03 +08:00
parent b1b8fd9abf
commit ada3e50321
1 changed files with 17 additions and 16 deletions

View File

@ -2678,30 +2678,30 @@ struct ggml_tensor_extra_cl {
}
};
template <typename _TData> struct ocl_kernel_arg_setter {};
template <typename _TData> struct cl_kernel_arg_setter {};
template <> struct ocl_kernel_arg_setter<int> {
template <> struct cl_kernel_arg_setter<int> {
static size_t set_arg(cl_kernel kernel, size_t index, int arg) {
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
return index + 1;
}
};
template <> struct ocl_kernel_arg_setter<cl_ulong> {
template <> struct cl_kernel_arg_setter<cl_ulong> {
static size_t set_arg(cl_kernel kernel, size_t index, cl_ulong arg) {
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
return index + 1;
}
};
template <> struct ocl_kernel_arg_setter<float> {
template <> struct cl_kernel_arg_setter<float> {
static size_t set_arg(cl_kernel kernel, size_t index, float arg) {
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
return index + 1;
}
};
template <> struct ocl_kernel_arg_setter<const ggml_tensor *> {
template <> struct cl_kernel_arg_setter<const ggml_tensor *> {
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra;
static_assert(std::is_same_v<decltype(extra->data_device), cl_mem>, "data_device type mismatch");
@ -2713,19 +2713,20 @@ template <> struct ocl_kernel_arg_setter<const ggml_tensor *> {
}
};
template <> struct ocl_kernel_arg_setter<ggml_tensor *> {
template <> struct cl_kernel_arg_setter<ggml_tensor *> {
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
return ocl_kernel_arg_setter<const ggml_tensor *>::set_arg(kernel, index, t);
return cl_kernel_arg_setter<const ggml_tensor *>::set_arg(kernel, index, t);
}
};
template <typename... _TArgs> static inline void set_kernel_args(cl_kernel kernel, _TArgs... args) {
template <typename... _TArgs> static inline size_t cl_set_kernel_args(cl_kernel kernel, _TArgs... args) {
size_t index = 0;
(
[&] {
index = ocl_kernel_arg_setter<decltype(args)>::set_arg(kernel, index, args);
index = cl_kernel_arg_setter<decltype(args)>::set_arg(kernel, index, args);
}(),
...);
return index;
}
// Additional tensor extra structs for quantized tensors.
@ -5046,7 +5047,7 @@ static void ggml_cl_add_id(ggml_backend_t backend, const ggml_tensor * src0, con
cl_kernel kernel = backend_ctx->kernel_add_id;
set_kernel_args(kernel, src0, src1, src2, dst, nb01, nb02, nb11, nb21, ne0, ne1);
cl_set_kernel_args(kernel, src0, src1, src2, dst, nb01, nb02, nb11, nb21, ne0, ne1);
int nth = MIN(ne00, (int) backend_ctx->get_kernel_workgroup_size(kernel));
size_t global_work_size[] = { (size_t)ne01*nth, (size_t)ne02, 1 };
@ -5117,7 +5118,7 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
kernel = backend_ctx->kernel_mul_row_f16;
}
set_kernel_args(kernel, src0, src1, dst, ne);
cl_set_kernel_args(kernel, src0, src1, dst, ne);
} else {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_mul;
@ -5125,7 +5126,7 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const
kernel = backend_ctx->kernel_mul_f16;
}
set_kernel_args(
cl_set_kernel_args(
kernel,
src0,
src1,
@ -6845,7 +6846,7 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
cl_kernel kernel = backend_ctx->kernel_timestep_embedding;
set_kernel_args(kernel, src0, dst, dst_nb1_bytes, logical_dim, max_period);
cl_set_kernel_args(kernel, src0, dst, dst_nb1_bytes, logical_dim, max_period);
size_t gws0 = (size_t)(((logical_dim + 1) / 2) + 1);
@ -8655,7 +8656,7 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
cl_kernel kernel = backend_ctx->kernel_scale;
set_kernel_args(kernel, src0, dst, scale, bias);
cl_set_kernel_args(kernel, src0, dst, scale, bias);
int n = ggml_nelements(dst)/4;
@ -8797,7 +8798,7 @@ static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * sr
if (ne00%8 == 0) {
kernel = backend_ctx->kernel_diag_mask_inf_8;
set_kernel_args(kernel, src0, dst, ne00, ne01, n_past);
cl_set_kernel_args(kernel, src0, dst, ne00, ne01, n_past);
size_t global_work_size[] = {(size_t)ne00*ne01*ne02/8, 1, 1};
size_t local_work_size[] = {64, 1, 1};
@ -8806,7 +8807,7 @@ static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * sr
} else {
kernel = backend_ctx->kernel_diag_mask_inf;
set_kernel_args(kernel, src0, dst, ne00, ne01, n_past);
cl_set_kernel_args(kernel, src0, dst, ne00, ne01, n_past);
size_t global_work_size[] = {(size_t)ne00, (size_t)ne01, (size_t)ne02};
size_t local_work_size[] = {64, 1, 1};