vulkan: support GGML_OP_SET (#19584)

This commit is contained in:
Jeff Bolz 2026-02-13 21:36:38 -08:00 committed by GitHub
parent 2dec548094
commit 53aef25a88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 4 deletions

View File

@ -688,6 +688,7 @@ struct vk_device_struct {
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_set_f32;
// [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
vk_pipeline pipeline_add[2][2][2];
@ -4182,7 +4183,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@ -8822,6 +8824,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_acc_f32;
}
return nullptr;
case GGML_OP_SET:
if (src0->type == src1->type && src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
return ctx->device->pipeline_set_f32;
}
return nullptr;
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
@ -9813,7 +9821,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
int offset = dst->op_params[3] / src0_type_size; // offset in bytes
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@ -12507,6 +12515,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
case GGML_OP_ACC:
case GGML_OP_SET:
ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
break;
@ -14967,7 +14976,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
return op->src[0]->type == GGML_TYPE_F32;
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_SET:
return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
case GGML_OP_CONCAT:
return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
case GGML_OP_ADD1:
@ -15618,6 +15630,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_ACC) {
tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_SET) {
tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_NORM) {
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_GROUP_NORM) {

View File

@ -3,6 +3,9 @@
#include "types.glsl"
#include "generic_binary_head.glsl"
// false for SET, true for ACC
layout(constant_id = 1) const bool ACC = true;
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
@ -23,7 +26,11 @@ void main() {
uint i00, i01, i02, i03;
if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
if (ACC) {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
} else {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
}
} else {
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));
}