117 lines
3.7 KiB
Common Lisp
117 lines
3.7 KiB
Common Lisp
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
|
|
//------------------------------------------------------------------------------
|
|
// softplus
|
|
//------------------------------------------------------------------------------
|
|
|
|
kernel void kernel_softplus_f32(
|
|
global const float * src0,
|
|
ulong offset0,
|
|
global float * dst,
|
|
ulong offsetd
|
|
) {
|
|
src0 = (global float*)((global char*)src0 + offset0);
|
|
dst = (global float*)((global char*)dst + offsetd);
|
|
|
|
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
|
|
}
|
|
|
|
kernel void kernel_softplus_f32_4(
|
|
global const float4 * src0,
|
|
ulong offset0,
|
|
global float4 * dst,
|
|
ulong offsetd
|
|
) {
|
|
src0 = (global float4*)((global char*)src0 + offset0);
|
|
dst = (global float4*)((global char*)dst + offsetd);
|
|
|
|
dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)]));
|
|
}
|
|
|
|
kernel void kernel_softplus_f16(
|
|
global const half * src0,
|
|
ulong offset0,
|
|
global half * dst,
|
|
ulong offsetd
|
|
) {
|
|
src0 = (global half*)((global char*)src0 + offset0);
|
|
dst = (global half*)((global char*)dst + offsetd);
|
|
|
|
const float x = convert_float(src0[get_global_id(0)]);
|
|
dst[get_global_id(0)] = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
|
|
}
|
|
|
|
kernel void kernel_softplus_f16_4(
|
|
global const half4 * src0,
|
|
ulong offset0,
|
|
global half4 * dst,
|
|
ulong offsetd
|
|
) {
|
|
src0 = (global half4*)((global char*)src0 + offset0);
|
|
dst = (global half4*)((global char*)dst + offsetd);
|
|
|
|
const float4 x = convert_float4(src0[get_global_id(0)]);
|
|
dst[get_global_id(0)] = convert_half4_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
|
|
}
|
|
|
|
kernel void kernel_softplus_f32_nc(
|
|
global const char * src0,
|
|
ulong offset0,
|
|
global char * dst,
|
|
ulong offsetd,
|
|
int ne00,
|
|
ulong nb00,
|
|
ulong nb01,
|
|
ulong nb02,
|
|
ulong nb03,
|
|
ulong nb0,
|
|
ulong nb1,
|
|
ulong nb2,
|
|
ulong nb3
|
|
) {
|
|
src0 = src0 + offset0;
|
|
dst = dst + offsetd;
|
|
|
|
const int i3 = get_group_id(2);
|
|
const int i2 = get_group_id(1);
|
|
const int i1 = get_group_id(0);
|
|
|
|
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
|
|
global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
*y = (*x > 20.0f) ? *x : log(1.0f + exp(*x));
|
|
}
|
|
}
|
|
|
|
kernel void kernel_softplus_f16_nc(
|
|
global const char * src0,
|
|
ulong offset0,
|
|
global char * dst,
|
|
ulong offsetd,
|
|
int ne00,
|
|
ulong nb00,
|
|
ulong nb01,
|
|
ulong nb02,
|
|
ulong nb03,
|
|
ulong nb0,
|
|
ulong nb1,
|
|
ulong nb2,
|
|
ulong nb3
|
|
) {
|
|
src0 = src0 + offset0;
|
|
dst = dst + offsetd;
|
|
|
|
const int i3 = get_group_id(2);
|
|
const int i2 = get_group_id(1);
|
|
const int i1 = get_group_id(0);
|
|
|
|
for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) {
|
|
global const half * hx = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
global half * hy = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
const float x = convert_float(*hx);
|
|
*hy = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x)));
|
|
}
|
|
}
|