Remove the second decoder for node. Moving the function into the model decoder

This commit is contained in:
XuejunZhai 2025-11-25 17:58:54 -08:00 committed by Mustafa Cavus
parent 4400b5cb4b
commit ae936519d2
5 changed files with 232 additions and 176 deletions

View File

@ -35,36 +35,12 @@
#include <string>
#include <vector>
GgmlOvDecoder::GgmlOvDecoder(ggml_tensor * node,
ggml_cgraph * cgraph,
bool is_static,
int context_size,
int context_size_swa,
int num_heads,
int num_heads_kv,
int head_size,
const std::vector<int> & swa_layers) :
m_is_static(is_static),
m_cgraph(cgraph),
m_node(node),
m_op_name(std::string(node->name)),
m_ctx(context_size),
m_ctx_swa(context_size_swa),
m_n_heads(num_heads),
m_n_heads_kv(num_heads_kv),
m_head_size(head_size),
m_swa_layers(swa_layers) {
set_input_output(node);
}
GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
bool is_static) :
m_is_static(is_static),
m_cgraph(cgraph),
m_op_name(m_node ? std::string(m_node->name) : ""),
m_model_weights(model_weights),
m_is_static(is_static) {
m_model_weights(model_weights) {
if (auto * env = getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); env && std::string(env) != "0") {
#ifdef _WIN32
_putenv_s("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS", "");
@ -83,6 +59,11 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
set_input_output(cur_node);
}
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
m_node_info_list[node_n].node_op_case = compute_op_case(m_node_info_list[node_n].node);
m_node_info_list[node_n].node_op_type = compute_op_type(m_node_info_list[node_n].node);
}
add_extra_inputs();
}
@ -104,6 +85,7 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, std::map<std::string, std::sh
// 3. constructing a decoder for the whole graph naively (op test case)
void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
std::string node_name;
NodeInfo current_node_info;
if (node->op == GGML_OP_SET_ROWS) {
// SET_ROWS updates the tensor in place. For later ov op that uses the
// the view_src of SET_ROWS, we need to make sure they get the updated tensor
@ -117,6 +99,12 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
m_output_names.push_back(node_name);
m_outputs[node_name] = node;
current_node_info.node = node;
current_node_info.node_name = node_name;
current_node_info.node_outputs[node_name] = node;
current_node_info.node_outputs_names.push_back(node_name);
current_node_info.node_op_case = 0;
for (int i = 0; i < GGML_MAX_SRC; i++) {
auto * src = node->src[i];
if (src == nullptr) {
@ -125,7 +113,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
std::string src_name = std::string(src->name);
m_input_names.push_back(src_name);
m_inputs[src_name] = src;
m_op_node_name.emplace_back(src_name, ggml_op_name(node->op));
current_node_info.node_inputs[src_name] = src;
current_node_info.node_inputs_names.push_back(src_name);
// Add model inputs and weights constants, if called for the whole graph
if (naive) {
@ -137,7 +126,7 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
m_model_inputs[src_name] = param_node;
}
} else if (!m_node && !src->view_src) {
} else if (!src->view_src) {
ggml_backend_buffer * buffer = src->buffer;
if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY || src->flags & GGML_TENSOR_FLAG_INPUT) {
@ -160,7 +149,7 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
// Add model outputs, if called for the whole graph
if (naive) {
m_model_output_names.push_back(node_name);
} else if (!m_node) {
} else {
// Model outputs are tensors with GGML_TENSOR_FLAG_OUTPUT flag and kv_caches
static std::set<std::string> debug_output_names = {};
// Workaround: the final tensor "result_output" does not have GGML_TENSOR_FLAG_OUTPUT flag set in cgraph
@ -179,92 +168,92 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
}
}
if (m_node) {
switch (node->op) {
case GGML_OP_RESHAPE: {
auto * src = node->src[0];
if (src->op == GGML_OP_RESHAPE && src->src[0]->ne[0] == node->ne[0] && src->src[0]->ne[1] == node->ne[1]) {
m_op_case = 4;
} else if (node->ne[0] * node->ne[1] == src->ne[0]) {
m_op_case = 1;
} else if (src->ne[0] * src->ne[1] == node->ne[0]) {
m_op_case = 2;
if (src->ne[2] * src->ne[3] == node->ne[1]) {
m_op_case = 5;
}
} else if (src->ne[0] * src->ne[1] == node->ne[1]) {
m_op_case = 3;
} else if (src->ne[1] * src->ne[2] == node->ne[1]) {
m_op_case = 6;
m_node_info_list.push_back(current_node_info);
}
int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) {
int op_case = 0;
switch (node->op) {
case GGML_OP_RESHAPE: {
auto * src = node->src[0];
if (src->op == GGML_OP_RESHAPE && src->src[0]->ne[0] == node->ne[0] && src->src[0]->ne[1] == node->ne[1]) {
op_case = 4;
} else if (node->ne[0] * node->ne[1] == src->ne[0]) {
op_case = 1;
} else if (src->ne[0] * src->ne[1] == node->ne[0]) {
op_case = 2;
if (src->ne[2] * src->ne[3] == node->ne[1]) {
op_case = 5;
}
break;
}
case GGML_OP_CONT: {
if (node->src[0]->op == GGML_OP_PERMUTE) {
m_op_case = 1;
} else if (node->src[0]->op == GGML_OP_TRANSPOSE) {
m_op_case = 2;
} else if (node->src[0]->op == GGML_OP_VIEW) {
// The input comes from a VIEW which is subtensor
m_op_case = 3;
}
break;
}
case GGML_OP_PERMUTE: {
if (node->src[0]->op != GGML_OP_VIEW) {
m_op_case = 1;
} else if (ggml_is_contiguous(node->src[0])) {
std::string src_name(node->view_src->name);
if (src_name.find("cache") == std::string::npos) {
// permute Qcur
m_op_case = 4;
} else {
// Permute kv cache (view)
int layer = extract_layer_from_name(src_name);
if (!is_swa_layer(layer)) {
m_op_case = 2;
} else {
m_op_case = 3;
}
}
}
break;
}
case GGML_OP_MUL_MAT: {
if (node->src[0]->op == GGML_OP_CONT && node->src[0]->src[0]->op == GGML_OP_TRANSPOSE) {
m_op_case = 2;
} else if (node->src[0]->op == GGML_OP_VIEW && node->src[1]->op == GGML_OP_VIEW) {
// test-backend-ops case
m_op_case = 3;
}
break;
}
case GGML_OP_GET_ROWS: {
if (node->src[1]->op == GGML_OP_VIEW) {
m_op_case = 2;
}
break;
}
case GGML_OP_ROPE: {
if (node->src[0]->op == GGML_OP_VIEW) {
m_op_case = 2;
}
break;
}
case GGML_OP_VIEW: {
if (node->src[0]->op == GGML_OP_VIEW) {
auto * src = node->src[0];
if (ggml_nelements(node) != ggml_nelements(src)) {
throw std::runtime_error("Unsupported VIEW case");
}
// This view is a reshape, slicing happens at src->op
m_op_case = 2;
}
}
default:
break;
} else if (src->ne[0] * src->ne[1] == node->ne[1]) {
op_case = 3;
} else if (src->ne[1] * src->ne[2] == node->ne[1]) {
op_case = 6;
}
break;
}
case GGML_OP_CONT: {
if (node->src[0]->op == GGML_OP_PERMUTE) {
op_case = 1;
} else if (node->src[0]->op == GGML_OP_TRANSPOSE) {
op_case = 2;
} else if (node->src[0]->op == GGML_OP_VIEW) {
op_case = 3;
}
break;
}
case GGML_OP_PERMUTE: {
if (node->src[0]->op != GGML_OP_VIEW) {
op_case = 1;
} else if (ggml_is_contiguous(node->src[0])) {
std::string src_name(node->view_src->name);
if (src_name.find("cache") == std::string::npos) {
op_case = 4;
} else {
int layer = extract_layer_from_name(src_name);
if (!is_swa_layer(layer)) {
op_case = 2;
} else {
op_case = 3;
}
}
}
break;
}
case GGML_OP_MUL_MAT: {
if (node->src[0]->op == GGML_OP_CONT && node->src[0]->src[0]->op == GGML_OP_TRANSPOSE) {
op_case = 2;
} else if (node->src[0]->op == GGML_OP_VIEW && node->src[1]->op == GGML_OP_VIEW) {
op_case = 3;
}
break;
}
case GGML_OP_GET_ROWS: {
if (node->src[1]->op == GGML_OP_VIEW) {
op_case = 2;
}
break;
}
case GGML_OP_ROPE: {
if (node->src[0]->op == GGML_OP_VIEW) {
op_case = 2;
}
break;
}
case GGML_OP_VIEW: {
if (node->src[0]->op == GGML_OP_VIEW) {
auto * src = node->src[0];
if (ggml_nelements(node) != ggml_nelements(src)) {
throw std::runtime_error("Unsupported VIEW case");
}
op_case = 2;
}
break;
}
default:
break;
}
return op_case;
}
int extract_layer_from_name(const std::string & name) {
@ -722,10 +711,18 @@ ov::PartialShape GgmlOvDecoder::get_input_shape(const std::string & name) const
return ov::PartialShape(get_shape(m_inputs.at(name)));
}
ov::PartialShape GgmlOvDecoder::get_input_shape(int node_idx, const std::string & name) const {
return ov::PartialShape(get_shape(m_node_info_list[node_idx].node_inputs.at(name)));
}
std::vector<size_t> GgmlOvDecoder::get_input_stride(const std::string & name) const {
return get_stride(m_inputs.at(name));
}
std::vector<size_t> GgmlOvDecoder::get_input_stride(int node_idx, const std::string & name) const {
return get_stride(m_node_info_list[node_idx].node_inputs.at(name));
}
ov::element::Type GgmlOvDecoder::get_input_type(const std::string & name) const {
return get_ov_type(m_inputs.at(name));
}
@ -734,15 +731,18 @@ size_t GgmlOvDecoder::get_input_size() const {
return m_input_names.size();
}
std::string & GgmlOvDecoder::get_input_name(size_t index) const {
m_name = m_input_names[index];
return m_name;
size_t GgmlOvDecoder::get_input_size(int node_idx) const {
return m_node_info_list[node_idx].node_inputs_names.size();
}
std::vector<std::string> GgmlOvDecoder::get_input_names() const {
return m_input_names;
}
std::vector<std::string> GgmlOvDecoder::get_input_names(int node_idx) const {
return m_node_info_list[node_idx].node_inputs_names;
}
std::vector<size_t> GgmlOvDecoder::get_output_stride(const std::string & name) const {
return get_stride(m_outputs.at(name));
}
@ -755,40 +755,58 @@ ov::PartialShape GgmlOvDecoder::get_output_shape(const std::string & name) const
return ov::PartialShape(get_shape(ggml_tensor));
}
ov::element::Type GgmlOvDecoder::get_output_type(const std::string & name) const {
return get_ov_type(m_outputs.at(name));
ov::PartialShape GgmlOvDecoder::get_output_shape(int node_idx, const std::string & name) const {
auto * ggml_tensor = m_node_info_list[node_idx].node_outputs.at(name);
if (ggml_tensor->op == GGML_OP_SET_ROWS) {
ggml_tensor = ggml_tensor->view_src;
}
return ov::PartialShape(get_shape(ggml_tensor));
}
std::string & GgmlOvDecoder::get_output_name(size_t index) const {
m_name = std::string(m_output_names[index]);
return m_name;
ov::element::Type GgmlOvDecoder::get_output_type(const std::string & name) const {
return get_ov_type(m_outputs.at(name));
}
std::vector<std::string> GgmlOvDecoder::get_output_names() const {
return m_output_names;
}
std::vector<std::string> GgmlOvDecoder::get_output_names(int node_idx) const {
return m_node_info_list[node_idx].node_outputs_names;
}
const std::string & GgmlOvDecoder::get_op_name() const {
return m_op_name;
static const std::string unknown_name = "UNKNOWN_OP_NAME";
return unknown_name;
}
const std::string & GgmlOvDecoder::get_op_name(int node_idx) const {
return m_node_info_list[node_idx].node_name;
}
int32_t * GgmlOvDecoder::get_input_op_params(const std::string & name) const {
return m_inputs.at(name)->op_params;
}
int32_t * GgmlOvDecoder::get_input_op_params(int node_idx, const std::string & name) const {
return m_node_info_list[node_idx].node_inputs.at(name)->op_params;
}
int32_t * GgmlOvDecoder::get_output_op_params(const std::string & name) const {
return m_outputs.at(name)->op_params;
}
void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const {
for (const auto & node : m_nodes) {
auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_ctx, m_ctx_swa, m_n_heads,
m_n_heads_kv, m_head_size, m_swa_layers);
node_visitor(decoder);
int32_t * GgmlOvDecoder::get_output_op_params(int node_idx, const std::string & name) const {
return m_node_info_list[node_idx].node_outputs.at(name)->op_params;
}
void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>, int node_idx)> node_visitor) const {
for (int node_idx = 0; node_idx < m_cgraph->n_nodes; node_idx++) {
node_visitor(std::make_shared<GgmlOvDecoder>(*this), node_idx);
}
}
const std::string & GgmlOvDecoder::get_op_type() const {
std::string GgmlOvDecoder::compute_op_type(const ggml_tensor * node) {
static const std::map<ggml_op, std::string> ops = {
{GGML_OP_NONE, "GGML_OP_NONE" },
{GGML_OP_ACC, "GGML_OP_ACC" },
@ -836,14 +854,23 @@ const std::string & GgmlOvDecoder::get_op_type() const {
{GGML_GLU_OP_REGLU, "GGML_GLU_OP_REGLU" }
};
switch (m_node->op) {
switch (node->op) {
case GGML_OP_UNARY:
return unary_ops.at(ggml_get_unary_op(m_node));
return unary_ops.at(ggml_get_unary_op(node));
case GGML_OP_GLU:
return glu_ops.at(ggml_get_glu_op(m_node));
return glu_ops.at(ggml_get_glu_op(node));
default:
return ops.at(m_node->op);
return ops.at(node->op);
}
static const std::string unknown_op = "UNKNOWN_GGML_OP";
return unknown_op;
}
const std::string & GgmlOvDecoder::get_op_type(int node_idx) const {
return m_node_info_list[node_idx].node_op_type;
}
const std::string & GgmlOvDecoder::get_op_type() const {
static const std::string unknown_op = "UNKNOWN_GGML_OP";
return unknown_op;
}

View File

@ -13,22 +13,21 @@
class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
public:
struct NodeInfo {
ggml_tensor * node;
std::map<std::string, ggml_tensor *> node_inputs;
std::vector<std::string> node_inputs_names;
std::map<std::string, ggml_tensor *> node_outputs;
std::vector<std::string> node_outputs_names;
int node_op_case = 0;
std::string node_op_type;
std::string node_name;
};
// Graph decoder
GgmlOvDecoder(ggml_cgraph * cgraph,
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
bool is_static);
// Node decoder, called in GgmlOvDecoder::visit_subgraph
GgmlOvDecoder(ggml_tensor * node,
ggml_cgraph * cgraph,
bool is_static,
int context_size,
int context_size_swa,
int num_heads,
int num_heads_kv,
int head_size,
const std::vector<int> & swa_layers);
// Naive graph decoder
GgmlOvDecoder(ggml_cgraph * cgraph, std::map<std::string, std::shared_ptr<ov::Node>> & model_weights);
@ -39,12 +38,18 @@ public:
virtual ov::PartialShape get_input_shape(const std::string & name) const override;
virtual ov::PartialShape get_input_shape(int node_idx, const std::string & name) const override;
virtual std::vector<size_t> get_input_stride(const std::string & name) const override;
virtual std::vector<size_t> get_input_stride(int node_idx, const std::string & name) const override;
virtual ov::element::Type get_input_type(const std::string & name) const override;
virtual size_t get_input_size() const override;
virtual size_t get_input_size(int node_idx) const override;
virtual void get_input_node(size_t input_port_idx,
std::string & producer_name,
std::string & producer_output_port_name,
@ -55,35 +60,45 @@ public:
GGML_UNUSED(producer_output_port_index);
}
virtual std::string & get_input_name(size_t index) const override;
virtual std::vector<std::string> get_input_names() const override;
virtual std::vector<std::string> get_input_names(int node_idx) const override;
virtual ov::PartialShape get_output_shape(const std::string & name) const override;
virtual ov::PartialShape get_output_shape(int node_idx, const std::string & name) const override;
virtual std::vector<size_t> get_output_stride(const std::string & name) const override;
virtual ov::element::Type get_output_type(const std::string & name) const override;
virtual int32_t * get_input_op_params(const std::string & name) const override;
virtual int32_t * get_input_op_params(int node_idx, const std::string & name) const override;
virtual int32_t * get_output_op_params(const std::string & name) const override;
virtual std::string & get_output_name(size_t index) const override;
virtual int32_t * get_output_op_params(int node_idx, const std::string & name) const override;
virtual std::vector<std::string> get_output_names() const override;
virtual std::vector<std::string> get_output_names(int node_idx) const override;
virtual const std::string & get_op_type() const override;
virtual const std::string & get_op_type(int node_idx) const override;
virtual const std::string & get_op_name() const override;
virtual void visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const override;
virtual const std::string & get_op_name(int node_idx) const override;
virtual void visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>, int node_idx)> node_visitor) const override;
ggml_tensor * get_input_ggml_tensor(const std::string & name) const { return m_inputs.at(name); }
ggml_tensor * get_output_ggml_tensor(const std::string & name) const { return m_outputs.at(name); }
virtual int get_op_case() const override { return m_op_case; }
virtual int get_op_case(int node_idx) const override { return m_node_info_list[node_idx].node_op_case; }
virtual const std::map<std::string, std::shared_ptr<ov::Node>> & get_model_inputs() const override {
return m_model_inputs;
@ -150,6 +165,8 @@ private:
static std::vector<size_t> get_shape(const ggml_tensor * tensor);
static std::vector<size_t> get_stride(const ggml_tensor * tensor);
static ov::element::Type get_ov_type(const ggml_tensor * tensor);
int compute_op_case(const ggml_tensor * node);
std::string compute_op_type(const ggml_tensor * node);
void set_llm_params();
void validate_cgraph() const;
@ -157,21 +174,18 @@ private:
bool m_is_static = false;
ggml_cgraph * m_cgraph = nullptr;
ggml_tensor * m_node = nullptr;
std::vector<ggml_tensor *> m_nodes;
std::map<std::string, ggml_tensor *> m_inputs;
std::vector<std::string> m_input_names;
std::map<std::string, ggml_tensor *> m_outputs;
std::vector<std::string> m_output_names;
std::string m_op_name;
mutable std::string m_name;
int m_op_case = 0;
std::vector<std::pair<std::string, std::string>> m_op_node_name;
std::map<std::string, std::shared_ptr<ov::Node>> m_model_inputs;
std::map<std::string, std::shared_ptr<ov::Node>> m_model_extra_inputs;
std::map<std::string, std::shared_ptr<ov::Tensor>> m_model_extra_input_values;
std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights;
std::vector<std::string> m_model_output_names;
std::vector<NodeInfo> m_node_info_list;
// Fixed for a model
int m_ctx = -1;

View File

@ -16,42 +16,58 @@ public:
virtual PartialShape get_input_shape(const std::string& name) const = 0;
virtual PartialShape get_input_shape(int node_idx, const std::string& name) const = 0;
virtual std::vector<size_t> get_input_stride(const std::string& name) const = 0;
virtual std::vector<size_t> get_input_stride(int node_idx, const std::string& name) const = 0;
virtual element::Type get_input_type(const std::string& name) const = 0;
virtual size_t get_input_size() const = 0;
virtual size_t get_input_size(int node_idx) const = 0;
virtual void get_input_node(size_t input_port_idx,
std::string& producer_name,
std::string& producer_output_port_name,
size_t& producer_output_port_index) const = 0;
virtual std::string& get_input_name(size_t index) const = 0;
virtual std::vector<std::string> get_input_names() const = 0;
virtual std::vector<std::string> get_input_names(int node_idx) const = 0;
virtual PartialShape get_output_shape(const std::string& name) const = 0;
virtual PartialShape get_output_shape(int node_idx, const std::string& name) const = 0;
virtual std::vector<size_t> get_output_stride(const std::string& name) const = 0;
virtual element::Type get_output_type(const std::string& name) const = 0;
virtual int32_t* get_input_op_params(const std::string& name) const = 0;
virtual int32_t* get_input_op_params(int node_idx, const std::string& name) const = 0;
virtual int32_t* get_output_op_params(const std::string& name) const = 0;
virtual std::string& get_output_name(size_t index) const = 0;
virtual int32_t* get_output_op_params(int node_idx, const std::string& name) const = 0;
virtual std::vector<std::string> get_output_names() const = 0;
virtual std::vector<std::string> get_output_names(int node_idx) const = 0;
virtual const std::string& get_op_type() const = 0;
virtual const std::string& get_op_type(int node_idx) const = 0;
virtual const std::string& get_op_name() const = 0;
virtual void visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const = 0;
virtual const std::string& get_op_name(int node_idx) const = 0;
virtual int get_op_case() const = 0;
virtual void visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>, int node_idx)> node_visitor) const = 0;
virtual int get_op_case(int node_idx) const = 0;
virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_inputs() const = 0;
virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_extra_inputs() const = 0;

View File

@ -18,13 +18,15 @@ class NodeContext : public frontend::NodeContext {
public:
NodeContext(const std::shared_ptr<GgmlDecoder>& decoder,
std::shared_ptr<TensorMap>& tensor_map,
int node_idx,
TranslateSession* translate_session = nullptr)
: ov::frontend::NodeContext(decoder->get_op_type()),
: ov::frontend::NodeContext(decoder->get_op_type(node_idx)),
m_decoder(decoder),
m_tensor_map(tensor_map),
m_node_idx(node_idx),
m_translate_session(translate_session) {
m_input_names = decoder->get_input_names();
m_output_names = decoder->get_output_names();
m_input_names = decoder->get_input_names(m_node_idx);
m_output_names = decoder->get_output_names(m_node_idx);
}
TranslateSession* get_translate_session() const {
@ -34,7 +36,7 @@ public:
const std::vector<std::string>& get_input_names() const { return m_input_names; }
size_t get_input_size() const override {
return m_decoder->get_input_size();
return m_decoder->get_input_size(m_node_idx);
}
ov::element::Type get_input_type(size_t index) const {
@ -42,29 +44,25 @@ public:
}
PartialShape get_input_shape(size_t index) const {
return m_decoder->get_input_shape(m_input_names[index]);
return m_decoder->get_input_shape(m_node_idx, m_input_names[index]);
}
std::vector<size_t> get_input_stride(size_t index) const {
return m_decoder->get_input_stride(m_input_names[index]);
return m_decoder->get_input_stride(m_node_idx, m_input_names[index]);
}
std::string get_output_name() const { return m_output_names[0]; }
PartialShape get_output_shape(size_t index) const {
return m_decoder->get_output_shape(m_output_names[index]);
}
std::vector<size_t> get_output_stride(size_t index) const {
return m_decoder->get_output_stride(m_output_names[index]);
return m_decoder->get_output_shape(m_node_idx, m_output_names[index]);
}
int32_t* get_input_op_params(size_t index) const {
return m_decoder->get_input_op_params(m_input_names[index]);
return m_decoder->get_input_op_params(m_node_idx, m_input_names[index]);
}
int32_t* get_output_op_params(size_t index) const {
return m_decoder->get_output_op_params(m_output_names[index]);
return m_decoder->get_output_op_params(m_node_idx, m_output_names[index]);
}
ov::element::Type get_output_type(size_t index) const {
@ -72,7 +70,7 @@ public:
}
Output<Node> get_input(int idx) const override {
return m_tensor_map->at(m_decoder->get_input_name(idx));
return m_tensor_map->at(m_input_names[idx]);
}
Output<Node> get_input(const std::string& name) const override {
@ -87,7 +85,7 @@ public:
}
const std::string& get_name() const override {
return m_decoder->get_op_name();
return m_decoder->get_op_name(m_node_idx);
}
ov::Any get_attribute_as_any(const std::string& name) const override {
@ -95,13 +93,14 @@ public:
}
int get_op_case() const {
return m_decoder->get_op_case();
return m_decoder->get_op_case(m_node_idx);
}
bool is_static() const { return m_decoder->is_static(); }
private:
std::shared_ptr<GgmlDecoder> m_decoder;
std::shared_ptr<TensorMap>& m_tensor_map;
int m_node_idx;
TranslateSession* m_translate_session;
std::vector<std::string> m_input_names;
std::vector<std::string> m_output_names;

View File

@ -164,8 +164,8 @@ std::shared_ptr<Model> TranslateSession::translate_graph(const frontend::InputMo
(*tensor_map)[it.first] = it.second;
}
auto node_visitor = [&](std::shared_ptr<GgmlDecoder> node) {
auto operation_type = node->get_op_type();
auto node_visitor = [&](std::shared_ptr<GgmlDecoder> decoder, int node_idx) {
auto operation_type = decoder->get_op_type(node_idx);
if (operation_type == "GGML_OP_NONE") {
return;
}
@ -174,10 +174,10 @@ std::shared_ptr<Model> TranslateSession::translate_graph(const frontend::InputMo
auto it = m_translator_map.find(operation_type);
FRONT_END_OP_CONVERSION_CHECK(it != m_translator_map.end(), "Translation for operation type ", operation_type,
" is not implemented.");
NodeContext node_context(node, tensor_map, this);
NodeContext node_context(decoder, tensor_map, node_idx, this);
converted_outputs = it->second(node_context);
const auto & node_output_names = node->get_output_names();
const auto & node_output_names = decoder->get_output_names(node_idx);
FRONT_END_OP_CONVERSION_CHECK(node_output_names.size() == converted_outputs.size(), "Number of ",
operation_type, " outputs greater than number of converted outputs, which are ",
node_output_names.size(), " and ", converted_outputs.size(), " respectively.");