Support op SET_ROWS

This commit is contained in:
Yu, Zijun 2025-08-13 10:57:22 +08:00 committed by Mustafa Cavus
parent 9a91ca6ef9
commit 63d000ba40
8 changed files with 93 additions and 7 deletions

View File

@ -90,7 +90,7 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph) {
// 3. constructing a decoder for the whole graph naively (op test case)
void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
std::string node_name;
if (node->op == GGML_OP_CPY) {
if (node->op == GGML_OP_CPY || node->op == GGML_OP_SET_ROWS) {
// CPY updates the input tensor in place. For later ov op that uses the
// input tensor of CPY, we need to make sure they get the updated tensor
// by putting the src tensor name in the tensor_map in
@ -151,9 +151,11 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
if (node->buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY) {
assert(name.find("cache_k") == 0 || name.find("cache_v") == 0);
}
auto it = std::find(m_model_output_names.begin(), m_model_output_names.end(), name);
if (it == m_model_output_names.end()) {
if (auto it = std::find(m_model_output_names.begin(), m_model_output_names.end(), name);
it == m_model_output_names.end()) {
m_model_output_names.push_back(name);
}
if (auto it = std::find(m_kv_names.begin(), m_kv_names.end(), name); it == m_kv_names.end()) {
m_kv_names.push_back(name);
}
}
@ -166,6 +168,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
m_op_case = 1;
} else if (node->src[0]->ne[0] * node->src[0]->ne[1] == node->ne[0]) {
m_op_case = 2;
} else if (node->src[0]->ne[0] * node->src[0]->ne[1] == node->ne[1]) {
m_op_case = 3;
}
break;
}
@ -270,6 +274,8 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
} else if (name.find("cache_v") == 0) {
input_shape = ov::PartialShape{m_num_heads_kv, m_head_size, m_context_size};
} else if (get_tensor_used_op(src)->op == GGML_OP_SET_ROWS) {
input_shape = ov::PartialShape{1, 1, -1};
} else if (src->op == GGML_OP_VIEW) {
// This case is added to make test-backend-ops work
input_shape = ov::PartialShape{get_shape(src->view_src)};
@ -283,6 +289,8 @@ void GgmlOvDecoder::add_extra_inputs() {
// Extra inputs:
// 1. `past_token_len`, used to create indices for updating kv cache. Usually equal to inp_pos[0], except for
// llama-perplexity.
// Update: SET_ROWS replaces CPY for updating kv cache. The indices creation is not needed anymore. See:
// https://github.com/ggml-org/llama.cpp/pull/14285
// 2. `attention_size`, used in matmul's in the attention block. The shape of those matmul's are 32 aligned,
// see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
// Not used for NPU
@ -305,6 +313,10 @@ void GgmlOvDecoder::add_extra_inputs() {
(int64_t) (node->src[1]->op_params[0] / node->src[1]->nb[0] / m_head_size / m_num_heads_kv);
break;
}
if (node->op == GGML_OP_SET_ROWS && std::string(node->name).find("cache_k") == 0) {
assert(node->src[1]->type == GGML_TYPE_I64);
past_token_len = *(int64_t*) (node->src[1]->data);
}
}
if (past_token_len == -1) {
@ -342,6 +354,18 @@ void GgmlOvDecoder::add_extra_inputs() {
}
}
const ggml_tensor* GgmlOvDecoder::get_tensor_used_op(const ggml_tensor* tensor) const {
for (int i = 0; i < m_cgraph->n_nodes; i++) {
const auto* node = m_cgraph->nodes[i];
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] == tensor) {
return node;
}
}
}
throw std::runtime_error("Tensor not found in cgraph");
}
std::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const {
std::map<std::string, std::string> kv_param_res_names;
for (const auto& name : m_kv_names) {
@ -618,7 +642,8 @@ const std::string& GgmlOvDecoder::get_op_type() const {
{GGML_OP_SOFT_MAX, "GGML_OP_SOFT_MAX" },
{GGML_OP_SUB, "GGML_OP_SUB" },
{GGML_OP_TRANSPOSE, "GGML_OP_TRANSPOSE"},
{GGML_OP_VIEW, "GGML_OP_VIEW" }
{GGML_OP_VIEW, "GGML_OP_VIEW" },
{GGML_OP_SET_ROWS, "GGML_OP_SET_ROWS" },
};
static const std::map<ggml_unary_op, std::string> unary_ops = {
{GGML_UNARY_OP_ABS, "GGML_UNARY_OP_ABS" },

View File

@ -117,6 +117,9 @@ public:
static std::shared_ptr<ov::Node> create_weight_node(ggml_tensor* tensor);
static std::map<std::string, std::shared_ptr<ov::Node>> create_weight_nodes(struct ggml_cgraph* cgraph);
const ggml_tensor* get_tensor_used_op(const ggml_tensor* tensor) const;
void clear_model_weights() { m_model_weights.clear(); }
private:

View File

@ -331,7 +331,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
static const std::set<ggml_op> supported_ops{GGML_OP_NONE, GGML_OP_ADD, GGML_OP_MUL, GGML_OP_MUL_MAT,
GGML_OP_VIEW, GGML_OP_CONT, GGML_OP_CPY, GGML_OP_RESHAPE,
GGML_OP_PERMUTE, GGML_OP_TRANSPOSE, GGML_OP_GET_ROWS, GGML_OP_ROPE,
GGML_OP_RMS_NORM, GGML_OP_SCALE, GGML_OP_SOFT_MAX};
GGML_OP_RMS_NORM, GGML_OP_SCALE, GGML_OP_SOFT_MAX, GGML_OP_SET_ROWS};
static const std::set<ggml_unary_op> supported_unary_ops{
GGML_UNARY_OP_SILU,
};

View File

@ -46,6 +46,8 @@ public:
return m_decoder->get_input_stride(m_input_names[index]);
}
std::string get_output_name() const { return m_output_names[0]; }
PartialShape get_output_shape(size_t index) const {
return m_decoder->get_output_shape(m_output_names[index]);
}

View File

@ -23,7 +23,7 @@ OutputVector translate_reshape(const NodeContext& context) {
}
int op_case = context.get_op_case();
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2, "Unsupported RESHAPE case");
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported RESHAPE case");
auto output_shape = context.get_output_shape(0).to_shape();
std::shared_ptr<ov::Node> new_shape_node;
@ -32,11 +32,14 @@ OutputVector translate_reshape(const NodeContext& context) {
ov::op::v0::Constant::create(ov::element::i64,
{3},
std::vector<int64_t>{-1, (int64_t)output_shape[1], (int64_t)output_shape[2]});
} else {
} else if (op_case == 2) {
new_shape_node =
ov::op::v0::Constant::create(ov::element::i64,
{3},
std::vector<int64_t>{(int64_t)output_shape[0], -1, (int64_t)output_shape[2]});
} else {
new_shape_node =
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{(int64_t) output_shape[0], -1, 1});
}
auto res = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), new_shape_node, false);
return rename_outputs_with_suffix({res}, context.get_name());

View File

@ -0,0 +1,51 @@
#include <cstdint>
#include <memory>
#include <openvino/core/node.hpp>
#include <openvino/core/node_output.hpp>
#include <openvino/frontend/exception.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/gather.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/scatter_update.hpp>
#include <openvino/op/shape_of.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/squeeze.hpp>
#include "../node_context.hpp"
#include "../op_table.hpp"
#include "../utils.hpp"
namespace ov {
namespace frontend {
namespace ggml {
namespace op {
OutputVector translate_set_rows(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto data = context.get_input(0);
auto indices = context.get_input(1);
auto dst = context.get_input(context.get_output_name());
auto dst_shape = context.get_output_shape(0).to_shape();
FRONT_END_OP_CONVERSION_CHECK(dst_shape[0] == 1, "Unsupported shape in SET_ROWS");
auto zero = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
auto dst_reshaped = std::make_shared<ov::op::v1::Reshape>(
dst,
ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}),
false);
auto indices_reshaped =
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
auto data_converted = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type(0));
auto data_reshaped = std::make_shared<ov::op::v0::Squeeze>(data_converted, zero);
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
auto res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
return rename_outputs_with_suffix({res}, context.get_name());
}
} // namespace op
} // namespace ggml
} // namespace frontend
} // namespace ov

View File

@ -35,6 +35,7 @@ std::unordered_map<std::string, CreatorFunction> get_supported_ops() {
{"GGML_UNARY_OP_SILU", op::translate_unary_silu },
{"GGML_OP_VIEW", op::translate_view },
{"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu },
{"GGML_OP_SET_ROWS", op::translate_set_rows },
};
}

View File

@ -26,6 +26,7 @@ GGML_OP_CONVERTER(translate_soft_max);
GGML_OP_CONVERTER(translate_transpose);
GGML_OP_CONVERTER(translate_view);
GGML_OP_CONVERTER(translate_glu_swiglu);
GGML_OP_CONVERTER(translate_set_rows);
} // namespace op