diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 2d96bf1572..97bd938567 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -147,7 +147,6 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) { continue; } std::string src_name = std::string(src->name); - m_inputs[src_name] = src; current_node_info.node_inputs[src_name] = src; current_node_info.node_inputs_names.push_back(src_name); @@ -163,6 +162,7 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) { if (m_model_inputs.find(src_name) != m_model_inputs.end()) { continue; } + m_inputs[src_name] = src; auto param_node = std::make_shared(get_ov_type(src), get_graph_input_shape(node, src)); param_node->set_friendly_name(src_name); @@ -751,6 +751,10 @@ ov::element::Type GgmlOvDecoder::get_input_type(const std::string & name) const return get_ov_type(m_inputs.at(name)); } +ov::element::Type GgmlOvDecoder::get_input_type(int node_idx, const std::string & name) const { + return get_ov_type(m_node_info_list[node_idx].node_inputs.at(name)); +} + size_t GgmlOvDecoder::get_input_size() const { return m_model_inputs.size(); } diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h index 336833d8af..c76315f8af 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.h +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -85,6 +85,8 @@ public: virtual ov::element::Type get_input_type(const std::string & name) const override; + virtual ov::element::Type get_input_type(int node_idx, const std::string & name) const override; + virtual size_t get_input_size() const override; virtual size_t get_input_size(int node_idx) const override; diff --git a/ggml/src/ggml-openvino/openvino/decoder.hpp b/ggml/src/ggml-openvino/openvino/decoder.hpp index 2cc6dbba46..ef4b3a7593 100644 --- a/ggml/src/ggml-openvino/openvino/decoder.hpp +++ b/ggml/src/ggml-openvino/openvino/decoder.hpp @@ -22,6 +22,8 @@ public: virtual element::Type get_input_type(const std::string& name) const = 0; + virtual element::Type get_input_type(int node_idx, const std::string& name) const = 0; + virtual size_t get_input_size() const = 0; virtual size_t get_input_size(int node_idx) const = 0; diff --git a/ggml/src/ggml-openvino/openvino/node_context.hpp b/ggml/src/ggml-openvino/openvino/node_context.hpp index e95bafc269..a0666b21ac 100644 --- a/ggml/src/ggml-openvino/openvino/node_context.hpp +++ b/ggml/src/ggml-openvino/openvino/node_context.hpp @@ -40,7 +40,7 @@ public: } ov::element::Type get_input_type(size_t index) const { - return m_decoder->get_input_type(m_input_names[index]); + return m_decoder->get_input_type(m_node_idx, m_input_names[index]); } PartialShape get_input_shape(size_t input_index) const {