llama.cpp/ggml/src/ggml-openvino/openvino/node_context.hpp

123 lines
3.6 KiB
C++

#pragma once
#include <cstdint>
#include <openvino/frontend/node_context.hpp>
#include "decoder.hpp"
namespace ov {
namespace frontend {
namespace ggml {
class TranslateSession;
typedef std::map<std::string, Output<Node>> TensorMap;
class NodeContext : public frontend::NodeContext {
public:
NodeContext(const std::shared_ptr<GgmlDecoder>& decoder,
std::shared_ptr<TensorMap>& tensor_map,
TranslateSession* translate_session = nullptr)
: ov::frontend::NodeContext(decoder->get_op_type()),
m_decoder(decoder),
m_tensor_map(tensor_map),
m_translate_session(translate_session) {
m_input_names = decoder->get_input_names();
m_output_names = decoder->get_output_names();
}
TranslateSession* get_translate_session() const {
return m_translate_session;
}
size_t get_input_size() const override {
return m_decoder->get_input_size();
}
ov::element::Type get_input_type(size_t index) const {
return m_decoder->get_input_type(m_input_names[index]);
}
PartialShape get_input_shape(size_t index) const {
return m_decoder->get_input_shape(m_input_names[index]);
}
std::vector<size_t> get_input_stride(size_t index) const {
return m_decoder->get_input_stride(m_input_names[index]);
}
PartialShape get_output_shape(size_t index) const {
return m_decoder->get_output_shape(m_output_names[index]);
}
std::vector<size_t> get_output_stride(size_t index) const {
return m_decoder->get_output_stride(m_output_names[index]);
}
int32_t* get_input_op_params(size_t index) const {
return m_decoder->get_input_op_params(m_input_names[index]);
}
int32_t* get_output_op_params(size_t index) const {
return m_decoder->get_output_op_params(m_output_names[index]);
}
ov::element::Type get_output_type(size_t index) const {
return m_decoder->get_output_type(m_output_names[index]);
}
Output<Node> get_input(int idx) const override {
return m_tensor_map->at(m_decoder->get_input_name(idx));
}
Output<Node> get_input(const std::string& name) const override {
if (m_tensor_map->find(name) == m_tensor_map->end()) {
throw std::runtime_error("'" + name + "' not found in tensor map.");
}
return m_tensor_map->at(name);
}
bool has_input(const std::string& name) const {
return m_tensor_map->find(name) != m_tensor_map->end();
}
const std::string& get_name() const override {
return m_decoder->get_op_name();
}
ov::Any get_attribute_as_any(const std::string& name) const override {
return m_decoder->get_attribute(name);
}
int get_op_case() const {
return m_decoder->get_op_case();
}
bool is_static() const {
return m_decoder->is_static();
}
bool is_first_token() const {
return m_decoder->is_first_token();
}
int get_num_heads() const { return m_decoder->get_num_heads(); }
int get_num_heads_kv() const { return m_decoder->get_num_heads_kv(); }
int get_head_size() const { return m_decoder->get_head_size(); }
int get_context_size() const { return m_decoder->get_context_size(); }
private:
std::shared_ptr<GgmlDecoder> m_decoder;
std::shared_ptr<TensorMap>& m_tensor_map;
TranslateSession* m_translate_session;
std::vector<std::string> m_input_names;
std::vector<std::string> m_output_names;
};
using CreatorFunction = std::function<ov::OutputVector(const ov::frontend::ggml::NodeContext&)>;
} // namespace ggml
} // namespace frontend
} // namespace ov