This commit is contained in:
lhez 2026-02-02 09:18:13 +08:00 committed by GitHub
commit abc503d7d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 400 additions and 447 deletions

View File

@ -453,7 +453,6 @@ struct ggml_backend_opencl_context {
cl_program program_rms_norm; cl_program program_rms_norm;
cl_program program_group_norm; cl_program program_group_norm;
cl_program program_rope; cl_program program_rope;
cl_program program_scale;
cl_program program_silu; cl_program program_silu;
cl_program program_sigmoid; cl_program program_sigmoid;
cl_program program_softmax_f32; cl_program program_softmax_f32;
@ -462,11 +461,8 @@ struct ggml_backend_opencl_context {
cl_program program_softmax_4_f16; cl_program program_softmax_4_f16;
cl_program program_argsort_f32_i32; cl_program program_argsort_f32_i32;
cl_program program_sum_rows_f32; cl_program program_sum_rows_f32;
cl_program program_repeat;
cl_program program_pad; cl_program program_pad;
cl_program program_tanh;
cl_program program_upscale; cl_program program_upscale;
cl_program program_concat;
cl_program program_conv_2d_f16; cl_program program_conv_2d_f16;
cl_program program_conv_2d_f32; cl_program program_conv_2d_f32;
cl_program program_conv_2d_f16_f32; cl_program program_conv_2d_f16_f32;
@ -485,7 +481,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16; cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16;
cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16; cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
cl_kernel kernel_add_id; cl_kernel kernel_add_id;
cl_kernel kernel_scale; cl_kernel kernel_scale_f32, kernel_scale_f32_4;
cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4; cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4;
cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4; cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4;
cl_kernel kernel_mean_f32; cl_kernel kernel_mean_f32;
@ -544,18 +540,17 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_im2col_f32, kernel_im2col_f16;
cl_kernel kernel_argsort_f32_i32; cl_kernel kernel_argsort_f32_i32;
cl_kernel kernel_sum_rows_f32; cl_kernel kernel_sum_rows_f32;
cl_kernel kernel_repeat; cl_kernel kernel_repeat_f32;
cl_kernel kernel_pad; cl_kernel kernel_pad;
cl_kernel kernel_tanh_f32_nd; cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc;
cl_kernel kernel_tanh_f16_nd; cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc;
cl_kernel kernel_expm1_f32_nd; cl_kernel kernel_expm1_f32_nd;
cl_kernel kernel_expm1_f16_nd; cl_kernel kernel_expm1_f16_nd;
cl_kernel kernel_softplus_f32_nd; cl_kernel kernel_softplus_f32_nd;
cl_kernel kernel_softplus_f16_nd; cl_kernel kernel_softplus_f16_nd;
cl_kernel kernel_upscale; cl_kernel kernel_upscale;
cl_kernel kernel_upscale_bilinear; cl_kernel kernel_upscale_bilinear;
cl_kernel kernel_concat_f32_contiguous; cl_kernel kernel_concat_f32;
cl_kernel kernel_concat_f32_non_contiguous;
cl_kernel kernel_conv_2d_f16; cl_kernel kernel_conv_2d_f16;
cl_kernel kernel_conv_2d_f32; cl_kernel kernel_conv_2d_f32;
cl_kernel kernel_conv_2d_f16_f32; cl_kernel kernel_conv_2d_f16_f32;
@ -1483,10 +1478,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
#else #else
const std::string kernel_src = read_file("scale.cl"); const std::string kernel_src = read_file("scale.cl");
#endif #endif
backend_ctx->program_scale = cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program_scale, "kernel_scale", &err), err)); CL_CHECK((backend_ctx->kernel_scale_f32 = clCreateKernel(prog, "kernel_scale_f32", &err), err));
CL_CHECK((backend_ctx->kernel_scale_f32_4 = clCreateKernel(prog, "kernel_scale_f32_4", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT("."); GGML_LOG_CONT(".");
} }
@ -1814,16 +1811,11 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
#else #else
const std::string kernel_src = read_file("repeat.cl"); const std::string kernel_src = read_file("repeat.cl");
#endif #endif
if (!kernel_src.empty()) { cl_program prog =
backend_ctx->program_repeat = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_repeat_f32 = clCreateKernel(prog, "kernel_repeat_f32", &err), err));
CL_CHECK((backend_ctx->kernel_repeat = clCreateKernel(backend_ctx->program_repeat, "kernel_repeat", &err), err)); CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT("."); GGML_LOG_CONT(".");
} else {
GGML_LOG_WARN("ggml_opencl: repeat kernel source not found or empty. Repeat operations will not be available.\n");
backend_ctx->program_repeat = nullptr;
backend_ctx->kernel_repeat = nullptr;
}
} }
// pad // pad
@ -1856,18 +1848,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
#else #else
const std::string kernel_src = read_file("tanh.cl"); const std::string kernel_src = read_file("tanh.cl");
#endif #endif
if (!kernel_src.empty()) { cl_program prog =
backend_ctx->program_tanh = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_tanh_f32 = clCreateKernel(prog, "kernel_tanh_f32", &err), err));
CL_CHECK((backend_ctx->kernel_tanh_f32_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f32_nd", &err), err)); CL_CHECK((backend_ctx->kernel_tanh_f32_4 = clCreateKernel(prog, "kernel_tanh_f32_4", &err), err));
CL_CHECK((backend_ctx->kernel_tanh_f16_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f16_nd", &err), err)); CL_CHECK((backend_ctx->kernel_tanh_f32_nc = clCreateKernel(prog, "kernel_tanh_f32_nc", &err), err));
GGML_LOG_CONT("."); CL_CHECK((backend_ctx->kernel_tanh_f16 = clCreateKernel(prog, "kernel_tanh_f16", &err), err));
} else { CL_CHECK((backend_ctx->kernel_tanh_f16_4 = clCreateKernel(prog, "kernel_tanh_f16_4", &err), err));
GGML_LOG_WARN("ggml_opencl: tanh kernel source not found or empty. Tanh operation will not be available.\n"); CL_CHECK((backend_ctx->kernel_tanh_f16_nc = clCreateKernel(prog, "kernel_tanh_f16_nc", &err), err));
backend_ctx->program_tanh = nullptr; CL_CHECK(clReleaseProgram(prog));
backend_ctx->kernel_tanh_f32_nd = nullptr; GGML_LOG_CONT(".");
backend_ctx->kernel_tanh_f16_nd = nullptr;
}
} }
// expm1 // expm1
@ -1959,22 +1949,13 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
#include "concat.cl.h" #include "concat.cl.h"
}; };
#else #else
const std::string kernel_src = read_file("concat.cl"); const std::string kernel_src = read_file("concat.cl");
#endif #endif
if (!kernel_src.empty()) { cl_program prog =
backend_ctx->program_concat = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_concat_f32 = clCreateKernel(prog, "kernel_concat_f32", &err), err));
CL_CHECK(clReleaseProgram(prog));
CL_CHECK((backend_ctx->kernel_concat_f32_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_contiguous", &err), err)); GGML_LOG_CONT(".");
CL_CHECK((backend_ctx->kernel_concat_f32_non_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_non_contiguous", &err), err));
GGML_LOG_CONT(".");
} else {
GGML_LOG_WARN("ggml_opencl: concat kernel source not found or empty. Concat operations will not be available.\n");
backend_ctx->program_concat = nullptr;
backend_ctx->kernel_concat_f32_contiguous = nullptr;
backend_ctx->kernel_concat_f32_non_contiguous = nullptr;
}
} }
// timestep_embedding // timestep_embedding
@ -3318,8 +3299,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_SIGMOID:
return ggml_is_contiguous(op->src[0]); return ggml_is_contiguous(op->src[0]);
case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_TANH:
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
case GGML_UNARY_OP_EXPM1: case GGML_UNARY_OP_EXPM1:
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
@ -7029,79 +7009,87 @@ static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0_abs = extra0->offset + src0->view_offs; cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd_abs = extrad->offset + dst->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const cl_ulong nb00 = src0->nb[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const cl_ulong nb0 = dst->nb[0];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
cl_kernel kernel; cl_kernel kernel;
if (dst->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_tanh_f32_nd;
} else if (dst->type == GGML_TYPE_F16) {
kernel = backend_ctx->kernel_tanh_f16_nd;
} else {
GGML_ASSERT(false && "Unsupported type for ggml_cl_tanh");
}
GGML_ASSERT(kernel != nullptr);
const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3]; if (ggml_is_contiguous(src0)) {
const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; // Handle contiguous input
int n = ggml_nelements(dst);
const int ne10 = dst->ne[0]; const int ne11 = dst->ne[1]; const int ne12 = dst->ne[2]; const int ne13 = dst->ne[3]; if (n % 4 == 0) {
const cl_ulong nb10 = dst->nb[0]; const cl_ulong nb11 = dst->nb[1]; const cl_ulong nb12 = dst->nb[2]; const cl_ulong nb13 = dst->nb[3]; if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_tanh_f32_4;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); } else {
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); kernel = backend_ctx->kernel_tanh_f16_4;
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); }
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); n /= 4;
} else {
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); if (src0->type == GGML_TYPE_F32) {
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); kernel = backend_ctx->kernel_tanh_f32;
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); } else {
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); kernel = backend_ctx->kernel_tanh_f16;
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); }
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
size_t global_work_size[3];
if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
return;
}
global_work_size[0] = (size_t)ne10;
global_work_size[1] = (size_t)ne11;
global_work_size[2] = (size_t)ne12;
size_t lws0 = 16, lws1 = 4, lws2 = 1;
if (ne10 < 16) lws0 = ne10;
if (ne11 < 4) lws1 = ne11;
if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
size_t local_work_size[] = {lws0, lws1, lws2};
size_t* local_work_size_ptr = local_work_size;
if (!backend_ctx->non_uniform_workgroups) {
if (global_work_size[0] % local_work_size[0] != 0 ||
global_work_size[1] % local_work_size[1] != 0 ||
global_work_size[2] % local_work_size[2] != 0) {
local_work_size_ptr = NULL;
} }
}
if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr;
}
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
} else {
// Handle non-contiguous input
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_tanh_f32_nc;
} else {
kernel = backend_ctx->kernel_tanh_f16_nc;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));
int nth = 64;
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
} }
static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -7319,53 +7307,58 @@ static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, con
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
if (backend_ctx->kernel_repeat == nullptr) { ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
GGML_LOG_WARN("%s: repeat kernel not available, skipping OpenCL execution.\n", __func__); ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
return;
}
ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; cl_ulong offset0 = extra0->offset + src0->view_offs;
ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_ulong off_src0 = extra_src0->offset + src0->view_offs; const int ne00 = src0->ne[0];
cl_ulong off_dst = extra_dst->offset + dst->view_offs; const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const int src0_ne0 = src0->ne[0]; const int src0_ne1 = src0->ne[1]; const int src0_ne2 = src0->ne[2]; const int src0_ne3 = src0->ne[3]; const cl_ulong nb00 = src0->nb[0];
const cl_ulong src0_nb0 = src0->nb[0]; const cl_ulong src0_nb1 = src0->nb[1]; const cl_ulong src0_nb2 = src0->nb[2]; const cl_ulong src0_nb3 = src0->nb[3]; const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int dst_ne0 = dst->ne[0]; const int dst_ne1 = dst->ne[1]; const int dst_ne2 = dst->ne[2]; const int dst_ne3 = dst->ne[3]; const int ne0 = dst->ne[0];
const cl_ulong dst_nb0 = dst->nb[0]; const cl_ulong dst_nb1 = dst->nb[1]; const cl_ulong dst_nb2 = dst->nb[2]; const cl_ulong dst_nb3 = dst->nb[3]; const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int ne3 = dst->ne[3];
cl_kernel kernel = backend_ctx->kernel_repeat; const cl_ulong nb0 = dst->nb[0];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); cl_kernel kernel = backend_ctx->kernel_repeat_f32;
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra_dst->data_device));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &off_src0));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &src0_ne0));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &src0_ne1));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &src0_ne2));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &src0_ne3));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &src0_nb0));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &src0_nb1));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &src0_nb2));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &src0_nb3));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &dst_ne0));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &dst_ne1));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &dst_ne2));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dst_ne3));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &dst_nb0));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &dst_nb1));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &dst_nb2));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &dst_nb3));
size_t gws0 = dst_ne1 > 0 ? (size_t)dst_ne1 : 1; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
size_t gws1 = dst_ne2 > 0 ? (size_t)dst_ne2 : 1; CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
size_t gws2 = dst_ne3 > 0 ? (size_t)dst_ne3 : 1; CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));
size_t global_work_size[] = { gws0, gws1, gws2 }; int nth = 64;
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3};
size_t local_work_size[] = {(size_t)nth, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
} }
static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) {
@ -7589,121 +7582,76 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
cl_command_queue queue = backend_ctx->queue;
if (backend_ctx->kernel_concat_f32_contiguous == nullptr || backend_ctx->kernel_concat_f32_non_contiguous == nullptr) { ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
GGML_LOG_WARN("%s: concat kernels not available, skipping OpenCL execution.\n", __func__); ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
return; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
}
ggml_tensor_extra_cl * extra0_cl = (ggml_tensor_extra_cl *)src0->extra; cl_ulong offset0 = extra0->offset + src0->view_offs;
ggml_tensor_extra_cl * extra1_cl = (ggml_tensor_extra_cl *)src1->extra; cl_ulong offset1 = extra1->offset + src1->view_offs;
ggml_tensor_extra_cl * extrad_cl = (ggml_tensor_extra_cl *)dst->extra; cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_ulong off_src0 = extra0_cl->offset + src0->view_offs; const int ne00 = src0->ne[0];
cl_ulong off_src1 = extra1_cl->offset + src1->view_offs; const int ne01 = src0->ne[1];
cl_ulong off_dst = extrad_cl->offset + dst->view_offs; const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const int32_t dim = ((const int32_t *) dst->op_params)[0]; const cl_ulong nb00 = src0->nb[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const cl_ulong nb10 = src1->nb[0];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb13 = src1->nb[3];
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int ne3 = dst->ne[3];
const cl_ulong nb0 = dst->nb[0];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
const cl_int dim = ((const int32_t *) dst->op_params)[0];
GGML_ASSERT(dim >= 0 && dim <= 3); GGML_ASSERT(dim >= 0 && dim <= 3);
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { int nth = MIN(64, ne0);
if (dim == 3) {
size_t nbytes_src0 = ggml_nbytes(src0); cl_kernel kernel = backend_ctx->kernel_concat_f32;
size_t nbytes_src1 = ggml_nbytes(src1);
CL_CHECK(clEnqueueCopyBuffer(queue, extra0_cl->data_device, extrad_cl->data_device, CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
off_src0, off_dst, nbytes_src0, 0, NULL, NULL)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clEnqueueCopyBuffer(queue, extra1_cl->data_device, extrad_cl->data_device, CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
off_src1, off_dst + nbytes_src0, nbytes_src1, 0, NULL, NULL)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
} else { CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_int), &dim));
cl_kernel kernel = backend_ctx->kernel_concat_f32_contiguous; size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3};
size_t global_work_size[3]; size_t local_work_size[] = {(size_t)nth, 1, 1};
for (int i3 = 0; i3 < dst->ne[3]; ++i3) { backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
cl_ulong current_off_src0 = off_src0 + (i3 * src0->nb[3]);
cl_ulong current_off_src1 = off_src1 + (i3 * src1->nb[3]);
cl_ulong current_off_dst = off_dst + (i3 * dst->nb[3]);
int d_ne00 = src0->ne[0]; int d_ne01 = src0->ne[1]; int d_ne02 = src0->ne[2];
int d_ne10 = src1->ne[0]; int d_ne11 = src1->ne[1]; int d_ne12 = src1->ne[2];
int d_ne0 = dst->ne[0]; int d_ne1 = dst->ne[1]; int d_ne2 = dst->ne[2];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &current_off_src0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &current_off_src1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &current_off_dst));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &d_ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne10));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &d_ne11));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &d_ne12));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dim));
global_work_size[0] = d_ne0;
global_work_size[1] = d_ne1;
global_work_size[2] = d_ne2;
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
}
}
} else {
cl_kernel kernel = backend_ctx->kernel_concat_f32_non_contiguous;
cl_long ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];
cl_ulong nb00 = src0->nb[0], nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3];
cl_ulong nb10 = src1->nb[0], nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3];
cl_long d_ne0 = dst->ne[0], d_ne1 = dst->ne[1], d_ne2 = dst->ne[2], d_ne3 = dst->ne[3];
cl_ulong d_nb0 = dst->nb[0], d_nb1 = dst->nb[1], d_nb2 = dst->nb[2], d_nb3 = dst->nb[3];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_src1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &off_dst));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_long), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_long), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_long), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_long), &ne03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_long), &d_ne0));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_long), &d_ne1));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_long), &d_ne2));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_long), &d_ne3));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &d_nb0));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &d_nb1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &d_nb2));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(cl_ulong), &d_nb3));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &dim));
size_t global_work_size_nc[] = { d_ne1 > 0 ? (size_t)d_ne1 : 1,
d_ne2 > 0 ? (size_t)d_ne2 : 1,
d_ne3 > 0 ? (size_t)d_ne3 : 1 };
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_nc, NULL, dst);
}
} }
static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) {
@ -8394,6 +8342,7 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t
CL_CHECK(clReleaseMemObject(D_sub_buffer)); CL_CHECK(clReleaseMemObject(D_sub_buffer));
CL_CHECK(clReleaseMemObject(D_image1d)); CL_CHECK(clReleaseMemObject(D_image1d));
#else #else
GGML_UNUSED(backend);
GGML_UNUSED(src0); GGML_UNUSED(src0);
GGML_UNUSED(src1); GGML_UNUSED(src1);
GGML_UNUSED(dst); GGML_UNUSED(dst);
@ -9913,7 +9862,16 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel = backend_ctx->kernel_scale; cl_kernel kernel;
int n = ggml_nelements(dst);
if (n % 4 == 0) {
kernel = backend_ctx->kernel_scale_f32_4;
n /= 4;
} else {
kernel = backend_ctx->kernel_scale_f32;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
@ -9922,8 +9880,6 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias));
int n = ggml_nelements(dst)/4;
size_t global_work_size[] = {(size_t)n, 1, 1}; size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1}; size_t local_work_size[] = {64, 1, 1};

View File

@ -1,109 +1,51 @@
kernel void kernel_concat_f32_contiguous( kernel void kernel_concat_f32(
global const char * p_src0, ulong off_src0, global const char * src0,
global const char * p_src1, ulong off_src1, ulong offset0,
global char * p_dst, ulong off_dst, global const char * src1,
int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice ulong offset1,
int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes) global char * dst,
int d_ne0, int d_ne1, int d_ne2, // dst->ne[0..2] for the slice ulong offsetd,
int dim int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int dim
) { ) {
global const float * src0 = (global const float*)((global char*)p_src0 + off_src0); src0 = src0 + offset0;
global const float * src1 = (global const float*)((global char*)p_src1 + off_src1); src1 = src1 + offset1;
global float * dst = (global float*)((global char*)p_dst + off_dst); dst = dst + offsetd;
int i0 = get_global_id(0); // Index along dst's 0th dimension const int i3 = get_group_id(2);
int i1 = get_global_id(1); // Index along dst's 1st dimension const int i2 = get_group_id(1);
int i2 = get_global_id(2); // Index along dst's 2nd dimension const int i1 = get_group_id(0);
if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) { int o[4] = {0, 0, 0, 0};
return; o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
}
ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0; global const float * x;
ulong src_idx;
if (dim == 0) { for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
if (i0 < d_ne00) { // Data from src0 if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; x = (global const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
dst[dst_idx] = src0[src_idx];
} else { // Data from src1
src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00);
dst[dst_idx] = src1[src_idx];
}
} else if (dim == 1) {
if (i1 < d_ne01) { // Data from src0
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
dst[dst_idx] = src0[src_idx];
} else { // Data from src1
src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0;
dst[dst_idx] = src1[src_idx];
}
} else if (dim == 2) {
if (i2 < d_ne02) { // Data from src0
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
dst[dst_idx] = src0[src_idx];
} else { // Data from src1
src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0;
dst[dst_idx] = src1[src_idx];
}
}
}
kernel void kernel_concat_f32_non_contiguous(
global const char * p_src0, ulong off_src0,
global const char * p_src1, ulong off_src1,
global char * p_dst, ulong off_dst,
long ne00, long ne01, long ne02, long ne03,
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1
long d_ne0, long d_ne1, long d_ne2, long d_ne3,
ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3,
int dim
) {
global const char * src0_base = p_src0 + off_src0;
global const char * src1_base = p_src1 + off_src1;
global char * dst_base = p_dst + off_dst;
long current_i1 = get_global_id(0); // Index for dst_dim_1
long current_i2 = get_global_id(1); // Index for dst_dim_2
long current_i3 = get_global_id(2); // Index for dst_dim_3
if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) {
return;
}
global const float * x_val_ptr;
global float * y_val_ptr;
for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) {
bool use_src0;
long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3;
if (dim == 0) {
use_src0 = (current_i0 < ne00);
if (!use_src0) { s_i0 = current_i0 - ne00; }
} else if (dim == 1) {
use_src0 = (current_i1 < ne01);
if (!use_src0) { s_i1 = current_i1 - ne01; }
} else if (dim == 2) {
use_src0 = (current_i2 < ne02);
if (!use_src0) { s_i2 = current_i2 - ne02; }
} else { // dim == 3
use_src0 = (current_i3 < ne03);
if (!use_src0) { s_i3 = current_i3 - ne03; }
}
if (use_src0) {
x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00);
} else { } else {
x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10); x = (global const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
} }
y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0); global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*y_val_ptr = *x_val_ptr;
*y = *x;
} }
} }

View File

@ -1,39 +1,38 @@
kernel void kernel_repeat( kernel void kernel_repeat_f32(
global const char * src0_data_in, global const char * src0,
global char * dst_data_in, ulong offset0,
ulong src0_offset, global char * dst,
ulong dst_offset, ulong offsetd,
int src0_ne0, int src0_ne1, int src0_ne2, int src0_ne3, int ne00,
ulong src0_nb0, ulong src0_nb1, ulong src0_nb2, ulong src0_nb3, int ne01,
int dst_ne0, int dst_ne1, int dst_ne2, int dst_ne3, int ne02,
ulong dst_nb0, ulong dst_nb1, ulong dst_nb2, ulong dst_nb3 int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne0,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) { ) {
global const char * src0_data = src0_data_in + src0_offset; src0 = src0 + offset0;
global char * dst_data = dst_data_in + dst_offset; dst = dst + offsetd;
const int d3 = get_global_id(2); const int i3 = get_group_id(2);
const int d2 = get_global_id(1); const int i2 = get_group_id(1);
const int d1 = get_global_id(0); const int i1 = get_group_id(0);
if (d3 >= dst_ne3 || d2 >= dst_ne2 || d1 >= dst_ne1) { const int i03 = i3%ne03;
return; const int i02 = i2%ne02;
} const int i01 = i1%ne01;
const int s3 = d3 % src0_ne3; global const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
const int s2 = d2 % src0_ne2; global char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1;
const int s1 = d1 % src0_ne1;
const global char * p_src0_slice = src0_data + (ulong)s3*src0_nb3 + (ulong)s2*src0_nb2 + (ulong)s1*src0_nb1; for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
global char * p_dst_slice = dst_data + (ulong)d3*dst_nb3 + (ulong)d2*dst_nb2 + (ulong)d1*dst_nb1; const int i00 = i0%ne00;
*((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i00*nb00));
for (int d0 = 0; d0 < dst_ne0; ++d0) {
// Determine source index for dimension 0 based on tiling/broadcasting.
const int s0 = d0 % src0_ne0;
const global char * restrict current_src_el_ptr = p_src0_slice + (ulong)s0*src0_nb0;
global char * restrict current_dst_el_ptr = p_dst_slice + (ulong)d0*dst_nb0;
for (int k = 0; k < src0_nb0; ++k) {
current_dst_el_ptr[k] = current_src_el_ptr[k];
}
} }
} }

View File

@ -1,9 +1,19 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_khr_fp16 : enable
//------------------------------------------------------------------------------ kernel void kernel_scale_f32(
// scale global float * src0,
//------------------------------------------------------------------------------ ulong offset0,
kernel void kernel_scale( global float * dst,
ulong offsetd,
float scale,
float bias
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias;
}
kernel void kernel_scale_f32_4(
global float4 * src0, global float4 * src0,
ulong offset0, ulong offset0,
global float4 * dst, global float4 * dst,

View File

@ -1,63 +1,109 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_required_subgroup_size kernel void kernel_tanh_f32(
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable global const float * src0,
#define INTEL_GPU 1 ulong offset0,
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) global float * dst,
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) ulong offsetd
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
kernel void kernel_tanh_f32_nd(
global void * p_src0_base, ulong off_src0_abs,
global void * p_dst_base, ulong off_dst_abs,
int ne00, int ne01, int ne02, int ne03,
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
int ne10, int ne11, int ne12, int ne13,
ulong nb10, ulong nb11, ulong nb12, ulong nb13
) { ) {
int i0 = get_global_id(0); src0 = (global float*)((global char*)src0 + offset0);
int i1 = get_global_id(1); dst = (global float*)((global char*)dst + offsetd);
int i2 = get_global_id(2);
if (i0 < ne10 && i1 < ne11 && i2 < ne12) { dst[get_global_id(0)] = tanh(src0[get_global_id(0)]);
for (int i3 = 0; i3 < ne13; ++i3) { }
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; kernel void kernel_tanh_f32_4(
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); global const float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
*dst_val_ptr = tanh(*src_val_ptr); dst[get_global_id(0)] = tanh(src0[get_global_id(0)]);
} }
kernel void kernel_tanh_f16(
global const half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
dst[get_global_id(0)] = tanh(src0[get_global_id(0)]);
}
kernel void kernel_tanh_f16_4(
global const half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = tanh(src0[get_global_id(0)]);
}
kernel void kernel_tanh_f32_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
dst = dst + offsetd;
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*y = tanh(*x);
} }
} }
kernel void kernel_tanh_f16_nd( kernel void kernel_tanh_f16_nc(
global void * p_src0_base, ulong off_src0_abs, global const char * src0,
global void * p_dst_base, ulong off_dst_abs, ulong offset0,
int ne00, int ne01, int ne02, int ne03, global char * dst,
ulong nb00, ulong nb01, ulong nb02, ulong nb03, ulong offsetd,
int ne10, int ne11, int ne12, int ne13, int ne00,
ulong nb10, ulong nb11, ulong nb12, ulong nb13 ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3
) { ) {
int i0 = get_global_id(0); src0 = src0 + offset0;
int i1 = get_global_id(1); dst = dst + offsetd;
int i2 = get_global_id(2);
if (i0 < ne10 && i1 < ne11 && i2 < ne12) { const int i3 = get_group_id(2);
for (int i3 = 0; i3 < ne13; ++i3) { const int i2 = get_group_id(1);
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; const int i1 = get_group_id(0);
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = tanh(*src_val_ptr); *y = tanh(*x);
}
} }
} }