opencl: add SOFTPLUS op support (#18726)
This commit is contained in:
parent
b137718878
commit
707cbafcaa
|
|
@ -122,6 +122,7 @@ set(GGML_OPENCL_KERNELS
|
||||||
upscale
|
upscale
|
||||||
tanh
|
tanh
|
||||||
expm1
|
expm1
|
||||||
|
softplus
|
||||||
pad
|
pad
|
||||||
repeat
|
repeat
|
||||||
mul_mat_f16_f32
|
mul_mat_f16_f32
|
||||||
|
|
|
||||||
|
|
@ -540,6 +540,8 @@ struct ggml_backend_opencl_context {
|
||||||
cl_kernel kernel_tanh_f16_nd;
|
cl_kernel kernel_tanh_f16_nd;
|
||||||
cl_kernel kernel_expm1_f32_nd;
|
cl_kernel kernel_expm1_f32_nd;
|
||||||
cl_kernel kernel_expm1_f16_nd;
|
cl_kernel kernel_expm1_f16_nd;
|
||||||
|
cl_kernel kernel_softplus_f32_nd;
|
||||||
|
cl_kernel kernel_softplus_f16_nd;
|
||||||
cl_kernel kernel_upscale;
|
cl_kernel kernel_upscale;
|
||||||
cl_kernel kernel_upscale_bilinear;
|
cl_kernel kernel_upscale_bilinear;
|
||||||
cl_kernel kernel_concat_f32_contiguous;
|
cl_kernel kernel_concat_f32_contiguous;
|
||||||
|
|
@ -1826,6 +1828,31 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
|
||||||
CL_CHECK(clReleaseProgram(prog));
|
CL_CHECK(clReleaseProgram(prog));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// softplus
|
||||||
|
{
|
||||||
|
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||||
|
const std::string kernel_src {
|
||||||
|
#include "softplus.cl.h"
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
const std::string kernel_src = read_file("softplus.cl");
|
||||||
|
#endif
|
||||||
|
cl_program prog;
|
||||||
|
if (!kernel_src.empty()) {
|
||||||
|
prog =
|
||||||
|
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
|
||||||
|
CL_CHECK((backend_ctx->kernel_softplus_f32_nd = clCreateKernel(prog, "kernel_softplus_f32_nd", &err), err));
|
||||||
|
CL_CHECK((backend_ctx->kernel_softplus_f16_nd = clCreateKernel(prog, "kernel_softplus_f16_nd", &err), err));
|
||||||
|
GGML_LOG_CONT(".");
|
||||||
|
} else {
|
||||||
|
GGML_LOG_WARN("ggml_opencl: softplus kernel source not found or empty. Softplus operation will not be available.\n");
|
||||||
|
prog = nullptr;
|
||||||
|
backend_ctx->kernel_softplus_f32_nd = nullptr;
|
||||||
|
backend_ctx->kernel_softplus_f16_nd = nullptr;
|
||||||
|
}
|
||||||
|
CL_CHECK(clReleaseProgram(prog));
|
||||||
|
}
|
||||||
|
|
||||||
// upscale
|
// upscale
|
||||||
{
|
{
|
||||||
#ifdef GGML_OPENCL_EMBED_KERNELS
|
#ifdef GGML_OPENCL_EMBED_KERNELS
|
||||||
|
|
@ -3138,6 +3165,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
|
||||||
case GGML_UNARY_OP_EXPM1:
|
case GGML_UNARY_OP_EXPM1:
|
||||||
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
|
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
|
||||||
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
|
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
|
||||||
|
case GGML_UNARY_OP_SOFTPLUS:
|
||||||
|
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
|
||||||
|
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -6596,6 +6626,108 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons
|
||||||
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
|
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
GGML_ASSERT(src0);
|
||||||
|
GGML_ASSERT(src0->extra);
|
||||||
|
GGML_ASSERT(dst);
|
||||||
|
GGML_ASSERT(dst->extra);
|
||||||
|
|
||||||
|
UNUSED(src1);
|
||||||
|
|
||||||
|
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
|
||||||
|
|
||||||
|
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
|
||||||
|
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
|
||||||
|
|
||||||
|
cl_ulong offset0_abs = extra0->offset + src0->view_offs;
|
||||||
|
cl_ulong offsetd_abs = extrad->offset + dst->view_offs;
|
||||||
|
|
||||||
|
cl_kernel kernel;
|
||||||
|
if (dst->type == GGML_TYPE_F32) {
|
||||||
|
kernel = backend_ctx->kernel_softplus_f32_nd;
|
||||||
|
} else if (dst->type == GGML_TYPE_F16) {
|
||||||
|
kernel = backend_ctx->kernel_softplus_f16_nd;
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false && "Unsupported type for ggml_cl_softplus");
|
||||||
|
}
|
||||||
|
GGML_ASSERT(kernel != nullptr);
|
||||||
|
|
||||||
|
const int ne00 = src0->ne[0];
|
||||||
|
const int ne01 = src0->ne[1];
|
||||||
|
const int ne02 = src0->ne[2];
|
||||||
|
const int ne03 = src0->ne[3];
|
||||||
|
|
||||||
|
const cl_ulong nb00 = src0->nb[0];
|
||||||
|
const cl_ulong nb01 = src0->nb[1];
|
||||||
|
const cl_ulong nb02 = src0->nb[2];
|
||||||
|
const cl_ulong nb03 = src0->nb[3];
|
||||||
|
|
||||||
|
const int ne10 = dst->ne[0];
|
||||||
|
const int ne11 = dst->ne[1];
|
||||||
|
const int ne12 = dst->ne[2];
|
||||||
|
const int ne13 = dst->ne[3];
|
||||||
|
|
||||||
|
const cl_ulong nb10 = dst->nb[0];
|
||||||
|
const cl_ulong nb11 = dst->nb[1];
|
||||||
|
const cl_ulong nb12 = dst->nb[2];
|
||||||
|
const cl_ulong nb13 = dst->nb[3];
|
||||||
|
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs));
|
||||||
|
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03));
|
||||||
|
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12));
|
||||||
|
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13));
|
||||||
|
|
||||||
|
size_t global_work_size[3];
|
||||||
|
if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
global_work_size[0] = (size_t)ne10;
|
||||||
|
global_work_size[1] = (size_t)ne11;
|
||||||
|
global_work_size[2] = (size_t)ne12;
|
||||||
|
|
||||||
|
size_t lws0 = 16, lws1 = 4, lws2 = 1;
|
||||||
|
if (ne10 < 16) lws0 = ne10;
|
||||||
|
if (ne11 < 4) lws1 = ne11;
|
||||||
|
if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1;
|
||||||
|
|
||||||
|
while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2;
|
||||||
|
while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2;
|
||||||
|
while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2;
|
||||||
|
|
||||||
|
|
||||||
|
size_t local_work_size[] = {lws0, lws1, lws2};
|
||||||
|
|
||||||
|
size_t* local_work_size_ptr = local_work_size;
|
||||||
|
if (!backend_ctx->non_uniform_workgroups) {
|
||||||
|
if (global_work_size[0] % local_work_size[0] != 0 ||
|
||||||
|
global_work_size[1] % local_work_size[1] != 0 ||
|
||||||
|
global_work_size[2] % local_work_size[2] != 0) {
|
||||||
|
local_work_size_ptr = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return;
|
||||||
|
|
||||||
|
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) {
|
static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) {
|
||||||
GGML_ASSERT(src0);
|
GGML_ASSERT(src0);
|
||||||
GGML_ASSERT(src0->extra);
|
GGML_ASSERT(src0->extra);
|
||||||
|
|
@ -9775,6 +9907,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
|
||||||
}
|
}
|
||||||
func = ggml_cl_expm1;
|
func = ggml_cl_expm1;
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_SOFTPLUS:
|
||||||
|
if (!any_on_device) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
func = ggml_cl_softplus;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
} break;
|
} break;
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,88 @@
|
||||||
|
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||||
|
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
// softplus
|
||||||
|
//------------------------------------------------------------------------------
|
||||||
|
inline float softplus_f32(float x){
|
||||||
|
float ax = fabs(x);
|
||||||
|
float m = fmax(x, 0.0f);
|
||||||
|
return log1p(exp(-ax)) + m;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_softplus_f32_nd(
|
||||||
|
global void * p_src0_base,
|
||||||
|
ulong off_src0_abs,
|
||||||
|
global void * p_dst_base,
|
||||||
|
ulong off_dst_abs,
|
||||||
|
int ne00,
|
||||||
|
int ne01,
|
||||||
|
int ne02,
|
||||||
|
int ne03,
|
||||||
|
ulong nb00,
|
||||||
|
ulong nb01,
|
||||||
|
ulong nb02,
|
||||||
|
ulong nb03,
|
||||||
|
int ne10,
|
||||||
|
int ne11,
|
||||||
|
int ne12,
|
||||||
|
int ne13,
|
||||||
|
ulong nb10,
|
||||||
|
ulong nb11,
|
||||||
|
ulong nb12,
|
||||||
|
ulong nb13
|
||||||
|
) {
|
||||||
|
int i0 = get_global_id(0);
|
||||||
|
int i1 = get_global_id(1);
|
||||||
|
int i2 = get_global_id(2);
|
||||||
|
|
||||||
|
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
||||||
|
for (int i3 = 0; i3 < ne13; ++i3) {
|
||||||
|
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
||||||
|
global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
||||||
|
|
||||||
|
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
||||||
|
global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
||||||
|
|
||||||
|
*dst_val_ptr = softplus_f32(*src_val_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_softplus_f16_nd(
|
||||||
|
global void * p_src0_base,
|
||||||
|
ulong off_src0_abs,
|
||||||
|
global void * p_dst_base,
|
||||||
|
ulong off_dst_abs,
|
||||||
|
int ne00,
|
||||||
|
int ne01,
|
||||||
|
int ne02,
|
||||||
|
int ne03,
|
||||||
|
ulong nb00,
|
||||||
|
ulong nb01,
|
||||||
|
ulong nb02,
|
||||||
|
ulong nb03,
|
||||||
|
int ne10,
|
||||||
|
int ne11,
|
||||||
|
int ne12,
|
||||||
|
int ne13,
|
||||||
|
ulong nb10,
|
||||||
|
ulong nb11,
|
||||||
|
ulong nb12,
|
||||||
|
ulong nb13
|
||||||
|
) {
|
||||||
|
int i0 = get_global_id(0);
|
||||||
|
int i1 = get_global_id(1);
|
||||||
|
int i2 = get_global_id(2);
|
||||||
|
|
||||||
|
if (i0 < ne10 && i1 < ne11 && i2 < ne12) {
|
||||||
|
for (int i3 = 0; i3 < ne13; ++i3) {
|
||||||
|
ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03;
|
||||||
|
global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor);
|
||||||
|
|
||||||
|
ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13;
|
||||||
|
global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor);
|
||||||
|
|
||||||
|
*dst_val_ptr = (half)(softplus_f32((float)(*src_val_ptr)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue