opencl: refactor softplus

This commit is contained in:
shaoqi 2026-02-02 16:17:25 -08:00 committed by Li He
parent 0b25212801
commit 8efc26bc0a
2 changed files with 166 additions and 149 deletions

View File

@ -546,8 +546,8 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc; cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc;
cl_kernel kernel_expm1_f32, kernel_expm1_f32_4, kernel_expm1_f32_nc; cl_kernel kernel_expm1_f32, kernel_expm1_f32_4, kernel_expm1_f32_nc;
cl_kernel kernel_expm1_f16, kernel_expm1_f16_4, kernel_expm1_f16_nc; cl_kernel kernel_expm1_f16, kernel_expm1_f16_4, kernel_expm1_f16_nc;
cl_kernel kernel_softplus_f32_nd; cl_kernel kernel_softplus_f32, kernel_softplus_f32_4, kernel_softplus_f32_nc;
cl_kernel kernel_softplus_f16_nd; cl_kernel kernel_softplus_f16, kernel_softplus_f16_4, kernel_softplus_f16_nc;
cl_kernel kernel_upscale; cl_kernel kernel_upscale;
cl_kernel kernel_upscale_bilinear; cl_kernel kernel_upscale_bilinear;
cl_kernel kernel_concat_f32; cl_kernel kernel_concat_f32;
@ -1890,20 +1890,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
#else #else
const std::string kernel_src = read_file("softplus.cl"); const std::string kernel_src = read_file("softplus.cl");
#endif #endif
cl_program prog; cl_program prog =
if (!kernel_src.empty()) { build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
prog = CL_CHECK((backend_ctx->kernel_softplus_f32 = clCreateKernel(prog, "kernel_softplus_f32", &err), err));
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_softplus_f32_4 = clCreateKernel(prog, "kernel_softplus_f32_4", &err), err));
CL_CHECK((backend_ctx->kernel_softplus_f32_nd = clCreateKernel(prog, "kernel_softplus_f32_nd", &err), err)); CL_CHECK((backend_ctx->kernel_softplus_f32_nc = clCreateKernel(prog, "kernel_softplus_f32_nc", &err), err));
CL_CHECK((backend_ctx->kernel_softplus_f16_nd = clCreateKernel(prog, "kernel_softplus_f16_nd", &err), err)); CL_CHECK((backend_ctx->kernel_softplus_f16 = clCreateKernel(prog, "kernel_softplus_f16", &err), err));
GGML_LOG_CONT("."); CL_CHECK((backend_ctx->kernel_softplus_f16_4 = clCreateKernel(prog, "kernel_softplus_f16_4", &err), err));
} else { CL_CHECK((backend_ctx->kernel_softplus_f16_nc = clCreateKernel(prog, "kernel_softplus_f16_nc", &err), err));
GGML_LOG_WARN("ggml_opencl: softplus kernel source not found or empty. Softplus operation will not be available.\n");
prog = nullptr;
backend_ctx->kernel_softplus_f32_nd = nullptr;
backend_ctx->kernel_softplus_f16_nd = nullptr;
}
CL_CHECK(clReleaseProgram(prog)); CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
} }
// upscale // upscale
@ -3299,8 +3295,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
case GGML_UNARY_OP_EXPM1: case GGML_UNARY_OP_EXPM1:
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_SOFTPLUS:
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
default: default:
return false; return false;
} }
@ -7196,18 +7191,8 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0_abs = extra0->offset + src0->view_offs; cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd_abs = extrad->offset + dst->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel;
if (dst->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_softplus_f32_nd;
} else if (dst->type == GGML_TYPE_F16) {
kernel = backend_ctx->kernel_softplus_f16_nd;
} else {
GGML_ASSERT(false && "Unsupported type for ggml_cl_softplus");
}
GGML_ASSERT(kernel != nullptr);
const int ne00 = src0->ne[0]; const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1]; const int ne01 = src0->ne[1];
@ -7219,70 +7204,74 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c
const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3]; const cl_ulong nb03 = src0->nb[3];
const int ne10 = dst->ne[0]; const cl_ulong nb0 = dst->nb[0];
const int ne11 = dst->ne[1]; const cl_ulong nb1 = dst->nb[1];
const int ne12 = dst->ne[2]; const cl_ulong nb2 = dst->nb[2];
const int ne13 = dst->ne[3]; const cl_ulong nb3 = dst->nb[3];
const cl_ulong nb10 = dst->nb[0]; cl_kernel kernel;
const cl_ulong nb11 = dst->nb[1];
const cl_ulong nb12 = dst->nb[2];
const cl_ulong nb13 = dst->nb[3];
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); if (ggml_is_contiguous(src0)) {
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); // Handle contiguous input
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); int n = ggml_nelements(dst);
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); if (n % 4 == 0) {
if (src0->type == GGML_TYPE_F32) {
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); kernel = backend_ctx->kernel_softplus_f32_4;
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); } else {
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); kernel = backend_ctx->kernel_softplus_f16_4;
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); }
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); n /= 4;
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); } else {
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); if (src0->type == GGML_TYPE_F32) {
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); kernel = backend_ctx->kernel_softplus_f32;
} else {
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); kernel = backend_ctx->kernel_softplus_f16;
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); }
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
size_t global_work_size[3];
if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
return;
}
global_work_size[0] = (size_t)ne10;
global_work_size[1] = (size_t)ne11;
global_work_size[2] = (size_t)ne12;
size_t lws0 = 16, lws1 = 4, lws2 = 1;
if (ne10 < 16) lws0 = ne10;
if (ne11 < 4) lws1 = ne11;
if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
size_t local_work_size[] = {lws0, lws1, lws2};
size_t* local_work_size_ptr = local_work_size;
if (!backend_ctx->non_uniform_workgroups) {
if (global_work_size[0] % local_work_size[0] != 0 ||
global_work_size[1] % local_work_size[1] != 0 ||
global_work_size[2] % local_work_size[2] != 0) {
local_work_size_ptr = NULL;
} }
}
if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr;
}
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
} else {
// Handle non-contiguous input
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_softplus_f32_nc;
} else {
kernel = backend_ctx->kernel_softplus_f16_nc;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3));
int nth = 64;
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
} }
static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) { static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) {

View File

@ -3,86 +3,114 @@
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
// softplus // softplus
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
inline float softplus_f32(float x){
float ax = fabs(x); kernel void kernel_softplus_f32(
float m = fmax(x, 0.0f); global const float * src0,
return log1p(exp(-ax)) + m; ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
} }
kernel void kernel_softplus_f32_nd( kernel void kernel_softplus_f32_4(
global void * p_src0_base, global const float4 * src0,
ulong off_src0_abs, ulong offset0,
global void * p_dst_base, global float4 * dst,
ulong off_dst_abs, ulong offsetd
int ne00, ) {
int ne01, src0 = (global float4*)((global char*)src0 + offset0);
int ne02, dst = (global float4*)((global char*)dst + offsetd);
int ne03,
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
}
kernel void kernel_softplus_f16(
global const half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
const float x = convert_float(src0[get_global_id(0)]);
dst[get_global_id(0)] = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
kernel void kernel_softplus_f16_4(
global const half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
const float4 x = convert_float4(src0[get_global_id(0)]);
dst[get_global_id(0)] = convert_half4_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
}
kernel void kernel_softplus_f32_nc(
global const char * src0,
ulong offset0,
global char * dst,
ulong offsetd,
int ne00,
ulong nb00, ulong nb00,
ulong nb01, ulong nb01,
ulong nb02, ulong nb02,
ulong nb03, ulong nb03,
int ne10, ulong nb0,
int ne11, ulong nb1,
int ne12, ulong nb2,
int ne13, ulong nb3
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
) { ) {
int i0 = get_global_id(0); src0 = src0 + offset0;
int i1 = get_global_id(1); dst = dst + offsetd;
int i2 = get_global_id(2);
if (i0 < ne10 && i1 < ne11 && i2 < ne12) { const int i3 = get_group_id(2);
for (int i3 = 0; i3 < ne13; ++i3) { const int i2 = get_group_id(1);
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; const int i1 = get_group_id(0);
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = softplus_f32(*src_val_ptr); *y = (*x > 20.0f) ? *x : log(1.0f + exp(*x));
}
} }
} }
kernel void kernel_softplus_f16_nd( kernel void kernel_softplus_f16_nc(
global void * p_src0_base, global const char * src0,
ulong off_src0_abs, ulong offset0,
global void * p_dst_base, global char * dst,
ulong off_dst_abs, ulong offsetd,
int ne00, int ne00,
int ne01,
int ne02,
int ne03,
ulong nb00, ulong nb00,
ulong nb01, ulong nb01,
ulong nb02, ulong nb02,
ulong nb03, ulong nb03,
int ne10, ulong nb0,
int ne11, ulong nb1,
int ne12, ulong nb2,
int ne13, ulong nb3
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb13
) { ) {
int i0 = get_global_id(0); src0 = src0 + offset0;
int i1 = get_global_id(1); dst = dst + offsetd;
int i2 = get_global_id(2);
if (i0 < ne10 && i1 < ne11 && i2 < ne12) { const int i3 = get_group_id(2);
for (int i3 = 0; i3 < ne13; ++i3) { const int i2 = get_group_id(1);
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; const int i1 = get_group_id(0);
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); global const half * hx = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
global half * hy = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
*dst_val_ptr = (half)(softplus_f32((float)(*src_val_ptr))); const float x = convert_float(*hx);
} *hy = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
} }
} }