ggml vulkan: add hardsigmoid and hardswish operations (#15762)
This commit is contained in:
parent
661ae31c9c
commit
0014fb4add
|
|
@ -529,6 +529,8 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_relu[2];
|
vk_pipeline pipeline_relu[2];
|
||||||
vk_pipeline pipeline_tanh[2];
|
vk_pipeline pipeline_tanh[2];
|
||||||
vk_pipeline pipeline_sigmoid[2];
|
vk_pipeline pipeline_sigmoid[2];
|
||||||
|
vk_pipeline pipeline_hardsigmoid[2];
|
||||||
|
vk_pipeline pipeline_hardswish[2];
|
||||||
|
|
||||||
vk_pipeline pipeline_geglu[2];
|
vk_pipeline pipeline_geglu[2];
|
||||||
vk_pipeline pipeline_reglu[2];
|
vk_pipeline pipeline_reglu[2];
|
||||||
|
|
@ -3261,6 +3263,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_UNARY(relu)
|
CREATE_UNARY(relu)
|
||||||
CREATE_UNARY(tanh)
|
CREATE_UNARY(tanh)
|
||||||
CREATE_UNARY(sigmoid)
|
CREATE_UNARY(sigmoid)
|
||||||
|
CREATE_UNARY(hardsigmoid)
|
||||||
|
CREATE_UNARY(hardswish)
|
||||||
#undef CREATE_UNARY
|
#undef CREATE_UNARY
|
||||||
|
|
||||||
#define CREATE_GLU(name) \
|
#define CREATE_GLU(name) \
|
||||||
|
|
@ -7533,6 +7537,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
|
return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
|
return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
|
||||||
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
|
return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
|
||||||
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
|
return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -10201,6 +10209,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -10571,6 +10581,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
|
ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|
@ -10813,6 +10825,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
buf = tensor->buffer;
|
buf = tensor->buffer;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|
@ -11764,6 +11778,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_UNARY_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
case GGML_UNARY_OP_TANH:
|
case GGML_UNARY_OP_TANH:
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
return ggml_is_contiguous(op->src[0]) &&
|
return ggml_is_contiguous(op->src[0]) &&
|
||||||
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||||
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||||
|
|
@ -12580,6 +12596,12 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||||
case GGML_UNARY_OP_SIGMOID:
|
case GGML_UNARY_OP_SIGMOID:
|
||||||
tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
|
tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
|
||||||
break;
|
break;
|
||||||
|
case GGML_UNARY_OP_HARDSIGMOID:
|
||||||
|
tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]);
|
||||||
|
break;
|
||||||
|
case GGML_UNARY_OP_HARDSWISH:
|
||||||
|
tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.comp"
|
||||||
|
#include "types.comp"
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
if (i >= p.KX) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float x = float(data_a[i]);
|
||||||
|
data_d[i] = D_TYPE(min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.comp"
|
||||||
|
#include "types.comp"
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
if (i >= p.KX) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float x = float(data_a[i]);
|
||||||
|
data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
|
||||||
|
}
|
||||||
|
|
@ -657,6 +657,10 @@ void process_shaders() {
|
||||||
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
string_to_spv("hardsigmoid_f16","hardsigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
||||||
for (auto rte : {false, true}) {
|
for (auto rte : {false, true}) {
|
||||||
std::string suffix = rte ? "_rte" : "";
|
std::string suffix = rte ? "_rte" : "";
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue