This commit is contained in:
Yu, Zijun 2025-08-14 15:40:36 +08:00 committed by Mustafa Cavus
parent 63d000ba40
commit 7bda5021f9
4 changed files with 65 additions and 6 deletions

View File

@ -193,6 +193,14 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
}
break;
}
case GGML_OP_SET_ROWS: {
if (std::string(node->name).find("cache_k") == 0) {
m_op_case = 1;
} else {
m_op_case = 2;
}
break;
}
case GGML_OP_PERMUTE: {
if (node->src[0]->view_src == nullptr) {
// Permute Qcur
@ -274,8 +282,18 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
} else if (name.find("cache_v") == 0) {
input_shape = ov::PartialShape{m_num_heads_kv, m_head_size, m_context_size};
} else if (get_tensor_used_op(src)->op == GGML_OP_SET_ROWS) {
} else if (const auto* op = get_tensor_used_op(src); op->op == GGML_OP_SET_ROWS) {
input_shape = ov::PartialShape{1, 1, -1};
if (m_is_static) {
if (m_is_first_token) {
// Dummy static shape, since the indices are not used in this case
input_shape = ov::PartialShape{1};
} else if (std::string(op->name).find("cache_k") == 0) {
input_shape = ov::PartialShape{1, 1, 1};
} else {
input_shape = ov::PartialShape{1, 1, m_num_heads_kv * m_head_size};
}
}
} else if (src->op == GGML_OP_VIEW) {
// This case is added to make test-backend-ops work
input_shape = ov::PartialShape{get_shape(src->view_src)};
@ -316,6 +334,7 @@ void GgmlOvDecoder::add_extra_inputs() {
if (node->op == GGML_OP_SET_ROWS && std::string(node->name).find("cache_k") == 0) {
assert(node->src[1]->type == GGML_TYPE_I64);
past_token_len = *(int64_t*) (node->src[1]->data);
break;
}
}
@ -366,6 +385,22 @@ const ggml_tensor* GgmlOvDecoder::get_tensor_used_op(const ggml_tensor* tensor)
throw std::runtime_error("Tensor not found in cgraph");
}
const ggml_tensor* GgmlOvDecoder::get_tensor_from_name(const std::string& name) const {
for (int i = 0; i < m_cgraph->n_nodes; i++) {
const auto* node = m_cgraph->nodes[i];
for (int j = 0; j < GGML_MAX_SRC; j++) {
const auto* src = node->src[j];
if (src == nullptr) {
break;
}
if (std::string(src->name) == name) {
return src;
}
}
}
return nullptr;
}
std::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const {
std::map<std::string, std::string> kv_param_res_names;
for (const auto& name : m_kv_names) {

View File

@ -119,6 +119,7 @@ public:
static std::map<std::string, std::shared_ptr<ov::Node>> create_weight_nodes(struct ggml_cgraph* cgraph);
const ggml_tensor* get_tensor_used_op(const ggml_tensor* tensor) const;
const ggml_tensor* get_tensor_from_name(const std::string& name) const;
void clear_model_weights() { m_model_weights.clear(); }

View File

@ -11,6 +11,7 @@
#include <openvino/op/shape_of.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/squeeze.hpp>
#include <openvino/op/transpose.hpp>
#include "../node_context.hpp"
#include "../op_table.hpp"
@ -25,21 +26,40 @@ OutputVector translate_set_rows(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto data = context.get_input(0);
auto indices = context.get_input(1);
auto dst = context.get_input(context.get_output_name());
data = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type(0));
auto dst_shape = context.get_output_shape(0).to_shape();
FRONT_END_OP_CONVERSION_CHECK(dst_shape[0] == 1, "Unsupported shape in SET_ROWS");
auto zero = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
if (context.is_static() && context.is_first_token()) {
Output<Node> res;
if (context.get_op_case() == 2) {
res = std::make_shared<ov::op::v1::Reshape>(
data,
ov::op::v0::Constant::create(
ov::element::i64,
{3},
{context.get_context_size(), context.get_num_heads_kv(), context.get_head_size()}),
false);
res = std::make_shared<ov::op::v1::Transpose>(
res, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 2, 0}));
} else {
res = data;
}
return rename_outputs_with_suffix({res}, context.get_name());
}
auto indices = context.get_input(1);
auto dst = context.get_input(context.get_output_name());
auto zero = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
auto dst_reshaped = std::make_shared<ov::op::v1::Reshape>(
dst,
ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}),
false);
auto indices_reshaped =
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
auto data_converted = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type(0));
auto data_reshaped = std::make_shared<ov::op::v0::Squeeze>(data_converted, zero);
auto data_reshaped = std::make_shared<ov::op::v0::Squeeze>(data, zero);
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
auto res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
return rename_outputs_with_suffix({res}, context.get_name());

View File

@ -328,6 +328,9 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
std::copy(padded_data.begin(), padded_data.end(), data_ptr);
}
} else if (const auto* op = ggml_decoder->get_tensor_used_op(ggml_decoder->get_tensor_from_name(param_name));
op->op == GGML_OP_SET_ROWS && is_static && is_first_token) {
input_tensor = ov::Tensor(ov::element::i64, ov::Shape{1});
} else {
input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name);
}