diff --git a/docs/ops.md b/docs/ops.md index cab0b37a77..019b087cef 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -47,6 +47,7 @@ Legend: | FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | | FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | +| GATED_DELTA_NET | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | | GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | @@ -92,7 +93,7 @@ Legend: | SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ | | SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | -| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | +| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | diff --git a/docs/ops/Vulkan.csv b/docs/ops/Vulkan.csv index 011dffcd8e..e693154968 100644 --- a/docs/ops/Vulkan.csv +++ b/docs/ops/Vulkan.csv @@ -1,8 +1,8 @@ "backend_name","op_name","op_params","test_mode","supported","error_message","backend_reg_name" "Vulkan0","ABS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","ABS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" @@ -85,8 +85,8 @@ "Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan" "Vulkan0","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" -"Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan" -"Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan" +"Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" +"Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" "Vulkan0","NEG","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan" "Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan" @@ -13591,3 +13591,16 @@ "Vulkan0","CROSS_ENTROPY_LOSS_BACK","type=f32,ne=[30000,1,1,1]","support","0","no","Vulkan" "Vulkan0","OPT_STEP_ADAMW","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" "Vulkan0","OPT_STEP_SGD","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan" +"Vulkan0","GATED_DELTA_NET","type=f32,head_count=32,head_size=128,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan" +"Vulkan0","GATED_DELTA_NET","type=f32,head_count=16,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan" +"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan" +"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan" +"Vulkan0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=0","support","0","no","Vulkan" +"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=0","support","0","no","Vulkan" +"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=1,kda=0","support","0","no","Vulkan" +"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan" +"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan" +"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=32,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan" +"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan" +"Vulkan0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=1","support","0","no","Vulkan" +"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","0","no","Vulkan" diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 70e992ce23..61d112c50a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -763,6 +763,7 @@ struct vk_device_struct { vk_pipeline pipeline_ceil[2]; vk_pipeline pipeline_floor[2]; vk_pipeline pipeline_trunc[2]; + vk_pipeline pipeline_sgn[2]; vk_pipeline pipeline_add1_f16_f16; vk_pipeline pipeline_add1_f16_f32; @@ -4393,6 +4394,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(ceil) CREATE_UNARY(floor) CREATE_UNARY(trunc) + CREATE_UNARY(sgn) #undef CREATE_UNARY #define CREATE_UNARY_RTE(name) \ @@ -9281,6 +9283,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_TRUNC: return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_SGN: + return ctx->device->pipeline_sgn[dst->type == GGML_TYPE_F16]; default: break; } @@ -12875,6 +12879,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_FLOOR: case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_SGN: ggml_vk_unary(ctx, compute_ctx, src0, node); break; case GGML_UNARY_OP_XIELU: @@ -15004,6 +15009,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_FLOOR: case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_SGN: return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && @@ -16170,6 +16176,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_TRUNC: tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_SGN: + tensor_clone = ggml_sgn(ggml_ctx, src_clone[0]); + break; default: std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp new file mode 100644 index 0000000000..a9c147bf9a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp @@ -0,0 +1,21 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + data_d[i] = D_TYPE(sign(float(data_a[i]))); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ed077dfb6c..fb8941232b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -871,6 +871,8 @@ void process_shaders() { string_to_spv("elu_f32", "elu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("sgn_f16", "sgn.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("sgn_f32", "sgn.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});