This commit is contained in:
Yihao Wang 2025-12-16 16:23:07 +08:00 committed by GitHub
commit 3c8b4f5203
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 63 additions and 0 deletions

View File

@ -732,6 +732,7 @@ struct vk_device_struct {
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
vk_pipeline pipeline_get_rel_pos_f32, pipeline_get_rel_pos_f16;
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_topk_f32[num_topk_pipelines];
@ -4038,6 +4039,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
} }
ggml_vk_create_pipeline(device, device->pipeline_get_rel_pos_f32, "get_rel_pos_f32", get_rel_pos_f32_len, get_rel_pos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rel_pos_f16, "get_rel_pos_f16", get_rel_pos_f16_len, get_rel_pos_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2); uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
if (i <= device->max_workgroup_size_log2 && if (i <= device->max_workgroup_size_log2 &&
@ -8846,6 +8850,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_fill_f32; return ctx->device->pipeline_fill_f32;
} }
return nullptr; return nullptr;
case GGML_OP_GET_REL_POS:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_get_rel_pos_f32;
}
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_get_rel_pos_f16;
}
return nullptr;
default: default:
return nullptr; return nullptr;
} }
@ -9137,6 +9149,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_UNARY: case GGML_OP_UNARY:
case GGML_OP_GLU: case GGML_OP_GLU:
case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_2D_DW:
case GGML_OP_GET_REL_POS:
{ {
uint32_t ne = ggml_nelements(dst); uint32_t ne = ggml_nelements(dst);
if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
@ -10297,6 +10310,11 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride)); ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride));
} }
static void ggml_vk_get_rel_pos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_unary_push_constants pc = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GET_REL_POS, pc);
}
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
const uint32_t * op_params = (const uint32_t *)dst->op_params; const uint32_t * op_params = (const uint32_t *)dst->op_params;
@ -12060,6 +12078,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_ROPE_BACK: case GGML_OP_ROPE_BACK:
ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true); ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true);
break;
case GGML_OP_GET_REL_POS:
ggml_vk_get_rel_pos(ctx, compute_ctx, src0, node);
break; break;
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
if (ctx->num_additional_fused_ops) { if (ctx->num_additional_fused_ops) {
@ -14098,6 +14120,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_DIAG: case GGML_OP_DIAG:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->type == op->src[0]->type; op->type == op->src[0]->type;
case GGML_OP_GET_REL_POS:
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
{ {
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {

View File

@ -0,0 +1,35 @@
#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
const uint i12_offset = i12*p.ne11*p.ne10;
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
const float kh = float(p.ne11);
const float qh = float(p.ne12);
const float k_scale = max(qh / kh, 1.0f);
const float q_scale = max(kh / qh, 1.0f);
// Add a small epsilon to avoid floating point precision issues
const float epsilon = 0.0001f;
const int pos = int(float(i12) * q_scale - float(i11) * k_scale + (kh - 1.0f) * k_scale + epsilon);
const uint src_idx = pos*p.nb01 + i10*p.nb00;
const uint dst_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + src_idx]);
}

View File

@ -930,6 +930,9 @@ void process_shaders() {
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
string_to_spv("get_rel_pos_f32", "get_rel_pos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("get_rel_pos_f16", "get_rel_pos.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});