ggml-vulkan: add SGN operator, auto-generate Vulkan.csv and ops.md (#20219)

This commit is contained in:
Bertay Eren 2026-03-09 09:24:16 +03:00 committed by GitHub
parent b2f460bd3c
commit 0beb8db3a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 51 additions and 5 deletions

View File

@ -47,6 +47,7 @@ Legend:
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| GATED_DELTA_NET | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
@ -92,7 +93,7 @@ Legend:
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SET_ROWS | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | | ✅ | ❌ | ❌ |
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |

View File

@ -1,8 +1,8 @@
"backend_name","op_name","op_params","test_mode","supported","error_message","backend_reg_name"
"Vulkan0","ABS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","ABS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
"Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
"Vulkan0","SGN","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","NEG","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","STEP","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
@ -85,8 +85,8 @@
"Vulkan0","TRUNC","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","Vulkan"
"Vulkan0","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","Vulkan"
"Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","Vulkan"
"Vulkan0","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
"Vulkan0","NEG","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","Vulkan"
"Vulkan0","STEP","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","Vulkan"
@ -13591,3 +13591,16 @@
"Vulkan0","CROSS_ENTROPY_LOSS_BACK","type=f32,ne=[30000,1,1,1]","support","0","no","Vulkan"
"Vulkan0","OPT_STEP_ADAMW","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
"Vulkan0","OPT_STEP_SGD","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=32,head_size=128,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=16,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=1,kda=0","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=1,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=32,n_seq_tokens=4,n_seqs=1,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=0,kda=1","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=8,head_size=32,n_seq_tokens=4,n_seqs=2,v_repeat=2,permuted=0,kda=1","support","0","no","Vulkan"
"Vulkan0","GATED_DELTA_NET","type=f32,head_count=4,head_size=64,n_seq_tokens=4,n_seqs=2,v_repeat=1,permuted=1,kda=1","support","0","no","Vulkan"

Can't render this file because it is too large.

View File

@ -763,6 +763,7 @@ struct vk_device_struct {
vk_pipeline pipeline_ceil[2];
vk_pipeline pipeline_floor[2];
vk_pipeline pipeline_trunc[2];
vk_pipeline pipeline_sgn[2];
vk_pipeline pipeline_add1_f16_f16;
vk_pipeline pipeline_add1_f16_f32;
@ -4393,6 +4394,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_UNARY(ceil)
CREATE_UNARY(floor)
CREATE_UNARY(trunc)
CREATE_UNARY(sgn)
#undef CREATE_UNARY
#define CREATE_UNARY_RTE(name) \
@ -9281,6 +9283,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_TRUNC:
return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
case GGML_UNARY_OP_SGN:
return ctx->device->pipeline_sgn[dst->type == GGML_TYPE_F16];
default:
break;
}
@ -12875,6 +12879,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_TRUNC:
case GGML_UNARY_OP_SGN:
ggml_vk_unary(ctx, compute_ctx, src0, node);
break;
case GGML_UNARY_OP_XIELU:
@ -15004,6 +15009,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_TRUNC:
case GGML_UNARY_OP_SGN:
return ggml_is_contiguous(op->src[0]) &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@ -16170,6 +16176,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_UNARY_OP_TRUNC:
tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
break;
case GGML_UNARY_OP_SGN:
tensor_clone = ggml_sgn(ggml_ctx, src_clone[0]);
break;
default:
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
GGML_ABORT("fatal error");

View File

@ -0,0 +1,21 @@
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(sign(float(data_a[i])));
}

View File

@ -871,6 +871,8 @@ void process_shaders() {
string_to_spv("elu_f32", "elu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("sgn_f16", "sgn.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("sgn_f32", "sgn.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});