refactor: add specialized kernel argument setter for float and simplify kernel argument setting

This commit is contained in:
chraac 2026-01-01 23:39:09 +08:00
parent 1fae16787e
commit b1b8fd9abf
1 changed files with 12 additions and 69 deletions

View File

@ -2694,6 +2694,13 @@ template <> struct ocl_kernel_arg_setter<cl_ulong> {
}
};
template <> struct ocl_kernel_arg_setter<float> {
static size_t set_arg(cl_kernel kernel, size_t index, float arg) {
CL_CHECK(clSetKernelArg(kernel, index, sizeof(arg), &arg));
return index + 1;
}
};
template <> struct ocl_kernel_arg_setter<const ggml_tensor *> {
static size_t set_arg(cl_kernel kernel, size_t index, const ggml_tensor * t) {
ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) t->extra;
@ -5037,32 +5044,9 @@ static void ggml_cl_add_id(ggml_backend_t backend, const ggml_tensor * src0, con
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offset2 = extra2->offset + src2->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel = backend_ctx->kernel_add_id;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb21));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1));
set_kernel_args(kernel, src0, src1, src2, dst, nb01, nb02, nb11, nb21, ne0, ne1);
int nth = MIN(ne00, (int) backend_ctx->get_kernel_workgroup_size(kernel));
size_t global_work_size[] = { (size_t)ne01*nth, (size_t)ne02, 1 };
@ -6855,25 +6839,13 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
return;
}
ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong off_src0 = extra_src0->offset + src0->view_offs;
cl_ulong off_dst = extra_dst->offset + dst->view_offs;
const int logical_dim = dst->op_params[0];
const int max_period = dst->op_params[1];
const int dst_nb1_bytes = dst->nb[1];
cl_kernel kernel = backend_ctx->kernel_timestep_embedding;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &dst_nb1_bytes));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &logical_dim));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &max_period));
set_kernel_args(kernel, src0, dst, dst_nb1_bytes, logical_dim, max_period);
size_t gws0 = (size_t)(((logical_dim + 1) / 2) + 1);
@ -8681,20 +8653,9 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&bias, ((int32_t *) dst->op_params) + 1, sizeof(float));
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel = backend_ctx->kernel_scale;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias));
set_kernel_args(kernel, src0, dst, scale, bias);
int n = ggml_nelements(dst)/4;
@ -8831,24 +8792,12 @@ static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * sr
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel;
if (ne00%8 == 0) {
kernel = backend_ctx->kernel_diag_mask_inf_8;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &n_past));
set_kernel_args(kernel, src0, dst, ne00, ne01, n_past);
size_t global_work_size[] = {(size_t)ne00*ne01*ne02/8, 1, 1};
size_t local_work_size[] = {64, 1, 1};
@ -8857,13 +8806,7 @@ static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * sr
} else {
kernel = backend_ctx->kernel_diag_mask_inf;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &n_past));
set_kernel_args(kernel, src0, dst, ne00, ne01, n_past);
size_t global_work_size[] = {(size_t)ne00, (size_t)ne01, (size_t)ne02};
size_t local_work_size[] = {64, 1, 1};