Support op SET_ROWS
This commit is contained in:
parent
9a91ca6ef9
commit
63d000ba40
|
|
@ -90,7 +90,7 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph) {
|
|||
// 3. constructing a decoder for the whole graph naively (op test case)
|
||||
void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
|
||||
std::string node_name;
|
||||
if (node->op == GGML_OP_CPY) {
|
||||
if (node->op == GGML_OP_CPY || node->op == GGML_OP_SET_ROWS) {
|
||||
// CPY updates the input tensor in place. For later ov op that uses the
|
||||
// input tensor of CPY, we need to make sure they get the updated tensor
|
||||
// by putting the src tensor name in the tensor_map in
|
||||
|
|
@ -151,9 +151,11 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
|
|||
if (node->buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY) {
|
||||
assert(name.find("cache_k") == 0 || name.find("cache_v") == 0);
|
||||
}
|
||||
auto it = std::find(m_model_output_names.begin(), m_model_output_names.end(), name);
|
||||
if (it == m_model_output_names.end()) {
|
||||
if (auto it = std::find(m_model_output_names.begin(), m_model_output_names.end(), name);
|
||||
it == m_model_output_names.end()) {
|
||||
m_model_output_names.push_back(name);
|
||||
}
|
||||
if (auto it = std::find(m_kv_names.begin(), m_kv_names.end(), name); it == m_kv_names.end()) {
|
||||
m_kv_names.push_back(name);
|
||||
}
|
||||
}
|
||||
|
|
@ -166,6 +168,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
|
|||
m_op_case = 1;
|
||||
} else if (node->src[0]->ne[0] * node->src[0]->ne[1] == node->ne[0]) {
|
||||
m_op_case = 2;
|
||||
} else if (node->src[0]->ne[0] * node->src[0]->ne[1] == node->ne[1]) {
|
||||
m_op_case = 3;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -270,6 +274,8 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
|
|||
input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
|
||||
} else if (name.find("cache_v") == 0) {
|
||||
input_shape = ov::PartialShape{m_num_heads_kv, m_head_size, m_context_size};
|
||||
} else if (get_tensor_used_op(src)->op == GGML_OP_SET_ROWS) {
|
||||
input_shape = ov::PartialShape{1, 1, -1};
|
||||
} else if (src->op == GGML_OP_VIEW) {
|
||||
// This case is added to make test-backend-ops work
|
||||
input_shape = ov::PartialShape{get_shape(src->view_src)};
|
||||
|
|
@ -283,6 +289,8 @@ void GgmlOvDecoder::add_extra_inputs() {
|
|||
// Extra inputs:
|
||||
// 1. `past_token_len`, used to create indices for updating kv cache. Usually equal to inp_pos[0], except for
|
||||
// llama-perplexity.
|
||||
// Update: SET_ROWS replaces CPY for updating kv cache. The indices creation is not needed anymore. See:
|
||||
// https://github.com/ggml-org/llama.cpp/pull/14285
|
||||
// 2. `attention_size`, used in matmul's in the attention block. The shape of those matmul's are 32 aligned,
|
||||
// see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
|
||||
// Not used for NPU
|
||||
|
|
@ -305,6 +313,10 @@ void GgmlOvDecoder::add_extra_inputs() {
|
|||
(int64_t) (node->src[1]->op_params[0] / node->src[1]->nb[0] / m_head_size / m_num_heads_kv);
|
||||
break;
|
||||
}
|
||||
if (node->op == GGML_OP_SET_ROWS && std::string(node->name).find("cache_k") == 0) {
|
||||
assert(node->src[1]->type == GGML_TYPE_I64);
|
||||
past_token_len = *(int64_t*) (node->src[1]->data);
|
||||
}
|
||||
}
|
||||
|
||||
if (past_token_len == -1) {
|
||||
|
|
@ -342,6 +354,18 @@ void GgmlOvDecoder::add_extra_inputs() {
|
|||
}
|
||||
}
|
||||
|
||||
const ggml_tensor* GgmlOvDecoder::get_tensor_used_op(const ggml_tensor* tensor) const {
|
||||
for (int i = 0; i < m_cgraph->n_nodes; i++) {
|
||||
const auto* node = m_cgraph->nodes[i];
|
||||
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
||||
if (node->src[j] == tensor) {
|
||||
return node;
|
||||
}
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Tensor not found in cgraph");
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const {
|
||||
std::map<std::string, std::string> kv_param_res_names;
|
||||
for (const auto& name : m_kv_names) {
|
||||
|
|
@ -618,7 +642,8 @@ const std::string& GgmlOvDecoder::get_op_type() const {
|
|||
{GGML_OP_SOFT_MAX, "GGML_OP_SOFT_MAX" },
|
||||
{GGML_OP_SUB, "GGML_OP_SUB" },
|
||||
{GGML_OP_TRANSPOSE, "GGML_OP_TRANSPOSE"},
|
||||
{GGML_OP_VIEW, "GGML_OP_VIEW" }
|
||||
{GGML_OP_VIEW, "GGML_OP_VIEW" },
|
||||
{GGML_OP_SET_ROWS, "GGML_OP_SET_ROWS" },
|
||||
};
|
||||
static const std::map<ggml_unary_op, std::string> unary_ops = {
|
||||
{GGML_UNARY_OP_ABS, "GGML_UNARY_OP_ABS" },
|
||||
|
|
|
|||
|
|
@ -117,6 +117,9 @@ public:
|
|||
|
||||
static std::shared_ptr<ov::Node> create_weight_node(ggml_tensor* tensor);
|
||||
static std::map<std::string, std::shared_ptr<ov::Node>> create_weight_nodes(struct ggml_cgraph* cgraph);
|
||||
|
||||
const ggml_tensor* get_tensor_used_op(const ggml_tensor* tensor) const;
|
||||
|
||||
void clear_model_weights() { m_model_weights.clear(); }
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -331,7 +331,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
|
|||
static const std::set<ggml_op> supported_ops{GGML_OP_NONE, GGML_OP_ADD, GGML_OP_MUL, GGML_OP_MUL_MAT,
|
||||
GGML_OP_VIEW, GGML_OP_CONT, GGML_OP_CPY, GGML_OP_RESHAPE,
|
||||
GGML_OP_PERMUTE, GGML_OP_TRANSPOSE, GGML_OP_GET_ROWS, GGML_OP_ROPE,
|
||||
GGML_OP_RMS_NORM, GGML_OP_SCALE, GGML_OP_SOFT_MAX};
|
||||
GGML_OP_RMS_NORM, GGML_OP_SCALE, GGML_OP_SOFT_MAX, GGML_OP_SET_ROWS};
|
||||
static const std::set<ggml_unary_op> supported_unary_ops{
|
||||
GGML_UNARY_OP_SILU,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ public:
|
|||
return m_decoder->get_input_stride(m_input_names[index]);
|
||||
}
|
||||
|
||||
std::string get_output_name() const { return m_output_names[0]; }
|
||||
|
||||
PartialShape get_output_shape(size_t index) const {
|
||||
return m_decoder->get_output_shape(m_output_names[index]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ OutputVector translate_reshape(const NodeContext& context) {
|
|||
}
|
||||
|
||||
int op_case = context.get_op_case();
|
||||
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2, "Unsupported RESHAPE case");
|
||||
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported RESHAPE case");
|
||||
|
||||
auto output_shape = context.get_output_shape(0).to_shape();
|
||||
std::shared_ptr<ov::Node> new_shape_node;
|
||||
|
|
@ -32,11 +32,14 @@ OutputVector translate_reshape(const NodeContext& context) {
|
|||
ov::op::v0::Constant::create(ov::element::i64,
|
||||
{3},
|
||||
std::vector<int64_t>{-1, (int64_t)output_shape[1], (int64_t)output_shape[2]});
|
||||
} else {
|
||||
} else if (op_case == 2) {
|
||||
new_shape_node =
|
||||
ov::op::v0::Constant::create(ov::element::i64,
|
||||
{3},
|
||||
std::vector<int64_t>{(int64_t)output_shape[0], -1, (int64_t)output_shape[2]});
|
||||
} else {
|
||||
new_shape_node =
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{(int64_t) output_shape[0], -1, 1});
|
||||
}
|
||||
auto res = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), new_shape_node, false);
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
|
|
|
|||
|
|
@ -0,0 +1,51 @@
|
|||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <openvino/core/node.hpp>
|
||||
#include <openvino/core/node_output.hpp>
|
||||
#include <openvino/frontend/exception.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/convert.hpp>
|
||||
#include <openvino/op/gather.hpp>
|
||||
#include <openvino/op/reshape.hpp>
|
||||
#include <openvino/op/scatter_update.hpp>
|
||||
#include <openvino/op/shape_of.hpp>
|
||||
#include <openvino/op/slice.hpp>
|
||||
#include <openvino/op/squeeze.hpp>
|
||||
|
||||
#include "../node_context.hpp"
|
||||
#include "../op_table.hpp"
|
||||
#include "../utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace ggml {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_set_rows(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
|
||||
auto data = context.get_input(0);
|
||||
auto indices = context.get_input(1);
|
||||
auto dst = context.get_input(context.get_output_name());
|
||||
auto dst_shape = context.get_output_shape(0).to_shape();
|
||||
FRONT_END_OP_CONVERSION_CHECK(dst_shape[0] == 1, "Unsupported shape in SET_ROWS");
|
||||
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
|
||||
|
||||
auto dst_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
dst,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}),
|
||||
false);
|
||||
auto indices_reshaped =
|
||||
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
||||
auto data_converted = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type(0));
|
||||
auto data_reshaped = std::make_shared<ov::op::v0::Squeeze>(data_converted, zero);
|
||||
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
|
||||
auto res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace ggml
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
|
|
@ -35,6 +35,7 @@ std::unordered_map<std::string, CreatorFunction> get_supported_ops() {
|
|||
{"GGML_UNARY_OP_SILU", op::translate_unary_silu },
|
||||
{"GGML_OP_VIEW", op::translate_view },
|
||||
{"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu },
|
||||
{"GGML_OP_SET_ROWS", op::translate_set_rows },
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ GGML_OP_CONVERTER(translate_soft_max);
|
|||
GGML_OP_CONVERTER(translate_transpose);
|
||||
GGML_OP_CONVERTER(translate_view);
|
||||
GGML_OP_CONVERTER(translate_glu_swiglu);
|
||||
GGML_OP_CONVERTER(translate_set_rows);
|
||||
|
||||
} // namespace op
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue