Initial stateful graph support
This commit is contained in:
parent
0d6f253e48
commit
5f30eacdb4
|
|
@ -44,9 +44,11 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
|
|||
ComputeParams & compute_params,
|
||||
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
|
||||
bool is_static,
|
||||
bool is_stateful,
|
||||
bool is_prefill,
|
||||
int prefill_chunk_size) :
|
||||
m_is_static(is_static),
|
||||
m_is_stateful(is_stateful),
|
||||
m_is_prefill(is_prefill),
|
||||
m_prefill_chunk_size(prefill_chunk_size),
|
||||
m_cgraph(cgraph),
|
||||
|
|
@ -157,19 +159,40 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
|
|||
ggml_backend_buffer * buffer = src->buffer;
|
||||
|
||||
if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY || src->flags & GGML_TENSOR_FLAG_INPUT) {
|
||||
ov::PartialShape stateful_kv_shape;
|
||||
// GGML_BACKEND_BUFFER_USAGE_ANY are kv caches
|
||||
if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY) {
|
||||
assert(src_name.find("cache_k") == 0 || src_name.find("cache_v") == 0);
|
||||
if (auto it = std::find(m_model_params.kv_names.begin(), m_model_params.kv_names.end(), src_name); it == m_model_params.kv_names.end()) {
|
||||
m_model_params.kv_names.push_back(src_name);
|
||||
if (is_stateful()) {
|
||||
// TODO: The shape modification for stateful model below is not validated for all supported models yet. More generic solution might be needed
|
||||
// to enable additional cases. Ideally, this could be removed from decoder and done as part of a transformation later.
|
||||
auto stateless_kv_shape = get_graph_input_shape(node, src);
|
||||
assert(stateless_kv_shape.size() == 4 && stateless_kv_shape[0] == 1 && stateless_kv_shape[1] == 1
|
||||
&& stateless_kv_shape[2].is_dynamic() && stateless_kv_shape[3] == (m_model_params.n_heads_kv*m_model_params.head_size));
|
||||
stateful_kv_shape = {stateless_kv_shape[0], ov::Dimension::dynamic(), m_model_params.n_heads_kv, m_model_params.head_size};
|
||||
}
|
||||
}
|
||||
}
|
||||
if (m_model_inputs.find(src_name) != m_model_inputs.end()) {
|
||||
continue;
|
||||
}
|
||||
m_inputs[src_name] = src;
|
||||
auto param_node =
|
||||
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(node, src));
|
||||
param_node->set_friendly_name(src_name);
|
||||
param_node->output(0).get_tensor().set_names({src_name});
|
||||
m_model_inputs[src_name] = param_node;
|
||||
assert(stateful_kv_shape.rank().is_static());
|
||||
if (stateful_kv_shape.rank().get_length() != 0) {
|
||||
auto param_node =
|
||||
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), stateful_kv_shape);
|
||||
param_node->set_friendly_name(src_name);
|
||||
param_node->output(0).get_tensor().set_names({src_name});
|
||||
m_model_inputs[src_name] = param_node;
|
||||
} else {
|
||||
auto param_node =
|
||||
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(node, src));
|
||||
param_node->set_friendly_name(src_name);
|
||||
param_node->output(0).get_tensor().set_names({src_name});
|
||||
m_model_inputs[src_name] = param_node;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -378,6 +401,8 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co
|
|||
} else if (name.find("KQ_mask") == 0) {
|
||||
if (m_is_static) {
|
||||
input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_model_params.ctx};
|
||||
} else if (m_is_stateful) {
|
||||
input_shape = ov::PartialShape{1, 1, -1, -1};
|
||||
} else {
|
||||
input_shape = ov::PartialShape{-1, 1, -1, -1};
|
||||
}
|
||||
|
|
@ -465,15 +490,15 @@ const ggml_tensor * GgmlOvDecoder::get_tensor_from_name(const std::string & name
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// std::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const {
|
||||
// std::map<std::string, std::string> kv_param_res_names;
|
||||
// for (const auto & name : m_model_params.kv_names) {
|
||||
// if (name.find("cache_k") == 0 || name.find("cache_v") == 0) {
|
||||
// kv_param_res_names[name] = name;
|
||||
// }
|
||||
// }
|
||||
// return kv_param_res_names;
|
||||
// }
|
||||
std::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const {
|
||||
std::map<std::string, std::string> kv_param_res_names;
|
||||
for (const auto & name : m_model_params.kv_names) {
|
||||
if (name.find("cache_k") == 0 || name.find("cache_v") == 0) {
|
||||
kv_param_res_names[name] = name;
|
||||
}
|
||||
}
|
||||
return kv_param_res_names;
|
||||
}
|
||||
|
||||
std::map<std::string, std::shared_ptr<ov::Node>> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph) {
|
||||
std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ struct ModelParams {
|
|||
int32_t * rope_params = nullptr;
|
||||
std::vector<int> swa_layers;
|
||||
|
||||
// std::vector<std::string> kv_names;
|
||||
std::vector<std::string> kv_names;
|
||||
|
||||
bool operator==(const ModelParams & other) const {
|
||||
return n_seq == other.n_seq && n_heads == other.n_heads && n_heads_kv == other.n_heads_kv &&
|
||||
|
|
@ -66,6 +66,7 @@ public:
|
|||
ComputeParams & compute_params,
|
||||
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
|
||||
bool is_static,
|
||||
bool is_stateful = false,
|
||||
bool is_prefill = false,
|
||||
int prefill_chunk_size = 256);
|
||||
|
||||
|
|
@ -171,10 +172,12 @@ public:
|
|||
|
||||
virtual int32_t * get_rope_params() const override { return m_model_params.rope_params; }
|
||||
|
||||
// virtual std::map<std::string, std::string> get_kv_param_res_names() const override;
|
||||
virtual std::map<std::string, std::string> get_kv_param_res_names() const override;
|
||||
|
||||
virtual bool is_static() const override { return m_is_static; }
|
||||
|
||||
virtual bool is_stateful() const override { return m_is_stateful; }
|
||||
|
||||
ov::PartialShape get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const;
|
||||
|
||||
static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename);
|
||||
|
|
@ -200,6 +203,7 @@ public:
|
|||
void set_compute_params(const ComputeParams & compute_params) { m_compute_params = compute_params; }
|
||||
|
||||
bool m_is_static = false;
|
||||
bool m_is_stateful = false;
|
||||
bool m_is_prefill = false;
|
||||
int m_prefill_chunk_size = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -59,10 +59,13 @@ public:
|
|||
virtual std::vector<std::string> get_model_output_names() const = 0;
|
||||
|
||||
virtual int32_t* get_rope_params() const = 0;
|
||||
// virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0;
|
||||
|
||||
virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0;
|
||||
|
||||
virtual bool is_static() const = 0;
|
||||
|
||||
virtual bool is_stateful() const = 0;
|
||||
|
||||
virtual int is_swa_layer(int layer) const = 0;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -91,8 +91,11 @@ public:
|
|||
int get_op_case() const {
|
||||
return m_decoder->get_op_case(m_node_idx);
|
||||
}
|
||||
|
||||
bool is_static() const { return m_decoder->is_static(); }
|
||||
|
||||
bool is_stateful() const { return m_decoder->is_stateful(); }
|
||||
|
||||
private:
|
||||
std::shared_ptr<GgmlDecoder> m_decoder;
|
||||
std::shared_ptr<TensorMap>& m_tensor_map;
|
||||
|
|
|
|||
|
|
@ -37,6 +37,9 @@ OutputVector translate_get_rows(const NodeContext & context) {
|
|||
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
|
||||
data = std::make_shared<ov::op::v0::Squeeze>(data, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
|
||||
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1);
|
||||
} else if (context.is_stateful() && data.get_partial_shape().rank() == 3) {
|
||||
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
|
||||
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1);
|
||||
} else {
|
||||
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
|
||||
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
|
||||
|
|
@ -45,7 +48,9 @@ OutputVector translate_get_rows(const NodeContext & context) {
|
|||
if (res.get_element_type() != context.get_output_type()) {
|
||||
res = std::make_shared<ov::op::v0::Convert>(res, context.get_output_type());
|
||||
}
|
||||
res = std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
|
||||
if (!(context.is_stateful())) {
|
||||
res = std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
|
||||
}
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ OutputVector translate_permute(const NodeContext & context) {
|
|||
auto src = context.get_input(0);
|
||||
auto perm = ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3});
|
||||
|
||||
if (op_case == 1) {
|
||||
if (op_case == 1 || context.is_stateful()) {
|
||||
res = std::make_shared<ov::op::v1::Transpose>(src, perm);
|
||||
} else if (op_case == 4) {
|
||||
auto output_shape = context.get_output_shape().to_shape();
|
||||
|
|
|
|||
|
|
@ -32,10 +32,15 @@ OutputVector translate_reshape(const NodeContext & context) {
|
|||
auto output_shape = context.get_output_shape().to_shape();
|
||||
std::shared_ptr<ov::Node> new_shape_node;
|
||||
if (op_case == 1) {
|
||||
new_shape_node = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {4},
|
||||
std::vector<int64_t>{(int64_t) output_shape[0], -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
|
||||
|
||||
if (context.is_stateful()) {
|
||||
new_shape_node = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {3},
|
||||
std::vector<int64_t>{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
|
||||
} else {
|
||||
new_shape_node = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {4},
|
||||
std::vector<int64_t>{(int64_t) output_shape[0], -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
|
||||
}
|
||||
} else if (op_case == 2) {
|
||||
new_shape_node = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {4},
|
||||
|
|
@ -50,8 +55,13 @@ OutputVector translate_reshape(const NodeContext & context) {
|
|||
return {context.get_input(0).get_node_shared_ptr()->input_value(0)};
|
||||
|
||||
} else if (op_case == 5) {
|
||||
std::vector<int64_t> shape_vec = {1, 1, -1, (int64_t) context.get_output_shape().to_shape()[3]};
|
||||
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, shape_vec);
|
||||
if (context.is_stateful()) {
|
||||
std::vector<int64_t> shape_vec = {1, -1, (int64_t) context.get_output_shape().to_shape()[3]};
|
||||
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {3}, shape_vec);
|
||||
} else {
|
||||
std::vector<int64_t> shape_vec = {1, 1, -1, (int64_t) context.get_output_shape().to_shape()[3]};
|
||||
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, shape_vec);
|
||||
}
|
||||
|
||||
// // Alternative
|
||||
// auto token_len = context.get_input("token_len");
|
||||
|
|
|
|||
|
|
@ -54,9 +54,18 @@ OutputVector translate_rope(const NodeContext & context) {
|
|||
// The input comes from a VIEW
|
||||
int slice_len = output_shape[2] * output_shape[3];
|
||||
data_node = process_view_input(context, 0, slice_len).get_node_shared_ptr();
|
||||
auto data_shape = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
|
||||
data_node = std::make_shared<ov::op::v1::Reshape>(data_node, data_shape, false);
|
||||
if (context.is_stateful()) {
|
||||
auto data_shape = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {3}, std::vector<int64_t>{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
|
||||
data_node = std::make_shared<ov::op::v1::Reshape>(data_node, data_shape, false);
|
||||
} else {
|
||||
auto data_shape = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
|
||||
data_node = std::make_shared<ov::op::v1::Reshape>(data_node, data_shape, false);
|
||||
}
|
||||
//auto data_shape = ov::op::v0::Constant::create(
|
||||
// ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
|
||||
//data_node = std::make_shared<ov::op::v1::Reshape>(data_node, data_shape, false);
|
||||
}
|
||||
|
||||
const int mode = op_params[2];
|
||||
|
|
@ -67,10 +76,19 @@ OutputVector translate_rope(const NodeContext & context) {
|
|||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
|
||||
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
|
||||
auto even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, three);
|
||||
auto odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, three);
|
||||
Output<Node> even_slice;
|
||||
Output<Node> odd_slice;
|
||||
int32_t unsqueeze_dim = 4;
|
||||
if (context.is_stateful()) {
|
||||
unsqueeze_dim = 3;
|
||||
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, two);
|
||||
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, two);
|
||||
} else {
|
||||
auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
|
||||
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, three);
|
||||
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, three);
|
||||
}
|
||||
|
||||
Output<Node> first_half =
|
||||
std::make_shared<ov::op::v1::Subtract>(std::make_shared<ov::op::v1::Multiply>(even_slice, cos_theta_node),
|
||||
|
|
@ -80,10 +98,10 @@ OutputVector translate_rope(const NodeContext & context) {
|
|||
std::make_shared<ov::op::v1::Multiply>(odd_slice, cos_theta_node));
|
||||
|
||||
first_half = std::make_shared<ov::op::v0::Unsqueeze>(first_half,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {1}, {4}));
|
||||
ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim}));
|
||||
second_half = std::make_shared<ov::op::v0::Unsqueeze>(second_half,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {1}, {4}));
|
||||
auto stack = std::make_shared<ov::op::v0::Concat>(OutputVector{first_half, second_half}, 4);
|
||||
ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim}));
|
||||
auto stack = std::make_shared<ov::op::v0::Concat>(OutputVector{first_half, second_half}, unsqueeze_dim);
|
||||
|
||||
auto data_shape = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
|
||||
|
|
@ -102,7 +120,11 @@ OutputVector translate_rope(const NodeContext & context) {
|
|||
std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, sin_theta_node),
|
||||
std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node));
|
||||
|
||||
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, 3);
|
||||
int32_t concat_dim = 3;
|
||||
if (context.is_stateful()) {
|
||||
concat_dim = 2;
|
||||
}
|
||||
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, concat_dim);
|
||||
}
|
||||
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
|
|
|
|||
|
|
@ -45,7 +45,17 @@ OutputVector translate_set_rows(const NodeContext & context) {
|
|||
false);
|
||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2});
|
||||
|
||||
Output<Node> res = std::make_shared<ov::op::v3::ScatterUpdate>(dst, ind_squeezed, data_reshaped, axes);
|
||||
Output<Node> res;
|
||||
if (context.is_stateful()) {
|
||||
int concat_axis = 1;
|
||||
int64_t dim2 = dst.get_partial_shape()[2].get_length();
|
||||
int64_t dim3 = dst.get_partial_shape()[3].get_length();
|
||||
data = std::make_shared<ov::op::v1::Reshape>(
|
||||
data, ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 1, (int64_t) -1, dim2, dim3}), false);
|
||||
res = std::make_shared<ov::op::v0::Concat>(OutputVector{dst, data}, concat_axis);
|
||||
} else {
|
||||
res = std::make_shared<ov::op::v3::ScatterUpdate>(dst, ind_squeezed, data_reshaped, axes);
|
||||
}
|
||||
|
||||
if (auto dst_reshape = std::dynamic_pointer_cast<ov::op::v1::Reshape>(dst.get_node_shared_ptr())) {
|
||||
// Fix the case of multiple sequences, reshape back to original shape [1, n_seq, ctx_per_seq, emb]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@
|
|||
#include <openvino/op/convert.hpp>
|
||||
#include <openvino/op/cos.hpp>
|
||||
#include <openvino/op/divide.hpp>
|
||||
#include <openvino/op/gather.hpp>
|
||||
#include <openvino/op/multiply.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <openvino/op/range.hpp>
|
||||
|
|
@ -82,6 +83,20 @@ void add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {
|
|||
std::shared_ptr<ov::Node> mask_sliced;
|
||||
if (is_static) {
|
||||
mask_sliced = mask;
|
||||
} else if (ggml_model_decoder.is_stateful()) {
|
||||
auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0});
|
||||
auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1});
|
||||
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {-2,-1});
|
||||
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
|
||||
auto shape_of_inp_pos = std::make_shared<ov::op::v3::ShapeOf>(inp_pos);
|
||||
auto gather_inp_pos = std::make_shared<ov::op::v8::Gather>(shape_of_inp_pos, two_1d, zero_1d);
|
||||
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len_per_seq, gather_inp_pos}, 0);
|
||||
mask_sliced =
|
||||
std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
|
||||
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
|
||||
mask_sliced->set_friendly_name(sliced_name);
|
||||
} else {
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
|
|
@ -226,11 +241,11 @@ std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<M
|
|||
manager.set_per_pass_validation(true);
|
||||
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();
|
||||
|
||||
// if (!ggml_model_decoder->is_static()) {
|
||||
// const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
|
||||
// const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
|
||||
// manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
|
||||
// }
|
||||
if (ggml_model_decoder->is_stateful()) {
|
||||
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
|
||||
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
|
||||
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
|
||||
}
|
||||
|
||||
if (ggml_model_decoder->is_static()) {
|
||||
manager.register_pass<pass::EliminateZeroPoints>();
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
#include <openvino/op/multiply.hpp>
|
||||
#include <openvino/op/shape_of.hpp>
|
||||
#include <openvino/op/sin.hpp>
|
||||
#include <openvino/op/squeeze.hpp>
|
||||
#include <openvino/op/subtract.hpp>
|
||||
#include <openvino/op/transpose.hpp>
|
||||
#include <string>
|
||||
|
|
@ -113,11 +114,20 @@ void ggml_rope_yarn_corr_dims(int n_dims,
|
|||
|
||||
std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params,
|
||||
std::shared_ptr<ov::Node> inp_pos,
|
||||
std::shared_ptr<ov::Node> rope_freqs_weight) {
|
||||
inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32);
|
||||
auto pos_perm =
|
||||
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 3, 1, 2});
|
||||
inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_perm);
|
||||
std::shared_ptr<ov::Node> rope_freqs_weight,
|
||||
bool stateful) {
|
||||
if (stateful) {
|
||||
inp_pos = std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
|
||||
inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32);
|
||||
auto pos_perm =
|
||||
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{2, 1, 0});
|
||||
inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_perm);
|
||||
} else {
|
||||
inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32);
|
||||
auto pos_perm =
|
||||
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 3, 1, 2});
|
||||
inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_perm);
|
||||
}
|
||||
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
|
|
@ -145,8 +155,14 @@ std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params
|
|||
factor[i] = theta_scale * factor[i - 1];
|
||||
}
|
||||
|
||||
Output<Node> freq_factors =
|
||||
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor);
|
||||
Output<Node> freq_factors;
|
||||
if (stateful) {
|
||||
freq_factors =
|
||||
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor);
|
||||
} else {
|
||||
freq_factors =
|
||||
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor);
|
||||
}
|
||||
if (rope_freqs_weight) {
|
||||
freq_factors = std::make_shared<ov::op::v1::Divide>(freq_factors, rope_freqs_weight);
|
||||
}
|
||||
|
|
@ -161,7 +177,12 @@ std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params
|
|||
theta = theta_interp;
|
||||
} else {
|
||||
auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor);
|
||||
auto one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f});
|
||||
Output<Node> one;
|
||||
if (stateful) {
|
||||
one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f});
|
||||
} else {
|
||||
one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f});
|
||||
}
|
||||
auto one_minus_ramp = std::make_shared<ov::op::v1::Subtract>(one, ramp_mix);
|
||||
|
||||
theta = std::make_shared<ov::op::v1::Add>(std::make_shared<ov::op::v1::Multiply>(theta_interp, one_minus_ramp),
|
||||
|
|
|
|||
|
|
@ -66,7 +66,8 @@ OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std::
|
|||
|
||||
std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t* rope_params,
|
||||
std::shared_ptr<ov::Node> inp_pos,
|
||||
std::shared_ptr<ov::Node> rope_freqs_weight = nullptr);
|
||||
std::shared_ptr<ov::Node> rope_freqs_weight = nullptr,
|
||||
bool stateful = false);
|
||||
|
||||
ov::Output<ov::Node> process_view_input(const NodeContext& context, int input_index, int slice_len = 0);
|
||||
|
||||
|
|
|
|||
|
|
@ -46,10 +46,14 @@ enum ggml_status ov_graph_compute(ggml_cgraph * cgraph) {
|
|||
// Use device from singleton (initialized during backend init)
|
||||
const auto & device = ggml_openvino_get_device_name();
|
||||
const auto is_static = ggml_openvino_is_npu();
|
||||
return is_static ? ov_graph_compute_static(cgraph) : ov_graph_compute_dynamic(cgraph, device);
|
||||
bool stateful = false;
|
||||
if (getenv("GGML_OPENVINO_STATEFUL_EXECUTION") && !is_static) {
|
||||
stateful = true;
|
||||
}
|
||||
return is_static ? ov_graph_compute_static(cgraph) : ov_graph_compute_dynamic(cgraph, device, stateful);
|
||||
}
|
||||
|
||||
enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::string & device) {
|
||||
enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::string & device, bool stateful) {
|
||||
auto & core = ov_singleton_core();
|
||||
const auto & config = ggml_openvino_get_compile_config();
|
||||
static auto is_static = false;
|
||||
|
|
@ -99,6 +103,12 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin
|
|||
ggml_decoder->add_extra_inputs();
|
||||
infer_request = infer_request_cache[key];
|
||||
|
||||
auto * inp_pos = get_inp_pos_tensor(cgraph);
|
||||
int32_t * pos_data = (int32_t *) inp_pos->data;
|
||||
if (pos_data[0] == 0) {
|
||||
infer_request->reset_state();
|
||||
}
|
||||
|
||||
decoder_end_time = ggml_time_us();
|
||||
conversion_end_time = decoder_end_time;
|
||||
compile_end_time = decoder_end_time;
|
||||
|
|
@ -108,7 +118,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin
|
|||
std::shared_ptr<ov::Model> model;
|
||||
auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph);
|
||||
|
||||
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights, is_static);
|
||||
ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights, is_static, stateful);
|
||||
decoder_end_time = ggml_time_us();
|
||||
|
||||
auto input_model = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder);
|
||||
|
|
@ -202,6 +212,7 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) {
|
|||
|
||||
static std::string device = "NPU";
|
||||
static auto is_static = true;
|
||||
static auto stateful = false;
|
||||
static auto prefill_chunk_size = get_prefill_chunk_size();
|
||||
const auto & config = ggml_openvino_get_compile_config();
|
||||
|
||||
|
|
@ -265,9 +276,9 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) {
|
|||
auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph);
|
||||
|
||||
auto ggml_decoder_prefill = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights,
|
||||
is_static, true, prefill_chunk_size);
|
||||
is_static, stateful, true, prefill_chunk_size);
|
||||
auto ggml_decoder_decode = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights,
|
||||
is_static, false, prefill_chunk_size);
|
||||
is_static, stateful, false, prefill_chunk_size);
|
||||
decoder_end_time = ggml_time_us();
|
||||
|
||||
auto input_model_prefill = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder_prefill);
|
||||
|
|
@ -606,8 +617,17 @@ ov::Tensor get_ov_output_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, con
|
|||
if (ggml_decoder->is_static() && result_name == "result_output" && output_shape[2] == 0) {
|
||||
output_shape[2] = 1;
|
||||
}
|
||||
ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data);
|
||||
return output_tensor;
|
||||
if (ggml_decoder->is_stateful() && result_name == "result_output") {
|
||||
std::vector<long unsigned int> output_shape_3d;
|
||||
for (size_t i=1; i<output_shape.size(); i++) {
|
||||
output_shape_3d.push_back(output_shape[i]);
|
||||
}
|
||||
ov::Tensor output_tensor(output_type, output_shape_3d, ggml_tensor->data);
|
||||
return output_tensor;
|
||||
} else {
|
||||
ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data);
|
||||
return output_tensor;
|
||||
}
|
||||
}
|
||||
|
||||
size_t checksum(const void * data, size_t size) {
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ struct graph_key_hash {
|
|||
|
||||
enum ggml_status ov_graph_compute(struct ggml_cgraph * cgraph);
|
||||
|
||||
enum ggml_status ov_graph_compute_dynamic(struct ggml_cgraph * cgraph, const std::string & device);
|
||||
enum ggml_status ov_graph_compute_dynamic(struct ggml_cgraph * cgraph, const std::string & device, bool stateful = false);
|
||||
enum ggml_status ov_graph_compute_static(struct ggml_cgraph * cgraph);
|
||||
|
||||
size_t checksum(const void * data, size_t size);
|
||||
|
|
|
|||
Loading…
Reference in New Issue