opencl: refactor concat

This commit is contained in:
Li He 2026-01-29 16:35:15 -08:00
parent 3dd95914d0
commit 93b642ea44
2 changed files with 107 additions and 220 deletions

View File

@ -466,7 +466,6 @@ struct ggml_backend_opencl_context {
cl_program program_pad;
cl_program program_tanh;
cl_program program_upscale;
cl_program program_concat;
cl_program program_conv_2d_f16;
cl_program program_conv_2d_f32;
cl_program program_conv_2d_f16_f32;
@ -554,8 +553,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_softplus_f16_nd;
cl_kernel kernel_upscale;
cl_kernel kernel_upscale_bilinear;
cl_kernel kernel_concat_f32_contiguous;
cl_kernel kernel_concat_f32_non_contiguous;
cl_kernel kernel_concat_f32;
cl_kernel kernel_conv_2d_f16;
cl_kernel kernel_conv_2d_f32;
cl_kernel kernel_conv_2d_f16_f32;
@ -1959,22 +1957,13 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
#include "concat.cl.h"
};
#else
const std::string kernel_src = read_file("concat.cl");
#endif
if (!kernel_src.empty()) {
backend_ctx->program_concat =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_concat_f32_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_contiguous", &err), err));
CL_CHECK((backend_ctx->kernel_concat_f32_non_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_non_contiguous", &err), err));
GGML_LOG_CONT(".");
} else {
GGML_LOG_WARN("ggml_opencl: concat kernel source not found or empty. Concat operations will not be available.\n");
backend_ctx->program_concat = nullptr;
backend_ctx->kernel_concat_f32_contiguous = nullptr;
backend_ctx->kernel_concat_f32_non_contiguous = nullptr;
}
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_concat_f32 = clCreateKernel(prog, "kernel_concat_f32", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// timestep_embedding
@ -7591,119 +7580,75 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
cl_command_queue queue = backend_ctx->queue;
if (backend_ctx->kernel_concat_f32_contiguous == nullptr || backend_ctx->kernel_concat_f32_non_contiguous == nullptr) {
GGML_LOG_WARN("%s: concat kernels not available, skipping OpenCL execution.\n", __func__);
return;
}
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
ggml_tensor_extra_cl * extra0_cl = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1_cl = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad_cl = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_ulong off_src0 = extra0_cl->offset + src0->view_offs;
cl_ulong off_src1 = extra1_cl->offset + src1->view_offs;
cl_ulong off_dst = extrad_cl->offset + dst->view_offs;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const int32_t dim = ((const int32_t *) dst->op_params)[0];
const cl_ulong nb00 = src0->nb[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const cl_ulong nb10 = src1->nb[0];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb13 = src1->nb[3];
const int ne0 = dst->ne[0];
const int ne1 = dst->ne[1];
const int ne2 = dst->ne[2];
const int ne3 = dst->ne[3];
const cl_ulong nb0 = dst->nb[0];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
const cl_int dim = ((const int32_t *) dst->op_params)[0];
GGML_ASSERT(dim >= 0 && dim <= 3);
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
if (dim == 3) {
int nth = MIN(64, ne0);
size_t nbytes_src0 = ggml_nbytes(src0);
size_t nbytes_src1 = ggml_nbytes(src1);
cl_kernel kernel = backend_ctx->kernel_concat_f32;
CL_CHECK(clEnqueueCopyBuffer(queue, extra0_cl->data_device, extrad_cl->data_device,
off_src0, off_dst, nbytes_src0, 0, NULL, NULL));
CL_CHECK(clEnqueueCopyBuffer(queue, extra1_cl->data_device, extrad_cl->data_device,
off_src1, off_dst + nbytes_src0, nbytes_src1, 0, NULL, NULL));
} else {
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_int), &dim));
cl_kernel kernel = backend_ctx->kernel_concat_f32_contiguous;
size_t global_work_size[3];
size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3};
size_t local_work_size[] = {(size_t)nth, 1, 1};
for (int i3 = 0; i3 < dst->ne[3]; ++i3) {
cl_ulong current_off_src0 = off_src0 + (i3 * src0->nb[3]);
cl_ulong current_off_src1 = off_src1 + (i3 * src1->nb[3]);
cl_ulong current_off_dst = off_dst + (i3 * dst->nb[3]);
int d_ne00 = src0->ne[0]; int d_ne01 = src0->ne[1]; int d_ne02 = src0->ne[2];
int d_ne10 = src1->ne[0]; int d_ne11 = src1->ne[1]; int d_ne12 = src1->ne[2];
int d_ne0 = dst->ne[0]; int d_ne1 = dst->ne[1]; int d_ne2 = dst->ne[2];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &current_off_src0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &current_off_src1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &current_off_dst));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &d_ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne10));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &d_ne11));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &d_ne12));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dim));
global_work_size[0] = d_ne0;
global_work_size[1] = d_ne1;
global_work_size[2] = d_ne2;
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
}
}
} else {
cl_kernel kernel = backend_ctx->kernel_concat_f32_non_contiguous;
cl_long ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3];
cl_ulong nb00 = src0->nb[0], nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3];
cl_ulong nb10 = src1->nb[0], nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3];
cl_long d_ne0 = dst->ne[0], d_ne1 = dst->ne[1], d_ne2 = dst->ne[2], d_ne3 = dst->ne[3];
cl_ulong d_nb0 = dst->nb[0], d_nb1 = dst->nb[1], d_nb2 = dst->nb[2], d_nb3 = dst->nb[3];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_src1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &off_dst));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_long), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_long), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_long), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_long), &ne03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_long), &d_ne0));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_long), &d_ne1));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_long), &d_ne2));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_long), &d_ne3));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &d_nb0));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &d_nb1));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &d_nb2));
CL_CHECK(clSetKernelArg(kernel, 25, sizeof(cl_ulong), &d_nb3));
CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &dim));
size_t global_work_size_nc[] = { d_ne1 > 0 ? (size_t)d_ne1 : 1,
d_ne2 > 0 ? (size_t)d_ne2 : 1,
d_ne3 > 0 ? (size_t)d_ne3 : 1 };
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_nc, NULL, dst);
}
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) {

View File

@ -1,109 +1,51 @@
kernel void kernel_concat_f32_contiguous(
global const char * p_src0, ulong off_src0,
global const char * p_src1, ulong off_src1,
global char * p_dst, ulong off_dst,
int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice
int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes)
int d_ne0, int d_ne1, int d_ne2, // dst->ne[0..2] for the slice
int dim
kernel void kernel_concat_f32(
global const char * src0,
ulong offset0,
global const char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
ulong nb0,
ulong nb1,
ulong nb2,
ulong nb3,
int dim
) {
global const float * src0 = (global const float*)((global char*)p_src0 + off_src0);
global const float * src1 = (global const float*)((global char*)p_src1 + off_src1);
global float * dst = (global float*)((global char*)p_dst + off_dst);
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i0 = get_global_id(0); // Index along dst's 0th dimension
int i1 = get_global_id(1); // Index along dst's 1st dimension
int i2 = get_global_id(2); // Index along dst's 2nd dimension
const int i3 = get_group_id(2);
const int i2 = get_group_id(1);
const int i1 = get_group_id(0);
if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) {
return;
}
int o[4] = {0, 0, 0, 0};
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0;
ulong src_idx;
global const float * x;
if (dim == 0) {
if (i0 < d_ne00) { // Data from src0
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
dst[dst_idx] = src0[src_idx];
} else { // Data from src1
src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00);
dst[dst_idx] = src1[src_idx];
}
} else if (dim == 1) {
if (i1 < d_ne01) { // Data from src0
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
dst[dst_idx] = src0[src_idx];
} else { // Data from src1
src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0;
dst[dst_idx] = src1[src_idx];
}
} else if (dim == 2) {
if (i2 < d_ne02) { // Data from src0
src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;
dst[dst_idx] = src0[src_idx];
} else { // Data from src1
src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0;
dst[dst_idx] = src1[src_idx];
}
}
}
kernel void kernel_concat_f32_non_contiguous(
global const char * p_src0, ulong off_src0,
global const char * p_src1, ulong off_src1,
global char * p_dst, ulong off_dst,
long ne00, long ne01, long ne02, long ne03,
ulong nb00, ulong nb01, ulong nb02, ulong nb03,
ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1
long d_ne0, long d_ne1, long d_ne2, long d_ne3,
ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3,
int dim
) {
global const char * src0_base = p_src0 + off_src0;
global const char * src1_base = p_src1 + off_src1;
global char * dst_base = p_dst + off_dst;
long current_i1 = get_global_id(0); // Index for dst_dim_1
long current_i2 = get_global_id(1); // Index for dst_dim_2
long current_i3 = get_global_id(2); // Index for dst_dim_3
if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) {
return;
}
global const float * x_val_ptr;
global float * y_val_ptr;
for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) {
bool use_src0;
long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3;
if (dim == 0) {
use_src0 = (current_i0 < ne00);
if (!use_src0) { s_i0 = current_i0 - ne00; }
} else if (dim == 1) {
use_src0 = (current_i1 < ne01);
if (!use_src0) { s_i1 = current_i1 - ne01; }
} else if (dim == 2) {
use_src0 = (current_i2 < ne02);
if (!use_src0) { s_i2 = current_i2 - ne02; }
} else { // dim == 3
use_src0 = (current_i3 < ne03);
if (!use_src0) { s_i3 = current_i3 - ne03; }
}
if (use_src0) {
x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00);
for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
x = (global const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
} else {
x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10);
x = (global const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
}
y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0);
*y_val_ptr = *x_val_ptr;
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*y = *x;
}
}