From 0d009fe61a718942f9184c32594fb6ae66bca30a Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Fri, 9 May 2025 13:04:20 +0800 Subject: [PATCH] FEAT: Add all conversion code from ov side --- docs/build.md | 6 +- ggml/src/ggml-openvino/ggml-decoder.h | 2 +- .../{decoder.h => openvino/decoder.hpp} | 1 - ggml/src/ggml-openvino/openvino/frontend.cpp | 27 +++ ggml/src/ggml-openvino/openvino/frontend.hpp | 23 +++ .../ggml-openvino/openvino/input_model.cpp | 17 ++ .../ggml-openvino/openvino/input_model.hpp | 29 +++ .../ggml-openvino/openvino/node_context.hpp | 100 ++++++++++ ggml/src/ggml-openvino/openvino/op/add.cpp | 23 +++ ggml/src/ggml-openvino/openvino/op/cont.cpp | 56 ++++++ ggml/src/ggml-openvino/openvino/op/cpy.cpp | 106 +++++++++++ .../ggml-openvino/openvino/op/get_rows.cpp | 40 ++++ ggml/src/ggml-openvino/openvino/op/mul.cpp | 28 +++ ggml/src/ggml-openvino/openvino/op/mulmat.cpp | 127 +++++++++++++ .../src/ggml-openvino/openvino/op/permute.cpp | 22 +++ .../src/ggml-openvino/openvino/op/reshape.cpp | 35 ++++ .../ggml-openvino/openvino/op/rms_norm.cpp | 47 +++++ ggml/src/ggml-openvino/openvino/op/rope.cpp | 171 ++++++++++++++++++ ggml/src/ggml-openvino/openvino/op/scale.cpp | 31 ++++ .../ggml-openvino/openvino/op/soft_max.cpp | 88 +++++++++ .../ggml-openvino/openvino/op/transpose.cpp | 23 +++ ggml/src/ggml-openvino/openvino/op/unary.cpp | 24 +++ .../ggml-openvino/openvino/op/unary_silu.cpp | 29 +++ ggml/src/ggml-openvino/openvino/op/view.cpp | 26 +++ ggml/src/ggml-openvino/openvino/op_table.cpp | 64 +++++++ ggml/src/ggml-openvino/openvino/op_table.hpp | 13 ++ .../openvino/translate_session.cpp | 145 +++++++++++++++ .../openvino/translate_session.hpp | 27 +++ ggml/src/ggml-openvino/openvino/utils.cpp | 52 ++++++ ggml/src/ggml-openvino/openvino/utils.hpp | 68 +++++++ ggml/src/ggml-openvino/utils.cpp | 30 +-- 31 files changed, 1465 insertions(+), 15 deletions(-) rename ggml/src/ggml-openvino/{decoder.h => openvino/decoder.hpp} (97%) create mode 100644 ggml/src/ggml-openvino/openvino/frontend.cpp create mode 100644 ggml/src/ggml-openvino/openvino/frontend.hpp create mode 100644 ggml/src/ggml-openvino/openvino/input_model.cpp create mode 100644 ggml/src/ggml-openvino/openvino/input_model.hpp create mode 100644 ggml/src/ggml-openvino/openvino/node_context.hpp create mode 100644 ggml/src/ggml-openvino/openvino/op/add.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/cont.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/cpy.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/get_rows.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/mul.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/mulmat.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/permute.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/reshape.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/rms_norm.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/rope.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/scale.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/soft_max.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/transpose.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/unary.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/unary_silu.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/view.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op_table.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op_table.hpp create mode 100644 ggml/src/ggml-openvino/openvino/translate_session.cpp create mode 100644 ggml/src/ggml-openvino/openvino/translate_session.hpp create mode 100644 ggml/src/ggml-openvino/openvino/utils.cpp create mode 100644 ggml/src/ggml-openvino/openvino/utils.hpp diff --git a/docs/build.md b/docs/build.md index 3079a91211..bb7c4137a5 100644 --- a/docs/build.md +++ b/docs/build.md @@ -692,7 +692,11 @@ To read documentation for how to build on IBM Z & LinuxONE, [click here](./build git submodule update --init --recursive export OPENVINO_LLAMA_PATH=$(pwd) + ``` + Before building, change "ENABLE_OV_GGML_FRONTEND" from true to false in the CMakePresets.json file since we already have the code from the ov side in this branch of llama.cpp (`full_backend`). You could also build the master branch of ov instead. + + ``` cmake --preset Release cmake --build build/Release ``` @@ -700,7 +704,7 @@ To read documentation for how to build on IBM Z & LinuxONE, [click here](./build ### Build llama.cpp-ov ```bash - git clone https://github.com/intel-sandbox/llama.cpp-ov.git -b dev_backend_openvino + git clone https://github.com/intel-sandbox/llama.cpp-ov.git -b full_backend cd llama.cpp-ov cmake --preset ReleaseOV diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index a0f6cbea30..959e00b65d 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -5,8 +5,8 @@ #include #include -#include "decoder.h" #include "ggml.h" +#include "openvino/decoder.hpp" class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { public: diff --git a/ggml/src/ggml-openvino/decoder.h b/ggml/src/ggml-openvino/openvino/decoder.hpp similarity index 97% rename from ggml/src/ggml-openvino/decoder.h rename to ggml/src/ggml-openvino/openvino/decoder.hpp index 3404e7c211..3987760a29 100644 --- a/ggml/src/ggml-openvino/decoder.h +++ b/ggml/src/ggml-openvino/openvino/decoder.hpp @@ -8,7 +8,6 @@ namespace ov { namespace frontend { namespace ggml { -// TODO: Directly include from openvino class GgmlDecoder : public DecoderBase { public: virtual ov::Any get_attribute(const std::string& name) const = 0; diff --git a/ggml/src/ggml-openvino/openvino/frontend.cpp b/ggml/src/ggml-openvino/openvino/frontend.cpp new file mode 100644 index 0000000000..ff7f0e8392 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/frontend.cpp @@ -0,0 +1,27 @@ +#include "frontend.hpp" + +#include "input_model.hpp" +#include "op_table.hpp" +#include "translate_session.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +FrontEnd::FrontEnd() {} + +std::shared_ptr FrontEnd::convert(const InputModel::Ptr& model) { + auto ggml_model = std::dynamic_pointer_cast(model); + FRONT_END_GENERAL_CHECK(ggml_model, "Invalid input model"); + std::shared_ptr converted_model; + const auto& supported_ops = get_supported_ops(); + { + TranslateSession translate_session(model, supported_ops); + converted_model = translate_session.get_converted_model(); + } + return converted_model; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/frontend.hpp b/ggml/src/ggml-openvino/openvino/frontend.hpp new file mode 100644 index 0000000000..5cc7ff1773 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/frontend.hpp @@ -0,0 +1,23 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ov { +namespace frontend { +namespace ggml { + +class FrontEnd { +public: + using Ptr = std::shared_ptr; + FrontEnd(); + + static std::shared_ptr convert(const InputModel::Ptr& model); +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/input_model.cpp b/ggml/src/ggml-openvino/openvino/input_model.cpp new file mode 100644 index 0000000000..5fb16ea2db --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/input_model.cpp @@ -0,0 +1,17 @@ +#include "input_model.hpp" + +#include "decoder.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +InputModel::InputModel(const std::shared_ptr& gdecoder) : m_decoder(gdecoder) {} + +const std::shared_ptr& InputModel::get_model_decoder() const { + return m_decoder; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/input_model.hpp b/ggml/src/ggml-openvino/openvino/input_model.hpp new file mode 100644 index 0000000000..9bc9a28e9a --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/input_model.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include "decoder.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +class FrontEnd; +class GgmlDecoder; +using ov::frontend::ggml::GgmlDecoder; + +class InputModel : public ov::frontend::InputModel { + friend class ::ov::frontend::ggml::FrontEnd; + +public: + explicit InputModel(const std::shared_ptr& gdecoder); + + const std::shared_ptr& get_model_decoder() const; + +private: + std::shared_ptr m_decoder; +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/node_context.hpp b/ggml/src/ggml-openvino/openvino/node_context.hpp new file mode 100644 index 0000000000..bac135270d --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/node_context.hpp @@ -0,0 +1,100 @@ +#pragma once + +#include + +#include "decoder.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +class TranslateSession; + +typedef std::map> TensorMap; + +class NodeContext : public frontend::NodeContext { +public: + NodeContext(const std::shared_ptr& decoder, + std::shared_ptr& tensor_map, + TranslateSession* translate_session = nullptr) + : ov::frontend::NodeContext(decoder->get_op_type()), + m_decoder(decoder), + m_tensor_map(tensor_map), + m_translate_session(translate_session) { + m_input_names = decoder->get_input_names(); + m_output_names = decoder->get_output_names(); + } + + TranslateSession* get_translate_session() const { + return m_translate_session; + } + + size_t get_input_size() const override { + return m_decoder->get_input_size(); + } + + Any get_input_type(size_t index) const { + return m_decoder->get_input_type(m_input_names[index]); + } + + PartialShape get_input_shape(size_t index) const { + return m_decoder->get_input_shape(m_input_names[index]); + } + + std::vector get_input_stride(size_t index) const { + return m_decoder->get_input_stride(m_input_names[index]); + } + + PartialShape get_output_shape(size_t index) const { + return m_decoder->get_output_shape(m_output_names[index]); + } + + std::vector get_output_stride(size_t index) const { + return m_decoder->get_output_stride(m_output_names[index]); + } + + int32_t* get_input_op_params(size_t index) const { + return m_decoder->get_input_op_params(m_input_names[index]); + } + + int32_t* get_output_op_params(size_t index) const { + return m_decoder->get_output_op_params(m_output_names[index]); + } + + ov::element::Type get_output_type(size_t index) const { + return m_decoder->get_output_type(m_output_names[index]); + } + + Output get_input(int idx) const override { + return m_tensor_map->at(m_decoder->get_input_name(idx)); + } + + Output get_input(const std::string& name) const override { + return m_tensor_map->at(name); + } + + const std::string& get_name() const override { + return m_decoder->get_op_name(); + } + + ov::Any get_attribute_as_any(const std::string& name) const override { + return m_decoder->get_attribute(name); + } + + bool check_if_continuous() const { + return m_decoder->check_if_continuous(); + } + +private: + std::shared_ptr m_decoder; + std::shared_ptr& m_tensor_map; + TranslateSession* m_translate_session; + std::vector m_input_names; + std::vector m_output_names; +}; + +using CreatorFunction = std::function; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/add.cpp b/ggml/src/ggml-openvino/openvino/op/add.cpp new file mode 100644 index 0000000000..c218cf34de --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/add.cpp @@ -0,0 +1,23 @@ +#include "openvino/op/add.hpp" + +#include "../node_context.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_add(const NodeContext& context) { + num_inputs_check(context, 2, 2); + + auto lhs = context.get_input(0); + auto rhs = context.get_input(1); + auto add = std::make_shared(lhs, rhs); + return {add}; +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/cont.cpp b/ggml/src/ggml-openvino/openvino/op/cont.cpp new file mode 100644 index 0000000000..2ebc890fda --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/cont.cpp @@ -0,0 +1,56 @@ + +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../utils.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/slice.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_cont(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + auto src_shape = context.get_input_shape(0).to_shape(); + auto dst_shape = context.get_output_shape(0).to_shape(); + + bool continuous = context.check_if_continuous(); + if (continuous) { + // The input comes from a PERMUTE + dst_shape[1] = -1; + auto result = std::make_shared( + context.get_input(0), + ov::op::v0::Constant::create(ov::element::i64, {dst_shape.size()}, dst_shape), + false); + + return {result}; + } else { + // The input comes from a VIEW + // Currently all cases are slicing at lowest dim + int32_t* op_params = context.get_input_op_params(0); + auto output_stride = context.get_output_stride(0); + + int64_t split_addr = op_params[0] / output_stride[2]; + std::vector begin = {0, 0, split_addr}; + std::vector end = {(int64_t)src_shape[0], INT_MAX, split_addr + (int64_t)src_shape[2]}; + std::vector strides = {1, 1, 1}; + + auto begin_const = ov::op::v0::Constant::create(ov::element::i64, {begin.size()}, begin); + auto end_const = ov::op::v0::Constant::create(ov::element::i64, {end.size()}, end); + auto strides_const = ov::op::v0::Constant::create(ov::element::i64, {strides.size()}, strides); + auto slice = std::make_shared(context.get_input(0), begin_const, end_const, strides_const); + + return {slice}; + } +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/cpy.cpp b/ggml/src/ggml-openvino/openvino/op/cpy.cpp new file mode 100644 index 0000000000..b4f4d59408 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/cpy.cpp @@ -0,0 +1,106 @@ +#include +#include +#include + +#include "../node_context.hpp" +#include "../utils.hpp" +#include "openvino/core/node.hpp" +#include "openvino/core/node_output.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_nd_update.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_cpy(const NodeContext& context) { + num_inputs_check(context, 2, 2); + auto src0 = context.get_input(0); + auto src1 = context.get_input(1); + auto past_token_len = context.get_input("past_token_len"); + + auto src0_shape = context.get_input_shape(0).to_shape(); + auto output_shape = context.get_output_shape(0).to_shape(); + bool continuous = context.check_if_continuous(); + + std::vector input0_strides = context.get_input_stride(0); + std::vector output_strides = context.get_output_stride(0); + + auto one = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1}); + + src0 = std::make_shared(src0, src1); + if (continuous) { + // Write K to cache_k + int64_t head_size = src0_shape[2]; + int64_t num_heads = src0_shape[1]; + + auto reshaped_src1_shape = + ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector{-1, num_heads, head_size}); + auto reshaped_src1 = std::make_shared(src1, reshaped_src1_shape, false); + + auto token_len = get_dimensions(src0.get_node_shared_ptr(), {0}); + token_len = std::make_shared(token_len, + ov::op::v0::Constant::create(ov::element::i64, {0}, {}), + false); + auto total_token_len = std::make_shared(past_token_len, token_len); + std::shared_ptr indices = + std::make_shared(past_token_len, total_token_len, one, ov::element::i64); + indices = std::make_shared( + indices, + ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{1})); + + auto res = std::make_shared(reshaped_src1, indices, src0); + return {res}; + } else { + // Write V to cache_v + int64_t total_head_size = src0_shape[1]; + + auto reshaped_src0 = std::make_shared( + src0, + ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector{total_head_size, -1}), + false); + auto transposed_src0 = + std::make_shared(reshaped_src0, + ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 0})); + + auto reshaped_src1 = std::make_shared( + src1, + ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector{total_head_size, -1}), + false); + auto transposed_src1 = + std::make_shared(reshaped_src1, + ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 0})); + + auto token_len = get_dimensions(src0.get_node_shared_ptr(), {2}); + token_len = std::make_shared(token_len, + ov::op::v0::Constant::create(ov::element::i64, {0}, {}), + false); + auto total_token_len = std::make_shared(past_token_len, token_len); + std::shared_ptr indices = + std::make_shared(past_token_len, total_token_len, one, ov::element::i64); + indices = std::make_shared( + indices, + ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{1})); + + auto res = std::make_shared(transposed_src1, indices, transposed_src0); + auto transposed_res = + std::make_shared(res, ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 0})); + auto reshaped_res = std::make_shared( + transposed_res, + ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector{1, total_head_size, -1}), + false); + return {reshaped_res}; + } +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp new file mode 100644 index 0000000000..edb25d9124 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp @@ -0,0 +1,40 @@ +#include +#include + +#include "../node_context.hpp" +#include "../utils.hpp" +#include "openvino/core/node.hpp" +#include "openvino/core/node_output.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/reshape.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_get_rows(const NodeContext& context) { + num_inputs_check(context, 2, 2); + + auto data_node = context.get_input(0); + auto indices_node = context.get_input(1); + + auto indices_shape = get_dimensions(indices_node.get_node_shared_ptr(), {2}); + Output indice_reshaped = std::make_shared(indices_node, indices_shape, false); + + auto axis_node = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); + + Output res = std::make_shared(data_node, indice_reshaped, axis_node); + if (res.get_element_type() != context.get_output_type(0)) { + res = std::make_shared(res, context.get_output_type(0)); + } + + return {res}; +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/mul.cpp b/ggml/src/ggml-openvino/openvino/op/mul.cpp new file mode 100644 index 0000000000..1b1c69f7df --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/mul.cpp @@ -0,0 +1,28 @@ +#include +#include + +#include "../node_context.hpp" +#include "../utils.hpp" +#include "openvino/core/node.hpp" +#include "openvino/core/node_output.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_mul(const NodeContext& context) { + num_inputs_check(context, 2, 2); + + auto res = std::make_shared(context.get_input(0), context.get_input(1)); + return {res}; +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/mulmat.cpp b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp new file mode 100644 index 0000000000..e00435ef81 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp @@ -0,0 +1,127 @@ +#include +#include +#include +#include + +#include "../node_context.hpp" +#include "../utils.hpp" +#include "openvino/core/node.hpp" +#include "openvino/core/node_output.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/transpose.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_mulmat(const NodeContext& context) { + num_inputs_check(context, 2, 2); + + bool continuous = context.check_if_continuous(); + if (continuous) { + auto src1 = context.get_input(1); + auto src0_converted = std::make_shared(context.get_input(0), src1); + auto result = std::make_shared(src1, src0_converted, false, true); + return {result}; + } else { + /* + Two cases here: + - 21: [ 96, 32, 32, 1] VIEW k-0 [ 2, 6144, 192, 6144] + [ 196608, 1, 1, 1] 0: NONE cache_k_l0 [ 2, 393216, 393216, 393216] + - 22: [ 96, 7, 32, 1] PERMUTE q-0 [ 4, 12288, 384, 86016] + [ 96, 32, 7, 1] 0: SCALE Qcur-0 [ 4, 384, 12288, 86016] + - 23: [ 32, 7, 32, 1] MUL_MAT kq-0 [ 4, 128, 896, 28672] + [ 96, 32, 32, 1] 0: VIEW k-0 [ 2, 6144, 192, 6144] + [ 96, 7, 32, 1] 1: PERMUTE q-0 [ 4, 12288, 384, 86016] + + - 20: [ 32, 96, 32, 1] VIEW v-0 [ 2, 128, 12288, 393216] + [ 196608, 1, 1, 1] 0: NONE cache_v_l0 [ 2, 393216, 393216, 393216] + - 25: [ 96, 7, 32, 1] MUL_MAT kqv-0 [ 4, 384, 2688, 86016] + [ 32, 96, 32, 1] 0: VIEW v-0 [ 2, 128, 12288, 393216] + [ 32, 7, 32, 1] 1: SOFT_MAX kq_soft_max_ext-0 [ 4, 128, 896, 28672] + + For case 1, for src0, Reshape + Slice + Transpose + For case 2, for src0, Reshape + Slice + */ + ov::Output A; + ov::Output B; + + auto attention_size = context.get_input("attention_size"); + + auto src0 = context.get_input(0); + auto src0_shape = context.get_input_shape(0).to_shape(); + auto src0_stride = context.get_input_stride(0); + auto permuted = is_permuted(src0_stride); + auto token_dim = permuted ? 0 : 2; + + auto src0_perm = argsort_descend(src0_stride); + auto src0_original_shape_ = permute(src0_shape, src0_perm); + std::vector src0_original_shape(src0_original_shape_.begin(), src0_original_shape_.end()); + src0_original_shape[token_dim] = -1; + + auto src0_slice_shape = src0_original_shape; + src0_slice_shape.erase(src0_slice_shape.begin() + token_dim); + + auto src0_reshape_shape = + ov::op::v0::Constant::create(ov::element::i64, {src0_original_shape.size()}, src0_original_shape); + auto src0_reshape = std::make_shared(src0, src0_reshape_shape, false); + + std::shared_ptr slice_end; + if (permuted) { + slice_end = std::make_shared( + ov::OutputVector{attention_size, ov::op::v0::Constant::create(ov::element::i64, {2}, src0_slice_shape)}, + 0); + } else { + slice_end = std::make_shared( + ov::OutputVector{ov::op::v0::Constant::create(ov::element::i64, {2}, src0_slice_shape), attention_size}, + 0); + } + auto slice_start = ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector(3, 0)); + auto slice_step = ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector(3, 1)); + auto src0_slice = std::make_shared(src0_reshape, slice_start, slice_end, slice_step); + + if (permuted) { + B = std::make_shared( + src0_slice, + ov::op::v0::Constant::create(ov::element::i64, {src0_perm.size()}, src0_perm)); + } else { + B = src0_slice; + } + + A = context.get_input(1); + B = std::make_shared(B, A); + + int64_t num_heads = context.get_input_shape(1).to_shape()[0]; + int64_t num_heads_kv = src0_shape[0]; + int64_t kv_num_heads_factor = num_heads / num_heads_kv; + if (kv_num_heads_factor > 1) { + auto num_heads_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{num_heads}); + auto num_heads_kv_node = + ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{num_heads_kv}); + auto B_shape_last_two = get_dimensions(B.get_node_shared_ptr(), {1, 2}); + + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + std::shared_ptr new_B_shape = + std::make_shared(ov::OutputVector{num_heads_kv_node, one, B_shape_last_two}, 0); + B = std::make_shared(B, new_B_shape, false); + + B = std::make_shared(ov::OutputVector(kv_num_heads_factor, B), 1); + new_B_shape = std::make_shared(ov::OutputVector{num_heads_node, B_shape_last_two}, 0); + B = std::make_shared(B, new_B_shape, false); + } + + auto result = std::make_shared(A, B, false, true); + return {result}; + } +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/permute.cpp b/ggml/src/ggml-openvino/openvino/op/permute.cpp new file mode 100644 index 0000000000..42472f18cc --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/permute.cpp @@ -0,0 +1,22 @@ +#include "../node_context.hpp" +#include "../utils.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/transpose.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { +OutputVector translate_permute(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + // TODO: make this more general + auto res = std::make_shared(context.get_input(0), + ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2})); + + return {res}; +}; +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/reshape.cpp b/ggml/src/ggml-openvino/openvino/op/reshape.cpp new file mode 100644 index 0000000000..ca18b72c42 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/reshape.cpp @@ -0,0 +1,35 @@ +#include "openvino/op/reshape.hpp" + +#include +#include + +#include "../node_context.hpp" +#include "../utils.hpp" +#include "openvino/core/node.hpp" +#include "openvino/core/node_output.hpp" +#include "openvino/op/constant.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_reshape(const NodeContext& context) { + num_inputs_check(context, 1, 1); + if (context.get_input_shape(0) == context.get_output_shape(0)) { + return {context.get_input(0)}; + } + + auto output_shape = context.get_output_shape(0).to_shape(); + auto new_shape_node = + ov::op::v0::Constant::create(ov::element::i64, + {3}, + std::vector{-1, (int64_t)output_shape[1], (int64_t)output_shape[2]}); + Output res = std::make_shared(context.get_input(0), new_shape_node, false); + return {res}; +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp b/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp new file mode 100644 index 0000000000..7b9783e8c9 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp @@ -0,0 +1,47 @@ +#include "../node_context.hpp" +#include "../utils.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reduce_sum.hpp" +#include "openvino/op/sqrt.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_rms_norm(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + ov::Shape input_shape = context.get_input_shape(0).to_shape(); + auto input_node = context.get_input(0); + auto square = std::make_shared(input_node, input_node); + + auto reduce_sum = + std::make_shared(square, + ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}), + true); + + auto mean = std::make_shared( + reduce_sum, + ov::op::v0::Constant::create(ov::element::f32, ov::Shape{}, {static_cast(input_shape[2])})); + + float eps; + memcpy(&eps, context.get_output_op_params(0), sizeof(float)); + auto rms = std::make_shared( + std::make_shared(mean, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{}, {eps}))); + + auto scale = + std::make_shared(ov::op::v0::Constant::create(ov::element::f32, ov::Shape{}, {1.0f}), rms); + + auto res = std::make_shared(input_node, scale); + + return {res}; +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp new file mode 100644 index 0000000000..d5083ae14b --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -0,0 +1,171 @@ + +#include +#include +#include + +#include "../node_context.hpp" +#include "../utils.hpp" +#include "openvino/core/node.hpp" +#include "openvino/core/node_output.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/cos.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/sin.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/split.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/transpose.hpp" + +#define GGML_ROPE_TYPE_NEOX 2 + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +} + +void ggml_rope_yarn_corr_dims(int n_dims, + int n_ctx_orig, + float freq_base, + float beta_fast, + float beta_slow, + float dims[2]) { + float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); + float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); + dims[0] = MAX(0, start); + dims[1] = MIN(n_dims - 1, end); +} + +OutputVector translate_rope(const NodeContext& context) { + num_inputs_check(context, 2, 3); + + auto data_node = context.get_input(0); + auto pos_node = context.get_input(1); + pos_node = std::make_shared(pos_node, ov::element::f32); + + auto permutation_node = + std::make_shared(ov::element::i64, ov::Shape{3}, std::vector{2, 1, 0}); + Output pos_node_reshaped = std::make_shared(pos_node, permutation_node); + + auto output_shape = context.get_output_shape(0); + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int32_t* op_params = context.get_output_op_params(0); + const int n_dims = op_params[1]; + const int mode = op_params[2]; + const int n_ctx_orig = op_params[4]; + memcpy(&freq_base, op_params + 5, sizeof(float)); + memcpy(&freq_scale, op_params + 6, sizeof(float)); + memcpy(&ext_factor, op_params + 7, sizeof(float)); + memcpy(&attn_factor, op_params + 8, sizeof(float)); + memcpy(&beta_fast, op_params + 9, sizeof(float)); + memcpy(&beta_slow, op_params + 10, sizeof(float)); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + // TODO: corr_dims is not used in the current implementation + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + + // TODO: GGML_OP_ROPE_BACK -> false + bool forward = true; + const float sin_sign = forward ? 1.0f : -1.0f; + + const int64_t ne0 = output_shape[2].get_length(); + std::vector factor(ne0 / 2); + factor[0] = freq_scale; + for (int64_t i = 1; i < ne0 / 2; i++) { + factor[i] = theta_scale * factor[i - 1]; + } + + Output factor_node = + std::make_shared(ov::element::f32, ov::Shape{factor.size()}, factor); + if (context.get_input_size() == 3) { + auto freq_factors_node = context.get_input(2); + factor_node = std::make_shared(factor_node, freq_factors_node); + } + + auto half_last_dim = ov::op::v0::Constant::create(ov::element::i64, Shape{1}, {output_shape[2].get_length() / 2}); + Output input_shape_node = std::make_shared( + OutputVector{get_dimensions(data_node.get_node_shared_ptr(), {0, 1}), half_last_dim}, + 0); + Output factor_broadcasted_node = std::make_shared(factor_node, input_shape_node); + + Output cos_factor_broadcasted_node = std::make_shared( + std::make_shared(factor_broadcasted_node, pos_node_reshaped)); + Output sin_factor_broadcasted_node = std::make_shared( + std::make_shared(factor_broadcasted_node, pos_node_reshaped)); + + float mscale = attn_factor; + Output mscale_node = + std::make_shared(ov::element::f32, ov::Shape{}, std::vector{mscale}); + Output mscale_sin_sign_node = + std::make_shared(ov::element::f32, ov::Shape{}, std::vector{mscale * sin_sign}); + Output cos_theta_node = std::make_shared(cos_factor_broadcasted_node, mscale_node); + Output sin_theta_node = std::make_shared(sin_factor_broadcasted_node, mscale_node); + + if (!is_neox) { + auto input_shape = context.get_input_shape(0); + + auto begin_even = ov::op::v0::Constant::create(ov::element::i64, Shape{3}, {0, 0, 0}); + auto begin_odd = ov::op::v0::Constant::create(ov::element::i64, Shape{3}, {0, 0, 1}); + auto end = std::make_shared(data_node); + auto stride = ov::op::v0::Constant::create(ov::element::i64, Shape{3}, {1, 1, 2}); + auto even_slice = std::make_shared(data_node, begin_even, end, stride); + auto odd_slice = std::make_shared(data_node, begin_odd, end, stride); + + auto first_half = + std::make_shared(std::make_shared(even_slice, cos_theta_node), + std::make_shared(odd_slice, sin_theta_node)); + auto second_half = + std::make_shared(std::make_shared(even_slice, sin_theta_node), + std::make_shared(odd_slice, cos_theta_node)); + + auto stack = std::make_shared(OutputVector{first_half, second_half}, 2); + auto shape_const = ov::op::v0::Constant::create( + ov::element::i64, + Shape{3}, + std::vector{-1, input_shape[1].get_length(), input_shape[2].get_length()}); + auto reshaped = std::make_shared(stack, shape_const, false); + + return {reshaped}; + } else { + auto slice_node = + std::make_shared(data_node, + ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2}), + 2); + Output slice_data_node_0 = slice_node->outputs()[0]; + Output slice_data_node_1 = slice_node->outputs()[1]; + + auto first_half_node = std::make_shared( + std::make_shared(slice_data_node_0, cos_theta_node), + std::make_shared(slice_data_node_1, sin_theta_node)); + + auto second_half_node = std::make_shared( + std::make_shared(slice_data_node_0, sin_theta_node), + std::make_shared(slice_data_node_1, cos_theta_node)); + + auto res_node = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, 2); + return {res_node}; + } +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/scale.cpp b/ggml/src/ggml-openvino/openvino/op/scale.cpp new file mode 100644 index 0000000000..392bfc1ed4 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/scale.cpp @@ -0,0 +1,31 @@ +#include +#include + +#include "../node_context.hpp" +#include "../utils.hpp" +#include "openvino/core/node.hpp" +#include "openvino/core/node_output.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/multiply.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_scale(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + float scale; + memcpy(&scale, context.get_output_op_params(0), sizeof(float)); + auto scale_node = std::make_shared(ov::element::f32, ov::Shape{}, std::vector{scale}); + + auto res = std::make_shared(context.get_input(0), scale_node); + + return {res}; +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/soft_max.cpp b/ggml/src/ggml-openvino/openvino/op/soft_max.cpp new file mode 100644 index 0000000000..27c7cefef0 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/soft_max.cpp @@ -0,0 +1,88 @@ + +#include +#include +#include + +#include "../node_context.hpp" +#include "../utils.hpp" +#include "openvino/core/node.hpp" +#include "openvino/core/node_output.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/softmax.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_soft_max(const NodeContext& context) { + num_inputs_check(context, 1, 2); + + auto input_node = context.get_input(0); + + float scale = 1.0f; + float max_bias = 0.0f; + auto op_params = context.get_output_op_params(0); + memcpy(&scale, (float*)op_params + 0, sizeof(float)); + memcpy(&max_bias, (float*)op_params + 1, sizeof(float)); + + const uint32_t n_head = context.get_input_shape(0)[0].get_length(); + const uint32_t n_head_log2 = 1u << (uint32_t)floor(log2(n_head)); + + // const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + // const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const float slope = (max_bias > 0.0f) ? 1.0f : 1.0f; + // const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) + // : 1.0f; + + if (scale != 1.0f) { + auto scale_node = + std::make_shared(ov::element::f32, ov::Shape{}, std::vector{scale}); + input_node = std::make_shared(input_node, scale_node); + } + + if (context.get_input_size() == 2) { + // Calculate mask then softmax + auto mask_node = context.get_input(1); + ov::element::Type mask_type = (context.get_input_type(1)).as(); + if (mask_type == ov::element::f16) { + // Convert f16 to f32 + mask_node = std::make_shared(mask_node, ov::element::f32); + } + + // Stride slice mask node + Output mask_begin_node = ov::op::v0::Constant::create(ov::element::i64, Shape{3}, {0, 0, 0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + auto input_last_two_dim = get_dimensions(input_node.get_node_shared_ptr(), {1, 2}); + auto mask_slice_shape = std::make_shared(ov::NodeVector{one, input_last_two_dim}, 0); + Output mask_stride_node = ov::op::v0::Constant::create(ov::element::i64, Shape{3}, {1, 1, 1}); + auto mask_node_sliced = + std::make_shared(mask_node, mask_begin_node, mask_slice_shape, mask_stride_node); + + // slope * mask + auto slope_node = + std::make_shared(ov::element::f32, ov::Shape{}, std::vector{slope}); + auto slope_mask_node = std::make_shared(mask_node_sliced, slope_node); + + // input + slope * mask + auto input_slope_mask_node = std::make_shared(input_node, slope_mask_node); + + // Calculate softmax + auto res = std::make_shared(input_slope_mask_node, 2); + return {res}; + } else { + // Directly softmax + auto res = std::make_shared(input_node, 0); + return {res}; + } +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/transpose.cpp b/ggml/src/ggml-openvino/openvino/op/transpose.cpp new file mode 100644 index 0000000000..f7408f40d4 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/transpose.cpp @@ -0,0 +1,23 @@ +#include "openvino/op/transpose.hpp" + +#include "../node_context.hpp" +#include "../utils.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_transpose(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + auto perm = argsort_descend(context.get_output_stride(0)); + auto res = std::make_shared(context.get_input(0), + ov::op::v0::Constant::create(ov::element::i64, {3}, perm)); + return {res}; +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/unary.cpp b/ggml/src/ggml-openvino/openvino/op/unary.cpp new file mode 100644 index 0000000000..391e0a7599 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/unary.cpp @@ -0,0 +1,24 @@ + +#include +#include + +#include "../node_context.hpp" +#include "../utils.hpp" +#include "openvino/core/node.hpp" +#include "openvino/core/node_output.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_unary(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + return {context.get_input(0)}; +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp b/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp new file mode 100644 index 0000000000..2a90a79475 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp @@ -0,0 +1,29 @@ +#include +#include + +#include "../node_context.hpp" +#include "../utils.hpp" +#include "openvino/core/node.hpp" +#include "openvino/core/node_output.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/sigmoid.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_unary_silu(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + auto input = context.get_input(0); + auto sigmoid = std::make_shared(input); + auto res = std::make_shared(input, sigmoid); + + return {res}; +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/view.cpp b/ggml/src/ggml-openvino/openvino/op/view.cpp new file mode 100644 index 0000000000..aaf117b662 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/view.cpp @@ -0,0 +1,26 @@ +#include +#include + +#include "../utils.hpp" +#include "openvino/core/node.hpp" +#include "openvino/core/node_output.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/strided_slice.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_view(const NodeContext& context) { + num_inputs_check(context, 1, 1); + + return {context.get_input(0)}; +}; + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.cpp b/ggml/src/ggml-openvino/openvino/op_table.cpp new file mode 100644 index 0000000000..af51bb157e --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op_table.cpp @@ -0,0 +1,64 @@ +#include "op_table.hpp" + +#include +#include +#include +#include +#include +#include + +#include "utils.hpp" + +using namespace ov::op; +namespace ov { +namespace frontend { +namespace ggml { + +namespace op { + +#define GGML_OP_CONVERTER(op) OutputVector op(const NodeContext& node) + +GGML_OP_CONVERTER(translate_add); +GGML_OP_CONVERTER(translate_cont); +GGML_OP_CONVERTER(translate_cpy); +GGML_OP_CONVERTER(translate_get_rows); +GGML_OP_CONVERTER(translate_mul); +GGML_OP_CONVERTER(translate_mulmat); +GGML_OP_CONVERTER(translate_permute); +GGML_OP_CONVERTER(translate_reshape); +GGML_OP_CONVERTER(translate_rms_norm); +GGML_OP_CONVERTER(translate_rope); +GGML_OP_CONVERTER(translate_scale); +GGML_OP_CONVERTER(translate_unary_silu); +GGML_OP_CONVERTER(translate_soft_max); +GGML_OP_CONVERTER(translate_transpose); +GGML_OP_CONVERTER(translate_unary); +GGML_OP_CONVERTER(translate_view); + +} // namespace op + +const std::unordered_map get_supported_ops() { + return {{"GGML_OP_ADD", op::translate_1to1_match_2_inputs}, + {"GGML_OP_ADD1", op::translate_1to1_match_2_inputs}, + {"GGML_OP_CONT", op::translate_cont}, + {"GGML_OP_CPY", op::translate_cpy}, + {"GGML_OP_DIV", op::translate_1to1_match_2_inputs}, + {"GGML_OP_GET_ROWS", op::translate_get_rows}, + // {"GGML_OP_MUL", op::translate_1to1_match_2_inputs}, + {"GGML_OP_MUL", op::translate_mul}, + {"GGML_OP_MUL_MAT", op::translate_mulmat}, + {"GGML_OP_PERMUTE", op::translate_permute}, + {"GGML_OP_RESHAPE", op::translate_reshape}, + {"GGML_OP_RMS_NORM", op::translate_rms_norm}, + {"GGML_OP_ROPE", op::translate_rope}, + {"GGML_OP_SCALE", op::translate_scale}, + {"GGML_OP_SOFT_MAX", op::translate_soft_max}, + {"GGML_OP_SUB", op::translate_1to1_match_2_inputs}, + {"GGML_OP_TRANSPOSE", op::translate_transpose}, + {"GGML_UNARY_OP_SILU", op::translate_unary_silu}, + {"GGML_OP_VIEW", op::translate_view}}; +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.hpp b/ggml/src/ggml-openvino/openvino/op_table.hpp new file mode 100644 index 0000000000..c83aaa199f --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op_table.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include "node_context.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +const std::unordered_map get_supported_ops(); + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp new file mode 100644 index 0000000000..f5b14d3a0f --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -0,0 +1,145 @@ +#include "translate_session.hpp" + +#include +#include + +#include "input_model.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +using namespace ov::op; + +TranslateSession::TranslateSession(const frontend::InputModel::Ptr& input_model, + const std::unordered_map& translator_map) + : m_input_model(input_model), + m_translator_map(translator_map), + m_ov_model(nullptr) {} + +std::shared_ptr TranslateSession::get_converted_model() { + if (m_ov_model) { + return m_ov_model; + } + m_ov_model = translate_graph(m_input_model); + // print_model_topology(); + return m_ov_model; +} + +void TranslateSession::print_model_topology() { + try { + std::ofstream outfile("model_topology.txt", std::ios::out | std::ios::app); + if (!outfile.is_open()) { + throw std::runtime_error("Failed to open file for writing model topology."); + } + + outfile << "============ Model ============" << std::endl; + for (const auto& op : m_ov_model->get_ordered_ops()) { + outfile << "Operation: " << op->get_friendly_name() << std::endl; + outfile << " Inputs:" << std::endl; + for (const auto& input : op->inputs()) { + outfile << " " << input.get_node()->get_friendly_name() << " -> " << input.get_element_type() << " " + << input.get_shape() << std::endl; + } + outfile << " Outputs:" << std::endl; + for (const auto& output : op->outputs()) { + outfile << " " << output.get_node()->get_friendly_name() << " -> " << output.get_element_type() + << " " << output.get_shape() << std::endl; + } + outfile << std::endl; + } + outfile << "===============================" << std::endl; + outfile.close(); + } catch (const std::exception& ex) { + std::cout << ex.what() << std::endl; + } +} + +std::shared_ptr TranslateSession::translate_graph(const frontend::InputModel::Ptr& input_model) { + ov::ParameterVector params; + ov::ResultVector results; + auto tensor_map = std::make_shared(); + std::shared_ptr resulting_model; + + const auto& ggml_model = std::dynamic_pointer_cast(input_model); + std::shared_ptr ggml_model_decoder = ggml_model->get_model_decoder(); + + FRONT_END_GENERAL_CHECK(ggml_model, "nullptr for InputModel is given for translation into OV Model"); + const auto& model_inputs = ggml_model->get_inputs(); + const auto& model_outputs = ggml_model->get_outputs(); + + for (const auto& it : ggml_model_decoder->get_model_inputs()) { + params.push_back(std::dynamic_pointer_cast(it.second)); + (*tensor_map)[it.first] = it.second; + } + + for (const auto& it : ggml_model_decoder->get_model_extra_inputs()) { + params.push_back(std::dynamic_pointer_cast(it.second)); + (*tensor_map)[it.first] = it.second; + } + + for (const auto& it : ggml_model_decoder->get_model_weights()) { + (*tensor_map)[it.first] = it.second; + } + + auto node_visitor = [&](std::shared_ptr node) { + auto operation_type = node->get_op_type(); + ov::OutputVector converted_outputs; + auto it = m_translator_map.find(operation_type); + if (it != m_translator_map.end()) { + try { + NodeContext node_context(node, tensor_map, this); + converted_outputs = it->second(node_context); + } catch (const std::exception& ex) { + std::cout << ex.what() << std::endl; + } + } else { + // TODO + } + + const auto& node_output_names = node->get_output_names(); + FRONT_END_OP_CONVERSION_CHECK(node_output_names.size() == converted_outputs.size(), + "Number of ", + operation_type, + " outputs greater than number of converted outputs, which are ", + node_output_names.size(), + " and ", + converted_outputs.size(), + " respectively."); + + for (size_t i = 0; i < node_output_names.size(); ++i) { + auto output_name = node_output_names[i]; + if (i < converted_outputs.size() && converted_outputs[i].get_node_shared_ptr() != nullptr) { + (*tensor_map)[output_name] = converted_outputs[i]; + } + } + }; + + ggml_model_decoder->visit_subgraph(node_visitor); + + for (const auto& name : ggml_model_decoder->get_model_output_names()) { + FRONT_END_GENERAL_CHECK(tensor_map->find(name) != tensor_map->end(), + "Output name not found in tensor map: ", + name); + auto result = std::make_shared(tensor_map->at(name)); + // result->set_friendly_name(it); + results.push_back(result); + } + + ov::ParameterVector used_params; + for (const auto& param : params) { + if (!param->output(0).get_target_inputs().empty()) { + used_params.push_back(param); + } + } + if (auto diff = params.size() - used_params.size()) { + std::cout << diff << " parameters are not used in the model." << std::endl; + } + resulting_model = std::make_shared(results, used_params); + + return resulting_model; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/translate_session.hpp b/ggml/src/ggml-openvino/openvino/translate_session.hpp new file mode 100644 index 0000000000..5c7a9d464d --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/translate_session.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "input_model.hpp" +#include "node_context.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +class TranslateSession { +public: + TranslateSession(const frontend::InputModel::Ptr& input_model, + const std::unordered_map& translator_map); + + std::shared_ptr get_converted_model(); + std::shared_ptr translate_graph(const frontend::InputModel::Ptr& input_model); + +private: + void print_model_topology(); + const frontend::InputModel::Ptr m_input_model; + const std::unordered_map& m_translator_map; + std::shared_ptr m_ov_model; +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp new file mode 100644 index 0000000000..ff16e9d4ae --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -0,0 +1,52 @@ +#include "utils.hpp" + +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { + +std::string getCurrentTime() { + std::time_t now = std::time(nullptr); + char buf[100]; + std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", std::localtime(&now)); + return buf; +} + +void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs) { + auto input_size = context.get_input_size(); + FRONT_END_OP_CONVERSION_CHECK(input_size >= min_inputs, "Got less inputs than expected"); + FRONT_END_OP_CONVERSION_CHECK(input_size <= max_inputs, "Got more inputs than expected"); +} + +int non_cont_dim(std::vector ne, std::vector nb) { + int dim = nb.size() - 1; + size_t bytes = nb[dim]; + for (int i = dim; i > 0; i--) { + bytes *= ne[i]; + if (bytes != nb[i - 1]) { + return i; + } + } + return 0; +} + +std::shared_ptr get_dimensions(const std::shared_ptr& shape, + const std::vector& dims) { + using namespace ov::op; + const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); + return std::make_shared(shape, dims_const, zero); +} + +std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims) { + return get_dimensions(std::make_shared(node), dims); +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/utils.hpp b/ggml/src/ggml-openvino/openvino/utils.hpp new file mode 100644 index 0000000000..6e106fa932 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/utils.hpp @@ -0,0 +1,68 @@ +#pragma once + +#include + +#include "node_context.hpp" + +namespace ov { +namespace frontend { +namespace ggml { + +void dump_ov_model(const std::shared_ptr model); + +void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs); + +int non_cont_dim(std::vector ne, std::vector nb); + +template +std::vector argsort_descend(const std::vector& v) { + std::vector idx(v.size()); + std::iota(idx.begin(), idx.end(), 0); + std::sort(idx.begin(), idx.end(), [&v](int i1, int i2) { + return v[i1] > v[i2]; + }); + return idx; +} + +template +std::vector sorted_descend(std::vector v) { + std::sort(v.begin(), v.end(), [](T a, T b) { + return a > b; + }); + return v; +} + +template +bool is_permuted(const std::vector& strides) { + for (size_t i = 0; i < strides.size() - 1; ++i) { + if (strides[i] < strides[i + 1]) { + return true; + } + } + return false; +} + +template +std::vector permute(const std::vector& x, const std::vector& perm) { + std::vector result; + result.reserve(perm.size()); + for (int i : perm) { + result.push_back(x[i]); + } + return result; +} + +std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims); +std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims); + +namespace op { +template +OutputVector translate_1to1_match_2_inputs(const NodeContext& context) { + num_inputs_check(context, 2, 2); + return {std::make_shared(context.get_input(0), context.get_input(1))}; +} +} // namespace op + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index f36700d5ec..34bcfc54a7 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -14,6 +14,8 @@ #include "ggml-impl.h" #include "ggml.h" +#include "openvino/frontend.hpp" +#include "openvino/input_model.hpp" std::shared_ptr get_ggml_decoder(struct ggml_cgraph* cgraph) { return std::make_shared(nullptr, cgraph); @@ -56,11 +58,11 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c } // auto devices = core.get_available_devices(); - static auto front_end = get_ggml_frontend(); - if (!front_end) { - GGML_LOG_ERROR("GGML FrontEnd is not initialized \n"); - return GGML_STATUS_FAILED; - } + // static auto front_end = get_ggml_frontend(); + // if (!front_end) { + // GGML_LOG_ERROR("GGML FrontEnd is not initialized \n"); + // return GGML_STATUS_FAILED; + // } using CachedItem = std::pair, ov::CompiledModel>; static std::unordered_map compiled_cache; @@ -79,14 +81,18 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c compiled_model = it->second.second; compile_end_time = ggml_time_us(); } else { - std::shared_ptr graph_decoder = ggml_decoder; - ov::frontend::InputModel::Ptr input_model = front_end->load(graph_decoder); - if (!input_model) { - GGML_LOG_ERROR("Input Model is not loaded \n"); - return GGML_STATUS_FAILED; - } + // std::shared_ptr graph_decoder = ggml_decoder; + // ov::frontend::InputModel::Ptr input_model = front_end->load(graph_decoder); + // if (!input_model) { + // GGML_LOG_ERROR("Input Model is not loaded \n"); + // return GGML_STATUS_FAILED; + // } + + // model = front_end->convert(input_model); + + ov::frontend::InputModel::Ptr input_model = std::make_shared(ggml_decoder); + model = ov::frontend::ggml::FrontEnd::convert(input_model); - model = front_end->convert(input_model); conversion_end_time = ggml_time_us(); if (getenv("GGML_OPENVINO_DUMP_IR")) {