Merge c25eb6f7c5 into 2aa45ef9e3
This commit is contained in:
commit
3c8b4f5203
|
|
@ -732,6 +732,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
|
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
|
||||||
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
|
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
|
||||||
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
|
||||||
|
vk_pipeline pipeline_get_rel_pos_f32, pipeline_get_rel_pos_f16;
|
||||||
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
||||||
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
||||||
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
|
vk_pipeline pipeline_topk_f32[num_topk_pipelines];
|
||||||
|
|
@ -4038,6 +4039,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_get_rel_pos_f32, "get_rel_pos_f32", get_rel_pos_f32_len, get_rel_pos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_get_rel_pos_f16, "get_rel_pos_f16", get_rel_pos_f16_len, get_rel_pos_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
|
for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
|
||||||
uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
|
uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
|
||||||
if (i <= device->max_workgroup_size_log2 &&
|
if (i <= device->max_workgroup_size_log2 &&
|
||||||
|
|
@ -8846,6 +8850,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_fill_f32;
|
return ctx->device->pipeline_fill_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_GET_REL_POS:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_get_rel_pos_f32;
|
||||||
|
}
|
||||||
|
if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||||
|
return ctx->device->pipeline_get_rel_pos_f16;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
@ -9137,6 +9149,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
case GGML_OP_GLU:
|
case GGML_OP_GLU:
|
||||||
case GGML_OP_CONV_2D_DW:
|
case GGML_OP_CONV_2D_DW:
|
||||||
|
case GGML_OP_GET_REL_POS:
|
||||||
{
|
{
|
||||||
uint32_t ne = ggml_nelements(dst);
|
uint32_t ne = ggml_nelements(dst);
|
||||||
if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
|
||||||
|
|
@ -10297,6 +10310,11 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
||||||
ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride));
|
ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_get_rel_pos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
|
vk_op_unary_push_constants pc = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
|
||||||
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GET_REL_POS, pc);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||||
const uint32_t * op_params = (const uint32_t *)dst->op_params;
|
const uint32_t * op_params = (const uint32_t *)dst->op_params;
|
||||||
|
|
||||||
|
|
@ -12060,6 +12078,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_OP_ROPE_BACK:
|
case GGML_OP_ROPE_BACK:
|
||||||
ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true);
|
ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_GET_REL_POS:
|
||||||
|
ggml_vk_get_rel_pos(ctx, compute_ctx, src0, node);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
if (ctx->num_additional_fused_ops) {
|
if (ctx->num_additional_fused_ops) {
|
||||||
|
|
@ -14098,6 +14120,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_DIAG:
|
case GGML_OP_DIAG:
|
||||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||||
op->type == op->src[0]->type;
|
op->type == op->src[0]->type;
|
||||||
|
case GGML_OP_GET_REL_POS:
|
||||||
|
return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
|
||||||
|
(op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16);
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
{
|
{
|
||||||
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
|
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "types.glsl"
|
||||||
|
#include "generic_unary_head.glsl"
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint idx = get_idx();
|
||||||
|
if (idx >= p.ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
|
||||||
|
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
|
||||||
|
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
|
||||||
|
const uint i12_offset = i12*p.ne11*p.ne10;
|
||||||
|
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
|
||||||
|
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
|
||||||
|
|
||||||
|
const float kh = float(p.ne11);
|
||||||
|
const float qh = float(p.ne12);
|
||||||
|
const float k_scale = max(qh / kh, 1.0f);
|
||||||
|
const float q_scale = max(kh / qh, 1.0f);
|
||||||
|
|
||||||
|
// Add a small epsilon to avoid floating point precision issues
|
||||||
|
const float epsilon = 0.0001f;
|
||||||
|
const int pos = int(float(i12) * q_scale - float(i11) * k_scale + (kh - 1.0f) * k_scale + epsilon);
|
||||||
|
|
||||||
|
const uint src_idx = pos*p.nb01 + i10*p.nb00;
|
||||||
|
const uint dst_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
|
||||||
|
|
||||||
|
data_d[get_doffset() + dst_idx] = D_TYPE(data_a[get_aoffset() + src_idx]);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -930,6 +930,9 @@ void process_shaders() {
|
||||||
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
|
||||||
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
|
||||||
|
|
||||||
|
string_to_spv("get_rel_pos_f32", "get_rel_pos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
string_to_spv("get_rel_pos_f16", "get_rel_pos.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
|
||||||
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
||||||
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
|
string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue